3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00

add EUF plugin framework.

plugin setting allows adding equality saturation within the E-graph propagation without involving externalizing theory solver dispatch. It makes equality saturation independent of SAT integration.
Add a special relation operator to support ad-hoc AC symbols.
This commit is contained in:
Nikolaj Bjorner 2023-11-30 13:58:24 -08:00
parent 5784c2da79
commit b52fd8d954
28 changed files with 3063 additions and 68 deletions

View file

@ -1,8 +1,14 @@
z3_add_component(euf
SOURCES
euf_ac_plugin.cpp
euf_arith_plugin.cpp
euf_bv_plugin.cpp
euf_egraph.cpp
euf_enode.cpp
euf_etable.cpp
euf_egraph.cpp
euf_justification.cpp
euf_plugin.cpp
euf_specrel_plugin.cpp
COMPONENT_DEPENDENCIES
ast
util

File diff suppressed because it is too large Load diff

309
src/ast/euf/euf_ac_plugin.h Normal file
View file

@ -0,0 +1,309 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_ac_plugin.h
Abstract:
plugin structure for ac functions
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
ex:
xyz -> xy, then xyzz -> xy by repeated rewriting
monomials = [0 |-> xyz, 1 |-> xy, 2 |-> xyzz]
parents(x) = [0, 1, 2]
parents(z) = [0, 1]
for p in parents(xyzz):
p != xyzz
p' := simplify_using(xyzz, p)
if p != p':
repeat reduction using p := p'
--*/
#pragma once
#include <iostream>
#include "ast/euf/euf_plugin.h"
namespace euf {
class ac_plugin : public plugin {
// enode structure for AC equivalences
struct node {
enode* n; // associated enode
node* root; // path compressed root
node* next; // next in equaivalence class
justification j; // justification for equality
node* target = nullptr; // justified next
unsigned_vector shared; // shared occurrences
unsigned_vector eqs; // equality occurrences
unsigned root_id() const { return root->n->get_id(); }
~node() {}
static node* mk(region& r, enode* n);
};
class equiv {
node& n;
public:
class iterator {
node* m_first;
node* m_last;
public:
iterator(node* n, node* m) : m_first(n), m_last(m) {}
node* operator*() { return m_first; }
iterator& operator++() { if (!m_last) m_last = m_first; m_first = m_first->next; return *this; }
iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; }
bool operator==(iterator const& other) const { return m_last == other.m_last && m_first == other.m_first; }
bool operator!=(iterator const& other) const { return !(*this == other); }
};
equiv(node& _n) :n(_n) {}
equiv(node* _n) :n(*_n) {}
iterator begin() const { return iterator(&n, nullptr); }
iterator end() const { return iterator(&n, &n); }
};
struct bloom {
uint64_t m_tick = 0;
uint64_t m_filter = 0;
};
enum eq_status {
processed, to_simplify, is_dead
};
// represent equalities added by merge_eh and by superposition
struct eq {
eq(unsigned l, unsigned r, justification j):
l(l), r(r), j(j) {}
unsigned l, r; // refer to monomials
eq_status status = to_simplify;
justification j; // justification for equality
};
// represent shared enodes that use the AC symbol.
struct shared {
enode* n; // original shared enode
unsigned m; // monomial index
justification j; // justification for current simplification of monomial
};
struct monomial_t {
ptr_vector<node> m_nodes;
bloom m_bloom;
node* operator[](unsigned i) const { return m_nodes[i]; }
unsigned size() const { return m_nodes.size(); }
void set(ptr_vector<node> const& ns) { m_nodes.reset(); m_nodes.append(ns); m_bloom.m_tick = 0; }
node* const* begin() const { return m_nodes.begin(); }
node* const* end() const { return m_nodes.end(); }
node* * begin() { return m_nodes.begin(); }
node* * end() { return m_nodes.end(); }
};
struct monomial_hash {
ac_plugin& p;
monomial_hash(ac_plugin& p) :p(p) {}
unsigned operator()(unsigned i) const {
unsigned h = 0;
auto& m = p.monomial(i);
if (!p.is_sorted(m))
p.sort(m);
for (auto* n : m)
h = combine_hash(h, n->root_id());
return h;
}
};
struct monomial_eq {
ac_plugin& p;
monomial_eq(ac_plugin& p) :p(p) {}
bool operator()(unsigned i, unsigned j) const {
auto const& m1 = p.monomial(i);
auto const& m2 = p.monomial(j);
if (m1.size() != m2.size()) return false;
for (unsigned k = 0; k < m1.size(); ++k)
if (m1[k]->root_id() != m2[k]->root_id())
return false;
return true;
}
};
unsigned m_fid = 0;
unsigned m_op = null_decl_kind;
func_decl* m_decl = nullptr;
vector<eq> m_eqs;
ptr_vector<node> m_nodes;
bool_vector m_shared_nodes;
vector<monomial_t> m_monomials;
svector<shared> m_shared;
justification::dependency_manager m_dep_manager;
tracked_uint_set m_to_simplify_todo;
tracked_uint_set m_shared_todo;
uint64_t m_tick = 1;
monomial_hash m_hash;
monomial_eq m_eq;
map<unsigned, shared, monomial_hash, monomial_eq> m_monomial_table;
// backtrackable state
enum undo_kind {
is_add_eq,
is_add_monomial,
is_add_node,
is_merge_node,
is_update_eq,
is_add_shared_index,
is_add_eq_index,
is_register_shared,
is_update_shared
};
svector<undo_kind> m_undo;
ptr_vector<node> m_node_trail;
svector<std::pair<unsigned, shared>> m_update_shared_trail;
svector<std::tuple<node*, unsigned, unsigned>> m_merge_trail;
svector<std::pair<unsigned, eq>> m_update_eq_trail;
node* mk_node(enode* n);
void merge(node* r1, node* r2, justification j);
bool is_op(enode* n) const { auto d = n->get_decl(); return d && (d == m_decl || (m_fid == d->get_family_id() && m_op == d->get_decl_kind())); }
std::function<void(void)> m_undo_notify;
void push_undo(undo_kind k);
enode_vector m_todo;
unsigned to_monomial(enode* n);
unsigned to_monomial(enode* n, ptr_vector<node> const& ms);
unsigned to_monomial(ptr_vector<node> const& ms) { return to_monomial(nullptr, ms); }
monomial_t const& monomial(unsigned i) const { return m_monomials[i]; }
monomial_t& monomial(unsigned i) { return m_monomials[i]; }
void sort(monomial_t& monomial);
bool is_sorted(monomial_t const& monomial) const;
uint64_t filter(monomial_t& m);
bool can_be_subset(monomial_t& subset, monomial_t& superset);
bool can_be_subset(monomial_t& subset, ptr_vector<node> const& m, bloom& b);
bool are_equal(ptr_vector<node> const& a, ptr_vector<node> const& b);
bool are_equal(monomial_t& a, monomial_t& b);
bool backward_subsumes(unsigned src_eq, unsigned dst_eq);
bool forward_subsumes(unsigned src_eq, unsigned dst_eq);
void init_equation(eq const& e);
bool orient_equation(eq& e);
void set_status(unsigned eq_id, eq_status s);
unsigned pick_next_eq();
void forward_simplify(unsigned eq_id, unsigned using_eq);
bool backward_simplify(unsigned eq_id, unsigned using_eq);
void superpose(unsigned src_eq, unsigned dst_eq);
ptr_vector<node> m_src_r, m_src_l, m_dst_r, m_dst_l;
struct ref_counts {
unsigned_vector ids;
unsigned_vector counts;
void reset() { for (auto idx : ids) counts[idx] = 0; ids.reset(); }
unsigned operator[](unsigned idx) const { return counts.get(idx, 0); }
void inc(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] += amount; }
void dec(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] -= amount; }
unsigned const* begin() const { return ids.begin(); }
unsigned const* end() const { return ids.end(); }
};
ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts, m_m_counts;
unsigned_vector m_eq_occurs;
bool_vector m_eq_seen;
unsigned_vector const& forward_iterator(unsigned eq);
unsigned_vector const& superpose_iterator(unsigned eq);
unsigned_vector const& backward_iterator(unsigned eq);
void init_ref_counts(monomial_t const& monomial, ref_counts& counts) const;
void init_ref_counts(ptr_vector<node> const& monomial, ref_counts& counts) const;
void init_overlap_iterator(unsigned eq, monomial_t const& m);
void init_subset_iterator(unsigned eq, monomial_t const& m);
void compress_eq_occurs(unsigned eq_id);
// check that src is a subset of dst, where dst_counts are precomputed
bool is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src);
// check that dst is a superset of dst, where src_counts are precomputed
bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst);
void rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_r_counts, ptr_vector<node>& dst_r);
bool reduce(ptr_vector<node>& m, justification& j);
void index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r);
bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; }
bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; }
bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; }
justification justify_rewrite(unsigned eq1, unsigned eq2);
justification::dependency* justify_equation(unsigned eq);
justification::dependency* justify_monomial(justification::dependency* d, monomial_t const& m);
justification join(justification j1, unsigned eq);
bool is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const;
bool is_correct_ref_count(ptr_vector<node> const& m, ref_counts const& counts) const;
void register_shared(enode* n);
void propagate_shared();
void simplify_shared(unsigned idx, shared s);
std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const { return display_monomial(out, m.m_nodes); }
std::ostream& display_monomial(std::ostream& out, ptr_vector<node> const& m) const;
std::ostream& display_equation(std::ostream& out, eq const& e) const;
std::ostream& display_status(std::ostream& out, eq_status s) const;
public:
ac_plugin(egraph& g, unsigned fid, unsigned op);
ac_plugin(egraph& g, func_decl* f);
~ac_plugin() override {}
unsigned get_id() const override { return m_fid; }
void register_node(enode* n) override;
void merge_eh(enode* n1, enode* n2) override;
void diseq_eh(enode* eq) override;
void undo() override;
void propagate() override;
std::ostream& display(std::ostream& out) const override;
void set_undo(std::function<void(void)> u) { m_undo_notify = u; }
struct eq_pp {
ac_plugin& p; eq const& e;
eq_pp(ac_plugin& p, eq const& e) : p(p), e(e) {};
eq_pp(ac_plugin& p, unsigned eq_id): p(p), e(p.m_eqs[eq_id]) {}
std::ostream& display(std::ostream& out) const { return p.display_equation(out, e); }
};
struct m_pp {
ac_plugin& p; ptr_vector<node> const& m;
m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m.m_nodes) {}
m_pp(ac_plugin& p, ptr_vector<node> const& m) : p(p), m(m) {}
std::ostream& display(std::ostream& out) const { return p.display_monomial(out, m); }
};
};
inline std::ostream& operator<<(std::ostream& out, ac_plugin::eq_pp const& d) { return d.display(out); }
inline std::ostream& operator<<(std::ostream& out, ac_plugin::m_pp const& d) { return d.display(out); }
}

