From 65a8c162f58331d6d28e30daaa15387318f6668d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 12 Nov 2023 15:39:45 -0800 Subject: [PATCH] add E(T) functionality for bv and ac functions Add an option to register EUF modulo theories, The current theory with a unit test is BV. The arithmetic theory plugs into an AC completion. It is partially finished, pending setting up testing and implementing handling of shared terms. --- src/ast/euf/CMakeLists.txt | 6 +- src/ast/euf/ac_plugin.h | 51 +++ src/ast/euf/euf_ac_plugin.cpp | 591 +++++++++++++++++++++++++++++++ src/ast/euf/euf_ac_plugin.h | 174 +++++++++ src/ast/euf/euf_arith_plugin.cpp | 77 ++++ src/ast/euf/euf_arith_plugin.h | 51 +++ src/ast/euf/euf_bv_plugin.cpp | 347 ++++++++++++++++++ src/ast/euf/euf_bv_plugin.h | 100 ++++++ src/ast/euf/euf_egraph.cpp | 77 +++- src/ast/euf/euf_egraph.h | 26 +- src/ast/euf/euf_enode.h | 1 + src/ast/euf/euf_justification.h | 48 ++- src/ast/euf/euf_plugin.cpp | 47 +++ src/ast/euf/euf_plugin.h | 58 +++ src/sat/smt/euf_solver.cpp | 2 - src/test/CMakeLists.txt | 1 + src/test/euf_bv_plugin.cpp | 180 ++++++++++ src/test/main.cpp | 1 + src/util/dependency.h | 13 + 19 files changed, 1830 insertions(+), 21 deletions(-) create mode 100644 src/ast/euf/ac_plugin.h create mode 100644 src/ast/euf/euf_ac_plugin.cpp create mode 100644 src/ast/euf/euf_ac_plugin.h create mode 100644 src/ast/euf/euf_arith_plugin.cpp create mode 100644 src/ast/euf/euf_arith_plugin.h create mode 100644 src/ast/euf/euf_bv_plugin.cpp create mode 100644 src/ast/euf/euf_bv_plugin.h create mode 100644 src/ast/euf/euf_plugin.cpp create mode 100644 src/ast/euf/euf_plugin.h create mode 100644 src/test/euf_bv_plugin.cpp diff --git a/src/ast/euf/CMakeLists.txt b/src/ast/euf/CMakeLists.txt index 8d3fa2e74..430ea2b08 100644 --- a/src/ast/euf/CMakeLists.txt +++ b/src/ast/euf/CMakeLists.txt @@ -1,8 +1,12 @@ z3_add_component(euf SOURCES + euf_ac_plugin.cpp + euf_arith_plugin.cpp + euf_bv_plugin.cpp + euf_egraph.cpp euf_enode.cpp euf_etable.cpp - euf_egraph.cpp + euf_plugin.cpp COMPONENT_DEPENDENCIES ast util diff --git a/src/ast/euf/ac_plugin.h b/src/ast/euf/ac_plugin.h new file mode 100644 index 000000000..7ce94e035 --- /dev/null +++ b/src/ast/euf/ac_plugin.h @@ -0,0 +1,51 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_ac_plugin.h + +Abstract: + + plugin structure for ac functions +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + Jakob Rath 2023-11-11 + + +--*/ + +#pragma once + +#include "ast/euf/euf_plugin.h" + +namespace euf { + + + class ac_plugin : public plugin { + struct eq { + enode_vector l, r; + }; + vector m_eqs; + vector m_use; + unsigned m_fid; + unsigned m_op; + + void push_eq(enode* l, enode* r); + public: + + ac_plugin(egraph& g, unsigned fid, unsigned op); + + unsigned get_id() const override { return m_fid; } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2, justification j) override; + + void diseq_eh(enode* n1, enode* n2) override; + + void undo() override; + + std::ostream& display(std::ostream& out) const override; + }; diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp new file mode 100644 index 000000000..70283d6b4 --- /dev/null +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -0,0 +1,591 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_ac_plugin.cpp + +Abstract: + + plugin structure for ac functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +Completion modulo AC + + E set of eqs + pick critical pair xy = z by j1 xu = v by j2 in E + Add new equation zu = xyu = vy by j1, j2 + + + Notes: + - Some equalities come from shared terms, so do not. + + +More notes: + + Justifications for new equations are joined (requires extension to egraph/justification) + + Process new merges so use list is updated + Justifications for processed merges are recorded + + Updated equations are recorded for restoration on backtracking + + Keep track of foreign / shared occurrences of AC functions. + - use register_shared to accumulate shared occurrences. + + Shared occurrences are rewritten modulo completion. + When equal to a different shared occurrence, propagate equality. + +--*/ + +#pragma once + +#include "ast/euf/euf_ac_plugin.h" +#include "ast/euf/euf_egraph.h" + +namespace euf { + + ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op): + plugin(g), m_fid(fid), m_op(op) + {} + + void ac_plugin::register_node(enode* n) { + + } + + void ac_plugin::register_shared(enode* n) { + auto m = to_monomial(n); + auto const& ns = monomial(m); + for (auto arg : ns) { + arg->shared.push_back(m); + m_node_trail.push_back(arg); + push_undo(is_add_shared); + } + m_shared_trail.push_back(m); + push_undo(is_register_shared); + } + + void ac_plugin::undo() { + auto k = m_undo.back(); + m_undo.pop_back(); + switch (k) { + case is_add_eq: { + auto const& eq = m_eqs.back(); + for (auto* n : monomial(eq.l)) + n->lhs.pop_back(); + for (auto* n : monomial(eq.r)) + n->rhs.pop_back(); + m_eqs.pop_back(); + break; + } + case is_add_node: { + auto* n = m_node_trail.back(); + m_node_trail.pop_back(); + m_nodes[n->n->get_id()] = nullptr; + n->~node(); + break; + } + case is_add_monomial: { + m_monomials.pop_back(); + m_monomial_enodes.pop_back(); + break; + } + case is_merge_node: { + auto [other, old_shared, old_lhs, old_rhs] = m_merge_trail.back(); + auto* root = other->root; + std::swap(other->next, root->next); + root->shared.shrink(old_shared); + root->lhs.shrink(old_lhs); + root->rhs.shrink(old_rhs); + m_merge_trail.pop_back(); + break; + } + case is_update_eq: { + auto const & [idx, eq] = m_update_eq_trail.back(); + m_eqs[idx] = eq; + m_update_eq_trail.pop_back(); + break; + } + case is_add_shared: { + auto n = m_node_trail.back(); + m_node_trail.pop_back(); + n->shared.pop_back(); + break; + } + case is_register_shared: { + m_shared_trail.pop_back(); + break; + } + case is_join_justification: { + m_dep_manager.pop_scope(1); + break; + } + default: + UNREACHABLE(); + } + } + + std::ostream& ac_plugin::display(std::ostream& out) const { + unsigned i = 0; + for (auto const& eq : m_eqs) { + out << i << ": " << eq.l << " == " << eq.r << ": "; + for (auto n : monomial(eq.l)) + out << g.bpp(n->n) << " "; + out << "== "; + for (auto n : monomial(eq.r)) + out << g.bpp(n->n) << " "; + out << "\n"; + ++i; + } + i = 0; + for (auto m : m_monomials) { + out << i << ": "; + for (auto n : m) + out << g.bpp(n->n) << " "; + out << "\n"; + ++i; + } + for (auto n : m_nodes) { + out << g.bpp(n->n) << " r: " << n->root_id() << "\n"; + out << "lhs "; + for (auto l : n->lhs) + out << l << " "; + out << "rhs "; + for (auto r : n->rhs) + out << r << " "; + out << "shared "; + for (auto s : n->shared) + out << s << " "; + out << "\n"; + } + return out; + } + + void ac_plugin::merge_eh(enode* l, enode* r, justification j) { + if (l == r) + return; + if (!is_op(l) && !is_op(r)) + merge(mk_node(l), mk_node(r), j); + else + init_equation({ to_monomial(l), to_monomial(r), false, j }); + } + + void ac_plugin::init_equation(eq const& e) { + m_eqs.push_back(e); + auto& eq = m_eqs.back(); + if (orient_equation(eq)) { + push_undo(is_add_eq); + unsigned eq_id = m_eqs.size() - 1; + for (auto n : monomial(eq.l)) + n->lhs.push_back(eq_id); + for (auto n : monomial(eq.r)) + n->rhs.push_back(eq_id); + } + else + m_eqs.pop_back(); + } + + bool ac_plugin::orient_equation(eq& e) { + auto& ml = monomial(e.l); + auto& mr = monomial(e.r); + if (ml.size() > mr.size()) + return true; + if (ml.size() < mr.size()) { + std::swap(e.l, e.r); + return true; + } + else { + std::sort(ml.begin(), ml.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); }); + std::sort(mr.begin(), mr.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); }); + for (unsigned i = ml.size(); i-- > 0;) { + if (ml[i] == mr[i]) + continue; + if (ml[i]->root_id() < mr[i]->root_id()) + std::swap(e.l, e.r); + return true; + } + return false; + } + } + + void ac_plugin::merge(node* root, node* other, justification j) { + for (auto n : equiv(other)) + n->root = root; + m_merge_trail.push_back({ other, root->shared.size(), root->lhs.size(), root->rhs.size()}); + for (auto eq_id : other->lhs) + set_processed(eq_id, false); + for (auto eq_id : other->rhs) + set_processed(eq_id, false); + root->shared.append(other->shared); + root->lhs.append(other->lhs); + root->rhs.append(other->rhs); + std::swap(root->next, other->next); + push_undo(is_merge_node); + } + + void ac_plugin::push_undo(undo_kind k) { + m_undo.push_back(k); + push_plugin_undo(get_id()); + m_undo_notify(); // tell main plugin to dispatch undo to this module. + } + + unsigned ac_plugin::to_monomial(enode* n) { + enode_vector& ns = m_todo; + ns.reset(); + ptr_vector ms; + ns.push_back(n); + for (unsigned i = 0; i < ns.size(); ++i) { + n = ns[i]; + if (is_op(n)) { + ns.append(n->num_args(), n->args()); + ns[i] = ns.back(); + ns.pop_back(); + --i; + } + else { + ms.push_back(mk_node(n)); + } + } + return to_monomial(n, ms); + } + + unsigned ac_plugin::to_monomial(enode* e, ptr_vector const& ms) { + unsigned id = m_monomials.size(); + m_monomials.push_back(ms); + m_monomial_enodes.push_back(e); + push_undo(is_add_monomial); + return id; + } + + ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) { + auto* mem = r.allocate(sizeof(node)); + node* res = new (mem) node(); + res->n = n; + res->root = res; + res->next = res; + return res; + } + + ac_plugin::node* ac_plugin::mk_node(enode* n) { + unsigned id = n->get_id(); + if (m_nodes.size() > id && m_nodes[id]) + return m_nodes[id]; + auto* r = node::mk(get_region(), n); + push_undo(is_add_node); + m_nodes.set(id, r); + m_node_trail.push_back(r); + return r; + } + + void ac_plugin::propagate() { + while (true) { + unsigned eq_id = pick_next_eq(); + if (eq_id == UINT_MAX) + break; + + // simplify eq using processed + for (auto other_eq : backward_iterator(eq_id)) + if (is_processed(other_eq)) + backward_simplify(eq_id, other_eq); + if (m_backward_simplified) + continue; + + // simplify processed using eq + for (auto other_eq : forward_iterator(eq_id)) + if (is_processed(other_eq)) + forward_simplify(other_eq, eq_id); + + // superpose, create new equations + for (auto other_eq : superpose_iterator(eq_id)) + if (is_processed(other_eq)) + superpose(eq_id, other_eq); + + // simplify to_simplify using eq + for (auto other_eq : forward_iterator(eq_id)) + if (is_to_simplify(other_eq)) + forward_simplify(other_eq, eq_id); + + set_processed(eq_id, true); + } + propagate_shared(); + } + + unsigned ac_plugin::pick_next_eq() { + for (unsigned i = 0, n = m_eqs.size(); i < n; ++i) { + unsigned id = (i + m_next_eq_index) % n; + auto const& eq = m_eqs[id]; + if (eq.is_processed) + continue; + ++m_next_eq_index; + return id; + } + return UINT_MAX; + } + + void ac_plugin::set_processed(unsigned id, bool f) { + auto& eq = m_eqs[id]; + if (eq.is_processed == f) + return; + m_update_eq_trail.push_back({ id, eq }); + eq.is_processed = f; + push_undo(is_update_eq); + } + + // + // superpose iterator enumerates all equations where lhs of eq have element in common. + // + unsigned_vector const& ac_plugin::superpose_iterator(unsigned eq_id) { + auto const& eq = m_eqs[eq_id]; + m_src_r.reset(); + m_src_r.append(monomial(eq.r)); + init_ids_counts(eq_id, eq.l, m_src_ids, m_src_count); + init_overlap_iterator(eq_id, eq.l); + return m_lhs_eqs; + } + + // + // backward iterator allows simplification of eq + // The rhs of eq is a super-set of lhs of other eq. + // + unsigned_vector const& ac_plugin::backward_iterator(unsigned eq_id) { + auto const& eq = m_eqs[eq_id]; + init_ids_counts(eq_id, eq.r, m_dst_ids, m_dst_count); + init_overlap_iterator(eq_id, eq.r); + m_backward_simplified = false; + return m_lhs_eqs; + } + + void ac_plugin::init_overlap_iterator(unsigned eq_id, unsigned monomial_id) { + m_lhs_eqs.reset(); + for (auto n : monomial(monomial_id)) + m_lhs_eqs.append(n->root->lhs); + + // prune m_lhs_eqs to single occurrences + unsigned j = 0; + for (unsigned i = 0; i < m_lhs_eqs.size(); ++i) { + unsigned id = m_lhs_eqs[i]; + m_eq_seen.reserve(id + 1, false); + if (m_eq_seen[id]) + continue; + if (id == eq_id) + continue; + m_lhs_eqs[j++] = id; + m_eq_seen[id] = true; + } + m_lhs_eqs.shrink(j); + for (auto id : m_lhs_eqs) + m_eq_seen[id] = false; + } + + // + // forward iterator simplifies other eqs where their rhs is a superset of lhs of eq + // + unsigned_vector const& ac_plugin::forward_iterator(unsigned eq_id) { + auto& eq = m_eqs[eq_id]; + m_src_r.reset(); + m_src_r.append(monomial(eq.r)); + init_ids_counts(eq_id, eq.l, m_src_ids, m_src_count); + unsigned min_r = UINT_MAX; + node* min_n = nullptr; + for (auto n : monomial(eq.l)) + if (n->root->rhs.size() < min_r) + min_n = n, min_r = n->root->rhs.size(); + // found node that occurs in fewest rhs + VERIFY(min_n); + return min_n->rhs; + } + + void ac_plugin::init_ids_counts(unsigned eq_id, unsigned monomial_id, unsigned_vector& ids, unsigned_vector& counts) { + auto& eq = m_eqs[eq_id]; + reset_ids_counts(ids, counts); + for (auto n : monomial(monomial_id)) { + unsigned id = n->root_id(); + counts.setx(id, counts.get(id, 0) + 1, 0); + ids.push_back(id); + } + } + + void ac_plugin::reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts) { + for (auto id : ids) + counts[id] = 0; + ids.reset(); + } + + void ac_plugin::forward_simplify(unsigned dst_eq, unsigned src_eq) { + + if (src_eq == dst_eq) + return; + + // check that left src.l is a subset of dst.r + // dst = A -> BC + // src = B -> D + // post(dst) := A -> CD + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + + reset_ids_counts(m_dst_ids, m_dst_count); + + unsigned src_l_size = monomial(src.l).size(); + unsigned src_r_size = m_src_r.size(); + + // subtract src.l from dst.r if src.l is a subset of dst.r + // new_rhs := old_rhs - src_lhs + src_rhs + unsigned num_overlap = 0; + for (auto n : monomial(dst.r)) { + unsigned id = n->root_id(); + unsigned count = m_src_count.get(id, 0); + if (count == 0) + m_src_r.push_back(n); + else { + unsigned dst_count = m_dst_count.get(id, 0); + if (dst_count >= count) + m_src_r.push_back(n); + else + m_dst_count.set(id, dst_count + 1), m_dst_ids.push_back(id), ++num_overlap; + } + } + // The dst.r has to be a superset of src.l, otherwise simplification does not apply + if (num_overlap == src_l_size) { + auto new_r = to_monomial(nullptr, m_src_r); + m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); + m_eqs[dst_eq].r = new_r; + m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq); + push_undo(is_update_eq); + } + m_src_r.shrink(src_r_size); + } + + void ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { + if (src_eq == dst_eq) + return; + + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + // + // dst_ids, dst_count contain rhs of dst_eq + // + + // check that src.l is a subset of dst.r + reset_ids_counts(m_src_ids, m_src_count); + + bool is_subset = true; + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + unsigned dst_count = m_dst_count.get(id, 0); + if (dst_count == 0) { + is_subset = false; + break; + } + else { + unsigned src_count = m_src_count.get(id, 0); + if (src_count >= dst_count) { + is_subset = false; + break; + } + else + m_src_count.set(id, src_count + 1), m_src_ids.push_back(id); + } + } + + if (is_subset) { + // dst_rhs := dst_rhs - src_lhs + src_rhs + m_src_r.reset(); + m_src_r.append(monomial(src.r)); + // add to m_src_r elements of dst.r that are not in src.l + for (auto n : monomial(dst.r)) { + unsigned id = n->root_id(); + unsigned count = m_src_count.get(id, 0); + if (count == 0) + m_src_r.push_back(n); + else + --m_src_count[id]; + } + auto new_r = to_monomial(nullptr, m_src_r); + m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); + m_eqs[dst_eq].r = new_r; + m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq); + push_undo(is_update_eq); + m_backward_simplified = true; + } + } + + void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) { + if (src_eq == dst_eq) + return; + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + + // AB -> C, AD -> E => BE ~ CD + // m_src_ids, m_src_counts contains information about src (call it AD -> E) + reset_ids_counts(m_dst_ids, m_dst_count); + + m_dst_r.reset(); + m_dst_r.append(monomial(dst.r)); + unsigned src_r_size = m_src_r.size(); + SASSERT(src_r_size == monomial(src.r).size()); + // dst_r contains C + // src_r contains E + + // compute BE, initialize dst_ids, dst_counts + for (auto n : monomial(dst.l)) { + unsigned id = n->root_id(); + unsigned src_count = m_src_count.get(id, 0); + unsigned dst_count = m_dst_count.get(id, 0); + m_dst_count.set(id, dst_count + 1); + m_dst_ids.push_back(id); + if (src_count < dst_count) + m_src_r.push_back(n); + } + // compute CD + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + unsigned dst_count = m_dst_count.get(id, 0); + if (dst_count > 0) + --m_dst_count[id]; + else + m_dst_r.push_back(n); + } + + justification j = justify_rewrite(src_eq, dst_eq); + if (m_src_r.size() == 1 && m_dst_r.size() == 1) + push_merge(m_src_r[0]->n, m_dst_r[0]->n, j); + else + init_equation({ to_monomial(nullptr, m_src_r), to_monomial(nullptr, m_dst_r), false, j }); + + m_src_r.shrink(src_r_size); + } + + + void ac_plugin::propagate_shared() { + for (auto m : m_shared_trail) + simplify_shared(m); + // check for collisions, push_merge when there is a collision. + } + + void ac_plugin::simplify_shared(unsigned monomial_id) { + // apply processed as a set of rewrites + } + + justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) { + auto const& e1 = m_eqs[eq1]; + auto const& e2 = m_eqs[eq2]; + auto* j = m_dep_manager.mk_join(m_dep_manager.mk_leaf(e1.j), m_dep_manager.mk_leaf(e2.j)); + j = justify_monomial(j, monomial(e1.l)); + j = justify_monomial(j, monomial(e1.r)); + j = justify_monomial(j, monomial(e2.l)); + j = justify_monomial(j, monomial(e2.r)); + m_dep_manager.push_scope(); + push_undo(is_join_justification); + return justification::dependent(j); + } + + justification::dependency* ac_plugin::justify_monomial(justification::dependency* j, ptr_vector const& m) { + for (auto n : m) + if (n->root->n != n->n) + j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n))); + return j; + } +} diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h new file mode 100644 index 000000000..5594f5e9d --- /dev/null +++ b/src/ast/euf/euf_ac_plugin.h @@ -0,0 +1,174 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_ac_plugin.h + +Abstract: + + plugin structure for ac functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +ex: +xyz -> xy, then xyzz -> xy by repeated rewriting + +monomials = [0 |-> xyz, 1 |-> xy, 2 |-> xyzz] +parents(x) = [0, 1, 2] +parents(z) = [0, 1] +for p in parents(xyzz): + p != xyzz + p' := simplify_using(xyzz, p) + if p != p': + repeat reduction using p := p' + +--*/ + +#pragma once + +#include "ast/euf/euf_plugin.h" + +namespace euf { + + class ac_plugin : public plugin { + + // enode structure for AC equivalenes + struct node { + enode* n; // associated enode + node* root; // path compressed root + node* next; // next in equaivalence class + justification j; // justification for equality + node* target = nullptr; // justified next + unsigned_vector shared; // shared occurrences + unsigned_vector lhs; // left hand side of equalities + unsigned_vector rhs; // left side of equalities + + unsigned root_id() const { return root->n->get_id(); } + ~node() {} + static node* mk(region& r, enode* n); + }; + + class equiv { + node& n; + public: + class iterator { + node* m_first; + node* m_last; + public: + iterator(node* n, node* m) : m_first(n), m_last(m) {} + node* operator*() { return m_first; } + iterator& operator++() { if (!m_last) m_last = m_first; m_first = m_first->next; return *this; } + iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; } + bool operator==(iterator const& other) const { return m_last == other.m_last && m_first == other.m_first; } + bool operator!=(iterator const& other) const { return !(*this == other); } + }; + equiv(node& _n) :n(_n) {} + equiv(node* _n) :n(*_n) {} + iterator begin() const { return iterator(&n, nullptr); } + iterator end() const { return iterator(&n, &n); } + }; + + struct eq { + unsigned l, r; // refer to monomials + bool is_processed = false; + justification j; + }; + + unsigned m_fid; + unsigned m_op; + vector m_eqs; + ptr_vector m_nodes; + vector> m_monomials; + enode_vector m_monomial_enodes; + justification::dependency_manager m_dep_manager; + + // backtrackable state + enum undo_kind { + is_add_eq, + is_add_monomial, + is_add_node, + is_merge_node, + is_update_eq, + is_add_shared, + is_register_shared, + is_join_justification + }; + svector m_undo; + ptr_vector m_node_trail; + unsigned_vector m_monomial_trail, m_shared_trail; + svector> m_merge_trail; + svector> m_update_eq_trail; + + node* mk_node(enode* n); + void merge(node* r1, node* r2, justification j); + + bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_kind(); } + + std::function m_undo_notify; + void push_undo(undo_kind k); + enode_vector m_todo; + unsigned to_monomial(enode* n); + unsigned to_monomial(enode* n, ptr_vector const& ms); + ptr_vector const& monomial(unsigned i) const { return m_monomials[i]; } + ptr_vector& monomial(unsigned i) { return m_monomials[i]; } + + void init_equation(eq const& e); + bool orient_equation(eq & e); + void set_processed(unsigned eq_id, bool f); + unsigned pick_next_eq(); + bool is_trivial(unsigned eq_id) const { throw default_exception("NYI"); } + + void forward_simplify(unsigned eq_id, unsigned using_eq); + void backward_simplify(unsigned eq_id, unsigned using_eq); + void superpose(unsigned src_eq, unsigned dst_eq); + + ptr_vector m_src_r, m_src_l, m_dst_r; + unsigned_vector m_src_ids, m_src_count, m_dst_ids, m_dst_count; + unsigned_vector m_lhs_eqs; + bool_vector m_eq_seen; + bool m_backward_simplified = false; + unsigned m_next_eq_index = 0; + + unsigned_vector const& forward_iterator(unsigned eq); + unsigned_vector const& superpose_iterator(unsigned eq); + unsigned_vector const& backward_iterator(unsigned eq); + void init_ids_counts(unsigned eq, unsigned monomial_id, unsigned_vector& ids, unsigned_vector& counts); + void reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts); + void init_overlap_iterator(unsigned eq, unsigned monomial_id); + + bool is_to_simplify(unsigned eq) const { return !m_eqs[eq].is_processed; } + bool is_processed(unsigned eq) const { return m_eqs[eq].is_processed; } + + justification justify_rewrite(unsigned eq1, unsigned eq2); + justification justify_superpose(justification j1, justification j2); + justification::dependency* justify_monomial(justification::dependency* d, ptr_vector const& m); + + void propagate_shared(); + void simplify_shared(unsigned monomial_id); + + public: + + ac_plugin(egraph& g, unsigned fid, unsigned op); + + unsigned get_id() const override { return m_fid; } + + void register_node(enode* n) override; + + void register_shared(enode* n) override; + + void merge_eh(enode* n1, enode* n2, justification j) override; + + void diseq_eh(enode* n1, enode* n2) override {} + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + void set_undo(std::function u) { m_undo_notify = u; } + }; +} diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp new file mode 100644 index 000000000..e43dd1465 --- /dev/null +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -0,0 +1,77 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_arith_plugin.cpp + +Abstract: + + plugin structure for arithetic + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#include "ast/euf/euf_arith_plugin.h" +#include "ast/euf/euf_egraph.h" +#include + +namespace euf { + + arith_plugin::arith_plugin(egraph& g) : + plugin(g), + a(g.get_manager()), + m_add(g, get_id(), OP_ADD), + m_mul(g, get_id(), OP_MUL) { + std::function uadd = [&]() { m_undo.push_back(undo_t::undo_add); }; + m_add.set_undo(uadd); + std::function umul = [&]() { m_undo.push_back(undo_t::undo_mul); }; + m_mul.set_undo(umul); + } + + void arith_plugin::register_node(enode* n) { + // no-op + } + + void arith_plugin::register_shared(enode* n) { + if (a.is_add(n->get_expr())) + m_add.register_shared(n); + if (a.is_mul(n->get_expr())) + m_mul.register_shared(n); + } + + void arith_plugin::merge_eh(enode* n1, enode* n2, justification j) { + m_add.merge_eh(n1, n2, j); + m_mul.merge_eh(n1, n2, j); + } + + void arith_plugin::diseq_eh(enode* n1, enode* n2) { + // no-op + } + + void arith_plugin::undo() { + auto k = m_undo.back(); + m_undo.pop_back(); + switch (k) { + case undo_t::undo_add: + m_add.undo(); + break; + case undo_t::undo_mul: + m_mul.undo(); + break; + default: + UNREACHABLE(); + } + } + + std::ostream& arith_plugin::display(std::ostream& out) const { + out << "add\n"; + m_add.display(out); + out << "mul\n"; + m_mul.display(out); + return out; + } +} diff --git a/src/ast/euf/euf_arith_plugin.h b/src/ast/euf/euf_arith_plugin.h new file mode 100644 index 000000000..beadf7823 --- /dev/null +++ b/src/ast/euf/euf_arith_plugin.h @@ -0,0 +1,51 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_arith_plugin.h + +Abstract: + + plugin structure for arithetic +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#pragma once + +#include "ast/arith_decl_plugin.h" +#include "ast/euf/euf_plugin.h" +#include "ast/euf/euf_ac_plugin.h" + +namespace euf { + + class egraph; + + class arith_plugin : public plugin { + enum undo_t { undo_add, undo_mul }; + arith_util a; + svector m_undo; + ac_plugin m_add, m_mul; + + public: + arith_plugin(egraph& g); + + unsigned get_id() const override { return a.get_family_id(); } + + void register_node(enode* n) override; + + void register_shared(enode* n) override; + + void merge_eh(enode* n1, enode* n2, justification j) override; + + void diseq_eh(enode* n1, enode* n2) override; + + void undo() override; + + std::ostream& display(std::ostream& out) const override; + + }; +} diff --git a/src/ast/euf/euf_bv_plugin.cpp b/src/ast/euf/euf_bv_plugin.cpp new file mode 100644 index 000000000..691052d4a --- /dev/null +++ b/src/ast/euf/euf_bv_plugin.cpp @@ -0,0 +1,347 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_bv_plugin.cpp + +Abstract: + + plugin structure for bit-vectors + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + Jakob Rath 2023-11-08 + +Objective: + +satisfies extract/concat axioms. + - concat(n{I],n[J]) = n[IJ] for I, J consecutive. + - concat(v1, v2) = 2^width(v1)*v2 + v1 + - concat(n[width(n)-1:0]) = n + - concat(a, b)[I] = concat(a[I1], b[I2]) + - concat(a, concat(b, c)) = concat(concat(a, b), c) + +E-graph: + +The E-graph contains node definitions of the form + + n := f(n1,n2,..) + +and congruences: + + n ~ n' means root(n) = root(n') + +Saturated state: + + 1. n := n1[I], n' := n2[J], n1 ~ n2 => root(n1) contains tree refining both I, J from smaller intervals + + 2. n := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => n ~ n3[IJ] + + 3. n := concat(n1[I], n2[J]), I and J are consecutive & n1 ~ n2, n1[I] ~ v1, n1[J] ~ v2 => n ~ 2^width(v1)*v2 + v1 + + 4. n := concat(n1[I], n2[J], I, J are consecutive, n1 ~ n2, n ~ v => n1[I] ~ v mod 2^width(n1[I]), n2[J] ~ v div 2^width(n1[I]) + + 5. n' := n[I] => n ~ n[width(n)-1:0] + + 6. n := concat(a, concat(b, c)) => n ~ concat(concat(a, b), c) + - handled by rewriter pre-processing for inputs + - terms created internally are not equated modulo associativity + + 7, n := concat(n1, n2)[I] => n ~ concat(n1[I1],n2[I2]) or n[I1] or n[I2] + - handled by rewriter pre-processing + +Example: + x == (x1 x2) x3 + y == y1 (y2 y3) + x1 == y1, x2 == y2, x3 == y3 + => + x = y + + by x2 == y2, x3 == y3 => (x2 x3) = (y2 y3) + by (2) => x[I23] = (x2 x3) + by (2) => x[I123] = (x1 (x2 x3)) + by (5) => x = x[I123] + +--*/ + + +#include "ast/euf/euf_bv_plugin.h" +#include "ast/euf/euf_egraph.h" + +namespace euf { + + bv_plugin::bv_plugin(egraph& g): + plugin(g), + bv(g.get_manager()) + {} + + enode* bv_plugin::mk_value_concat(enode* a, enode* b) { + auto v1 = get_value(a); + auto v2 = get_value(b); + auto v3 = v1 + v2 * power(rational(2), width(a)); + return mk_value(v3, width(a) + width(b)); + } + + enode* bv_plugin::mk_value(rational const& v, unsigned sz) { + auto e = bv.mk_numeral(v, sz); + return mk(e, 0, nullptr); + } + + void bv_plugin::merge_eh(enode* x, enode* y, justification j) { + SASSERT(x == x->get_root()); + SASSERT(x == y->get_root()); + + TRACE("bv", tout << "merge_eh " << g.bpp(x) << " == " << g.bpp(y) << "\n"); + SASSERT(!m_internal); + flet _internal(m_internal, true); + + propagate_values(x); + + // ensure slices align + if (has_sub(x) || has_sub(y)) { + enode_vector& xs = m_xs, & ys = m_ys; + xs.reset(); + ys.reset(); + xs.push_back(x); + ys.push_back(y); + merge(xs, ys, j); + } + + // ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ] + for (auto* n : enode_class(x)) + propagate_extract(n); + } + + // enforce concat(v1, v2) = v2*2^|v1| + v1 + void bv_plugin::propagate_values(enode* x) { + if (!is_value(x)) + return; + + enode* a, * b; + for (enode* p : enode_parents(x)) + if (is_concat(p, a, b) && is_value(a) && is_value(b) && !is_value(p)) + push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b)); + + for (enode* sib : enode_class(x)) { + if (is_concat(sib, a, b)) { + if (!is_value(a) || !is_value(b)) { + auto val = get_value(x); + auto v1 = mod2k(val, width(a)); + auto v2 = machine_div2k(val, width(a)); + push_merge(mk_concat(mk_value(v1, width(a)), mk_value(v2, width(b))), x->get_interpreted()); + } + } + } + } + + // + // p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ] + // + // n is of form arg[I] + // p is of form concat(n, b) or concat(a, n) + // b is congruent to arg[J], I is consecutive with J => ensure that arg[IJ] = p + // a is congruent to arg[J], J is consecutive with I => ensure that arg[JI] = p + // + + void bv_plugin::propagate_extract(enode* n) { + unsigned lo1, hi1, lo2, hi2; + enode* a, * b; + if (!is_extract(n, lo1, hi1)) + return; + + enode* arg = n->get_arg(0); + enode* arg_r = arg->get_root(); + enode* n_r = n->get_root(); + + auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) { + TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n"); + unsigned lo_, hi_; + for (enode* p1 : enode_parents(n)) + if (is_extract(p1, lo_, hi_) && lo_ == lo && hi_ == hi && p1->get_arg(0)->get_root() == arg_r) + return; + // add the axiom instead of merge(p, mk_extract(arg, lo, hi)), which would require tracking justifications + push_merge(mk_concat(mk_extract(arg, lo, mid), mk_extract(arg, mid + 1, hi)), mk_extract(arg, lo, hi)); + }; + + auto propagate_left = [&](enode* b) { + TRACE("bv", tout << "propagate-left " << g.bpp(b) << "\n"); + for (enode* sib : enode_class(b)) + if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2) + ensure_concat(lo1, hi1, hi2); + }; + + auto propagate_right = [&](enode* a) { + TRACE("bv", tout << "propagate-right " << g.bpp(a) << "\n"); + for (enode* sib : enode_class(a)) + if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1) + ensure_concat(lo2, hi2, hi1); + }; + + for (enode* p : enode_parents(n)) { + if (is_concat(p, a, b)) { + if (a->get_root() == n_r) + propagate_left(b); + if (b->get_root() == n_r) + propagate_right(a); + } + } + } + + void bv_plugin::push_undo_split(enode* n) { + m_undo_split.push_back(n); + push_plugin_undo(bv.get_family_id()); + } + + void bv_plugin::undo() { + enode* n = m_undo_split.back(); + m_undo_split.pop_back(); + auto& i = info(n); + i.lo = nullptr; + i.hi = nullptr; + i.cut = null_cut; + } + + void bv_plugin::register_node(enode* n) { + TRACE("bv", tout << "register " << g.bpp(n) << "\n"); + auto& i = info(n); + i.value = n; + enode* a, * b; + if (is_concat(n, a, b)) { + i.lo = a; + i.hi = b; + i.cut = width(a); + push_undo_split(n); + } + unsigned lo, hi; + if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) { + enode* arg = n->get_arg(0); + unsigned w = width(arg); + if (all_of(enode_parents(arg), [&](enode* p) { unsigned _lo, _hi; return !is_extract(p, _lo, _hi) || _lo != 0 || _hi + 1 != w; })) + push_merge(mk_extract(arg, 0, w - 1), arg); + ensure_slice(arg, lo, hi); + } + } + + // + // Ensure that there are slices at boundaries of n[hi:lo] + // + void bv_plugin::ensure_slice(enode* n, unsigned lo, unsigned hi) { + enode* r = n; + unsigned lb = 0, ub = width(n) - 1; + while (true) { + TRACE("bv", tout << "ensure slice " << g.bpp(n) << " " << lb << " [" << lo << ", " << hi << "] " << ub << "\n"); + SASSERT(lb <= lo && hi <= ub); + SASSERT(ub - lb + 1 == width(r)); + if (lb == lo && ub == hi) + return; + slice_info& i = info(r); + if (!i.lo) { + if (lo > lb) { + split(r, lo - lb); + if (hi < ub) // or split(info(r).hi, ...) + ensure_slice(n, lo, hi); + } + else if (hi < ub) + split(r, ub - hi); + break; + } + auto cut = i.cut; + if (cut + lb <= lo) { + lb += cut; + r = i.hi; + continue; + } + if (cut + lb > hi) { + ub = cut + lb - 1; + r = i.lo; + continue; + } + SASSERT(lo < cut + lb && cut + lb <= hi); + ensure_slice(n, lo, cut + lb - 1); + ensure_slice(n, cut + lb, hi); + break; + } + } + + enode* bv_plugin::mk_extract(enode* n, unsigned lo, unsigned hi) { + SASSERT(lo <= hi && width(n) > hi - lo); + unsigned lo1, hi1; + while (is_extract(n, lo1, hi1)) { + lo += lo1; + hi += lo1; + n = n->get_arg(0); + } + return mk(bv.mk_extract(hi, lo, n->get_expr()), 1, &n); + } + + enode* bv_plugin::mk_concat(enode* lo, enode* hi) { + enode* args[2] = { lo, hi }; + return mk(bv.mk_concat(lo->get_expr(), hi->get_expr()), 2, args); + } + + void bv_plugin::merge(enode_vector& xs, enode_vector& ys, justification dep) { + while (!xs.empty()) { + SASSERT(!ys.empty()); + auto x = xs.back(); + auto y = ys.back(); + if (unfold_sub(x, xs)) + continue; + else if (unfold_sub(y, ys)) + continue; + else if (unfold_width(x, xs, y, ys)) + continue; + else if (unfold_width(y, ys, x, xs)) + continue; + else if (x->get_root() != y->get_root()) + push_merge(x, y, dep); + xs.pop_back(); + ys.pop_back(); + } + SASSERT(ys.empty()); + } + + bool bv_plugin::unfold_sub(enode* x, enode_vector& xs) { + if (!has_sub(x)) + return false; + xs.pop_back(); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + return true; + } + + bool bv_plugin::unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys) { + if (width(x) <= width(y)) + return false; + split(x, width(y)); + xs.pop_back(); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + return true; + } + + void bv_plugin::split(enode* n, unsigned cut) { + TRACE("bv", tout << "split: " << g.bpp(n) << " " << cut << "\n"); + unsigned w = width(n); + SASSERT(!info(n).hi); + SASSERT(0 < cut && cut < w); + enode* hi = mk_extract(n, cut, w - 1); + enode* lo = mk_extract(n, 0, cut - 1); + auto& i = info(n); + SASSERT(i.value); + i.hi = hi; + i.lo = lo; + i.cut = cut; + push_undo_split(n); + push_merge(mk_concat(lo, hi), n); + } + + std::ostream& bv_plugin::display(std::ostream& out) const { + out << "bv\n"; + for (auto const& i : m_info) + if (i.lo) + out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n"; + return out; + } +} diff --git a/src/ast/euf/euf_bv_plugin.h b/src/ast/euf/euf_bv_plugin.h new file mode 100644 index 000000000..4ab9d4618 --- /dev/null +++ b/src/ast/euf/euf_bv_plugin.h @@ -0,0 +1,100 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_bv_plugin.h + +Abstract: + + plugin structure for bit-vectors + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + Jakob Rath 2023-11-08 + + +--*/ + +#pragma once + +#include "ast/bv_decl_plugin.h" +#include "ast/euf/euf_plugin.h" + +namespace euf { + + class egraph; + + class bv_plugin : public plugin { + static constexpr unsigned null_cut = std::numeric_limits::max(); + + struct slice_info { + unsigned cut = null_cut; // = bv.get_bv_size(lo) + enode* hi = nullptr; // + enode* lo = nullptr; // + enode* value = nullptr; + void reset() { *this = slice_info(); } + }; + using slice_info_vector = svector; + + bv_util bv; + slice_info_vector m_info; // indexed by enode::get_id() + + enode_vector m_xs, m_ys; + + bool is_concat(enode* n) const { return bv.is_concat(n->get_expr()); } + bool is_concat(enode* n, enode*& a, enode*& b) { return is_concat(n) && (a = n->get_arg(0), b = n->get_arg(1), true); } + bool is_extract(enode* n, unsigned& lo, unsigned& hi) { expr* body; return bv.is_extract(n->get_expr(), lo, hi, body); } + bool is_extract(enode* n) const { return bv.is_extract(n->get_expr()); } + unsigned width(enode* n) const { return bv.get_bv_size(n->get_expr()); } + + enode* mk_extract(enode* n, unsigned lo, unsigned hi); + enode* mk_concat(enode* lo, enode* hi); + enode* mk_value_concat(enode* lo, enode* hi); + enode* mk_value(rational const& v, unsigned sz); + unsigned width(enode* n) { return bv.get_bv_size(n->get_expr()); } + bool is_value(enode* n) { return n->get_root()->interpreted(); } + rational get_value(enode* n) { rational val; VERIFY(bv.is_numeral(n->get_interpreted()->get_expr(), val)); return val; } + slice_info& info(enode* n) { unsigned id = n->get_id(); m_info.reserve(id + 1); return m_info[id]; } + slice_info& root_info(enode* n) { unsigned id = n->get_root_id(); m_info.reserve(id + 1); return m_info[id]; } + bool has_sub(enode* n) { return !!info(n).lo; } + enode* sub_lo(enode* n) { return info(n).lo; } + enode* sub_hi(enode* n) { return info(n).hi; } + + bool m_internal = false; + void ensure_slice(enode* n, unsigned lo, unsigned hi); + + + void split(enode* n, unsigned cut); + + bool unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys); + bool unfold_sub(enode* x, enode_vector& xs); + void merge(enode_vector& xs, enode_vector& ys, justification j); + void propagate_extract(enode* n); + void propagate_values(enode* n); + + enode_vector m_undo_split; + void push_undo_split(enode* n); + + public: + bv_plugin(egraph& g); + + unsigned get_id() const override { return bv.get_family_id(); } + + void register_node(enode* n) override; + + void register_shared(enode* n) override {} + + void merge_eh(enode* n1, enode* n2, justification j) override; + + void diseq_eh(enode* n1, enode* n2) override {} + + void propagate() override {} + + void undo() override; + + std::ostream& display(std::ostream& out) const override; + + }; +} diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 8ab44d8cd..c9bc8ac1e 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -18,6 +18,7 @@ Notes: #include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_bv_plugin.h" #include "ast/ast_pp.h" #include "ast/ast_translation.h" @@ -113,8 +114,11 @@ namespace euf { n->mark_interpreted(); if (m_on_make) m_on_make(n); - if (num_args == 0) + register_node(n); + + if (num_args == 0) return n; + if (m.is_eq(f) && !m.is_iff(f)) { n->set_is_equality(); reinsert_equality(n); @@ -123,11 +127,26 @@ namespace euf { if (n2 == n) update_children(n); else - merge(n, n2, justification::congruence(comm, m_congruence_timestamp++)); - + push_merge(n, n2, justification::congruence(comm, 0)); + // merge(n, n2, justification::congruence(comm, m_congruence_timestamp++)); + return n; } + void egraph::register_node(enode* n) { + if (m_plugins.empty()) + return; + auto* p = get_plugin(n); + if (p) + p->register_node(n); + for (auto* arg : enode_args(n)) { + auto* p_arg = get_plugin(arg); + if (p != p_arg) + p_arg->register_shared(arg); + } + + } + egraph::egraph(ast_manager& m) : m(m), m_table(m), m_tmp_app(2), m_exprs(m), m_eq_decls(m) { m_tmp_eq = enode::mk_tmp(m_region, 2); } @@ -139,6 +158,18 @@ namespace euf { memory::deallocate(m_tmp_node); } + void egraph::add_plugins() { + auto* plugin = alloc(bv_plugin, *this); + m_plugins.reserve(plugin->get_id() + 1); + m_plugins.set(plugin->get_id(), plugin); + } + + void egraph::propagate_plugins() { + for (auto* p : m_plugins) + if (p) + p->propagate(); + } + void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) { TRACE("euf_verbose", tout << "eq: " << v1 << " == " << v2 << "\n";); m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r)); @@ -422,6 +453,9 @@ namespace euf { p.r1->m_args[i]->get_root()->m_parents.pop_back(); } break; + case update_record::tag_t::is_plugin_undo: + m_plugins[p.m_th_id]->undo(); + break; default: UNREACHABLE(); break; @@ -442,7 +476,7 @@ namespace euf { if (!n1->cgc_enabled() && !n2->cgc_enabled()) return; - SASSERT(n1->get_sort() == n2->get_sort()); + enode* r1 = n1->get_root(); enode* r2 = n2->get_root(); if (r1 == r2) @@ -452,6 +486,7 @@ namespace euf { IF_VERBOSE(20, j.display(verbose_stream() << "merge: " << bpp(n1) << " == " << bpp(n2) << " ", m_display_justification) << "\n";); force_push(); SASSERT(m_num_scopes == 0); + SASSERT(n1->get_sort() == n2->get_sort()); ++m_stats.m_num_merge; if (r1->interpreted() && r2->interpreted()) { set_conflict(n1, n2, j); @@ -476,7 +511,7 @@ namespace euf { c->m_root = r2; std::swap(r1->m_next, r2->m_next); r2->inc_class_size(r1->class_size()); - merge_th_eq(r1, r2); + merge_th_eq(r1, r2, j); reinsert_parents(r1, r2); if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr()))) add_literal(n1, r2); @@ -487,6 +522,10 @@ namespace euf { for (auto& cb : m_on_merge) cb(r2, r1); + + auto* p = get_plugin(r1); + if (p) + p->merge_eh(r2, r1, j); } void egraph::remove_parents(enode* r) { @@ -532,7 +571,7 @@ namespace euf { } } - void egraph::merge_th_eq(enode* n, enode* root) { + void egraph::merge_th_eq(enode* n, enode* root, justification j) { SASSERT(n != root); for (auto const& iv : enode_th_vars(n)) { theory_id id = iv.get_id(); @@ -574,13 +613,17 @@ namespace euf { unmerge_justification(n1); } - - bool egraph::propagate() { - SASSERT(m_num_scopes == 0 || m_to_merge.empty()); + bool egraph::propagate() { force_push(); for (unsigned i = 0; i < m_to_merge.size() && m.limit().inc() && !inconsistent(); ++i) { auto const& w = m_to_merge[i]; - merge(w.a, w.b, justification::congruence(w.commutativity, m_congruence_timestamp++)); + if (w.j.is_congruence()) + merge(w.a, w.b, justification::congruence(w.j.is_commutative(), m_congruence_timestamp++)); + else + merge(w.a, w.b, w.j); + + if (i + 1 == m_to_merge.size()) + propagate_plugins(); } m_to_merge.reset(); return @@ -746,8 +789,15 @@ namespace euf { TRACE("euf_verbose", tout << "explain-eq: " << bpp(a) << " == " << bpp(b) << " jst: " << j << "\n";); if (j.is_external()) justifications.push_back(j.ext()); - else if (j.is_congruence()) + else if (j.is_congruence()) push_congruence(a, b, j.is_commutative()); + else if (j.is_dependent()) { + vector js; + for (auto const& j2 : justification::dependency_manager::s_linearize(j.get_dependency(), js)) + explain_eq(justifications, cc, a, b, j2); + } + else if (j.is_equality()) + explain_eq(justifications, cc, j.lhs(), j.rhs()); if (cc && j.is_congruence()) cc->push_back(std::tuple(a->get_app(), b->get_app(), j.timestamp(), j.is_commutative())); } @@ -867,7 +917,10 @@ namespace euf { for (enode* n : m_nodes) max_args = std::max(max_args, n->num_args()); for (enode* n : m_nodes) - display(out, max_args, n); + display(out, max_args, n); + for (auto* p : m_plugins) + if (p) + p->display(out); return out; } diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 5158f2fc9..44bee42ea 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -29,8 +29,10 @@ Notes: #include "util/statistics.h" #include "util/trail.h" #include "util/lbool.h" +#include "util/scoped_ptr_vector.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_etable.h" +#include "ast/euf/euf_plugin.h" #include "ast/ast_ll_pp.h" #include @@ -82,12 +84,15 @@ namespace euf { class egraph { + friend class plugin; + typedef ptr_vector trail_stack; struct to_merge { enode* a, * b; - bool commutativity; - to_merge(enode* a, enode* b, bool c) : a(a), b(b), commutativity(c) {} + justification j; + to_merge(enode* a, enode* b, bool c) : a(a), b(b), j(justification::congruence(c, 0)) {} + to_merge(enode* a, enode* b, justification j) : a(a), b(b), j(j) {} }; struct stats { @@ -113,10 +118,12 @@ namespace euf { struct lbl_set {}; struct update_children {}; struct set_relevant {}; + struct plugin_undo {}; enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children, is_add_th_var, is_replace_th_var, is_new_th_eq, is_lbl_hash, is_new_th_eq_qhead, - is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant }; + is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant, + is_plugin_undo }; tag_t tag; enode* r1; enode* n1; @@ -159,11 +166,14 @@ namespace euf { tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} update_record(enode* n, set_relevant) : tag(tag_t::is_set_relevant), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} + update_record(unsigned th_id, plugin_undo) : + tag(tag_t::is_plugin_undo), r1(nullptr), n1(nullptr), m_th_id(th_id) {} }; ast_manager& m; svector m_to_merge; etable m_table; region m_region; + scoped_ptr_vector m_plugins; svector m_updates; unsigned_vector m_scopes; enode_vector m_expr2enode; @@ -202,6 +212,13 @@ namespace euf { } void push_node(enode* n) { m_updates.push_back(update_record(n)); } + // plugin related methods + void push_plugin_undo(unsigned th_id) { m_updates.push_back(update_record(th_id, update_record::plugin_undo())); } + void push_merge(enode* a, enode* b, justification j) { m_to_merge.push_back({ a, b, j }); } + plugin* get_plugin(enode* n) { return m_plugins.get(n->get_sort()->get_family_id(), nullptr); } + void register_node(enode* n); + void propagate_plugins(); + void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r); void add_th_diseqs(theory_id id, theory_var v1, enode* r); @@ -213,7 +230,7 @@ namespace euf { void force_push(); void set_conflict(enode* n1, enode* n2, justification j); void merge(enode* n1, enode* n2, justification j); - void merge_th_eq(enode* n, enode* root); + void merge_th_eq(enode* n, enode* root, justification j); void merge_justification(enode* n1, enode* n2, justification j); void reinsert_parents(enode* r1, enode* r2); void remove_parents(enode* r); @@ -241,6 +258,7 @@ namespace euf { public: egraph(ast_manager& m); ~egraph(); + void add_plugins(); enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); } enode* find(expr* f, unsigned n, enode* const* args); enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args); diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 76c3611eb..4fc682f65 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -202,6 +202,7 @@ namespace euf { enode* get_root() const { return m_root; } expr* get_expr() const { return m_expr; } sort* get_sort() const { return m_expr->get_sort(); } + enode* get_interpreted() const { return get_root(); } app* get_app() const { return to_app(m_expr); } func_decl* get_decl() const { return is_app(m_expr) ? to_app(m_expr)->get_decl() : nullptr; } unsigned get_expr_id() const { return m_expr->get_id(); } diff --git a/src/ast/euf/euf_justification.h b/src/ast/euf/euf_justification.h index c98002396..1a4e76bf0 100644 --- a/src/ast/euf/euf_justification.h +++ b/src/ast/euf/euf_justification.h @@ -22,19 +22,34 @@ Notes: #pragma once +#include "util/dependency.h" + namespace euf { + class enode; + class justification { + public: + typedef scoped_dependency_manager dependency_manager; + typedef scoped_dependency_manager::dependency dependency; + private: enum class kind_t { axiom_t, congruence_t, - external_t + external_t, + dependent_t, + equality_t }; kind_t m_kind; - bool m_comm; + union { + bool m_comm; + enode* m_n1; + }; union { void* m_external; uint64_t m_timestamp; + dependency* m_dependency; + enode* m_n2; }; justification(bool comm, uint64_t ts): @@ -49,6 +64,18 @@ namespace euf { m_external(ext) {} + justification(dependency* dep, int): + m_kind(kind_t::dependent_t), + m_comm(false), + m_dependency(dep) + {} + + justification(enode* n1, enode* n2): + m_kind(kind_t::equality_t), + m_n1(n1), + m_n2(n2) + {} + public: justification(): m_kind(kind_t::axiom_t), @@ -59,10 +86,17 @@ namespace euf { static justification axiom() { return justification(); } static justification congruence(bool c, uint64_t ts) { return justification(c, ts); } static justification external(void* ext) { return justification(ext); } + static justification dependent(dependency* d) { return justification(d, 1); } + static justification equality(enode* a, enode* b) { return justification(a, b); } bool is_external() const { return m_kind == kind_t::external_t; } bool is_congruence() const { return m_kind == kind_t::congruence_t; } bool is_commutative() const { return m_comm; } + bool is_dependent() const { return m_kind == kind_t::dependent_t; } + bool is_equality() const { return m_kind == kind_t::equality_t; } + dependency* get_dependency() const { SASSERT(is_dependent()); return m_dependency; } + enode* lhs() const { SASSERT(is_equality()); return m_n1; } + enode* rhs() const { SASSERT(is_equality()); return m_n2; } uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; } template T* ext() const { SASSERT(is_external()); return static_cast(m_external); } @@ -75,6 +109,9 @@ namespace euf { return axiom(); case kind_t::congruence_t: return congruence(m_comm, m_timestamp); + case kind_t::dependent_t: + NOT_IMPLEMENTED_YET(); + return dependent(m_dependency); default: UNREACHABLE(); return axiom(); @@ -93,6 +130,13 @@ namespace euf { return out << "axiom"; case kind_t::congruence_t: return out << "congruence"; + case kind_t::dependent_t: { + vector js; + out << "dependent"; + for (auto const& j : dependency_manager::s_linearize(m_dependency, js)) + j.display(out << " ", ext); + return out; + } default: UNREACHABLE(); return out; diff --git a/src/ast/euf/euf_plugin.cpp b/src/ast/euf/euf_plugin.cpp new file mode 100644 index 000000000..57c1849fd --- /dev/null +++ b/src/ast/euf/euf_plugin.cpp @@ -0,0 +1,47 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_plugin.cpp + +Abstract: + + plugin structure for euf + + Plugins allow adding equality saturation for theories. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + +--*/ + +#include "ast/euf/euf_egraph.h" + +namespace euf { + + void plugin::push_plugin_undo(unsigned th_id) { + g.push_plugin_undo(th_id); + } + + void plugin::push_merge(enode* a, enode* b, justification j) { + g.push_merge(a, b, j); + } + + void plugin::push_merge(enode* a, enode* b) { + TRACE("plugin", tout << g.bpp(a) << " == " << g.bpp(b) << "\n"); + g.push_merge(a, b, justification::axiom()); + } + + enode* plugin::mk(expr* e, unsigned n, enode* const* args) { + enode* r = g.find(e); + if (!r) + r = g.mk(e, 0, n, args); + return r; + } + + region& plugin::get_region() { + return g.m_region; + } +} diff --git a/src/ast/euf/euf_plugin.h b/src/ast/euf/euf_plugin.h new file mode 100644 index 000000000..3e9d35771 --- /dev/null +++ b/src/ast/euf/euf_plugin.h @@ -0,0 +1,58 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_plugin.h + +Abstract: + + plugin structure for euf + + Plugins allow adding equality saturation for theories. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + +--*/ + +#pragma once + +#include "ast/euf/euf_enode.h" +#include "ast/euf/euf_justification.h" + +namespace euf { + + + class plugin { + protected: + egraph& g; + void push_plugin_undo(unsigned th_id); + void push_merge(enode* a, enode* b, justification j); + void push_merge(enode* a, enode* b); + enode* mk(expr* e, unsigned n, enode* const* args); + region& get_region(); + public: + plugin(egraph& g): + g(g) + {} + + virtual unsigned get_id() const = 0; + + virtual void register_node(enode* n) = 0; + + virtual void register_shared(enode* n) = 0; + + virtual void merge_eh(enode* n1, enode* n2, justification j) = 0; + + virtual void diseq_eh(enode* n1, enode* n2) = 0; + + virtual void propagate() = 0; + + virtual void undo() = 0; + + virtual std::ostream& display(std::ostream& out) const = 0; + + }; +} diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 0ae56beb3..d03f82aee 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -420,8 +420,6 @@ namespace euf { return *c; } - - bool solver::unit_propagate() { bool propagated = false; while (!s().inconsistent()) { diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 07a559365..dca916803 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -38,6 +38,7 @@ add_executable(test-z3 dl_util.cpp doc.cpp egraph.cpp + euf_bv_plugin.cpp escaped.cpp ex.cpp expr_rand.cpp diff --git a/src/test/euf_bv_plugin.cpp b/src/test/euf_bv_plugin.cpp new file mode 100644 index 000000000..bea98dfe6 --- /dev/null +++ b/src/test/euf_bv_plugin.cpp @@ -0,0 +1,180 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +--*/ + +#include "util/util.h" +#include "util/timer.h" +#include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_bv_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" +#include + +euf::enode* get_node(euf::egraph& g, expr* e) { + auto* n = g.find(e); + if (n) + return n; + euf::enode_vector args; + for (expr* arg : *to_app(e)) + args.push_back(get_node(g, arg)); + return g.mk(e, 0, args.size(), args.data()); +} + +// align slices, and propagate extensionality +static void test1() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref y(m.mk_const("y", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref y3(bv.mk_extract(31, 24, y), m); + expr_ref y2(bv.mk_extract(23, 8, y), m); + expr_ref y1(bv.mk_extract(7, 0, y), m); + expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); + expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m); + auto* nx = get_node(g, xx); + auto* ny = get_node(g, yy); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; + SASSERT(nx->get_root() == ny->get_root()); +} + +// propagate values down +static void test2() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); + g.merge(get_node(g, xx), get_node(g, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr); + g.propagate(); + SASSERT(get_node(g, x1)->get_root()->interpreted()); + SASSERT(get_node(g, x2)->get_root()->interpreted()); + SASSERT(get_node(g, x3)->get_root()->interpreted()); + SASSERT(get_node(g, x)->get_root()->interpreted()); +} + + +// propagate values up +static void test3() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m); + expr_ref y(m.mk_const("y", u32), m); + g.merge(get_node(g, xx), get_node(g, y), nullptr); + g.merge(get_node(g, x1), get_node(g, bv.mk_numeral(2, 8)), nullptr); + g.merge(get_node(g, x2), get_node(g, bv.mk_numeral(8, 8)), nullptr); + g.propagate(); + SASSERT(get_node(g, bv.mk_concat(x1, x2))->get_root()->interpreted()); + SASSERT(get_node(g, x1)->get_root()->interpreted()); + SASSERT(get_node(g, x2)->get_root()->interpreted()); +} + +// propagate extract up +static void test4() { + // concat(a, x[J]), a = x[I] => x[IJ] = concat(x[I],x[J]) + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + sort_ref u8(bv.mk_sort(8), m); + sort_ref u16(bv.mk_sort(16), m); + expr_ref a(m.mk_const("a", u8), m); + expr_ref x(m.mk_const("x", u32), m); + expr_ref y(m.mk_const("y", u16), m); + expr_ref x1(bv.mk_extract(15, 8, x), m); + expr_ref x2(bv.mk_extract(23, 16, x), m); + g.merge(get_node(g, bv.mk_concat(a, x2)), get_node(g, y), nullptr); + g.merge(get_node(g, x1), get_node(g, a), nullptr); + g.propagate(); + TRACE("bv", tout << g << "\n"); + SASSERT(get_node(g, bv.mk_extract(23, 8, x))->get_root() == get_node(g, y)->get_root()); +} + +// iterative slicing +static void test5() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x1(bv.mk_extract(31, 4, x), m); + expr_ref x2(bv.mk_extract(27, 0, x), m); + auto* nx = get_node(g, x1); + auto* ny = get_node(g, x2); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + +// iterative slicing +static void test6() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x1(bv.mk_extract(31, 3, x), m); + expr_ref x2(bv.mk_extract(28, 0, x), m); + auto* nx = get_node(g, x1); + auto* ny = get_node(g, x2); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + + +void tst_euf_bv_plugin() { + enable_trace("bv"); + enable_trace("plugin"); + test6(); + return; + test1(); + test2(); + test3(); + test4(); + test5(); + test6(); +} \ No newline at end of file diff --git a/src/test/main.cpp b/src/test/main.cpp index f222837a8..f7085cfbc 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -270,4 +270,5 @@ int main(int argc, char ** argv) { TST(slicing); TST(totalizer); TST(distribution); + TST(euf_bv_plugin); } diff --git a/src/util/dependency.h b/src/util/dependency.h index b7527cabc..ad6353320 100644 --- a/src/util/dependency.h +++ b/src/util/dependency.h @@ -69,6 +69,14 @@ public: d->unmark(); } + static void s_linearize(dependency* d, vector& vs) { + if (!d) + return; + ptr_vector todo; + todo.push_back(d); + linearize_todo(todo, vs); + } + private: struct join : public dependency { dependency * m_children[2]; @@ -325,6 +333,11 @@ public: return m_dep_manager.linearize(d, vs); } + static vector const& s_linearize(dependency* d, vector& vs) { + dep_manager::s_linearize(d, vs); + return vs; + } + void linearize(ptr_vector& d, vector & vs) { return m_dep_manager.linearize(d, vs); }