diff --git a/src/ast/euf/CMakeLists.txt b/src/ast/euf/CMakeLists.txt index 8d3fa2e74..aa71e7fba 100644 --- a/src/ast/euf/CMakeLists.txt +++ b/src/ast/euf/CMakeLists.txt @@ -1,8 +1,14 @@ z3_add_component(euf SOURCES + euf_ac_plugin.cpp + euf_arith_plugin.cpp + euf_bv_plugin.cpp + euf_egraph.cpp euf_enode.cpp euf_etable.cpp - euf_egraph.cpp + euf_justification.cpp + euf_plugin.cpp + euf_specrel_plugin.cpp COMPONENT_DEPENDENCIES ast util diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp new file mode 100644 index 000000000..5b4b1df66 --- /dev/null +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -0,0 +1,1058 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_ac_plugin.cpp + +Abstract: + + plugin structure for ac functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +Completion modulo AC + + E set of eqs + pick critical pair xy = z by j1 xu = v by j2 in E + Add new equation zu = xyu = vy by j1, j2 + + + Notes: + - Some equalities come from shared terms, some do not. + + - V2 can use multiplicities of elements to handle larger domains. + - e.g. 3x + 100000y + +More notes: + + Justifications for new equations are joined (requires extension to egraph/justification) + + Process new merges so use list is updated + Justifications for processed merges are recorded + + Updated equations are recorded for restoration on backtracking + + Keep track of foreign / shared occurrences of AC functions. + - use register_shared to accumulate shared occurrences. + + Shared occurrences are rewritten modulo completion. + When equal to a different shared occurrence, propagate equality. + + - Elimination of redundant rules. + -> forward and backward subsumption + - apply forward subsumption when simplifying equality using processed + - apply backward subsumption when simplifying processed and to_simplify + + Rewrite rules are reoriented after a merge of enodes. + It simulates creating a critical pair: + n -> n' + n + k = j + k + after merge + n' + k = j + k, could be that n' + k < j + k < n + k in term ordering because n' < j, m < n + +TODOs: + +- Efficiency of handling shared terms. + - The shared terms hash table is not incremental. + It could be made incremental by updating it on every merge similar to how the egraph handles it. +- V2 using multiplicities instead of repeated values in monomials. +- Squash trail updates when equations or monomials are modified within the same epoque. + - by an epoque counter that can be updated by the egraph class whenever there is a push/pop. + - store the epoque as a tick on equations and possibly when updating monomials on equations. + +--*/ + +#include "ast/euf/euf_ac_plugin.h" +#include "ast/euf/euf_egraph.h" +#include "ast/ast_pp.h" + +namespace euf { + + ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op) : + plugin(g), m_fid(fid), m_op(op), + m_dep_manager(get_region()), + m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq) + { + g.set_th_propagates_diseqs(m_fid); + } + + ac_plugin::ac_plugin(egraph& g, func_decl* f) : + plugin(g), m_decl(f), m_fid(f->get_family_id()), + m_dep_manager(get_region()), + m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq) + { + if (m_fid != null_family_id) + g.set_th_propagates_diseqs(m_fid); + } + + void ac_plugin::register_node(enode* n) { + if (is_op(n)) + return; + for (auto arg : enode_args(n)) + if (is_op(arg)) + register_shared(arg); // TODO optimization to avoid registering shared terms twice + } + + void ac_plugin::register_shared(enode* n) { + if (m_shared_nodes.get(n->get_id(), false)) + return; + auto m = to_monomial(n); + auto const& ns = monomial(m); + for (auto arg : ns) { + arg->shared.push_back(m); + m_node_trail.push_back(arg); + push_undo(is_add_shared_index); + } + m_shared_nodes.setx(n->get_id(), true, false); + sort(monomial(m)); + m_shared_todo.insert(m_shared.size()); + m_shared.push_back({ n, m, justification::axiom() }); + push_undo(is_register_shared); + } + + void ac_plugin::undo() { + auto k = m_undo.back(); + m_undo.pop_back(); + switch (k) { + case is_add_eq: { + m_eqs.pop_back(); + break; + } + case is_add_node: { + auto* n = m_node_trail.back(); + m_node_trail.pop_back(); + m_nodes[n->n->get_id()] = nullptr; + n->~node(); + break; + } + case is_add_monomial: { + m_monomials.pop_back(); + break; + } + case is_merge_node: { + auto [other, old_shared, old_eqs] = m_merge_trail.back(); + auto* root = other->root; + std::swap(other->next, root->next); + root->shared.shrink(old_shared); + root->eqs.shrink(old_eqs); + m_merge_trail.pop_back(); + ++m_tick; + break; + } + case is_update_eq: { + auto const& [idx, eq] = m_update_eq_trail.back(); + m_eqs[idx] = eq; + m_update_eq_trail.pop_back(); + break; + } + case is_add_shared_index: { + auto n = m_node_trail.back(); + m_node_trail.pop_back(); + n->shared.pop_back(); + break; + } + case is_add_eq_index: { + auto n = m_node_trail.back(); + m_node_trail.pop_back(); + n->eqs.pop_back(); + break; + } + case is_register_shared: { + auto s = m_shared.back(); + m_shared_nodes[s.n->get_id()] = false; + m_shared.pop_back(); + break; + } + case is_update_shared: { + auto [id, s] = m_update_shared_trail.back(); + m_shared[id] = s; + m_update_shared_trail.pop_back(); + break; + } + default: + UNREACHABLE(); + } + } + + std::ostream& ac_plugin::display_monomial(std::ostream& out, ptr_vector const& m) const { + for (auto n : m) { + if (n->n->num_args() == 0) + out << mk_pp(n->n->get_expr(), g.get_manager()) << " "; + else + out << g.bpp(n->n) << " "; + } + return out; + } + + std::ostream& ac_plugin::display_equation(std::ostream& out, eq const& e) const { + display_status(out, e.status) << " "; + display_monomial(out, monomial(e.l)); + out << "== "; + display_monomial(out, monomial(e.r)); + return out; + } + + std::ostream& ac_plugin::display_status(std::ostream& out, eq_status s) const { + switch (s) { + case eq_status::is_dead: out << "d"; break; + case eq_status::processed: out << "p"; break; + case eq_status::to_simplify: out << "s"; break; + } + return out; + } + + std::ostream& ac_plugin::display(std::ostream& out) const { + unsigned i = 0; + for (auto const& eq : m_eqs) { + out << i << ": " << eq.l << " == " << eq.r << ": "; + display_equation(out, eq); + out << "\n"; + ++i; + } + i = 0; + for (auto m : m_monomials) { + out << i << ": "; + display_monomial(out, m); + out << "\n"; + ++i; + } + for (auto n : m_nodes) { + if (!n) + continue; + if (n->eqs.empty() && n->shared.empty()) + continue; + out << g.bpp(n->n) << " r: " << n->root_id() << " "; + if (!n->eqs.empty()) { + out << "eqs "; + for (auto l : n->eqs) + out << l << " "; + } + if (!n->shared.empty()) { + out << "shared "; + for (auto s : n->shared) + out << s << " "; + } + out << "\n"; + } + return out; + } + + void ac_plugin::merge_eh(enode* l, enode* r) { + if (l == r) + return; + auto j = justification::equality(l, r); + if (!is_op(l) && !is_op(r)) + merge(mk_node(l), mk_node(r), j); + else + init_equation(eq(to_monomial(l), to_monomial(r), j)); + } + + void ac_plugin::diseq_eh(enode* eq) { + SASSERT(g.get_manager().is_eq(eq->get_expr())); + enode* a = eq->get_arg(0), * b = eq->get_arg(1); + a = a->get_closest_th_node(m_fid); + b = b->get_closest_th_node(m_fid); + SASSERT(a && b); + register_shared(a); + register_shared(b); + } + + void ac_plugin::init_equation(eq const& e) { + m_eqs.push_back(e); + auto& eq = m_eqs.back(); + if (orient_equation(eq)) { + + unsigned eq_id = m_eqs.size() - 1; + + for (auto n : monomial(eq.l)) { + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq_id); + n->root->n->mark1(); + push_undo(is_add_eq_index); + m_node_trail.push_back(n->root); + } + } + + for (auto n : monomial(eq.r)) { + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq_id); + n->root->n->mark1(); + push_undo(is_add_eq_index); + m_node_trail.push_back(n->root); + } + } + + for (auto n : monomial(eq.l)) + n->root->n->unmark1(); + + for (auto n : monomial(eq.r)) + n->root->n->unmark1(); + + m_to_simplify_todo.insert(eq_id); + } + else + m_eqs.pop_back(); + } + + bool ac_plugin::orient_equation(eq& e) { + auto& ml = monomial(e.l); + auto& mr = monomial(e.r); + if (ml.size() > mr.size()) + return true; + if (ml.size() < mr.size()) { + std::swap(e.l, e.r); + return true; + } + else { + sort(ml); + sort(mr); + for (unsigned i = ml.size(); i-- > 0;) { + if (ml[i]->root_id() == mr[i]->root_id()) + continue; + if (ml[i]->root_id() < mr[i]->root_id()) + std::swap(e.l, e.r); + return true; + } + return false; + } + } + + void ac_plugin::sort(monomial_t& m) { + std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); }); + } + + bool ac_plugin::is_sorted(monomial_t const& m) const { + if (m.m_bloom.m_tick == m_tick) + return true; + for (unsigned i = m.size(); i-- > 1; ) + if (m[i - 1]->root_id() > m[i]->root_id()) + return false; + return true; + } + + uint64_t ac_plugin::filter(monomial_t& m) { + auto& bloom = m.m_bloom; + if (bloom.m_tick == m_tick) + return bloom.m_filter; + bloom.m_filter = 0; + for (auto n : m) + bloom.m_filter |= (1ull << (n->root_id() % 64ull)); + if (!is_sorted(m)) + sort(m); + bloom.m_tick = m_tick; + return bloom.m_filter; + } + + bool ac_plugin::can_be_subset(monomial_t& subset, monomial_t& superset) { + if (subset.size() > superset.size()) + return false; + auto f1 = filter(subset); + auto f2 = filter(superset); + return (f1 | f2) == f2; + } + + bool ac_plugin::can_be_subset(monomial_t& subset, ptr_vector const& m, bloom& bloom) { + if (subset.size() > m.size()) + return false; + if (bloom.m_tick != m_tick) { + bloom.m_filter = 0; + for (auto n : m) + bloom.m_filter |= (1ull << (n->root_id() % 64ull)); + bloom.m_tick = m_tick; + } + auto f2 = bloom.m_filter; + return (filter(subset) | f2) == f2; + } + + void ac_plugin::merge(node* root, node* other, justification j) { + for (auto n : equiv(other)) + n->root = root; + m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() }); + for (auto eq_id : other->eqs) + set_status(eq_id, eq_status::to_simplify); + for (auto m : other->shared) + m_shared_todo.insert(m); + root->shared.append(other->shared); + root->eqs.append(other->eqs); + std::swap(root->next, other->next); + push_undo(is_merge_node); + ++m_tick; + } + + void ac_plugin::push_undo(undo_kind k) { + m_undo.push_back(k); + push_plugin_undo(get_id()); + m_undo_notify(); // tell main plugin to dispatch undo to this module. + } + + unsigned ac_plugin::to_monomial(enode* n) { + enode_vector& ns = m_todo; + ns.reset(); + ptr_vector m; + ns.push_back(n); + for (unsigned i = 0; i < ns.size(); ++i) { + n = ns[i]; + if (is_op(n)) + ns.append(n->num_args(), n->args()); + else + m.push_back(mk_node(n)); + } + return to_monomial(n, m); + } + + unsigned ac_plugin::to_monomial(enode* e, ptr_vector const& ms) { + unsigned id = m_monomials.size(); + m_monomials.push_back({ ms, bloom() }); + push_undo(is_add_monomial); + return id; + } + + ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) { + auto* mem = r.allocate(sizeof(node)); + node* res = new (mem) node(); + res->n = n; + res->root = res; + res->next = res; + return res; + } + + ac_plugin::node* ac_plugin::mk_node(enode* n) { + unsigned id = n->get_id(); + if (m_nodes.size() > id && m_nodes[id]) + return m_nodes[id]; + auto* r = node::mk(get_region(), n); + push_undo(is_add_node); + m_nodes.setx(id, r, nullptr); + m_node_trail.push_back(r); + return r; + } + + void ac_plugin::propagate() { + while (true) { + loop_start: + unsigned eq_id = pick_next_eq(); + if (eq_id == UINT_MAX) + break; + + TRACE("plugin", tout << "propagate " << eq_id << ": " << eq_pp(*this, m_eqs[eq_id]) << "\n"); + + // simplify eq using processed + for (auto other_eq : backward_iterator(eq_id)) + TRACE("plugin", tout << "backward iterator " << eq_id << " vs " << other_eq << " " << is_processed(other_eq) << "\n"); + for (auto other_eq : backward_iterator(eq_id)) + if (is_processed(other_eq) && backward_simplify(eq_id, other_eq)) + goto loop_start; + + set_status(eq_id, eq_status::processed); + + // simplify processed using eq + for (auto other_eq : forward_iterator(eq_id)) + if (is_processed(other_eq)) + forward_simplify(eq_id, other_eq); + + // superpose, create new equations + for (auto other_eq : superpose_iterator(eq_id)) + if (is_processed(other_eq)) + superpose(eq_id, other_eq); + + // simplify to_simplify using eq + for (auto other_eq : forward_iterator(eq_id)) + if (is_to_simplify(other_eq)) + forward_simplify(eq_id, other_eq); + } + propagate_shared(); + + CTRACE("plugin", !m_shared.empty() || !m_eqs.empty(), display(tout)); + } + + unsigned ac_plugin::pick_next_eq() { + while (!m_to_simplify_todo.empty()) { + unsigned id = *m_to_simplify_todo.begin(); + if (id < m_eqs.size() && is_to_simplify(id)) + return id; + m_to_simplify_todo.remove(id); + } + return UINT_MAX; + } + + // reorient equations when the status of equations are set to to_simplify. + void ac_plugin::set_status(unsigned id, eq_status s) { + auto& eq = m_eqs[id]; + if (eq.status == eq_status::is_dead) + return; + if (s == eq_status::to_simplify && are_equal(monomial(eq.l), monomial(eq.r))) + s = eq_status::is_dead; + + if (eq.status != s) { + m_update_eq_trail.push_back({ id, eq }); + eq.status = s; + push_undo(is_update_eq); + } + switch (s) { + case eq_status::processed: + case eq_status::is_dead: + m_to_simplify_todo.remove(id); + break; + case eq_status::to_simplify: + m_to_simplify_todo.insert(id); + orient_equation(eq); + break; + } + } + + // + // superpose iterator enumerates all equations where lhs of eq have element in common. + // + unsigned_vector const& ac_plugin::superpose_iterator(unsigned eq_id) { + auto const& eq = m_eqs[eq_id]; + m_src_r.reset(); + m_src_r.append(monomial(eq.r).m_nodes); + init_ref_counts(monomial(eq.l), m_src_l_counts); + init_overlap_iterator(eq_id, monomial(eq.l)); + return m_eq_occurs; + } + + // + // backward iterator allows simplification of eq + // The rhs of eq is a super-set of lhs of other eq. + // + unsigned_vector const& ac_plugin::backward_iterator(unsigned eq_id) { + auto const& eq = m_eqs[eq_id]; + init_ref_counts(monomial(eq.r), m_dst_r_counts); + init_ref_counts(monomial(eq.l), m_dst_l_counts); + m_dst_r.reset(); + m_dst_r.append(monomial(eq.r).m_nodes); + init_subset_iterator(eq_id, monomial(eq.r)); + return m_eq_occurs; + } + + void ac_plugin::init_overlap_iterator(unsigned eq_id, monomial_t const& m) { + m_eq_occurs.reset(); + for (auto n : m) + m_eq_occurs.append(n->root->eqs); + compress_eq_occurs(eq_id); + } + + // + // add all but one of the use lists. Identify the largest use list and skip it. + // The rationale is that [a, b] is a subset of [a, b, c, d, e] if + // it has at least two elements (otherwise it would not apply as a rewrite over AC). + // then one of the two elements has to be in the set of [a, b, c, d, e] \ { x } + // where x is an arbitrary value from a, b, c, d, e. Not a two-element watch list, but still. + // + void ac_plugin::init_subset_iterator(unsigned eq_id, monomial_t const& m) { + unsigned max_use = 0; + node* max_n = nullptr; + bool has_two = false; + for (auto n : m) + if (n->root->eqs.size() >= max_use) + has_two |= max_n && (max_n != n->root), max_n = n->root, max_use = n->root->eqs.size(); + m_eq_occurs.reset(); + if (has_two) { + for (auto n : m) + if (n->root != max_n) + m_eq_occurs.append(n->root->eqs); + } + else { + for (auto n : m) { + m_eq_occurs.append(n->root->eqs); + break; + } + } + compress_eq_occurs(eq_id); + } + + // prune m_eq_occurs to single occurrences + void ac_plugin::compress_eq_occurs(unsigned eq_id) { + unsigned j = 0; + m_eq_seen.reserve(m_eqs.size() + 1, false); + for (unsigned i = 0; i < m_eq_occurs.size(); ++i) { + unsigned id = m_eq_occurs[i]; + if (m_eq_seen[id]) + continue; + if (id == eq_id) + continue; + m_eq_occurs[j++] = id; + m_eq_seen[id] = true; + } + m_eq_occurs.shrink(j); + for (auto id : m_eq_occurs) + m_eq_seen[id] = false; + } + + // + // forward iterator simplifies other eqs where their rhs is a superset of lhs of eq + // + unsigned_vector const& ac_plugin::forward_iterator(unsigned eq_id) { + auto& eq = m_eqs[eq_id]; + m_src_r.reset(); + m_src_r.append(monomial(eq.r).m_nodes); + init_ref_counts(monomial(eq.l), m_src_l_counts); + init_ref_counts(monomial(eq.r), m_src_r_counts); + unsigned min_r = UINT_MAX; + node* min_n = nullptr; + for (auto n : monomial(eq.l)) + if (n->root->eqs.size() < min_r) + min_n = n, min_r = n->root->eqs.size(); + // found node that occurs in fewest eqs + VERIFY(min_n); + return min_n->eqs; + } + + void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) const { + init_ref_counts(monomial.m_nodes, counts); + } + + void ac_plugin::init_ref_counts(ptr_vector const& monomial, ref_counts& counts) const { + counts.reset(); + for (auto n : monomial) + counts.inc(n->root_id(), 1); + } + + bool ac_plugin::is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const { + return is_correct_ref_count(m.m_nodes, counts); + } + + bool ac_plugin::is_correct_ref_count(ptr_vector const& m, ref_counts const& counts) const { + ref_counts check; + init_ref_counts(m, check); + return + all_of(counts, [&](unsigned i) { return check[i] == counts[i]; }) && + all_of(check, [&](unsigned i) { return check[i] == counts[i]; }); + } + + void ac_plugin::forward_simplify(unsigned src_eq, unsigned dst_eq) { + + if (src_eq == dst_eq) + return; + + // check that left src.l is a subset of dst.r + // dst = A -> BC + // src = B -> D + // post(dst) := A -> CD + auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized + auto& dst = m_eqs[dst_eq]; + + TRACE("plugin", tout << "forward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n"); + + + if (forward_subsumes(src_eq, dst_eq)) { + TRACE("plugin", tout << "forward subsumed\n"); + set_status(dst_eq, eq_status::is_dead); + return; + } + + if (!can_be_subset(monomial(src.l), monomial(dst.r))) + return; + + + m_dst_r_counts.reset(); + + unsigned src_l_size = monomial(src.l).size(); + unsigned src_r_size = m_src_r.size(); + + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + // subtract src.l from dst.r if src.l is a subset of dst.r + // dst_rhs := dst_rhs - src_lhs + src_rhs + // := src_rhs + (dst_rhs - src_lhs) + // := src_rhs + elements from dst_rhs that are in excess of src_lhs + unsigned num_overlap = 0; + for (auto n : monomial(dst.r)) { + unsigned id = n->root_id(); + unsigned dst_count = m_dst_r_counts[id]; + unsigned src_count = m_src_l_counts[id]; + if (dst_count > src_count) { + m_src_r.push_back(n); + m_dst_r_counts.dec(id, 1); + } + else if (dst_count < src_count) { + m_src_r.shrink(src_r_size); + return; + } + else + ++num_overlap; + } + // The dst.r has to be a superset of src.l, otherwise simplification does not apply + if (num_overlap != src_l_size) { + m_src_r.shrink(src_r_size); + return; + } + auto j = justify_rewrite(src_eq, dst_eq); + reduce(m_src_r, j); + auto new_r = to_monomial(m_src_r); + index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r)); + m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); + m_eqs[dst_eq].r = new_r; + m_eqs[dst_eq].j = j; + push_undo(is_update_eq); + m_src_r.reset(); + m_src_r.append(monomial(src.r).m_nodes); + TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); + } + + bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { + if (src_eq == dst_eq) + return false; + + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; // pre-computed dst_r_counts, dst_l_counts + // + // dst_ids, dst_count contain rhs of dst_eq + // + TRACE("plugin", tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n"); + + if (backward_subsumes(src_eq, dst_eq)) { + TRACE("plugin", tout << "backward subsumed\n"); + set_status(dst_eq, eq_status::is_dead); + return true; + } + // check that src.l is a subset of dst.r + if (!can_be_subset(monomial(src.l), monomial(dst.r))) + return false; + if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l))) { + TRACE("plugin", tout << "not subset\n"); + return false; + } + + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); + + ptr_vector m(m_dst_r); + init_ref_counts(monomial(src.l), m_src_l_counts); + + rewrite1(m_src_l_counts, monomial(src.r), m_dst_r_counts, m); + auto j = justify_rewrite(src_eq, dst_eq); + reduce(m, j); + auto new_r = to_monomial(m); + index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r)); + m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); + m_eqs[dst_eq].r = new_r; + m_eqs[dst_eq].j = j; + TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); + push_undo(is_update_eq); + return true; + } + + // dst_eq is fixed, dst_l_count is pre-computed for monomial(dst.l) + // dst_r_counts is pre-computed for monomial(dst.r). + // is dst_eq subsumed by src_eq? + bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) { + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); + if (!can_be_subset(monomial(src.l), monomial(dst.l))) + return false; + if (!can_be_subset(monomial(src.r), monomial(dst.r))) + return false; + unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); + if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) + return false; + if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) + return false; + if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) + return false; + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); + // add difference betwen dst.l and src.l to both src.l, src.r + for (auto n : monomial(dst.l)) { + unsigned id = n->root_id(); + SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]); + unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; + if (diff > 0) { + m_src_l_counts.inc(id, diff); + m_src_r_counts.inc(id, diff); + } + } + // now dst.r and src.r should align and have the same elements. + // since src.r is a subset of dst.r we iterate over dst.r + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); + } + + // src_l_counts, src_r_counts are initialized for src.l, src.r + bool ac_plugin::forward_subsumes(unsigned src_eq, unsigned dst_eq) { + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); + if (!can_be_subset(monomial(src.l), monomial(dst.l))) + return false; + if (!can_be_subset(monomial(src.r), monomial(dst.r))) + return false; + unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); + if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) + return false; + if (!is_superset(m_src_l_counts, m_dst_l_counts, monomial(dst.l))) + return false; + if (!is_superset(m_src_r_counts, m_dst_r_counts, monomial(dst.r))) + return false; + SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + SASSERT(m_src_l_counts[id] <= m_dst_l_counts[id]); + unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; + if (diff == 0) + continue; + m_dst_l_counts.dec(id, diff); + if (m_dst_r_counts[id] < diff) + return false; + m_dst_r_counts.dec(id, diff); + } + + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); + } + + void ac_plugin::rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_counts, ptr_vector& dst) { + // pre-condition: is-subset is invoked so that src_l is initialized. + // pre-condition: dst_count is also initialized. + // remove from dst elements that are in src_l + // add elements from src_r + SASSERT(is_correct_ref_count(dst, dst_counts)); + SASSERT(&src_r.m_nodes != &dst); + unsigned sz = dst.size(), j = 0; + for (unsigned i = 0; i < sz; ++i) { + auto* n = dst[i]; + unsigned id = n->root_id(); + unsigned dst_count = dst_counts[id]; + unsigned src_count = src_l[id]; + SASSERT(dst_count > 0); + if (src_count == 0) + dst[j++] = n; + else if (src_count < dst_count) { + dst[j++] = n; + dst_counts.dec(id, 1); + } + } + dst.shrink(j); + dst.append(src_r.m_nodes); + } + + // rewrite monomial to normal form. + bool ac_plugin::reduce(ptr_vector& m, justification& j) { + bool change = false; + do { + init_loop: + if (m.size() == 1) + return change; + bloom b; + init_ref_counts(m, m_m_counts); + for (auto n : m) { + for (auto eq : n->root->eqs) { + if (!is_processed(eq)) + continue; + auto& src = m_eqs[eq]; + + if (!can_be_subset(monomial(src.l), m, b)) + continue; + if (!is_subset(m_m_counts, m_eq_counts, monomial(src.l))) + continue; + TRACE("plugin", display_equation(tout << "reduce ", src) << "\n"); + SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts)); + rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m); + j = join(j, eq); + change = true; + goto init_loop; + } + } + } + while (false); + return change; + } + + // check that src is a subset of dst, where dst_counts are precomputed + bool ac_plugin::is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src) { + SASSERT(&dst_counts != &src_counts); + init_ref_counts(src, src_counts); + return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; }); + } + + // check that dst is a superset of src, where src_counts are precomputed + bool ac_plugin::is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst) { + SASSERT(&dst_counts != &src_counts); + init_ref_counts(dst, dst_counts); + return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; }); + } + + void ac_plugin::index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r) { + for (auto n : old_r) + n->root->n->mark1(); + for (auto n : new_r) + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq); + m_node_trail.push_back(n->root); + n->root->n->mark1(); + push_undo(is_add_eq_index); + } + for (auto n : old_r) + n->root->n->unmark1(); + for (auto n : new_r) + n->root->n->unmark1(); + } + + + void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) { + if (src_eq == dst_eq) + return; + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + + TRACE("plugin", tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";); + // AB -> C, AD -> E => BE ~ CD + // m_src_ids, m_src_counts contains information about src (call it AD -> E) + m_dst_l_counts.reset(); + + m_dst_r.reset(); + m_dst_r.append(monomial(dst.r).m_nodes); + unsigned src_r_size = m_src_r.size(); + unsigned dst_r_size = m_dst_r.size(); + SASSERT(src_r_size == monomial(src.r).size()); + // dst_r contains C + // src_r contains E + + // compute BE, initialize dst_ids, dst_counts + bool overlap = false; + for (auto n : monomial(dst.l)) { + unsigned id = n->root_id(); + m_dst_l_counts.inc(id, 1); + if (m_src_l_counts[id] < m_dst_l_counts[id]) + m_src_r.push_back(n); + overlap |= m_src_l_counts[id] > 0; + } + + if (!overlap) { + m_src_r.shrink(src_r_size); + return; + } + + // compute CD + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + if (m_dst_l_counts[id] > 0) + m_dst_l_counts.dec(id, 1); + else + m_dst_r.push_back(n); + } + + if (are_equal(m_src_r, m_dst_r)) { + m_src_r.shrink(src_r_size); + return; + } + + TRACE("plugin", tout << m_pp(*this, m_src_r) << "== " << m_pp(*this, m_dst_r) << "\n";); + + justification j = justify_rewrite(src_eq, dst_eq); + reduce(m_dst_r, j); + reduce(m_src_r, j); + if (m_src_r.size() == 1 && m_dst_r.size() == 1) + push_merge(m_src_r[0]->n, m_dst_r[0]->n, j); + else + init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j)); + + m_src_r.reset(); + m_src_r.append(monomial(src.r).m_nodes); + } + + bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) { + return filter(a) == filter(b) && are_equal(a.m_nodes, b.m_nodes); + } + + bool ac_plugin::are_equal(ptr_vector const& a, ptr_vector const& b) { + if (a.size() != b.size()) + return false; + m_eq_counts.reset(); + for (auto n : a) + m_eq_counts.inc(n->root_id(), 1); + + for (auto n : b) { + unsigned id = n->root_id(); + if (m_eq_counts[id] == 0) + return false; + m_eq_counts.dec(id, 1); + } + return true; + } + + // + // simple version based on propagating all shared + // todo: version touching only newly processed shared, and maintaining incremental data-structures. + // - hash-tables for shared monomials similar to the ones used for euf_table. + // the tables have to be updated (and re-sorted) whenever a child changes root. + // + + void ac_plugin::propagate_shared() { + if (m_shared_todo.empty()) + return; + while (!m_shared_todo.empty()) { + auto idx = *m_shared_todo.begin(); + m_shared_todo.remove(idx); + if (idx < m_shared.size()) + simplify_shared(idx, m_shared[idx]); + } + m_monomial_table.reset(); + for (auto const& s1 : m_shared) { + shared s2; + TRACE("plugin", tout << "shared " << m_pp(*this, monomial(s1.m)) << "\n"); + if (!m_monomial_table.find(s1.m, s2)) + m_monomial_table.insert(s1.m, s1); + else if (s2.n->get_root() != s1.n->get_root()) { + TRACE("plugin", tout << m_pp(*this, monomial(s1.m)) << " == " << m_pp(*this, monomial(s2.m)) << "\n"); + push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j)))); + } + } + } + + void ac_plugin::simplify_shared(unsigned idx, shared s) { + auto j = s.j; + auto old_m = s.m; + ptr_vector m1(monomial(old_m).m_nodes); + TRACE("plugin", tout << "simplify " << m_pp(*this, monomial(old_m)) << "\n"); + if (!reduce(m1, j)) + return; + + auto new_m = to_monomial(m1); + // update shared occurrences for members of the new monomial that are not already in the old monomial. + for (auto n : monomial(old_m)) + n->root->n->mark1(); + for (auto n : m1) + if (!n->root->n->is_marked1()) { + n->root->shared.push_back(idx); + m_shared_todo.insert(idx); + m_node_trail.push_back(n->root); + push_undo(is_add_shared_index); + } + for (auto n : monomial(old_m)) + n->root->n->unmark1(); + m_update_shared_trail.push_back({ idx, s }); + push_undo(is_update_shared); + m_shared[idx].m = new_m; + m_shared[idx].j = j; + } + + justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) { + auto* j = m_dep_manager.mk_join(justify_equation(eq1), justify_equation(eq2)); + return justification::dependent(j); + } + + justification::dependency* ac_plugin::justify_equation(unsigned eq) { + auto const& e = m_eqs[eq]; + auto* j = m_dep_manager.mk_leaf(e.j); + j = justify_monomial(j, monomial(e.l)); + j = justify_monomial(j, monomial(e.r)); + return j; + } + + justification::dependency* ac_plugin::justify_monomial(justification::dependency* j, monomial_t const& m) { + for (auto n : m) + if (n->root->n != n->n) + j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n))); + return j; + } + + justification ac_plugin::join(justification j, unsigned eq) { + return justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(j), justify_equation(eq))); + } + +} diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h new file mode 100644 index 000000000..ea444b60c --- /dev/null +++ b/src/ast/euf/euf_ac_plugin.h @@ -0,0 +1,309 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_ac_plugin.h + +Abstract: + + plugin structure for ac functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +ex: +xyz -> xy, then xyzz -> xy by repeated rewriting + +monomials = [0 |-> xyz, 1 |-> xy, 2 |-> xyzz] +parents(x) = [0, 1, 2] +parents(z) = [0, 1] +for p in parents(xyzz): + p != xyzz + p' := simplify_using(xyzz, p) + if p != p': + repeat reduction using p := p' + +--*/ + +#pragma once + +#include +#include "ast/euf/euf_plugin.h" + +namespace euf { + + class ac_plugin : public plugin { + + // enode structure for AC equivalences + struct node { + enode* n; // associated enode + node* root; // path compressed root + node* next; // next in equaivalence class + justification j; // justification for equality + node* target = nullptr; // justified next + unsigned_vector shared; // shared occurrences + unsigned_vector eqs; // equality occurrences + + unsigned root_id() const { return root->n->get_id(); } + ~node() {} + static node* mk(region& r, enode* n); + }; + + class equiv { + node& n; + public: + class iterator { + node* m_first; + node* m_last; + public: + iterator(node* n, node* m) : m_first(n), m_last(m) {} + node* operator*() { return m_first; } + iterator& operator++() { if (!m_last) m_last = m_first; m_first = m_first->next; return *this; } + iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; } + bool operator==(iterator const& other) const { return m_last == other.m_last && m_first == other.m_first; } + bool operator!=(iterator const& other) const { return !(*this == other); } + }; + equiv(node& _n) :n(_n) {} + equiv(node* _n) :n(*_n) {} + iterator begin() const { return iterator(&n, nullptr); } + iterator end() const { return iterator(&n, &n); } + }; + + struct bloom { + uint64_t m_tick = 0; + uint64_t m_filter = 0; + }; + + enum eq_status { + processed, to_simplify, is_dead + }; + + // represent equalities added by merge_eh and by superposition + struct eq { + eq(unsigned l, unsigned r, justification j): + l(l), r(r), j(j) {} + unsigned l, r; // refer to monomials + eq_status status = to_simplify; + justification j; // justification for equality + }; + + // represent shared enodes that use the AC symbol. + struct shared { + enode* n; // original shared enode + unsigned m; // monomial index + justification j; // justification for current simplification of monomial + }; + + struct monomial_t { + ptr_vector m_nodes; + bloom m_bloom; + node* operator[](unsigned i) const { return m_nodes[i]; } + unsigned size() const { return m_nodes.size(); } + void set(ptr_vector const& ns) { m_nodes.reset(); m_nodes.append(ns); m_bloom.m_tick = 0; } + node* const* begin() const { return m_nodes.begin(); } + node* const* end() const { return m_nodes.end(); } + node* * begin() { return m_nodes.begin(); } + node* * end() { return m_nodes.end(); } + }; + + + struct monomial_hash { + ac_plugin& p; + monomial_hash(ac_plugin& p) :p(p) {} + unsigned operator()(unsigned i) const { + unsigned h = 0; + auto& m = p.monomial(i); + if (!p.is_sorted(m)) + p.sort(m); + for (auto* n : m) + h = combine_hash(h, n->root_id()); + return h; + } + }; + + struct monomial_eq { + ac_plugin& p; + monomial_eq(ac_plugin& p) :p(p) {} + bool operator()(unsigned i, unsigned j) const { + auto const& m1 = p.monomial(i); + auto const& m2 = p.monomial(j); + if (m1.size() != m2.size()) return false; + for (unsigned k = 0; k < m1.size(); ++k) + if (m1[k]->root_id() != m2[k]->root_id()) + return false; + return true; + } + }; + + unsigned m_fid = 0; + unsigned m_op = null_decl_kind; + func_decl* m_decl = nullptr; + vector m_eqs; + ptr_vector m_nodes; + bool_vector m_shared_nodes; + vector m_monomials; + svector m_shared; + justification::dependency_manager m_dep_manager; + tracked_uint_set m_to_simplify_todo; + tracked_uint_set m_shared_todo; + uint64_t m_tick = 1; + + + + monomial_hash m_hash; + monomial_eq m_eq; + map m_monomial_table; + + + // backtrackable state + enum undo_kind { + is_add_eq, + is_add_monomial, + is_add_node, + is_merge_node, + is_update_eq, + is_add_shared_index, + is_add_eq_index, + is_register_shared, + is_update_shared + }; + svector m_undo; + ptr_vector m_node_trail; + + svector> m_update_shared_trail; + svector> m_merge_trail; + svector> m_update_eq_trail; + + + + node* mk_node(enode* n); + void merge(node* r1, node* r2, justification j); + + bool is_op(enode* n) const { auto d = n->get_decl(); return d && (d == m_decl || (m_fid == d->get_family_id() && m_op == d->get_decl_kind())); } + + std::function m_undo_notify; + void push_undo(undo_kind k); + enode_vector m_todo; + unsigned to_monomial(enode* n); + unsigned to_monomial(enode* n, ptr_vector const& ms); + unsigned to_monomial(ptr_vector const& ms) { return to_monomial(nullptr, ms); } + monomial_t const& monomial(unsigned i) const { return m_monomials[i]; } + monomial_t& monomial(unsigned i) { return m_monomials[i]; } + void sort(monomial_t& monomial); + bool is_sorted(monomial_t const& monomial) const; + uint64_t filter(monomial_t& m); + bool can_be_subset(monomial_t& subset, monomial_t& superset); + bool can_be_subset(monomial_t& subset, ptr_vector const& m, bloom& b); + bool are_equal(ptr_vector const& a, ptr_vector const& b); + bool are_equal(monomial_t& a, monomial_t& b); + bool backward_subsumes(unsigned src_eq, unsigned dst_eq); + bool forward_subsumes(unsigned src_eq, unsigned dst_eq); + + void init_equation(eq const& e); + bool orient_equation(eq& e); + void set_status(unsigned eq_id, eq_status s); + unsigned pick_next_eq(); + + void forward_simplify(unsigned eq_id, unsigned using_eq); + bool backward_simplify(unsigned eq_id, unsigned using_eq); + void superpose(unsigned src_eq, unsigned dst_eq); + + ptr_vector m_src_r, m_src_l, m_dst_r, m_dst_l; + + struct ref_counts { + unsigned_vector ids; + unsigned_vector counts; + void reset() { for (auto idx : ids) counts[idx] = 0; ids.reset(); } + unsigned operator[](unsigned idx) const { return counts.get(idx, 0); } + void inc(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] += amount; } + void dec(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] -= amount; } + unsigned const* begin() const { return ids.begin(); } + unsigned const* end() const { return ids.end(); } + }; + ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts, m_m_counts; + unsigned_vector m_eq_occurs; + bool_vector m_eq_seen; + + unsigned_vector const& forward_iterator(unsigned eq); + unsigned_vector const& superpose_iterator(unsigned eq); + unsigned_vector const& backward_iterator(unsigned eq); + void init_ref_counts(monomial_t const& monomial, ref_counts& counts) const; + void init_ref_counts(ptr_vector const& monomial, ref_counts& counts) const; + void init_overlap_iterator(unsigned eq, monomial_t const& m); + void init_subset_iterator(unsigned eq, monomial_t const& m); + void compress_eq_occurs(unsigned eq_id); + // check that src is a subset of dst, where dst_counts are precomputed + bool is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src); + + // check that dst is a superset of dst, where src_counts are precomputed + bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst); + void rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_r_counts, ptr_vector& dst_r); + bool reduce(ptr_vector& m, justification& j); + void index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r); + + bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; } + bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; } + bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; } + + justification justify_rewrite(unsigned eq1, unsigned eq2); + justification::dependency* justify_equation(unsigned eq); + justification::dependency* justify_monomial(justification::dependency* d, monomial_t const& m); + justification join(justification j1, unsigned eq); + + bool is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const; + bool is_correct_ref_count(ptr_vector const& m, ref_counts const& counts) const; + + void register_shared(enode* n); + void propagate_shared(); + void simplify_shared(unsigned idx, shared s); + + std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const { return display_monomial(out, m.m_nodes); } + std::ostream& display_monomial(std::ostream& out, ptr_vector const& m) const; + std::ostream& display_equation(std::ostream& out, eq const& e) const; + std::ostream& display_status(std::ostream& out, eq_status s) const; + + + public: + + ac_plugin(egraph& g, unsigned fid, unsigned op); + + ac_plugin(egraph& g, func_decl* f); + + ~ac_plugin() override {} + + unsigned get_id() const override { return m_fid; } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override; + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + void set_undo(std::function u) { m_undo_notify = u; } + + struct eq_pp { + ac_plugin& p; eq const& e; + eq_pp(ac_plugin& p, eq const& e) : p(p), e(e) {}; + eq_pp(ac_plugin& p, unsigned eq_id): p(p), e(p.m_eqs[eq_id]) {} + std::ostream& display(std::ostream& out) const { return p.display_equation(out, e); } + }; + + struct m_pp { + ac_plugin& p; ptr_vector const& m; + m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m.m_nodes) {} + m_pp(ac_plugin& p, ptr_vector const& m) : p(p), m(m) {} + std::ostream& display(std::ostream& out) const { return p.display_monomial(out, m); } + }; + }; + + inline std::ostream& operator<<(std::ostream& out, ac_plugin::eq_pp const& d) { return d.display(out); } + inline std::ostream& operator<<(std::ostream& out, ac_plugin::m_pp const& d) { return d.display(out); } +} diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp new file mode 100644 index 000000000..26f8e0bd9 --- /dev/null +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -0,0 +1,71 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_arith_plugin.cpp + +Abstract: + + plugin structure for arithetic + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#include "ast/euf/euf_arith_plugin.h" +#include "ast/euf/euf_egraph.h" +#include + +namespace euf { + + arith_plugin::arith_plugin(egraph& g) : + plugin(g), + a(g.get_manager()), + m_add(g, get_id(), OP_ADD), + m_mul(g, get_id(), OP_MUL) { + std::function uadd = [&]() { m_undo.push_back(undo_t::undo_add); }; + m_add.set_undo(uadd); + std::function umul = [&]() { m_undo.push_back(undo_t::undo_mul); }; + m_mul.set_undo(umul); + } + + void arith_plugin::register_node(enode* n) { + // no-op + } + + void arith_plugin::merge_eh(enode* n1, enode* n2) { + m_add.merge_eh(n1, n2); + m_mul.merge_eh(n1, n2); + } + + void arith_plugin::propagate() { + m_add.propagate(); + m_mul.propagate(); + } + + void arith_plugin::undo() { + auto k = m_undo.back(); + m_undo.pop_back(); + switch (k) { + case undo_t::undo_add: + m_add.undo(); + break; + case undo_t::undo_mul: + m_mul.undo(); + break; + default: + UNREACHABLE(); + } + } + + std::ostream& arith_plugin::display(std::ostream& out) const { + out << "add\n"; + m_add.display(out); + out << "mul\n"; + m_mul.display(out); + return out; + } +} diff --git a/src/ast/euf/euf_arith_plugin.h b/src/ast/euf/euf_arith_plugin.h new file mode 100644 index 000000000..7cca01f1c --- /dev/null +++ b/src/ast/euf/euf_arith_plugin.h @@ -0,0 +1,53 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_arith_plugin.h + +Abstract: + + plugin structure for arithetic +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#pragma once + +#include "ast/arith_decl_plugin.h" +#include "ast/euf/euf_plugin.h" +#include "ast/euf/euf_ac_plugin.h" + +namespace euf { + + class egraph; + + class arith_plugin : public plugin { + enum undo_t { undo_add, undo_mul }; + arith_util a; + svector m_undo; + ac_plugin m_add, m_mul; + + public: + arith_plugin(egraph& g); + + ~arith_plugin() override {} + + unsigned get_id() const override { return a.get_family_id(); } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override {} + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + }; +} diff --git a/src/ast/euf/euf_bv_plugin.cpp b/src/ast/euf/euf_bv_plugin.cpp new file mode 100644 index 000000000..99bf8941b --- /dev/null +++ b/src/ast/euf/euf_bv_plugin.cpp @@ -0,0 +1,361 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_bv_plugin.cpp + +Abstract: + + plugin structure for bit-vectors + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + Jakob Rath 2023-11-08 + +Objective: + +satisfies extract/concat axioms. + - concat(n{I],n[J]) = n[IJ] for I, J consecutive. + - concat(v1, v2) = 2^width(v1)*v2 + v1 + - concat(n[width(n)-1:0]) = n + - concat(a, b)[I] = concat(a[I1], b[I2]) + - concat(a, concat(b, c)) = concat(concat(a, b), c) + +E-graph: + +The E-graph contains node definitions of the form + + n := f(n1,n2,..) + +and congruences: + + n ~ n' means root(n) = root(n') + +Saturated state: + + 1. n := n1[I], n' := n2[J], n1 ~ n2 => root(n1) contains tree refining both I, J from smaller intervals + + 2. n := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => n ~ n3[IJ] + + 3. n := concat(n1[I], n2[J]), I and J are consecutive & n1 ~ n2, n1[I] ~ v1, n1[J] ~ v2 => n ~ 2^width(v1)*v2 + v1 + + 4. n := concat(n1[I], n2[J], I, J are consecutive, n1 ~ n2, n ~ v => n1[I] ~ v mod 2^width(n1[I]), n2[J] ~ v div 2^width(n1[I]) + + 5. n' := n[I] => n ~ n[width(n)-1:0] + + 6. n := concat(a, concat(b, c)) => n ~ concat(concat(a, b), c) + - handled by rewriter pre-processing for inputs + - terms created internally are not equated modulo associativity + + 7, n := concat(n1, n2)[I] => n ~ concat(n1[I1],n2[I2]) or n[I1] or n[I2] + - handled by rewriter pre-processing + +Example: + x == (x1 x2) x3 + y == y1 (y2 y3) + x1 == y1, x2 == y2, x3 == y3 + => + x = y + + by x2 == y2, x3 == y3 => (x2 x3) = (y2 y3) + by (2) => x[I23] = (x2 x3) + by (2) => x[I123] = (x1 (x2 x3)) + by (5) => x = x[I123] + +The formal properties of saturation have to be established. + +- Saturation does not complete with respect to associativity. +Instead the claim is along the lines that the resulting E-graph can be used as a canonizer. +If given a set of equations E that are saturated, and terms t1, t2 that are +both simplified with respect to left-associativity of concatentation, and t1, t2 belong to the E-graph, +then t1 = t2 iff t1 ~ t2 in the E-graph. + +TODO: Is saturation for (7) overkill for the purpose of canonization? + +TODO: revisit re-entrancy during register_node. It can be called when creating internal extract terms. +Instead of allowing re-entrancy we can accumulate nodes that are registered during recursive calls +and have the main call perform recursive slicing. + +--*/ + + +#include "ast/euf/euf_bv_plugin.h" +#include "ast/euf/euf_egraph.h" + +namespace euf { + + bv_plugin::bv_plugin(egraph& g): + plugin(g), + bv(g.get_manager()) + {} + + enode* bv_plugin::mk_value_concat(enode* a, enode* b) { + auto v1 = get_value(a); + auto v2 = get_value(b); + auto v3 = v1 + v2 * power(rational(2), width(a)); + return mk_value(v3, width(a) + width(b)); + } + + enode* bv_plugin::mk_value(rational const& v, unsigned sz) { + auto e = bv.mk_numeral(v, sz); + return mk(e, 0, nullptr); + } + + void bv_plugin::merge_eh(enode* x, enode* y) { + SASSERT(x == x->get_root()); + SASSERT(x == y->get_root()); + + TRACE("bv", tout << "merge_eh " << g.bpp(x) << " == " << g.bpp(y) << "\n"); + SASSERT(!m_internal); + flet _internal(m_internal, true); + + propagate_values(x); + + // ensure slices align + if (has_sub(x) || has_sub(y)) { + enode_vector& xs = m_xs, & ys = m_ys; + xs.reset(); + ys.reset(); + xs.push_back(x); + ys.push_back(y); + merge(xs, ys, justification::equality(x, y)); + } + + // ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ] + for (auto* n : enode_class(x)) + propagate_extract(n); + } + + // enforce concat(v1, v2) = v2*2^|v1| + v1 + void bv_plugin::propagate_values(enode* x) { + if (!is_value(x)) + return; + + enode* a, * b; + for (enode* p : enode_parents(x)) + if (is_concat(p, a, b) && is_value(a) && is_value(b) && !is_value(p)) + push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b)); + + for (enode* sib : enode_class(x)) { + if (is_concat(sib, a, b)) { + if (!is_value(a) || !is_value(b)) { + auto val = get_value(x); + auto v1 = mod2k(val, width(a)); + auto v2 = machine_div2k(val, width(a)); + push_merge(mk_concat(mk_value(v1, width(a)), mk_value(v2, width(b))), x->get_interpreted()); + } + } + } + } + + // + // p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ] + // + // n is of form arg[I] + // p is of form concat(n, b) or concat(a, n) + // b is congruent to arg[J], I is consecutive with J => ensure that arg[IJ] = p + // a is congruent to arg[J], J is consecutive with I => ensure that arg[JI] = p + // + + void bv_plugin::propagate_extract(enode* n) { + unsigned lo1, hi1, lo2, hi2; + enode* a, * b; + if (!is_extract(n, lo1, hi1)) + return; + + enode* arg = n->get_arg(0); + enode* arg_r = arg->get_root(); + enode* n_r = n->get_root(); + + auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) { + TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n"); + unsigned lo_, hi_; + for (enode* p1 : enode_parents(n)) + if (is_extract(p1, lo_, hi_) && lo_ == lo && hi_ == hi && p1->get_arg(0)->get_root() == arg_r) + return; + // add the axiom instead of merge(p, mk_extract(arg, lo, hi)), which would require tracking justifications + push_merge(mk_concat(mk_extract(arg, lo, mid), mk_extract(arg, mid + 1, hi)), mk_extract(arg, lo, hi)); + }; + + auto propagate_left = [&](enode* b) { + TRACE("bv", tout << "propagate-left " << g.bpp(b) << "\n"); + for (enode* sib : enode_class(b)) + if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2) + ensure_concat(lo1, hi1, hi2); + }; + + auto propagate_right = [&](enode* a) { + TRACE("bv", tout << "propagate-right " << g.bpp(a) << "\n"); + for (enode* sib : enode_class(a)) + if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1) + ensure_concat(lo2, hi2, hi1); + }; + + for (enode* p : enode_parents(n)) { + if (is_concat(p, a, b)) { + if (a->get_root() == n_r) + propagate_left(b); + if (b->get_root() == n_r) + propagate_right(a); + } + } + } + + void bv_plugin::push_undo_split(enode* n) { + m_undo_split.push_back(n); + push_plugin_undo(bv.get_family_id()); + } + + void bv_plugin::undo() { + enode* n = m_undo_split.back(); + m_undo_split.pop_back(); + auto& i = info(n); + i.lo = nullptr; + i.hi = nullptr; + i.cut = null_cut; + } + + void bv_plugin::register_node(enode* n) { + TRACE("bv", tout << "register " << g.bpp(n) << "\n"); + auto& i = info(n); + i.value = n; + enode* a, * b; + if (is_concat(n, a, b)) { + i.lo = a; + i.hi = b; + i.cut = width(a); + push_undo_split(n); + } + unsigned lo, hi; + if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) { + enode* arg = n->get_arg(0); + unsigned w = width(arg); + if (all_of(enode_parents(arg), [&](enode* p) { unsigned _lo, _hi; return !is_extract(p, _lo, _hi) || _lo != 0 || _hi + 1 != w; })) + push_merge(mk_extract(arg, 0, w - 1), arg); + ensure_slice(arg, lo, hi); + } + } + + // + // Ensure that there are slices at boundaries of n[hi:lo] + // + void bv_plugin::ensure_slice(enode* n, unsigned lo, unsigned hi) { + enode* r = n; + unsigned lb = 0, ub = width(n) - 1; + while (true) { + TRACE("bv", tout << "ensure slice " << g.bpp(n) << " " << lb << " [" << lo << ", " << hi << "] " << ub << "\n"); + SASSERT(lb <= lo && hi <= ub); + SASSERT(ub - lb + 1 == width(r)); + if (lb == lo && ub == hi) + return; + slice_info& i = info(r); + if (!i.lo) { + if (lo > lb) { + split(r, lo - lb); + if (hi < ub) // or split(info(r).hi, ...) + ensure_slice(n, lo, hi); + } + else if (hi < ub) + split(r, ub - hi); + break; + } + auto cut = i.cut; + if (cut + lb <= lo) { + lb += cut; + r = i.hi; + continue; + } + if (cut + lb > hi) { + ub = cut + lb - 1; + r = i.lo; + continue; + } + SASSERT(lo < cut + lb && cut + lb <= hi); + ensure_slice(n, lo, cut + lb - 1); + ensure_slice(n, cut + lb, hi); + break; + } + } + + enode* bv_plugin::mk_extract(enode* n, unsigned lo, unsigned hi) { + SASSERT(lo <= hi && width(n) > hi - lo); + unsigned lo1, hi1; + while (is_extract(n, lo1, hi1)) { + lo += lo1; + hi += lo1; + n = n->get_arg(0); + } + return mk(bv.mk_extract(hi, lo, n->get_expr()), 1, &n); + } + + enode* bv_plugin::mk_concat(enode* lo, enode* hi) { + enode* args[2] = { lo, hi }; + return mk(bv.mk_concat(lo->get_expr(), hi->get_expr()), 2, args); + } + + void bv_plugin::merge(enode_vector& xs, enode_vector& ys, justification dep) { + while (!xs.empty()) { + SASSERT(!ys.empty()); + auto x = xs.back(); + auto y = ys.back(); + if (unfold_sub(x, xs)) + continue; + else if (unfold_sub(y, ys)) + continue; + else if (unfold_width(x, xs, y, ys)) + continue; + else if (unfold_width(y, ys, x, xs)) + continue; + else if (x->get_root() != y->get_root()) + push_merge(x, y, dep); + xs.pop_back(); + ys.pop_back(); + } + SASSERT(ys.empty()); + } + + bool bv_plugin::unfold_sub(enode* x, enode_vector& xs) { + if (!has_sub(x)) + return false; + xs.pop_back(); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + return true; + } + + bool bv_plugin::unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys) { + if (width(x) <= width(y)) + return false; + split(x, width(y)); + xs.pop_back(); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + return true; + } + + void bv_plugin::split(enode* n, unsigned cut) { + TRACE("bv", tout << "split: " << g.bpp(n) << " " << cut << "\n"); + unsigned w = width(n); + SASSERT(!info(n).hi); + SASSERT(0 < cut && cut < w); + enode* hi = mk_extract(n, cut, w - 1); + enode* lo = mk_extract(n, 0, cut - 1); + auto& i = info(n); + SASSERT(i.value); + i.hi = hi; + i.lo = lo; + i.cut = cut; + push_undo_split(n); + push_merge(mk_concat(lo, hi), n); + } + + std::ostream& bv_plugin::display(std::ostream& out) const { + out << "bv\n"; + for (auto const& i : m_info) + if (i.lo) + out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n"; + return out; + } +} diff --git a/src/ast/euf/euf_bv_plugin.h b/src/ast/euf/euf_bv_plugin.h new file mode 100644 index 000000000..b8d62051e --- /dev/null +++ b/src/ast/euf/euf_bv_plugin.h @@ -0,0 +1,100 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_bv_plugin.h + +Abstract: + + plugin structure for bit-vectors + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + Jakob Rath 2023-11-08 + + +--*/ + +#pragma once + +#include "ast/bv_decl_plugin.h" +#include "ast/euf/euf_plugin.h" + +namespace euf { + + class egraph; + + class bv_plugin : public plugin { + static constexpr unsigned null_cut = std::numeric_limits::max(); + + struct slice_info { + unsigned cut = null_cut; // = bv.get_bv_size(lo) + enode* hi = nullptr; // + enode* lo = nullptr; // + enode* value = nullptr; + void reset() { *this = slice_info(); } + }; + using slice_info_vector = svector; + + bv_util bv; + slice_info_vector m_info; // indexed by enode::get_id() + + enode_vector m_xs, m_ys; + + bool is_concat(enode* n) const { return bv.is_concat(n->get_expr()); } + bool is_concat(enode* n, enode*& a, enode*& b) { return is_concat(n) && (a = n->get_arg(0), b = n->get_arg(1), true); } + bool is_extract(enode* n, unsigned& lo, unsigned& hi) { expr* body; return bv.is_extract(n->get_expr(), lo, hi, body); } + bool is_extract(enode* n) const { return bv.is_extract(n->get_expr()); } + unsigned width(enode* n) const { return bv.get_bv_size(n->get_expr()); } + + enode* mk_extract(enode* n, unsigned lo, unsigned hi); + enode* mk_concat(enode* lo, enode* hi); + enode* mk_value_concat(enode* lo, enode* hi); + enode* mk_value(rational const& v, unsigned sz); + unsigned width(enode* n) { return bv.get_bv_size(n->get_expr()); } + bool is_value(enode* n) { return n->get_root()->interpreted(); } + rational get_value(enode* n) { rational val; VERIFY(bv.is_numeral(n->get_interpreted()->get_expr(), val)); return val; } + slice_info& info(enode* n) { unsigned id = n->get_id(); m_info.reserve(id + 1); return m_info[id]; } + slice_info& root_info(enode* n) { unsigned id = n->get_root_id(); m_info.reserve(id + 1); return m_info[id]; } + bool has_sub(enode* n) { return !!info(n).lo; } + enode* sub_lo(enode* n) { return info(n).lo; } + enode* sub_hi(enode* n) { return info(n).hi; } + + bool m_internal = false; + void ensure_slice(enode* n, unsigned lo, unsigned hi); + + + void split(enode* n, unsigned cut); + + bool unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys); + bool unfold_sub(enode* x, enode_vector& xs); + void merge(enode_vector& xs, enode_vector& ys, justification j); + void propagate_extract(enode* n); + void propagate_values(enode* n); + + enode_vector m_undo_split; + void push_undo_split(enode* n); + + public: + bv_plugin(egraph& g); + + ~bv_plugin() override {} + + unsigned get_id() const override { return bv.get_family_id(); } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override {} + + void propagate() override {} + + void undo() override; + + std::ostream& display(std::ostream& out) const override; + + }; +} diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 6c8cbda06..154106e23 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -130,8 +130,8 @@ namespace euf { if (n2 == n) update_children(n); else - merge(n, n2, justification::congruence(comm, m_congruence_timestamp++)); - + push_merge(n, n2, comm); + return n; } @@ -146,19 +146,36 @@ namespace euf { memory::deallocate(m_tmp_node); } + void egraph::add_plugin(plugin* p) { + m_plugins.reserve(p->get_id() + 1); + m_plugins.set(p->get_id(), p); + } + + void egraph::propagate_plugins() { + for (auto* p : m_plugins) + if (p) + p->propagate(); + } + void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) { TRACE("euf_verbose", tout << "eq: " << v1 << " == " << v2 << "\n";); m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r)); m_updates.push_back(update_record(update_record::new_th_eq())); ++m_stats.m_num_th_eqs; + auto* p = get_plugin(id); + if (p) + p->merge_eh(c, r); } - void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq) { + void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) { if (!th_propagates_diseqs(id)) return; TRACE("euf_verbose", tout << "eq: " << v1 << " != " << v2 << "\n";); - m_new_th_eqs.push_back(th_eq(id, v1, v2, eq)); + m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr())); m_updates.push_back(update_record(update_record::new_th_eq())); + auto* p = get_plugin(id); + if (p) + p->diseq_eh(eq); ++m_stats.m_num_th_diseqs; } @@ -202,7 +219,7 @@ namespace euf { return; theory_var v1 = arg1->get_closest_th_var(id); theory_var v2 = arg2->get_closest_th_var(id); - add_th_diseq(id, v1, v2, n->get_expr()); + add_th_diseq(id, v1, v2, n); return; } for (auto const& p : euf::enode_th_vars(r1)) { @@ -210,8 +227,8 @@ namespace euf { continue; for (auto const& q : euf::enode_th_vars(r2)) if (p.get_id() == q.get_id()) - add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n->get_expr()); - } + add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n); + } } @@ -230,7 +247,7 @@ namespace euf { n = n->get_root(); theory_var v2 = n->get_closest_th_var(id); if (v2 != null_theory_var) - add_th_diseq(id, v1, v2, p->get_expr()); + add_th_diseq(id, v1, v2, p); } } } @@ -249,6 +266,10 @@ namespace euf { theory_var w = n->get_th_var(id); enode* r = n->get_root(); + auto* p = get_plugin(id); + if (p) + p->register_node(n); + if (w == null_theory_var) { n->add_th_var(v, id, m_region); m_updates.push_back(update_record(n, id, update_record::add_th_var())); @@ -424,6 +445,9 @@ namespace euf { p.r1->m_args[i]->get_root()->m_parents.pop_back(); } break; + case update_record::tag_t::is_plugin_undo: + m_plugins[p.m_th_id]->undo(); + break; default: UNREACHABLE(); break; @@ -589,6 +613,9 @@ namespace euf { case to_merge_comm: merge(w.a, w.b, justification::congruence(w.commutativity(), m_congruence_timestamp++)); break; + case to_justified: + merge(w.a, w.b, w.j); + break; case to_add_literal: add_literal(w.a, w.b); break; @@ -760,6 +787,13 @@ namespace euf { justifications.push_back(j.ext()); else if (j.is_congruence()) push_congruence(a, b, j.is_commutative()); + else if (j.is_dependent()) { + vector js; + for (auto const& j2 : justification::dependency_manager::s_linearize(j.get_dependency(), js)) + explain_eq(justifications, cc, a, b, j2); + } + else if (j.is_equality()) + explain_eq(justifications, cc, j.lhs(), j.rhs()); if (cc && j.is_congruence()) cc->push_back(std::tuple(a->get_app(), b->get_app(), j.timestamp(), j.is_commutative())); } @@ -879,6 +913,9 @@ namespace euf { max_args = std::max(max_args, n->num_args()); for (enode* n : m_nodes) display(out, max_args, n); + for (auto* p : m_plugins) + if (p) + p->display(out); return out; } diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 83a3574ce..f2eedc21d 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -29,8 +29,10 @@ Notes: #include "util/statistics.h" #include "util/trail.h" #include "util/lbool.h" +#include "util/scoped_ptr_vector.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_etable.h" +#include "ast/euf/euf_plugin.h" #include "ast/ast_ll_pp.h" #include @@ -82,14 +84,18 @@ namespace euf { class egraph { + friend class plugin; + typedef ptr_vector trail_stack; - enum to_merge_t { to_merge_plain, to_merge_comm, to_add_literal }; + enum to_merge_t { to_merge_plain, to_merge_comm, to_justified, to_add_literal }; struct to_merge { enode* a, * b; to_merge_t t; + justification j; bool commutativity() const { return t == to_merge_comm; } to_merge(enode* a, enode* b, bool c) : a(a), b(b), t(c ? to_merge_comm : to_merge_plain) {} + to_merge(enode* a, enode* b, justification j): a(a), b(b), t(to_justified), j(j) {} to_merge(enode* p, enode* ante): a(p), b(ante), t(to_add_literal) {} }; @@ -116,10 +122,12 @@ namespace euf { struct lbl_set {}; struct update_children {}; struct set_relevant {}; + struct plugin_undo {}; enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children, is_add_th_var, is_replace_th_var, is_new_th_eq, is_lbl_hash, is_new_th_eq_qhead, - is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant }; + is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant, + is_plugin_undo }; tag_t tag; enode* r1; enode* n1; @@ -162,11 +170,14 @@ namespace euf { tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} update_record(enode* n, set_relevant) : tag(tag_t::is_set_relevant), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} + update_record(unsigned th_id, plugin_undo) : + tag(tag_t::is_plugin_undo), r1(nullptr), n1(nullptr), m_th_id(th_id) {} }; ast_manager& m; svector m_to_merge; etable m_table; region m_region; + scoped_ptr_vector m_plugins; svector m_updates; unsigned_vector m_scopes; enode_vector m_expr2enode; @@ -205,6 +216,12 @@ namespace euf { } void push_node(enode* n) { m_updates.push_back(update_record(n)); } + // plugin related methods + void push_plugin_undo(unsigned th_id) { m_updates.push_back(update_record(th_id, update_record::plugin_undo())); } + void push_merge(enode* a, enode* b, justification j) { m_to_merge.push_back({ a, b, j }); } + void push_merge(enode* a, enode* b, bool comm) { m_to_merge.push_back({ a, b, comm }); } + void propagate_plugins(); + void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r); void add_th_diseqs(theory_id id, theory_var v1, enode* r); @@ -245,11 +262,15 @@ namespace euf { public: egraph(ast_manager& m); ~egraph(); + + void add_plugin(plugin* p); + plugin* get_plugin(family_id fid) const { return m_plugins.get(fid, nullptr); } + enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); } enode* find(expr* f, unsigned n, enode* const* args); enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args); enode_vector const& enodes_of(func_decl* f); - void push() { if (!m_to_merge.empty()) propagate(); ++m_num_scopes; } + void push() { if (can_propagate()) propagate(); ++m_num_scopes; } void pop(unsigned num_scopes); /** @@ -269,6 +290,7 @@ namespace euf { of new equalities. */ bool propagate(); + bool can_propagate() const { return !m_to_merge.empty(); } bool inconsistent() const { return m_inconsistent; } /** @@ -286,7 +308,7 @@ namespace euf { where \c n is an enode and \c is_eq indicates whether the enode is an equality consequence. */ - void add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq); + void add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq); bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); } th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; } void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } diff --git a/src/ast/euf/euf_enode.cpp b/src/ast/euf/euf_enode.cpp index 08df9f493..335c8f3e9 100644 --- a/src/ast/euf/euf_enode.cpp +++ b/src/ast/euf/euf_enode.cpp @@ -93,6 +93,17 @@ namespace euf { return null_theory_var; } + enode* enode::get_closest_th_node(theory_id id) { + enode* n = this; + while (n) { + theory_var v = n->get_th_var(id); + if (v != null_theory_var) + return n; + n = n->m_target; + } + return nullptr; + } + bool enode::acyclic() const { enode const* n = this; enode const* p = this; diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 064688d47..0f22ff20e 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -207,6 +207,7 @@ namespace euf { enode* get_root() const { return m_root; } expr* get_expr() const { return m_expr; } sort* get_sort() const { return m_expr->get_sort(); } + enode* get_interpreted() const { return get_root(); } app* get_app() const { return to_app(m_expr); } func_decl* get_decl() const { return is_app(m_expr) ? to_app(m_expr)->get_decl() : nullptr; } unsigned get_expr_id() const { return m_expr->get_id(); } @@ -216,6 +217,10 @@ namespace euf { bool children_are_roots() const; enode* get_next() const { return m_next; } + enode* get_target() const { return m_target; } + justification get_justification() const { return m_justification; } + justification get_lit_justification() const { return m_lit_justification; } + bool has_lbl_hash() const { return m_lbl_hash >= 0; } unsigned char get_lbl_hash() const { SASSERT(m_lbl_hash >= 0 && static_cast(m_lbl_hash) < approx_set_traits::capacity); @@ -229,6 +234,7 @@ namespace euf { theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); } theory_var get_closest_th_var(theory_id id) const; + enode* get_closest_th_node(theory_id id); bool is_attached_to(theory_id id) const { return get_th_var(id) != null_theory_var; } bool has_th_vars() const { return !m_th_vars.empty(); } bool has_one_th_var() const { return !m_th_vars.empty() && !m_th_vars.get_next();} diff --git a/src/ast/euf/euf_justification.cpp b/src/ast/euf/euf_justification.cpp new file mode 100644 index 000000000..22b52ea84 --- /dev/null +++ b/src/ast/euf/euf_justification.cpp @@ -0,0 +1,54 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + euf_justification.cpp + +Abstract: + + justification structure for euf + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-23 + +--*/ + + +#include "ast/euf/euf_justification.h" +#include "ast/euf/euf_enode.h" + +namespace euf { + + + std::ostream& justification::display(std::ostream& out, std::function const& ext) const { + switch (m_kind) { + case kind_t::external_t: + if (ext) + ext(out, m_external); + else + out << "external"; + return out; + case kind_t::axiom_t: + return out << "axiom"; + case kind_t::congruence_t: + return out << "congruence"; + case kind_t::dependent_t: { + vector js; + out << "dependent"; + for (auto const& j : dependency_manager::s_linearize(m_dependency, js)) + j.display(out << " ", ext); + return out; + } + case kind_t::equality_t: + return out << "equality #" << m_n1->get_id() << " == #" << m_n2->get_id(); + + default: + UNREACHABLE(); + return out; + } + return out; + } + +} diff --git a/src/ast/euf/euf_justification.h b/src/ast/euf/euf_justification.h index c98002396..f9d3b3637 100644 --- a/src/ast/euf/euf_justification.h +++ b/src/ast/euf/euf_justification.h @@ -22,19 +22,34 @@ Notes: #pragma once +#include "util/dependency.h" + namespace euf { + class enode; + class justification { + public: + typedef stacked_dependency_manager dependency_manager; + typedef stacked_dependency_manager::dependency dependency; + private: enum class kind_t { axiom_t, congruence_t, - external_t + external_t, + dependent_t, + equality_t }; kind_t m_kind; - bool m_comm; + union { + bool m_comm; + enode* m_n1; + }; union { void* m_external; uint64_t m_timestamp; + dependency* m_dependency; + enode* m_n2; }; justification(bool comm, uint64_t ts): @@ -49,6 +64,18 @@ namespace euf { m_external(ext) {} + justification(dependency* dep, int): + m_kind(kind_t::dependent_t), + m_comm(false), + m_dependency(dep) + {} + + justification(enode* n1, enode* n2): + m_kind(kind_t::equality_t), + m_n1(n1), + m_n2(n2) + {} + public: justification(): m_kind(kind_t::axiom_t), @@ -59,10 +86,17 @@ namespace euf { static justification axiom() { return justification(); } static justification congruence(bool c, uint64_t ts) { return justification(c, ts); } static justification external(void* ext) { return justification(ext); } + static justification dependent(dependency* d) { return justification(d, 1); } + static justification equality(enode* a, enode* b) { return justification(a, b); } bool is_external() const { return m_kind == kind_t::external_t; } bool is_congruence() const { return m_kind == kind_t::congruence_t; } bool is_commutative() const { return m_comm; } + bool is_dependent() const { return m_kind == kind_t::dependent_t; } + bool is_equality() const { return m_kind == kind_t::equality_t; } + dependency* get_dependency() const { SASSERT(is_dependent()); return m_dependency; } + enode* lhs() const { SASSERT(is_equality()); return m_n1; } + enode* rhs() const { SASSERT(is_equality()); return m_n2; } uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; } template T* ext() const { SASSERT(is_external()); return static_cast(m_external); } @@ -75,30 +109,17 @@ namespace euf { return axiom(); case kind_t::congruence_t: return congruence(m_comm, m_timestamp); + case kind_t::dependent_t: + NOT_IMPLEMENTED_YET(); + return dependent(m_dependency); default: UNREACHABLE(); return axiom(); } } - std::ostream& display(std::ostream& out, std::function const& ext) const { - switch (m_kind) { - case kind_t::external_t: - if (ext) - ext(out, m_external); - else - out << "external"; - return out; - case kind_t::axiom_t: - return out << "axiom"; - case kind_t::congruence_t: - return out << "congruence"; - default: - UNREACHABLE(); - return out; - } - return out; - } + std::ostream& display(std::ostream& out, std::function const& ext) const; + }; inline std::ostream& operator<<(std::ostream& out, justification const& j) { diff --git a/src/ast/euf/euf_plugin.cpp b/src/ast/euf/euf_plugin.cpp new file mode 100644 index 000000000..57c1849fd --- /dev/null +++ b/src/ast/euf/euf_plugin.cpp @@ -0,0 +1,47 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_plugin.cpp + +Abstract: + + plugin structure for euf + + Plugins allow adding equality saturation for theories. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + +--*/ + +#include "ast/euf/euf_egraph.h" + +namespace euf { + + void plugin::push_plugin_undo(unsigned th_id) { + g.push_plugin_undo(th_id); + } + + void plugin::push_merge(enode* a, enode* b, justification j) { + g.push_merge(a, b, j); + } + + void plugin::push_merge(enode* a, enode* b) { + TRACE("plugin", tout << g.bpp(a) << " == " << g.bpp(b) << "\n"); + g.push_merge(a, b, justification::axiom()); + } + + enode* plugin::mk(expr* e, unsigned n, enode* const* args) { + enode* r = g.find(e); + if (!r) + r = g.mk(e, 0, n, args); + return r; + } + + region& plugin::get_region() { + return g.m_region; + } +} diff --git a/src/ast/euf/euf_plugin.h b/src/ast/euf/euf_plugin.h new file mode 100644 index 000000000..ff49d6c40 --- /dev/null +++ b/src/ast/euf/euf_plugin.h @@ -0,0 +1,58 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_plugin.h + +Abstract: + + plugin structure for euf + + Plugins allow adding equality saturation for theories. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + +--*/ + +#pragma once + +#include "ast/euf/euf_enode.h" +#include "ast/euf/euf_justification.h" + +namespace euf { + + + class plugin { + protected: + egraph& g; + void push_plugin_undo(unsigned th_id); + void push_merge(enode* a, enode* b, justification j); + void push_merge(enode* a, enode* b); + enode* mk(expr* e, unsigned n, enode* const* args); + region& get_region(); + public: + plugin(egraph& g): + g(g) + {} + + virtual ~plugin() {} + + virtual unsigned get_id() const = 0; + + virtual void register_node(enode* n) = 0; + + virtual void merge_eh(enode* n1, enode* n2) = 0; + + virtual void diseq_eh(enode* eq) {}; + + virtual void propagate() = 0; + + virtual void undo() = 0; + + virtual std::ostream& display(std::ostream& out) const = 0; + + }; +} diff --git a/src/ast/euf/euf_specrel_plugin.cpp b/src/ast/euf/euf_specrel_plugin.cpp new file mode 100644 index 000000000..3220a24e6 --- /dev/null +++ b/src/ast/euf/euf_specrel_plugin.cpp @@ -0,0 +1,71 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_specrel_plugin.cpp + +Abstract: + + plugin structure for specrel + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#include "ast/euf/euf_specrel_plugin.h" +#include "ast/euf/euf_egraph.h" +#include + +namespace euf { + + specrel_plugin::specrel_plugin(egraph& g) : + plugin(g), + sp(g.get_manager()) { + } + + void specrel_plugin::register_node(enode* n) { + func_decl* f = n->get_decl(); + if (!f) + return; + if (!sp.is_ac(f)) + return; + ac_plugin* p = nullptr; + if (!m_decl2plugin.find(f, p)) { + p = alloc(ac_plugin, g, f); + m_decl2plugin.insert(f, p); + m_plugins.push_back(p); + std::function undo_op = [&]() { m_undo.push_back(p); }; + p->set_undo(undo_op); + } + } + + void specrel_plugin::merge_eh(enode* n1, enode* n2) { + for (auto * p : m_plugins) + p->merge_eh(n1, n2); + } + + void specrel_plugin::diseq_eh(enode* eq) { + for (auto* p : m_plugins) + p->diseq_eh(eq); + } + + void specrel_plugin::propagate() { + for (auto * p : m_plugins) + p->propagate(); + } + + void specrel_plugin::undo() { + auto p = m_undo.back(); + m_undo.pop_back(); + p->undo(); + } + + std::ostream& specrel_plugin::display(std::ostream& out) const { + for (auto * p : m_plugins) + p->display(out); + return out; + } +} \ No newline at end of file diff --git a/src/ast/euf/euf_specrel_plugin.h b/src/ast/euf/euf_specrel_plugin.h new file mode 100644 index 000000000..228bb5e15 --- /dev/null +++ b/src/ast/euf/euf_specrel_plugin.h @@ -0,0 +1,56 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_specrel_plugin.h + +Abstract: + + plugin structure for specrel functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#pragma once + +#include +#include "util/scoped_ptr_vector.h" +#include "ast/special_relations_decl_plugin.h" +#include "ast/euf/euf_plugin.h" +#include "ast/euf/euf_ac_plugin.h" + +namespace euf { + + class specrel_plugin : public plugin { + scoped_ptr_vector m_plugins; + ptr_vector m_undo; + obj_map m_decl2plugin; + special_relations_util sp; + + public: + + specrel_plugin(egraph& g); + + ~specrel_plugin() override {} + + unsigned get_id() const override { return sp.get_family_id(); } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override; + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + }; + +} diff --git a/src/ast/special_relations_decl_plugin.cpp b/src/ast/special_relations_decl_plugin.cpp index 24a756bf7..bbfe819d4 100644 --- a/src/ast/special_relations_decl_plugin.cpp +++ b/src/ast/special_relations_decl_plugin.cpp @@ -26,7 +26,8 @@ special_relations_decl_plugin::special_relations_decl_plugin(): m_po("partial-order"), m_plo("piecewise-linear-order"), m_to("tree-order"), - m_tc("transitive-closure") + m_tc("transitive-closure"), + m_ac("ac-op") {} func_decl * special_relations_decl_plugin::mk_func_decl( @@ -41,24 +42,53 @@ func_decl * special_relations_decl_plugin::mk_func_decl( m_manager->raise_exception("argument sort missmatch. The two arguments should have the same sort"); return nullptr; } + if (!range && k == OP_SPECIAL_RELATION_AC) + range = domain[0]; + if (!range) { range = m_manager->mk_bool_sort(); } - if (!m_manager->is_bool(range)) { - m_manager->raise_exception("range type is expected to be Boolean for special relations"); - } + auto check_bool_range = [&]() { + if (!m_manager->is_bool(range)) + m_manager->raise_exception("range type is expected to be Boolean for special relations"); + }; + + m_has_special_relation = true; func_decl_info info(m_family_id, k, num_parameters, parameters); symbol name; switch(k) { - case OP_SPECIAL_RELATION_PO: name = m_po; break; - case OP_SPECIAL_RELATION_LO: name = m_lo; break; - case OP_SPECIAL_RELATION_PLO: name = m_plo; break; - case OP_SPECIAL_RELATION_TO: name = m_to; break; + case OP_SPECIAL_RELATION_PO: check_bool_range(); name = m_po; break; + case OP_SPECIAL_RELATION_LO: check_bool_range(); name = m_lo; break; + case OP_SPECIAL_RELATION_PLO: check_bool_range(); name = m_plo; break; + case OP_SPECIAL_RELATION_TO: check_bool_range(); name = m_to; break; + case OP_SPECIAL_RELATION_AC: { + if (range != domain[0]) + m_manager->raise_exception("AC operation should have the same range as domain type"); + name = m_ac; + if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast())) + m_manager->raise_exception("parameter to transitive closure should be a function declaration"); + func_decl* f = to_func_decl(parameters[0].get_ast()); + if (f->get_arity() != 2) + m_manager->raise_exception("ac function should be binary"); + if (f->get_domain(0) != f->get_domain(1)) + m_manager->raise_exception("ac function should have same domain"); + if (f->get_domain(0) != f->get_range()) + m_manager->raise_exception("ac function should have same domain and range"); + break; + } case OP_SPECIAL_RELATION_TC: + check_bool_range(); name = m_tc; if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast())) m_manager->raise_exception("parameter to transitive closure should be a function declaration"); + func_decl* f = to_func_decl(parameters[0].get_ast()); + if (f->get_arity() != 2) + m_manager->raise_exception("tc relation should be binary"); + if (f->get_domain(0) != f->get_domain(1)) + m_manager->raise_exception("tc relation should have same domain"); + if (!m_manager->is_bool(f->get_range())) + m_manager->raise_exception("tc relation should be Boolean"); break; } return m_manager->mk_func_decl(name, arity, domain, range, info); @@ -71,6 +101,7 @@ void special_relations_decl_plugin::get_op_names(svector & op_name op_names.push_back(builtin_name(m_plo.str(), OP_SPECIAL_RELATION_PLO)); op_names.push_back(builtin_name(m_to.str(), OP_SPECIAL_RELATION_TO)); op_names.push_back(builtin_name(m_tc.str(), OP_SPECIAL_RELATION_TC)); + op_names.push_back(builtin_name(m_ac.str(), OP_SPECIAL_RELATION_AC)); } } @@ -81,6 +112,7 @@ sr_property special_relations_util::get_property(func_decl* f) const { case OP_SPECIAL_RELATION_PLO: return sr_plo; case OP_SPECIAL_RELATION_TO: return sr_to; case OP_SPECIAL_RELATION_TC: return sr_tc; + case OP_SPECIAL_RELATION_AC: return sr_none; default: UNREACHABLE(); return sr_po; diff --git a/src/ast/special_relations_decl_plugin.h b/src/ast/special_relations_decl_plugin.h index c422cbcdc..a65f98758 100644 --- a/src/ast/special_relations_decl_plugin.h +++ b/src/ast/special_relations_decl_plugin.h @@ -16,6 +16,8 @@ Author: Revision History: + 2023-11-27: Added ac-op for E-graph plugin + --*/ #pragma once @@ -28,6 +30,7 @@ enum special_relations_op_kind { OP_SPECIAL_RELATION_PLO, OP_SPECIAL_RELATION_TO, OP_SPECIAL_RELATION_TC, + OP_SPECIAL_RELATION_AC, LAST_SPECIAL_RELATIONS_OP }; @@ -37,6 +40,7 @@ class special_relations_decl_plugin : public decl_plugin { symbol m_plo; symbol m_to; symbol m_tc; + symbol m_ac; bool m_has_special_relation = false; public: special_relations_decl_plugin(); @@ -86,13 +90,16 @@ class special_relations_util { public: special_relations_util(ast_manager& m) : m(m), m_fid(null_family_id) { } + family_id get_family_id() const { return fid(); } + bool has_special_relation() const { return static_cast(m.get_plugin(m.mk_family_id("specrels")))->has_special_relation(); } bool is_special_relation(func_decl* f) const { return f->get_family_id() == fid(); } - bool is_special_relation(app* e) const { return is_special_relation(e->get_decl()); } + bool is_special_relation(expr* e) const { return is_app(e) && is_special_relation(to_app(e)->get_decl()); } sr_property get_property(func_decl* f) const; sr_property get_property(app* e) const { return get_property(e->get_decl()); } func_decl* get_relation(func_decl* f) const { SASSERT(is_special_relation(f)); return to_func_decl(f->get_parameter(0).get_ast()); } + func_decl* get_relation(expr* e) const { SASSERT(is_special_relation(e)); return to_func_decl(to_app(e)->get_parameter(0).get_ast()); } func_decl* mk_to_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_TO); } func_decl* mk_po_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_PO); } @@ -105,12 +112,14 @@ public: bool is_plo(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_PLO); } bool is_to(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TO); } bool is_tc(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TC); } + bool is_ac(expr const* e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_AC); } bool is_lo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_LO); } bool is_po(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PO); } bool is_plo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PLO); } bool is_to(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TO); } bool is_tc(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TC); } + bool is_ac(func_decl const* e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_AC); } app * mk_lo (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_LO, arg1, arg2); } app * mk_po (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_PO, arg1, arg2); } diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 4a899ca9d..7caccded6 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -44,6 +44,7 @@ z3_add_component(sat_smt q_solver.cpp recfun_solver.cpp sat_th.cpp + specrel_solver.cpp tseitin_theory_checker.cpp user_solver.cpp COMPONENT_DEPENDENCIES diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index d21e7a12a..8e4bd765d 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -28,6 +28,7 @@ Author: #include "sat/smt/fpa_solver.h" #include "sat/smt/dt_solver.h" #include "sat/smt/recfun_solver.h" +#include "sat/smt/specrel_solver.h" namespace euf { @@ -130,6 +131,7 @@ namespace euf { arith_util arith(m); datatype_util dt(m); recfun::util rf(m); + special_relations_util sp(m); if (pb.get_family_id() == fid) ext = alloc(pb::solver, *this, fid); else if (bvu.get_family_id() == fid) @@ -144,6 +146,8 @@ namespace euf { ext = alloc(dt::solver, *this, fid); else if (rf.get_family_id() == fid) ext = alloc(recfun::solver, *this); + else if (sp.get_family_id() == fid) + ext = alloc(specrel::solver, *this, fid); if (ext) add_solver(ext); diff --git a/src/sat/smt/specrel_solver.cpp b/src/sat/smt/specrel_solver.cpp new file mode 100644 index 000000000..d59029e6b --- /dev/null +++ b/src/sat/smt/specrel_solver.cpp @@ -0,0 +1,120 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + specrel_solver.h + +Abstract: + + Theory plugin for special relations + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ + +#include "sat/smt/specrel_solver.h" +#include "sat/smt/euf_solver.h" +#include "ast/euf/euf_specrel_plugin.h" + +namespace euf { + class solver; +} + +namespace specrel { + + solver::solver(euf::solver& ctx, theory_id id) : + th_euf_solver(ctx, ctx.get_manager().get_family_name(id), id), + sp(m) + { + ctx.get_egraph().add_plugin(alloc(euf::specrel_plugin, ctx.get_egraph())); + } + + solver::~solver() { + } + + void solver::asserted(sat::literal l) { + + } + + sat::check_result solver::check() { + return sat::check_result::CR_DONE; + } + + std::ostream& solver::display(std::ostream& out) const { + return out; + } + + void solver::collect_statistics(statistics& st) const { + } + + euf::th_solver* solver::clone(euf::solver& ctx) { + return alloc(solver, ctx, get_id()); + } + + void solver::new_eq_eh(euf::th_eq const& eq) { + TRACE("specrel", tout << "new-eq\n"); + if (eq.is_eq()) { + auto* p = ctx.get_egraph().get_plugin(sp.get_family_id()); + p->merge_eh(var2enode(eq.v1()), var2enode(eq.v2())); + TRACE("specrel", tout << eq.v1() << " " << eq.v2() << "\n"); + } + } + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + } + + bool solver::add_dep(euf::enode* n, top_sort& dep) { + return false; + } + + bool solver::include_func_interp(func_decl* f) const { + return false; + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + auto lit = ctx.expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + if (visited(e)) + return true; + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* term, bool sign, bool root) { + euf::enode* n = expr2enode(term); + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(term); + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + TRACE("specrel", tout << ctx.bpp(n) << "\n"); + return true; + } + + euf::theory_var solver::mk_var(euf::enode* n) { + if (is_attached_to_var(n)) + return n->get_th_var(get_id()); + euf::theory_var r = th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, r); + return r; + } +} diff --git a/src/sat/smt/specrel_solver.h b/src/sat/smt/specrel_solver.h new file mode 100644 index 000000000..9ebb76916 --- /dev/null +++ b/src/sat/smt/specrel_solver.h @@ -0,0 +1,75 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + specrel_solver.h + +Abstract: + + Theory plugin for special relations + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ +#pragma once + +#include "sat/smt/sat_th.h" +#include "ast/special_relations_decl_plugin.h" + +namespace euf { + class solver; +} + +namespace specrel { + + class solver : public euf::th_euf_solver { + typedef euf::theory_var theory_var; + typedef euf::theory_id theory_id; + typedef euf::enode enode; + typedef euf::enode_pair enode_pair; + typedef euf::enode_pair_vector enode_pair_vector; + typedef sat::bool_var bool_var; + typedef sat::literal literal; + typedef sat::literal_vector literal_vector; + + special_relations_util sp; + + public: + solver(euf::solver& ctx, theory_id id); + ~solver() override; + + bool is_external(bool_var v) override { return false; } + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {} + void asserted(literal l) override; + sat::check_result check() override; + + std::ostream& display(std::ostream& out) const override; + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return euf::th_explain::from_index(idx).display(out); } + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return display_justification(out, idx); } + void collect_statistics(statistics& st) const override; + euf::th_solver* clone(euf::solver& ctx) override; + void new_eq_eh(euf::th_eq const& eq) override; + bool unit_propagate() override { return false; } + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + bool add_dep(euf::enode* n, top_sort& dep) override; + bool include_func_interp(func_decl* f) const override; + sat::literal internalize(expr* e, bool sign, bool root) override; + void internalize(expr* e) override; + bool visit(expr* e) override; + bool visited(expr* e) override; + bool post_visit(expr* e, bool sign, bool root) override; + + euf::theory_var mk_var(euf::enode* n) override; + void apply_sort_cnstr(euf::enode* n, sort* s) override {} + bool is_shared(theory_var v) const override { return false; } + lbool get_phase(bool_var v) override { return l_true; } + bool enable_self_propagate() const override { return true; } + + void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2); + void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {} + void unmerge_eh(theory_var v1, theory_var v2) {} + }; +} diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index df3010295..14b51f822 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -39,6 +39,8 @@ add_executable(test-z3 doc.cpp egraph.cpp escaped.cpp + euf_bv_plugin.cpp + euf_arith_plugin.cpp ex.cpp expr_rand.cpp expr_substitution.cpp diff --git a/src/test/euf_arith_plugin.cpp b/src/test/euf_arith_plugin.cpp new file mode 100644 index 000000000..41d629ad5 --- /dev/null +++ b/src/test/euf_arith_plugin.cpp @@ -0,0 +1,106 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +--*/ + +#include "util/util.h" +#include "util/timer.h" +#include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_arith_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" +#include + +unsigned s_var = 0; + +static euf::enode* get_node(euf::egraph& g, arith_util& a, expr* e) { + auto* n = g.find(e); + if (n) + return n; + euf::enode_vector args; + for (expr* arg : *to_app(e)) + args.push_back(get_node(g, a, arg)); + n = g.mk(e, 0, args.size(), args.data()); + g.add_th_var(n, s_var++, a.get_family_id()); + return n; +} + +// +static void test1() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::arith_plugin, g)); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nx = get_node(g, a, a.mk_add(a.mk_add(y, y), a.mk_add(x, x))); + auto* ny = get_node(g, a, a.mk_add(a.mk_add(y, x), x)); + TRACE("plugin", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr); + g.propagate(); + std::cout << g << "\n"; +} + +static void test2() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::arith_plugin, g)); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nxy = get_node(g, a, a.mk_add(x, y)); + auto* nyx = get_node(g, a, a.mk_add(y, x)); + auto* nx = get_node(g, a, x); + auto* ny = get_node(g, a, y); + + TRACE("plugin", tout << "before merge\n" << g << "\n"); + g.merge(nxy, nx, nullptr); + g.merge(nyx, ny, nullptr); + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + SASSERT(nx->get_root() == ny->get_root()); + g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr); + g.propagate(); + std::cout << g << "\n"; +} + +static void test3() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::arith_plugin, g)); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nxyy = get_node(g, a, a.mk_add(a.mk_add(x, y), y)); + auto* nyxx = get_node(g, a, a.mk_add(a.mk_add(y, x), x)); + auto* nx = get_node(g, a, x); + auto* ny = get_node(g, a, y); + g.merge(nxyy, nx, nullptr); + g.merge(nyxx, ny, nullptr); + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + +void tst_euf_arith_plugin() { + enable_trace("plugin"); + test1(); + test2(); + test3(); +} diff --git a/src/test/euf_bv_plugin.cpp b/src/test/euf_bv_plugin.cpp new file mode 100644 index 000000000..501bd7b14 --- /dev/null +++ b/src/test/euf_bv_plugin.cpp @@ -0,0 +1,183 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +--*/ + +#include "util/util.h" +#include "util/timer.h" +#include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_bv_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" +#include + +static unsigned s_var = 0; +static euf::enode* get_node(euf::egraph& g, bv_util& b, expr* e) { + auto* n = g.find(e); + if (n) + return n; + euf::enode_vector args; + for (expr* arg : *to_app(e)) + args.push_back(get_node(g, b, arg)); + n = g.mk(e, 0, args.size(), args.data()); + g.add_th_var(n, s_var++, b.get_family_id()); + return n; +} + +// align slices, and propagate extensionality +static void test1() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref y(m.mk_const("y", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref y3(bv.mk_extract(31, 24, y), m); + expr_ref y2(bv.mk_extract(23, 8, y), m); + expr_ref y1(bv.mk_extract(7, 0, y), m); + expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); + expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m); + auto* nx = get_node(g, bv, xx); + auto* ny = get_node(g, bv, yy); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; + SASSERT(nx->get_root() == ny->get_root()); +} + +// propagate values down +static void test2() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); + g.merge(get_node(g, bv, xx), get_node(g, bv, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr); + g.propagate(); + SASSERT(get_node(g, bv, x1)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x2)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x3)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x)->get_root()->interpreted()); +} + + +// propagate values up +static void test3() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m); + expr_ref y(m.mk_const("y", u32), m); + g.merge(get_node(g, bv, xx), get_node(g, bv, y), nullptr); + g.merge(get_node(g, bv, x1), get_node(g, bv, bv.mk_numeral(2, 8)), nullptr); + g.merge(get_node(g, bv, x2), get_node(g, bv, bv.mk_numeral(8, 8)), nullptr); + g.propagate(); + SASSERT(get_node(g, bv, bv.mk_concat(x1, x2))->get_root()->interpreted()); + SASSERT(get_node(g, bv, x1)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x2)->get_root()->interpreted()); +} + +// propagate extract up +static void test4() { + // concat(a, x[J]), a = x[I] => x[IJ] = concat(x[I],x[J]) + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + sort_ref u8(bv.mk_sort(8), m); + sort_ref u16(bv.mk_sort(16), m); + expr_ref a(m.mk_const("a", u8), m); + expr_ref x(m.mk_const("x", u32), m); + expr_ref y(m.mk_const("y", u16), m); + expr_ref x1(bv.mk_extract(15, 8, x), m); + expr_ref x2(bv.mk_extract(23, 16, x), m); + g.merge(get_node(g, bv, bv.mk_concat(a, x2)), get_node(g, bv, y), nullptr); + g.merge(get_node(g, bv, x1), get_node(g, bv, a), nullptr); + g.propagate(); + TRACE("bv", tout << g << "\n"); + SASSERT(get_node(g, bv, bv.mk_extract(23, 8, x))->get_root() == get_node(g, bv, y)->get_root()); +} + +// iterative slicing +static void test5() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x1(bv.mk_extract(31, 4, x), m); + expr_ref x2(bv.mk_extract(27, 0, x), m); + auto* nx = get_node(g, bv, x1); + auto* ny = get_node(g, bv, x2); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + +// iterative slicing +static void test6() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x1(bv.mk_extract(31, 3, x), m); + expr_ref x2(bv.mk_extract(28, 0, x), m); + auto* nx = get_node(g, bv, x1); + auto* ny = get_node(g, bv, x2); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + + +void tst_euf_bv_plugin() { + enable_trace("bv"); + enable_trace("plugin"); + test6(); + return; + test1(); + test2(); + test3(); + test4(); + test5(); + test6(); +} diff --git a/src/test/main.cpp b/src/test/main.cpp index 7cd4b6cf9..3f073abf2 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -265,4 +265,6 @@ int main(int argc, char ** argv) { TST(finder); TST(totalizer); TST(distribution); + TST(euf_bv_plugin); + TST(euf_arith_plugin); } diff --git a/src/util/dependency.h b/src/util/dependency.h index 57057460c..a76d43f88 100644 --- a/src/util/dependency.h +++ b/src/util/dependency.h @@ -44,8 +44,39 @@ public: public: unsigned get_ref_count() const { return m_ref_count; } bool is_leaf() const { return m_leaf == 1; } + value const& leaf_value() const { SASSERT(is_leaf()); return static_cast(this)->m_value; } }; + static void linearize_todo(ptr_vector& todo, vector& vs) { + unsigned qhead = 0; + while (qhead < todo.size()) { + dependency* d = todo[qhead]; + qhead++; + if (d->is_leaf()) { + vs.push_back(to_leaf(d)->m_value); + } + else { + for (unsigned i = 0; i < 2; i++) { + dependency* child = to_join(d)->m_children[i]; + if (!child->is_marked()) { + todo.push_back(child); + child->mark(); + } + } + } + } + for (auto* d : todo) + d->unmark(); + } + + static void s_linearize(dependency* d, vector& vs) { + if (!d) + return; + ptr_vector todo; + todo.push_back(d); + linearize_todo(todo, vs); + } + private: struct join : public dependency { dependency * m_children[2]; @@ -69,7 +100,7 @@ private: value_manager & m_vmanager; allocator & m_allocator; - mutable ptr_vector m_todo; + ptr_vector m_todo; void inc_ref(value const & v) { if (C::ref_count) @@ -83,6 +114,7 @@ private: void del(dependency * d) { SASSERT(d); + SASSERT(m_todo.empty()); m_todo.push_back(d); while (!m_todo.empty()) { d = m_todo.back(); @@ -106,8 +138,8 @@ private: } } - void unmark_todo() const { - for (auto* d : m_todo) + void unmark_todo() { + for (auto* d : m_todo) d->unmark(); m_todo.reset(); } @@ -190,30 +222,30 @@ public: return false; } - void linearize(dependency * d, vector & vs) const { - if (d) { - m_todo.reset(); - d->mark(); - m_todo.push_back(d); - unsigned qhead = 0; - while (qhead < m_todo.size()) { - d = m_todo[qhead]; - qhead++; - if (d->is_leaf()) { - vs.push_back(to_leaf(d)->m_value); - } - else { - for (unsigned i = 0; i < 2; i++) { - dependency * child = to_join(d)->m_children[i]; - if (!child->is_marked()) { - m_todo.push_back(child); - child->mark(); - } - } - } + + + void linearize(dependency * d, vector & vs) { + if (!d) + return; + SASSERT(m_todo.empty()); + d->mark(); + m_todo.push_back(d); + linearize_todo(m_todo, vs); + m_todo.reset(); + } + + void linearize(ptr_vector& deps, vector & vs) { + if (deps.empty()) + return; + SASSERT(m_todo.empty()); + for (auto* d : deps) { + if (d && !d->is_marked()) { + d->mark(); + m_todo.push_back(d); } - unmark_todo(); } + linearize_todo(m_todo, vs); + m_todo.reset(); } }; @@ -297,7 +329,16 @@ public: return m_dep_manager.contains(d, v); } - void linearize(dependency * d, vector & vs) const { + void linearize(dependency * d, vector & vs) { + return m_dep_manager.linearize(d, vs); + } + + static vector const& s_linearize(dependency* d, vector& vs) { + dep_manager::s_linearize(d, vs); + return vs; + } + + void linearize(ptr_vector& d, vector & vs) { return m_dep_manager.linearize(d, vs); } @@ -320,4 +361,83 @@ typedef scoped_dependency_manager::dependency v_dependency; typedef scoped_dependency_manager u_dependency_manager; typedef scoped_dependency_manager::dependency u_dependency; +/** + \brief Version of the scoped-depenendcy-manager where region scopes are handled externally. +*/ +template +class stacked_dependency_manager { + class config { + public: + static const bool ref_count = true; + + typedef Value value; + + class value_manager { + public: + void inc_ref(value const& v) { + } + + void dec_ref(value const& v) { + } + }; + + class allocator { + region& m_region; + public: + allocator(region& r) : m_region(r) {} + + void* allocate(size_t sz) { + return m_region.allocate(sz); + } + + void deallocate(size_t sz, void* mem) { + } + }; + }; + + typedef dependency_manager dep_manager; +public: + typedef typename dep_manager::dependency dependency; + typedef Value value; + +private: + typename config::value_manager m_vmanager; + typename config::allocator m_allocator; + dep_manager m_dep_manager; + +public: + stacked_dependency_manager(region& r) : + m_allocator(r), + m_dep_manager(m_vmanager, m_allocator) { + } + + dependency* mk_empty() { + return m_dep_manager.mk_empty(); + } + + dependency* mk_leaf(value const& v) { + return m_dep_manager.mk_leaf(v); + } + + dependency* mk_join(dependency* d1, dependency* d2) { + return m_dep_manager.mk_join(d1, d2); + } + + bool contains(dependency* d, value const& v) { + return m_dep_manager.contains(d, v); + } + + void linearize(dependency* d, vector& vs) { + return m_dep_manager.linearize(d, vs); + } + + static vector const& s_linearize(dependency* d, vector& vs) { + dep_manager::s_linearize(d, vs); + return vs; + } + + void linearize(ptr_vector& d, vector& vs) { + return m_dep_manager.linearize(d, vs); + } +}; \ No newline at end of file