3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

add E(T) functionality for bv and ac functions

Add an option to register EUF modulo theories,
The current theory with a unit test is BV.
The arithmetic theory plugs into an AC completion. It is partially finished, pending setting up testing and implementing handling of shared terms.
This commit is contained in:
Nikolaj Bjorner 2023-11-12 15:39:45 -08:00
parent 9ce47ab460
commit 65a8c162f5
19 changed files with 1830 additions and 21 deletions

View file

@ -1,8 +1,12 @@
z3_add_component(euf
SOURCES
euf_ac_plugin.cpp
euf_arith_plugin.cpp
euf_bv_plugin.cpp
euf_egraph.cpp
euf_enode.cpp
euf_etable.cpp
euf_egraph.cpp
euf_plugin.cpp
COMPONENT_DEPENDENCIES
ast
util

51
src/ast/euf/ac_plugin.h Normal file
View file

@ -0,0 +1,51 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_ac_plugin.h
Abstract:
plugin structure for ac functions
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
Jakob Rath 2023-11-11
--*/
#pragma once
#include "ast/euf/euf_plugin.h"
namespace euf {
class ac_plugin : public plugin {
struct eq {
enode_vector l, r;
};
vector<eq> m_eqs;
vector<enode_vector> m_use;
unsigned m_fid;
unsigned m_op;
void push_eq(enode* l, enode* r);
public:
ac_plugin(egraph& g, unsigned fid, unsigned op);
unsigned get_id() const override { return m_fid; }
void register_node(enode* n) override;
void merge_eh(enode* n1, enode* n2, justification j) override;
void diseq_eh(enode* n1, enode* n2) override;
void undo() override;
std::ostream& display(std::ostream& out) const override;
};

View file

@ -0,0 +1,591 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_ac_plugin.cpp
Abstract:
plugin structure for ac functions
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
Completion modulo AC
E set of eqs
pick critical pair xy = z by j1 xu = v by j2 in E
Add new equation zu = xyu = vy by j1, j2
Notes:
- Some equalities come from shared terms, so do not.
More notes:
Justifications for new equations are joined (requires extension to egraph/justification)
Process new merges so use list is updated
Justifications for processed merges are recorded
Updated equations are recorded for restoration on backtracking
Keep track of foreign / shared occurrences of AC functions.
- use register_shared to accumulate shared occurrences.
Shared occurrences are rewritten modulo completion.
When equal to a different shared occurrence, propagate equality.
--*/
#pragma once
#include "ast/euf/euf_ac_plugin.h"
#include "ast/euf/euf_egraph.h"
namespace euf {
ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op):
plugin(g), m_fid(fid), m_op(op)
{}
void ac_plugin::register_node(enode* n) {
}
void ac_plugin::register_shared(enode* n) {
auto m = to_monomial(n);
auto const& ns = monomial(m);
for (auto arg : ns) {
arg->shared.push_back(m);
m_node_trail.push_back(arg);
push_undo(is_add_shared);
}
m_shared_trail.push_back(m);
push_undo(is_register_shared);
}
void ac_plugin::undo() {
auto k = m_undo.back();
m_undo.pop_back();
switch (k) {
case is_add_eq: {
auto const& eq = m_eqs.back();
for (auto* n : monomial(eq.l))
n->lhs.pop_back();
for (auto* n : monomial(eq.r))
n->rhs.pop_back();
m_eqs.pop_back();
break;
}
case is_add_node: {
auto* n = m_node_trail.back();
m_node_trail.pop_back();
m_nodes[n->n->get_id()] = nullptr;
n->~node();
break;
}
case is_add_monomial: {
m_monomials.pop_back();
m_monomial_enodes.pop_back();
break;
}
case is_merge_node: {
auto [other, old_shared, old_lhs, old_rhs] = m_merge_trail.back();
auto* root = other->root;
std::swap(other->next, root->next);
root->shared.shrink(old_shared);
root->lhs.shrink(old_lhs);
root->rhs.shrink(old_rhs);
m_merge_trail.pop_back();
break;
}
case is_update_eq: {
auto const & [idx, eq] = m_update_eq_trail.back();
m_eqs[idx] = eq;
m_update_eq_trail.pop_back();
break;
}
case is_add_shared: {
auto n = m_node_trail.back();
m_node_trail.pop_back();
n->shared.pop_back();
break;
}
case is_register_shared: {
m_shared_trail.pop_back();
break;
}
case is_join_justification: {
m_dep_manager.pop_scope(1);
break;
}
default:
UNREACHABLE();
}
}
std::ostream& ac_plugin::display(std::ostream& out) const {
unsigned i = 0;
for (auto const& eq : m_eqs) {
out << i << ": " << eq.l << " == " << eq.r << ": ";
for (auto n : monomial(eq.l))
out << g.bpp(n->n) << " ";
out << "== ";
for (auto n : monomial(eq.r))
out << g.bpp(n->n) << " ";
out << "\n";
++i;
}
i = 0;
for (auto m : m_monomials) {
out << i << ": ";
for (auto n : m)
out << g.bpp(n->n) << " ";
out << "\n";
++i;
}
for (auto n : m_nodes) {
out << g.bpp(n->n) << " r: " << n->root_id() << "\n";
out << "lhs ";
for (auto l : n->lhs)
out << l << " ";
out << "rhs ";
for (auto r : n->rhs)
out << r << " ";
out << "shared ";
for (auto s : n->shared)
out << s << " ";
out << "\n";
}
return out;
}
void ac_plugin::merge_eh(enode* l, enode* r, justification j) {
if (l == r)
return;
if (!is_op(l) && !is_op(r))
merge(mk_node(l), mk_node(r), j);
else
init_equation({ to_monomial(l), to_monomial(r), false, j });
}
void ac_plugin::init_equation(eq const& e) {
m_eqs.push_back(e);
auto& eq = m_eqs.back();
if (orient_equation(eq)) {
push_undo(is_add_eq);
unsigned eq_id = m_eqs.size() - 1;
for (auto n : monomial(eq.l))
n->lhs.push_back(eq_id);
for (auto n : monomial(eq.r))
n->rhs.push_back(eq_id);
}
else
m_eqs.pop_back();
}
bool ac_plugin::orient_equation(eq& e) {
auto& ml = monomial(e.l);
auto& mr = monomial(e.r);
if (ml.size() > mr.size())
return true;
if (ml.size() < mr.size()) {
std::swap(e.l, e.r);
return true;
}
else {
std::sort(ml.begin(), ml.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); });
std::sort(mr.begin(), mr.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); });
for (unsigned i = ml.size(); i-- > 0;) {
if (ml[i] == mr[i])
continue;
if (ml[i]->root_id() < mr[i]->root_id())
std::swap(e.l, e.r);
return true;
}
return false;
}
}
void ac_plugin::merge(node* root, node* other, justification j) {
for (auto n : equiv(other))
n->root = root;
m_merge_trail.push_back({ other, root->shared.size(), root->lhs.size(), root->rhs.size()});
for (auto eq_id : other->lhs)
set_processed(eq_id, false);
for (auto eq_id : other->rhs)
set_processed(eq_id, false);
root->shared.append(other->shared);
root->lhs.append(other->lhs);
root->rhs.append(other->rhs);
std::swap(root->next, other->next);
push_undo(is_merge_node);
}
void ac_plugin::push_undo(undo_kind k) {
m_undo.push_back(k);
push_plugin_undo(get_id());
m_undo_notify(); // tell main plugin to dispatch undo to this module.
}
unsigned ac_plugin::to_monomial(enode* n) {
enode_vector& ns = m_todo;
ns.reset();
ptr_vector<node> ms;
ns.push_back(n);
for (unsigned i = 0; i < ns.size(); ++i) {
n = ns[i];
if (is_op(n)) {
ns.append(n->num_args(), n->args());
ns[i] = ns.back();
ns.pop_back();
--i;
}
else {
ms.push_back(mk_node(n));
}
}
return to_monomial(n, ms);
}
unsigned ac_plugin::to_monomial(enode* e, ptr_vector<node> const& ms) {
unsigned id = m_monomials.size();
m_monomials.push_back(ms);
m_monomial_enodes.push_back(e);
push_undo(is_add_monomial);
return id;
}
ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) {
auto* mem = r.allocate(sizeof(node));
node* res = new (mem) node();
res->n = n;
res->root = res;
res->next = res;
return res;
}
ac_plugin::node* ac_plugin::mk_node(enode* n) {
unsigned id = n->get_id();
if (m_nodes.size() > id && m_nodes[id])
return m_nodes[id];
auto* r = node::mk(get_region(), n);
push_undo(is_add_node);
m_nodes.set(id, r);
m_node_trail.push_back(r);
return r;
}
void ac_plugin::propagate() {
while (true) {
unsigned eq_id = pick_next_eq();
if (eq_id == UINT_MAX)
break;
// simplify eq using processed
for (auto other_eq : backward_iterator(eq_id))
if (is_processed(other_eq))
backward_simplify(eq_id, other_eq);
if (m_backward_simplified)
continue;
// simplify processed using eq
for (auto other_eq : forward_iterator(eq_id))
if (is_processed(other_eq))
forward_simplify(other_eq, eq_id);
// superpose, create new equations
for (auto other_eq : superpose_iterator(eq_id))
if (is_processed(other_eq))
superpose(eq_id, other_eq);
// simplify to_simplify using eq
for (auto other_eq : forward_iterator(eq_id))
if (is_to_simplify(other_eq))
forward_simplify(other_eq, eq_id);
set_processed(eq_id, true);
}
propagate_shared();
}
unsigned ac_plugin::pick_next_eq() {
for (unsigned i = 0, n = m_eqs.size(); i < n; ++i) {
unsigned id = (i + m_next_eq_index) % n;
auto const& eq = m_eqs[id];
if (eq.is_processed)
continue;
++m_next_eq_index;
return id;
}
return UINT_MAX;
}
void ac_plugin::set_processed(unsigned id, bool f) {
auto& eq = m_eqs[id];
if (eq.is_processed == f)
return;
m_update_eq_trail.push_back({ id, eq });
eq.is_processed = f;
push_undo(is_update_eq);
}
//
// superpose iterator enumerates all equations where lhs of eq have element in common.
//
unsigned_vector const& ac_plugin::superpose_iterator(unsigned eq_id) {
auto const& eq = m_eqs[eq_id];
m_src_r.reset();
m_src_r.append(monomial(eq.r));
init_ids_counts(eq_id, eq.l, m_src_ids, m_src_count);
init_overlap_iterator(eq_id, eq.l);
return m_lhs_eqs;
}
//
// backward iterator allows simplification of eq
// The rhs of eq is a super-set of lhs of other eq.
//
unsigned_vector const& ac_plugin::backward_iterator(unsigned eq_id) {
auto const& eq = m_eqs[eq_id];
init_ids_counts(eq_id, eq.r, m_dst_ids, m_dst_count);
init_overlap_iterator(eq_id, eq.r);
m_backward_simplified = false;
return m_lhs_eqs;
}
void ac_plugin::init_overlap_iterator(unsigned eq_id, unsigned monomial_id) {
m_lhs_eqs.reset();
for (auto n : monomial(monomial_id))
m_lhs_eqs.append(n->root->lhs);
// prune m_lhs_eqs to single occurrences
unsigned j = 0;
for (unsigned i = 0; i < m_lhs_eqs.size(); ++i) {
unsigned id = m_lhs_eqs[i];
m_eq_seen.reserve(id + 1, false);
if (m_eq_seen[id])
continue;
if (id == eq_id)
continue;
m_lhs_eqs[j++] = id;
m_eq_seen[id] = true;
}
m_lhs_eqs.shrink(j);
for (auto id : m_lhs_eqs)
m_eq_seen[id] = false;
}
//
// forward iterator simplifies other eqs where their rhs is a superset of lhs of eq
//
unsigned_vector const& ac_plugin::forward_iterator(unsigned eq_id) {
auto& eq = m_eqs[eq_id];
m_src_r.reset();
m_src_r.append(monomial(eq.r));
init_ids_counts(eq_id, eq.l, m_src_ids, m_src_count);
unsigned min_r = UINT_MAX;
node* min_n = nullptr;
for (auto n : monomial(eq.l))
if (n->root->rhs.size() < min_r)
min_n = n, min_r = n->root->rhs.size();
// found node that occurs in fewest rhs
VERIFY(min_n);
return min_n->rhs;
}
void ac_plugin::init_ids_counts(unsigned eq_id, unsigned monomial_id, unsigned_vector& ids, unsigned_vector& counts) {
auto& eq = m_eqs[eq_id];
reset_ids_counts(ids, counts);
for (auto n : monomial(monomial_id)) {
unsigned id = n->root_id();
counts.setx(id, counts.get(id, 0) + 1, 0);
ids.push_back(id);
}
}
void ac_plugin::reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts) {
for (auto id : ids)
counts[id] = 0;
ids.reset();
}
void ac_plugin::forward_simplify(unsigned dst_eq, unsigned src_eq) {
if (src_eq == dst_eq)
return;
// check that left src.l is a subset of dst.r
// dst = A -> BC
// src = B -> D
// post(dst) := A -> CD
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
reset_ids_counts(m_dst_ids, m_dst_count);
unsigned src_l_size = monomial(src.l).size();
unsigned src_r_size = m_src_r.size();
// subtract src.l from dst.r if src.l is a subset of dst.r
// new_rhs := old_rhs - src_lhs + src_rhs
unsigned num_overlap = 0;
for (auto n : monomial(dst.r)) {
unsigned id = n->root_id();
unsigned count = m_src_count.get(id, 0);
if (count == 0)
m_src_r.push_back(n);
else {
unsigned dst_count = m_dst_count.get(id, 0);
if (dst_count >= count)
m_src_r.push_back(n);
else
m_dst_count.set(id, dst_count + 1), m_dst_ids.push_back(id), ++num_overlap;
}
}
// The dst.r has to be a superset of src.l, otherwise simplification does not apply
if (num_overlap == src_l_size) {
auto new_r = to_monomial(nullptr, m_src_r);
m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] });
m_eqs[dst_eq].r = new_r;
m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq);
push_undo(is_update_eq);
}
m_src_r.shrink(src_r_size);
}
void ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) {
if (src_eq == dst_eq)
return;
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
//
// dst_ids, dst_count contain rhs of dst_eq
//
// check that src.l is a subset of dst.r
reset_ids_counts(m_src_ids, m_src_count);
bool is_subset = true;
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
unsigned dst_count = m_dst_count.get(id, 0);
if (dst_count == 0) {
is_subset = false;
break;
}
else {
unsigned src_count = m_src_count.get(id, 0);
if (src_count >= dst_count) {
is_subset = false;
break;
}
else
m_src_count.set(id, src_count + 1), m_src_ids.push_back(id);
}
}
if (is_subset) {
// dst_rhs := dst_rhs - src_lhs + src_rhs
m_src_r.reset();
m_src_r.append(monomial(src.r));
// add to m_src_r elements of dst.r that are not in src.l
for (auto n : monomial(dst.r)) {
unsigned id = n->root_id();
unsigned count = m_src_count.get(id, 0);
if (count == 0)
m_src_r.push_back(n);
else
--m_src_count[id];
}
auto new_r = to_monomial(nullptr, m_src_r);
m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] });
m_eqs[dst_eq].r = new_r;
m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq);
push_undo(is_update_eq);
m_backward_simplified = true;
}
}
void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) {
if (src_eq == dst_eq)
return;
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
// AB -> C, AD -> E => BE ~ CD
// m_src_ids, m_src_counts contains information about src (call it AD -> E)
reset_ids_counts(m_dst_ids, m_dst_count);
m_dst_r.reset();
m_dst_r.append(monomial(dst.r));
unsigned src_r_size = m_src_r.size();
SASSERT(src_r_size == monomial(src.r).size());
// dst_r contains C
// src_r contains E
// compute BE, initialize dst_ids, dst_counts
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
unsigned src_count = m_src_count.get(id, 0);
unsigned dst_count = m_dst_count.get(id, 0);
m_dst_count.set(id, dst_count + 1);
m_dst_ids.push_back(id);
if (src_count < dst_count)
m_src_r.push_back(n);
}
// compute CD
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
unsigned dst_count = m_dst_count.get(id, 0);
if (dst_count > 0)
--m_dst_count[id];
else
m_dst_r.push_back(n);
}
justification j = justify_rewrite(src_eq, dst_eq);
if (m_src_r.size() == 1 && m_dst_r.size() == 1)
push_merge(m_src_r[0]->n, m_dst_r[0]->n, j);
else
init_equation({ to_monomial(nullptr, m_src_r), to_monomial(nullptr, m_dst_r), false, j });
m_src_r.shrink(src_r_size);
}
void ac_plugin::propagate_shared() {
for (auto m : m_shared_trail)
simplify_shared(m);
// check for collisions, push_merge when there is a collision.
}
void ac_plugin::simplify_shared(unsigned monomial_id) {
// apply processed as a set of rewrites
}
justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) {
auto const& e1 = m_eqs[eq1];
auto const& e2 = m_eqs[eq2];
auto* j = m_dep_manager.mk_join(m_dep_manager.mk_leaf(e1.j), m_dep_manager.mk_leaf(e2.j));
j = justify_monomial(j, monomial(e1.l));
j = justify_monomial(j, monomial(e1.r));
j = justify_monomial(j, monomial(e2.l));
j = justify_monomial(j, monomial(e2.r));
m_dep_manager.push_scope();
push_undo(is_join_justification);
return justification::dependent(j);
}
justification::dependency* ac_plugin::justify_monomial(justification::dependency* j, ptr_vector<node> const& m) {
for (auto n : m)
if (n->root->n != n->n)
j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n)));
return j;
}
}