View file

@ -0,0 +1,71 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_arith_plugin.cpp
Abstract:
plugin structure for arithetic
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#include "ast/euf/euf_arith_plugin.h"
#include "ast/euf/euf_egraph.h"
#include <algorithm>
namespace euf {
arith_plugin::arith_plugin(egraph& g) :
plugin(g),
a(g.get_manager()),
m_add(g, get_id(), OP_ADD),
m_mul(g, get_id(), OP_MUL) {
std::function<void(void)> uadd = [&]() { m_undo.push_back(undo_t::undo_add); };
m_add.set_undo(uadd);
std::function<void(void)> umul = [&]() { m_undo.push_back(undo_t::undo_mul); };
m_mul.set_undo(umul);
}
void arith_plugin::register_node(enode* n) {
// no-op
}
void arith_plugin::merge_eh(enode* n1, enode* n2) {
m_add.merge_eh(n1, n2);
m_mul.merge_eh(n1, n2);
}
void arith_plugin::propagate() {
m_add.propagate();
m_mul.propagate();
}
void arith_plugin::undo() {
auto k = m_undo.back();
m_undo.pop_back();
switch (k) {
case undo_t::undo_add:
m_add.undo();
break;
case undo_t::undo_mul:
m_mul.undo();
break;
default:
UNREACHABLE();
}
}
std::ostream& arith_plugin::display(std::ostream& out) const {
out << "add\n";
m_add.display(out);
out << "mul\n";
m_mul.display(out);
return out;
}
}

View file

@ -0,0 +1,53 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_arith_plugin.h
Abstract:
plugin structure for arithetic
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#pragma once
#include "ast/arith_decl_plugin.h"
#include "ast/euf/euf_plugin.h"
#include "ast/euf/euf_ac_plugin.h"
namespace euf {
class egraph;
class arith_plugin : public plugin {
enum undo_t { undo_add, undo_mul };
arith_util a;
svector<undo_t> m_undo;
ac_plugin m_add, m_mul;
public:
arith_plugin(egraph& g);
~arith_plugin() override {}
unsigned get_id() const override { return a.get_family_id(); }
void register_node(enode* n) override;
void merge_eh(enode* n1, enode* n2) override;
void diseq_eh(enode* eq) override {}
void undo() override;
void propagate() override;
std::ostream& display(std::ostream& out) const override;
};
}

View file

