From dbcbc6c3ac4d750620cfd9dcf953e930e94be2fc Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 21 Jul 2025 07:35:06 -0700 Subject: [PATCH] revamp ac plugin and plugin propagation --- src/ast/euf/euf_ac_plugin.cpp | 483 ++++++++++++++++++------- src/ast/euf/euf_ac_plugin.h | 77 ++-- src/ast/euf/euf_arith_plugin.cpp | 19 + src/ast/euf/euf_arith_plugin.h | 5 + src/ast/euf/euf_egraph.cpp | 42 ++- src/ast/euf/euf_egraph.h | 12 +- src/ast/euf/euf_plugin.cpp | 2 +- src/ast/euf/euf_plugin.h | 3 + src/ast/simplifiers/euf_completion.cpp | 185 ++++++++-- src/ast/simplifiers/euf_completion.h | 10 +- src/sat/smt/arith_axioms.cpp | 2 +- src/sat/smt/arith_solver.cpp | 2 +- src/sat/smt/bv_solver.cpp | 2 +- src/util/trace_tags.def | 1 + 14 files changed, 630 insertions(+), 215 deletions(-) diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 22cf8abbd..b62180df4 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -20,11 +20,38 @@ Completion modulo AC Add new equation zu = xyu = vy by j1, j2 - Notes: - - Some equalities come from shared terms, some do not. + Sets P - processed, R - reductions, S - to simplify + + new equality l = r: + reduce l = r modulo R if equation is external + orient l = r - if it cannot be oriented, discard + if l = r is a reduction rule then reduce R, S using l = r, insert into R + else insert into S + + main loop: + for e as (l = r) in S: + remove e from S + backward simplify e + if e is backward subsumed, continue + if e is a reduction rule, then reduce R, S using e, insert into R, continue + insert e into P + superpose with e + forward simplify with e + +backward simplify e as (l = r) using (l' = r') in P u S: + if l' is a subset or r then replace l' by r' in r. + +backward subsumption e as (l = r) using (l' = r') in P u S: + l = r is of the form l'x = r'x + +is reduction rule e as (l = r): + l is a unit, and r is unit, is empty, or is zero. - - V2 can use multiplicities of elements to handle larger domains. - - e.g. 3x + 100000y +superpose e as (l = r) with (l' = r') in P: + if l and l' share a common subset x. + +forward simplify (l' = r') in P u S using e as (l = r): + More notes: @@ -107,6 +134,19 @@ namespace euf { register_shared(arg); } + // unit -> {} + void ac_plugin::add_unit(enode* u) { + m_units.push_back(u); + auto n = mk_node(u); + auto m_id = to_monomial(u, {}); + init_equation(eq(to_monomial(u), m_id, justification::axiom(get_id()))); + } + + // zero x -> zero + void ac_plugin::add_zero(enode* z) { + mk_node(z)->is_zero = true; + } + void ac_plugin::register_shared(enode* n) { if (m_shared_nodes.get(n->get_id(), false)) return; @@ -144,16 +184,6 @@ namespace euf { m_monomials.pop_back(); break; } - case is_merge_node: { - auto [other, old_shared, old_eqs] = m_merge_trail.back(); - auto* root = other->root; - std::swap(other->next, root->next); - root->shared.shrink(old_shared); - root->eqs.shrink(old_eqs); - m_merge_trail.pop_back(); - ++m_tick; - break; - } case is_update_eq: { auto const& [idx, eq] = m_update_eq_trail.back(); m_eqs[idx] = eq; @@ -226,6 +256,7 @@ namespace euf { case eq_status::is_dead: out << "d"; break; case eq_status::processed: out << "p"; break; case eq_status::to_simplify: out << "s"; break; + case eq_status::is_reducing_eq: out << "r"; break; } return out; } @@ -234,15 +265,16 @@ namespace euf { out << m_name << "\n"; unsigned i = 0; for (auto const& eq : m_eqs) { - if (eq.status != eq_status::is_dead) - out << i << ": " << eq_pp_ll(*this, eq) << "\n"; + if (eq.status != eq_status::is_dead) { + out << "["; display_status(out, eq.status) << "] " << i << " : " << eq_pp_ll(*this, eq) << "\n"; + } ++i; } if (!m_shared.empty()) out << "shared monomials:\n"; for (auto const& s : m_shared) { - out << g.bpp(s.n) << ": " << s.m << "\n"; + out << g.bpp(s.n) << ": " << s.m << " r: " << g.bpp(s.n->get_root()) << "\n"; } #if 0 i = 0; @@ -274,13 +306,21 @@ namespace euf { return out; } + void ac_plugin::collect_statistics(statistics& st) const { + + std::string name = m_name.str(); + m_superposition_stats = symbol((std::string("ac ") + name + " superpositions")); + m_eqs_stats = symbol((std::string("ac ") + name + " equations")); + st.update(m_superposition_stats.bare_str(), m_stats.m_num_superpositions); + st.update(m_eqs_stats.bare_str(), m_eqs.size()); + } + void ac_plugin::merge_eh(enode* l, enode* r) { if (l == r) return; + m_fuel += m_fuel_inc; auto j = justification::equality(l, r); TRACE(plugin, tout << "merge: " << m_name << " " << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n"); - if (!is_op(l) && !is_op(r)) - merge(mk_node(l), mk_node(r), j); init_equation(eq(to_monomial(l), to_monomial(r), j)); } @@ -294,7 +334,7 @@ namespace euf { register_shared(b); } - void ac_plugin::init_equation(eq const& e) { + bool ac_plugin::init_equation(eq const& e) { m_eqs.push_back(e); auto& eq = m_eqs.back(); deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes); @@ -303,44 +343,53 @@ namespace euf { auto& ml = monomial(eq.l); auto& mr = monomial(eq.r); - if (ml.size() == 1 && mr.size() == 1) - push_merge(ml[0]->n, mr[0]->n, eq.j); - unsigned eq_id = m_eqs.size() - 1; + if (ml.size() == 1 && mr.size() == 1) + push_merge(ml[0]->n, mr[0]->n, eq.j); + for (auto n : ml) { - if (!n->root->n->is_marked1()) { - n->root->eqs.push_back(eq_id); - n->root->n->mark1(); + if (!n->n->is_marked2()) { + n->eqs.push_back(eq_id); + n->n->mark2(); push_undo(is_add_eq_index); - m_node_trail.push_back(n->root); - for (auto s : n->root->shared) + m_node_trail.push_back(n); + for (auto s : n->shared) m_shared_todo.insert(s); } } for (auto n : mr) { - if (!n->root->n->is_marked1()) { - n->root->eqs.push_back(eq_id); - n->root->n->mark1(); + if (!n->n->is_marked2()) { + n->eqs.push_back(eq_id); + n->n->mark2(); push_undo(is_add_eq_index); - m_node_trail.push_back(n->root); - for (auto s : n->root->shared) + m_node_trail.push_back(n); + for (auto s : n->shared) m_shared_todo.insert(s); } } for (auto n : ml) - n->root->n->unmark1(); + n->n->unmark2(); for (auto n : mr) - n->root->n->unmark1(); + n->n->unmark2(); - TRACE(plugin, display_equation_ll(tout, e) << " shared: " << m_shared_todo << "\n"); + SASSERT(well_formed(eq)); + + TRACE(plugin, display_equation_ll(tout, eq) << " shared: " << m_shared_todo << "\n"); m_to_simplify_todo.insert(eq_id); + m_new_eqs.push_back(eq_id); + + //display_equation_ll(verbose_stream() << "init " << eq_id << ": ", eq) << "\n"; + + return true; } - else + else { m_eqs.pop_back(); + return false; + } } bool ac_plugin::orient_equation(eq& e) { @@ -361,7 +410,7 @@ namespace euf { if (ml[i]->id() < mr[i]->id()) std::swap(e.l, e.r); return true; - } + } return false; } } @@ -436,25 +485,6 @@ namespace euf { return (filter(subset) | f2) == f2; } - void ac_plugin::merge(node* a, node* b, justification j) { - TRACE(plugin, tout << g.bpp(a->n) << " == " << g.bpp(b->n) << " num shared " << b->shared.size() << "\n"); - if (a == b) - return; - if (a->id() < b->id()) - std::swap(a, b); - for (auto n : equiv(a)) - n->root = b; - m_merge_trail.push_back({ a, b->shared.size(), b->eqs.size() }); - for (auto eq_id : a->eqs) - set_status(eq_id, eq_status::to_simplify); - for (auto m : a->shared) - m_shared_todo.insert(m); - b->shared.append(a->shared); - b->eqs.append(a->eqs); - std::swap(b->next, a->next); - push_undo(is_merge_node); - ++m_tick; - } void ac_plugin::push_undo(undo_kind k) { m_undo.push_back(k); @@ -489,10 +519,21 @@ namespace euf { ptr_buffer args; enode_vector nodes; for (auto arg : mon) { - nodes.push_back(arg->root->n); - args.push_back(arg->root->n->get_expr()); + nodes.push_back(arg->n); + args.push_back(arg->n->get_expr()); + } + expr* n = nullptr; + switch (args.size()) { + case 0: + UNREACHABLE(); + break; + case 1: + n = args[0]; + break; + default: + n = m.mk_app(m_fid, m_op, args.size(), args.data()); + break; } - auto n = args.size() == 1 ? args[0] : m.mk_app(m_fid, m_op, args.size(), args.data()); auto r = g.find(n); return r ? r : g.mk(n, 0, nodes.size(), nodes.data()); } @@ -501,8 +542,6 @@ namespace euf { auto* mem = r.allocate(sizeof(node)); node* res = new (mem) node(); res->n = n; - res->root = res; - res->next = res; return res; } @@ -521,13 +560,22 @@ namespace euf { } void ac_plugin::propagate() { + //verbose_stream() << "propagate " << m_name << "\n"; + unsigned ts = m_to_simplify_todo.size(); + unsigned round = 0; while (true) { - loop_start: + loop_start: + //verbose_stream() << "loop_start " << (round++) << " " << m_to_simplify_todo.size() << " ts: " << ts << "\n"; + if (m_fuel == 0) + break; unsigned eq_id = pick_next_eq(); if (eq_id == UINT_MAX) break; TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n"); + //verbose_stream() << m_name << " propagate eq " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n"; + + SASSERT(well_formed(m_eqs[eq_id])); // simplify eq using processed TRACE(plugin, @@ -543,26 +591,41 @@ namespace euf { set_status(eq_id, eq_status::is_dead); continue; } + if (is_backward_subsumed(eq_id)) { + set_status(eq_id, eq_status::is_dead); + continue; + } + + if (is_reducing(eq)) { + set_status(eq_id, eq_status::is_reducing_eq); + forward_reduce(eq_id); + continue; + } + --m_fuel; set_status(eq_id, eq_status::processed); // simplify processed using eq for (auto other_eq : forward_iterator(eq_id)) - if (is_processed(other_eq)) + if (is_processed(other_eq) || is_reducing(other_eq)) forward_simplify(eq_id, other_eq); + backward_subsume_new_eqs(); // superpose, create new equations - unsigned new_eqs = 0; + unsigned new_sup = 0; + m_new_eqs.reset(); for (auto other_eq : superpose_iterator(eq_id)) if (is_processed(other_eq)) - new_eqs += superpose(eq_id, other_eq); + new_sup += superpose(eq_id, other_eq); + backward_subsume_new_eqs(); - (void)new_eqs; - TRACE(plugin, tout << "added eqs " << new_eqs << "\n"); + m_stats.m_num_superpositions += new_sup; + TRACE(plugin, tout << "new superpositions " << new_sup << "\n"); // simplify to_simplify using eq for (auto other_eq : forward_iterator(eq_id)) if (is_to_simplify(other_eq)) forward_simplify(eq_id, other_eq); + backward_subsume_new_eqs(); } propagate_shared(); @@ -584,7 +647,7 @@ namespace euf { auto& eq = m_eqs[id]; if (eq.status == eq_status::is_dead) return; - if (s == eq_status::to_simplify && are_equal(monomial(eq.l), monomial(eq.r))) + if (are_equal(monomial(eq.l), monomial(eq.r))) s = eq_status::is_dead; if (eq.status != s) { @@ -594,12 +657,15 @@ namespace euf { } switch (s) { case eq_status::processed: + case eq_status::is_reducing_eq: case eq_status::is_dead: m_to_simplify_todo.remove(id); break; case eq_status::to_simplify: m_to_simplify_todo.insert(id); - orient_equation(eq); + if (!orient_equation(eq)) { + set_status(id, eq_status::is_dead); + } break; } } @@ -624,6 +690,12 @@ namespace euf { auto const& eq = m_eqs[eq_id]; init_ref_counts(monomial(eq.r), m_dst_r_counts); init_ref_counts(monomial(eq.l), m_dst_l_counts); + if (monomial(eq.r).size() == 0) { + m_dst_l.reset(); + m_dst_l.append(monomial(eq.l).m_nodes); + init_subset_iterator(eq_id, monomial(eq.l)); + return m_eq_occurs; + } m_dst_r.reset(); m_dst_r.append(monomial(eq.r).m_nodes); init_subset_iterator(eq_id, monomial(eq.r)); @@ -633,7 +705,7 @@ namespace euf { void ac_plugin::init_overlap_iterator(unsigned eq_id, monomial_t const& m) { m_eq_occurs.reset(); for (auto n : m) - m_eq_occurs.append(n->root->eqs); + m_eq_occurs.append(n->eqs); compress_eq_occurs(eq_id); } @@ -649,17 +721,17 @@ namespace euf { node* max_n = nullptr; bool has_two = false; for (auto n : m) - if (n->root->eqs.size() >= max_use) - has_two |= max_n && (max_n != n->root), max_n = n->root, max_use = n->root->eqs.size(); + if (n->eqs.size() >= max_use) + has_two |= max_n && (max_n != n), max_n = n, max_use = n->eqs.size(); m_eq_occurs.reset(); if (has_two) { for (auto n : m) - if (n->root != max_n) - m_eq_occurs.append(n->root->eqs); + if (n != max_n) + m_eq_occurs.append(n->eqs); } else { for (auto n : m) { - m_eq_occurs.append(n->root->eqs); + m_eq_occurs.append(n->eqs); break; } } @@ -676,6 +748,8 @@ namespace euf { continue; if (id == eq_id) continue; + if (!is_alive(id)) + continue; m_eq_occurs[j++] = id; m_eq_seen[id] = true; } @@ -696,8 +770,8 @@ namespace euf { unsigned min_r = UINT_MAX; node* min_n = nullptr; for (auto n : monomial(eq.l)) - if (n->root->eqs.size() < min_r) - min_n = n, min_r = n->root->eqs.size(); + if (n->eqs.size() < min_r) + min_n = n, min_r = n->eqs.size(); // found node that occurs in fewest eqs VERIFY(min_n); return min_n->eqs; @@ -722,7 +796,7 @@ namespace euf { init_ref_counts(m, check); return all_of(counts, [&](unsigned i) { return check[i] == counts[i]; }) && - all_of(check, [&](unsigned i) { return check[i] == counts[i]; }); + all_of(check, [&](unsigned i) { return check[i] == counts[i]; }); } void ac_plugin::forward_simplify(unsigned src_eq, unsigned dst_eq) { @@ -737,10 +811,10 @@ namespace euf { auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized auto& dst = m_eqs[dst_eq]; - TRACE(plugin, tout << "forward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n"); + TRACE(plugin_verbose, tout << "forward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n"); if (forward_subsumes(src_eq, dst_eq)) { - TRACE(plugin, tout << "forward subsumed\n"); + TRACE(plugin_verbose, tout << "forward subsumed\n"); set_status(dst_eq, eq_status::is_dead); return; } @@ -761,18 +835,14 @@ namespace euf { unsigned num_overlap = 0; for (auto n : monomial(dst.r)) { unsigned id = n->id(); - unsigned dst_count = m_dst_r_counts[id]; unsigned src_count = m_src_l_counts[id]; - if (dst_count > src_count) { - m_src_r.push_back(n); - m_dst_r_counts.dec(id, 1); - } - else if (dst_count < src_count) { - m_src_r.shrink(src_r_size); - return; + unsigned dst_count = m_dst_r_counts[id]; + if (dst_count < src_count) { + m_dst_r_counts.inc(id, 1); + ++num_overlap; } else - ++num_overlap; + m_src_r.push_back(n); } // The dst.r has to be a superset of src.l, otherwise simplification does not apply if (num_overlap != src_l_size) { @@ -789,7 +859,8 @@ namespace euf { push_undo(is_update_eq); m_src_r.reset(); m_src_r.append(monomial(src.r).m_nodes); - TRACE(plugin, tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); + TRACE(plugin_verbose, tout << "rewritten to " << m_pp_ll(*this, monomial(new_r)) << "\n"); + m_new_eqs.push_back(dst_eq); } bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { @@ -804,22 +875,26 @@ namespace euf { TRACE(plugin, tout << "backward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n"); if (backward_subsumes(src_eq, dst_eq)) { - TRACE(plugin, tout << "backward subsumed\n"); set_status(dst_eq, eq_status::is_dead); return true; } + if (!is_equation_oriented(src)) + return false; // check that src.l is a subset of dst.r if (!can_be_subset(monomial(src.l), monomial(dst.r))) return false; - if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l))) { - TRACE(plugin, tout << "not subset\n"); - return false; - } + if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l))) + return false; + if (monomial(dst.r).size() == 0) + return false; + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); ptr_vector m(m_dst_r); init_ref_counts(monomial(src.l), m_src_l_counts); + + //verbose_stream() << "backward simplify " << eq_pp_ll(*this, src_eq) << " for " << eq_pp_ll(*this, dst_eq) << "\n"; rewrite1(m_src_l_counts, monomial(src.r), m_dst_r_counts, m); auto j = justify_rewrite(src_eq, dst_eq); @@ -831,30 +906,60 @@ namespace euf { m_eqs[dst_eq].j = j; TRACE(plugin, tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); push_undo(is_update_eq); + return true; } + void ac_plugin::backward_subsume_new_eqs() { + for (auto f_id : m_new_eqs) + if (is_backward_subsumed(f_id)) + set_status(f_id, eq_status::is_dead); + m_new_eqs.reset(); + } + + bool ac_plugin::is_backward_subsumed(unsigned eq_id) { + return any_of(backward_iterator(eq_id), [&](unsigned other_eq) { return backward_subsumes(other_eq, eq_id); }); + } + // dst_eq is fixed, dst_l_count is pre-computed for monomial(dst.l) // dst_r_counts is pre-computed for monomial(dst.r). // is dst_eq subsumed by src_eq? bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) { auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; + TRACE(plugin_verbose, tout << "backward subsumes " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n"); SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); - if (!can_be_subset(monomial(src.l), monomial(dst.l))) + if (!can_be_subset(monomial(src.l), monomial(dst.l))) { + TRACE(plugin_verbose, tout << "not subset of dst.l\n"); + SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq])); return false; - if (!can_be_subset(monomial(src.r), monomial(dst.r))) + } + if (!can_be_subset(monomial(src.r), monomial(dst.r))) { + TRACE(plugin_verbose, tout << "not subset of dst.r\n"); + SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq])); return false; + } unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); - if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) + if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) { + TRACE(plugin_verbose, tout << "size diff does not match: " << size_diff << "\n"); + SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq])); return false; - if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) + } + if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) { + TRACE(plugin_verbose, tout << "not subset of dst.l counts\n"); + SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq])); return false; - if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) - return false; + } + if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) { + TRACE(plugin_verbose, tout << "not subset of dst.r counts\n"); + SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq])); + return false; + } SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); + SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); // add difference betwen dst.l and src.l to both src.l, src.r for (auto n : monomial(dst.l)) { unsigned id = n->id(); @@ -867,6 +972,13 @@ namespace euf { } // now dst.r and src.r should align and have the same elements. // since src.r is a subset of dst.r we iterate over dst.r + if (!all_of(monomial(src.r), [&](node* n) { + unsigned id = n->id(); + return m_src_r_counts[id] == m_dst_r_counts[id]; })) { + TRACE(plugin_verbose, tout << "dst.r and src.r do not align\n"); + SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq])); + return false; + } return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); } @@ -937,27 +1049,27 @@ namespace euf { bool change = false; unsigned sz = m.size(); unsigned jj = 0; - //verbose_stream() << "start\n"; do { init_loop: - //verbose_stream() << "loop " << jj++ << "\n"; if (m.size() == 1) return change; bloom b; init_ref_counts(m, m_m_counts); unsigned k = 0; for (auto n : m) { - //verbose_stream() << "inner loop " << k++ << "\n"; - for (auto eq : n->root->eqs) { + if (n->is_zero) { + m[0] = n; + m.shrink(1); + break; + } + for (auto eq : n->eqs) { if (!is_processed(eq)) continue; auto& src = m_eqs[eq]; if (!is_equation_oriented(src)) { + //verbose_stream() << "equation is not oriented: " << m_eq_ll(*this, src) << "\n"; continue; - if (!orient_equation(src)) - continue; - // deduplicate(src.l, src.r); } if (!can_be_subset(monomial(src.l), m, b)) continue; @@ -997,18 +1109,19 @@ namespace euf { void ac_plugin::index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r) { for (auto n : old_r) - n->root->n->mark1(); - for (auto n : new_r) - if (!n->root->n->is_marked1()) { - n->root->eqs.push_back(eq); - m_node_trail.push_back(n->root); - n->root->n->mark1(); + n->n->mark2(); + for (auto n : new_r) { + if (!n->n->is_marked2()) { + n->eqs.push_back(eq); + m_node_trail.push_back(n); + n->n->mark2(); push_undo(is_add_eq_index); } + } for (auto n : old_r) - n->root->n->unmark1(); + n->n->unmark2(); for (auto n : new_r) - n->root->n->unmark1(); + n->n->unmark2(); } @@ -1018,10 +1131,9 @@ namespace euf { auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; - unsigned max_left = std::max(monomial(src.l).size(), monomial(dst.l).size()); + unsigned max_left = std::max(monomial(src.l).size(), monomial(dst.l).size()); unsigned min_right = std::max(monomial(src.r).size(), monomial(dst.r).size()); - TRACE(plugin, tout << "superpose: "; display_equation_ll(tout, src); tout << " "; display_equation_ll(tout, dst); tout << "\n";); // AB -> C, AD -> E => BE ~ CD // m_src_ids, m_src_counts contains information about src (call it AD -> E) @@ -1066,24 +1178,81 @@ namespace euf { TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";); justification j = justify_rewrite(src_eq, dst_eq); - deduplicate(m_src_r, m_dst_r); reduce(m_dst_r, j); reduce(m_src_r, j); + deduplicate(m_src_r, m_dst_r); TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";); bool added_eq = false; + auto src_r = src.r; unsigned max_left_new = std::max(m_src_r.size(), m_dst_r.size()); unsigned min_right_new = std::min(m_src_r.size(), m_dst_r.size()); - if (max_left_new <= max_left && min_right_new <= min_right) { - init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j)); - added_eq = true; - } - + if (max_left_new <= max_left && min_right_new <= min_right) + added_eq = init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j)); + m_src_r.reset(); - m_src_r.append(monomial(src.r).m_nodes); + m_src_r.append(monomial(src_r).m_nodes); return added_eq; } + bool ac_plugin::is_reducing(eq const& e) const { + auto const& l = monomial(e.l); + auto const& r = monomial(e.r); + return l.size() == 1 && r.size() <= 1; + } + + void ac_plugin::forward_reduce(unsigned eq_id) { + auto const& eq = m_eqs[eq_id]; + if (!is_reducing(eq)) + return; + for (auto other_eq : superpose_iterator(eq_id)) { + SASSERT(is_alive(other_eq)); + forward_reduce(eq, other_eq); + } + } + + void ac_plugin::forward_reduce(eq const& eq, unsigned other_eq_id) { + auto& other_eq = m_eqs[other_eq_id]; + + bool change = false; + + if (forward_reduce_monomial(eq, monomial(other_eq.l))) + change = true; + if (forward_reduce_monomial(eq, monomial(other_eq.r))) + change = true; + + if (change) + set_status(other_eq_id, eq_status::to_simplify); + } + + bool ac_plugin::forward_reduce_monomial(eq const& eq, monomial_t& m) { + auto const& r = monomial(eq.r); + unsigned j = 0; + bool change = false; + for (auto n : m) { + unsigned id = n->id(); + SASSERT(m_src_l_counts[id] <= 1); + if (m_src_l_counts[id] == 0) { + m.m_nodes[j++] = n; + continue; + } + change = true; + if (r.size() == 0) + // if r is empty, we can remove n from l + continue; + SASSERT(r.size() == 1); + if (r[0]->is_zero) { + m.m_nodes[0] = r[0]; + j = 1; + break; + } + m.m_nodes[j++] = r[0]; + } + m.m_nodes.shrink(j); + return change; + } + + bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) { return filter(a) == filter(b) && are_equal(a.m_nodes, b.m_nodes); } @@ -1104,7 +1273,42 @@ namespace euf { return true; } + bool ac_plugin::well_formed(eq const& e) const { + if (e.l == e.r) + return false; // trivial equation + for (auto n : monomial(e.l)) { + if (n->is_zero && monomial(e.l).size() > 1) + return false; // zero is not allowed in equations + } + for (auto n : monomial(e.r)) { + if (n->is_zero && monomial(e.r).size() > 1) + return false; // zero is not allowed in equations + } + return true; + } + void ac_plugin::deduplicate(ptr_vector& a, ptr_vector& b) { + { + unsigned j = 0; + for (auto n : a) { + if (n->is_zero) { + //verbose_stream() << "deduplicate: removing zero from a: " << m_pp(*this, a) << "\n"; + a[0] = n; + a.shrink(1); + break; + } + } + j = 0; + for (auto n : b) { + if (n->is_zero) { + // verbose_stream() << "deduplicate: removing zero from b: " << m_pp(*this, b) << "\n"; + b[0] = n; + b.shrink(1); + break; + } + } + } + if (!m_is_injective) return; m_eq_counts.reset(); @@ -1144,7 +1348,7 @@ namespace euf { // void ac_plugin::propagate_shared() { - TRACE(plugin, tout << "num shared todo " << m_shared_todo.size() << "\n"); + TRACE(plugin_verbose, tout << "num shared todo " << m_shared_todo.size() << "\n"); if (m_shared_todo.empty()) return; while (!m_shared_todo.empty()) { @@ -1156,11 +1360,11 @@ namespace euf { m_monomial_table.reset(); for (auto const& s1 : m_shared) { shared s2; - TRACE(plugin, tout << "shared " << s1.m << ": " << m_pp_ll(*this, monomial(s1.m)) << "\n"); + TRACE(plugin_verbose, tout << "shared " << s1.m << ": " << m_pp_ll(*this, monomial(s1.m)) << "\n"); if (!m_monomial_table.find(s1.m, s2)) m_monomial_table.insert(s1.m, s1); else if (s2.n->get_root() != s1.n->get_root()) { - TRACE(plugin, tout << m_pp(*this, monomial(s1.m)) << " == " << m_pp(*this, monomial(s2.m)) << "\n"); + TRACE(plugin, tout << "merge shared " << g.bpp(s1.n->get_root()) << " and " << g.bpp(s2.n->get_root()) << "\n"); push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j)))); } } @@ -1171,31 +1375,32 @@ namespace euf { auto old_m = s.m; auto old_n = monomial(old_m).m_src; ptr_vector m1(monomial(old_m).m_nodes); - TRACE(plugin, tout << "simplify shared: " << g.bpp(old_n) << ": " << m_pp_ll(*this, monomial(old_m)) << "\n"); + TRACE(plugin_verbose, tout << "simplify shared: " << g.bpp(old_n) << ": " << m_pp_ll(*this, monomial(old_m)) << "\n"); if (!reduce(m1, j)) return; - auto new_n = from_monomial(m1); + enode* new_n = nullptr; + new_n = m1.empty() ? get_unit(s.n) : from_monomial(m1); auto new_m = to_monomial(new_n, m1); // update shared occurrences for members of the new monomial that are not already in the old monomial. for (auto n : monomial(old_m)) - n->root->n->mark1(); + n->n->mark2(); for (auto n : m1) { - if (!n->root->n->is_marked1()) { - n->root->shared.push_back(idx); + if (!n->n->is_marked2()) { + n->shared.push_back(idx); m_shared_todo.insert(idx); - m_node_trail.push_back(n->root); + m_node_trail.push_back(n); push_undo(is_add_shared_index); } } for (auto n : monomial(old_m)) - n->root->n->unmark1(); + n->n->unmark2(); m_update_shared_trail.push_back({ idx, s }); push_undo(is_update_shared); m_shared[idx].m = new_m; m_shared[idx].j = j; - TRACE(plugin, tout << "shared simplified to " << m_pp_ll(*this, monomial(new_m)) << "\n"); + TRACE(plugin_verbose, tout << "shared simplified to " << m_pp_ll(*this, monomial(new_m)) << "\n"); push_merge(old_n, new_n, j); } @@ -1215,8 +1420,8 @@ namespace euf { justification::dependency* ac_plugin::justify_monomial(justification::dependency* j, monomial_t const& m) { for (auto n : m) - if (n->root->n != n->n) - j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n))); + if (n->n != n->n) + j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->n, n->n))); return j; } diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index 6dfa483bb..c516e46a3 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -36,37 +36,19 @@ namespace euf { class ac_plugin : public plugin { - // enode structure for AC equivalences - struct node { - enode* n; // associated enode - node* root; // path compressed root - node* next; // next in equivalence class - justification j; // justification for equality - node* target = nullptr; // justified next - unsigned_vector shared; // shared occurrences - unsigned_vector eqs; // equality occurrences - - unsigned id() const { return root->n->get_id(); } - static node* mk(region& r, enode* n); + struct stats { + unsigned m_num_superpositions = 0;// number of superpositions }; - class equiv { - node& n; - public: - class iterator { - node* m_first; - node* m_last; - public: - iterator(node* n, node* m) : m_first(n), m_last(m) {} - node* operator*() { return m_first; } - iterator& operator++() { if (!m_last) m_last = m_first; m_first = m_first->next; return *this; } - iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; } - bool operator!=(iterator const& other) const { return m_last != other.m_last || m_first != other.m_first; } - }; - equiv(node& _n) :n(_n) {} - equiv(node* _n) :n(*_n) {} - iterator begin() const { return iterator(&n, nullptr); } - iterator end() const { return iterator(&n, &n); } + // enode structure for AC equivalences + struct node { + enode* n; // associated enode + unsigned_vector shared; // shared occurrences + unsigned_vector eqs; // equality occurrences + bool is_zero = false; + + unsigned id() const { return n->get_id(); } + static node* mk(region& r, enode* n); }; struct bloom { @@ -75,7 +57,7 @@ namespace euf { }; enum eq_status { - processed, to_simplify, is_dead + processed, to_simplify, is_reducing_eq, is_dead }; // represent equalities added by merge_eh and by superposition @@ -150,6 +132,10 @@ namespace euf { tracked_uint_set m_shared_todo; uint64_t m_tick = 1; symbol m_name; + unsigned m_fuel = 0; + unsigned m_fuel_inc = 3; + stats m_stats; + mutable symbol m_superposition_stats, m_eqs_stats; @@ -163,7 +149,6 @@ namespace euf { is_add_eq, is_add_monomial, is_add_node, - is_merge_node, is_update_eq, is_add_shared_index, is_add_eq_index, @@ -200,14 +185,35 @@ namespace euf { bool can_be_subset(monomial_t& subset, ptr_vector const& m, bloom& b); bool are_equal(ptr_vector const& a, ptr_vector const& b); bool are_equal(monomial_t& a, monomial_t& b); + bool are_equal(eq const& a, eq const& b) { + return are_equal(monomial(a.l), monomial(b.l)) && are_equal(monomial(a.r), monomial(b.r)); + } + bool well_formed(eq const& e) const; + bool is_reducing(eq const& e) const; + void forward_reduce(unsigned eq_id); + void forward_reduce(eq const& src, unsigned dst); + bool forward_reduce_monomial(eq const& eq, monomial_t& m); + void backward_subsume_new_eqs(); + bool is_backward_subsumed(unsigned dst_eq); bool backward_subsumes(unsigned src_eq, unsigned dst_eq); bool forward_subsumes(unsigned src_eq, unsigned dst_eq); - void init_equation(eq const& e); + enode_vector m_units; + enode* get_unit(enode* n) const { + for (auto u : m_units) { + if (u->get_sort() == n->get_sort()) + return u; + } + UNREACHABLE(); + return nullptr; + } + + bool init_equation(eq const& e); bool orient_equation(eq& e); void set_status(unsigned eq_id, eq_status s); unsigned pick_next_eq(); + unsigned_vector m_new_eqs; void forward_simplify(unsigned eq_id, unsigned using_eq); bool backward_simplify(unsigned eq_id, unsigned using_eq); bool superpose(unsigned src_eq, unsigned dst_eq); @@ -249,6 +255,7 @@ namespace euf { bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; } bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; } + bool is_reducing(unsigned eq) const { return m_eqs[eq].status == eq_status::is_reducing_eq; } bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; } justification justify_rewrite(unsigned eq1, unsigned eq2); @@ -279,6 +286,10 @@ namespace euf { ac_plugin(egraph& g, func_decl* f); void set_injective() { m_is_injective = true; } + + void add_unit(enode*); + + void add_zero(enode*); theory_id get_id() const override { return m_fid; } @@ -294,6 +305,8 @@ namespace euf { std::ostream& display(std::ostream& out) const override; + void collect_statistics(statistics& st) const override; + void set_undo(std::function u) { m_undo_notify = u; } struct eq_pp { diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp index b1f2bc28e..df9207a3b 100644 --- a/src/ast/euf/euf_arith_plugin.cpp +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -31,6 +31,25 @@ namespace euf { std::function umul = [&]() { m_undo.push_back(undo_t::undo_mul); }; m_mul.set_undo(umul); m_add.set_injective(); + auto e = a.mk_int(0); + auto n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr); + m_add.add_unit(n); + m_mul.add_zero(n); + + e = a.mk_real(0); + n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr); + m_add.add_unit(n); + m_mul.add_zero(n); + + e = a.mk_int(1); + n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr); + m_mul.add_unit(n); + + e = a.mk_real(1); + n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr); + m_mul.add_unit(n); + + } void arith_plugin::register_node(enode* n) { diff --git a/src/ast/euf/euf_arith_plugin.h b/src/ast/euf/euf_arith_plugin.h index 0cc122d99..1852c1a28 100644 --- a/src/ast/euf/euf_arith_plugin.h +++ b/src/ast/euf/euf_arith_plugin.h @@ -46,6 +46,11 @@ namespace euf { void propagate() override; std::ostream& display(std::ostream& out) const override; + + void collect_statistics(statistics& st) const override { + m_add.collect_statistics(st); + m_mul.collect_statistics(st); + } }; } diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 7dcc49dcf..040a679b6 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -117,6 +117,7 @@ namespace euf { enode* egraph::mk(expr* f, unsigned generation, unsigned num_args, enode *const* args) { SASSERT(!find(f)); + TRACE(euf, tout << "mk: " << mk_bounded_pp(f, m) << " generation: " << generation << " num_args: " << num_args << "\n";); force_push(); enode *n = mk_enode(f, generation, num_args, args); @@ -157,6 +158,21 @@ namespace euf { } void egraph::propagate_plugins() { + if (m_plugins.empty()) + return; + if (m_plugin_qhead < m_new_th_eqs.size()) + m_updates.push_back(update_record(m_plugin_qhead, update_record::plugin_qhead())); + + for (; m_plugin_qhead < m_new_th_eqs.size(); ++m_plugin_qhead) { + auto const& eq = m_new_th_eqs[m_plugin_qhead]; + auto* p = get_plugin(eq.id()); + if (!p) + continue; + if (eq.is_eq()) + p->merge_eh(eq.child(), eq.root()); + else + p->diseq_eh(eq.eq()); + } for (auto* p : m_plugins) if (p) p->propagate(); @@ -167,23 +183,18 @@ namespace euf { m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r)); m_updates.push_back(update_record(update_record::new_th_eq())); ++m_stats.m_num_th_eqs; - auto* p = get_plugin(id); - if (p) - p->merge_eh(c, r); } void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) { if (!th_propagates_diseqs(id)) return; TRACE(euf_verbose, tout << "eq: " << v1 << " != " << v2 << "\n";); - m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr())); + m_new_th_eqs.push_back(th_eq(id, v1, v2, eq)); m_updates.push_back(update_record(update_record::new_th_eq())); - auto* p = get_plugin(id); - if (p) - p->diseq_eh(eq); + ++m_stats.m_num_th_diseqs; } - + void egraph::add_literal(enode* n, enode* ante) { TRACE(euf, tout << "propagate " << bpp(n) << " " << bpp(ante) << "\n"); if (!m_on_propagate_literal) @@ -447,6 +458,9 @@ namespace euf { case update_record::tag_t::is_new_th_eq_qhead: m_new_th_eqs_qhead = p.qhead; break; + case update_record::tag_t::is_plugin_qhead: + m_plugin_qhead = p.qhead; + break; case update_record::tag_t::is_inconsistent: m_inconsistent = p.m_inconsistent; break; @@ -546,16 +560,18 @@ namespace euf { void egraph::remove_parents(enode* r) { TRACE(euf_verbose, tout << bpp(r) << "\n"); SASSERT(all_of(enode_parents(r), [&](enode* p) { return !p->is_marked1(); })); + TRACE(euf, tout << "remove_parents " << bpp(r) << "\n"); for (enode* p : enode_parents(r)) { if (p->is_marked1()) continue; if (p->cgc_enabled()) { if (!p->is_cgr()) continue; + TRACE(euf, tout << "removing " << m_table.contains_ptr(p) << " " << bpp(p) << "\n"); SASSERT(m_table.contains_ptr(p)); p->mark1(); erase_from_table(p); - CTRACE(euf_verbose, m_table.contains_ptr(p), tout << bpp(p) << "\n"; display(tout)); + CTRACE(euf, m_table.contains_ptr(p), tout << bpp(p) << "\n"; display(tout)); SASSERT(!m_table.contains_ptr(p)); } else if (p->is_equality()) @@ -564,15 +580,16 @@ namespace euf { } void egraph::reinsert_parents(enode* r1, enode* r2) { + TRACE(euf, tout << "reinsert_parents " << bpp(r1) << " " << bpp(r2) << "\n";); for (enode* p : enode_parents(r1)) { if (!p->is_marked1()) continue; p->unmark1(); - TRACE(euf_verbose, tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";); + TRACE(euf, tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";); if (p->cgc_enabled()) { auto [p_other, comm] = insert_table(p); SASSERT(m_table.contains_ptr(p) == (p_other == p)); - CTRACE(euf_verbose, p_other != p, tout << "reinsert " << bpp(p) << " == " << bpp(p_other) << " " << p->value() << " " << p_other->value() << "\n"); + CTRACE(euf, p_other != p, tout << "reinsert " << bpp(p) << " == " << bpp(p_other) << " " << p->value() << " " << p_other->value() << "\n"); if (p_other != p) m_to_merge.push_back(to_merge(p_other, p, comm)); else @@ -957,6 +974,9 @@ namespace euf { st.update("euf propagations theory eqs", m_stats.m_num_th_eqs); st.update("euf propagations theory diseqs", m_stats.m_num_th_diseqs); st.update("euf propagations literal", m_stats.m_num_lits); + for (auto p : m_plugins) + if (p) + p->collect_statistics(st); } void egraph::copy_from(egraph const& src, std::function& copy_justification) { diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 53a0b7da2..ba0712e3b 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -58,7 +58,7 @@ namespace euf { theory_var m_v2; union { enode* m_child; - expr* m_eq; + enode* m_eq; }; enode* m_root; public: @@ -68,10 +68,10 @@ namespace euf { theory_var v2() const { return m_v2; } enode* child() const { SASSERT(is_eq()); return m_child; } enode* root() const { SASSERT(is_eq()); return m_root; } - expr* eq() const { SASSERT(!is_eq()); return m_eq; } + enode* eq() const { SASSERT(!is_eq()); return m_eq; } th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) : m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {} - th_eq(theory_id id, theory_var v1, theory_var v2, expr* eq) : + th_eq(theory_id id, theory_var v1, theory_var v2, enode* eq) : m_id(id), m_v1(v1), m_v2(v2), m_eq(eq), m_root(nullptr) {} }; @@ -116,6 +116,7 @@ namespace euf { struct replace_th_var {}; struct new_th_eq {}; struct new_th_eq_qhead {}; + struct plugin_qhead {}; struct inconsistent {}; struct value_assignment {}; struct lbl_hash {}; @@ -125,7 +126,7 @@ namespace euf { struct plugin_undo {}; enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children, is_add_th_var, is_replace_th_var, is_new_th_eq, - is_lbl_hash, is_new_th_eq_qhead, + is_lbl_hash, is_new_th_eq_qhead, is_plugin_qhead, is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant, is_plugin_undo }; tag_t tag; @@ -158,6 +159,8 @@ namespace euf { tag(tag_t::is_new_th_eq), r1(nullptr), n1(nullptr), r2_num_parents(0) {} update_record(unsigned qh, new_th_eq_qhead): tag(tag_t::is_new_th_eq_qhead), r1(nullptr), n1(nullptr), qhead(qh) {} + update_record(unsigned qh, plugin_qhead) : + tag(tag_t::is_plugin_qhead), r1(nullptr), n1(nullptr), qhead(qh) {} update_record(bool inc, inconsistent) : tag(tag_t::is_inconsistent), r1(nullptr), n1(nullptr), m_inconsistent(inc) {} update_record(enode* n, value_assignment) : @@ -196,6 +199,7 @@ namespace euf { enode *m_n2 = nullptr; justification m_justification; unsigned m_new_th_eqs_qhead = 0; + unsigned m_plugin_qhead = 0; svector m_new_th_eqs; bool_vector m_th_propagates_diseqs; enode_vector m_todo; diff --git a/src/ast/euf/euf_plugin.cpp b/src/ast/euf/euf_plugin.cpp index c6efe521b..4146ea996 100644 --- a/src/ast/euf/euf_plugin.cpp +++ b/src/ast/euf/euf_plugin.cpp @@ -35,7 +35,7 @@ namespace euf { void plugin::push_merge(enode* a, enode* b) { if (a->get_root() == b->get_root()) return; // already merged - TRACE(plugin, tout << g.bpp(a) << " == " << g.bpp(b) << "\n"); + TRACE(plugin, tout << "push-merge " << g.bpp(a) << " == " << g.bpp(b) << "\n"); g.push_merge(a, b, justification::axiom(get_id())); } diff --git a/src/ast/euf/euf_plugin.h b/src/ast/euf/euf_plugin.h index 8dbd4d7e7..edce49150 100644 --- a/src/ast/euf/euf_plugin.h +++ b/src/ast/euf/euf_plugin.h @@ -19,6 +19,7 @@ Author: #pragma once +#include "util/statistics.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_justification.h" @@ -53,6 +54,8 @@ namespace euf { virtual void undo() = 0; virtual std::ostream& display(std::ostream& out) const = 0; + + virtual void collect_statistics(statistics& st) const {} }; } diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index f711fc19e..8b9289934 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -68,6 +68,8 @@ namespace euf { m_mam(mam::mk(*this, *this)), m_canonical(m), m_eargs(m), + m_expr_trail(m), + m_consequences(m), m_canonical_proofs(m), // m_infer_patterns(m, m_smt_params), m_deps(m), @@ -135,6 +137,7 @@ namespace euf { }; m_matcher.set_on_match(on_match); + } completion::~completion() { @@ -230,15 +233,51 @@ namespace euf { read_egraph(); IF_VERBOSE(1, verbose_stream() << "(euf.completion :rounds " << rounds << " :instances " << m_stats.m_num_instances << " :stop " << should_stop() << ")\n"); } + map_congruences(); + for (auto c : m_consequences) + add_consequence(c); + + TRACE(euf_completion, m_egraph.display(tout)); + } + + void completion::map_congruences() { + unsigned sz = qtail(); + for (unsigned i = qhead(); i < sz; ++i) { + auto [f, p, d] = m_fmls[i](); + if (is_app(f) && to_app(f)->get_num_args() == 1 && symbol("congruences") == to_app(f)->get_decl()->get_name()) + map_congruence(to_app(f)->get_arg(0)); + } + } + + void completion::map_congruence(expr* t) { + auto n = m_egraph.find(t); + if (!n) + return; + ptr_vector args; + for (auto s : enode_class(n)) { + expr_ref r(s->get_expr(), m); + m_rewriter(r); + args.push_back(r); + } + expr_ref cong(m); + cong = m.mk_app(symbol("congruence"), args.size(), args.data(), m.mk_bool_sort()); + m_fmls.add(dependent_expr(m, cong, nullptr, nullptr)); + } + + void completion::add_consequence(expr* f) { + expr_ref r(f, m); + m_rewriter(r); + f = r.get(); + // verbose_stream() << r << "\n"; + auto cons = m.mk_app(symbol("consequence"), 1, &f, m.mk_bool_sort()); + m_fmls.add(dependent_expr(m, cons, nullptr, nullptr)); } void completion::add_egraph() { m_nodes_to_canonize.reset(); unsigned sz = qtail(); - for (unsigned i = qhead(); i < sz; ++i) { auto [f, p, d] = m_fmls[i](); - add_constraint(f, p, d); } m_should_propagate = true; @@ -248,6 +287,7 @@ namespace euf { m_mam->propagate(); flush_binding_queue(); propagate_rules(); + propagate_closures(); IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n"); if (!m_should_propagate && !should_stop()) propagate_all_rules(); @@ -271,7 +311,7 @@ namespace euf { for (auto* ch : enode_args(n)) m_nodes_to_canonize.push_back(ch); }; - expr* x, * y; + expr* x = nullptr, * y = nullptr; if (m.is_eq(f, x, y)) { expr_ref x1(x, m); expr_ref y1(y, m); @@ -285,16 +325,20 @@ namespace euf { if (a->get_root() == b->get_root()) return; m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); + m_egraph.propagate(); add_children(a); add_children(b); auto a1 = mk_enode(x); if (a1->get_root() != a->get_root()) { m_egraph.merge(a, a1, nullptr); + m_egraph.propagate(); add_children(a1); } auto b1 = mk_enode(y); if (b1->get_root() != b->get_root()) { + TRACE(euf, tout << "merge and propagate\n"); m_egraph.merge(b, b1, nullptr); + m_egraph.propagate(); add_children(b1); } @@ -310,6 +354,7 @@ namespace euf { add_quantifiers(f); auto j = to_ptr(push_pr_dep(pr, d)); m_egraph.new_diseq(n, j); + m_egraph.propagate(); add_children(n); m_should_propagate = true; if (m_side_condition_solver) @@ -322,6 +367,7 @@ namespace euf { return; IF_VERBOSE(1, verbose_stream() << "fml: " << mk_pp(f, m) << "\n"); m_egraph.merge(n, m_tt, to_ptr(push_pr_dep(pr, d))); + m_egraph.propagate(); add_children(n); if (is_forall(f)) { quantifier* q = to_quantifier(f); @@ -352,7 +398,7 @@ namespace euf { } add_rule(f, pr, d); - if (!is_forall(f) && !m.is_implies(f)) { + if (!is_forall(f) && !m.is_implies(f) && !m.is_or(f)) { add_quantifiers(f); if (m_side_condition_solver) m_side_condition_solver->add_constraint(f, pr, d); @@ -388,18 +434,27 @@ namespace euf { else if (is_quantifier(t)) { auto q = to_quantifier(t); auto nd = q->get_num_decls(); - verbose_stream() << "bind " << mk_pp(q, m) << "\n"; + IF_VERBOSE(1, verbose_stream() << "bind " << mk_pp(q, m) << "\n"); for (unsigned i = 0; i < nd; ++i) { auto name = std::string("bound!") + std::to_string(bound.size()); auto b = m.mk_const(name, q->get_decl_sort(i)); - // TODO: persist bound variables withn scope to avoid reference count crashes + if (b->get_ref_count() == 0) { + m_expr_trail.push_back(b); + get_trail().push(push_back_vector(m_expr_trail)); + } bound.push_back(b); } expr_ref inst = var_subst(m)(q->get_expr(), bound); + if (!m_egraph.find(inst)) { + expr_ref clos(m); m_closures.insert(q, { bound, inst }); get_trail().push(insert_map(m_closures, q)); - mk_enode(inst); + // ensure that inst occurs in a foreign context to enable equality propagation + // on inst. + func_decl* f = m.mk_func_decl(symbol("clos!"), inst->get_sort(), m.mk_bool_sort()); + clos = m.mk_app(f, inst); + mk_enode(clos); // TODO: handle nested quantifiers after m_closures is updated to // index on sort declaration prefix together with quantifier // add_quantifiers(bound, inst); @@ -445,13 +500,31 @@ namespace euf { void completion::add_rule(expr* f, proof* pr, expr_dependency* d) { expr* x = nullptr, * y = nullptr; - if (!m.is_implies(f, x, y)) - return; expr_ref_vector body(m); proof_ref pr_i(m), pr0(m); expr_ref_vector prs(m); - expr_ref head(y, m); - body.push_back(x); + expr_ref head(m); + if (m.is_implies(f, x, y)) { + head = y; + body.push_back(x); + } + else if (m.is_or(f)) { + auto a = to_app(f); + for (auto arg : *to_app(f)) { + if (m.is_eq(arg)) { + if (head) + return; + head = arg; + } + else + body.push_back(arg); + } + if (!head) + return; + } + else + return; + flatten_and(body); unsigned j = 0; flet _propagate_with_solver(m_propagate_with_solver, true); @@ -552,6 +625,39 @@ namespace euf { } } + void completion::propagate_closures() { + for (auto [q, clos] : m_closures) { + expr* body = clos.second; + auto n = m_egraph.find(body); + SASSERT(n); +#if 0 + verbose_stream() << "class of " << mk_pp(body, m) << "\n"; + for (auto s : euf::enode_class(n)) { + verbose_stream() << mk_pp(s->get_expr(), m) << "\n"; + } +#endif + if (n->is_root()) + continue; + auto qn = m_egraph.find(q); +#if 0 + verbose_stream() << "class of " << mk_pp(q, m) << "\n"; + for (auto s : euf::enode_class(qn)) { + verbose_stream() << mk_pp(s->get_expr(), m) << "\n"; + } +#endif + expr_ref new_body = expr_ref(n->get_root()->get_expr(), m); + expr_ref new_q = expr_abstract(m, clos.first, new_body); + new_q = m.update_quantifier(q, new_q); + auto new_qn = m_egraph.find(new_q); + if (!new_qn) + new_qn = m_egraph.mk(new_q, qn->generation(), 0, nullptr); + if (new_qn->get_root() == qn->get_root()) + continue; + m_egraph.merge(new_qn, qn, nullptr); // todo track dependencies + m_should_propagate = true; + } + } + binding* completion::tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding) { if (q->get_num_decls() > m_tmp_binding_capacity) { void* mem = memory::allocate(sizeof(binding) + q->get_num_decls() * sizeof(euf::enode*)); @@ -643,11 +749,12 @@ namespace euf { void completion::apply_binding(binding& b, quantifier* q, expr_ref_vector const& s) { var_subst subst(m); - expr_ref r = subst(q->get_expr(), s); + expr_ref r = subst(q->get_expr(), s); scoped_generation sg(*this, b.m_max_top_generation + 1); auto [pr, d] = get_dependency(q); if (pr) pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), s.size(), s.data()); + m_consequences.push_back(r); add_constraint(r, pr, d); propagate_rules(); m_egraph.propagate(); @@ -788,7 +895,7 @@ namespace euf { if (x1 == y1) r = expr_ref(m.mk_true(), m); else { - expr* c = get_canonical(x, pr3, d); + auto c = get_canonical(x, pr3, d); if (c == x1) r = m_rewriter.mk_eq(y1, c); else if (c == y1) @@ -832,8 +939,6 @@ namespace euf { } expr_ref completion::canonize(expr* f, proof_ref& pr, expr_dependency_ref& d) { - if (is_quantifier(f)) - return expr_ref(canonize(to_quantifier(f), pr, d), m); if (!is_app(f)) return expr_ref(f, m); // todo could normalize ground expressions under quantifiers @@ -862,13 +967,30 @@ namespace euf { return r; } - expr_ref completion::canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d) { + expr_ref completion::get_canonical(quantifier* q, proof_ref& pr, expr_dependency_ref& d) { std::pair, expr*> clos; +// verbose_stream() << "canonize " << mk_pp(q, m) << "\n"; if (!m_closures.find(q, clos)) return expr_ref(q, m); expr* body = clos.second; - expr_ref new_body = canonize(body, pr, d); + auto n = m_egraph.find(body); + SASSERT(n); +#if 0 + verbose_stream() << "class of " << mk_pp(body, m) << "\n"; + for (auto s : euf::enode_class(n)) { + verbose_stream() << mk_pp(s->get_expr(), m) << "\n"; + } +#endif + n = m_egraph.find(q); +#if 0 + verbose_stream() << "class of " << mk_pp(q, m) << "\n"; + for (auto s : euf::enode_class(n)) { + verbose_stream() << mk_pp(s->get_expr(), m) << "\n"; + } +#endif + expr_ref new_body = get_canonical(body, pr, d); expr_ref result = expr_abstract(m, clos.first, new_body); + result = m.update_quantifier(q, result); if (m.proofs_enabled()) { // add proof rule // @@ -881,10 +1003,28 @@ namespace euf { } - expr* completion::get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d) { + expr_ref completion::get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d) { + expr_ref e(m); + if (has_quantifiers(f)) { + if (is_quantifier(f)) + return get_canonical(to_quantifier(f), pr, d); + else if (is_app(f)) { + expr_ref_vector args(m); + for (auto arg : *to_app(f)) { + // TODO: pr reconstruction + args.push_back(get_canonical(arg, pr, d)); + } + e = m.mk_app(to_app(f)->get_decl(), args); + if (!m_egraph.find(e)) + return e; + f = e; + } + else + UNREACHABLE(); + + } enode* n = m_egraph.find(f); - if (!n) verbose_stream() << "not found " << f->get_id() << " " << mk_pp(f, m) << "\n"; enode* r = n->get_root(); d = m.mk_join(d, explain_eq(n, r)); d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); @@ -894,7 +1034,7 @@ namespace euf { pr = m.mk_transitivity(pr, get_canonical_proof(r)); } SASSERT(m_canonical.get(r->get_id())); - return m_canonical.get(r->get_id()); + return expr_ref(m_canonical.get(r->get_id()), m); } expr* completion::get_canonical(enode* n) { @@ -990,6 +1130,7 @@ namespace euf { void completion::collect_statistics(statistics& st) const { st.update("euf-completion-rewrites", m_stats.m_num_rewrites); st.update("euf-completion-instances", m_stats.m_num_instances); + m_egraph.collect_statistics(st); } bool completion::is_gt(expr* lhs, expr* rhs) const { @@ -1098,8 +1239,8 @@ namespace euf { proof_ref pr(m); prs.reset(); for (enode* arg : enode_args(rep)) { - enode* rarg = arg->get_root(); - expr* c = get_canonical(rarg); + auto rarg = arg->get_root(); + auto c = get_canonical(rarg); if (c) { m_eargs.push_back(c); new_arg |= c != arg->get_expr(); diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index 6893e5327..8d1a936c7 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -128,7 +128,7 @@ namespace euf { enode* m_tt, *m_ff; ptr_vector m_todo; enode_vector m_args, m_reps, m_nodes_to_canonize; - expr_ref_vector m_canonical, m_eargs; + expr_ref_vector m_canonical, m_eargs, m_expr_trail, m_consequences; proof_ref_vector m_canonical_proofs; // pattern_inference_rw m_infer_patterns; bindings m_bindings; @@ -166,11 +166,14 @@ namespace euf { void read_egraph(); expr_ref canonize(expr* f, proof_ref& pr, expr_dependency_ref& dep); expr_ref canonize_fml(expr* f, proof_ref& pr, expr_dependency_ref& dep); - expr* get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d); + expr_ref get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d); expr* get_canonical(enode* n); proof* get_canonical_proof(enode* n); void set_canonical(enode* n, expr* e, proof* pr); void add_constraint(expr*f, proof* pr, expr_dependency* d); + void map_congruences(); + void map_congruence(expr* t); + void add_consequence(expr* t); // Enable equality propagation inside of quantifiers // add quantifier bodies as closure terms to the E-graph. @@ -181,7 +184,7 @@ namespace euf { // Closure terms are re-abstracted by the canonizer. void add_quantifiers(ptr_vector& bound, expr* t); void add_quantifiers(expr* t); - expr_ref canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d); + expr_ref get_canonical(quantifier* q, proof_ref& pr, expr_dependency_ref& d); obj_map, expr*>> m_closures; expr_dependency* explain_eq(enode* a, enode* b); @@ -208,6 +211,7 @@ namespace euf { void propagate_rule(conditional_rule& r); void propagate_rules(); void propagate_all_rules(); + void propagate_closures(); void clear_propagation_queue(); ptr_vector m_propagation_queue; struct push_watch_rule; diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 380208abd..e594b8bc6 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -572,7 +572,7 @@ namespace arith { } void solver::new_diseq_eh(euf::th_eq const& e) { - TRACE(artih, tout << mk_bounded_pp(e.eq(), m) << "\n"); + TRACE(artih, tout << mk_bounded_pp(e.eq()->get_expr(), m) << "\n"); ensure_column(e.v1()); ensure_column(e.v2()); m_delayed_eqs.push_back(std::make_pair(e, false)); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 1632cd3e8..1695f5e41 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1147,7 +1147,7 @@ namespace arith { new_eq_eh(e); else if (is_eq(e.v1(), e.v2())) { mk_diseq_axiom(e.v1(), e.v2()); - TRACE(arith, tout << mk_bounded_pp(e.eq(), m) << " " << use_nra_model() << "\n"); + TRACE(arith, tout << mk_bounded_pp(e.eq()->get_expr(), m) << " " << use_nra_model() << "\n"); found_diseq = true; break; } diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index a17e1dd31..5ff0ff0ae 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -280,7 +280,7 @@ namespace bv { undef_idx--; sat::literal consequent = m_bits[v1][undef_idx]; sat::literal b = m_bits[v2][undef_idx]; - sat::literal antecedent = ~expr2literal(ne.eq()); + sat::literal antecedent = ~expr2literal(ne.eq()->get_expr()); SASSERT(s().value(antecedent) == l_true); SASSERT(s().value(consequent) == l_undef); SASSERT(s().value(b) != l_undef); diff --git a/src/util/trace_tags.def b/src/util/trace_tags.def index fadb42866..ffa631d7a 100644 --- a/src/util/trace_tags.def +++ b/src/util/trace_tags.def @@ -758,6 +758,7 @@ X(Global, pivot_bug, "pivot bug") X(Global, pivot_shape, "pivot shape") X(Global, pivot_stats, "pivot stats") X(Global, plugin, "plugin") +X(Global, plugin_verbose, "plugin verbose") X(Global, pob_queue, "pob queue") X(Global, poly_rewriter, "poly rewriter") X(Global, polynomial, "polynomial")