174
src/ast/euf/euf_ac_plugin.h Normal file
View file

@ -0,0 +1,174 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_ac_plugin.h
Abstract:
plugin structure for ac functions
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
ex:
xyz -> xy, then xyzz -> xy by repeated rewriting
monomials = [0 |-> xyz, 1 |-> xy, 2 |-> xyzz]
parents(x) = [0, 1, 2]
parents(z) = [0, 1]
for p in parents(xyzz):
p != xyzz
p' := simplify_using(xyzz, p)
if p != p':
repeat reduction using p := p'
--*/
#pragma once
#include "ast/euf/euf_plugin.h"
namespace euf {
class ac_plugin : public plugin {
// enode structure for AC equivalenes
struct node {
enode* n; // associated enode
node* root; // path compressed root
node* next; // next in equaivalence class
justification j; // justification for equality
node* target = nullptr; // justified next
unsigned_vector shared; // shared occurrences
unsigned_vector lhs; // left hand side of equalities
unsigned_vector rhs; // left side of equalities
unsigned root_id() const { return root->n->get_id(); }
~node() {}
static node* mk(region& r, enode* n);
};
class equiv {
node& n;
public:
class iterator {
node* m_first;
node* m_last;
public:
iterator(node* n, node* m) : m_first(n), m_last(m) {}
node* operator*() { return m_first; }
iterator& operator++() { if (!m_last) m_last = m_first; m_first = m_first->next; return *this; }
iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; }
bool operator==(iterator const& other) const { return m_last == other.m_last && m_first == other.m_first; }
bool operator!=(iterator const& other) const { return !(*this == other); }
};
equiv(node& _n) :n(_n) {}
equiv(node* _n) :n(*_n) {}
iterator begin() const { return iterator(&n, nullptr); }
iterator end() const { return iterator(&n, &n); }
};
struct eq {
unsigned l, r; // refer to monomials
bool is_processed = false;
justification j;
};
unsigned m_fid;
unsigned m_op;
vector<eq> m_eqs;
ptr_vector<node> m_nodes;
vector<ptr_vector<node>> m_monomials;
enode_vector m_monomial_enodes;
justification::dependency_manager m_dep_manager;
// backtrackable state
enum undo_kind {
is_add_eq,
is_add_monomial,
is_add_node,
is_merge_node,
is_update_eq,
is_add_shared,
is_register_shared,
is_join_justification
};
svector<undo_kind> m_undo;
ptr_vector<node> m_node_trail;
unsigned_vector m_monomial_trail, m_shared_trail;
svector<std::tuple<node*, unsigned, unsigned, unsigned>> m_merge_trail;
svector<std::pair<unsigned, eq>> m_update_eq_trail;
node* mk_node(enode* n);
void merge(node* r1, node* r2, justification j);
bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_kind(); }
std::function<void(void)> m_undo_notify;
void push_undo(undo_kind k);
enode_vector m_todo;
unsigned to_monomial(enode* n);
unsigned to_monomial(enode* n, ptr_vector<node> const& ms);
ptr_vector<node> const& monomial(unsigned i) const { return m_monomials[i]; }
ptr_vector<node>& monomial(unsigned i) { return m_monomials[i]; }
void init_equation(eq const& e);
bool orient_equation(eq & e);
void set_processed(unsigned eq_id, bool f);
unsigned pick_next_eq();
bool is_trivial(unsigned eq_id) const { throw default_exception("NYI"); }
void forward_simplify(unsigned eq_id, unsigned using_eq);
void backward_simplify(unsigned eq_id, unsigned using_eq);
void superpose(unsigned src_eq, unsigned dst_eq);
ptr_vector<node> m_src_r, m_src_l, m_dst_r;
unsigned_vector m_src_ids, m_src_count, m_dst_ids, m_dst_count;
unsigned_vector m_lhs_eqs;
bool_vector m_eq_seen;
bool m_backward_simplified = false;
unsigned m_next_eq_index = 0;
unsigned_vector const& forward_iterator(unsigned eq);
unsigned_vector const& superpose_iterator(unsigned eq);
unsigned_vector const& backward_iterator(unsigned eq);
void init_ids_counts(unsigned eq, unsigned monomial_id, unsigned_vector& ids, unsigned_vector& counts);
void reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts);
void init_overlap_iterator(unsigned eq, unsigned monomial_id);
bool is_to_simplify(unsigned eq) const { return !m_eqs[eq].is_processed; }
bool is_processed(unsigned eq) const { return m_eqs[eq].is_processed; }
justification justify_rewrite(unsigned eq1, unsigned eq2);
justification justify_superpose(justification j1, justification j2);
justification::dependency* justify_monomial(justification::dependency* d, ptr_vector<node> const& m);
void propagate_shared();
void simplify_shared(unsigned monomial_id);
public:
ac_plugin(egraph& g, unsigned fid, unsigned op);
unsigned get_id() const override { return m_fid; }
void register_node(enode* n) override;
void register_shared(enode* n) override;
void merge_eh(enode* n1, enode* n2, justification j) override;
void diseq_eh(enode* n1, enode* n2) override {}
void undo() override;
void propagate() override;
std::ostream& display(std::ostream& out) const override;
void set_undo(std::function<void(void)> u) { m_undo_notify = u; }
};
}