@ -0,0 +1,361 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_bv_plugin.cpp
Abstract:
plugin structure for bit-vectors
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
Jakob Rath 2023-11-08
Objective:
satisfies extract/concat axioms.
- concat(n{I],n[J]) = n[IJ] for I, J consecutive.
- concat(v1, v2) = 2^width(v1)*v2 + v1
- concat(n[width(n)-1:0]) = n
- concat(a, b)[I] = concat(a[I1], b[I2])
- concat(a, concat(b, c)) = concat(concat(a, b), c)
E-graph:
The E-graph contains node definitions of the form
n := f(n1,n2,..)
and congruences:
n ~ n' means root(n) = root(n')
Saturated state:
1. n := n1[I], n' := n2[J], n1 ~ n2 => root(n1) contains tree refining both I, J from smaller intervals
2. n := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => n ~ n3[IJ]
3. n := concat(n1[I], n2[J]), I and J are consecutive & n1 ~ n2, n1[I] ~ v1, n1[J] ~ v2 => n ~ 2^width(v1)*v2 + v1
4. n := concat(n1[I], n2[J], I, J are consecutive, n1 ~ n2, n ~ v => n1[I] ~ v mod 2^width(n1[I]), n2[J] ~ v div 2^width(n1[I])
5. n' := n[I] => n ~ n[width(n)-1:0]
6. n := concat(a, concat(b, c)) => n ~ concat(concat(a, b), c)
- handled by rewriter pre-processing for inputs
- terms created internally are not equated modulo associativity
7, n := concat(n1, n2)[I] => n ~ concat(n1[I1],n2[I2]) or n[I1] or n[I2]
- handled by rewriter pre-processing
Example:
x == (x1 x2) x3
y == y1 (y2 y3)
x1 == y1, x2 == y2, x3 == y3
=>
x = y
by x2 == y2, x3 == y3 => (x2 x3) = (y2 y3)
by (2) => x[I23] = (x2 x3)
by (2) => x[I123] = (x1 (x2 x3))
by (5) => x = x[I123]
The formal properties of saturation have to be established.
- Saturation does not complete with respect to associativity.
Instead the claim is along the lines that the resulting E-graph can be used as a canonizer.
If given a set of equations E that are saturated, and terms t1, t2 that are
both simplified with respect to left-associativity of concatentation, and t1, t2 belong to the E-graph,
then t1 = t2 iff t1 ~ t2 in the E-graph.
TODO: Is saturation for (7) overkill for the purpose of canonization?
TODO: revisit re-entrancy during register_node. It can be called when creating internal extract terms.
Instead of allowing re-entrancy we can accumulate nodes that are registered during recursive calls
and have the main call perform recursive slicing.
--*/
#include "ast/euf/euf_bv_plugin.h"
#include "ast/euf/euf_egraph.h"
namespace euf {
bv_plugin::bv_plugin(egraph& g):
plugin(g),
bv(g.get_manager())
{}
enode* bv_plugin::mk_value_concat(enode* a, enode* b) {
auto v1 = get_value(a);
auto v2 = get_value(b);
auto v3 = v1 + v2 * power(rational(2), width(a));
return mk_value(v3, width(a) + width(b));
}
enode* bv_plugin::mk_value(rational const& v, unsigned sz) {
auto e = bv.mk_numeral(v, sz);
return mk(e, 0, nullptr);
}
void bv_plugin::merge_eh(enode* x, enode* y) {
SASSERT(x == x->get_root());
SASSERT(x == y->get_root());
TRACE("bv", tout << "merge_eh " << g.bpp(x) << " == " << g.bpp(y) << "\n");
SASSERT(!m_internal);
flet<bool> _internal(m_internal, true);
propagate_values(x);
// ensure slices align
if (has_sub(x) || has_sub(y)) {
enode_vector& xs = m_xs, & ys = m_ys;
xs.reset();
ys.reset();
xs.push_back(x);
ys.push_back(y);
merge(xs, ys, justification::equality(x, y));
}
// ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ]
for (auto* n : enode_class(x))
propagate_extract(n);
}
// enforce concat(v1, v2) = v2*2^|v1| + v1
void bv_plugin::propagate_values(enode* x) {
if (!is_value(x))
return;
enode* a, * b;
for (enode* p : enode_parents(x))
if (is_concat(p, a, b) && is_value(a) && is_value(b) && !is_value(p))
push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b));
for (enode* sib : enode_class(x)) {
if (is_concat(sib, a, b)) {
if (!is_value(a) || !is_value(b)) {
auto val = get_value(x);
auto v1 = mod2k(val, width(a));
auto v2 = machine_div2k(val, width(a));
push_merge(mk_concat(mk_value(v1, width(a)), mk_value(v2, width(b))), x->get_interpreted());
}
}
}
}
//
// p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ]
//
// n is of form arg[I]
// p is of form concat(n, b) or concat(a, n)
// b is congruent to arg[J], I is consecutive with J => ensure that arg[IJ] = p
// a is congruent to arg[J], J is consecutive with I => ensure that arg[JI] = p
//
void bv_plugin::propagate_extract(enode* n) {
unsigned lo1, hi1, lo2, hi2;
enode* a, * b;
if (!is_extract(n, lo1, hi1))
return;
enode* arg = n->get_arg(0);
enode* arg_r = arg->get_root();
enode* n_r = n->get_root();
auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) {
TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n");
unsigned lo_, hi_;
for (enode* p1 : enode_parents(n))
if (is_extract(p1, lo_, hi_) && lo_ == lo && hi_ == hi && p1->get_arg(0)->get_root() == arg_r)
return;
// add the axiom instead of merge(p, mk_extract(arg, lo, hi)), which would require tracking justifications
push_merge(mk_concat(mk_extract(arg, lo, mid), mk_extract(arg, mid + 1, hi)), mk_extract(arg, lo, hi));
};
auto propagate_left = [&](enode* b) {
TRACE("bv", tout << "propagate-left " << g.bpp(b) << "\n");
for (enode* sib : enode_class(b))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2)
ensure_concat(lo1, hi1, hi2);
};
auto propagate_right = [&](enode* a) {
TRACE("bv", tout << "propagate-right " << g.bpp(a) << "\n");
for (enode* sib : enode_class(a))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1)
ensure_concat(lo2, hi2, hi1);
};
for (enode* p : enode_parents(n)) {
if (is_concat(p, a, b)) {
if (a->get_root() == n_r)
propagate_left(b);
if (b->get_root() == n_r)
propagate_right(a);
}
}
}
void bv_plugin::push_undo_split(enode* n) {
m_undo_split.push_back(n);
push_plugin_undo(bv.get_family_id());
}
void bv_plugin::undo() {
enode* n = m_undo_split.back();
m_undo_split.pop_back();
auto& i = info(n);
i.lo = nullptr;
i.hi = nullptr;
i.cut = null_cut;
}
void bv_plugin::register_node(enode* n) {
TRACE("bv", tout << "register " << g.bpp(n) << "\n");
auto& i = info(n);
i.value = n;
enode* a, * b;
if (is_concat(n, a, b)) {
i.lo = a;
i.hi = b;
i.cut = width(a);
push_undo_split(n);
}
unsigned lo, hi;
if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) {
enode* arg = n->get_arg(0);
unsigned w = width(arg);
if (all_of(enode_parents(arg), [&](enode* p) { unsigned _lo, _hi; return !is_extract(p, _lo, _hi) || _lo != 0 || _hi + 1 != w; }))
push_merge(mk_extract(arg, 0, w - 1), arg);
ensure_slice(arg, lo, hi);
}
}
//
// Ensure that there are slices at boundaries of n[hi:lo]
//
void bv_plugin::ensure_slice(enode* n, unsigned lo, unsigned hi) {
enode* r = n;
unsigned lb = 0, ub = width(n) - 1;
while (true) {
TRACE("bv", tout << "ensure slice " << g.bpp(n) << " " << lb << " [" << lo << ", " << hi << "] " << ub << "\n");
SASSERT(lb <= lo && hi <= ub);
SASSERT(ub - lb + 1 == width(r));
if (lb == lo && ub == hi)
return;
slice_info& i = info(r);
if (!i.lo) {
if (lo > lb) {
split(r, lo - lb);
if (hi < ub) // or split(info(r).hi, ...)
ensure_slice(n, lo, hi);
}
else if (hi < ub)
split(r, ub - hi);
break;
}
auto cut = i.cut;
if (cut + lb <= lo) {
lb += cut;
r = i.hi;
continue;
}
if (cut + lb > hi) {
ub = cut + lb - 1;
r = i.lo;
continue;
}
SASSERT(lo < cut + lb && cut + lb <= hi);
ensure_slice(n, lo, cut + lb - 1);
ensure_slice(n, cut + lb, hi);
break;
}
}
enode* bv_plugin::mk_extract(enode* n, unsigned lo, unsigned hi) {
SASSERT(lo <= hi && width(n) > hi - lo);
unsigned lo1, hi1;
while (is_extract(n, lo1, hi1)) {
lo += lo1;
hi += lo1;
n = n->get_arg(0);
}
return mk(bv.mk_extract(hi, lo, n->get_expr()), 1, &n);
}
enode* bv_plugin::mk_concat(enode* lo, enode* hi) {
enode* args[2] = { lo, hi };
return mk(bv.mk_concat(lo->get_expr(), hi->get_expr()), 2, args);
}
void bv_plugin::merge(enode_vector& xs, enode_vector& ys, justification dep) {
while (!xs.empty()) {
SASSERT(!ys.empty());
auto x = xs.back();
auto y = ys.back();
if (unfold_sub(x, xs))
continue;
else if (unfold_sub(y, ys))
continue;
else if (unfold_width(x, xs, y, ys))
continue;
else if (unfold_width(y, ys, x, xs))
continue;
else if (x->get_root() != y->get_root())
push_merge(x, y, dep);
xs.pop_back();
ys.pop_back();
}
SASSERT(ys.empty());
}
bool bv_plugin::unfold_sub(enode* x, enode_vector& xs) {
if (!has_sub(x))
return false;
xs.pop_back();
xs.push_back(sub_hi(x));
xs.push_back(sub_lo(x));
return true;
}
bool bv_plugin::unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys) {
if (width(x) <= width(y))
return false;
split(x, width(y));
xs.pop_back();
xs.push_back(sub_hi(x));
xs.push_back(sub_lo(x));
return true;
}
void bv_plugin::split(enode* n, unsigned cut) {
TRACE("bv", tout << "split: " << g.bpp(n) << " " << cut << "\n");
unsigned w = width(n);
SASSERT(!info(n).hi);
SASSERT(0 < cut && cut < w);
enode* hi = mk_extract(n, cut, w - 1);
enode* lo = mk_extract(n, 0, cut - 1);
auto& i = info(n);
SASSERT(i.value);
i.hi = hi;
i.lo = lo;
i.cut = cut;
push_undo_split(n);
push_merge(mk_concat(lo, hi), n);
}
std::ostream& bv_plugin::display(std::ostream& out) const {
out << "bv\n";
for (auto const& i : m_info)
if (i.lo)
out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n";
return out;
}
}

100
src/ast/euf/euf_bv_plugin.h Normal file
View file

@ -0,0 +1,100 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_bv_plugin.h
Abstract:
plugin structure for bit-vectors
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
Jakob Rath 2023-11-08
--*/
#pragma once
#include "ast/bv_decl_plugin.h"
#include "ast/euf/euf_plugin.h"
namespace euf {
class egraph;
class bv_plugin : public plugin {
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
struct slice_info {
unsigned cut = null_cut; // = bv.get_bv_size(lo)
enode* hi = nullptr; //
enode* lo = nullptr; //
enode* value = nullptr;
void reset() { *this = slice_info(); }
};
using slice_info_vector = svector<slice_info>;
bv_util bv;
slice_info_vector m_info; // indexed by enode::get_id()
enode_vector m_xs, m_ys;
bool is_concat(enode* n) const { return bv.is_concat(n->get_expr()); }
bool is_concat(enode* n, enode*& a, enode*& b) { return is_concat(n) && (a = n->get_arg(0), b = n->get_arg(1), true); }
bool is_extract(enode* n, unsigned& lo, unsigned& hi) { expr* body; return bv.is_extract(n->get_expr(), lo, hi, body); }
bool is_extract(enode* n) const { return bv.is_extract(n->get_expr()); }
unsigned width(enode* n) const { return bv.get_bv_size(n->get_expr()); }
enode* mk_extract(enode* n, unsigned lo, unsigned hi);
enode* mk_concat(enode* lo, enode* hi);
enode* mk_value_concat(enode* lo, enode* hi);
enode* mk_value(rational const& v, unsigned sz);
unsigned width(enode* n) { return bv.get_bv_size(n->get_expr()); }
bool is_value(enode* n) { return n->get_root()->interpreted(); }
rational get_value(enode* n) { rational val; VERIFY(bv.is_numeral(n->get_interpreted()->get_expr(), val)); return val; }
slice_info& info(enode* n) { unsigned id = n->get_id(); m_info.reserve(id + 1); return m_info[id]; }
slice_info& root_info(enode* n) { unsigned id = n->get_root_id(); m_info.reserve(id + 1); return m_info[id]; }
bool has_sub(enode* n) { return !!info(n).lo; }
enode* sub_lo(enode* n) { return info(n).lo; }
enode* sub_hi(enode* n) { return info(n).hi; }
bool m_internal = false;
void ensure_slice(enode* n, unsigned lo, unsigned hi);
void split(enode* n, unsigned cut);
bool unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys);
bool unfold_sub(enode* x, enode_vector& xs);
void merge(enode_vector& xs, enode_vector& ys, justification j);
void propagate_extract(enode* n);
void propagate_values(enode* n);
enode_vector m_undo_split;
void push_undo_split(enode* n);
public:
bv_plugin(egraph& g);
~bv_plugin() override {}
unsigned get_id() const override { return bv.get_family_id(); }
void register_node(enode* n) override;
void merge_eh(enode* n1, enode* n2) override;
void diseq_eh(enode* eq) override {}
void propagate() override {}
void undo() override;
std::ostream& display(std::ostream& out) const override;
};
}

