From 04ddb669310aa41390825b56967da6f1f5dd869a Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 19 Jun 2026 16:18:07 -0700 Subject: [PATCH] updated term enumerator --- src/ast/term_enumeration.cpp | 452 +++++++++++++--------------------- src/ast/term_enumeration.h | 10 +- src/smt/smt_model_checker.cpp | 6 +- src/smt/smt_model_finder.cpp | 74 +++--- src/smt/smt_model_finder.h | 2 +- src/test/CMakeLists.txt | 1 + src/test/main.cpp | 3 +- src/test/term_enumeration.cpp | 283 +++++++++++++++++++++ 8 files changed, 499 insertions(+), 332 deletions(-) create mode 100644 src/test/term_enumeration.cpp diff --git a/src/ast/term_enumeration.cpp b/src/ast/term_enumeration.cpp index 458bc4f61..84f32e967 100644 --- a/src/ast/term_enumeration.cpp +++ b/src/ast/term_enumeration.cpp @@ -22,6 +22,7 @@ #include "util/obj_hashtable.h" #include "ast/ast.h" #include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" #include "ast/arith_decl_plugin.h" #include "ast/bv_decl_plugin.h" #include "ast/seq_decl_plugin.h" @@ -60,7 +61,7 @@ struct production { */ class grammar { public: - grammar(ast_manager& m) : m(m) {} + grammar(ast_manager& m) : m(m), m_pinned(m) {} void add_production(production p) { if (p.is_leaf()) @@ -73,46 +74,8 @@ public: vector const& operators() const { return m_operators; } ast_manager& mgr() const { return m; } - void add_variable(char const* name, sort* s) { - expr_ref var(m.mk_const(name, s), m); - sort_ref sr(s, m); - sort_ref_vector dom(m); - add_production({name, sr, dom, [var](expr_ref_vector const&) { return var; }}); - } - - void add_int_const(int val) { - arith_util a(m); - expr_ref e(a.mk_int(val), m); - sort_ref sr(a.mk_int(), m); - sort_ref_vector dom(m); - add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }}); - } - - void add_bv_const(int val, unsigned bits) { - bv_util bv(m); - expr_ref e(bv.mk_numeral(rational(val), bits), m); - sort_ref sr(bv.mk_sort(bits), m); - sort_ref_vector dom(m); - add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }}); - } - - void add_string_const(std::string const& val) { - seq_util seq(m); - expr_ref e(seq.str.mk_string(zstring(val.c_str())), m); - sort_ref sr(seq.str.mk_string_sort(), m); - sort_ref_vector dom(m); - add_production({"\"" + val + "\"", sr, dom, [e](expr_ref_vector const&) { return e; }}); - } - - void add_bool_const(bool val) { - expr_ref e(val ? m.mk_true() : m.mk_false(), m); - sort_ref sr(m.mk_bool_sort(), m); - sort_ref_vector dom(m); - std::string n = val ? "true" : "false"; - add_production({n, sr, dom, [e](expr_ref_vector const&) { return e; }}); - } - void add_func_decl(func_decl *f) { + m_pinned.push_back(f); sort_ref range(f->get_range(), m); sort_ref_vector dom(m); for (unsigned i = 0; i < f->get_arity(); ++i) @@ -123,6 +86,7 @@ public: } void add_expr(expr *e) { + m_pinned.push_back(e); sort_ref range(e->get_sort(), m); sort_ref_vector dom(m); std::stringstream ss; @@ -131,72 +95,31 @@ public: add_production({name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }}); } + std::ostream& display(std::ostream& out) const { + out << "Leaves:\n"; + for (auto const &p : m_leaves) { + out << " " << p.name << " : " << mk_pp(p.range, m) << "\n"; + } + out << "Operators:\n"; + for (auto const &p : m_operators) { + out << " " << p.name << " : ("; + for (unsigned i = 0; i < p.domain.size(); ++i) { + if (i > 0) + out << ", "; + out << mk_pp(p.domain[i], m); + } + out << ") -> " << mk_pp(p.range, m) << "\n"; + } + return out; + } + private: ast_manager& m; + ast_ref_vector m_pinned; vector m_leaves; vector m_operators; }; -// ============================================================================ -// Observational Equivalence Manager -// ============================================================================ - -/** - * Evaluates candidate terms on a set of sample inputs and keeps only one - * representative per equivalence class (the one encountered first). - * - * Uses Z3's model evaluation to reduce terms to concrete values. - */ -class oe_manager { -public: - oe_manager(ast_manager& m) : m(m) {} - - void set_samples(vector samples) { - m_samples = std::move(samples); - m_seen.clear(); - } - - void add_sample(model_ref mdl) { - m_samples.push_back(std::move(mdl)); - m_seen.clear(); - } - - /** - * Returns true if `term` is a new representative (its fingerprint has - * not been seen before). - */ - bool is_representative(expr* term) { - auto fingerprint = compute_fingerprint(term); - if (fingerprint == 0) - return false; - return m_seen.insert(fingerprint).second; - } - - void clear() { m_seen.clear(); } - - size_t num_classes() const { return m_seen.size(); } - unsigned num_samples() const { return m_samples.size(); } - -private: - ast_manager& m; - vector m_samples; - std::unordered_set m_seen; - - uint64_t compute_fingerprint(expr* term) { - uint64_t a = 1, b = 2, c = 0; - for (auto& mdl : m_samples) { - expr_ref val(m); - model_evaluator eval(*mdl); - eval.set_model_completion(true); - if (!eval.eval(term, val, true)) - continue; - a *= val->hash(); - mix(a, b, c); - } - return c; - } -}; - // ============================================================================ // Term Bank - stores enumerated terms by cost and sort // ============================================================================ @@ -249,6 +172,36 @@ public: return m_terms[cost]->find(s); } + std::ostream& display(std::ostream& out) const { + for (unsigned cost = 0; cost < m_terms.size(); ++cost) { + if (!m_terms[cost]) + continue; + out << "cost " << cost << ":\n"; + for (auto& [s, terms] : *m_terms[cost]) { + out << " sort " << mk_pp(s, m) << ":\n"; + for (expr* e : terms) { + out << " #" << e->get_id() << " "; + if (cost == 0) { + out << mk_bounded_pp(e, m); + } + else if (is_app(e)) { + app* a = to_app(e); + out << a->get_decl()->get_name() << "("; + bool first = true; + for (expr* arg : *a) { + if (!first) out << ", "; + first = false; + out << "#" << arg->get_id(); + } + out << ")"; + } + out << "\n"; + } + } + } + return out; + } + private: ast_manager& m; expr_ref_vector m_pinned; @@ -338,33 +291,17 @@ private: // bottom_up_enumerator - the main bottom-up term enumeration engine // ============================================================================ -/** - * Enumerates terms bottom-up by cost, applying observational equivalence - * pruning. Users iterate via has_next() / next(). - * - * Usage: - * ast_manager m; - * grammar g(m); - * // ... add productions ... - * oe_manager oe(m); - * // ... set samples ... - * bottom_up_enumerator en(g, oe); - * arith_util a(m); - * en.set_target_sort(a.mk_int()); - * while (en.has_next()) { - * expr_ref term = en.next(); - * // ... check if term satisfies specification ... - * } - */ + class bottom_up_enumerator { public: - bottom_up_enumerator(grammar& grammar, oe_manager& oe) - : m_grammar(grammar), m(grammar.mgr()), m_oe(oe), + bottom_up_enumerator(grammar& grammar) + : m_grammar(grammar), m(grammar.mgr()), m_bank(grammar.mgr()), m_pending(grammar.mgr()) {} - void set_target_sort(sort* s) { m_target_sort = s; } - + void set_target_sort(sort *s) { + m_target_sort = s; + } bool has_next() { if (m_pending) return true; m_pending = find_next(); @@ -377,17 +314,21 @@ public: expr_ref result(m_pending, m); m_pending = nullptr; return result; - } + } term_bank const& bank() const { return m_bank; } + std::ostream& display(std::ostream& out) const { + m_grammar.display(out); + return m_bank.display(out); + } + void reset() { m_cost = 0; m_leaf_idx = 0; m_op_idx = 0; m_state = State::Leaves; m_bank.reset(); - m_oe.clear(); m_pending = nullptr; m_children_iter.reset(); } @@ -397,17 +338,17 @@ private: grammar& m_grammar; ast_manager& m; - oe_manager& m_oe; term_bank m_bank; unsigned m_cost = 0; unsigned m_leaf_idx = 0; unsigned m_op_idx = 0; unsigned m_bank_idx = 0; unsigned m_bank_size = 0; + bool m_has_range = false; State m_state = State::Leaves; - sort* m_target_sort = nullptr; expr_ref m_pending; std::unique_ptr m_children_iter; + sort *m_target_sort = nullptr; bool sort_matches(expr* e) const { return !m_target_sort || e->get_sort() == m_target_sort; @@ -418,21 +359,20 @@ private: switch (m_state) { case State::Leaves: while (m_leaf_idx < m_grammar.leaves().size()) { - production const& prod = m_grammar.leaves()[m_leaf_idx]; + production const &prod = m_grammar.leaves()[m_leaf_idx]; m_leaf_idx++; expr_ref_vector empty_args(m); expr_ref term = prod.builder(empty_args); - if (m_oe.is_representative(term)) { - m_bank.add(term, 0); - if (sort_matches(term)) - return term; - } + m_bank.add(term, 0); + if (sort_matches(term)) + return term; } m_state = State::Operators; m_cost = 1; m_op_idx = 0; m_bank_idx = 0; - m_bank_size = get_bank_size(); + m_bank_size = get_bank_size(); + m_has_range = false; m_children_iter.reset(); break; @@ -440,6 +380,7 @@ private: expr* result = enumerate_operators(); if (result) return result; + m_cost++; m_op_idx = 0; m_bank_idx = 0; @@ -477,16 +418,25 @@ private: expr_ref_vector children = m_children_iter->next(new_cost); production const &prod = ops[m_op_idx - 1]; expr_ref term = prod.builder(children); + // IF_VERBOSE(0, verbose_stream() << term << "\n"); SASSERT(new_cost >= m_cost); - if (m_oe.is_representative(term)) { - m_bank.add(term, new_cost); - if (sort_matches(term) && new_cost == m_cost) - return term; - } + m_bank.add(term, new_cost); + if (sort_matches(term) && new_cost == m_cost) { + m_has_range = true; + return term; + } continue; } - if (m_op_idx >= ops.size()) + if (ops.empty()) { + m_state = State::Done; return nullptr; + } + + if (m_op_idx >= ops.size()) { + if (!m_has_range) + m_state = State::Done; + return nullptr; + } production const &prod = ops[m_op_idx]; m_op_idx++; m_children_iter = std::make_unique(m, prod, m_bank, m_cost); @@ -503,12 +453,11 @@ private: struct term_enumeration::imp { ast_manager& m; term_enum::grammar m_grammar; - term_enum::oe_manager m_oe; term_enum::bottom_up_enumerator m_bottom_up_enumerator; std::function m_cost; imp(ast_manager& m) : - m(m), m_grammar(m), m_oe(m), m_bottom_up_enumerator(m_grammar, m_oe) {} + m(m), m_grammar(m), m_bottom_up_enumerator(m_grammar) {} void add_production(func_decl* f) { m_grammar.add_func_decl(f); @@ -519,71 +468,11 @@ struct term_enumeration::imp { } void set_cost(std::function const& cost) { - m_cost = cost; + // TODO } - // Enumerate terms of given sort up to a cost, ordered by cost. - // Returns the next term in cost order, or nullptr if exhausted at current cost. - expr* next_term(sort* s, unsigned& cost_state, unsigned& idx_state, - vector& levels) { - // Expand levels as needed - while (idx_state >= level_size(levels, cost_state)) { - cost_state++; - expand_level(s, cost_state, levels); - idx_state = 0; - if (level_size(levels, cost_state) > 0) - break; - } - if (cost_state >= levels.size() || idx_state >= levels[cost_state].size()) - return nullptr; - return levels[cost_state].get(idx_state++); - } - -private: - unsigned level_size(vector const& levels, unsigned cost) const { - if (cost >= levels.size()) return 0; - return levels[cost].size(); - } - - void expand_level(sort* s, unsigned cost, vector& levels) { - if (cost >= levels.size()) - levels.resize(cost + 1, expr_ref_vector(m)); - - // Collect terms at this cost - if (cost == 0) { - // Leaves - for (auto const& prod : m_grammar.leaves()) { - if (prod.range.get() != s) continue; - expr_ref_vector empty_args(m); - expr_ref term = prod.builder(empty_args); - if (m_oe.is_representative(term)) { - m_bottom_up_enumerator.bank(); // just to ensure bank is populated - levels[0].push_back(term); - } - } - } - else { - // operators - for (auto const& prod : m_grammar.operators()) { - if (prod.range.get() != s) continue; - term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), cost); - while (iter.has_next(cost)) { - unsigned new_cost = 0; - expr_ref_vector children = iter.next(new_cost); - expr_ref term = prod.builder(children); - levels.reserve(new_cost + 1, expr_ref_vector(m)); - if (m_oe.is_representative(term)) - levels[new_cost].push_back(term); - } - } - } - - // Sort by cost if cost function is set - if (m_cost && !levels[cost].empty()) { - expr_ref_vector& lv = levels[cost]; - std::sort(lv.data(), lv.data() + lv.size(), - [&](expr* a, expr* b) { return m_cost(a) < m_cost(b); }); - } + std::ostream& display(std::ostream& out) const { + return m_bottom_up_enumerator.display(out); } }; @@ -591,84 +480,85 @@ private: struct term_enumeration::iterator::iter_imp { imp& m_imp; + ast_manager & m; sort* m_sort; unsigned m_cost = 0; unsigned m_idx = 0; vector m_levels; - expr* m_current = nullptr; + expr_ref m_current; bool m_end; + vector m_vars; + vector> m_decls; + vector> m_names; - iter_imp(imp& i, sort* s) : m_imp(i), m_sort(s), m_end(false) { - expand_current_level(); - advance_to_valid(); + iter_imp(imp& i, sort* s) : m_imp(i), m(i.m), m_sort(s), m_current(i.m), m_end(false) { + m_imp.m_bottom_up_enumerator.reset(); + init_sort(); + advance(); } - + // Sentinel constructor iter_imp(imp& i) : - m_imp(i), m_sort(nullptr), m_end(true) {} - - void expand_current_level() { - if (m_cost >= m_levels.size()) - m_levels.resize(m_cost + 1, expr_ref_vector(m_imp.m)); - - if (!m_levels[m_cost].empty()) - return; - - if (m_cost == 0) { - for (auto const& prod : m_imp.m_grammar.leaves()) { - if (prod.range.get() != m_sort) continue; - expr_ref_vector empty_args(m_imp.m); - expr_ref term = prod.builder(empty_args); - if (m_imp.m_oe.is_representative(term)) - m_levels[0].push_back(term); - } - } - else { - for (auto const& prod : m_imp.m_grammar.operators()) { - if (prod.range.get() != m_sort) continue; - term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_cost); - while (iter.has_next(m_cost)) { - unsigned new_cost = 0; - expr_ref_vector children = iter.next(new_cost); - expr_ref term = prod.builder(children); - m_levels.reserve(new_cost + 1, expr_ref_vector(m_imp.m)); - if (m_imp.m_oe.is_representative(term)) - m_levels[new_cost].push_back(term); - } - } - } - - // Sort by cost if cost function is set - if (m_imp.m_cost && !m_levels[m_cost].empty()) { - expr_ref_vector& lv = m_levels[m_cost]; - std::sort(lv.data(), lv.data() + lv.size(), - [&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); }); - } + m_imp(i), m(i.m), m_sort(nullptr), m_current(i.m), m_end(true) { + UNREACHABLE(); } - void advance_to_valid() { - while (true) { - if (m_cost >= m_levels.size()) - expand_current_level(); - if (m_idx < m_levels[m_cost].size()) { - m_current = m_levels[m_cost].get(m_idx); - return; + + void init_sort() { + array_util autil(m); + sort *range = m_sort; + + while (autil.is_array(range)) { + m_vars.push_back(expr_ref_vector(m)); + m_decls.push_back(ptr_vector()); + m_names.push_back(vector()); + for (unsigned i = 0; i < get_array_arity(range); ++i) { + m_decls.back().push_back(get_array_domain(range, i)); + m_vars.back().push_back(nullptr); + m_names.back().push_back(symbol()); } - m_cost++; - m_idx = 0; - if (m_cost > 100) { - m_end = true; - m_current = nullptr; - return; + + #if 0 + // TODO: don't enable this until we ensure only generating whnf (beta-redex free) expressions. + expr_ref_vector args(m); + args.push_back(m.mk_const("a", range)); + for (unsigned i = 0; i < m_decls.back().size(); ++i) { + args.push_back(m.mk_var(i, m_decls.back().get(i))); } - expand_current_level(); + app_ref sel(autil.mk_select(args), m); + m_imp.m_grammar.add_func_decl(sel->get_decl()); + #endif + + range = get_array_range(range); } + unsigned n = 0; + for (unsigned i = m_decls.size(); i-- > 0;) { + for (unsigned j = m_decls[i].size(); j-- > 0;) { + m_vars[i][j] = m.mk_var(n, m_decls[i][j]); + m_names[i][j] = symbol(n); + m_imp.add_production(m_vars[i].get(j)); + n++; + } + } + m_sort = range; + m_imp.m_bottom_up_enumerator.set_target_sort(range); + } + + void mk_lambda() { + if (!m_current) + return; + for (unsigned i = m_decls.size(); i-- > 0;) + m_current = m.mk_lambda(m_decls[i].size(), m_decls[i].data(), m_names[i].data(), m_current); } void advance() { - if (m_end) return; - m_idx++; - advance_to_valid(); + if (m_end) + return; + m_current = m_imp.m_bottom_up_enumerator.next(); + SASSERT(!m_current || m_current->get_sort() == m_sort); + mk_lambda(); + if (!m_current) + m_end = true; } }; @@ -680,22 +570,6 @@ term_enumeration::iterator::iterator(std::nullptr_t) { m_imp = nullptr; } -term_enumeration::iterator::iterator(iterator const& other) { - m_imp = nullptr; - if (other.m_imp) - m_imp = alloc(iter_imp, *other.m_imp); -} - -term_enumeration::iterator& term_enumeration::iterator::operator=(iterator const& other) { - if (this != &other) { - dealloc(m_imp); - m_imp = nullptr; - if (other.m_imp) - m_imp = alloc(iter_imp, *other.m_imp); - } - return *this; -} - term_enumeration::iterator::~iterator() { dealloc(m_imp); } @@ -715,12 +589,12 @@ term_enumeration::iterator term_enumeration::iterator::operator++(int) { return tmp; } -bool term_enumeration::iterator::operator!=(iterator const& other) const { - if (!m_imp && !other.m_imp) return false; - if (!m_imp) return !other.m_imp->m_end; - if (!other.m_imp) return !m_imp->m_end; - return m_imp->m_end != other.m_imp->m_end || - m_imp->m_current != other.m_imp->m_current; +bool term_enumeration::iterator::operator==(iterator const& other) const { + if (!m_imp && !other.m_imp) return true; + if (!m_imp) return other.m_imp->m_end; + if (!other.m_imp) return m_imp->m_end; + return m_imp->m_end == other.m_imp->m_end && + m_imp->m_current == other.m_imp->m_current; } // -- terms implementation -- @@ -760,3 +634,7 @@ void term_enumeration::set_cost(std::function const& cost) { term_enumeration::terms term_enumeration::enum_terms(sort* s) { return terms(m_imp, s); } + +std::ostream& term_enumeration::display(std::ostream& out) const { + return m_imp->display(out); +} diff --git a/src/ast/term_enumeration.h b/src/ast/term_enumeration.h index 7934c0f4e..865b8d402 100644 --- a/src/ast/term_enumeration.h +++ b/src/ast/term_enumeration.h @@ -12,6 +12,7 @@ public: void add_production(func_decl* f); void add_production(expr* e); + // void add_production(sort *s, std::function g); // cost function associated with expressions. // terms are enumerated with increasing cost. @@ -24,13 +25,14 @@ public: public: iterator(imp& i, sort* s); iterator(std::nullptr_t); - iterator(iterator const& other); - iterator& operator=(iterator const& other); ~iterator(); expr* operator*(); iterator operator++(int); iterator& operator++(); - bool operator!=(iterator const& other) const; + bool operator!=(iterator const& other) const { + return !(*this == other); + } + bool operator==(iterator const &other) const; }; class terms { @@ -43,4 +45,6 @@ public: }; terms enum_terms(sort* s); + + std::ostream& display(std::ostream& out) const; }; \ No newline at end of file diff --git a/src/smt/smt_model_checker.cpp b/src/smt/smt_model_checker.cpp index c23cfe01e..f0196baad 100644 --- a/src/smt/smt_model_checker.cpp +++ b/src/smt/smt_model_checker.cpp @@ -219,7 +219,7 @@ namespace smt { if (use_inv) { unsigned sk_term_gen = 0; - expr * sk_term = m_model_finder.get_inv(q, i, sk_value, sk_term_gen); + expr * sk_term = m_model_finder.get_inv(q, i, sk_value, *cex, sk_term_gen); if (sk_term != nullptr) { TRACE(model_checker, tout << "Found inverse " << mk_pp(sk_term, m) << "\n";); SASSERT(!m.is_model_value(sk_term)); @@ -238,10 +238,6 @@ namespace smt { TRACE(model_checker, tout << "sk term " << mk_pp(sk_term, m) << "\n"); sk_value = sk_term; } - // last ditch: am I an array? - else if (false && autil.is_as_array(sk_value, f) && cex->get_func_interp(f) && cex->get_func_interp(f)->get_array_interp(f)) { - sk_value = cex->get_func_interp(f)->get_array_interp(f); - } } if (contains_model_value(sk_value)) { diff --git a/src/smt/smt_model_finder.cpp b/src/smt/smt_model_finder.cpp index 675a35200..94c9fa0ba 100644 --- a/src/smt/smt_model_finder.cpp +++ b/src/smt/smt_model_finder.cpp @@ -18,6 +18,7 @@ Revision History: --*/ #include "util/backtrackable_set.h" #include "ast/ast_util.h" +#include "ast/has_free_vars.h" #include "ast/macros/macro_util.h" #include "ast/arith_decl_plugin.h" #include "ast/bv_decl_plugin.h" @@ -108,9 +109,15 @@ namespace smt { } } - expr* get_inv(expr* v) const { + expr* get_inv(expr* v, model& mdl) const { expr* t = nullptr; m_inv.find(v, t); + if (!t) { + for (auto [k, term] : m_inv) { + if (mdl.are_equal(k, v)) + return term; + } + } return t; } @@ -121,14 +128,11 @@ namespace smt { } void mk_inverse(evaluator& ev) { - for (auto const& kv : m_elems) { - expr* t = kv.m_key; + for (auto const &[t, gen] : m_elems) { SASSERT(!contains_model_value(t)); - unsigned gen = kv.m_value; expr* t_val = ev.eval(t, true); if (!t_val) break; TRACE(model_finder, tout << mk_pp(t, m) << " " << mk_pp(t_val, m) << "\n";); - expr* old_t = nullptr; if (m_inv.find(t_val, old_t)) { unsigned old_t_gen = 0; @@ -292,7 +296,7 @@ namespace smt { } void insert(expr* n, unsigned generation) { - if (is_ground(n)) + if (is_ground(n) || (has_quantifiers(n) && !has_free_vars(n))) // this is a closed term get_root()->m_set->insert(n, generation); } @@ -600,7 +604,10 @@ namespace smt { } else { r = tmp; - TRACE(model_finder, tout << "eval\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";); + TRACE(model_finder, tout << "eval-failed\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";); + if (is_lambda(tmp)) { + r = m.mk_fresh_const("lambda", tmp->get_sort()); + } } m_eval_cache[model_completion].insert(n, r); m_eval_cache_range.push_back(r); @@ -1388,43 +1395,30 @@ namespace smt { } void display(std::ostream &out) const override { - out << "(" << "ho-var " << ":" << m_var_i << ")"; + out << "(" << "ho-var: " << m_var_i << ")"; } void process_auf(quantifier *q, auf_solver &s, context *ctx) override { + /* node * S_i = */ s.get_uvar(q, m_var_i); } void populate_inst_sets(quantifier *q, auf_solver &s, context *ctx) override { - node *S = s.get_uvar(q, m_var_i); sort *srt = S->get_sort(); - sort* range = get_array_range(srt); - unsigned arity = get_array_arity(srt); + IF_VERBOSE(0, verbose_stream() << "ho_var::populate_inst_sets: " << q->get_id() << " " << mk_pp(srt, m) << "\n";); term_enumeration tn(m); // Add ground terms of type S. // Add productions for functions in E-graph // add other possible relevant functions such as equality over srt, Boolean operators - // TODO: use term_enumerator to produce instances int the instantiation set of S. - expr_ref_vector vars(m); - ptr_vector sorts; - vector names; - for (unsigned i = 0; i < arity; ++i) { - vars.push_back(m.mk_var(i, get_array_domain(srt, i))); - auto v = vars.back(); - tn.add_production(v); - sorts.push_back(v->get_sort()); - names.push_back(symbol(i)); - } - auto mk_lambda = [&](expr* body) { - return m.mk_lambda(vars.size(), sorts.data(), names.data(), body); - }; + ast_mark visited; for (enode *n : ctx->enodes()) { - if (false && !ctx->is_relevant(n)) + if (!ctx->is_relevant(n)) continue; auto e = n->get_expr(); if (srt == n->get_sort()) { + IF_VERBOSE(0, verbose_stream() << "inserting " << mk_pp(e, m) << " into inst set\n"); S->insert(e, n->get_generation()); } else if (is_uninterp_const(e)) { @@ -1443,11 +1437,10 @@ namespace smt { unsigned max_count = 20; for (auto t : tn.enum_terms(srt)) { - auto lam = mk_lambda(t); unsigned generation = 0; // todo - inherited from sub-term of t? IF_VERBOSE(0, verbose_stream() << "ho_var: adding term " << mk_ismt2_pp(t, m) << " to instantiation set of S" << std::endl;); - S->insert(lam, generation); + S->insert(t, generation); } } }; @@ -2251,7 +2244,6 @@ namespace smt { } SASSERT(is_quantifier(atom)); - UNREACHABLE(); } void process_literal(expr* atom, polarity pol) { @@ -2603,11 +2595,12 @@ namespace smt { Store in generation the generation of the result */ - expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, unsigned& generation) { + expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, model& mdl,unsigned& generation) { instantiation_set const* s = get_uvar_inst_set(q, i); if (s == nullptr) return nullptr; - expr* t = s->get_inv(val); + + expr* t = s->get_inv(val, mdl); if (m_auf_solver->is_default_representative(t)) return val; if (t != nullptr) { @@ -2643,16 +2636,27 @@ namespace smt { obj_map const& inv = s->get_inv_map(); if (inv.empty()) continue; // nothing to do - ptr_buffer eqs; - for (auto const& [val, _] : inv) { - if (val->get_sort() == sk->get_sort()) - eqs.push_back(m.mk_eq(sk, val)); + expr_ref_vector eqs(m), defs(m); + + for (auto const& [val, term] : inv) { + if (val->get_sort() == sk->get_sort()) { + if (is_lambda(term)) { + eqs.push_back(m.mk_eq(sk, val)); + defs.push_back(m.mk_eq(val, term)); + } + else + eqs.push_back(m.mk_eq(sk, val)); + } } if (!eqs.empty()) { expr_ref new_cnstr(m); new_cnstr = m.mk_or(eqs); TRACE(model_finder, tout << "assert_restriction:\n" << mk_pp(new_cnstr, m) << "\n";); aux_ctx->assert_expr(new_cnstr); + for (auto def : defs) { + TRACE(model_finder, tout << "assert_def:\n" << mk_pp(def, m) << "\n";); + aux_ctx->assert_expr(def); + } asserted_something = true; } } diff --git a/src/smt/smt_model_finder.h b/src/smt/smt_model_finder.h index 1c468fc64..3b34e0192 100644 --- a/src/smt/smt_model_finder.h +++ b/src/smt/smt_model_finder.h @@ -113,7 +113,7 @@ namespace smt { void fix_model(proto_model * m); quantifier * get_flat_quantifier(quantifier * q); - expr * get_inv(quantifier * q, unsigned i, expr * val, unsigned & generation); + expr * get_inv(quantifier * q, unsigned i, expr * val, model& m, unsigned & generation); bool restrict_sks_to_inst_set(context * aux_ctx, quantifier * q, expr_ref_vector const & sks); void restart_eh(); diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 39050620a..d3b9b245a 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -143,6 +143,7 @@ add_executable(test-z3 symbol.cpp symbol_table.cpp tbv.cpp + term_enumeration.cpp theory_dl.cpp theory_pb.cpp timeout.cpp diff --git a/src/test/main.cpp b/src/test/main.cpp index 1727cb9dc..a6cb40970 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -193,7 +193,8 @@ X(ho_matcher) \ X(finite_set) \ X(finite_set_rewriter) \ - X(fpa) + X(fpa) \ + X(term_enumeration) #define FOR_EACH_TEST(X, X_ARGV) \ FOR_EACH_ALL_TEST(X, X_ARGV) \ diff --git a/src/test/term_enumeration.cpp b/src/test/term_enumeration.cpp new file mode 100644 index 000000000..5ab7069a1 --- /dev/null +++ b/src/test/term_enumeration.cpp @@ -0,0 +1,283 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + tst_term_enumeration.cpp + +Abstract: + + Test term enumeration module + +--*/ + +#include "ast/term_enumeration.h" +#include "ast/ast_pp.h" +#include "ast/arith_decl_plugin.h" +#include "ast/bv_decl_plugin.h" +#include "ast/array_decl_plugin.h" +#include "ast/reg_decl_plugins.h" +#include +#include + +static void tst_basic_enumeration() { + std::cout << "=== test basic enumeration ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add some leaf productions (constants) + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + te.add_production(zero); + te.add_production(one); + + // Enumerate terms of Int sort + sort* int_sort = a.mk_int(); + unsigned count = 0; + for (expr* e : te.enum_terms(int_sort)) { + std::cout << "Term: " << mk_pp(e, m) << "\n"; + count++; + if (count >= 5) break; // Limit output + } + + ENSURE(count >= 2); // At least 0 and 1 + std::cout << "Enumerated " << count << " terms\n"; +} + +static void tst_enumeration_with_operators() { + std::cout << "=== test enumeration with operators ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add leaf productions + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + te.add_production(zero); + te.add_production(one); + + // Add operator productions (+ and *) + // Get func_decl by creating an app and extracting the decl + app_ref tmp_add(a.mk_add(zero, one), m); + app_ref tmp_mul(a.mk_mul(zero, one), m); + func_decl* add_decl = tmp_add->get_decl(); + func_decl* mul_decl = tmp_mul->get_decl(); + te.add_production(add_decl); + te.add_production(mul_decl); + + sort* int_sort = a.mk_int(); + unsigned count = 0; + for (expr* e : te.enum_terms(int_sort)) { + std::cout << "Term: " << mk_pp(e, m) << "\n"; + count++; + if (count >= 20) break; // Limit output + } + + ENSURE(count >= 2); // At least the leaves + std::cout << "Enumerated " << count << " terms with operators\n"; +} + +static void tst_display() { + std::cout << "=== test display ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add leaf productions + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + te.add_production(zero); + te.add_production(one); + + // Add operator productions + app_ref tmp_add(a.mk_add(zero, one), m); + func_decl* add_decl = tmp_add->get_decl(); + te.add_production(add_decl); + + sort* int_sort = a.mk_int(); + unsigned count = 0; + for (expr* e : te.enum_terms(int_sort)) { + (void)e; + count++; + if (count >= 10) break; + } + + std::cout << "Internal state after enumeration:\n"; + std::ostringstream oss; + te.display(oss); + std::cout << oss.str(); + + // Verify display produced some output + ENSURE(!oss.str().empty()); +} + +static void tst_bitvector_enumeration() { + std::cout << "=== test bitvector enumeration ===\n"; + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + + term_enumeration te(m); + + // Add bitvector constants + unsigned bv_size = 8; + expr_ref bv_zero(bv.mk_numeral(0, bv_size), m); + expr_ref bv_one(bv.mk_numeral(1, bv_size), m); + te.add_production(bv_zero); + te.add_production(bv_one); + + // Add bvadd operator + app_ref tmp_add(bv.mk_bv_add(bv_zero, bv_one), m); + func_decl* bvadd = tmp_add->get_decl(); + te.add_production(bvadd); + + sort* bv8 = bv.mk_sort(bv_size); + unsigned count = 0; + for (expr* e : te.enum_terms(bv8)) { + std::cout << "BV Term: " << mk_pp(e, m) << "\n"; + count++; + if (count >= 10) break; + } + + ENSURE(count >= 2); + std::cout << "Enumerated " << count << " bitvector terms\n"; +} + +static void tst_multiple_sorts() { + std::cout << "=== test multiple sorts ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add Int constants + expr_ref i_zero(a.mk_int(0), m); + expr_ref i_one(a.mk_int(1), m); + te.add_production(i_zero); + te.add_production(i_one); + + // Add Real constants + expr_ref r_zero(a.mk_real(0), m); + expr_ref r_one(a.mk_real(1), m); + te.add_production(r_zero); + te.add_production(r_one); + + // Enumerate Int terms + sort* int_sort = a.mk_int(); + unsigned int_count = 0; + for (expr* e : te.enum_terms(int_sort)) { + std::cout << "Int Term: " << mk_pp(e, m) << "\n"; + int_count++; + if (int_count >= 5) break; + } + + ENSURE(int_count >= 2); + std::cout << "Enumerated " << int_count << " Int terms\n"; +} + +static void tst_nested_array_enumeration() { + std::cout << "=== test nested array enumeration (Array(A, Array(B, A))) ===\n"; + ast_manager m; + reg_decl_plugins(m); + array_util arr(m); + + term_enumeration te(m); + + // Create uninterpreted sorts A and B + sort_ref sort_A(m.mk_uninterpreted_sort(symbol("A")), m); + sort_ref sort_B(m.mk_uninterpreted_sort(symbol("B")), m); + + // Create nested array sort: Array(B, A) - arrays indexed by B returning A + sort_ref array_B_A(arr.mk_array_sort(sort_B, sort_A), m); + + // Create outer array sort: Array(A, Array(B, A)) - arrays indexed by A returning Array(B,A) + sort_ref array_A_arrayBA(arr.mk_array_sort(sort_A, array_B_A), m); + + std::cout << "Sort A: " << mk_pp(sort_A.get(), m) << "\n"; + std::cout << "Sort B: " << mk_pp(sort_B.get(), m) << "\n"; + std::cout << "Sort Array(B, A): " << mk_pp(array_B_A.get(), m) << "\n"; + std::cout << "Sort Array(A, Array(B, A)): " << mk_pp(array_A_arrayBA.get(), m) << "\n"; + + // Add constants of sort A + app_ref a0(m.mk_const(symbol("a0"), sort_A), m); + app_ref a1(m.mk_const(symbol("a1"), sort_A), m); + te.add_production(a0); + te.add_production(a1); + + // Add constants of sort B + app_ref b0(m.mk_const(symbol("b0"), sort_B), m); + app_ref b1(m.mk_const(symbol("b1"), sort_B), m); + te.add_production(b0); + te.add_production(b1); + + // Add a constant array of inner type Array(B, A) - const_array(a0) : Array(B, A) + app_ref const_inner(arr.mk_const_array(array_B_A, a0), m); + te.add_production(const_inner); + + // Add a constant array of outer type Array(A, Array(B, A)) + app_ref const_outer(arr.mk_const_array(array_A_arrayBA, const_inner), m); + te.add_production(const_outer); + + // Add store operator for the inner array type Array(B, A) + // store(array, index, value) : store(Array(B,A), B, A) -> Array(B,A) + expr* store_inner_args[3] = { const_inner.get(), b0.get(), a0.get() }; + app_ref tmp_store_inner(arr.mk_store(3, store_inner_args), m); + func_decl* store_inner_decl = tmp_store_inner->get_decl(); + te.add_production(store_inner_decl); + + // Add store operator for the outer array type Array(A, Array(B, A)) + // store(array, index, value) : store(Array(A, Array(B,A)), A, Array(B,A)) -> Array(A, Array(B,A)) + expr* store_outer_args[3] = { const_outer.get(), a0.get(), const_inner.get() }; + app_ref tmp_store_outer(arr.mk_store(3, store_outer_args), m); + func_decl* store_outer_decl = tmp_store_outer->get_decl(); + te.add_production(store_outer_decl); + + // Add select operator for the outer array (returns Array(B, A)) + // select(Array(A, Array(B,A)), A) -> Array(B, A) + app_ref tmp_select_outer(arr.mk_select(const_outer.get(), a0.get()), m); + func_decl* select_outer_decl = tmp_select_outer->get_decl(); + te.add_production(select_outer_decl); + + // Enumerate terms of the nested array sort Array(A, Array(B, A)) + std::cout << "\nEnumerating terms of sort Array(A, Array(B, A)):\n"; + unsigned count = 0; + for (expr* e : te.enum_terms(array_A_arrayBA)) { + std::cout << " Term " << count << ": " << mk_pp(e, m) << "\n"; + count++; + if (count >= 15) break; // Limit output + } + + ENSURE(count >= 1); // At least the constant array + std::cout << "Enumerated " << count << " terms of sort Array(A, Array(B, A))\n"; + + // Also enumerate terms of the inner array sort Array(B, A) + std::cout << "\nEnumerating terms of sort Array(B, A):\n"; + unsigned inner_count = 0; + for (expr* e : te.enum_terms(array_B_A)) { + std::cout << " Term " << inner_count << ": " << mk_pp(e, m) << "\n"; + inner_count++; + if (inner_count >= 10) break; + } + + // ENSURE(inner_count >= 1); + std::cout << "Enumerated " << inner_count << " terms of sort Array(B, A)\n"; + te.display(std::cout); +} + +void tst_term_enumeration() { + tst_basic_enumeration(); + tst_enumeration_with_operators(); + tst_display(); + tst_bitvector_enumeration(); + tst_multiple_sorts(); + tst_nested_array_enumeration(); + std::cout << "All term_enumeration tests passed!\n"; +}