View file

@ -0,0 +1,77 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_arith_plugin.cpp
Abstract:
plugin structure for arithetic
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#include "ast/euf/euf_arith_plugin.h"
#include "ast/euf/euf_egraph.h"
#include <algorithm>
namespace euf {
arith_plugin::arith_plugin(egraph& g) :
plugin(g),
a(g.get_manager()),
m_add(g, get_id(), OP_ADD),
m_mul(g, get_id(), OP_MUL) {
std::function<void(void)> uadd = [&]() { m_undo.push_back(undo_t::undo_add); };
m_add.set_undo(uadd);
std::function<void(void)> umul = [&]() { m_undo.push_back(undo_t::undo_mul); };
m_mul.set_undo(umul);
}
void arith_plugin::register_node(enode* n) {
// no-op
}
void arith_plugin::register_shared(enode* n) {
if (a.is_add(n->get_expr()))
m_add.register_shared(n);
if (a.is_mul(n->get_expr()))
m_mul.register_shared(n);
}
void arith_plugin::merge_eh(enode* n1, enode* n2, justification j) {
m_add.merge_eh(n1, n2, j);
m_mul.merge_eh(n1, n2, j);
}
void arith_plugin::diseq_eh(enode* n1, enode* n2) {
// no-op
}
void arith_plugin::undo() {
auto k = m_undo.back();
m_undo.pop_back();
switch (k) {
case undo_t::undo_add:
m_add.undo();
break;
case undo_t::undo_mul:
m_mul.undo();
break;
default:
UNREACHABLE();
}
}
std::ostream& arith_plugin::display(std::ostream& out) const {
out << "add\n";
m_add.display(out);
out << "mul\n";
m_mul.display(out);
return out;
}
}