View file

@ -130,8 +130,8 @@ namespace euf {
if (n2 == n)
update_children(n);
else
merge(n, n2, justification::congruence(comm, m_congruence_timestamp++));
push_merge(n, n2, comm);
return n;
}
@ -146,19 +146,36 @@ namespace euf {
memory::deallocate(m_tmp_node);
}
void egraph::add_plugin(plugin* p) {
m_plugins.reserve(p->get_id() + 1);
m_plugins.set(p->get_id(), p);
}
void egraph::propagate_plugins() {
for (auto* p : m_plugins)
if (p)
p->propagate();
}
void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) {
TRACE("euf_verbose", tout << "eq: " << v1 << " == " << v2 << "\n";);
m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r));
m_updates.push_back(update_record(update_record::new_th_eq()));
++m_stats.m_num_th_eqs;
auto* p = get_plugin(id);
if (p)
p->merge_eh(c, r);
}
void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq) {
void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) {
if (!th_propagates_diseqs(id))
return;
TRACE("euf_verbose", tout << "eq: " << v1 << " != " << v2 << "\n";);
m_new_th_eqs.push_back(th_eq(id, v1, v2, eq));
m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr()));
m_updates.push_back(update_record(update_record::new_th_eq()));
auto* p = get_plugin(id);
if (p)
p->diseq_eh(eq);
++m_stats.m_num_th_diseqs;
}
@ -202,7 +219,7 @@ namespace euf {
return;
theory_var v1 = arg1->get_closest_th_var(id);
theory_var v2 = arg2->get_closest_th_var(id);
add_th_diseq(id, v1, v2, n->get_expr());
add_th_diseq(id, v1, v2, n);
return;
}
for (auto const& p : euf::enode_th_vars(r1)) {
@ -210,8 +227,8 @@ namespace euf {
continue;
for (auto const& q : euf::enode_th_vars(r2))
if (p.get_id() == q.get_id())
add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n->get_expr());
}
add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n);
}
}
@ -230,7 +247,7 @@ namespace euf {
n = n->get_root();
theory_var v2 = n->get_closest_th_var(id);
if (v2 != null_theory_var)
add_th_diseq(id, v1, v2, p->get_expr());
add_th_diseq(id, v1, v2, p);
}
}
}
@ -249,6 +266,10 @@ namespace euf {
theory_var w = n->get_th_var(id);
enode* r = n->get_root();
auto* p = get_plugin(id);
if (p)
p->register_node(n);
if (w == null_theory_var) {
n->add_th_var(v, id, m_region);
m_updates.push_back(update_record(n, id, update_record::add_th_var()));
@ -424,6 +445,9 @@ namespace euf {
p.r1->m_args[i]->get_root()->m_parents.pop_back();
}
break;
case update_record::tag_t::is_plugin_undo:
m_plugins[p.m_th_id]->undo();
break;
default:
UNREACHABLE();
break;
@ -589,6 +613,9 @@ namespace euf {
case to_merge_comm:
merge(w.a, w.b, justification::congruence(w.commutativity(), m_congruence_timestamp++));
break;
case to_justified:
merge(w.a, w.b, w.j);
break;
case to_add_literal:
add_literal(w.a, w.b);
break;
@ -760,6 +787,13 @@ namespace euf {
justifications.push_back(j.ext<T>());
else if (j.is_congruence())
push_congruence(a, b, j.is_commutative());
else if (j.is_dependent()) {
vector<justification, false> js;
for (auto const& j2 : justification::dependency_manager::s_linearize(j.get_dependency(), js))
explain_eq(justifications, cc, a, b, j2);
}
else if (j.is_equality())
explain_eq(justifications, cc, j.lhs(), j.rhs());
if (cc && j.is_congruence())
cc->push_back(std::tuple(a->get_app(), b->get_app(), j.timestamp(), j.is_commutative()));
}
@ -879,6 +913,9 @@ namespace euf {
max_args = std::max(max_args, n->num_args());
for (enode* n : m_nodes)
display(out, max_args, n);
for (auto* p : m_plugins)
if (p)
p->display(out);
return out;
}

View file

@ -29,8 +29,10 @@ Notes:
#include "util/statistics.h"
#include "util/trail.h"
#include "util/lbool.h"
#include "util/scoped_ptr_vector.h"
#include "ast/euf/euf_enode.h"
#include "ast/euf/euf_etable.h"
#include "ast/euf/euf_plugin.h"
#include "ast/ast_ll_pp.h"
#include <vector>
@ -82,14 +84,18 @@ namespace euf {
class egraph {
friend class plugin;
typedef ptr_vector<trail> trail_stack;
enum to_merge_t { to_merge_plain, to_merge_comm, to_add_literal };
enum to_merge_t { to_merge_plain, to_merge_comm, to_justified, to_add_literal };
struct to_merge {
enode* a, * b;
to_merge_t t;
justification j;
bool commutativity() const { return t == to_merge_comm; }
to_merge(enode* a, enode* b, bool c) : a(a), b(b), t(c ? to_merge_comm : to_merge_plain) {}
to_merge(enode* a, enode* b, justification j): a(a), b(b), t(to_justified), j(j) {}
to_merge(enode* p, enode* ante): a(p), b(ante), t(to_add_literal) {}
};
@ -116,10 +122,12 @@ namespace euf {
struct lbl_set {};
struct update_children {};
struct set_relevant {};
struct plugin_undo {};
enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children,
is_add_th_var, is_replace_th_var, is_new_th_eq,
is_lbl_hash, is_new_th_eq_qhead,
is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant };
is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant,
is_plugin_undo };
tag_t tag;
enode* r1;
enode* n1;
@ -162,11 +170,14 @@ namespace euf {
tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {}
update_record(enode* n, set_relevant) :
tag(tag_t::is_set_relevant), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {}
update_record(unsigned th_id, plugin_undo) :
tag(tag_t::is_plugin_undo), r1(nullptr), n1(nullptr), m_th_id(th_id) {}
};
ast_manager& m;
svector<to_merge> m_to_merge;
etable m_table;
region m_region;
scoped_ptr_vector<plugin> m_plugins;
svector<update_record> m_updates;
unsigned_vector m_scopes;
enode_vector m_expr2enode;
@ -205,6 +216,12 @@ namespace euf {
}
void push_node(enode* n) { m_updates.push_back(update_record(n)); }
// plugin related methods
void push_plugin_undo(unsigned th_id) { m_updates.push_back(update_record(th_id, update_record::plugin_undo())); }
void push_merge(enode* a, enode* b, justification j) { m_to_merge.push_back({ a, b, j }); }
void push_merge(enode* a, enode* b, bool comm) { m_to_merge.push_back({ a, b, comm }); }
void propagate_plugins();
void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r);
void add_th_diseqs(theory_id id, theory_var v1, enode* r);
@ -245,11 +262,15 @@ namespace euf {
public:
egraph(ast_manager& m);
~egraph();
void add_plugin(plugin* p);
plugin* get_plugin(family_id fid) const { return m_plugins.get(fid, nullptr); }
enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); }
enode* find(expr* f, unsigned n, enode* const* args);
enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args);
enode_vector const& enodes_of(func_decl* f);
void push() { if (!m_to_merge.empty()) propagate(); ++m_num_scopes; }
void push() { if (can_propagate()) propagate(); ++m_num_scopes; }
void pop(unsigned num_scopes);
/**
@ -269,6 +290,7 @@ namespace euf {
of new equalities.
*/
bool propagate();
bool can_propagate() const { return !m_to_merge.empty(); }
bool inconsistent() const { return m_inconsistent; }
/**
@ -286,7 +308,7 @@ namespace euf {
where \c n is an enode and \c is_eq indicates whether the enode
is an equality consequence.
*/
void add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq);
void add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq);
bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); }
th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; }
void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; }

