diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 99c9a2fd8..8558eb925 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -93,7 +93,7 @@ namespace euf { return; for (auto arg : enode_args(n)) if (is_op(arg)) - register_shared(arg); // TODO optimization to avoid registering shared terms twice + register_shared(arg); } void ac_plugin::register_shared(enode* n) { @@ -180,7 +180,7 @@ namespace euf { std::ostream& ac_plugin::display_monomial(std::ostream& out, ptr_vector const& m) const { for (auto n : m) { if (n->n->num_args() == 0) - out << mk_pp(n->n->get_expr(), g.get_manager()) << " "; + out << n->n->get_expr_id() << ": " << mk_pp(n->n->get_expr(), g.get_manager()) << " "; else out << g.bpp(n->n) << " "; } @@ -244,6 +244,7 @@ namespace euf { if (l == r) return; auto j = justification::equality(l, r); + TRACE(plugin, tout << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n"); if (!is_op(l) && !is_op(r)) merge(mk_node(l), mk_node(r), j); else @@ -263,6 +264,7 @@ namespace euf { void ac_plugin::init_equation(eq const& e) { m_eqs.push_back(e); auto& eq = m_eqs.back(); + TRACE(plugin, display_equation(tout, e) << "\n"); if (orient_equation(eq)) { unsigned eq_id = m_eqs.size() - 1; @@ -273,6 +275,8 @@ namespace euf { n->root->n->mark1(); push_undo(is_add_eq_index); m_node_trail.push_back(n->root); + for (auto s : n->root->shared) + m_shared_todo.insert(s); } } @@ -282,6 +286,8 @@ namespace euf { n->root->n->mark1(); push_undo(is_add_eq_index); m_node_trail.push_back(n->root); + for (auto s : n->root->shared) + m_shared_todo.insert(s); } } @@ -291,6 +297,7 @@ namespace euf { for (auto n : monomial(eq.r)) n->root->n->unmark1(); + TRACE(plugin, display_equation(tout, e) << "\n"); m_to_simplify_todo.insert(eq_id); } else @@ -368,6 +375,7 @@ namespace euf { } void ac_plugin::merge(node* root, node* other, justification j) { + TRACE(plugin, tout << root << " == " << other << " num shared " << other->shared.size() << "\n"); for (auto n : equiv(other)) n->root = root; m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() }); @@ -394,22 +402,34 @@ namespace euf { ptr_vector m; ns.push_back(n); for (unsigned i = 0; i < ns.size(); ++i) { - n = ns[i]; - if (is_op(n)) - ns.append(n->num_args(), n->args()); + auto k = ns[i]; + if (is_op(k)) + ns.append(k->num_args(), k->args()); else - m.push_back(mk_node(n)); + m.push_back(mk_node(k)); } return to_monomial(n, m); } unsigned ac_plugin::to_monomial(enode* e, ptr_vector const& ms) { unsigned id = m_monomials.size(); - m_monomials.push_back({ ms, bloom() }); + m_monomials.push_back({ ms, bloom(), e }); push_undo(is_add_monomial); return id; } + enode* ac_plugin::from_monomial(ptr_vector const& mon) { + auto& m = g.get_manager(); + ptr_buffer args; + enode_vector nodes; + for (auto arg : mon) { + nodes.push_back(arg->root->n); + args.push_back(arg->root->n->get_expr()); + } + auto n = m.mk_app(m_fid, m_op, args.size(), args.data()); + return g.mk(n, 0, nodes.size(), nodes.data()); + } + ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) { auto* mem = r.allocate(sizeof(node)); node* res = new (mem) node(); @@ -427,6 +447,9 @@ namespace euf { push_undo(is_add_node); m_nodes.setx(id, r, nullptr); m_node_trail.push_back(r); + if (is_op(n)) { + // extract shared sub-expressions + } return r; } @@ -983,6 +1006,7 @@ namespace euf { // void ac_plugin::propagate_shared() { + TRACE(plugin, tout << "num shared todo " << m_shared_todo.size() << "\n"); if (m_shared_todo.empty()) return; while (!m_shared_todo.empty()) { @@ -1007,12 +1031,15 @@ namespace euf { void ac_plugin::simplify_shared(unsigned idx, shared s) { auto j = s.j; auto old_m = s.m; + auto old_n = monomial(old_m).m_src; ptr_vector m1(monomial(old_m).m_nodes); - TRACE(plugin, tout << "simplify " << m_pp(*this, monomial(old_m)) << "\n"); + TRACE(plugin, tout << "simplify " << g.bpp(old_n) << ": " << m_pp(*this, monomial(old_m)) << "\n"); if (!reduce(m1, j)) return; - auto new_m = to_monomial(m1); + + auto new_n = from_monomial(m1); + auto new_m = to_monomial(new_n, m1); // update shared occurrences for members of the new monomial that are not already in the old monomial. for (auto n : monomial(old_m)) n->root->n->mark1(); @@ -1029,6 +1056,10 @@ namespace euf { push_undo(is_update_shared); m_shared[idx].m = new_m; m_shared[idx].j = j; + + TRACE(plugin, tout << "shared simplified to " << m_pp(*this, monomial(new_m)) << "\n"); + + push_merge(old_n, new_n, j); } justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) { diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index 809ac55cf..290ddb561 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -97,6 +97,7 @@ namespace euf { struct monomial_t { ptr_vector m_nodes; bloom m_bloom; + enode* m_src = nullptr; node* operator[](unsigned i) const { return m_nodes[i]; } unsigned size() const { return m_nodes.size(); } void set(ptr_vector const& ns) { m_nodes.reset(); m_nodes.append(ns); m_bloom.m_tick = 0; } @@ -187,6 +188,7 @@ namespace euf { unsigned to_monomial(enode* n); unsigned to_monomial(enode* n, ptr_vector const& ms); unsigned to_monomial(ptr_vector const& ms) { return to_monomial(nullptr, ms); } + enode* from_monomial(ptr_vector const& m); monomial_t const& monomial(unsigned i) const { return m_monomials[i]; } monomial_t& monomial(unsigned i) { return m_monomials[i]; } void sort(monomial_t& monomial); diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp index 268eff38d..317b192c6 100644 --- a/src/ast/euf/euf_arith_plugin.cpp +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -33,10 +33,13 @@ namespace euf { } void arith_plugin::register_node(enode* n) { - // no-op + TRACE(plugin, tout << g.bpp(n) << "\n"); + m_add.register_node(n); + m_mul.register_node(n); } void arith_plugin::merge_eh(enode* n1, enode* n2) { + TRACE(plugin, tout << g.bpp(n1) << " == " << g.bpp(n2) << "\n"); m_add.merge_eh(n1, n2); m_mul.merge_eh(n1, n2); } diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index a3f07e9cb..10ed000c6 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -310,6 +310,13 @@ namespace euf { } } + void egraph::register_shared(enode* n, theory_id id) { + force_push(); + auto* p = get_plugin(id); + if (p) + p->register_node(n); + } + void egraph::undo_add_th_var(enode* n, theory_id tid) { theory_var v = n->get_th_var(tid); SASSERT(v != null_theory_var); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 4280b4780..53a0b7da2 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -318,6 +318,7 @@ namespace euf { void add_th_var(enode* n, theory_var v, theory_id id); + void register_shared(enode* n, theory_id id); void set_th_propagates_diseqs(theory_id id); void set_cgc_enabled(enode* n, bool enable_cgc); void set_merge_tf_enabled(enode* n, bool enable_merge_tf); diff --git a/src/ast/euf/euf_mam.cpp b/src/ast/euf/euf_mam.cpp index c87cd4197..ea928415c 100644 --- a/src/ast/euf/euf_mam.cpp +++ b/src/ast/euf/euf_mam.cpp @@ -649,7 +649,7 @@ namespace euf { } bool is_ac(func_decl* f) const { - return false && f->is_associative() && f->is_commutative(); + return f->is_associative() && f->is_commutative(); } instruction * mk_init(func_decl* f, unsigned n) { @@ -1777,6 +1777,10 @@ namespace euf { m_use_filters(use_filters) { } + bool is_ac(func_decl* f) const { + return f->is_associative() && f->is_commutative(); + } + /** \brief Create a new code tree for the given quantifier. @@ -1791,6 +1795,8 @@ namespace euf { code_tree * r = m_ct_manager.mk_code_tree(p->get_decl(), num_args, filter_candidates); init(r, qa, mp, first_idx); linearise(r->m_root, first_idx); + if (is_ac(p->get_decl())) + ++m_num_choices; r->m_num_choices = m_num_choices; TRACE(mam_compiler, tout << "new tree for:\n" << mk_pp(mp, m) << "\n" << *r;); return r; @@ -1861,9 +1867,6 @@ namespace euf { unsigned m_old_max_generation; union { enode * m_curr; - struct { - unsigned m_next_pattern; - }; struct { enode_vector * m_to_recycle; enode * const * m_it; @@ -2009,7 +2012,7 @@ namespace euf { void display_pc_info(std::ostream & out); - bool match_ac(initn const* pc); + bool next_ac_match(initn const* pc); #define INIT_ARGS_SIZE 16 @@ -2291,9 +2294,53 @@ namespace euf { // Established: use Diophantine equations to capture matchability. // - bool interpreter::match_ac(initn const* pc) { + bool interpreter::next_ac_match(initn const* pc) { unsigned f_args = pc->m_num_args; SASSERT(f_args <= m_acargs.size()); + for (unsigned i = f_args; i-- > 0;) { + unsigned j = m_acpatarg[i]; + m_acbitset[j] = false; + next_j: + ++j; + for (; j < m_acargs.size(); ++j) { + if (m_acbitset[j]) + continue; + m_registers[i + 1] = m_acargs[j]; + m_acbitset[j] = true; + m_acpatarg[i] = j; + break; + } + if (j == m_acargs.size()) + continue; + + for (unsigned ii = i + 1; ii < f_args; ++ii) { + unsigned k = 0; + // populate arguments after i + for (; k < m_acargs.size(); ++k) { + if (!m_acbitset[k]) { + m_registers[ii + 1] = m_acargs[k]; + m_acbitset[k] = true; + m_acpatarg[ii] = k; + break; + } + } + if (k == m_acargs.size()) { + --ii; + // clean up + for (; ii >= i; --ii) { + k = m_acpatarg[ii]; + m_acbitset[k] = false; + } + goto next_j; + } + } + IF_VERBOSE(2, + verbose_stream() << "next ac: "; + for (unsigned j = 0; j < f_args; ++j) + verbose_stream() << m_acpatarg[j] << " "; + verbose_stream() << "\n";); + return true; + } return false; } @@ -2412,6 +2459,7 @@ namespace euf { m_acargs.reset(); m_acargs.push_back(m_app); auto* f = m_app->get_decl(); + auto num_pat_args = static_cast(m_pc)->m_num_args; for (unsigned i = 0; i < m_acargs.size(); ++i) { auto* arg = m_acargs[i]; if (is_app(arg->get_expr()) && f == arg->get_decl()) { @@ -2421,7 +2469,7 @@ namespace euf { --i; } } - if (static_cast(m_pc)->m_num_args > m_acargs.size()) + if (num_pat_args > m_acargs.size()) goto backtrack; m_acbitset.reset(); m_acbitset.reserve(m_acargs.size(), false); @@ -2429,11 +2477,12 @@ namespace euf { m_acpatarg.reserve(m_acargs.size(), 0); m_backtrack_stack[m_top].m_instr = m_pc; m_backtrack_stack[m_top].m_old_max_generation = m_curr_max_generation; - m_backtrack_stack[m_top].m_next_pattern = 0; - ++m_top; - // perform the match relative index - if (!match_ac(static_cast(m_pc))) - goto backtrack; + ++m_top; + for (unsigned i = 0; i < num_pat_args; ++i) { + m_acpatarg[i] = i; + m_acbitset[i] = true; + m_registers[i + 1] = m_acargs[i]; + } m_pc = m_pc->m_next; goto main_loop; } @@ -2499,7 +2548,7 @@ namespace euf { m_app = get_first_f_app(static_cast(m_pc)->m_label, static_cast(m_pc)->m_num_args, m_n1); \ if (!m_app) \ goto backtrack; \ - TRACE(mam_int, tout << "bind candidate: " << mk_pp(m_app->get_expr(), m) << "\n";); \ + TRACE(mam_int, tout << "bind candidate: " << mk_pp(m_app->get_expr(), m) << " " << m_top << " " << m_backtrack_stack.size() << "\n";); \ m_backtrack_stack[m_top].m_instr = m_pc; \ m_backtrack_stack[m_top].m_old_max_generation = m_curr_max_generation; \ m_backtrack_stack[m_top].m_curr = m_app; \ @@ -2832,7 +2881,11 @@ namespace euf { case INITAC: // this is a backtracking point. - NOT_IMPLEMENTED_YET(); + if (!next_ac_match(static_cast(bp.m_instr))) { + --m_top; + goto backtrack; + } + m_pc = bp.m_instr->m_next; goto main_loop; case CONTINUE: diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index d4eee6832..0d685e903 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -51,6 +51,8 @@ Mam optimization? #include "ast/ast_pp.h" #include "ast/ast_util.h" #include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_arith_plugin.h" +#include "ast/euf/euf_bv_plugin.h" #include "ast/rewriter/var_subst.h" #include "ast/simplifiers/euf_completion.h" #include "ast/shared_occs.h" @@ -87,6 +89,9 @@ namespace euf { m_egraph.set_on_merge(_on_merge); m_egraph.set_on_make(_on_make); + + m_egraph.add_plugin(alloc(arith_plugin, m_egraph)); + m_egraph.add_plugin(alloc(bv_plugin, m_egraph)); } completion::~completion() { @@ -203,6 +208,7 @@ namespace euf { if (!m_should_propagate && !should_stop()) propagate_all_rules(); } + TRACE(euf, m_egraph.display(tout)); } unsigned completion::push_pr_dep(proof* pr, expr_dependency* d) { @@ -520,7 +526,7 @@ namespace euf { if (g != f) { m_fmls.update(i, dependent_expr(m, g, pr, dep)); m_stats.m_num_rewrites++; - IF_VERBOSE(0, verbose_stream() << mk_bounded_pp(f, m, 3) << " -> " << mk_bounded_pp(g, m, 3) << "\n"); + IF_VERBOSE(2, verbose_stream() << mk_bounded_pp(f, m, 3) << " -> " << mk_bounded_pp(g, m, 3) << "\n"); update_has_new_eq(g); } CTRACE(euf_completion, g != f, tout << mk_bounded_pp(f, m) << " -> " << mk_bounded_pp(g, m) << "\n"); @@ -579,7 +585,16 @@ namespace euf { m_todo.push_back(arg); } if (sz == m_todo.size()) { - m_nodes_to_canonize.push_back(m_egraph.mk(e, m_generation, m_args.size(), m_args.data())); + n = m_egraph.mk(e, m_generation, m_args.size(), m_args.data()); + if (m_egraph.get_plugin(e->get_sort()->get_family_id())) + m_egraph.add_th_var(n, m_th_var++, e->get_sort()->get_family_id()); + if (!m.is_eq(e)) { + for (auto ch : m_args) + for (auto idv : euf::enode_th_vars(*ch)) + m_egraph.register_shared(n, idv.get_id()); + } + + m_nodes_to_canonize.push_back(n); m_todo.pop_back(); } } diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index 5bfecb86c..d6a02ddd1 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -122,6 +122,7 @@ namespace euf { smt_params m_smt_params; egraph m_egraph; + unsigned m_th_var = 0; scoped_ptr m_mam; enode* m_tt, *m_ff; ptr_vector m_todo;