View file

@ -0,0 +1,51 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_arith_plugin.h
Abstract:
plugin structure for arithetic
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#pragma once
#include "ast/arith_decl_plugin.h"
#include "ast/euf/euf_plugin.h"
#include "ast/euf/euf_ac_plugin.h"
namespace euf {
class egraph;
class arith_plugin : public plugin {
enum undo_t { undo_add, undo_mul };
arith_util a;
svector<undo_t> m_undo;
ac_plugin m_add, m_mul;
public:
arith_plugin(egraph& g);
unsigned get_id() const override { return a.get_family_id(); }
void register_node(enode* n) override;
void register_shared(enode* n) override;
void merge_eh(enode* n1, enode* n2, justification j) override;
void diseq_eh(enode* n1, enode* n2) override;
void undo() override;
std::ostream& display(std::ostream& out) const override;
};
}

View file

@ -0,0 +1,347 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_bv_plugin.cpp
Abstract:
plugin structure for bit-vectors
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
Jakob Rath 2023-11-08
Objective:
satisfies extract/concat axioms.
- concat(n{I],n[J]) = n[IJ] for I, J consecutive.
- concat(v1, v2) = 2^width(v1)*v2 + v1
- concat(n[width(n)-1:0]) = n
- concat(a, b)[I] = concat(a[I1], b[I2])
- concat(a, concat(b, c)) = concat(concat(a, b), c)
E-graph:
The E-graph contains node definitions of the form
n := f(n1,n2,..)
and congruences:
n ~ n' means root(n) = root(n')
Saturated state:
1. n := n1[I], n' := n2[J], n1 ~ n2 => root(n1) contains tree refining both I, J from smaller intervals
2. n := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => n ~ n3[IJ]
3. n := concat(n1[I], n2[J]), I and J are consecutive & n1 ~ n2, n1[I] ~ v1, n1[J] ~ v2 => n ~ 2^width(v1)*v2 + v1
4. n := concat(n1[I], n2[J], I, J are consecutive, n1 ~ n2, n ~ v => n1[I] ~ v mod 2^width(n1[I]), n2[J] ~ v div 2^width(n1[I])
5. n' := n[I] => n ~ n[width(n)-1:0]
6. n := concat(a, concat(b, c)) => n ~ concat(concat(a, b), c)
- handled by rewriter pre-processing for inputs
- terms created internally are not equated modulo associativity
7, n := concat(n1, n2)[I] => n ~ concat(n1[I1],n2[I2]) or n[I1] or n[I2]
- handled by rewriter pre-processing
Example:
x == (x1 x2) x3
y == y1 (y2 y3)
x1 == y1, x2 == y2, x3 == y3
=>
x = y
by x2 == y2, x3 == y3 => (x2 x3) = (y2 y3)
by (2) => x[I23] = (x2 x3)
by (2) => x[I123] = (x1 (x2 x3))
by (5) => x = x[I123]
--*/
#include "ast/euf/euf_bv_plugin.h"
#include "ast/euf/euf_egraph.h"
namespace euf {
bv_plugin::bv_plugin(egraph& g):
plugin(g),
bv(g.get_manager())
{}
enode* bv_plugin::mk_value_concat(enode* a, enode* b) {
auto v1 = get_value(a);
auto v2 = get_value(b);
auto v3 = v1 + v2 * power(rational(2), width(a));
return mk_value(v3, width(a) + width(b));
}
enode* bv_plugin::mk_value(rational const& v, unsigned sz) {
auto e = bv.mk_numeral(v, sz);
return mk(e, 0, nullptr);
}
void bv_plugin::merge_eh(enode* x, enode* y, justification j) {
SASSERT(x == x->get_root());
SASSERT(x == y->get_root());
TRACE("bv", tout << "merge_eh " << g.bpp(x) << " == " << g.bpp(y) << "\n");
SASSERT(!m_internal);
flet<bool> _internal(m_internal, true);
propagate_values(x);
// ensure slices align
if (has_sub(x) || has_sub(y)) {
enode_vector& xs = m_xs, & ys = m_ys;
xs.reset();
ys.reset();
xs.push_back(x);
ys.push_back(y);
merge(xs, ys, j);
}
// ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ]
for (auto* n : enode_class(x))
propagate_extract(n);
}
// enforce concat(v1, v2) = v2*2^|v1| + v1
void bv_plugin::propagate_values(enode* x) {
if (!is_value(x))
return;
enode* a, * b;
for (enode* p : enode_parents(x))
if (is_concat(p, a, b) && is_value(a) && is_value(b) && !is_value(p))
push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b));
for (enode* sib : enode_class(x)) {
if (is_concat(sib, a, b)) {
if (!is_value(a) || !is_value(b)) {
auto val = get_value(x);
auto v1 = mod2k(val, width(a));
auto v2 = machine_div2k(val, width(a));
push_merge(mk_concat(mk_value(v1, width(a)), mk_value(v2, width(b))), x->get_interpreted());
}
}
}
}
//
// p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ]
//
// n is of form arg[I]
// p is of form concat(n, b) or concat(a, n)
// b is congruent to arg[J], I is consecutive with J => ensure that arg[IJ] = p
// a is congruent to arg[J], J is consecutive with I => ensure that arg[JI] = p
//
void bv_plugin::propagate_extract(enode* n) {
unsigned lo1, hi1, lo2, hi2;
enode* a, * b;
if (!is_extract(n, lo1, hi1))
return;
enode* arg = n->get_arg(0);
enode* arg_r = arg->get_root();
enode* n_r = n->get_root();
auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) {
TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n");
unsigned lo_, hi_;
for (enode* p1 : enode_parents(n))
if (is_extract(p1, lo_, hi_) && lo_ == lo && hi_ == hi && p1->get_arg(0)->get_root() == arg_r)
return;
// add the axiom instead of merge(p, mk_extract(arg, lo, hi)), which would require tracking justifications
push_merge(mk_concat(mk_extract(arg, lo, mid), mk_extract(arg, mid + 1, hi)), mk_extract(arg, lo, hi));
};
auto propagate_left = [&](enode* b) {
TRACE("bv", tout << "propagate-left " << g.bpp(b) << "\n");
for (enode* sib : enode_class(b))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2)
ensure_concat(lo1, hi1, hi2);
};
auto propagate_right = [&](enode* a) {
TRACE("bv", tout << "propagate-right " << g.bpp(a) << "\n");
for (enode* sib : enode_class(a))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1)
ensure_concat(lo2, hi2, hi1);
};
for (enode* p : enode_parents(n)) {
if (is_concat(p, a, b)) {
if (a->get_root() == n_r)
propagate_left(b);
if (b->get_root() == n_r)
propagate_right(a);
}
}
}
void bv_plugin::push_undo_split(enode* n) {
m_undo_split.push_back(n);
push_plugin_undo(bv.get_family_id());
}
void bv_plugin::undo() {
enode* n = m_undo_split.back();
m_undo_split.pop_back();
auto& i = info(n);
i.lo = nullptr;
i.hi = nullptr;
i.cut = null_cut;
}
void bv_plugin::register_node(enode* n) {
TRACE("bv", tout << "register " << g.bpp(n) << "\n");
auto& i = info(n);
i.value = n;
enode* a, * b;
if (is_concat(n, a, b)) {
i.lo = a;
i.hi = b;
i.cut = width(a);
push_undo_split(n);
}
unsigned lo, hi;
if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) {
enode* arg = n->get_arg(0);
unsigned w = width(arg);
if (all_of(enode_parents(arg), [&](enode* p) { unsigned _lo, _hi; return !is_extract(p, _lo, _hi) || _lo != 0 || _hi + 1 != w; }))
push_merge(mk_extract(arg, 0, w - 1), arg);
ensure_slice(arg, lo, hi);
}
}
//
// Ensure that there are slices at boundaries of n[hi:lo]
//
void bv_plugin::ensure_slice(enode* n, unsigned lo, unsigned hi) {
enode* r = n;
unsigned lb = 0, ub = width(n) - 1;
while (true) {
TRACE("bv", tout << "ensure slice " << g.bpp(n) << " " << lb << " [" << lo << ", " << hi << "] " << ub << "\n");
SASSERT(lb <= lo && hi <= ub);
SASSERT(ub - lb + 1 == width(r));
if (lb == lo && ub == hi)
return;
slice_info& i = info(r);
if (!i.lo) {
if (lo > lb) {
split(r, lo - lb);
if (hi < ub) // or split(info(r).hi, ...)
ensure_slice(n, lo, hi);
}
else if (hi < ub)
split(r, ub - hi);
break;
}
auto cut = i.cut;
if (cut + lb <= lo) {
lb += cut;
r = i.hi;
continue;
}
if (cut + lb > hi) {
ub = cut + lb - 1;
r = i.lo;
continue;
}
SASSERT(lo < cut + lb && cut + lb <= hi);
ensure_slice(n, lo, cut + lb - 1);
ensure_slice(n, cut + lb, hi);
break;
}
}
enode* bv_plugin::mk_extract(enode* n, unsigned lo, unsigned hi) {
SASSERT(lo <= hi && width(n) > hi - lo);
unsigned lo1, hi1;
while (is_extract(n, lo1, hi1)) {
lo += lo1;
hi += lo1;
n = n->get_arg(0);
}
return mk(bv.mk_extract(hi, lo, n->get_expr()), 1, &n);
}
enode* bv_plugin::mk_concat(enode* lo, enode* hi) {
enode* args[2] = { lo, hi };
return mk(bv.mk_concat(lo->get_expr(), hi->get_expr()), 2, args);
}
void bv_plugin::merge(enode_vector& xs, enode_vector& ys, justification dep) {
while (!xs.empty()) {
SASSERT(!ys.empty());
auto x = xs.back();
auto y = ys.back();
if (unfold_sub(x, xs))
continue;
else if (unfold_sub(y, ys))
continue;
else if (unfold_width(x, xs, y, ys))
continue;
else if (unfold_width(y, ys, x, xs))
continue;
else if (x->get_root() != y->get_root())
push_merge(x, y, dep);
xs.pop_back();
ys.pop_back();
}
SASSERT(ys.empty());
}
bool bv_plugin::unfold_sub(enode* x, enode_vector& xs) {
if (!has_sub(x))
return false;
xs.pop_back();
xs.push_back(sub_hi(x));
xs.push_back(sub_lo(x));
return true;
}
bool bv_plugin::unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys) {
if (width(x) <= width(y))
return false;
split(x, width(y));
xs.pop_back();
xs.push_back(sub_hi(x));
xs.push_back(sub_lo(x));
return true;
}
void bv_plugin::split(enode* n, unsigned cut) {
TRACE("bv", tout << "split: " << g.bpp(n) << " " << cut << "\n");
unsigned w = width(n);
SASSERT(!info(n).hi);
SASSERT(0 < cut && cut < w);
enode* hi = mk_extract(n, cut, w - 1);
enode* lo = mk_extract(n, 0, cut - 1);
auto& i = info(n);
SASSERT(i.value);
i.hi = hi;
i.lo = lo;
i.cut = cut;
push_undo_split(n);
push_merge(mk_concat(lo, hi), n);
}
std::ostream& bv_plugin::display(std::ostream& out) const {
out << "bv\n";
for (auto const& i : m_info)
if (i.lo)
out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n";
return out;
}
}