View file

@ -93,6 +93,17 @@ namespace euf {
return null_theory_var;
}
enode* enode::get_closest_th_node(theory_id id) {
enode* n = this;
while (n) {
theory_var v = n->get_th_var(id);
if (v != null_theory_var)
return n;
n = n->m_target;
}
return nullptr;
}
bool enode::acyclic() const {
enode const* n = this;
enode const* p = this;

View file

@ -207,6 +207,7 @@ namespace euf {
enode* get_root() const { return m_root; }
expr* get_expr() const { return m_expr; }
sort* get_sort() const { return m_expr->get_sort(); }
enode* get_interpreted() const { return get_root(); }
app* get_app() const { return to_app(m_expr); }
func_decl* get_decl() const { return is_app(m_expr) ? to_app(m_expr)->get_decl() : nullptr; }
unsigned get_expr_id() const { return m_expr->get_id(); }
@ -216,6 +217,10 @@ namespace euf {
bool children_are_roots() const;
enode* get_next() const { return m_next; }
enode* get_target() const { return m_target; }
justification get_justification() const { return m_justification; }
justification get_lit_justification() const { return m_lit_justification; }
bool has_lbl_hash() const { return m_lbl_hash >= 0; }
unsigned char get_lbl_hash() const {
SASSERT(m_lbl_hash >= 0 && static_cast<unsigned>(m_lbl_hash) < approx_set_traits<unsigned long long>::capacity);
@ -229,6 +234,7 @@ namespace euf {
theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); }
theory_var get_closest_th_var(theory_id id) const;
enode* get_closest_th_node(theory_id id);
bool is_attached_to(theory_id id) const { return get_th_var(id) != null_theory_var; }
bool has_th_vars() const { return !m_th_vars.empty(); }
bool has_one_th_var() const { return !m_th_vars.empty() && !m_th_vars.get_next();}

View file

@ -0,0 +1,54 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
euf_justification.cpp
Abstract:
justification structure for euf
Author:
Nikolaj Bjorner (nbjorner) 2020-08-23
--*/
#include "ast/euf/euf_justification.h"
#include "ast/euf/euf_enode.h"
namespace euf {
std::ostream& justification::display(std::ostream& out, std::function<void (std::ostream&, void*)> const& ext) const {
switch (m_kind) {
case kind_t::external_t:
if (ext)
ext(out, m_external);
else
out << "external";
return out;
case kind_t::axiom_t:
return out << "axiom";
case kind_t::congruence_t:
return out << "congruence";
case kind_t::dependent_t: {
vector<justification, false> js;
out << "dependent";
for (auto const& j : dependency_manager::s_linearize(m_dependency, js))
j.display(out << " ", ext);
return out;
}
case kind_t::equality_t:
return out << "equality #" << m_n1->get_id() << " == #" << m_n2->get_id();
default:
UNREACHABLE();
return out;
}
return out;
}
}

View file

@ -22,19 +22,34 @@ Notes:
#pragma once
#include "util/dependency.h"
namespace euf {
class enode;
class justification {
public:
typedef stacked_dependency_manager<justification> dependency_manager;
typedef stacked_dependency_manager<justification>::dependency dependency;
private:
enum class kind_t {
axiom_t,
congruence_t,
external_t
external_t,
dependent_t,
equality_t
};
kind_t m_kind;
bool m_comm;
union {
bool m_comm;
enode* m_n1;
};
union {
void* m_external;
uint64_t m_timestamp;
dependency* m_dependency;
enode* m_n2;
};
justification(bool comm, uint64_t ts):
@ -49,6 +64,18 @@ namespace euf {
m_external(ext)
{}
justification(dependency* dep, int):
m_kind(kind_t::dependent_t),
m_comm(false),
m_dependency(dep)
{}
justification(enode* n1, enode* n2):
m_kind(kind_t::equality_t),
m_n1(n1),
m_n2(n2)
{}
public:
justification():
m_kind(kind_t::axiom_t),
@ -59,10 +86,17 @@ namespace euf {
static justification axiom() { return justification(); }
static justification congruence(bool c, uint64_t ts) { return justification(c, ts); }
static justification external(void* ext) { return justification(ext); }
static justification dependent(dependency* d) { return justification(d, 1); }
static justification equality(enode* a, enode* b) { return justification(a, b); }
bool is_external() const { return m_kind == kind_t::external_t; }
bool is_congruence() const { return m_kind == kind_t::congruence_t; }
bool is_commutative() const { return m_comm; }
bool is_dependent() const { return m_kind == kind_t::dependent_t; }
bool is_equality() const { return m_kind == kind_t::equality_t; }
dependency* get_dependency() const { SASSERT(is_dependent()); return m_dependency; }
enode* lhs() const { SASSERT(is_equality()); return m_n1; }
enode* rhs() const { SASSERT(is_equality()); return m_n2; }
uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; }
template <typename T>
T* ext() const { SASSERT(is_external()); return static_cast<T*>(m_external); }
@ -75,30 +109,17 @@ namespace euf {
return axiom();
case kind_t::congruence_t:
return congruence(m_comm, m_timestamp);
case kind_t::dependent_t:
NOT_IMPLEMENTED_YET();
return dependent(m_dependency);
default:
UNREACHABLE();
return axiom();
}
}
std::ostream& display(std::ostream& out, std::function<void (std::ostream&, void*)> const& ext) const {
switch (m_kind) {
case kind_t::external_t:
if (ext)
ext(out, m_external);
else
out << "external";
return out;
case kind_t::axiom_t:
return out << "axiom";
case kind_t::congruence_t:
return out << "congruence";
default:
UNREACHABLE();
return out;
}
return out;
}
std::ostream& display(std::ostream& out, std::function<void(std::ostream&, void*)> const& ext) const;
};
inline std::ostream& operator<<(std::ostream& out, justification const& j) {

View file

@ -0,0 +1,47 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_plugin.cpp
Abstract:
plugin structure for euf
Plugins allow adding equality saturation for theories.
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
--*/
#include "ast/euf/euf_egraph.h"
namespace euf {
void plugin::push_plugin_undo(unsigned th_id) {
g.push_plugin_undo(th_id);
}
void plugin::push_merge(enode* a, enode* b, justification j) {
g.push_merge(a, b, j);
}
void plugin::push_merge(enode* a, enode* b) {
TRACE("plugin", tout << g.bpp(a) << " == " << g.bpp(b) << "\n");
g.push_merge(a, b, justification::axiom());
}
enode* plugin::mk(expr* e, unsigned n, enode* const* args) {
enode* r = g.find(e);
if (!r)
r = g.mk(e, 0, n, args);
return r;
}
region& plugin::get_region() {
return g.m_region;
}
}

58
src/ast/euf/euf_plugin.h Normal file
View file

@ -0,0 +1,58 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_plugin.h
Abstract:
plugin structure for euf
Plugins allow adding equality saturation for theories.
Author:
Nikolaj Bjorner (nbjorner) 2023-11-08
--*/
#pragma once
#include "ast/euf/euf_enode.h"
#include "ast/euf/euf_justification.h"
namespace euf {
class plugin {
protected:
egraph& g;
void push_plugin_undo(unsigned th_id);
void push_merge(enode* a, enode* b, justification j);
void push_merge(enode* a, enode* b);
enode* mk(expr* e, unsigned n, enode* const* args);
region& get_region();
public:
plugin(egraph& g):
g(g)
{}
virtual ~plugin() {}
virtual unsigned get_id() const = 0;
virtual void register_node(enode* n) = 0;
virtual void merge_eh(enode* n1, enode* n2) = 0;
virtual void diseq_eh(enode* eq) {};
virtual void propagate() = 0;
virtual void undo() = 0;
virtual std::ostream& display(std::ostream& out) const = 0;
};
}

View file

@ -0,0 +1,71 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_specrel_plugin.cpp
Abstract:
plugin structure for specrel
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#include "ast/euf/euf_specrel_plugin.h"
#include "ast/euf/euf_egraph.h"
#include <algorithm>
namespace euf {
specrel_plugin::specrel_plugin(egraph& g) :
plugin(g),
sp(g.get_manager()) {
}
void specrel_plugin::register_node(enode* n) {
func_decl* f = n->get_decl();
if (!f)
return;
if (!sp.is_ac(f))
return;
ac_plugin* p = nullptr;
if (!m_decl2plugin.find(f, p)) {
p = alloc(ac_plugin, g, f);
m_decl2plugin.insert(f, p);
m_plugins.push_back(p);
std::function<void(void)> undo_op = [&]() { m_undo.push_back(p); };
p->set_undo(undo_op);
}
}
void specrel_plugin::merge_eh(enode* n1, enode* n2) {
for (auto * p : m_plugins)
p->merge_eh(n1, n2);
}
void specrel_plugin::diseq_eh(enode* eq) {
for (auto* p : m_plugins)
p->diseq_eh(eq);
}
void specrel_plugin::propagate() {
for (auto * p : m_plugins)
p->propagate();
}
void specrel_plugin::undo() {
auto p = m_undo.back();
m_undo.pop_back();
p->undo();
}
std::ostream& specrel_plugin::display(std::ostream& out) const {
for (auto * p : m_plugins)
p->display(out);
return out;
}
}

View file

@ -0,0 +1,56 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_specrel_plugin.h
Abstract:
plugin structure for specrel functions
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#pragma once
#include <iostream>
#include "util/scoped_ptr_vector.h"
#include "ast/special_relations_decl_plugin.h"
#include "ast/euf/euf_plugin.h"
#include "ast/euf/euf_ac_plugin.h"
namespace euf {
class specrel_plugin : public plugin {
scoped_ptr_vector<ac_plugin> m_plugins;
ptr_vector<ac_plugin> m_undo;
obj_map<func_decl, ac_plugin*> m_decl2plugin;
special_relations_util sp;
public:
specrel_plugin(egraph& g);
~specrel_plugin() override {}
unsigned get_id() const override { return sp.get_family_id(); }
void register_node(enode* n) override;
void merge_eh(enode* n1, enode* n2) override;
void diseq_eh(enode* eq) override;
void undo() override;
void propagate() override;
std::ostream& display(std::ostream& out) const override;
};
}

View file

@ -26,7 +26,8 @@ special_relations_decl_plugin::special_relations_decl_plugin():
m_po("partial-order"),
m_plo("piecewise-linear-order"),
m_to("tree-order"),
m_tc("transitive-closure")
m_tc("transitive-closure"),
m_ac("ac-op")
{}
func_decl * special_relations_decl_plugin::mk_func_decl(
@ -41,24 +42,53 @@ func_decl * special_relations_decl_plugin::mk_func_decl(
m_manager->raise_exception("argument sort missmatch. The two arguments should have the same sort");
return nullptr;
}
if (!range && k == OP_SPECIAL_RELATION_AC)
range = domain[0];
if (!range) {
range = m_manager->mk_bool_sort();
}
if (!m_manager->is_bool(range)) {
m_manager->raise_exception("range type is expected to be Boolean for special relations");
}
auto check_bool_range = [&]() {
if (!m_manager->is_bool(range))
m_manager->raise_exception("range type is expected to be Boolean for special relations");
};
m_has_special_relation = true;
func_decl_info info(m_family_id, k, num_parameters, parameters);
symbol name;
switch(k) {
case OP_SPECIAL_RELATION_PO: name = m_po; break;
case OP_SPECIAL_RELATION_LO: name = m_lo; break;
case OP_SPECIAL_RELATION_PLO: name = m_plo; break;
case OP_SPECIAL_RELATION_TO: name = m_to; break;
case OP_SPECIAL_RELATION_PO: check_bool_range(); name = m_po; break;
case OP_SPECIAL_RELATION_LO: check_bool_range(); name = m_lo; break;
case OP_SPECIAL_RELATION_PLO: check_bool_range(); name = m_plo; break;
case OP_SPECIAL_RELATION_TO: check_bool_range(); name = m_to; break;
case OP_SPECIAL_RELATION_AC: {
if (range != domain[0])
m_manager->raise_exception("AC operation should have the same range as domain type");
name = m_ac;
if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast()))
m_manager->raise_exception("parameter to transitive closure should be a function declaration");
func_decl* f = to_func_decl(parameters[0].get_ast());
if (f->get_arity() != 2)
m_manager->raise_exception("ac function should be binary");
if (f->get_domain(0) != f->get_domain(1))
m_manager->raise_exception("ac function should have same domain");
if (f->get_domain(0) != f->get_range())
m_manager->raise_exception("ac function should have same domain and range");
break;
}
case OP_SPECIAL_RELATION_TC:
check_bool_range();
name = m_tc;
if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast()))
m_manager->raise_exception("parameter to transitive closure should be a function declaration");
func_decl* f = to_func_decl(parameters[0].get_ast());
if (f->get_arity() != 2)
m_manager->raise_exception("tc relation should be binary");
if (f->get_domain(0) != f->get_domain(1))
m_manager->raise_exception("tc relation should have same domain");
if (!m_manager->is_bool(f->get_range()))
m_manager->raise_exception("tc relation should be Boolean");
break;
}
return m_manager->mk_func_decl(name, arity, domain, range, info);
@ -71,6 +101,7 @@ void special_relations_decl_plugin::get_op_names(svector<builtin_name> & op_name
op_names.push_back(builtin_name(m_plo.str(), OP_SPECIAL_RELATION_PLO));
op_names.push_back(builtin_name(m_to.str(), OP_SPECIAL_RELATION_TO));
op_names.push_back(builtin_name(m_tc.str(), OP_SPECIAL_RELATION_TC));
op_names.push_back(builtin_name(m_ac.str(), OP_SPECIAL_RELATION_AC));
}
}
@ -81,6 +112,7 @@ sr_property special_relations_util::get_property(func_decl* f) const {
case OP_SPECIAL_RELATION_PLO: return sr_plo;
case OP_SPECIAL_RELATION_TO: return sr_to;
case OP_SPECIAL_RELATION_TC: return sr_tc;
case OP_SPECIAL_RELATION_AC: return sr_none;
default:
UNREACHABLE();
return sr_po;

View file

@ -16,6 +16,8 @@ Author:
Revision History:
2023-11-27: Added ac-op for E-graph plugin
--*/
#pragma once
@ -28,6 +30,7 @@ enum special_relations_op_kind {
OP_SPECIAL_RELATION_PLO,
OP_SPECIAL_RELATION_TO,
OP_SPECIAL_RELATION_TC,
OP_SPECIAL_RELATION_AC,
LAST_SPECIAL_RELATIONS_OP
};
@ -37,6 +40,7 @@ class special_relations_decl_plugin : public decl_plugin {
symbol m_plo;
symbol m_to;
symbol m_tc;
symbol m_ac;
bool m_has_special_relation = false;
public:
special_relations_decl_plugin();
@ -86,13 +90,16 @@ class special_relations_util {
public:
special_relations_util(ast_manager& m) : m(m), m_fid(null_family_id) { }
family_id get_family_id() const { return fid(); }
bool has_special_relation() const { return static_cast<special_relations_decl_plugin*>(m.get_plugin(m.mk_family_id("specrels")))->has_special_relation(); }
bool is_special_relation(func_decl* f) const { return f->get_family_id() == fid(); }
bool is_special_relation(app* e) const { return is_special_relation(e->get_decl()); }
bool is_special_relation(expr* e) const { return is_app(e) && is_special_relation(to_app(e)->get_decl()); }
sr_property get_property(func_decl* f) const;
sr_property get_property(app* e) const { return get_property(e->get_decl()); }
func_decl* get_relation(func_decl* f) const { SASSERT(is_special_relation(f)); return to_func_decl(f->get_parameter(0).get_ast()); }
func_decl* get_relation(expr* e) const { SASSERT(is_special_relation(e)); return to_func_decl(to_app(e)->get_parameter(0).get_ast()); }
func_decl* mk_to_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_TO); }
func_decl* mk_po_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_PO); }
@ -105,12 +112,14 @@ public:
bool is_plo(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_PLO); }
bool is_to(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TO); }
bool is_tc(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TC); }
bool is_ac(expr const* e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_AC); }
bool is_lo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_LO); }
bool is_po(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PO); }
bool is_plo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PLO); }
bool is_to(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TO); }
bool is_tc(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TC); }
bool is_ac(func_decl const* e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_AC); }
app * mk_lo (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_LO, arg1, arg2); }
app * mk_po (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_PO, arg1, arg2); }

