From a805e1f27d492ae4e8b8d46a1fb1679c122846bf Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 28 Nov 2023 12:50:43 -0800 Subject: [PATCH] fixes to AC plugin --- src/ast/euf/CMakeLists.txt | 2 + src/ast/euf/euf_ac_plugin.cpp | 474 +++++++++++++++------- src/ast/euf/euf_ac_plugin.h | 44 +- src/ast/euf/euf_arith_plugin.cpp | 17 +- src/ast/euf/euf_arith_plugin.h | 6 +- src/ast/euf/euf_bv_plugin.cpp | 4 +- src/ast/euf/euf_bv_plugin.h | 6 +- src/ast/euf/euf_egraph.cpp | 47 +-- src/ast/euf/euf_egraph.h | 5 +- src/ast/euf/euf_enode.cpp | 11 + src/ast/euf/euf_enode.h | 1 + src/ast/euf/euf_justification.cpp | 54 +++ src/ast/euf/euf_justification.h | 27 +- src/ast/euf/euf_plugin.h | 6 +- src/ast/euf/euf_specrel_plugin.cpp | 71 ++++ src/ast/euf/euf_specrel_plugin.h | 56 +++ src/ast/special_relations_decl_plugin.cpp | 48 ++- src/ast/special_relations_decl_plugin.h | 11 +- src/sat/smt/CMakeLists.txt | 1 + src/sat/smt/dt_solver.cpp | 2 +- src/sat/smt/dt_solver.h | 2 +- src/sat/smt/euf_solver.cpp | 5 +- src/sat/smt/specrel_solver.cpp | 119 ++++++ src/sat/smt/specrel_solver.h | 75 ++++ src/test/euf_arith_plugin.cpp | 37 +- src/test/euf_bv_plugin.cpp | 49 +-- 26 files changed, 887 insertions(+), 293 deletions(-) create mode 100644 src/ast/euf/euf_justification.cpp create mode 100644 src/ast/euf/euf_specrel_plugin.cpp create mode 100644 src/ast/euf/euf_specrel_plugin.h create mode 100644 src/sat/smt/specrel_solver.cpp create mode 100644 src/sat/smt/specrel_solver.h diff --git a/src/ast/euf/CMakeLists.txt b/src/ast/euf/CMakeLists.txt index 430ea2b08..aa71e7fba 100644 --- a/src/ast/euf/CMakeLists.txt +++ b/src/ast/euf/CMakeLists.txt @@ -6,7 +6,9 @@ z3_add_component(euf euf_egraph.cpp euf_enode.cpp euf_etable.cpp + euf_justification.cpp euf_plugin.cpp + euf_specrel_plugin.cpp COMPONENT_DEPENDENCIES ast util diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 272a14320..5b4b1df66 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -67,6 +67,7 @@ TODOs: #include "ast/euf/euf_ac_plugin.h" #include "ast/euf/euf_egraph.h" +#include "ast/ast_pp.h" namespace euf { @@ -74,7 +75,18 @@ namespace euf { plugin(g), m_fid(fid), m_op(op), m_dep_manager(get_region()), m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq) - {} + { + g.set_th_propagates_diseqs(m_fid); + } + + ac_plugin::ac_plugin(egraph& g, func_decl* f) : + plugin(g), m_decl(f), m_fid(f->get_family_id()), + m_dep_manager(get_region()), + m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq) + { + if (m_fid != null_family_id) + g.set_th_propagates_diseqs(m_fid); + } void ac_plugin::register_node(enode* n) { if (is_op(n)) @@ -85,16 +97,19 @@ namespace euf { } void ac_plugin::register_shared(enode* n) { + if (m_shared_nodes.get(n->get_id(), false)) + return; auto m = to_monomial(n); auto const& ns = monomial(m); for (auto arg : ns) { arg->shared.push_back(m); m_node_trail.push_back(arg); - push_undo(is_add_shared); + push_undo(is_add_shared_index); } + m_shared_nodes.setx(n->get_id(), true, false); sort(monomial(m)); + m_shared_todo.insert(m_shared.size()); m_shared.push_back({ n, m, justification::axiom() }); - m_shared_todo.insert(m); push_undo(is_register_shared); } @@ -103,11 +118,6 @@ namespace euf { m_undo.pop_back(); switch (k) { case is_add_eq: { - auto const& eq = m_eqs.back(); - for (auto* n : monomial(eq.l)) - n->eqs.pop_back(); - for (auto* n : monomial(eq.r)) - n->eqs.pop_back(); m_eqs.pop_back(); break; } @@ -138,13 +148,21 @@ namespace euf { m_update_eq_trail.pop_back(); break; } - case is_add_shared: { + case is_add_shared_index: { auto n = m_node_trail.back(); m_node_trail.pop_back(); n->shared.pop_back(); break; } + case is_add_eq_index: { + auto n = m_node_trail.back(); + m_node_trail.pop_back(); + n->eqs.pop_back(); + break; + } case is_register_shared: { + auto s = m_shared.back(); + m_shared_nodes[s.n->get_id()] = false; m_shared.pop_back(); break; } @@ -159,9 +177,13 @@ namespace euf { } } - std::ostream& ac_plugin::display_monomial(std::ostream& out, monomial_t const& m) const { - for (auto n : m) - out << g.bpp(n->n) << " "; + std::ostream& ac_plugin::display_monomial(std::ostream& out, ptr_vector const& m) const { + for (auto n : m) { + if (n->n->num_args() == 0) + out << mk_pp(n->n->get_expr(), g.get_manager()) << " "; + else + out << g.bpp(n->n) << " "; + } return out; } @@ -200,6 +222,8 @@ namespace euf { for (auto n : m_nodes) { if (!n) continue; + if (n->eqs.empty() && n->shared.empty()) + continue; out << g.bpp(n->n) << " r: " << n->root_id() << " "; if (!n->eqs.empty()) { out << "eqs "; @@ -216,25 +240,57 @@ namespace euf { return out; } - void ac_plugin::merge_eh(enode* l, enode* r, justification j) { + void ac_plugin::merge_eh(enode* l, enode* r) { if (l == r) return; + auto j = justification::equality(l, r); if (!is_op(l) && !is_op(r)) merge(mk_node(l), mk_node(r), j); else init_equation(eq(to_monomial(l), to_monomial(r), j)); } + void ac_plugin::diseq_eh(enode* eq) { + SASSERT(g.get_manager().is_eq(eq->get_expr())); + enode* a = eq->get_arg(0), * b = eq->get_arg(1); + a = a->get_closest_th_node(m_fid); + b = b->get_closest_th_node(m_fid); + SASSERT(a && b); + register_shared(a); + register_shared(b); + } + void ac_plugin::init_equation(eq const& e) { m_eqs.push_back(e); auto& eq = m_eqs.back(); if (orient_equation(eq)) { - push_undo(is_add_eq); + unsigned eq_id = m_eqs.size() - 1; - for (auto n : monomial(eq.l)) - n->eqs.push_back(eq_id); + + for (auto n : monomial(eq.l)) { + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq_id); + n->root->n->mark1(); + push_undo(is_add_eq_index); + m_node_trail.push_back(n->root); + } + } + + for (auto n : monomial(eq.r)) { + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq_id); + n->root->n->mark1(); + push_undo(is_add_eq_index); + m_node_trail.push_back(n->root); + } + } + + for (auto n : monomial(eq.l)) + n->root->n->unmark1(); + for (auto n : monomial(eq.r)) - n->eqs.push_back(eq_id); + n->root->n->unmark1(); + m_to_simplify_todo.insert(eq_id); } else @@ -298,6 +354,19 @@ namespace euf { return (f1 | f2) == f2; } + bool ac_plugin::can_be_subset(monomial_t& subset, ptr_vector const& m, bloom& bloom) { + if (subset.size() > m.size()) + return false; + if (bloom.m_tick != m_tick) { + bloom.m_filter = 0; + for (auto n : m) + bloom.m_filter |= (1ull << (n->root_id() % 64ull)); + bloom.m_tick = m_tick; + } + auto f2 = bloom.m_filter; + return (filter(subset) | f2) == f2; + } + void ac_plugin::merge(node* root, node* other, justification j) { for (auto n : equiv(other)) n->root = root; @@ -326,15 +395,10 @@ namespace euf { ns.push_back(n); for (unsigned i = 0; i < ns.size(); ++i) { n = ns[i]; - if (is_op(n)) { - ns.append(n->num_args(), n->args()); - ns[i] = ns.back(); - ns.pop_back(); - --i; - } - else { - m.push_back(mk_node(n)); - } + if (is_op(n)) + ns.append(n->num_args(), n->args()); + else + m.push_back(mk_node(n)); } return to_monomial(n, m); } @@ -367,7 +431,6 @@ namespace euf { } void ac_plugin::propagate() { - TRACE("plugin", display(tout)); while (true) { loop_start: unsigned eq_id = pick_next_eq(); @@ -377,10 +440,14 @@ namespace euf { TRACE("plugin", tout << "propagate " << eq_id << ": " << eq_pp(*this, m_eqs[eq_id]) << "\n"); // simplify eq using processed + for (auto other_eq : backward_iterator(eq_id)) + TRACE("plugin", tout << "backward iterator " << eq_id << " vs " << other_eq << " " << is_processed(other_eq) << "\n"); for (auto other_eq : backward_iterator(eq_id)) if (is_processed(other_eq) && backward_simplify(eq_id, other_eq)) goto loop_start; + set_status(eq_id, eq_status::processed); + // simplify processed using eq for (auto other_eq : forward_iterator(eq_id)) if (is_processed(other_eq)) @@ -395,10 +462,10 @@ namespace euf { for (auto other_eq : forward_iterator(eq_id)) if (is_to_simplify(other_eq)) forward_simplify(eq_id, other_eq); - - set_status(eq_id, eq_status::processed); } propagate_shared(); + + CTRACE("plugin", !m_shared.empty() || !m_eqs.empty(), display(tout)); } unsigned ac_plugin::pick_next_eq() { @@ -456,6 +523,8 @@ namespace euf { auto const& eq = m_eqs[eq_id]; init_ref_counts(monomial(eq.r), m_dst_r_counts); init_ref_counts(monomial(eq.l), m_dst_l_counts); + m_dst_r.reset(); + m_dst_r.append(monomial(eq.r).m_nodes); init_subset_iterator(eq_id, monomial(eq.r)); return m_eq_occurs; } @@ -479,12 +548,20 @@ namespace euf { node* max_n = nullptr; bool has_two = false; for (auto n : m) - if (n->root->eqs.size() > max_use) + if (n->root->eqs.size() >= max_use) has_two |= max_n && (max_n != n->root), max_n = n->root, max_use = n->root->eqs.size(); m_eq_occurs.reset(); - for (auto n : m) - if (n->root != max_n && has_two) + if (has_two) { + for (auto n : m) + if (n->root != max_n) + m_eq_occurs.append(n->root->eqs); + } + else { + for (auto n : m) { m_eq_occurs.append(n->root->eqs); + break; + } + } compress_eq_occurs(eq_id); } @@ -525,10 +602,26 @@ namespace euf { return min_n->eqs; } - void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) { - counts.reset(); - for (auto n : monomial) - counts.inc(n->root_id(), 1); + void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) const { + init_ref_counts(monomial.m_nodes, counts); + } + + void ac_plugin::init_ref_counts(ptr_vector const& monomial, ref_counts& counts) const { + counts.reset(); + for (auto n : monomial) + counts.inc(n->root_id(), 1); + } + + bool ac_plugin::is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const { + return is_correct_ref_count(m.m_nodes, counts); + } + + bool ac_plugin::is_correct_ref_count(ptr_vector const& m, ref_counts const& counts) const { + ref_counts check; + init_ref_counts(m, check); + return + all_of(counts, [&](unsigned i) { return check[i] == counts[i]; }) && + all_of(check, [&](unsigned i) { return check[i] == counts[i]; }); } void ac_plugin::forward_simplify(unsigned src_eq, unsigned dst_eq) { @@ -540,12 +633,14 @@ namespace euf { // dst = A -> BC // src = B -> D // post(dst) := A -> CD - auto& src = m_eqs[src_eq]; + auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized auto& dst = m_eqs[dst_eq]; TRACE("plugin", tout << "forward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n"); + if (forward_subsumes(src_eq, dst_eq)) { + TRACE("plugin", tout << "forward subsumed\n"); set_status(dst_eq, eq_status::is_dead); return; } @@ -553,34 +648,49 @@ namespace euf { if (!can_be_subset(monomial(src.l), monomial(dst.r))) return; + m_dst_r_counts.reset(); unsigned src_l_size = monomial(src.l).size(); unsigned src_r_size = m_src_r.size(); + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); // subtract src.l from dst.r if src.l is a subset of dst.r - // new_rhs := old_rhs - src_lhs + src_rhs + // dst_rhs := dst_rhs - src_lhs + src_rhs + // := src_rhs + (dst_rhs - src_lhs) + // := src_rhs + elements from dst_rhs that are in excess of src_lhs unsigned num_overlap = 0; for (auto n : monomial(dst.r)) { unsigned id = n->root_id(); - unsigned count = m_src_l_counts[id]; - if (count == 0) - m_src_r.push_back(n); - else if (m_dst_r_counts[id] >= count) + unsigned dst_count = m_dst_r_counts[id]; + unsigned src_count = m_src_l_counts[id]; + if (dst_count > src_count) { m_src_r.push_back(n); + m_dst_r_counts.dec(id, 1); + } + else if (dst_count < src_count) { + m_src_r.shrink(src_r_size); + return; + } else - m_dst_r_counts.inc(id, 1), ++num_overlap; + ++num_overlap; } // The dst.r has to be a superset of src.l, otherwise simplification does not apply - if (num_overlap == src_l_size) { - auto new_r = to_monomial(nullptr, m_src_r); - m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); - m_eqs[dst_eq].r = new_r; - m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq); - push_undo(is_update_eq); - TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); + if (num_overlap != src_l_size) { + m_src_r.shrink(src_r_size); + return; } - m_src_r.shrink(src_r_size); + auto j = justify_rewrite(src_eq, dst_eq); + reduce(m_src_r, j); + auto new_r = to_monomial(m_src_r); + index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r)); + m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); + m_eqs[dst_eq].r = new_r; + m_eqs[dst_eq].j = j; + push_undo(is_update_eq); + m_src_r.reset(); + m_src_r.append(monomial(src.r).m_nodes); + TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); } bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { @@ -588,25 +698,38 @@ namespace euf { return false; auto& src = m_eqs[src_eq]; - auto& dst = m_eqs[dst_eq]; + auto& dst = m_eqs[dst_eq]; // pre-computed dst_r_counts, dst_l_counts // // dst_ids, dst_count contain rhs of dst_eq // - TRACE("plugin", tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n"); - // check that src.l is a subset of dst.r - if (!can_be_subset(monomial(src.l), monomial(dst.r))) - return false; - if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) - return false; - if (backward_subsumes(src_eq, dst_eq)) { + TRACE("plugin", tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n"); + + if (backward_subsumes(src_eq, dst_eq)) { + TRACE("plugin", tout << "backward subsumed\n"); set_status(dst_eq, eq_status::is_dead); return true; } - // dst_rhs := dst_rhs - src_lhs + src_rhs - auto new_r = rewrite(monomial(src.r), monomial(dst.r)); + // check that src.l is a subset of dst.r + if (!can_be_subset(monomial(src.l), monomial(dst.r))) + return false; + if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l))) { + TRACE("plugin", tout << "not subset\n"); + return false; + } + + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); + + ptr_vector m(m_dst_r); + init_ref_counts(monomial(src.l), m_src_l_counts); + + rewrite1(m_src_l_counts, monomial(src.r), m_dst_r_counts, m); + auto j = justify_rewrite(src_eq, dst_eq); + reduce(m, j); + auto new_r = to_monomial(m); + index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r)); m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); m_eqs[dst_eq].r = new_r; - m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq); + m_eqs[dst_eq].j = j; TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); push_undo(is_update_eq); return true; @@ -618,6 +741,8 @@ namespace euf { bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) { auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; + SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); if (!can_be_subset(monomial(src.l), monomial(dst.l))) return false; if (!can_be_subset(monomial(src.r), monomial(dst.r))) @@ -625,13 +750,14 @@ namespace euf { unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) return false; - if (!is_superset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) + if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) return false; - if (!is_superset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) - return false; - // add difference betwen src and dst1 to dst2 - // (also add it to dst1 to make sure same difference isn't counted twice). - for (auto n : monomial(src.l)) { + if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) + return false; + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); + // add difference betwen dst.l and src.l to both src.l, src.r + for (auto n : monomial(dst.l)) { unsigned id = n->root_id(); SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]); unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; @@ -645,10 +771,12 @@ namespace euf { return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); } - // src_counts, src2_counts are initialized for src_eq + // src_l_counts, src_r_counts are initialized for src.l, src.r bool ac_plugin::forward_subsumes(unsigned src_eq, unsigned dst_eq) { auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); if (!can_be_subset(monomial(src.l), monomial(dst.l))) return false; if (!can_be_subset(monomial(src.r), monomial(dst.r))) @@ -658,52 +786,87 @@ namespace euf { return false; if (!is_superset(m_src_l_counts, m_dst_l_counts, monomial(dst.l))) return false; - if (!is_subset(m_src_r_counts, m_dst_r_counts, monomial(dst.r))) + if (!is_superset(m_src_r_counts, m_dst_r_counts, monomial(dst.r))) return false; + SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); for (auto n : monomial(src.l)) { unsigned id = n->root_id(); - SASSERT(m_src_l_counts[id] >= m_dst_l_counts[id]); - unsigned diff = m_src_l_counts[id] - m_dst_l_counts[id]; - if (diff > 0) { - m_dst_l_counts.inc(id, diff); - m_dst_r_counts.inc(id, diff); - } + SASSERT(m_src_l_counts[id] <= m_dst_l_counts[id]); + unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; + if (diff == 0) + continue; + m_dst_l_counts.dec(id, diff); + if (m_dst_r_counts[id] < diff) + return false; + m_dst_r_counts.dec(id, diff); } + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); } - unsigned ac_plugin::rewrite(monomial_t const& src_r, monomial_t const& dst_r) { - // pre-condition: is-subset is invoked so that m_src_count is initialized. - // pre-condition: m_dst_count is also initialized (once). - m_src_r.reset(); - m_src_r.append(src_r.m_nodes); - // add to m_src_r elements of dst.r that are not in src.l - for (auto n : dst_r) { + void ac_plugin::rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_counts, ptr_vector& dst) { + // pre-condition: is-subset is invoked so that src_l is initialized. + // pre-condition: dst_count is also initialized. + // remove from dst elements that are in src_l + // add elements from src_r + SASSERT(is_correct_ref_count(dst, dst_counts)); + SASSERT(&src_r.m_nodes != &dst); + unsigned sz = dst.size(), j = 0; + for (unsigned i = 0; i < sz; ++i) { + auto* n = dst[i]; unsigned id = n->root_id(); - unsigned count = m_src_l_counts[id]; - if (count == 0) - m_src_r.push_back(n); - else - m_src_l_counts.inc(id, -1); + unsigned dst_count = dst_counts[id]; + unsigned src_count = src_l[id]; + SASSERT(dst_count > 0); + if (src_count == 0) + dst[j++] = n; + else if (src_count < dst_count) { + dst[j++] = n; + dst_counts.dec(id, 1); + } } - return to_monomial(nullptr, m_src_r); + dst.shrink(j); + dst.append(src_r.m_nodes); + } + + // rewrite monomial to normal form. + bool ac_plugin::reduce(ptr_vector& m, justification& j) { + bool change = false; + do { + init_loop: + if (m.size() == 1) + return change; + bloom b; + init_ref_counts(m, m_m_counts); + for (auto n : m) { + for (auto eq : n->root->eqs) { + if (!is_processed(eq)) + continue; + auto& src = m_eqs[eq]; + + if (!can_be_subset(monomial(src.l), m, b)) + continue; + if (!is_subset(m_m_counts, m_eq_counts, monomial(src.l))) + continue; + TRACE("plugin", display_equation(tout << "reduce ", src) << "\n"); + SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts)); + rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m); + j = join(j, eq); + change = true; + goto init_loop; + } + } + } + while (false); + return change; } // check that src is a subset of dst, where dst_counts are precomputed bool ac_plugin::is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src) { SASSERT(&dst_counts != &src_counts); - src_counts.reset(); - for (auto n : src) { - unsigned id = n->root_id(); - unsigned dst_count = dst_counts[id]; - if (dst_count == 0) - return false; - else if (src_counts[id] >= dst_count) - return false; - else - src_counts.inc(id, 1); - } - return true; + init_ref_counts(src, src_counts); + return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; }); } // check that dst is a superset of src, where src_counts are precomputed @@ -713,6 +876,23 @@ namespace euf { return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; }); } + void ac_plugin::index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r) { + for (auto n : old_r) + n->root->n->mark1(); + for (auto n : new_r) + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq); + m_node_trail.push_back(n->root); + n->root->n->mark1(); + push_undo(is_add_eq_index); + } + for (auto n : old_r) + n->root->n->unmark1(); + for (auto n : new_r) + n->root->n->unmark1(); + } + + void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) { if (src_eq == dst_eq) return; @@ -733,12 +913,20 @@ namespace euf { // src_r contains E // compute BE, initialize dst_ids, dst_counts + bool overlap = false; for (auto n : monomial(dst.l)) { unsigned id = n->root_id(); + m_dst_l_counts.inc(id, 1); if (m_src_l_counts[id] < m_dst_l_counts[id]) m_src_r.push_back(n); - m_dst_l_counts.inc(id, 1); + overlap |= m_src_l_counts[id] > 0; } + + if (!overlap) { + m_src_r.shrink(src_r_size); + return; + } + // compute CD for (auto n : monomial(src.l)) { unsigned id = n->root_id(); @@ -753,16 +941,18 @@ namespace euf { return; } - TRACE("plugin", for (auto n : m_src_r) tout << g.bpp(n->n) << " "; tout << "== "; for (auto n : m_dst_r) tout << g.bpp(n->n) << " "; tout << "\n";); - + TRACE("plugin", tout << m_pp(*this, m_src_r) << "== " << m_pp(*this, m_dst_r) << "\n";); justification j = justify_rewrite(src_eq, dst_eq); + reduce(m_dst_r, j); + reduce(m_src_r, j); if (m_src_r.size() == 1 && m_dst_r.size() == 1) push_merge(m_src_r[0]->n, m_dst_r[0]->n, j); else - init_equation(eq(to_monomial(nullptr, m_src_r), to_monomial(nullptr, m_dst_r), j)); + init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j)); - m_src_r.shrink(src_r_size); + m_src_r.reset(); + m_src_r.append(monomial(src.r).m_nodes); } bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) { @@ -804,52 +994,41 @@ namespace euf { m_monomial_table.reset(); for (auto const& s1 : m_shared) { shared s2; - if (m_monomial_table.find(s1.m, s2)) { - if (s2.n->get_root() != s1.n->get_root()) - push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j)))); - } - else + TRACE("plugin", tout << "shared " << m_pp(*this, monomial(s1.m)) << "\n"); + if (!m_monomial_table.find(s1.m, s2)) m_monomial_table.insert(s1.m, s1); + else if (s2.n->get_root() != s1.n->get_root()) { + TRACE("plugin", tout << m_pp(*this, monomial(s1.m)) << " == " << m_pp(*this, monomial(s2.m)) << "\n"); + push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j)))); + } } } void ac_plugin::simplify_shared(unsigned idx, shared s) { - bool change = true; - while (change) { - change = false; - auto & m = monomial(s.m); - init_ref_counts(m, m_dst_l_counts); - init_subset_iterator(UINT_MAX, m); - for (auto eq : m_eq_occurs) { - auto& src = m_eqs[eq]; - if (!can_be_subset(monomial(src.l), m)) - continue; - if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) - continue; - m_update_shared_trail.push_back({ idx, s }); - push_undo(is_update_shared); - unsigned new_m = rewrite(monomial(src.r), m); - m_shared[idx].m = new_m; - m_shared[idx].j = justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s.j), justify_equation(eq))); - - // update shared occurrences for members of the new monomial that are not already in the old monomial. - for (auto n : monomial(s.m)) - n->root->n->mark1(); - for (auto n : monomial(new_m)) - if (!n->root->n->is_marked1()) { - n->root->shared.push_back(s.m); - m_shared_todo.insert(s.m); - m_node_trail.push_back(n->root); - push_undo(is_add_shared); - } - for (auto n : monomial(s.m)) - n->root->n->unmark1(); + auto j = s.j; + auto old_m = s.m; + ptr_vector m1(monomial(old_m).m_nodes); + TRACE("plugin", tout << "simplify " << m_pp(*this, monomial(old_m)) << "\n"); + if (!reduce(m1, j)) + return; - s = m_shared[idx]; - change = true; - break; + auto new_m = to_monomial(m1); + // update shared occurrences for members of the new monomial that are not already in the old monomial. + for (auto n : monomial(old_m)) + n->root->n->mark1(); + for (auto n : m1) + if (!n->root->n->is_marked1()) { + n->root->shared.push_back(idx); + m_shared_todo.insert(idx); + m_node_trail.push_back(n->root); + push_undo(is_add_shared_index); } - } + for (auto n : monomial(old_m)) + n->root->n->unmark1(); + m_update_shared_trail.push_back({ idx, s }); + push_undo(is_update_shared); + m_shared[idx].m = new_m; + m_shared[idx].j = j; } justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) { @@ -871,4 +1050,9 @@ namespace euf { j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n))); return j; } + + justification ac_plugin::join(justification j, unsigned eq) { + return justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(j), justify_equation(eq))); + } + } diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index 5fd1b272b..ea444b60c 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -101,6 +101,7 @@ namespace euf { bloom m_bloom; node* operator[](unsigned i) const { return m_nodes[i]; } unsigned size() const { return m_nodes.size(); } + void set(ptr_vector const& ns) { m_nodes.reset(); m_nodes.append(ns); m_bloom.m_tick = 0; } node* const* begin() const { return m_nodes.begin(); } node* const* end() const { return m_nodes.end(); } node* * begin() { return m_nodes.begin(); } @@ -136,10 +137,12 @@ namespace euf { } }; - unsigned m_fid; - unsigned m_op; + unsigned m_fid = 0; + unsigned m_op = null_decl_kind; + func_decl* m_decl = nullptr; vector m_eqs; ptr_vector m_nodes; + bool_vector m_shared_nodes; vector m_monomials; svector m_shared; justification::dependency_manager m_dep_manager; @@ -161,7 +164,8 @@ namespace euf { is_add_node, is_merge_node, is_update_eq, - is_add_shared, + is_add_shared_index, + is_add_eq_index, is_register_shared, is_update_shared }; @@ -177,19 +181,21 @@ namespace euf { node* mk_node(enode* n); void merge(node* r1, node* r2, justification j); - bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_decl_kind(); } + bool is_op(enode* n) const { auto d = n->get_decl(); return d && (d == m_decl || (m_fid == d->get_family_id() && m_op == d->get_decl_kind())); } std::function m_undo_notify; void push_undo(undo_kind k); enode_vector m_todo; unsigned to_monomial(enode* n); unsigned to_monomial(enode* n, ptr_vector const& ms); + unsigned to_monomial(ptr_vector const& ms) { return to_monomial(nullptr, ms); } monomial_t const& monomial(unsigned i) const { return m_monomials[i]; } monomial_t& monomial(unsigned i) { return m_monomials[i]; } void sort(monomial_t& monomial); bool is_sorted(monomial_t const& monomial) const; uint64_t filter(monomial_t& m); bool can_be_subset(monomial_t& subset, monomial_t& superset); + bool can_be_subset(monomial_t& subset, ptr_vector const& m, bloom& b); bool are_equal(ptr_vector const& a, ptr_vector const& b); bool are_equal(monomial_t& a, monomial_t& b); bool backward_subsumes(unsigned src_eq, unsigned dst_eq); @@ -216,14 +222,15 @@ namespace euf { unsigned const* begin() const { return ids.begin(); } unsigned const* end() const { return ids.end(); } }; - ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts; + ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts, m_m_counts; unsigned_vector m_eq_occurs; bool_vector m_eq_seen; unsigned_vector const& forward_iterator(unsigned eq); unsigned_vector const& superpose_iterator(unsigned eq); unsigned_vector const& backward_iterator(unsigned eq); - void init_ref_counts(monomial_t const& monomial, ref_counts& counts); + void init_ref_counts(monomial_t const& monomial, ref_counts& counts) const; + void init_ref_counts(ptr_vector const& monomial, ref_counts& counts) const; void init_overlap_iterator(unsigned eq, monomial_t const& m); void init_subset_iterator(unsigned eq, monomial_t const& m); void compress_eq_occurs(unsigned eq_id); @@ -232,7 +239,9 @@ namespace euf { // check that dst is a superset of dst, where src_counts are precomputed bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst); - unsigned rewrite(monomial_t const& src_r, monomial_t const& dst_r); + void rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_r_counts, ptr_vector& dst_r); + bool reduce(ptr_vector& m, justification& j); + void index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r); bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; } bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; } @@ -241,11 +250,17 @@ namespace euf { justification justify_rewrite(unsigned eq1, unsigned eq2); justification::dependency* justify_equation(unsigned eq); justification::dependency* justify_monomial(justification::dependency* d, monomial_t const& m); + justification join(justification j1, unsigned eq); + bool is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const; + bool is_correct_ref_count(ptr_vector const& m, ref_counts const& counts) const; + + void register_shared(enode* n); void propagate_shared(); void simplify_shared(unsigned idx, shared s); - std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const; + std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const { return display_monomial(out, m.m_nodes); } + std::ostream& display_monomial(std::ostream& out, ptr_vector const& m) const; std::ostream& display_equation(std::ostream& out, eq const& e) const; std::ostream& display_status(std::ostream& out, eq_status s) const; @@ -254,17 +269,17 @@ namespace euf { ac_plugin(egraph& g, unsigned fid, unsigned op); + ac_plugin(egraph& g, func_decl* f); + ~ac_plugin() override {} unsigned get_id() const override { return m_fid; } void register_node(enode* n) override; - void register_shared(enode* n) override; + void merge_eh(enode* n1, enode* n2) override; - void merge_eh(enode* n1, enode* n2, justification j) override; - - void diseq_eh(enode* n1, enode* n2) override {} + void diseq_eh(enode* eq) override; void undo() override; @@ -282,8 +297,9 @@ namespace euf { }; struct m_pp { - ac_plugin& p; monomial_t const& m; - m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m) {} + ac_plugin& p; ptr_vector const& m; + m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m.m_nodes) {} + m_pp(ac_plugin& p, ptr_vector const& m) : p(p), m(m) {} std::ostream& display(std::ostream& out) const { return p.display_monomial(out, m); } }; }; diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp index 3b134c640..26f8e0bd9 100644 --- a/src/ast/euf/euf_arith_plugin.cpp +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -36,20 +36,9 @@ namespace euf { // no-op } - void arith_plugin::register_shared(enode* n) { - if (a.is_add(n->get_expr())) - m_add.register_shared(n); - if (a.is_mul(n->get_expr())) - m_mul.register_shared(n); - } - - void arith_plugin::merge_eh(enode* n1, enode* n2, justification j) { - m_add.merge_eh(n1, n2, j); - m_mul.merge_eh(n1, n2, j); - } - - void arith_plugin::diseq_eh(enode* n1, enode* n2) { - // no-op + void arith_plugin::merge_eh(enode* n1, enode* n2) { + m_add.merge_eh(n1, n2); + m_mul.merge_eh(n1, n2); } void arith_plugin::propagate() { diff --git a/src/ast/euf/euf_arith_plugin.h b/src/ast/euf/euf_arith_plugin.h index 893b94a74..7cca01f1c 100644 --- a/src/ast/euf/euf_arith_plugin.h +++ b/src/ast/euf/euf_arith_plugin.h @@ -39,11 +39,9 @@ namespace euf { void register_node(enode* n) override; - void register_shared(enode* n) override; + void merge_eh(enode* n1, enode* n2) override; - void merge_eh(enode* n1, enode* n2, justification j) override; - - void diseq_eh(enode* n1, enode* n2) override; + void diseq_eh(enode* eq) override {} void undo() override; diff --git a/src/ast/euf/euf_bv_plugin.cpp b/src/ast/euf/euf_bv_plugin.cpp index d096eb801..99bf8941b 100644 --- a/src/ast/euf/euf_bv_plugin.cpp +++ b/src/ast/euf/euf_bv_plugin.cpp @@ -103,7 +103,7 @@ namespace euf { return mk(e, 0, nullptr); } - void bv_plugin::merge_eh(enode* x, enode* y, justification j) { + void bv_plugin::merge_eh(enode* x, enode* y) { SASSERT(x == x->get_root()); SASSERT(x == y->get_root()); @@ -120,7 +120,7 @@ namespace euf { ys.reset(); xs.push_back(x); ys.push_back(y); - merge(xs, ys, j); + merge(xs, ys, justification::equality(x, y)); } // ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ] diff --git a/src/ast/euf/euf_bv_plugin.h b/src/ast/euf/euf_bv_plugin.h index f7ae53f97..b8d62051e 100644 --- a/src/ast/euf/euf_bv_plugin.h +++ b/src/ast/euf/euf_bv_plugin.h @@ -86,11 +86,9 @@ namespace euf { void register_node(enode* n) override; - void register_shared(enode* n) override {} + void merge_eh(enode* n1, enode* n2) override; - void merge_eh(enode* n1, enode* n2, justification j) override; - - void diseq_eh(enode* n1, enode* n2) override {} + void diseq_eh(enode* eq) override {} void propagate() override {} diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index b0b098fff..caae38f0b 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -20,6 +20,7 @@ Notes: #include "ast/euf/euf_egraph.h" #include "ast/euf/euf_bv_plugin.h" #include "ast/euf/euf_arith_plugin.h" +#include "ast/euf/euf_specrel_plugin.h" #include "ast/ast_pp.h" #include "ast/ast_translation.h" @@ -115,7 +116,6 @@ namespace euf { n->mark_interpreted(); if (m_on_make) m_on_make(n); - register_node(n); if (num_args == 0) return n; @@ -134,22 +134,6 @@ namespace euf { return n; } - void egraph::register_node(enode* n) { - if (m_plugins.empty()) - return; - auto* p = get_plugin(n); - if (p) - p->register_node(n); - if (!n->is_equality()) { - for (auto* arg : enode_args(n)) { - auto* p_arg = get_plugin(arg); - if (p != p_arg) - p_arg->register_shared(arg); - } - } - - } - egraph::egraph(ast_manager& m) : m(m), m_table(m), m_tmp_app(2), m_exprs(m), m_eq_decls(m) { m_tmp_eq = enode::mk_tmp(m_region, 2); } @@ -162,6 +146,9 @@ namespace euf { } void egraph::add_plugins() { + if (!m_plugins.empty()) + return; + auto insert = [&](plugin* p) { m_plugins.reserve(p->get_id() + 1); m_plugins.set(p->get_id(), p); @@ -169,6 +156,7 @@ namespace euf { insert(alloc(bv_plugin, *this)); insert(alloc(arith_plugin, *this)); + insert(alloc(specrel_plugin, *this)); } void egraph::propagate_plugins() { @@ -182,14 +170,20 @@ namespace euf { m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r)); m_updates.push_back(update_record(update_record::new_th_eq())); ++m_stats.m_num_th_eqs; + auto* p = get_plugin(id); + if (p) + p->merge_eh(c, r); } - void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq) { + void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) { if (!th_propagates_diseqs(id)) return; TRACE("euf_verbose", tout << "eq: " << v1 << " != " << v2 << "\n";); - m_new_th_eqs.push_back(th_eq(id, v1, v2, eq)); + m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr())); m_updates.push_back(update_record(update_record::new_th_eq())); + auto* p = get_plugin(id); + if (p) + p->diseq_eh(eq); ++m_stats.m_num_th_diseqs; } @@ -238,7 +232,7 @@ namespace euf { return; theory_var v1 = arg1->get_closest_th_var(id); theory_var v2 = arg2->get_closest_th_var(id); - add_th_diseq(id, v1, v2, n->get_expr()); + add_th_diseq(id, v1, v2, n); return; } for (auto const& p : euf::enode_th_vars(r1)) { @@ -246,7 +240,7 @@ namespace euf { continue; for (auto const& q : euf::enode_th_vars(r2)) if (p.get_id() == q.get_id()) - add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n->get_expr()); + add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n); } } @@ -266,7 +260,7 @@ namespace euf { n = n->get_root(); theory_var v2 = n->get_closest_th_var(id); if (v2 != null_theory_var) - add_th_diseq(id, v1, v2, p->get_expr()); + add_th_diseq(id, v1, v2, p); } } } @@ -285,6 +279,10 @@ namespace euf { theory_var w = n->get_th_var(id); enode* r = n->get_root(); + auto* p = get_plugin(id); + if (p) + p->register_node(n); + if (w == null_theory_var) { n->add_th_var(v, id, m_region); m_updates.push_back(update_record(n, id, update_record::add_th_var())); @@ -529,10 +527,7 @@ namespace euf { for (auto& cb : m_on_merge) cb(r2, r1); - - auto* p = get_plugin(r1); - if (p) - p->merge_eh(r2, r1, j); + } void egraph::remove_parents(enode* r) { diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 44bee42ea..b7c9ef537 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -215,8 +215,6 @@ namespace euf { // plugin related methods void push_plugin_undo(unsigned th_id) { m_updates.push_back(update_record(th_id, update_record::plugin_undo())); } void push_merge(enode* a, enode* b, justification j) { m_to_merge.push_back({ a, b, j }); } - plugin* get_plugin(enode* n) { return m_plugins.get(n->get_sort()->get_family_id(), nullptr); } - void register_node(enode* n); void propagate_plugins(); void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r); @@ -259,6 +257,7 @@ namespace euf { egraph(ast_manager& m); ~egraph(); void add_plugins(); + plugin* get_plugin(family_id fid) const { return m_plugins.get(fid, nullptr); } enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); } enode* find(expr* f, unsigned n, enode* const* args); enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args); @@ -302,7 +301,7 @@ namespace euf { where \c n is an enode and \c is_eq indicates whether the enode is an equality consequence. */ - void add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq); + void add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq); bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); } th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; } void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } diff --git a/src/ast/euf/euf_enode.cpp b/src/ast/euf/euf_enode.cpp index 08df9f493..335c8f3e9 100644 --- a/src/ast/euf/euf_enode.cpp +++ b/src/ast/euf/euf_enode.cpp @@ -93,6 +93,17 @@ namespace euf { return null_theory_var; } + enode* enode::get_closest_th_node(theory_id id) { + enode* n = this; + while (n) { + theory_var v = n->get_th_var(id); + if (v != null_theory_var) + return n; + n = n->m_target; + } + return nullptr; + } + bool enode::acyclic() const { enode const* n = this; enode const* p = this; diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 4fc682f65..e9d917e9e 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -229,6 +229,7 @@ namespace euf { theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); } theory_var get_closest_th_var(theory_id id) const; + enode* get_closest_th_node(theory_id id); bool is_attached_to(theory_id id) const { return get_th_var(id) != null_theory_var; } bool has_th_vars() const { return !m_th_vars.empty(); } bool has_one_th_var() const { return !m_th_vars.empty() && !m_th_vars.get_next();} diff --git a/src/ast/euf/euf_justification.cpp b/src/ast/euf/euf_justification.cpp new file mode 100644 index 000000000..22b52ea84 --- /dev/null +++ b/src/ast/euf/euf_justification.cpp @@ -0,0 +1,54 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + euf_justification.cpp + +Abstract: + + justification structure for euf + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-23 + +--*/ + + +#include "ast/euf/euf_justification.h" +#include "ast/euf/euf_enode.h" + +namespace euf { + + + std::ostream& justification::display(std::ostream& out, std::function const& ext) const { + switch (m_kind) { + case kind_t::external_t: + if (ext) + ext(out, m_external); + else + out << "external"; + return out; + case kind_t::axiom_t: + return out << "axiom"; + case kind_t::congruence_t: + return out << "congruence"; + case kind_t::dependent_t: { + vector js; + out << "dependent"; + for (auto const& j : dependency_manager::s_linearize(m_dependency, js)) + j.display(out << " ", ext); + return out; + } + case kind_t::equality_t: + return out << "equality #" << m_n1->get_id() << " == #" << m_n2->get_id(); + + default: + UNREACHABLE(); + return out; + } + return out; + } + +} diff --git a/src/ast/euf/euf_justification.h b/src/ast/euf/euf_justification.h index b46b14067..f9d3b3637 100644 --- a/src/ast/euf/euf_justification.h +++ b/src/ast/euf/euf_justification.h @@ -118,31 +118,8 @@ namespace euf { } } - std::ostream& display(std::ostream& out, std::function const& ext) const { - switch (m_kind) { - case kind_t::external_t: - if (ext) - ext(out, m_external); - else - out << "external"; - return out; - case kind_t::axiom_t: - return out << "axiom"; - case kind_t::congruence_t: - return out << "congruence"; - case kind_t::dependent_t: { - vector js; - out << "dependent"; - for (auto const& j : dependency_manager::s_linearize(m_dependency, js)) - j.display(out << " ", ext); - return out; - } - default: - UNREACHABLE(); - return out; - } - return out; - } + std::ostream& display(std::ostream& out, std::function const& ext) const; + }; inline std::ostream& operator<<(std::ostream& out, justification const& j) { diff --git a/src/ast/euf/euf_plugin.h b/src/ast/euf/euf_plugin.h index f36ab38f3..ff49d6c40 100644 --- a/src/ast/euf/euf_plugin.h +++ b/src/ast/euf/euf_plugin.h @@ -43,12 +43,10 @@ namespace euf { virtual unsigned get_id() const = 0; virtual void register_node(enode* n) = 0; - - virtual void register_shared(enode* n) = 0; - virtual void merge_eh(enode* n1, enode* n2, justification j) = 0; + virtual void merge_eh(enode* n1, enode* n2) = 0; - virtual void diseq_eh(enode* n1, enode* n2) = 0; + virtual void diseq_eh(enode* eq) {}; virtual void propagate() = 0; diff --git a/src/ast/euf/euf_specrel_plugin.cpp b/src/ast/euf/euf_specrel_plugin.cpp new file mode 100644 index 000000000..3220a24e6 --- /dev/null +++ b/src/ast/euf/euf_specrel_plugin.cpp @@ -0,0 +1,71 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_specrel_plugin.cpp + +Abstract: + + plugin structure for specrel + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#include "ast/euf/euf_specrel_plugin.h" +#include "ast/euf/euf_egraph.h" +#include + +namespace euf { + + specrel_plugin::specrel_plugin(egraph& g) : + plugin(g), + sp(g.get_manager()) { + } + + void specrel_plugin::register_node(enode* n) { + func_decl* f = n->get_decl(); + if (!f) + return; + if (!sp.is_ac(f)) + return; + ac_plugin* p = nullptr; + if (!m_decl2plugin.find(f, p)) { + p = alloc(ac_plugin, g, f); + m_decl2plugin.insert(f, p); + m_plugins.push_back(p); + std::function undo_op = [&]() { m_undo.push_back(p); }; + p->set_undo(undo_op); + } + } + + void specrel_plugin::merge_eh(enode* n1, enode* n2) { + for (auto * p : m_plugins) + p->merge_eh(n1, n2); + } + + void specrel_plugin::diseq_eh(enode* eq) { + for (auto* p : m_plugins) + p->diseq_eh(eq); + } + + void specrel_plugin::propagate() { + for (auto * p : m_plugins) + p->propagate(); + } + + void specrel_plugin::undo() { + auto p = m_undo.back(); + m_undo.pop_back(); + p->undo(); + } + + std::ostream& specrel_plugin::display(std::ostream& out) const { + for (auto * p : m_plugins) + p->display(out); + return out; + } +} \ No newline at end of file diff --git a/src/ast/euf/euf_specrel_plugin.h b/src/ast/euf/euf_specrel_plugin.h new file mode 100644 index 000000000..228bb5e15 --- /dev/null +++ b/src/ast/euf/euf_specrel_plugin.h @@ -0,0 +1,56 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_specrel_plugin.h + +Abstract: + + plugin structure for specrel functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#pragma once + +#include +#include "util/scoped_ptr_vector.h" +#include "ast/special_relations_decl_plugin.h" +#include "ast/euf/euf_plugin.h" +#include "ast/euf/euf_ac_plugin.h" + +namespace euf { + + class specrel_plugin : public plugin { + scoped_ptr_vector m_plugins; + ptr_vector m_undo; + obj_map m_decl2plugin; + special_relations_util sp; + + public: + + specrel_plugin(egraph& g); + + ~specrel_plugin() override {} + + unsigned get_id() const override { return sp.get_family_id(); } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override; + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + }; + +} diff --git a/src/ast/special_relations_decl_plugin.cpp b/src/ast/special_relations_decl_plugin.cpp index 24a756bf7..bbfe819d4 100644 --- a/src/ast/special_relations_decl_plugin.cpp +++ b/src/ast/special_relations_decl_plugin.cpp @@ -26,7 +26,8 @@ special_relations_decl_plugin::special_relations_decl_plugin(): m_po("partial-order"), m_plo("piecewise-linear-order"), m_to("tree-order"), - m_tc("transitive-closure") + m_tc("transitive-closure"), + m_ac("ac-op") {} func_decl * special_relations_decl_plugin::mk_func_decl( @@ -41,24 +42,53 @@ func_decl * special_relations_decl_plugin::mk_func_decl( m_manager->raise_exception("argument sort missmatch. The two arguments should have the same sort"); return nullptr; } + if (!range && k == OP_SPECIAL_RELATION_AC) + range = domain[0]; + if (!range) { range = m_manager->mk_bool_sort(); } - if (!m_manager->is_bool(range)) { - m_manager->raise_exception("range type is expected to be Boolean for special relations"); - } + auto check_bool_range = [&]() { + if (!m_manager->is_bool(range)) + m_manager->raise_exception("range type is expected to be Boolean for special relations"); + }; + + m_has_special_relation = true; func_decl_info info(m_family_id, k, num_parameters, parameters); symbol name; switch(k) { - case OP_SPECIAL_RELATION_PO: name = m_po; break; - case OP_SPECIAL_RELATION_LO: name = m_lo; break; - case OP_SPECIAL_RELATION_PLO: name = m_plo; break; - case OP_SPECIAL_RELATION_TO: name = m_to; break; + case OP_SPECIAL_RELATION_PO: check_bool_range(); name = m_po; break; + case OP_SPECIAL_RELATION_LO: check_bool_range(); name = m_lo; break; + case OP_SPECIAL_RELATION_PLO: check_bool_range(); name = m_plo; break; + case OP_SPECIAL_RELATION_TO: check_bool_range(); name = m_to; break; + case OP_SPECIAL_RELATION_AC: { + if (range != domain[0]) + m_manager->raise_exception("AC operation should have the same range as domain type"); + name = m_ac; + if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast())) + m_manager->raise_exception("parameter to transitive closure should be a function declaration"); + func_decl* f = to_func_decl(parameters[0].get_ast()); + if (f->get_arity() != 2) + m_manager->raise_exception("ac function should be binary"); + if (f->get_domain(0) != f->get_domain(1)) + m_manager->raise_exception("ac function should have same domain"); + if (f->get_domain(0) != f->get_range()) + m_manager->raise_exception("ac function should have same domain and range"); + break; + } case OP_SPECIAL_RELATION_TC: + check_bool_range(); name = m_tc; if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast())) m_manager->raise_exception("parameter to transitive closure should be a function declaration"); + func_decl* f = to_func_decl(parameters[0].get_ast()); + if (f->get_arity() != 2) + m_manager->raise_exception("tc relation should be binary"); + if (f->get_domain(0) != f->get_domain(1)) + m_manager->raise_exception("tc relation should have same domain"); + if (!m_manager->is_bool(f->get_range())) + m_manager->raise_exception("tc relation should be Boolean"); break; } return m_manager->mk_func_decl(name, arity, domain, range, info); @@ -71,6 +101,7 @@ void special_relations_decl_plugin::get_op_names(svector & op_name op_names.push_back(builtin_name(m_plo.str(), OP_SPECIAL_RELATION_PLO)); op_names.push_back(builtin_name(m_to.str(), OP_SPECIAL_RELATION_TO)); op_names.push_back(builtin_name(m_tc.str(), OP_SPECIAL_RELATION_TC)); + op_names.push_back(builtin_name(m_ac.str(), OP_SPECIAL_RELATION_AC)); } } @@ -81,6 +112,7 @@ sr_property special_relations_util::get_property(func_decl* f) const { case OP_SPECIAL_RELATION_PLO: return sr_plo; case OP_SPECIAL_RELATION_TO: return sr_to; case OP_SPECIAL_RELATION_TC: return sr_tc; + case OP_SPECIAL_RELATION_AC: return sr_none; default: UNREACHABLE(); return sr_po; diff --git a/src/ast/special_relations_decl_plugin.h b/src/ast/special_relations_decl_plugin.h index c422cbcdc..a65f98758 100644 --- a/src/ast/special_relations_decl_plugin.h +++ b/src/ast/special_relations_decl_plugin.h @@ -16,6 +16,8 @@ Author: Revision History: + 2023-11-27: Added ac-op for E-graph plugin + --*/ #pragma once @@ -28,6 +30,7 @@ enum special_relations_op_kind { OP_SPECIAL_RELATION_PLO, OP_SPECIAL_RELATION_TO, OP_SPECIAL_RELATION_TC, + OP_SPECIAL_RELATION_AC, LAST_SPECIAL_RELATIONS_OP }; @@ -37,6 +40,7 @@ class special_relations_decl_plugin : public decl_plugin { symbol m_plo; symbol m_to; symbol m_tc; + symbol m_ac; bool m_has_special_relation = false; public: special_relations_decl_plugin(); @@ -86,13 +90,16 @@ class special_relations_util { public: special_relations_util(ast_manager& m) : m(m), m_fid(null_family_id) { } + family_id get_family_id() const { return fid(); } + bool has_special_relation() const { return static_cast(m.get_plugin(m.mk_family_id("specrels")))->has_special_relation(); } bool is_special_relation(func_decl* f) const { return f->get_family_id() == fid(); } - bool is_special_relation(app* e) const { return is_special_relation(e->get_decl()); } + bool is_special_relation(expr* e) const { return is_app(e) && is_special_relation(to_app(e)->get_decl()); } sr_property get_property(func_decl* f) const; sr_property get_property(app* e) const { return get_property(e->get_decl()); } func_decl* get_relation(func_decl* f) const { SASSERT(is_special_relation(f)); return to_func_decl(f->get_parameter(0).get_ast()); } + func_decl* get_relation(expr* e) const { SASSERT(is_special_relation(e)); return to_func_decl(to_app(e)->get_parameter(0).get_ast()); } func_decl* mk_to_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_TO); } func_decl* mk_po_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_PO); } @@ -105,12 +112,14 @@ public: bool is_plo(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_PLO); } bool is_to(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TO); } bool is_tc(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TC); } + bool is_ac(expr const* e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_AC); } bool is_lo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_LO); } bool is_po(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PO); } bool is_plo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PLO); } bool is_to(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TO); } bool is_tc(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TC); } + bool is_ac(func_decl const* e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_AC); } app * mk_lo (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_LO, arg1, arg2); } app * mk_po (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_PO, arg1, arg2); } diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 448e75d26..df0c0dea8 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -45,6 +45,7 @@ z3_add_component(sat_smt q_solver.cpp recfun_solver.cpp sat_th.cpp + specrel_solver.cpp tseitin_theory_checker.cpp user_solver.cpp COMPONENT_DEPENDENCIES diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index 56a224d36..0d4d01b8e 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -7,7 +7,7 @@ Module Name: Abstract: - Theory plugin for altegraic datatypes + Theory plugin for algebraic datatypes Author: diff --git a/src/sat/smt/dt_solver.h b/src/sat/smt/dt_solver.h index 4118d6780..428dbbde0 100644 --- a/src/sat/smt/dt_solver.h +++ b/src/sat/smt/dt_solver.h @@ -7,7 +7,7 @@ Module Name: Abstract: - Theory plugin for altegraic datatypes + Theory plugin for algebraic datatypes Author: diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index d03f82aee..215321d40 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -28,6 +28,7 @@ Author: #include "sat/smt/fpa_solver.h" #include "sat/smt/dt_solver.h" #include "sat/smt/recfun_solver.h" +#include "sat/smt/specrel_solver.h" namespace euf { @@ -130,6 +131,7 @@ namespace euf { arith_util arith(m); datatype_util dt(m); recfun::util rf(m); + special_relations_util sp(m); if (pb.get_family_id() == fid) ext = alloc(pb::solver, *this, fid); else if (bvu.get_family_id() == fid) @@ -144,7 +146,8 @@ namespace euf { ext = alloc(dt::solver, *this, fid); else if (rf.get_family_id() == fid) ext = alloc(recfun::solver, *this); - + else if (sp.get_family_id() == fid) + ext = alloc(specrel::solver, *this, fid); if (ext) add_solver(ext); else if (f) diff --git a/src/sat/smt/specrel_solver.cpp b/src/sat/smt/specrel_solver.cpp new file mode 100644 index 000000000..ab2c0366a --- /dev/null +++ b/src/sat/smt/specrel_solver.cpp @@ -0,0 +1,119 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + specrel_solver.h + +Abstract: + + Theory plugin for special relations + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ + +#include "sat/smt/specrel_solver.h" +#include "sat/smt/euf_solver.h" + +namespace euf { + class solver; +} + +namespace specrel { + + solver::solver(euf::solver& ctx, theory_id id) : + th_euf_solver(ctx, ctx.get_manager().get_family_name(id), id), + sp(m) + { + ctx.get_egraph().add_plugins(); + } + + solver::~solver() { + } + + void solver::asserted(sat::literal l) { + + } + + sat::check_result solver::check() { + return sat::check_result::CR_DONE; + } + + std::ostream& solver::display(std::ostream& out) const { + return out; + } + + void solver::collect_statistics(statistics& st) const { + } + + euf::th_solver* solver::clone(euf::solver& ctx) { + return alloc(solver, ctx, get_id()); + } + + void solver::new_eq_eh(euf::th_eq const& eq) { + TRACE("specrel", tout << "new-eq\n"); + if (eq.is_eq()) { + auto* p = ctx.get_egraph().get_plugin(sp.get_family_id()); + p->merge_eh(var2enode(eq.v1()), var2enode(eq.v2())); + TRACE("specrel", tout << eq.v1() << " " << eq.v2() << "\n"); + } + } + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + } + + bool solver::add_dep(euf::enode* n, top_sort& dep) { + return false; + } + + bool solver::include_func_interp(func_decl* f) const { + return false; + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + auto lit = ctx.expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + if (visited(e)) + return true; + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* term, bool sign, bool root) { + euf::enode* n = expr2enode(term); + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(term); + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + TRACE("specrel", tout << ctx.bpp(n) << "\n"); + return true; + } + + euf::theory_var solver::mk_var(euf::enode* n) { + if (is_attached_to_var(n)) + return n->get_th_var(get_id()); + euf::theory_var r = th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, r); + return r; + } +} diff --git a/src/sat/smt/specrel_solver.h b/src/sat/smt/specrel_solver.h new file mode 100644 index 000000000..9ebb76916 --- /dev/null +++ b/src/sat/smt/specrel_solver.h @@ -0,0 +1,75 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + specrel_solver.h + +Abstract: + + Theory plugin for special relations + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ +#pragma once + +#include "sat/smt/sat_th.h" +#include "ast/special_relations_decl_plugin.h" + +namespace euf { + class solver; +} + +namespace specrel { + + class solver : public euf::th_euf_solver { + typedef euf::theory_var theory_var; + typedef euf::theory_id theory_id; + typedef euf::enode enode; + typedef euf::enode_pair enode_pair; + typedef euf::enode_pair_vector enode_pair_vector; + typedef sat::bool_var bool_var; + typedef sat::literal literal; + typedef sat::literal_vector literal_vector; + + special_relations_util sp; + + public: + solver(euf::solver& ctx, theory_id id); + ~solver() override; + + bool is_external(bool_var v) override { return false; } + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {} + void asserted(literal l) override; + sat::check_result check() override; + + std::ostream& display(std::ostream& out) const override; + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return euf::th_explain::from_index(idx).display(out); } + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return display_justification(out, idx); } + void collect_statistics(statistics& st) const override; + euf::th_solver* clone(euf::solver& ctx) override; + void new_eq_eh(euf::th_eq const& eq) override; + bool unit_propagate() override { return false; } + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + bool add_dep(euf::enode* n, top_sort& dep) override; + bool include_func_interp(func_decl* f) const override; + sat::literal internalize(expr* e, bool sign, bool root) override; + void internalize(expr* e) override; + bool visit(expr* e) override; + bool visited(expr* e) override; + bool post_visit(expr* e, bool sign, bool root) override; + + euf::theory_var mk_var(euf::enode* n) override; + void apply_sort_cnstr(euf::enode* n, sort* s) override {} + bool is_shared(theory_var v) const override { return false; } + lbool get_phase(bool_var v) override { return l_true; } + bool enable_self_propagate() const override { return true; } + + void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2); + void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {} + void unmerge_eh(theory_var v1, theory_var v2) {} + }; +} diff --git a/src/test/euf_arith_plugin.cpp b/src/test/euf_arith_plugin.cpp index 217928977..1b85d5dd3 100644 --- a/src/test/euf_arith_plugin.cpp +++ b/src/test/euf_arith_plugin.cpp @@ -11,14 +11,18 @@ Copyright (c) 2023 Microsoft Corporation #include "ast/ast_pp.h" #include -static euf::enode* get_node(euf::egraph& g, expr* e) { +unsigned s_var = 0; + +static euf::enode* get_node(euf::egraph& g, arith_util& a, expr* e) { auto* n = g.find(e); if (n) return n; euf::enode_vector args; for (expr* arg : *to_app(e)) - args.push_back(get_node(g, arg)); - return g.mk(e, 0, args.size(), args.data()); + args.push_back(get_node(g, a, arg)); + n = g.mk(e, 0, args.size(), args.data()); + g.add_th_var(n, s_var++, a.get_family_id()); + return n; } // @@ -32,15 +36,15 @@ static void test1() { expr_ref x(m.mk_const("x", I), m); expr_ref y(m.mk_const("y", I), m); - auto* nx = get_node(g, a.mk_add(a.mk_add(y, y), a.mk_add(x, x))); - auto* ny = get_node(g, a.mk_add(a.mk_add(y, x), x)); + auto* nx = get_node(g, a, a.mk_add(a.mk_add(y, y), a.mk_add(x, x))); + auto* ny = get_node(g, a, a.mk_add(a.mk_add(y, x), x)); TRACE("plugin", tout << "before merge\n" << g << "\n"); g.merge(nx, ny, nullptr); TRACE("plugin", tout << "before propagate\n" << g << "\n"); g.propagate(); TRACE("plugin", tout << "after propagate\n" << g << "\n"); - g.merge(get_node(g, a.mk_add(x, a.mk_add(y, y))), get_node(g, a.mk_add(y, x)), nullptr); + g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr); g.propagate(); std::cout << g << "\n"; } @@ -55,10 +59,10 @@ static void test2() { expr_ref x(m.mk_const("x", I), m); expr_ref y(m.mk_const("y", I), m); - auto* nxy = get_node(g, a.mk_add(x, y)); - auto* nyx = get_node(g, a.mk_add(y, x)); - auto* nx = get_node(g, x); - auto* ny = get_node(g, y); + auto* nxy = get_node(g, a, a.mk_add(x, y)); + auto* nyx = get_node(g, a, a.mk_add(y, x)); + auto* nx = get_node(g, a, x); + auto* ny = get_node(g, a, y); TRACE("plugin", tout << "before merge\n" << g << "\n"); g.merge(nxy, nx, nullptr); @@ -67,7 +71,7 @@ static void test2() { g.propagate(); TRACE("plugin", tout << "after propagate\n" << g << "\n"); SASSERT(nx->get_root() == ny->get_root()); - g.merge(get_node(g, a.mk_add(x, a.mk_add(y, y))), get_node(g, a.mk_add(y, x)), nullptr); + g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr); g.propagate(); std::cout << g << "\n"; } @@ -82,22 +86,21 @@ static void test3() { expr_ref x(m.mk_const("x", I), m); expr_ref y(m.mk_const("y", I), m); - auto* nxyy = get_node(g, a.mk_add(a.mk_add(x, y), y)); - auto* nyxx = get_node(g, a.mk_add(a.mk_add(y, x), x)); - auto* nx = get_node(g, x); - auto* ny = get_node(g, y); + auto* nxyy = get_node(g, a, a.mk_add(a.mk_add(x, y), y)); + auto* nyxx = get_node(g, a, a.mk_add(a.mk_add(y, x), x)); + auto* nx = get_node(g, a, x); + auto* ny = get_node(g, a, y); g.merge(nxyy, nx, nullptr); g.merge(nyxx, ny, nullptr); TRACE("plugin", tout << "before propagate\n" << g << "\n"); g.propagate(); TRACE("plugin", tout << "after propagate\n" << g << "\n"); - SASSERT(nx->get_root() == ny->get_root()); std::cout << g << "\n"; } void tst_euf_arith_plugin() { enable_trace("plugin"); - test3(); test1(); test2(); + test3(); } diff --git a/src/test/euf_bv_plugin.cpp b/src/test/euf_bv_plugin.cpp index a0946682b..90b9e1511 100644 --- a/src/test/euf_bv_plugin.cpp +++ b/src/test/euf_bv_plugin.cpp @@ -11,14 +11,17 @@ Copyright (c) 2023 Microsoft Corporation #include "ast/ast_pp.h" #include -static euf::enode* get_node(euf::egraph& g, expr* e) { +static unsigned s_var = 0; +static euf::enode* get_node(euf::egraph& g, bv_util& b, expr* e) { auto* n = g.find(e); if (n) return n; euf::enode_vector args; for (expr* arg : *to_app(e)) - args.push_back(get_node(g, arg)); - return g.mk(e, 0, args.size(), args.data()); + args.push_back(get_node(g, b, arg)); + n = g.mk(e, 0, args.size(), args.data()); + g.add_th_var(n, s_var++, b.get_family_id()); + return n; } // align slices, and propagate extensionality @@ -40,8 +43,8 @@ static void test1() { expr_ref y1(bv.mk_extract(7, 0, y), m); expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m); - auto* nx = get_node(g, xx); - auto* ny = get_node(g, yy); + auto* nx = get_node(g, bv, xx); + auto* ny = get_node(g, bv, yy); TRACE("bv", tout << "before merge\n" << g << "\n"); g.merge(nx, ny, nullptr); TRACE("bv", tout << "before propagate\n" << g << "\n"); @@ -65,12 +68,12 @@ static void test2() { expr_ref x2(bv.mk_extract(15, 8, x), m); expr_ref x1(bv.mk_extract(7, 0, x), m); expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); - g.merge(get_node(g, xx), get_node(g, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr); + g.merge(get_node(g, bv, xx), get_node(g, bv, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr); g.propagate(); - SASSERT(get_node(g, x1)->get_root()->interpreted()); - SASSERT(get_node(g, x2)->get_root()->interpreted()); - SASSERT(get_node(g, x3)->get_root()->interpreted()); - SASSERT(get_node(g, x)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x1)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x2)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x3)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x)->get_root()->interpreted()); } @@ -89,13 +92,13 @@ static void test3() { expr_ref x1(bv.mk_extract(7, 0, x), m); expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m); expr_ref y(m.mk_const("y", u32), m); - g.merge(get_node(g, xx), get_node(g, y), nullptr); - g.merge(get_node(g, x1), get_node(g, bv.mk_numeral(2, 8)), nullptr); - g.merge(get_node(g, x2), get_node(g, bv.mk_numeral(8, 8)), nullptr); + g.merge(get_node(g, bv, xx), get_node(g, bv, y), nullptr); + g.merge(get_node(g, bv, x1), get_node(g, bv, bv.mk_numeral(2, 8)), nullptr); + g.merge(get_node(g, bv, x2), get_node(g, bv, bv.mk_numeral(8, 8)), nullptr); g.propagate(); - SASSERT(get_node(g, bv.mk_concat(x1, x2))->get_root()->interpreted()); - SASSERT(get_node(g, x1)->get_root()->interpreted()); - SASSERT(get_node(g, x2)->get_root()->interpreted()); + SASSERT(get_node(g, bv, bv.mk_concat(x1, x2))->get_root()->interpreted()); + SASSERT(get_node(g, bv, x1)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x2)->get_root()->interpreted()); } // propagate extract up @@ -114,11 +117,11 @@ static void test4() { expr_ref y(m.mk_const("y", u16), m); expr_ref x1(bv.mk_extract(15, 8, x), m); expr_ref x2(bv.mk_extract(23, 16, x), m); - g.merge(get_node(g, bv.mk_concat(a, x2)), get_node(g, y), nullptr); - g.merge(get_node(g, x1), get_node(g, a), nullptr); + g.merge(get_node(g, bv, bv.mk_concat(a, x2)), get_node(g, bv, y), nullptr); + g.merge(get_node(g, bv, x1), get_node(g, bv, a), nullptr); g.propagate(); TRACE("bv", tout << g << "\n"); - SASSERT(get_node(g, bv.mk_extract(23, 8, x))->get_root() == get_node(g, y)->get_root()); + SASSERT(get_node(g, bv, bv.mk_extract(23, 8, x))->get_root() == get_node(g, bv, y)->get_root()); } // iterative slicing @@ -133,8 +136,8 @@ static void test5() { expr_ref x(m.mk_const("x", u32), m); expr_ref x1(bv.mk_extract(31, 4, x), m); expr_ref x2(bv.mk_extract(27, 0, x), m); - auto* nx = get_node(g, x1); - auto* ny = get_node(g, x2); + auto* nx = get_node(g, bv, x1); + auto* ny = get_node(g, bv, x2); TRACE("bv", tout << "before merge\n" << g << "\n"); g.merge(nx, ny, nullptr); TRACE("bv", tout << "before propagate\n" << g << "\n"); @@ -155,8 +158,8 @@ static void test6() { expr_ref x(m.mk_const("x", u32), m); expr_ref x1(bv.mk_extract(31, 3, x), m); expr_ref x2(bv.mk_extract(28, 0, x), m); - auto* nx = get_node(g, x1); - auto* ny = get_node(g, x2); + auto* nx = get_node(g, bv, x1); + auto* ny = get_node(g, bv, x2); TRACE("bv", tout << "before merge\n" << g << "\n"); g.merge(nx, ny, nullptr); TRACE("bv", tout << "before propagate\n" << g << "\n");