100
src/ast/euf/euf_bv_plugin.h Normal file
View file

@ -0,0 +1,100 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_bv_plugin.h
Abstract:
plugin structure for bit-vectors
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
Jakob Rath 2023-11-08
--*/
#pragma once
#include "ast/bv_decl_plugin.h"
#include "ast/euf/euf_plugin.h"
namespace euf {
class egraph;
class bv_plugin : public plugin {
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
struct slice_info {
unsigned cut = null_cut; // = bv.get_bv_size(lo)
enode* hi = nullptr; //
enode* lo = nullptr; //
enode* value = nullptr;
void reset() { *this = slice_info(); }
};
using slice_info_vector = svector<slice_info>;
bv_util bv;
slice_info_vector m_info; // indexed by enode::get_id()
enode_vector m_xs, m_ys;
bool is_concat(enode* n) const { return bv.is_concat(n->get_expr()); }
bool is_concat(enode* n, enode*& a, enode*& b) { return is_concat(n) && (a = n->get_arg(0), b = n->get_arg(1), true); }
bool is_extract(enode* n, unsigned& lo, unsigned& hi) { expr* body; return bv.is_extract(n->get_expr(), lo, hi, body); }
bool is_extract(enode* n) const { return bv.is_extract(n->get_expr()); }
unsigned width(enode* n) const { return bv.get_bv_size(n->get_expr()); }
enode* mk_extract(enode* n, unsigned lo, unsigned hi);
enode* mk_concat(enode* lo, enode* hi);
enode* mk_value_concat(enode* lo, enode* hi);
enode* mk_value(rational const& v, unsigned sz);
unsigned width(enode* n) { return bv.get_bv_size(n->get_expr()); }
bool is_value(enode* n) { return n->get_root()->interpreted(); }
rational get_value(enode* n) { rational val; VERIFY(bv.is_numeral(n->get_interpreted()->get_expr(), val)); return val; }
slice_info& info(enode* n) { unsigned id = n->get_id(); m_info.reserve(id + 1); return m_info[id]; }
slice_info& root_info(enode* n) { unsigned id = n->get_root_id(); m_info.reserve(id + 1); return m_info[id]; }
bool has_sub(enode* n) { return !!info(n).lo; }
enode* sub_lo(enode* n) { return info(n).lo; }
enode* sub_hi(enode* n) { return info(n).hi; }
bool m_internal = false;
void ensure_slice(enode* n, unsigned lo, unsigned hi);
void split(enode* n, unsigned cut);
bool unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys);
bool unfold_sub(enode* x, enode_vector& xs);
void merge(enode_vector& xs, enode_vector& ys, justification j);
void propagate_extract(enode* n);
void propagate_values(enode* n);
enode_vector m_undo_split;
void push_undo_split(enode* n);
public:
bv_plugin(egraph& g);
unsigned get_id() const override { return bv.get_family_id(); }
void register_node(enode* n) override;
void register_shared(enode* n) override {}
void merge_eh(enode* n1, enode* n2, justification j) override;
void diseq_eh(enode* n1, enode* n2) override {}
void propagate() override {}
void undo() override;
std::ostream& display(std::ostream& out) const override;
};
}