View file

@ -44,6 +44,7 @@ z3_add_component(sat_smt
q_solver.cpp
recfun_solver.cpp
sat_th.cpp
specrel_solver.cpp
tseitin_theory_checker.cpp
user_solver.cpp
COMPONENT_DEPENDENCIES

View file

@ -28,6 +28,7 @@ Author:
#include "sat/smt/fpa_solver.h"
#include "sat/smt/dt_solver.h"
#include "sat/smt/recfun_solver.h"
#include "sat/smt/specrel_solver.h"
namespace euf {
@ -130,6 +131,7 @@ namespace euf {
arith_util arith(m);
datatype_util dt(m);
recfun::util rf(m);
special_relations_util sp(m);
if (pb.get_family_id() == fid)
ext = alloc(pb::solver, *this, fid);
else if (bvu.get_family_id() == fid)
@ -144,6 +146,8 @@ namespace euf {
ext = alloc(dt::solver, *this, fid);
else if (rf.get_family_id() == fid)
ext = alloc(recfun::solver, *this);
else if (sp.get_family_id() == fid)
ext = alloc(specrel::solver, *this, fid);
if (ext)
add_solver(ext);

View file

@ -0,0 +1,120 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
specrel_solver.h
Abstract:
Theory plugin for special relations
Author:
Nikolaj Bjorner (nbjorner) 2020-09-08
--*/
#include "sat/smt/specrel_solver.h"
#include "sat/smt/euf_solver.h"
#include "ast/euf/euf_specrel_plugin.h"
namespace euf {
class solver;
}
namespace specrel {
solver::solver(euf::solver& ctx, theory_id id) :
th_euf_solver(ctx, ctx.get_manager().get_family_name(id), id),
sp(m)
{
ctx.get_egraph().add_plugin(alloc(euf::specrel_plugin, ctx.get_egraph()));
}
solver::~solver() {
}
void solver::asserted(sat::literal l) {
}
sat::check_result solver::check() {
return sat::check_result::CR_DONE;
}
std::ostream& solver::display(std::ostream& out) const {
return out;
}
void solver::collect_statistics(statistics& st) const {
}
euf::th_solver* solver::clone(euf::solver& ctx) {
return alloc(solver, ctx, get_id());
}
void solver::new_eq_eh(euf::th_eq const& eq) {
TRACE("specrel", tout << "new-eq\n");
if (eq.is_eq()) {
auto* p = ctx.get_egraph().get_plugin(sp.get_family_id());
p->merge_eh(var2enode(eq.v1()), var2enode(eq.v2()));
TRACE("specrel", tout << eq.v1() << " " << eq.v2() << "\n");
}
}
void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) {
}
bool solver::add_dep(euf::enode* n, top_sort<euf::enode>& dep) {
return false;
}
bool solver::include_func_interp(func_decl* f) const {
return false;
}
sat::literal solver::internalize(expr* e, bool sign, bool root) {
if (!visit_rec(m, e, sign, root))
return sat::null_literal;
auto lit = ctx.expr2literal(e);
if (sign)
lit.neg();
return lit;
}
void solver::internalize(expr* e) {
visit_rec(m, e, false, false);
}
bool solver::visit(expr* e) {
if (visited(e))
return true;
m_stack.push_back(sat::eframe(e));
return false;
}
bool solver::visited(expr* e) {
euf::enode* n = expr2enode(e);
return n && n->is_attached_to(get_id());
}
bool solver::post_visit(expr* term, bool sign, bool root) {
euf::enode* n = expr2enode(term);
SASSERT(!n || !n->is_attached_to(get_id()));
if (!n)
n = mk_enode(term);
SASSERT(!n->is_attached_to(get_id()));
mk_var(n);
TRACE("specrel", tout << ctx.bpp(n) << "\n");
return true;
}
euf::theory_var solver::mk_var(euf::enode* n) {
if (is_attached_to_var(n))
return n->get_th_var(get_id());
euf::theory_var r = th_euf_solver::mk_var(n);
ctx.attach_th_var(n, this, r);
return r;
}
}

View file

@ -0,0 +1,75 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
specrel_solver.h
Abstract:
Theory plugin for special relations
Author:
Nikolaj Bjorner (nbjorner) 2020-09-08
--*/
#pragma once
#include "sat/smt/sat_th.h"
#include "ast/special_relations_decl_plugin.h"
namespace euf {
class solver;
}
namespace specrel {
class solver : public euf::th_euf_solver {
typedef euf::theory_var theory_var;
typedef euf::theory_id theory_id;
typedef euf::enode enode;
typedef euf::enode_pair enode_pair;
typedef euf::enode_pair_vector enode_pair_vector;
typedef sat::bool_var bool_var;
typedef sat::literal literal;
typedef sat::literal_vector literal_vector;
special_relations_util sp;
public:
solver(euf::solver& ctx, theory_id id);
~solver() override;
bool is_external(bool_var v) override { return false; }
void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {}
void asserted(literal l) override;
sat::check_result check() override;
std::ostream& display(std::ostream& out) const override;
std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return euf::th_explain::from_index(idx).display(out); }
std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return display_justification(out, idx); }
void collect_statistics(statistics& st) const override;
euf::th_solver* clone(euf::solver& ctx) override;
void new_eq_eh(euf::th_eq const& eq) override;
bool unit_propagate() override { return false; }
void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override;
bool add_dep(euf::enode* n, top_sort<euf::enode>& dep) override;
bool include_func_interp(func_decl* f) const override;
sat::literal internalize(expr* e, bool sign, bool root) override;
void internalize(expr* e) override;
bool visit(expr* e) override;
bool visited(expr* e) override;
bool post_visit(expr* e, bool sign, bool root) override;
euf::theory_var mk_var(euf::enode* n) override;
void apply_sort_cnstr(euf::enode* n, sort* s) override {}
bool is_shared(theory_var v) const override { return false; }
lbool get_phase(bool_var v) override { return l_true; }
bool enable_self_propagate() const override { return true; }
void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2);
void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {}
void unmerge_eh(theory_var v1, theory_var v2) {}
};
}

