diff --git a/src/ast/rewriter/CMakeLists.txt b/src/ast/rewriter/CMakeLists.txt index cfcc179bc..6db1320ab 100644 --- a/src/ast/rewriter/CMakeLists.txt +++ b/src/ast/rewriter/CMakeLists.txt @@ -43,6 +43,7 @@ z3_add_component(rewriter seq_rewriter.cpp seq_regex_bisim.cpp seq_skolem.cpp + term_enumeration.cpp th_rewriter.cpp value_sweep.cpp var_subst.cpp diff --git a/src/ast/rewriter/term_enumeration.cpp b/src/ast/rewriter/term_enumeration.cpp new file mode 100644 index 000000000..9b17ac1dc --- /dev/null +++ b/src/ast/rewriter/term_enumeration.cpp @@ -0,0 +1,674 @@ +/** + * term_enumeration.cpp - Bottom-up term enumeration module for Z3 + * + * Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning + * for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs. + * + * Key ideas: + * - Terms are enumerated bottom-up by "cost" (calculated by tree size). + * - A grammar describes which function symbols (operators) and leaves + * (constants, variables) are available for enumeration. + */ + +#include +#include +#include +#include "util/vector.h" +#include "util/scoped_ptr_vector.h" +#include "util/obj_hashtable.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" +#include "ast/rewriter/th_rewriter.h" +#include "ast/rewriter/term_enumeration.h" + + +namespace term_enum { + +// ============================================================================ +// grammar production rule +// ============================================================================ + +/** + * A production describes how to construct a term from child terms. + * - domain: the sort required for each child + * - range: the sort of the produced term + * - builder: given a vector of child exprs, produce the result expr + */ +struct production { + std::string name; + sort_ref range; + sort_ref_vector domain; + std::function builder; + + bool is_leaf() const { return domain.empty(); } +}; + +// ============================================================================ +// grammar +// ============================================================================ + +/** + * A grammar groups productions into leaves (arity 0) and operators (arity > 0). + */ +class grammar { +public: + grammar(ast_manager& m) : m(m), m_pinned(m) {} + + void add_production(production* p) { + if (p->is_leaf()) + m_leaves.push_back(p); + else + m_operators.push_back(p); + } + + scoped_ptr_vector const& leaves() const { return m_leaves; } + scoped_ptr_vector const& operators() const { return m_operators; } + ast_manager& mgr() const { return m; } + + void add_func_decl(func_decl *f) { + if (m_seen.contains(f)) + return; + m_pinned.push_back(f); + m_seen.insert(f); + sort_ref range(f->get_range(), m); + sort_ref_vector dom(m); + for (unsigned i = 0; i < f->get_arity(); ++i) + dom.push_back(sort_ref(f->get_domain(i), m)); + add_production(alloc(production, {f->get_name().str(), range, dom, [this, f](expr_ref_vector const &args) { + return expr_ref(m.mk_app(f, args), m); + }})); + } + + void add_expr(expr *e) { + if (m_seen.contains(e)) + return; + m_pinned.push_back(e); + m_seen.insert(e); + sort_ref range(e->get_sort(), m); + sort_ref_vector dom(m); + std::stringstream ss; + ss << mk_bounded_pp(e, m); + std::string name = ss.str(); + add_production(alloc(production, {name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }})); + } + + std::ostream& display(std::ostream& out) const { + out << "Leaves:\n"; + for (auto const *p : m_leaves) { + out << " " << p->name << " : " << mk_pp(p->range, m) << "\n"; + } + out << "Operators:\n"; + for (auto const *p : m_operators) { + out << " " << p->name << " : ("; + for (unsigned i = 0; i < p->domain.size(); ++i) { + if (i > 0) + out << ", "; + out << mk_pp(p->domain[i], m); + } + out << ") -> " << mk_pp(p->range, m) << "\n"; + } + return out; + } + +private: + ast_manager& m; + ast_ref_vector m_pinned; + scoped_ptr_vector m_leaves; + scoped_ptr_vector m_operators; + obj_hashtable m_seen; +}; + +// ============================================================================ +// Term Bank - stores enumerated terms by cost and sort +// ============================================================================ + +using cost_terms = vector>; + +class term_bank { + using sort_term_map = obj_map>; +public: + term_bank(ast_manager& m) : m(m), m_pinned(m) {} + + ~term_bank() { + for (auto s : m_terms) + dealloc(s); + } + + void reset() { + m_pinned.reset(); + m_terms.clear(); + } + + void add(expr* term, unsigned cost) { + sort* s = term->get_sort(); + m_pinned.push_back(term); + if (cost >= m_terms.size()) + m_terms.resize(cost + 1); + if (!m_terms[cost]) + m_terms[cost] = alloc(sort_term_map); + m_terms[cost]->insert_if_not_there(s, ptr_vector()).push_back(term); + } + + /** Get all terms of a given sort up to (and including) max_cost */ + cost_terms get_by_sort(sort* s, unsigned max_cost) const { + cost_terms result; + for (unsigned c = 0; c <= max_cost; ++c) { + if (c >= m_terms.size()) + break; + if (!m_terms[c]->contains(s)) + continue; + for (auto t : m_terms[c]->find(s)) + result.push_back({t, c}); + } + return result; + } + + // Return true if there is at least one term at/above `cost` whose sort is + // not in `sorts` (i.e., enumeration can still produce a new requested sort). + bool is_productive(unsigned cost, uint_set const& sorts) { + for (unsigned i = cost; i < m_terms.size(); ++i) { + if (!m_terms[i]) + continue; + for (auto const& entry : *m_terms[i]) { + sort* term_sort = entry.m_key; + if (!sorts.contains(term_sort->get_small_id())) + return true; + } + } + return false; + } + + ptr_vector null_ptr_vector; + ptr_vector const &get_by_cost_and_sort(unsigned cost, sort *s) const { + if (cost >= m_terms.size() || !m_terms[cost] || !m_terms[cost]->contains(s)) + return null_ptr_vector; + return m_terms[cost]->find(s); + } + + std::ostream& display(std::ostream& out) const { + for (unsigned cost = 0; cost < m_terms.size(); ++cost) { + if (!m_terms[cost]) + continue; + out << "cost " << cost << ":\n"; + for (auto& [s, terms] : *m_terms[cost]) { + out << " sort " << mk_pp(s, m) << ":\n"; + for (expr* e : terms) { + out << " #" << e->get_id() << " "; + if (cost == 0) { + out << mk_bounded_pp(e, m); + } + else if (is_app(e)) { + app* a = to_app(e); + out << a->get_decl()->get_name() << "("; + bool first = true; + for (expr* arg : *a) { + if (!first) out << ", "; + first = false; + out << "#" << arg->get_id(); + } + out << ")"; + } + out << "\n"; + } + } + } + return out; + } + +private: + ast_manager& m; + expr_ref_vector m_pinned; + // cost -> sort -> terms + ptr_vector m_terms; +}; + +// ============================================================================ +// Children Iterator - generates all combinations of child terms +// ============================================================================ + +/** + * Iterates over all tuples (c1, c2, ..., cn) where each ci has the required + * sort, drawn from the term bank, with at least one child at the current + * cost - 1 (to avoid regenerating previously seen terms). + */ +class children_iterator { +public: + children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_cost) + : m(m), m_prod(prod), m_current_cost(current_cost), m_done(false) + { + m_arity = prod.domain.size(); + if (m_arity == 0) { + m_done = true; + return; + } + for (unsigned i = 0; i < m_arity; ++i) { + m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_cost - 1)); + if (m_candidates.back().empty()) { + m_done = true; + return; + } + } + m_indices.resize(m_arity, 0); + } + + bool has_next(unsigned cost) { + while (!m_done) { + if (has_child_at_cost(cost)) + return true; + advance(); + } + return false; + } + + expr_ref_vector next(unsigned& cost) { + expr_ref_vector result(m); + cost = 1; + for (unsigned i = 0; i < m_arity; ++i) { + auto [e, c] = m_candidates[i].get(m_indices[i]); + cost += c; + result.push_back(e); + } + advance(); + return result; + } + +private: + ast_manager& m; + production const& m_prod; + unsigned m_current_cost; + unsigned m_arity; + bool m_done; + vector m_candidates; + svector m_indices; + + bool has_child_at_cost(unsigned cost) const { + for (unsigned i = 0; i < m_arity; ++i) { + auto [e, c] = m_candidates[i].get(m_indices[i]); + if (c + 1 == cost) + return true; + } + return false; + } + + void advance() { + for (auto i = m_arity; i-- > 0;) { + m_indices[i]++; + if (m_indices[i] < m_candidates[i].size()) return; + m_indices[i] = 0; + } + m_done = true; + } +}; + +// ============================================================================ +// bottom_up_enumerator - the main bottom-up term enumeration engine +// ============================================================================ + + +class bottom_up_enumerator { +public: + bottom_up_enumerator(grammar& grammar) + : m_grammar(grammar), m(grammar.mgr()), + m_bank(grammar.mgr()), m_pending(grammar.mgr()), m_rewriter(grammar.mgr()) + {} + + void set_target_sort(sort *s) { + m_target_sort = s; + } + bool has_next() { + if (m_pending) return true; + m_pending = find_next(); + return m_pending != nullptr; + } + + expr_ref next() { + if (!m_pending) + m_pending = find_next(); + expr_ref result(m_pending, m); + m_pending = nullptr; + return result; + } + + term_bank const& bank() const { return m_bank; } + + std::ostream& display(std::ostream& out) const { + m_grammar.display(out); + return m_bank.display(out); + } + + void reset() { + m_cost = 0; + m_leaf_idx = 0; + m_op_idx = 0; + m_state = State::Leaves; + m_bank.reset(); + m_pending = nullptr; + m_rewriter.reset(); + m_seen_terms.reset(); + m_children_iter.reset(); + } + + expr* add_term(expr_ref const& term, unsigned cost) { + expr_ref simplified(m); + m_rewriter(term, simplified); + if (m_seen_terms.contains(simplified)) + return nullptr; + m_seen_terms.insert(simplified); + m_bank.add(simplified, cost); + return simplified; + } + +private: + enum class State { Leaves, Operators, Done }; + + grammar& m_grammar; + ast_manager& m; + term_bank m_bank; + unsigned m_cost = 0; + unsigned m_leaf_idx = 0; + unsigned m_op_idx = 0; + unsigned m_bank_idx = 0; + unsigned m_bank_size = 0; + bool m_made_progress = false; + uint_set m_sorts_produced; + State m_state = State::Leaves; + expr_ref m_pending; + th_rewriter m_rewriter; + obj_hashtable m_seen_terms; + std::unique_ptr m_children_iter; + sort *m_target_sort = nullptr; + + bool sort_matches(expr* e) const { + return !m_target_sort || e->get_sort() == m_target_sort; + } + + expr* find_next() { + while (true) { + switch (m_state) { + case State::Leaves: + while (m_leaf_idx < m_grammar.leaves().size()) { + production const &prod = *m_grammar.leaves()[m_leaf_idx]; + m_leaf_idx++; + expr_ref_vector empty_args(m); + expr_ref term = prod.builder(empty_args); + expr* r = add_term(term, 0); + if (r && sort_matches(r)) + return r; + } + m_state = State::Operators; + m_cost = 1; + m_op_idx = 0; + m_bank_idx = 0; + m_bank_size = get_bank_size(); + m_made_progress = false; + m_sorts_produced.reset(); + m_children_iter.reset(); + break; + + case State::Operators: { + expr* result = enumerate_operators(); + if (result) + return result; + + m_cost++; + m_op_idx = 0; + m_bank_idx = 0; + m_bank_size = get_bank_size(); + m_children_iter.reset(); + if (!m_made_progress && !m_bank.is_productive(m_cost, m_sorts_produced)) { + m_state = State::Done; + return nullptr; + } + if (m_sorts_produced.contains(m_target_sort->get_small_id())) + m_sorts_produced.reset(); + m_made_progress = false; + break; + } + case State::Done: + return nullptr; + } + } + } + + unsigned get_bank_size() const { + auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort); + return terms.size(); + } + + expr *enumerate_operators() { + auto const &ops = m_grammar.operators(); + while (true) { + + // first find terms at m_cost that were already created + if (m_bank_idx < m_bank_size) { + auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort); + auto t = terms.get(m_bank_idx); + m_bank_idx++; + SASSERT(sort_matches(t)); + return t; + } + + // then create new terms using children at cost below current m_cost. + if (m_children_iter && m_children_iter->has_next(m_cost)) { + unsigned new_cost = 0; + expr_ref_vector children = m_children_iter->next(new_cost); + production const &prod = *ops[m_op_idx - 1]; + expr_ref term = prod.builder(children); + // IF_VERBOSE(0, verbose_stream() << term << "\n"); + SASSERT(new_cost >= m_cost); + expr* r = add_term(term, new_cost); + if (!r) + continue; + unsigned sort_id = r->get_sort()->get_small_id(); + if (!m_sorts_produced.contains(sort_id)) + m_made_progress = true; + m_sorts_produced.insert(sort_id); + if (sort_matches(r) && new_cost == m_cost) { + return r; + } + continue; + } + + if (m_op_idx >= ops.size()) + return nullptr; + + production const &prod = *ops[m_op_idx]; + m_op_idx++; + m_children_iter = std::make_unique(m, prod, m_bank, m_cost); + } + } +}; + +} // namespace term_enum + +// ============================================================================ +// term_enumeration public interface implementation +// ============================================================================ + +struct term_enumeration::imp { + ast_manager& m; + term_enum::grammar m_grammar; + term_enum::bottom_up_enumerator m_bottom_up_enumerator; + std::function m_cost; + + imp(ast_manager& m) : + m(m), m_grammar(m), m_bottom_up_enumerator(m_grammar) {} + + void add_production(func_decl* f) { + m_grammar.add_func_decl(f); + } + + void add_production(expr* e) { + m_grammar.add_expr(e); + } + + void set_cost(std::function const& cost) { + // TODO + } + + std::ostream& display(std::ostream& out) const { + return m_bottom_up_enumerator.display(out); + } +}; + +// -- iterator implementation -- + +struct term_enumeration::iterator::iter_imp { + imp& m_imp; + ast_manager & m; + sort* m_sort; + unsigned m_cost = 0; + unsigned m_idx = 0; + vector m_levels; + expr_ref m_current; + bool m_end; + vector m_vars; + vector> m_decls; + vector> m_names; + + iter_imp(imp& i, sort* s) : m_imp(i), m(i.m), m_sort(s), m_current(i.m), m_end(false) { + m_imp.m_bottom_up_enumerator.reset(); + init_sort(); + advance(); + } + + // Sentinel constructor + iter_imp(imp& i) : + m_imp(i), m(i.m), m_sort(nullptr), m_current(i.m), m_end(true) { + UNREACHABLE(); + } + + + void init_sort() { + array_util autil(m); + sort *range = m_sort; + + while (autil.is_array(range)) { + m_vars.push_back(expr_ref_vector(m)); + m_decls.push_back(ptr_vector()); + m_names.push_back(vector()); + for (unsigned i = 0; i < get_array_arity(range); ++i) { + m_decls.back().push_back(get_array_domain(range, i)); + m_vars.back().push_back(nullptr); + m_names.back().push_back(symbol()); + } + + expr_ref_vector args(m); + args.push_back(m.mk_const("a", range)); + for (unsigned i = 0; i < m_decls.back().size(); ++i) { + args.push_back(m.mk_var(i, m_decls.back().get(i))); + } + app_ref sel(autil.mk_select(args), m); + m_imp.m_grammar.add_func_decl(sel->get_decl()); + + range = get_array_range(range); + } + unsigned n = 0; + for (unsigned i = m_decls.size(); i-- > 0;) { + for (unsigned j = m_decls[i].size(); j-- > 0;) { + m_vars[i][j] = m.mk_var(n, m_decls[i][j]); + m_names[i][j] = symbol(n); + m_imp.add_production(m_vars[i].get(j)); + n++; + } + } + m_sort = range; + m_imp.m_bottom_up_enumerator.set_target_sort(range); + } + + void mk_lambda() { + if (!m_current) + return; + for (unsigned i = m_decls.size(); i-- > 0;) + m_current = m.mk_lambda(m_decls[i].size(), m_decls[i].data(), m_names[i].data(), m_current); + } + + void advance() { + if (m_end) + return; + m_current = m_imp.m_bottom_up_enumerator.next(); + SASSERT(!m_current || m_current->get_sort() == m_sort); + mk_lambda(); + if (!m_current) + m_end = true; + } +}; + +term_enumeration::iterator::iterator(imp& i, sort* s) { + m_imp = alloc(iter_imp, i, s); +} + +term_enumeration::iterator::iterator(std::nullptr_t) { + m_imp = nullptr; +} + +term_enumeration::iterator::~iterator() { + dealloc(m_imp); +} + +expr* term_enumeration::iterator::operator*() { + return m_imp ? m_imp->m_current.get() : nullptr; +} + +term_enumeration::iterator& term_enumeration::iterator::operator++() { + if (m_imp) m_imp->advance(); + return *this; +} + +term_enumeration::iterator term_enumeration::iterator::operator++(int) { + iterator tmp(*this); + ++(*this); + return tmp; +} + +bool term_enumeration::iterator::operator==(iterator const& other) const { + if (!m_imp && !other.m_imp) return true; + if (!m_imp) return other.m_imp->m_end; + if (!other.m_imp) return m_imp->m_end; + return m_imp->m_end == other.m_imp->m_end && + m_imp->m_current == other.m_imp->m_current; +} + +// -- terms implementation -- + +term_enumeration::terms::terms(imp* i, sort* s) : m_imp(i), m_sort(s) {} + +term_enumeration::iterator term_enumeration::terms::begin() { + return iterator(*m_imp, m_sort); +} + +term_enumeration::iterator term_enumeration::terms::end() { + return iterator(nullptr); +} + +// -- term_enumeration implementation -- + +term_enumeration::term_enumeration(ast_manager& m) { + m_imp = alloc(imp, m); +} + +term_enumeration::~term_enumeration() { + dealloc(m_imp); +} + +void term_enumeration::add_production(func_decl* f) { + m_imp->add_production(f); +} + +void term_enumeration::add_production(expr* e) { + m_imp->add_production(e); +} + +void term_enumeration::set_cost(std::function const& cost) { + m_imp->set_cost(cost); +} + +term_enumeration::terms term_enumeration::enum_terms(sort* s) { + return terms(m_imp, s); +} + +std::ostream& term_enumeration::display(std::ostream& out) const { + return m_imp->display(out); +} diff --git a/src/ast/rewriter/term_enumeration.h b/src/ast/rewriter/term_enumeration.h new file mode 100644 index 000000000..865b8d402 --- /dev/null +++ b/src/ast/rewriter/term_enumeration.h @@ -0,0 +1,50 @@ +#pragma once + +#include "ast/ast.h" +#include + +class term_enumeration { + struct imp; + imp* m_imp; +public: + term_enumeration(ast_manager& m); + ~term_enumeration(); + + void add_production(func_decl* f); + void add_production(expr* e); + // void add_production(sort *s, std::function g); + + // cost function associated with expressions. + // terms are enumerated with increasing cost. + + void set_cost(std::function const& cost); + + class iterator { + struct iter_imp; + iter_imp* m_imp; + public: + iterator(imp& i, sort* s); + iterator(std::nullptr_t); + ~iterator(); + expr* operator*(); + iterator operator++(int); + iterator& operator++(); + bool operator!=(iterator const& other) const { + return !(*this == other); + } + bool operator==(iterator const &other) const; + }; + + class terms { + imp* m_imp; + sort* m_sort; + public: + terms(imp* i, sort* s); + iterator begin(); + iterator end(); + }; + + terms enum_terms(sort* s); + + std::ostream& display(std::ostream& out) const; +}; \ No newline at end of file diff --git a/src/model/model_macro_solver.cpp b/src/model/model_macro_solver.cpp index 64881482e..0e1c44900 100644 --- a/src/model/model_macro_solver.cpp +++ b/src/model/model_macro_solver.cpp @@ -513,7 +513,7 @@ void non_auf_macro_solver::collect_candidates(ptr_vector const& qs, TRACE(model_finder, tout << "considering macro for: " << f->get_name() << "\n"; m->display(tout); tout << "\n";); if (m->is_unconditional() && (!qi->is_auf() || m->get_weight() >= m_mbqi_force_template)) { - full_macros.insert(f, std::make_pair(m, q)); + full_macros.insert(f, {m, q}); cond_macros.erase(f); } else if (!full_macros.contains(f) && !qi->is_auf()) @@ -524,10 +524,8 @@ void non_auf_macro_solver::collect_candidates(ptr_vector const& qs, } void non_auf_macro_solver::process_full_macros(obj_map const& full_macros, obj_hashtable& removed) { - for (auto const& kv : full_macros) { - func_decl* f = kv.m_key; - cond_macro* m = kv.m_value.first; - quantifier* q = kv.m_value.second; + for (auto const &[f, v] : full_macros) { + auto [m, q] = v; SASSERT(m->is_unconditional()); if (add_macro(f, m->get_def())) { get_qinfo(q)->set_the_one(f); diff --git a/src/smt/smt_model_checker.cpp b/src/smt/smt_model_checker.cpp index c23cfe01e..f0196baad 100644 --- a/src/smt/smt_model_checker.cpp +++ b/src/smt/smt_model_checker.cpp @@ -219,7 +219,7 @@ namespace smt { if (use_inv) { unsigned sk_term_gen = 0; - expr * sk_term = m_model_finder.get_inv(q, i, sk_value, sk_term_gen); + expr * sk_term = m_model_finder.get_inv(q, i, sk_value, *cex, sk_term_gen); if (sk_term != nullptr) { TRACE(model_checker, tout << "Found inverse " << mk_pp(sk_term, m) << "\n";); SASSERT(!m.is_model_value(sk_term)); @@ -238,10 +238,6 @@ namespace smt { TRACE(model_checker, tout << "sk term " << mk_pp(sk_term, m) << "\n"); sk_value = sk_term; } - // last ditch: am I an array? - else if (false && autil.is_as_array(sk_value, f) && cex->get_func_interp(f) && cex->get_func_interp(f)->get_array_interp(f)) { - sk_value = cex->get_func_interp(f)->get_array_interp(f); - } } if (contains_model_value(sk_value)) { diff --git a/src/smt/smt_model_finder.cpp b/src/smt/smt_model_finder.cpp index 8dd85f0b9..3b9cc31b3 100644 --- a/src/smt/smt_model_finder.cpp +++ b/src/smt/smt_model_finder.cpp @@ -18,6 +18,7 @@ Revision History: --*/ #include "util/backtrackable_set.h" #include "ast/ast_util.h" +#include "ast/has_free_vars.h" #include "ast/macros/macro_util.h" #include "ast/arith_decl_plugin.h" #include "ast/bv_decl_plugin.h" @@ -31,6 +32,7 @@ Revision History: #include "ast/ast_ll_pp.h" #include "ast/well_sorted.h" #include "ast/ast_smt2_pp.h" +#include "ast/rewriter/term_enumeration.h" #include "model/model_pp.h" #include "model/model_macro_solver.h" #include "smt/smt_model_finder.h" @@ -107,9 +109,15 @@ namespace smt { } } - expr* get_inv(expr* v) const { + expr* get_inv(expr* v, model& mdl) const { expr* t = nullptr; m_inv.find(v, t); + if (!t) { + for (auto [k, term] : m_inv) { + if (mdl.are_equal(k, v)) + return term; + } + } return t; } @@ -120,14 +128,11 @@ namespace smt { } void mk_inverse(evaluator& ev) { - for (auto const& kv : m_elems) { - expr* t = kv.m_key; + for (auto const &[t, gen] : m_elems) { SASSERT(!contains_model_value(t)); - unsigned gen = kv.m_value; expr* t_val = ev.eval(t, true); if (!t_val) break; TRACE(model_finder, tout << mk_pp(t, m) << " " << mk_pp(t_val, m) << "\n";); - expr* old_t = nullptr; if (m_inv.find(t_val, old_t)) { unsigned old_t_gen = 0; @@ -187,14 +192,14 @@ namespace smt { \brief Base class used to solve model construction constraints. */ class node { - unsigned m_id; - node* m_find{ nullptr }; - unsigned m_eqc_size{ 1 }; + unsigned m_id = 0; + node* m_find = nullptr; + unsigned m_eqc_size = 1; - sort* m_sort; // sort of the elements in the instantiation set. + sort* m_sort = nullptr; // sort of the elements in the instantiation set. - bool m_mono_proj{ false }; // relevant for integers & reals & bit-vectors - bool m_signed_proj{ false }; // relevant for bit-vectors. + bool m_mono_proj = false; // relevant for integers & reals & bit-vectors + bool m_signed_proj = false; // relevant for bit-vectors. ptr_vector m_avoid_set; ptr_vector m_exceptions; @@ -291,7 +296,7 @@ namespace smt { } void insert(expr* n, unsigned generation) { - if (is_ground(n)) + if (is_ground(n) || (has_quantifiers(n) && !has_free_vars(n))) // this is a closed term get_root()->m_set->insert(n, generation); } @@ -599,7 +604,10 @@ namespace smt { } else { r = tmp; - TRACE(model_finder, tout << "eval\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";); + TRACE(model_finder, tout << "eval-failed\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";); + if (is_lambda(tmp)) { + r = m.mk_fresh_const("lambda", tmp->get_sort()); + } } m_eval_cache[model_completion].insert(n, r); m_eval_cache_range.push_back(r); @@ -1235,8 +1243,8 @@ namespace smt { void populate_inst_sets(quantifier* q, func_decl* mhead, ptr_vector& uvar_inst_sets, context* ctx) override { if (m_f != mhead) return; - uvar_inst_sets.reserve(m_var_j + 1, 0); - if (uvar_inst_sets[m_var_j] == 0) + uvar_inst_sets.reserve(m_var_j + 1, nullptr); + if (uvar_inst_sets[m_var_j] == nullptr) uvar_inst_sets[m_var_j] = alloc(instantiation_set, ctx->get_manager()); instantiation_set* s = uvar_inst_sets[m_var_j]; SASSERT(s != nullptr); @@ -1369,6 +1377,74 @@ namespace smt { }; + class ho_var : public qinfo { + unsigned m_var_i; + public: + ho_var(ast_manager& m, unsigned i) : qinfo(m), m_var_i(i) { + } + + char const *get_kind() const override { + return "ho_var"; + } + + bool is_equal(qinfo const *qi) const override { + if (qi->get_kind() != get_kind()) + return false; + ho_var const *other = static_cast(qi); + return m_var_i == other->m_var_i; + } + + void display(std::ostream &out) const override { + out << "(" << "ho-var: " << m_var_i << ")"; + } + + void process_auf(quantifier *q, auf_solver &s, context *ctx) override { + /* node * S_i = */ s.get_uvar(q, m_var_i); + } + + void populate_inst_sets(quantifier *q, auf_solver &s, context *ctx) override { + node *S = s.get_uvar(q, m_var_i); + sort *srt = S->get_sort(); + + IF_VERBOSE(3, verbose_stream() << "ho_var::populate_inst_sets: " << q->get_id() << " " << mk_pp(srt, m) << "\n";); + term_enumeration tn(m); + // Add ground terms of type S. + // Add productions for functions in E-graph + // add other possible relevant functions such as equality over srt, Boolean operators + + ast_mark visited; + for (enode *n : ctx->enodes()) { + if (!ctx->is_relevant(n)) + continue; + auto e = n->get_expr(); + if (srt == n->get_sort()) { + TRACE(model_finder, tout << "inserting " << mk_pp(e, m) << " into inst set\n"); + S->insert(e, n->get_generation()); + } + else if (is_uninterp_const(e)) { + TRACE(model_finder, tout << "add production " << mk_pp(e, m) << "\n"); + tn.add_production(e); + } + else if (is_uninterp(e)) { + auto f = to_app(e)->get_decl(); + if (visited.is_marked(f)) + continue; + visited.mark(f, true); + TRACE(model_finder, tout << "add function " << mk_pp(f, m) << "\n"); + tn.add_production(f); + } + } + + unsigned max_count = 20; + for (auto t : tn.enum_terms(srt)) { + unsigned generation = 0; // todo - inherited from sub-term of t? + TRACE(model_finder, tout << "ho_var: adding term " << mk_ismt2_pp(t, m) + << " to instantiation set of S" << std::endl;); + S->insert(t, generation); + } + } + }; + /** \brief auf_arr is a term (pattern) of the form: @@ -2105,7 +2181,12 @@ namespace smt { process_app(to_app(curr)); } else if (is_var(curr)) { - m_info->m_is_auf = false; // unexpected occurrence of variable. + if (m_array_util.is_array(curr)) { + insert_qinfo(alloc(ho_var, m, to_var(curr)->get_idx())); + } + else { + m_info->m_is_auf = false; // unexpected occurrence of variable. + } } else { SASSERT(is_lambda(curr)); @@ -2520,11 +2601,12 @@ namespace smt { Store in generation the generation of the result */ - expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, unsigned& generation) { + expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, model& mdl,unsigned& generation) { instantiation_set const* s = get_uvar_inst_set(q, i); if (s == nullptr) return nullptr; - expr* t = s->get_inv(val); + + expr* t = s->get_inv(val, mdl); if (m_auf_solver->is_default_representative(t)) return val; if (t != nullptr) { @@ -2560,16 +2642,27 @@ namespace smt { obj_map const& inv = s->get_inv_map(); if (inv.empty()) continue; // nothing to do - ptr_buffer eqs; - for (auto const& [val, _] : inv) { - if (val->get_sort() == sk->get_sort()) - eqs.push_back(m.mk_eq(sk, val)); + expr_ref_vector eqs(m), defs(m); + + for (auto const& [val, term] : inv) { + if (val->get_sort() == sk->get_sort()) { + if (is_lambda(term)) { + eqs.push_back(m.mk_eq(sk, val)); + defs.push_back(m.mk_eq(val, term)); + } + else + eqs.push_back(m.mk_eq(sk, val)); + } } if (!eqs.empty()) { expr_ref new_cnstr(m); new_cnstr = m.mk_or(eqs); TRACE(model_finder, tout << "assert_restriction:\n" << mk_pp(new_cnstr, m) << "\n";); aux_ctx->assert_expr(new_cnstr); + for (auto def : defs) { + TRACE(model_finder, tout << "assert_def:\n" << mk_pp(def, m) << "\n";); + aux_ctx->assert_expr(def); + } asserted_something = true; } } diff --git a/src/smt/smt_model_finder.h b/src/smt/smt_model_finder.h index 1c468fc64..3b34e0192 100644 --- a/src/smt/smt_model_finder.h +++ b/src/smt/smt_model_finder.h @@ -113,7 +113,7 @@ namespace smt { void fix_model(proto_model * m); quantifier * get_flat_quantifier(quantifier * q); - expr * get_inv(quantifier * q, unsigned i, expr * val, unsigned & generation); + expr * get_inv(quantifier * q, unsigned i, expr * val, model& m, unsigned & generation); bool restrict_sks_to_inst_set(context * aux_ctx, quantifier * q, expr_ref_vector const & sks); void restart_eh(); diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index d84892cd5..404cf4553 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -145,6 +145,7 @@ add_executable(test-z3 symbol.cpp symbol_table.cpp tbv.cpp + term_enumeration.cpp theory_dl.cpp theory_pb.cpp timeout.cpp diff --git a/src/test/main.cpp b/src/test/main.cpp index 4eb66798f..b78e38789 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -195,6 +195,7 @@ X(finite_set) \ X(finite_set_rewriter) \ X(fpa) \ + X(term_enumeration) \ X(lcube) #define FOR_EACH_TEST(X, X_ARGV) \ diff --git a/src/test/term_enumeration.cpp b/src/test/term_enumeration.cpp new file mode 100644 index 000000000..57b5da852 --- /dev/null +++ b/src/test/term_enumeration.cpp @@ -0,0 +1,309 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + tst_term_enumeration.cpp + +Abstract: + + Test term enumeration module + +--*/ + + +#include "ast/rewriter/term_enumeration.h" +#include "ast/ast_pp.h" +#include "ast/arith_decl_plugin.h" +#include "ast/bv_decl_plugin.h" +#include "ast/array_decl_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/rewriter/th_rewriter.h" +#include "util/obj_hashtable.h" +#include +#include + +static void tst_basic_enumeration() { + std::cout << "=== test basic enumeration ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add some leaf productions (constants) + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + te.add_production(zero); + te.add_production(one); + + // Enumerate terms of Int sort + sort* int_sort = a.mk_int(); + unsigned count = 0; + for (expr* e : te.enum_terms(int_sort)) { + std::cout << "Term: " << mk_pp(e, m) << "\n"; + count++; + if (count >= 5) break; // Limit output + } + + ENSURE(count >= 2); // At least 0 and 1 + std::cout << "Enumerated " << count << " terms\n"; +} + +static void tst_enumeration_with_operators() { + std::cout << "=== test enumeration with operators ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add leaf productions + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + te.add_production(zero); + te.add_production(one); + + // Add operator productions (+ and *) + // Get func_decl by creating an app and extracting the decl + app_ref tmp_add(a.mk_add(zero, one), m); + app_ref tmp_mul(a.mk_mul(zero, one), m); + func_decl* add_decl = tmp_add->get_decl(); + func_decl* mul_decl = tmp_mul->get_decl(); + te.add_production(add_decl); + te.add_production(mul_decl); + + sort* int_sort = a.mk_int(); + unsigned count = 0; + for (expr* e : te.enum_terms(int_sort)) { + std::cout << "Term: " << mk_pp(e, m) << "\n"; + count++; + if (count >= 20) break; // Limit output + } + + ENSURE(count >= 2); // At least the leaves + std::cout << "Enumerated " << count << " terms with operators\n"; +} + +static void tst_observational_equivalence_filter() { + std::cout << "=== test observational equivalence filter ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + th_rewriter rw(m); + + term_enumeration te(m); + + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + te.add_production(zero); + te.add_production(one); + + app_ref tmp_add(a.mk_add(zero, one), m); + te.add_production(tmp_add->get_decl()); + + sort* int_sort = a.mk_int(); + obj_hashtable seen; + unsigned count = 0; + for (expr* e : te.enum_terms(int_sort)) { + expr_ref r(m); + rw(e, r); + ENSURE(r == e); + ENSURE(!seen.contains(r)); + seen.insert(r); + count++; + if (count >= 20) break; + } + + ENSURE(count >= 2); +} + +static void tst_display() { + std::cout << "=== test display ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add leaf productions + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + te.add_production(zero); + te.add_production(one); + + // Add operator productions + app_ref tmp_add(a.mk_add(zero, one), m); + func_decl* add_decl = tmp_add->get_decl(); + te.add_production(add_decl); + + sort* int_sort = a.mk_int(); + unsigned count = 0; + for (expr* e : te.enum_terms(int_sort)) { + (void)e; + count++; + if (count >= 10) break; + } + + std::cout << "Internal state after enumeration:\n"; + std::ostringstream oss; + te.display(oss); + std::cout << oss.str(); + + // Verify display produced some output + ENSURE(!oss.str().empty()); +} + +static void tst_bitvector_enumeration() { + std::cout << "=== test bitvector enumeration ===\n"; + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + + term_enumeration te(m); + + // Add bitvector constants + unsigned bv_size = 8; + expr_ref bv_zero(bv.mk_numeral(0, bv_size), m); + expr_ref bv_one(bv.mk_numeral(1, bv_size), m); + te.add_production(bv_zero); + te.add_production(bv_one); + + // Add bvadd operator + app_ref tmp_add(bv.mk_bv_add(bv_zero, bv_one), m); + func_decl* bvadd = tmp_add->get_decl(); + te.add_production(bvadd); + + sort* bv8 = bv.mk_sort(bv_size); + unsigned count = 0; + for (expr* e : te.enum_terms(bv8)) { + std::cout << "BV Term: " << mk_pp(e, m) << "\n"; + count++; + if (count >= 10) break; + } + + ENSURE(count >= 2); + std::cout << "Enumerated " << count << " bitvector terms\n"; +} + +static void tst_multiple_sorts() { + std::cout << "=== test multiple sorts ===\n"; + ast_manager m; + reg_decl_plugins(m); + arith_util a(m); + + term_enumeration te(m); + + // Add Int constants + expr_ref i_zero(a.mk_int(0), m); + expr_ref i_one(a.mk_int(1), m); + te.add_production(i_zero); + te.add_production(i_one); + + // Add Real constants + expr_ref r_zero(a.mk_real(0), m); + expr_ref r_one(a.mk_real(1), m); + te.add_production(r_zero); + te.add_production(r_one); + + // Enumerate Int terms + sort* int_sort = a.mk_int(); + unsigned int_count = 0; + for (expr* e : te.enum_terms(int_sort)) { + std::cout << "Int Term: " << mk_pp(e, m) << "\n"; + int_count++; + if (int_count >= 5) break; + } + + ENSURE(int_count >= 2); + std::cout << "Enumerated " << int_count << " Int terms\n"; +} + +static void tst_nested_array_enumeration() { + std::cout << "=== test nested array enumeration (Array(A, Array(B, A))) ===\n"; + ast_manager m; + reg_decl_plugins(m); + array_util arr(m); + + term_enumeration te(m); + + // Create uninterpreted sorts A and B + sort_ref sort_A(m.mk_uninterpreted_sort(symbol("A")), m); + sort_ref sort_B(m.mk_uninterpreted_sort(symbol("B")), m); + + // Create nested array sort: Array(B, A) - arrays indexed by B returning A + sort_ref array_B_A(arr.mk_array_sort(sort_B, sort_A), m); + + // Create outer array sort: Array(A, Array(B, A)) - arrays indexed by A returning Array(B,A) + sort_ref array_A_arrayBA(arr.mk_array_sort(sort_A, array_B_A), m); + + std::cout << "Sort A: " << mk_pp(sort_A.get(), m) << "\n"; + std::cout << "Sort B: " << mk_pp(sort_B.get(), m) << "\n"; + std::cout << "Sort Array(B, A): " << mk_pp(array_B_A.get(), m) << "\n"; + std::cout << "Sort Array(A, Array(B, A)): " << mk_pp(array_A_arrayBA.get(), m) << "\n"; + + // Add constants of sort A + app_ref a0(m.mk_const(symbol("a0"), sort_A), m); + app_ref a1(m.mk_const(symbol("a1"), sort_A), m); + te.add_production(a0); + te.add_production(a1); + + // Add constants of sort B + app_ref b0(m.mk_const(symbol("b0"), sort_B), m); + app_ref b1(m.mk_const(symbol("b1"), sort_B), m); + te.add_production(b0); + te.add_production(b1); + + // Add a constant array of inner type Array(B, A) - const_array(a0) : Array(B, A) + app_ref const_inner(arr.mk_const_array(array_B_A, a0), m); + te.add_production(const_inner); + + // Add a constant array of outer type Array(A, Array(B, A)) + app_ref const_outer(arr.mk_const_array(array_A_arrayBA, const_inner), m); + te.add_production(const_outer); + + // Add store operator for the inner array type Array(B, A) + // store(array, index, value) : store(Array(B,A), B, A) -> Array(B,A) + expr* store_inner_args[3] = { const_inner.get(), b0.get(), a0.get() }; + app_ref tmp_store_inner(arr.mk_store(3, store_inner_args), m); + func_decl* store_inner_decl = tmp_store_inner->get_decl(); + te.add_production(store_inner_decl); + + // Add store operator for the outer array type Array(A, Array(B, A)) + // store(array, index, value) : store(Array(A, Array(B,A)), A, Array(B,A)) -> Array(A, Array(B,A)) + expr* store_outer_args[3] = { const_outer.get(), a0.get(), const_inner.get() }; + app_ref tmp_store_outer(arr.mk_store(3, store_outer_args), m); + func_decl* store_outer_decl = tmp_store_outer->get_decl(); + te.add_production(store_outer_decl); + + // Add select operator for the outer array (returns Array(B, A)) + // select(Array(A, Array(B,A)), A) -> Array(B, A) + app_ref tmp_select_outer(arr.mk_select(const_outer.get(), a0.get()), m); + func_decl* select_outer_decl = tmp_select_outer->get_decl(); + te.add_production(select_outer_decl); + + // Enumerate terms of the nested array sort Array(A, Array(B, A)) + std::cout << "\nEnumerating terms of sort Array(A, Array(B, A)):\n"; + unsigned count = 0; + for (expr* e : te.enum_terms(array_A_arrayBA)) { + std::cout << " Term " << count << ": " << mk_pp(e, m) << "\n"; + count++; + if (count >= 15) break; // Limit output + } + + ENSURE(count >= 1); // At least the constant array + std::cout << "Enumerated " << count << " terms of sort Array(A, Array(B, A))\n"; + + te.display(std::cout); +} + +void tst_term_enumeration() { + tst_basic_enumeration(); + tst_enumeration_with_operators(); + tst_observational_equivalence_filter(); + tst_display(); + tst_bitvector_enumeration(); + tst_multiple_sorts(); + tst_nested_array_enumeration(); + std::cout << "All term_enumeration tests passed!\n"; +}