From 0995928f6e99a0a89210ada18a1011ef9c7cdebf Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 11 Jul 2025 12:48:27 +0200 Subject: [PATCH] wip - throttle AC completion, enable congruences over bound bodies - AC completion which is exposed as an option to the new congruence closure core used roots of E-Graph which gets ordering of monomials out of sync. - Added injective function handling to AC completion - Move to model where all equations, also unit to unit are in completion - throw in first level bound bodies into the E-graph to enable canonization on them. --- src/ast/euf/euf_ac_plugin.cpp | 247 +++++++++++++++++++------ src/ast/euf/euf_ac_plugin.h | 44 ++++- src/ast/euf/euf_arith_plugin.cpp | 1 + src/ast/euf/euf_plugin.cpp | 4 + src/ast/expr_abstract.h | 2 + src/ast/simplifiers/euf_completion.cpp | 110 ++++++++++- src/ast/simplifiers/euf_completion.h | 13 ++ 7 files changed, 345 insertions(+), 76 deletions(-) diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 8558eb925..c57507838 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -53,6 +53,14 @@ More notes: after merge n' + k = j + k, could be that n' + k < j + k < n + k in term ordering because n' < j, m < n + We filter results of superposition so that the size of monomials in the new equations don't grow. + This filter is used to make the process somewhat tractable. In other words, we never compute a + complete saturation. If l1 = r1 l2 = r2 are used to produce new equation l = r, we ensure + that min(|l|,|r|) <= min(|r1|,|r2|) and max(|l|,|r|) <= max(|l1|,|l2|) + + If the operator is injective we simplify equations + xl = xr to l = r + TODOs: - Efficiency of handling shared terms. @@ -63,6 +71,7 @@ TODOs: - by an epoch counter that can be updated by the egraph class whenever there is a push/pop. - store the epoch as a tick on equations and possibly when updating monomials on equations. + --*/ #include "ast/euf/euf_ac_plugin.h" @@ -195,6 +204,20 @@ namespace euf { return out; } + std::ostream& ac_plugin::display_monomial_ll(std::ostream& out, ptr_vector const& m) const { + for (auto n : m) + out << n->id() << " "; + return out; + } + + std::ostream& ac_plugin::display_equation_ll(std::ostream& out, eq const& e) const { + display_status(out, e.status) << " "; + display_monomial_ll(out, monomial(e.l)); + out << "== "; + display_monomial_ll(out, monomial(e.r)); + return out; + } + std::ostream& ac_plugin::display_status(std::ostream& out, eq_status s) const { switch (s) { case eq_status::is_dead: out << "d"; break; @@ -207,15 +230,14 @@ namespace euf { std::ostream& ac_plugin::display(std::ostream& out) const { unsigned i = 0; for (auto const& eq : m_eqs) { - out << i << ": " << eq.l << " == " << eq.r << ": "; - display_equation(out, eq); - out << "\n"; + if (eq.status != eq_status::is_dead) + out << i << ": " << eq_pp_ll(*this, eq) << "\n"; ++i; } i = 0; for (auto m : m_monomials) { out << i << ": "; - display_monomial(out, m); + display_monomial_ll(out, m); out << "\n"; ++i; } @@ -224,7 +246,7 @@ namespace euf { continue; if (n->eqs.empty() && n->shared.empty()) continue; - out << g.bpp(n->n) << " r: " << n->root_id() << " "; + out << g.bpp(n->n) << " r: " << n->id() << " "; if (!n->eqs.empty()) { out << "eqs "; for (auto l : n->eqs) @@ -247,8 +269,7 @@ namespace euf { TRACE(plugin, tout << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n"); if (!is_op(l) && !is_op(r)) merge(mk_node(l), mk_node(r), j); - else - init_equation(eq(to_monomial(l), to_monomial(r), j)); + init_equation(eq(to_monomial(l), to_monomial(r), j)); } void ac_plugin::diseq_eh(enode* eq) { @@ -264,12 +285,20 @@ namespace euf { void ac_plugin::init_equation(eq const& e) { m_eqs.push_back(e); auto& eq = m_eqs.back(); - TRACE(plugin, display_equation(tout, e) << "\n"); + TRACE(plugin, display_equation_ll(tout, e) << "\n"); + deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes); + TRACE(plugin, display_equation_ll(tout << "dedup ", e) << "\n"); + if (orient_equation(eq)) { + auto& ml = monomial(eq.l); + auto& mr = monomial(eq.r); + + if (ml.size() == 1 && mr.size() == 1) + push_merge(ml[0]->n, mr[0]->n, eq.j); unsigned eq_id = m_eqs.size() - 1; - for (auto n : monomial(eq.l)) { + for (auto n : ml) { if (!n->root->n->is_marked1()) { n->root->eqs.push_back(eq_id); n->root->n->mark1(); @@ -280,7 +309,7 @@ namespace euf { } } - for (auto n : monomial(eq.r)) { + for (auto n : mr) { if (!n->root->n->is_marked1()) { n->root->eqs.push_back(eq_id); n->root->n->mark1(); @@ -291,13 +320,13 @@ namespace euf { } } - for (auto n : monomial(eq.l)) + for (auto n : ml) n->root->n->unmark1(); - for (auto n : monomial(eq.r)) + for (auto n : mr) n->root->n->unmark1(); - TRACE(plugin, display_equation(tout, e) << "\n"); + TRACE(plugin, display_equation_ll(tout, e) << "\n"); m_to_simplify_todo.insert(eq_id); } else @@ -312,14 +341,14 @@ namespace euf { if (ml.size() < mr.size()) { std::swap(e.l, e.r); return true; - } + } else { sort(ml); sort(mr); - for (unsigned i = ml.size(); i-- > 0;) { - if (ml[i]->root_id() == mr[i]->root_id()) + for (unsigned i = 0; i < ml.size(); ++i) { + if (ml[i]->id() == mr[i]->id()) continue; - if (ml[i]->root_id() < mr[i]->root_id()) + if (ml[i]->id() < mr[i]->id()) std::swap(e.l, e.r); return true; } @@ -327,15 +356,38 @@ namespace euf { } } + bool ac_plugin::is_equation_oriented(eq const& e) const { + auto& ml = monomial(e.l); + auto& mr = monomial(e.r); + if (ml.size() > mr.size()) + return true; + if (ml.size() < mr.size()) + return false; + else { + if (!is_sorted(ml)) + return false; + if (!is_sorted(mr)) + return false; + for (unsigned i = 0; i < ml.size(); ++i) { + if (ml[i]->id() == mr[i]->id()) + continue; + if (ml[i]->id() < mr[i]->id()) + return false; + return true; + } + return false; + } + } + void ac_plugin::sort(monomial_t& m) { - std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); }); + std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->id() < b->id(); }); } bool ac_plugin::is_sorted(monomial_t const& m) const { if (m.m_bloom.m_tick == m_tick) return true; for (unsigned i = m.size(); i-- > 1; ) - if (m[i - 1]->root_id() > m[i]->root_id()) + if (m[i - 1]->id() > m[i]->id()) return false; return true; } @@ -346,7 +398,7 @@ namespace euf { return bloom.m_filter; bloom.m_filter = 0; for (auto n : m) - bloom.m_filter |= (1ull << (n->root_id() % 64ull)); + bloom.m_filter |= (1ull << (n->id() % 64ull)); if (!is_sorted(m)) sort(m); bloom.m_tick = m_tick; @@ -367,25 +419,29 @@ namespace euf { if (bloom.m_tick != m_tick) { bloom.m_filter = 0; for (auto n : m) - bloom.m_filter |= (1ull << (n->root_id() % 64ull)); + bloom.m_filter |= (1ull << (n->id() % 64ull)); bloom.m_tick = m_tick; } auto f2 = bloom.m_filter; return (filter(subset) | f2) == f2; } - void ac_plugin::merge(node* root, node* other, justification j) { - TRACE(plugin, tout << root << " == " << other << " num shared " << other->shared.size() << "\n"); - for (auto n : equiv(other)) - n->root = root; - m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() }); - for (auto eq_id : other->eqs) + void ac_plugin::merge(node* a, node* b, justification j) { + TRACE(plugin, tout << a << " == " << b << " num shared " << b->shared.size() << "\n"); + if (a == b) + return; + if (a->id() < b->id()) + std::swap(a, b); + for (auto n : equiv(a)) + n->root = b; + m_merge_trail.push_back({ a, b->shared.size(), b->eqs.size() }); + for (auto eq_id : a->eqs) set_status(eq_id, eq_status::to_simplify); - for (auto m : other->shared) + for (auto m : a->shared) m_shared_todo.insert(m); - root->shared.append(other->shared); - root->eqs.append(other->eqs); - std::swap(root->next, other->next); + b->shared.append(a->shared); + b->eqs.append(a->eqs); + std::swap(b->next, a->next); push_undo(is_merge_node); ++m_tick; } @@ -427,7 +483,7 @@ namespace euf { args.push_back(arg->root->n->get_expr()); } auto n = m.mk_app(m_fid, m_op, args.size(), args.data()); - return g.mk(n, 0, nodes.size(), nodes.data()); + return g.find(n) ? g.find(n) : g.mk(n, 0, nodes.size(), nodes.data()); } ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) { @@ -460,7 +516,7 @@ namespace euf { if (eq_id == UINT_MAX) break; - TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp(*this, m_eqs[eq_id]) << "\n"); + TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n"); // simplify eq using processed TRACE(plugin, @@ -470,6 +526,12 @@ namespace euf { if (is_processed(other_eq) && backward_simplify(eq_id, other_eq)) goto loop_start; + auto& eq = m_eqs[eq_id]; + deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes); + if (monomial(eq.l).size() == 0) { + set_status(eq_id, eq_status::is_dead); + continue; + } set_status(eq_id, eq_status::processed); // simplify processed using eq @@ -478,9 +540,13 @@ namespace euf { forward_simplify(eq_id, other_eq); // superpose, create new equations + unsigned new_eqs = 0; for (auto other_eq : superpose_iterator(eq_id)) if (is_processed(other_eq)) - superpose(eq_id, other_eq); + new_eqs += superpose(eq_id, other_eq); + + (void)new_eqs; + TRACE(plugin, tout << "added eqs " << new_eqs << "\n"); // simplify to_simplify using eq for (auto other_eq : forward_iterator(eq_id)) @@ -633,7 +699,7 @@ namespace euf { void ac_plugin::init_ref_counts(ptr_vector const& monomial, ref_counts& counts) const { counts.reset(); for (auto n : monomial) - counts.inc(n->root_id(), 1); + counts.inc(n->id(), 1); } bool ac_plugin::is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const { @@ -660,7 +726,7 @@ namespace euf { auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized auto& dst = m_eqs[dst_eq]; - TRACE(plugin, tout << "forward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n"); + TRACE(plugin, tout << "forward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n"); if (forward_subsumes(src_eq, dst_eq)) { @@ -685,7 +751,7 @@ namespace euf { // := src_rhs + elements from dst_rhs that are in excess of src_lhs unsigned num_overlap = 0; for (auto n : monomial(dst.r)) { - unsigned id = n->root_id(); + unsigned id = n->id(); unsigned dst_count = m_dst_r_counts[id]; unsigned src_count = m_src_l_counts[id]; if (dst_count > src_count) { @@ -726,7 +792,7 @@ namespace euf { // // dst_ids, dst_count contain rhs of dst_eq // - TRACE(plugin, tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n"); + TRACE(plugin, tout << "backward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n"); if (backward_subsumes(src_eq, dst_eq)) { TRACE(plugin, tout << "backward subsumed\n"); @@ -782,7 +848,7 @@ namespace euf { SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); // add difference betwen dst.l and src.l to both src.l, src.r for (auto n : monomial(dst.l)) { - unsigned id = n->root_id(); + unsigned id = n->id(); SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]); unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; if (diff > 0) { @@ -792,7 +858,7 @@ namespace euf { } // now dst.r and src.r should align and have the same elements. // since src.r is a subset of dst.r we iterate over dst.r - return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); } // src_l_counts, src_r_counts are initialized for src.l, src.r @@ -815,7 +881,7 @@ namespace euf { SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); for (auto n : monomial(src.l)) { - unsigned id = n->root_id(); + unsigned id = n->id(); SASSERT(m_src_l_counts[id] <= m_dst_l_counts[id]); unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; if (diff == 0) @@ -826,7 +892,7 @@ namespace euf { m_dst_r_counts.dec(id, diff); } - return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); } void ac_plugin::rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_counts, ptr_vector& dst) { @@ -837,14 +903,17 @@ namespace euf { SASSERT(is_correct_ref_count(dst, dst_counts)); SASSERT(&src_r.m_nodes != &dst); unsigned sz = dst.size(), j = 0; + bool change = false; for (unsigned i = 0; i < sz; ++i) { auto* n = dst[i]; - unsigned id = n->root_id(); + unsigned id = n->id(); unsigned dst_count = dst_counts[id]; unsigned src_count = src_l[id]; SASSERT(dst_count > 0); - if (src_count == 0) - dst[j++] = n; + + if (src_count == 0) { + dst[j++] = n; + } else if (src_count < dst_count) { dst[j++] = n; dst_counts.dec(id, 1); @@ -857,25 +926,41 @@ namespace euf { // rewrite monomial to normal form. bool ac_plugin::reduce(ptr_vector& m, justification& j) { bool change = false; + unsigned sz = m.size(); + unsigned jj = 0; + //verbose_stream() << "start\n"; do { init_loop: + //verbose_stream() << "loop " << jj++ << "\n"; if (m.size() == 1) return change; bloom b; init_ref_counts(m, m_m_counts); + unsigned k = 0; for (auto n : m) { + //verbose_stream() << "inner loop " << k++ << "\n"; for (auto eq : n->root->eqs) { if (!is_processed(eq)) continue; auto& src = m_eqs[eq]; + if (!is_equation_oriented(src)) { + continue; + if (!orient_equation(src)) + continue; + // deduplicate(src.l, src.r); + } if (!can_be_subset(monomial(src.l), m, b)) continue; if (!is_subset(m_m_counts, m_eq_counts, monomial(src.l))) continue; - TRACE(plugin, display_equation(tout << "reduce ", src) << "\n"); + + TRACE(plugin, display_equation_ll(tout << "reduce ", src) << "\n"); SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts)); + //display_equation_ll(std::cout << "reduce ", src) << ": "; + //display_monomial_ll(std::cout, m); rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m); + //display_monomial_ll(std::cout << " -> ", m) << "\n"; j = join(j, eq); change = true; goto init_loop; @@ -883,6 +968,7 @@ namespace euf { } } while (false); + VERIFY(sz >= m.size()); return change; } @@ -917,13 +1003,17 @@ namespace euf { } - void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) { + bool ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) { if (src_eq == dst_eq) - return; + return false; auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; - TRACE(plugin, tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";); + unsigned max_left = std::max(monomial(src.l).size(), monomial(dst.l).size()); + unsigned min_right = std::max(monomial(src.r).size(), monomial(dst.r).size()); + + + TRACE(plugin, tout << "superpose: "; display_equation_ll(tout, src); tout << " "; display_equation_ll(tout, dst); tout << "\n";); // AB -> C, AD -> E => BE ~ CD // m_src_ids, m_src_counts contains information about src (call it AD -> E) m_dst_l_counts.reset(); @@ -938,7 +1028,7 @@ namespace euf { // compute BE, initialize dst_ids, dst_counts bool overlap = false; for (auto n : monomial(dst.l)) { - unsigned id = n->root_id(); + unsigned id = n->id(); m_dst_l_counts.inc(id, 1); if (m_src_l_counts[id] < m_dst_l_counts[id]) m_src_r.push_back(n); @@ -947,12 +1037,12 @@ namespace euf { if (!overlap) { m_src_r.shrink(src_r_size); - return; + return false; } // compute CD for (auto n : monomial(src.l)) { - unsigned id = n->root_id(); + unsigned id = n->id(); if (m_dst_l_counts[id] > 0) m_dst_l_counts.dec(id, 1); else @@ -961,21 +1051,28 @@ namespace euf { if (are_equal(m_src_r, m_dst_r)) { m_src_r.shrink(src_r_size); - return; + return false; } - TRACE(plugin, tout << m_pp(*this, m_src_r) << "== " << m_pp(*this, m_dst_r) << "\n";); + TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";); justification j = justify_rewrite(src_eq, dst_eq); + deduplicate(m_src_r, m_dst_r); reduce(m_dst_r, j); reduce(m_src_r, j); - if (m_src_r.size() == 1 && m_dst_r.size() == 1) - push_merge(m_src_r[0]->n, m_dst_r[0]->n, j); - else + TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";); + + bool added_eq = false; + unsigned max_left_new = std::max(m_src_r.size(), m_dst_r.size()); + unsigned min_right_new = std::min(m_src_r.size(), m_dst_r.size()); + if (max_left_new <= max_left && min_right_new <= min_right) { init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j)); + added_eq = true; + } m_src_r.reset(); m_src_r.append(monomial(src.r).m_nodes); + return added_eq; } bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) { @@ -987,10 +1084,10 @@ namespace euf { return false; m_eq_counts.reset(); for (auto n : a) - m_eq_counts.inc(n->root_id(), 1); + m_eq_counts.inc(n->id(), 1); for (auto n : b) { - unsigned id = n->root_id(); + unsigned id = n->id(); if (m_eq_counts[id] == 0) return false; m_eq_counts.dec(id, 1); @@ -998,6 +1095,38 @@ namespace euf { return true; } + void ac_plugin::deduplicate(ptr_vector& a, ptr_vector& b) { + if (!m_is_injective) + return; + m_eq_counts.reset(); + for (auto n : a) + m_eq_counts.inc(n->id(), 1); + bool has_dup = any_of(b, [&](node* n) { return m_eq_counts[n->id()] > 0; }); + if (!has_dup) + return; + std::sort(a.begin(), a.end(), [&](node* x, node* y) { return x->id() < y->id(); }); + std::sort(b.begin(), b.end(), [&](node* x, node* y) { return x->id() < y->id(); }); + unsigned i = 0, j = 0, in = 0, jn = 0; + for (; i < a.size() && j < b.size(); ) { + if (a[i]->id() == b[j]->id()) { + ++i; + ++j; + } + else if (a[i]->id() < b[j]->id()) { + a[in++] = a[i++]; + } + else { + b[jn++] = b[j++]; + } + } + for (; i < a.size(); ++i) + a[in++] = a[i]; + for (; j < b.size(); ++j) + b[jn++] = b[j]; + a.shrink(in); + b.shrink(jn); + } + // // simple version based on propagating all shared // todo: version touching only newly processed shared, and maintaining incremental data-structures. diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index 290ddb561..1a699fce3 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -46,7 +46,7 @@ namespace euf { unsigned_vector shared; // shared occurrences unsigned_vector eqs; // equality occurrences - unsigned root_id() const { return root->n->get_id(); } + unsigned id() const { return root->n->get_id(); } static node* mk(region& r, enode* n); }; @@ -117,7 +117,7 @@ namespace euf { if (!p.is_sorted(m)) p.sort(m); for (auto* n : m) - h = combine_hash(h, n->root_id()); + h = combine_hash(h, n->id()); return h; } }; @@ -130,7 +130,7 @@ namespace euf { auto const& m2 = p.monomial(j); if (m1.size() != m2.size()) return false; for (unsigned k = 0; k < m1.size(); ++k) - if (m1[k]->root_id() != m2[k]->root_id()) + if (m1[k]->id() != m2[k]->id()) return false; return true; } @@ -139,6 +139,7 @@ namespace euf { theory_id m_fid = 0; decl_kind m_op = null_decl_kind; func_decl* m_decl = nullptr; + bool m_is_injective = false; vector m_eqs; ptr_vector m_nodes; bool_vector m_shared_nodes; @@ -208,7 +209,8 @@ namespace euf { void forward_simplify(unsigned eq_id, unsigned using_eq); bool backward_simplify(unsigned eq_id, unsigned using_eq); - void superpose(unsigned src_eq, unsigned dst_eq); + bool superpose(unsigned src_eq, unsigned dst_eq); + void deduplicate(ptr_vector& a, ptr_vector& b); ptr_vector m_src_r, m_src_l, m_dst_r, m_dst_l; @@ -236,6 +238,7 @@ namespace euf { void compress_eq_occurs(unsigned eq_id); // check that src is a subset of dst, where dst_counts are precomputed bool is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src); + bool is_equation_oriented(eq const& e) const; // check that dst is a superset of dst, where src_counts are precomputed bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst); @@ -262,6 +265,9 @@ namespace euf { std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const { return display_monomial(out, m.m_nodes); } std::ostream& display_monomial(std::ostream& out, ptr_vector const& m) const; std::ostream& display_equation(std::ostream& out, eq const& e) const; + std::ostream& display_monomial_ll(std::ostream& out, monomial_t const& m) const { return display_monomial_ll(out, m.m_nodes); } + std::ostream& display_monomial_ll(std::ostream& out, ptr_vector const& m) const; + std::ostream& display_equation_ll(std::ostream& out, eq const& e) const; std::ostream& display_status(std::ostream& out, eq_status s) const; @@ -270,6 +276,8 @@ namespace euf { ac_plugin(egraph& g, unsigned fid, unsigned op); ac_plugin(egraph& g, func_decl* f); + + void set_injective() { m_is_injective = true; } theory_id get_id() const override { return m_fid; } @@ -288,20 +296,36 @@ namespace euf { void set_undo(std::function u) { m_undo_notify = u; } struct eq_pp { - ac_plugin& p; eq const& e; - eq_pp(ac_plugin& p, eq const& e) : p(p), e(e) {}; - eq_pp(ac_plugin& p, unsigned eq_id): p(p), e(p.m_eqs[eq_id]) {} + ac_plugin const& p; eq const& e; + eq_pp(ac_plugin const& p, eq const& e) : p(p), e(e) {}; + eq_pp(ac_plugin const& p, unsigned eq_id): p(p), e(p.m_eqs[eq_id]) {} std::ostream& display(std::ostream& out) const { return p.display_equation(out, e); } }; + struct eq_pp_ll { + ac_plugin const& p; eq const& e; + eq_pp_ll(ac_plugin const& p, eq const& e) : p(p), e(e) {}; + eq_pp_ll(ac_plugin const& p, unsigned eq_id) : p(p), e(p.m_eqs[eq_id]) {} + std::ostream& display(std::ostream& out) const { return p.display_equation_ll(out, e); } + }; + struct m_pp { - ac_plugin& p; ptr_vector const& m; - m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m.m_nodes) {} - m_pp(ac_plugin& p, ptr_vector const& m) : p(p), m(m) {} + ac_plugin const& p; ptr_vector const& m; + m_pp(ac_plugin const& p, monomial_t const& m) : p(p), m(m.m_nodes) {} + m_pp(ac_plugin const& p, ptr_vector const& m) : p(p), m(m) {} std::ostream& display(std::ostream& out) const { return p.display_monomial(out, m); } }; + + struct m_pp_ll { + ac_plugin const& p; ptr_vector const& m; + m_pp_ll(ac_plugin const& p, monomial_t const& m) : p(p), m(m.m_nodes) {} + m_pp_ll(ac_plugin const& p, ptr_vector const& m) : p(p), m(m) {} + std::ostream& display(std::ostream& out) const { return p.display_monomial_ll(out, m); } + }; }; inline std::ostream& operator<<(std::ostream& out, ac_plugin::eq_pp const& d) { return d.display(out); } + inline std::ostream& operator<<(std::ostream& out, ac_plugin::eq_pp_ll const& d) { return d.display(out); } inline std::ostream& operator<<(std::ostream& out, ac_plugin::m_pp const& d) { return d.display(out); } + inline std::ostream& operator<<(std::ostream& out, ac_plugin::m_pp_ll const& d) { return d.display(out); } } diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp index 317b192c6..df7299c73 100644 --- a/src/ast/euf/euf_arith_plugin.cpp +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -30,6 +30,7 @@ namespace euf { m_add.set_undo(uadd); std::function umul = [&]() { m_undo.push_back(undo_t::undo_mul); }; m_mul.set_undo(umul); + m_add.set_injective(); } void arith_plugin::register_node(enode* n) { diff --git a/src/ast/euf/euf_plugin.cpp b/src/ast/euf/euf_plugin.cpp index 198bc86d8..c6efe521b 100644 --- a/src/ast/euf/euf_plugin.cpp +++ b/src/ast/euf/euf_plugin.cpp @@ -26,11 +26,15 @@ namespace euf { } void plugin::push_merge(enode* a, enode* b, justification j) { + if (a->get_root() == b->get_root()) + return; // already merged TRACE(euf, tout << "push-merge " << g.bpp(a) << " == " << g.bpp(b) << " " << j << "\n"); g.push_merge(a, b, j); } void plugin::push_merge(enode* a, enode* b) { + if (a->get_root() == b->get_root()) + return; // already merged TRACE(plugin, tout << g.bpp(a) << " == " << g.bpp(b) << "\n"); g.push_merge(a, b, justification::axiom(get_id())); } diff --git a/src/ast/expr_abstract.h b/src/ast/expr_abstract.h index 75a10a9ef..7b7e4672d 100644 --- a/src/ast/expr_abstract.h +++ b/src/ast/expr_abstract.h @@ -35,6 +35,8 @@ void expr_abstract(ast_manager& m, unsigned base, unsigned num_bound, expr* cons inline expr_ref expr_abstract(ast_manager& m, unsigned base, unsigned num_bound, expr* const* bound, expr* n) { expr_ref r(m); expr_abstract(m, base, num_bound, bound, n, r); return r; } inline expr_ref expr_abstract(expr_ref_vector const& bound, expr* n) { return expr_abstract(bound.m(), 0, bound.size(), bound.data(), n); } inline expr_ref expr_abstract(app_ref_vector const& bound, expr* n) { return expr_abstract(bound.m(), 0, bound.size(), (expr*const*)bound.data(), n); } +inline expr_ref expr_abstract(ast_manager& m, ptr_vector const& bound, expr* n) { return expr_abstract(m, 0, bound.size(), bound.data(), n); } + expr_ref mk_forall(ast_manager& m, unsigned num_bound, app* const* bound, expr* n); expr_ref mk_exists(ast_manager& m, unsigned num_bound, app* const* bound, expr* n); inline expr_ref mk_forall(ast_manager& m, app* b, expr* n) { return mk_forall(m, 1, &b, n); } diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index 9ff6ecc13..f711fc19e 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -50,6 +50,7 @@ Mam optimization? #include "ast/ast_pp.h" #include "ast/ast_util.h" +#include "ast/expr_abstract.h" #include "ast/euf/euf_egraph.h" #include "ast/euf/euf_arith_plugin.h" #include "ast/euf/euf_bv_plugin.h" @@ -276,6 +277,9 @@ namespace euf { expr_ref y1(y, m); m_rewriter(x1); m_rewriter(y1); + + add_quantifiers(x1); + add_quantifiers(y1); enode* a = mk_enode(x1); enode* b = mk_enode(y1); if (a->get_root() == b->get_root()) @@ -283,25 +287,40 @@ namespace euf { m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); add_children(a); add_children(b); + auto a1 = mk_enode(x); + if (a1->get_root() != a->get_root()) { + m_egraph.merge(a, a1, nullptr); + add_children(a1); + } + auto b1 = mk_enode(y); + if (b1->get_root() != b->get_root()) { + m_egraph.merge(b, b1, nullptr); + add_children(b1); + } + m_should_propagate = true; if (m_side_condition_solver) m_side_condition_solver->add_constraint(f, pr, d); + IF_VERBOSE(1, verbose_stream() << "eq: " << mk_pp(x1, m) << " == " << mk_pp(y1, m) << "\n"); } else if (m.is_not(f, f)) { enode* n = mk_enode(f); if (m.is_false(n->get_root()->get_expr())) return; + add_quantifiers(f); auto j = to_ptr(push_pr_dep(pr, d)); m_egraph.new_diseq(n, j); add_children(n); m_should_propagate = true; if (m_side_condition_solver) m_side_condition_solver->add_constraint(f, pr, d); + IF_VERBOSE(1, verbose_stream() << "not: " << mk_pp(f, m) << "\n"); } else { enode* n = mk_enode(f); if (m.is_true(n->get_root()->get_expr())) - return; + return; + IF_VERBOSE(1, verbose_stream() << "fml: " << mk_pp(f, m) << "\n"); m_egraph.merge(n, m_tt, to_ptr(push_pr_dep(pr, d))); add_children(n); if (is_forall(f)) { @@ -314,7 +333,7 @@ namespace euf { q = to_quantifier(tmp); } #endif - + for (unsigned i = 0; i < q->get_num_patterns(); ++i) { auto p = to_app(q->get_pattern(i)); auto [q1, p1] = m_matcher.compile_ho_pattern(q, p); @@ -326,14 +345,67 @@ namespace euf { mk_enode(g); m_mam->add_pattern(q, p); if (p != p1) - m_mam->add_pattern(q1, p1); + m_mam->add_pattern(q1, p1); } - m_q2dep.insert(q, { pr, d}); - get_trail().push(insert_obj_map(m_q2dep, q)); + m_q2dep.insert(q, { pr, d }); + get_trail().push(insert_obj_map(m_q2dep, q)); } + add_rule(f, pr, d); - if (!is_forall(f) && !m.is_implies(f) && m_side_condition_solver) - m_side_condition_solver->add_constraint(f, pr, d); + if (!is_forall(f) && !m.is_implies(f)) { + add_quantifiers(f); + if (m_side_condition_solver) + m_side_condition_solver->add_constraint(f, pr, d); + } + } + } + + void completion::add_quantifiers(expr* f) { + if (!has_quantifiers(f)) + return; + ptr_vector bound; + add_quantifiers(bound, f); + } + + void completion::add_quantifiers(ptr_vector& bound, expr* f) { + if (!has_quantifiers(f)) + return; + + ptr_vector todo; + todo.push_back(f); + expr_fast_mark1 visited; + for (unsigned j = 0; j < todo.size(); ++j) { + expr* t = todo[j]; + if (visited.is_marked(t)) + continue; + visited.mark(t); + if (!has_quantifiers(t)) + continue; + if (is_app(t)) { + for (auto arg : *to_app(t)) + todo.push_back(arg); + } + else if (is_quantifier(t)) { + auto q = to_quantifier(t); + auto nd = q->get_num_decls(); + verbose_stream() << "bind " << mk_pp(q, m) << "\n"; + for (unsigned i = 0; i < nd; ++i) { + auto name = std::string("bound!") + std::to_string(bound.size()); + auto b = m.mk_const(name, q->get_decl_sort(i)); + // TODO: persist bound variables withn scope to avoid reference count crashes + bound.push_back(b); + } + expr_ref inst = var_subst(m)(q->get_expr(), bound); + if (!m_egraph.find(inst)) { + m_closures.insert(q, { bound, inst }); + get_trail().push(insert_map(m_closures, q)); + mk_enode(inst); + // TODO: handle nested quantifiers after m_closures is updated to + // index on sort declaration prefix together with quantifier + // add_quantifiers(bound, inst); + } + bound.shrink(bound.size() - nd); + } } } @@ -760,6 +832,9 @@ namespace euf { } expr_ref completion::canonize(expr* f, proof_ref& pr, expr_dependency_ref& d) { + if (is_quantifier(f)) + return expr_ref(canonize(to_quantifier(f), pr, d), m); + if (!is_app(f)) return expr_ref(f, m); // todo could normalize ground expressions under quantifiers @@ -787,8 +862,29 @@ namespace euf { return r; } + expr_ref completion::canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d) { + std::pair, expr*> clos; + if (!m_closures.find(q, clos)) + return expr_ref(q, m); + expr* body = clos.second; + expr_ref new_body = canonize(body, pr, d); + expr_ref result = expr_abstract(m, clos.first, new_body); + if (m.proofs_enabled()) { + // add proof rule + // + // body = new_body + // --------------------------- + // Q x . body = Q x . new_body + NOT_IMPLEMENTED_YET(); + } + return result; + } + + expr* completion::get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d) { enode* n = m_egraph.find(f); + + if (!n) verbose_stream() << "not found " << f->get_id() << " " << mk_pp(f, m) << "\n"; enode* r = n->get_root(); d = m.mk_join(d, explain_eq(n, r)); d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index cda12aebf..6893e5327 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -171,6 +171,19 @@ namespace euf { proof* get_canonical_proof(enode* n); void set_canonical(enode* n, expr* e, proof* pr); void add_constraint(expr*f, proof* pr, expr_dependency* d); + + // Enable equality propagation inside of quantifiers + // add quantifier bodies as closure terms to the E-graph. + // use fresh variables for bound variables, but such that the fresh variables are + // the same when the quantifier prefix is the same. + // Thus, we are going to miss equalities of quantifier bodies + // if the prefixes are different but the bodies are the same. + // Closure terms are re-abstracted by the canonizer. + void add_quantifiers(ptr_vector& bound, expr* t); + void add_quantifiers(expr* t); + expr_ref canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d); + obj_map, expr*>> m_closures; + expr_dependency* explain_eq(enode* a, enode* b); proof_ref prove_eq(enode* a, enode* b); proof_ref prove_conflict();