View file

@ -39,6 +39,8 @@ add_executable(test-z3
doc.cpp
egraph.cpp
escaped.cpp
euf_bv_plugin.cpp
euf_arith_plugin.cpp
ex.cpp
expr_rand.cpp
expr_substitution.cpp

View file

@ -0,0 +1,106 @@
/*++
Copyright (c) 2023 Microsoft Corporation
--*/
#include "util/util.h"
#include "util/timer.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_arith_plugin.h"
#include "ast/reg_decl_plugins.h"
#include "ast/ast_pp.h"
#include <iostream>
unsigned s_var = 0;
static euf::enode* get_node(euf::egraph& g, arith_util& a, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, a, arg));
n = g.mk(e, 0, args.size(), args.data());
g.add_th_var(n, s_var++, a.get_family_id());
return n;
}
//
static void test1() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::arith_plugin, g));
arith_util a(m);
sort_ref I(a.mk_int(), m);
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nx = get_node(g, a, a.mk_add(a.mk_add(y, y), a.mk_add(x, x)));
auto* ny = get_node(g, a, a.mk_add(a.mk_add(y, x), x));
TRACE("plugin", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr);
g.propagate();
std::cout << g << "\n";
}
static void test2() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::arith_plugin, g));
arith_util a(m);
sort_ref I(a.mk_int(), m);
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nxy = get_node(g, a, a.mk_add(x, y));
auto* nyx = get_node(g, a, a.mk_add(y, x));
auto* nx = get_node(g, a, x);
auto* ny = get_node(g, a, y);
TRACE("plugin", tout << "before merge\n" << g << "\n");
g.merge(nxy, nx, nullptr);
g.merge(nyx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
SASSERT(nx->get_root() == ny->get_root());
g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr);
g.propagate();
std::cout << g << "\n";
}
static void test3() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::arith_plugin, g));
arith_util a(m);
sort_ref I(a.mk_int(), m);
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nxyy = get_node(g, a, a.mk_add(a.mk_add(x, y), y));
auto* nyxx = get_node(g, a, a.mk_add(a.mk_add(y, x), x));
auto* nx = get_node(g, a, x);
auto* ny = get_node(g, a, y);
g.merge(nxyy, nx, nullptr);
g.merge(nyxx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
void tst_euf_arith_plugin() {
enable_trace("plugin");
test1();
test2();
test3();
}

183
src/test/euf_bv_plugin.cpp Normal file
View file

@ -0,0 +1,183 @@
/*++
Copyright (c) 2023 Microsoft Corporation
--*/
#include "util/util.h"
#include "util/timer.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/reg_decl_plugins.h"
#include "ast/ast_pp.h"
#include <iostream>
static unsigned s_var = 0;
static euf::enode* get_node(euf::egraph& g, bv_util& b, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, b, arg));
n = g.mk(e, 0, args.size(), args.data());
g.add_th_var(n, s_var++, b.get_family_id());
return n;
}
// align slices, and propagate extensionality
static void test1() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref y(m.mk_const("y", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref y3(bv.mk_extract(31, 24, y), m);
expr_ref y2(bv.mk_extract(23, 8, y), m);
expr_ref y1(bv.mk_extract(7, 0, y), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m);
auto* nx = get_node(g, bv, xx);
auto* ny = get_node(g, bv, yy);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
SASSERT(nx->get_root() == ny->get_root());
}
// propagate values down
static void test2() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
g.merge(get_node(g, bv, xx), get_node(g, bv, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr);
g.propagate();
SASSERT(get_node(g, bv, x1)->get_root()->interpreted());
SASSERT(get_node(g, bv, x2)->get_root()->interpreted());
SASSERT(get_node(g, bv, x3)->get_root()->interpreted());
SASSERT(get_node(g, bv, x)->get_root()->interpreted());
}
// propagate values up
static void test3() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m);
expr_ref y(m.mk_const("y", u32), m);
g.merge(get_node(g, bv, xx), get_node(g, bv, y), nullptr);
g.merge(get_node(g, bv, x1), get_node(g, bv, bv.mk_numeral(2, 8)), nullptr);
g.merge(get_node(g, bv, x2), get_node(g, bv, bv.mk_numeral(8, 8)), nullptr);
g.propagate();
SASSERT(get_node(g, bv, bv.mk_concat(x1, x2))->get_root()->interpreted());
SASSERT(get_node(g, bv, x1)->get_root()->interpreted());
SASSERT(get_node(g, bv, x2)->get_root()->interpreted());
}
// propagate extract up
static void test4() {
// concat(a, x[J]), a = x[I] => x[IJ] = concat(x[I],x[J])
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
sort_ref u8(bv.mk_sort(8), m);
sort_ref u16(bv.mk_sort(16), m);
expr_ref a(m.mk_const("a", u8), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref y(m.mk_const("y", u16), m);
expr_ref x1(bv.mk_extract(15, 8, x), m);
expr_ref x2(bv.mk_extract(23, 16, x), m);
g.merge(get_node(g, bv, bv.mk_concat(a, x2)), get_node(g, bv, y), nullptr);
g.merge(get_node(g, bv, x1), get_node(g, bv, a), nullptr);
g.propagate();
TRACE("bv", tout << g << "\n");
SASSERT(get_node(g, bv, bv.mk_extract(23, 8, x))->get_root() == get_node(g, bv, y)->get_root());
}
// iterative slicing
static void test5() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 4, x), m);
expr_ref x2(bv.mk_extract(27, 0, x), m);
auto* nx = get_node(g, bv, x1);
auto* ny = get_node(g, bv, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
// iterative slicing
static void test6() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 3, x), m);
expr_ref x2(bv.mk_extract(28, 0, x), m);
auto* nx = get_node(g, bv, x1);
auto* ny = get_node(g, bv, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
void tst_euf_bv_plugin() {
enable_trace("bv");
enable_trace("plugin");
test6();
return;
test1();
test2();
test3();
test4();
test5();
test6();
}

View file

@ -265,4 +265,6 @@ int main(int argc, char ** argv) {
TST(finder);
TST(totalizer);
TST(distribution);
TST(euf_bv_plugin);
TST(euf_arith_plugin);
}

View file

@ -44,8 +44,39 @@ public:
public:
unsigned get_ref_count() const { return m_ref_count; }
bool is_leaf() const { return m_leaf == 1; }
value const& leaf_value() const { SASSERT(is_leaf()); return static_cast<leaf const*>(this)->m_value; }
};
static void linearize_todo(ptr_vector<dependency>& todo, vector<value, false>& vs) {
unsigned qhead = 0;
while (qhead < todo.size()) {
dependency* d = todo[qhead];
qhead++;
if (d->is_leaf()) {
vs.push_back(to_leaf(d)->m_value);
}
else {
for (unsigned i = 0; i < 2; i++) {
dependency* child = to_join(d)->m_children[i];
if (!child->is_marked()) {
todo.push_back(child);
child->mark();
}
}
}
}
for (auto* d : todo)
d->unmark();
}
static void s_linearize(dependency* d, vector<value, false>& vs) {
if (!d)
return;
ptr_vector<dependency> todo;
todo.push_back(d);
linearize_todo(todo, vs);
}
private:
struct join : public dependency {
dependency * m_children[2];
@ -69,7 +100,7 @@ private:
value_manager & m_vmanager;
allocator & m_allocator;
mutable ptr_vector<dependency> m_todo;
ptr_vector<dependency> m_todo;
void inc_ref(value const & v) {
if (C::ref_count)
@ -83,6 +114,7 @@ private:
void del(dependency * d) {
SASSERT(d);
SASSERT(m_todo.empty());
m_todo.push_back(d);
while (!m_todo.empty()) {
d = m_todo.back();
@ -106,8 +138,8 @@ private:
}
}
void unmark_todo() const {
for (auto* d : m_todo)
void unmark_todo() {
for (auto* d : m_todo)
d->unmark();
m_todo.reset();
}
@ -190,30 +222,30 @@ public:
return false;
}
void linearize(dependency * d, vector<value, false> & vs) const {
if (d) {
m_todo.reset();
d->mark();
m_todo.push_back(d);
unsigned qhead = 0;
while (qhead < m_todo.size()) {
d = m_todo[qhead];
qhead++;
if (d->is_leaf()) {
vs.push_back(to_leaf(d)->m_value);
}
else {
for (unsigned i = 0; i < 2; i++) {
dependency * child = to_join(d)->m_children[i];
if (!child->is_marked()) {
m_todo.push_back(child);
child->mark();
}
}
}
void linearize(dependency * d, vector<value, false> & vs) {
if (!d)
return;
SASSERT(m_todo.empty());
d->mark();
m_todo.push_back(d);
linearize_todo(m_todo, vs);
m_todo.reset();
}
void linearize(ptr_vector<dependency>& deps, vector<value, false> & vs) {
if (deps.empty())
return;
SASSERT(m_todo.empty());
for (auto* d : deps) {
if (d && !d->is_marked()) {
d->mark();
m_todo.push_back(d);
}
unmark_todo();
}
linearize_todo(m_todo, vs);
m_todo.reset();
}
};
@ -297,7 +329,16 @@ public:
return m_dep_manager.contains(d, v);
}
void linearize(dependency * d, vector<value, false> & vs) const {
void linearize(dependency * d, vector<value, false> & vs) {
return m_dep_manager.linearize(d, vs);
}
static vector<value, false> const& s_linearize(dependency* d, vector<value, false>& vs) {
dep_manager::s_linearize(d, vs);
return vs;
}
void linearize(ptr_vector<dependency>& d, vector<value, false> & vs) {
return m_dep_manager.linearize(d, vs);
}
@ -320,4 +361,83 @@ typedef scoped_dependency_manager<void*>::dependency v_dependency;
typedef scoped_dependency_manager<unsigned> u_dependency_manager;
typedef scoped_dependency_manager<unsigned>::dependency u_dependency;
/**
\brief Version of the scoped-depenendcy-manager where region scopes are handled externally.
*/
template<typename Value>
class stacked_dependency_manager {
class config {
public:
static const bool ref_count = true;
typedef Value value;
class value_manager {
public:
void inc_ref(value const& v) {
}
void dec_ref(value const& v) {
}
};
class allocator {
region& m_region;
public:
allocator(region& r) : m_region(r) {}
void* allocate(size_t sz) {
return m_region.allocate(sz);
}
void deallocate(size_t sz, void* mem) {
}
};
};
typedef dependency_manager<config> dep_manager;
public:
typedef typename dep_manager::dependency dependency;
typedef Value value;
private:
typename config::value_manager m_vmanager;
typename config::allocator m_allocator;
dep_manager m_dep_manager;
public:
stacked_dependency_manager(region& r) :
m_allocator(r),
m_dep_manager(m_vmanager, m_allocator) {
}
dependency* mk_empty() {
return m_dep_manager.mk_empty();
}
dependency* mk_leaf(value const& v) {
return m_dep_manager.mk_leaf(v);
}
dependency* mk_join(dependency* d1, dependency* d2) {
return m_dep_manager.mk_join(d1, d2);
}
bool contains(dependency* d, value const& v) {
return m_dep_manager.contains(d, v);
}
void linearize(dependency* d, vector<value, false>& vs) {
return m_dep_manager.linearize(d, vs);
}
static vector<value, false> const& s_linearize(dependency* d, vector<value, false>& vs) {
dep_manager::s_linearize(d, vs);
return vs;
}
void linearize(ptr_vector<dependency>& d, vector<value, false>& vs) {
return m_dep_manager.linearize(d, vs);
}
};