View file

@ -18,6 +18,7 @@ Notes:
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/ast_pp.h"
#include "ast/ast_translation.h"
@ -113,8 +114,11 @@ namespace euf {
n->mark_interpreted();
if (m_on_make)
m_on_make(n);
if (num_args == 0)
register_node(n);
if (num_args == 0)
return n;
if (m.is_eq(f) && !m.is_iff(f)) {
n->set_is_equality();
reinsert_equality(n);
@ -123,11 +127,26 @@ namespace euf {
if (n2 == n)
update_children(n);
else
merge(n, n2, justification::congruence(comm, m_congruence_timestamp++));
push_merge(n, n2, justification::congruence(comm, 0));
// merge(n, n2, justification::congruence(comm, m_congruence_timestamp++));
return n;
}
void egraph::register_node(enode* n) {
if (m_plugins.empty())
return;
auto* p = get_plugin(n);
if (p)
p->register_node(n);
for (auto* arg : enode_args(n)) {
auto* p_arg = get_plugin(arg);
if (p != p_arg)
p_arg->register_shared(arg);
}
}
egraph::egraph(ast_manager& m) : m(m), m_table(m), m_tmp_app(2), m_exprs(m), m_eq_decls(m) {
m_tmp_eq = enode::mk_tmp(m_region, 2);
}
@ -139,6 +158,18 @@ namespace euf {
memory::deallocate(m_tmp_node);
}
void egraph::add_plugins() {
auto* plugin = alloc(bv_plugin, *this);
m_plugins.reserve(plugin->get_id() + 1);
m_plugins.set(plugin->get_id(), plugin);
}
void egraph::propagate_plugins() {
for (auto* p : m_plugins)
if (p)
p->propagate();
}
void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) {
TRACE("euf_verbose", tout << "eq: " << v1 << " == " << v2 << "\n";);
m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r));
@ -422,6 +453,9 @@ namespace euf {
p.r1->m_args[i]->get_root()->m_parents.pop_back();
}
break;
case update_record::tag_t::is_plugin_undo:
m_plugins[p.m_th_id]->undo();
break;
default:
UNREACHABLE();
break;
@ -442,7 +476,7 @@ namespace euf {
if (!n1->cgc_enabled() && !n2->cgc_enabled())
return;
SASSERT(n1->get_sort() == n2->get_sort());
enode* r1 = n1->get_root();
enode* r2 = n2->get_root();
if (r1 == r2)
@ -452,6 +486,7 @@ namespace euf {
IF_VERBOSE(20, j.display(verbose_stream() << "merge: " << bpp(n1) << " == " << bpp(n2) << " ", m_display_justification) << "\n";);
force_push();
SASSERT(m_num_scopes == 0);
SASSERT(n1->get_sort() == n2->get_sort());
++m_stats.m_num_merge;
if (r1->interpreted() && r2->interpreted()) {
set_conflict(n1, n2, j);
@ -476,7 +511,7 @@ namespace euf {
c->m_root = r2;
std::swap(r1->m_next, r2->m_next);
r2->inc_class_size(r1->class_size());
merge_th_eq(r1, r2);
merge_th_eq(r1, r2, j);
reinsert_parents(r1, r2);
if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr())))
add_literal(n1, r2);
@ -487,6 +522,10 @@ namespace euf {
for (auto& cb : m_on_merge)
cb(r2, r1);
auto* p = get_plugin(r1);
if (p)
p->merge_eh(r2, r1, j);
}
void egraph::remove_parents(enode* r) {
@ -532,7 +571,7 @@ namespace euf {
}
}
void egraph::merge_th_eq(enode* n, enode* root) {
void egraph::merge_th_eq(enode* n, enode* root, justification j) {
SASSERT(n != root);
for (auto const& iv : enode_th_vars(n)) {
theory_id id = iv.get_id();
@ -574,13 +613,17 @@ namespace euf {
unmerge_justification(n1);
}
bool egraph::propagate() {
SASSERT(m_num_scopes == 0 || m_to_merge.empty());
bool egraph::propagate() {
force_push();
for (unsigned i = 0; i < m_to_merge.size() && m.limit().inc() && !inconsistent(); ++i) {
auto const& w = m_to_merge[i];
merge(w.a, w.b, justification::congruence(w.commutativity, m_congruence_timestamp++));
if (w.j.is_congruence())
merge(w.a, w.b, justification::congruence(w.j.is_commutative(), m_congruence_timestamp++));
else
merge(w.a, w.b, w.j);
if (i + 1 == m_to_merge.size())
propagate_plugins();
}
m_to_merge.reset();
return
@ -746,8 +789,15 @@ namespace euf {
TRACE("euf_verbose", tout << "explain-eq: " << bpp(a) << " == " << bpp(b) << " jst: " << j << "\n";);
if (j.is_external())
justifications.push_back(j.ext<T>());
else if (j.is_congruence())
else if (j.is_congruence())
push_congruence(a, b, j.is_commutative());
else if (j.is_dependent()) {
vector<justification, false> js;
for (auto const& j2 : justification::dependency_manager::s_linearize(j.get_dependency(), js))
explain_eq(justifications, cc, a, b, j2);
}
else if (j.is_equality())
explain_eq(justifications, cc, j.lhs(), j.rhs());
if (cc && j.is_congruence())
cc->push_back(std::tuple(a->get_app(), b->get_app(), j.timestamp(), j.is_commutative()));
}
@ -867,7 +917,10 @@ namespace euf {
for (enode* n : m_nodes)
max_args = std::max(max_args, n->num_args());
for (enode* n : m_nodes)
display(out, max_args, n);
display(out, max_args, n);
for (auto* p : m_plugins)
if (p)
p->display(out);
return out;
}

View file

@ -29,8 +29,10 @@ Notes:
#include "util/statistics.h"
#include "util/trail.h"
#include "util/lbool.h"
#include "util/scoped_ptr_vector.h"
#include "ast/euf/euf_enode.h"
#include "ast/euf/euf_etable.h"
#include "ast/euf/euf_plugin.h"
#include "ast/ast_ll_pp.h"
#include <vector>
@ -82,12 +84,15 @@ namespace euf {
class egraph {
friend class plugin;
typedef ptr_vector<trail> trail_stack;
struct to_merge {
enode* a, * b;
bool commutativity;
to_merge(enode* a, enode* b, bool c) : a(a), b(b), commutativity(c) {}
justification j;
to_merge(enode* a, enode* b, bool c) : a(a), b(b), j(justification::congruence(c, 0)) {}
to_merge(enode* a, enode* b, justification j) : a(a), b(b), j(j) {}
};
struct stats {
@ -113,10 +118,12 @@ namespace euf {
struct lbl_set {};
struct update_children {};
struct set_relevant {};
struct plugin_undo {};
enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children,
is_add_th_var, is_replace_th_var, is_new_th_eq,
is_lbl_hash, is_new_th_eq_qhead,
is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant };
is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant,
is_plugin_undo };
tag_t tag;
enode* r1;
enode* n1;
@ -159,11 +166,14 @@ namespace euf {
tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {}
update_record(enode* n, set_relevant) :
tag(tag_t::is_set_relevant), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {}
update_record(unsigned th_id, plugin_undo) :
tag(tag_t::is_plugin_undo), r1(nullptr), n1(nullptr), m_th_id(th_id) {}
};
ast_manager& m;
svector<to_merge> m_to_merge;
etable m_table;
region m_region;
scoped_ptr_vector<plugin> m_plugins;
svector<update_record> m_updates;
unsigned_vector m_scopes;
enode_vector m_expr2enode;
@ -202,6 +212,13 @@ namespace euf {
}
void push_node(enode* n) { m_updates.push_back(update_record(n)); }
// plugin related methods
void push_plugin_undo(unsigned th_id) { m_updates.push_back(update_record(th_id, update_record::plugin_undo())); }
void push_merge(enode* a, enode* b, justification j) { m_to_merge.push_back({ a, b, j }); }
plugin* get_plugin(enode* n) { return m_plugins.get(n->get_sort()->get_family_id(), nullptr); }
void register_node(enode* n);
void propagate_plugins();
void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r);
void add_th_diseqs(theory_id id, theory_var v1, enode* r);
@ -213,7 +230,7 @@ namespace euf {
void force_push();
void set_conflict(enode* n1, enode* n2, justification j);
void merge(enode* n1, enode* n2, justification j);
void merge_th_eq(enode* n, enode* root);
void merge_th_eq(enode* n, enode* root, justification j);
void merge_justification(enode* n1, enode* n2, justification j);
void reinsert_parents(enode* r1, enode* r2);
void remove_parents(enode* r);
@ -241,6 +258,7 @@ namespace euf {
public:
egraph(ast_manager& m);
~egraph();
void add_plugins();
enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); }
enode* find(expr* f, unsigned n, enode* const* args);
enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args);

View file

@ -202,6 +202,7 @@ namespace euf {
enode* get_root() const { return m_root; }
expr* get_expr() const { return m_expr; }
sort* get_sort() const { return m_expr->get_sort(); }
enode* get_interpreted() const { return get_root(); }
app* get_app() const { return to_app(m_expr); }
func_decl* get_decl() const { return is_app(m_expr) ? to_app(m_expr)->get_decl() : nullptr; }
unsigned get_expr_id() const { return m_expr->get_id(); }

View file

@ -22,19 +22,34 @@ Notes:
#pragma once
#include "util/dependency.h"
namespace euf {
class enode;
class justification {
public:
typedef scoped_dependency_manager<justification> dependency_manager;
typedef scoped_dependency_manager<justification>::dependency dependency;
private:
enum class kind_t {
axiom_t,
congruence_t,
external_t
external_t,
dependent_t,
equality_t
};
kind_t m_kind;
bool m_comm;
union {
bool m_comm;
enode* m_n1;
};
union {
void* m_external;
uint64_t m_timestamp;
dependency* m_dependency;
enode* m_n2;
};
justification(bool comm, uint64_t ts):
@ -49,6 +64,18 @@ namespace euf {
m_external(ext)
{}
justification(dependency* dep, int):
m_kind(kind_t::dependent_t),
m_comm(false),
m_dependency(dep)
{}
justification(enode* n1, enode* n2):
m_kind(kind_t::equality_t),
m_n1(n1),
m_n2(n2)
{}
public:
justification():
m_kind(kind_t::axiom_t),
@ -59,10 +86,17 @@ namespace euf {
static justification axiom() { return justification(); }
static justification congruence(bool c, uint64_t ts) { return justification(c, ts); }
static justification external(void* ext) { return justification(ext); }
static justification dependent(dependency* d) { return justification(d, 1); }
static justification equality(enode* a, enode* b) { return justification(a, b); }
bool is_external() const { return m_kind == kind_t::external_t; }
bool is_congruence() const { return m_kind == kind_t::congruence_t; }
bool is_commutative() const { return m_comm; }
bool is_dependent() const { return m_kind == kind_t::dependent_t; }
bool is_equality() const { return m_kind == kind_t::equality_t; }
dependency* get_dependency() const { SASSERT(is_dependent()); return m_dependency; }
enode* lhs() const { SASSERT(is_equality()); return m_n1; }
enode* rhs() const { SASSERT(is_equality()); return m_n2; }
uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; }
template <typename T>
T* ext() const { SASSERT(is_external()); return static_cast<T*>(m_external); }
@ -75,6 +109,9 @@ namespace euf {
return axiom();
case kind_t::congruence_t:
return congruence(m_comm, m_timestamp);
case kind_t::dependent_t:
NOT_IMPLEMENTED_YET();
return dependent(m_dependency);
default:
UNREACHABLE();
return axiom();
@ -93,6 +130,13 @@ namespace euf {
return out << "axiom";
case kind_t::congruence_t:
return out << "congruence";
case kind_t::dependent_t: {
vector<justification, false> js;
out << "dependent";
for (auto const& j : dependency_manager::s_linearize(m_dependency, js))
j.display(out << " ", ext);
return out;
}
default:
UNREACHABLE();
return out;

View file

@ -0,0 +1,47 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_plugin.cpp
Abstract:
plugin structure for euf
Plugins allow adding equality saturation for theories.
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
--*/
#include "ast/euf/euf_egraph.h"
namespace euf {
void plugin::push_plugin_undo(unsigned th_id) {
g.push_plugin_undo(th_id);
}
void plugin::push_merge(enode* a, enode* b, justification j) {
g.push_merge(a, b, j);
}
void plugin::push_merge(enode* a, enode* b) {
TRACE("plugin", tout << g.bpp(a) << " == " << g.bpp(b) << "\n");
g.push_merge(a, b, justification::axiom());
}
enode* plugin::mk(expr* e, unsigned n, enode* const* args) {
enode* r = g.find(e);
if (!r)
r = g.mk(e, 0, n, args);
return r;
}
region& plugin::get_region() {
return g.m_region;
}
}

58
src/ast/euf/euf_plugin.h Normal file
View file

@ -0,0 +1,58 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_plugin.h
Abstract:
plugin structure for euf
Plugins allow adding equality saturation for theories.
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
--*/
#pragma once
#include "ast/euf/euf_enode.h"
#include "ast/euf/euf_justification.h"
namespace euf {
class plugin {
protected:
egraph& g;
void push_plugin_undo(unsigned th_id);
void push_merge(enode* a, enode* b, justification j);
void push_merge(enode* a, enode* b);
enode* mk(expr* e, unsigned n, enode* const* args);
region& get_region();
public:
plugin(egraph& g):
g(g)
{}
virtual unsigned get_id() const = 0;
virtual void register_node(enode* n) = 0;
virtual void register_shared(enode* n) = 0;
virtual void merge_eh(enode* n1, enode* n2, justification j) = 0;
virtual void diseq_eh(enode* n1, enode* n2) = 0;
virtual void propagate() = 0;
virtual void undo() = 0;
virtual std::ostream& display(std::ostream& out) const = 0;
};
}

View file

@ -420,8 +420,6 @@ namespace euf {
return *c;
}
bool solver::unit_propagate() {
bool propagated = false;
while (!s().inconsistent()) {

View file

@ -38,6 +38,7 @@ add_executable(test-z3
dl_util.cpp
doc.cpp
egraph.cpp
euf_bv_plugin.cpp
escaped.cpp
ex.cpp
expr_rand.cpp

180
src/test/euf_bv_plugin.cpp Normal file
View file

@ -0,0 +1,180 @@
/*++
Copyright (c) 2023 Microsoft Corporation
--*/
#include "util/util.h"
#include "util/timer.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/reg_decl_plugins.h"
#include "ast/ast_pp.h"
#include <iostream>
euf::enode* get_node(euf::egraph& g, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, arg));
return g.mk(e, 0, args.size(), args.data());
}
// align slices, and propagate extensionality
static void test1() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugins();
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref y(m.mk_const("y", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref y3(bv.mk_extract(31, 24, y), m);
expr_ref y2(bv.mk_extract(23, 8, y), m);
expr_ref y1(bv.mk_extract(7, 0, y), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m);
auto* nx = get_node(g, xx);
auto* ny = get_node(g, yy);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
SASSERT(nx->get_root() == ny->get_root());
}
// propagate values down
static void test2() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugins();
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
g.merge(get_node(g, xx), get_node(g, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr);
g.propagate();
SASSERT(get_node(g, x1)->get_root()->interpreted());
SASSERT(get_node(g, x2)->get_root()->interpreted());
SASSERT(get_node(g, x3)->get_root()->interpreted());
SASSERT(get_node(g, x)->get_root()->interpreted());
}
// propagate values up
static void test3() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugins();
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m);
expr_ref y(m.mk_const("y", u32), m);
g.merge(get_node(g, xx), get_node(g, y), nullptr);
g.merge(get_node(g, x1), get_node(g, bv.mk_numeral(2, 8)), nullptr);
g.merge(get_node(g, x2), get_node(g, bv.mk_numeral(8, 8)), nullptr);
g.propagate();
SASSERT(get_node(g, bv.mk_concat(x1, x2))->get_root()->interpreted());
SASSERT(get_node(g, x1)->get_root()->interpreted());
SASSERT(get_node(g, x2)->get_root()->interpreted());
}
// propagate extract up
static void test4() {
// concat(a, x[J]), a = x[I] => x[IJ] = concat(x[I],x[J])
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugins();
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
sort_ref u8(bv.mk_sort(8), m);
sort_ref u16(bv.mk_sort(16), m);
expr_ref a(m.mk_const("a", u8), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref y(m.mk_const("y", u16), m);
expr_ref x1(bv.mk_extract(15, 8, x), m);
expr_ref x2(bv.mk_extract(23, 16, x), m);
g.merge(get_node(g, bv.mk_concat(a, x2)), get_node(g, y), nullptr);
g.merge(get_node(g, x1), get_node(g, a), nullptr);
g.propagate();
TRACE("bv", tout << g << "\n");
SASSERT(get_node(g, bv.mk_extract(23, 8, x))->get_root() == get_node(g, y)->get_root());
}
// iterative slicing
static void test5() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugins();
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 4, x), m);
expr_ref x2(bv.mk_extract(27, 0, x), m);
auto* nx = get_node(g, x1);
auto* ny = get_node(g, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
// iterative slicing
static void test6() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugins();
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 3, x), m);
expr_ref x2(bv.mk_extract(28, 0, x), m);
auto* nx = get_node(g, x1);
auto* ny = get_node(g, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
void tst_euf_bv_plugin() {
enable_trace("bv");
enable_trace("plugin");
test6();
return;
test1();
test2();
test3();
test4();
test5();
test6();
}

View file

@ -270,4 +270,5 @@ int main(int argc, char ** argv) {
TST(slicing);
TST(totalizer);
TST(distribution);
TST(euf_bv_plugin);
}

View file

@ -69,6 +69,14 @@ public:
d->unmark();
}
static void s_linearize(dependency* d, vector<value, false>& vs) {
if (!d)
return;
ptr_vector<dependency> todo;
todo.push_back(d);
linearize_todo(todo, vs);
}
private:
struct join : public dependency {
dependency * m_children[2];
@ -325,6 +333,11 @@ public:
return m_dep_manager.linearize(d, vs);
}
static vector<value, false> const& s_linearize(dependency* d, vector<value, false>& vs) {
dep_manager::s_linearize(d, vs);
return vs;
}
void linearize(ptr_vector<dependency>& d, vector<value, false> & vs) {
return m_dep_manager.linearize(d, vs);
}