3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 17:45:32 +00:00

use z3's egraph (wip)

This commit is contained in:
Jakob Rath 2023-07-12 16:21:38 +02:00
parent d8d8c67a3b
commit 0fb81fc437
2 changed files with 156 additions and 193 deletions

View file

@ -46,9 +46,27 @@ Recycle the z3 egraph?
namespace polysat {
void* slicing::encode_dep(dep_t d) {
if (d == null_dep)
return nullptr;
else
return reinterpret_cast<void*>(static_cast<std::uintptr_t>(d.to_uint()) + 1);
static_assert( sizeof(void*) >= sizeof(std::uintptr_t) );
static_assert( sizeof(std::uintptr_t) > sizeof(decltype(d.to_uint())) );
}
slicing::dep_t slicing::decode_dep(void* d) {
if (!d)
return null_dep;
else
return sat::to_literal(reinterpret_cast<std::uintptr_t>(d) - 1);
}
slicing::slicing(solver& s):
m_solver(s),
m_egraph(m_ast)
m_egraph(m_ast),
// m_slice2app(m_ast),
m_expr_storage(m_ast)
{
m_slice_sort = m_ast.mk_uninterpreted_sort(symbol("slice"));
}
@ -70,6 +88,8 @@ namespace polysat {
void slicing::push_scope() {
m_scopes.push_back(m_trail.size());
m_expr_scopes.push_back(m_expr_storage.size());
m_egraph.push();
}
void slicing::pop_scope(unsigned num_scopes) {
@ -79,24 +99,27 @@ namespace polysat {
SASSERT(num_scopes <= lvl);
unsigned const target_lvl = lvl - num_scopes;
unsigned const target_size = m_scopes[target_lvl];
unsigned const target_expr_size = m_expr_scopes[target_lvl];
m_scopes.shrink(target_lvl);
m_expr_scopes.shrink(target_lvl);
while (m_trail.size() > target_size) {
switch (m_trail.back()) {
case trail_item::add_var: undo_add_var(); break;
case trail_item::alloc_slice: undo_alloc_slice(); break;
case trail_item::split_slice: undo_split_slice(); break;
case trail_item::merge_base: undo_merge_base(); break;
case trail_item::mk_value_slice: undo_mk_value_slice(); break;
case trail_item::split_core: undo_split_core(); break;
// case trail_item::merge_base: undo_merge_base(); break;
// case trail_item::mk_value_slice: undo_mk_value_slice(); break;
default: UNREACHABLE();
}
m_trail.pop_back();
}
m_scopes.shrink(target_lvl);
m_egraph.pop(num_scopes);
m_expr_storage.shrink(target_expr_size);
}
void slicing::add_var(unsigned bit_width) {
pvar const v = m_var2slice.size();
slice const s = alloc_slice();
m_slice_width[s] = bit_width;
slice const s = alloc_slice(bit_width);
m_slice2var[s] = v;
m_var2slice.push_back(s);
}
@ -105,20 +128,20 @@ namespace polysat {
m_var2slice.pop_back();
}
slicing::slice slicing::alloc_slice() {
slicing::slice slicing::alloc_slice(unsigned width) {
SASSERT(width > 0);
slice const s = m_slice_cut.size();
m_slice_width.push_back(0);
m_slice_width.push_back(width);
m_slice_cut.push_back(null_cut);
m_slice_sub.push_back(null_slice);
m_find.push_back(s);
m_size.push_back(1);
m_next.push_back(s);
m_slice2var.push_back(null_var);
m_proof_parent.push_back(null_slice);
m_proof_reason.push_back(null_dep);
m_mark.push_back(0);
// m_value_root.push_back(null_slice);
m_slice2val.push_back(rational(-1));
// m_mark.push_back(0);
app* a = m_ast.mk_fresh_const("s", m_slice_sort, false); // TODO: what's the effect of "skolem = true"?
m_expr_storage.push_back(a);
euf::enode* n = m_egraph.mk(a, 0, 0, nullptr);
m_slice2enode.push_back(n);
SASSERT(!m_enode2slice.contains(n));
m_enode2slice.insert(n, s);
m_trail.push_back(trail_item::alloc_slice);
return s;
}
@ -127,15 +150,12 @@ namespace polysat {
m_slice_width.pop_back();
m_slice_cut.pop_back();
m_slice_sub.pop_back();
m_find.pop_back();
m_size.pop_back();
m_next.pop_back();
m_slice2var.pop_back();
m_proof_parent.pop_back();
m_proof_reason.pop_back();
m_mark.pop_back();
// m_value_root.pop_back();
m_slice2val.pop_back();
euf::enode* n = m_slice2enode.back();
SASSERT_EQ(m_enode2slice[n], m_slice_cut.size());
m_enode2slice.remove(n);
m_slice2enode.pop_back();
// m_mark.pop_back();
}
slicing::slice slicing::sub_hi(slice parent) const {
@ -148,6 +168,14 @@ namespace polysat {
return m_slice_sub[parent] + 1;
}
euf::enode* slicing::sub_hi(euf::enode* n) const {
return slice2enode(sub_hi(enode2slice(n)));
}
euf::enode* slicing::sub_lo(euf::enode* n) const {
return slice2enode(sub_lo(enode2slice(n)));
}
slicing::slice slicing::find_sub_hi(slice parent) const {
return find(sub_hi(parent));
}
@ -156,32 +184,61 @@ namespace polysat {
return find(sub_lo(parent));
}
void slicing::split(slice s, unsigned cut) {
// split a single slice without updating any equivalences
void slicing::split_core(slice s, unsigned cut) {
SASSERT(!has_sub(s));
SASSERT(width(s) - 1 >= cut + 1);
slice const sub_hi = alloc_slice();
slice const sub_lo = alloc_slice();
slice const sub_hi = alloc_slice(width(s) - cut - 1);
slice const sub_lo = alloc_slice(cut + 1);
m_slice_cut[s] = cut;
m_slice_sub[s] = sub_hi;
SASSERT_EQ(sub_lo, sub_hi + 1);
m_slice_width[sub_hi] = width(s) - cut - 1;
m_slice_width[sub_lo] = cut + 1;
m_trail.push_back(trail_item::split_slice);
m_trail.push_back(trail_item::split_core);
m_split_trail.push_back(s);
if (has_value(s)) {
rational const& val = get_value(s);
// set_value(sub_lo, mod2k(val, cut + 1));
// set_value(sub_hi, machine_div2k(val, cut + 1));
}
// if (has_value(s)) {
// rational const& val = get_value(s);
// // set_value(sub_lo, mod2k(val, cut + 1));
// // set_value(sub_hi, machine_div2k(val, cut + 1));
// }
// // s = hi ++ lo ... TODO: necessary??? probably not
// euf::enode* s_n = slice2enode(s);
// euf::enode* hi_n = slice2enode(sub_hi);
// euf::enode* lo_n = slice2enode(sub_lo);
// app* a = m_ast.mk_app(get_concat_decl(2), hi_n->get_expr(), lo_n->get_expr());
// auto args = {hi_n, lo_n};
// euf::enode* concat_n = m_egraph.mk(a, 0, args.size(), blup.begin());
// m_egraph.merge(s_n, concat_n, nullptr);
// SASSERT(!concat_n->is_root()); // else we have to register it in enode2slice
}
void slicing::undo_split_slice() {
void slicing::undo_split_core() {
slice s = m_split_trail.back();
m_split_trail.pop_back();
m_slice_cut[s] = null_cut;
m_slice_sub[s] = null_slice;
}
void slicing::split(slice s, unsigned cut) {
euf::enode* sn = slice2enode(s);
// split all slices in the equivalence class
for (euf::enode* n : euf::enode_class(sn)) {
split_core(enode2slice(n), cut);
}
// propagate the proper equivalences
for (euf::enode* n : euf::enode_class(sn)) {
euf::enode* target = n->get_target();
if (!target)
continue;
euf::justification j = n->get_justification();
SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before.
m_egraph.merge(sub_hi(n), sub_hi(target), j.ext<void>());
m_egraph.merge(sub_lo(n), sub_lo(target), j.ext<void>());
}
m_egraph.propagate(); // TODO: could do this later
}
#if 0
slicing::slice slicing::mk_value_slice(rational const& val, unsigned bit_width) {
SASSERT(0 <= val && val < rational::power_of_two(bit_width));
val2slice_key key(val, bit_width);
@ -202,102 +259,23 @@ namespace polysat {
m_val2slice.remove(m_val2slice_trail.back());
m_val2slice_trail.pop_back();
}
#endif
slicing::slice slicing::find(slice s) const {
while (true) {
SASSERT(s < m_find.size());
slice const new_s = m_find[s];
if (new_s == s)
return s;
s = new_s;
}
return enode2slice(slice2enode(s)->get_root());
}
#if 1
bool slicing::merge_base(slice s1, slice s2, dep_t dep) {
SASSERT_EQ(width(s1), width(s2));
SASSERT(!has_sub(s1));
SASSERT(!has_sub(s2));
slice r1 = find(s1);
slice r2 = find(s2);
if (r1 == r2)
return true;
if (m_size[r1] > m_size[r2]) {
std::swap(r1, r2);
std::swap(s1, s2);
}
if (has_value(r1)) {
if (has_value(r2)) {
if (get_value(r1) != get_value(r2)) {
NOT_IMPLEMENTED_YET(); // TODO: conflict
return false;
}
}
// else
// set_value(r2, get_value(r1));
}
// r2 becomes the representative of the merged class
m_find[r1] = r2;
m_size[r2] += m_size[r1];
std::swap(m_next[r1], m_next[r2]);
if (m_slice2var[r2] == null_var)
m_slice2var[r2] = m_slice2var[r1];
else {
// otherwise the classes should have been merged already
SASSERT(m_slice2var[r2] != m_slice2var[r1]);
}
// Add justification 'dep' for s1 = s2
// NOTE: invariant: root of the proof tree is the representative
SASSERT(m_proof_parent[r1] == null_slice);
SASSERT(m_proof_parent[r2] == null_slice);
make_proof_root(s1);
SASSERT(m_proof_parent[s1] == null_slice);
m_proof_parent[s1] = s2;
m_proof_reason[s1] = dep;
SASSERT(m_proof_parent[r2] == null_slice);
m_trail.push_back(trail_item::merge_base);
m_merge_trail.push_back({r1, s1});
return true;
}
void slicing::undo_merge_base() {
auto const [r1, s1] = m_merge_trail.back();
m_merge_trail.pop_back();
slice const r2 = m_find[r1];
SASSERT(find(r2) == r2);
m_find[r1] = r1;
m_size[r2] -= m_size[r1];
std::swap(m_next[r1], m_next[r2]);
if (m_slice2var[r2] == m_slice2var[r1])
m_slice2var[r2] = null_var;
SASSERT(m_proof_parent[s1] == null_slice);
SASSERT(m_proof_parent[r2] == null_slice);
m_proof_parent[s1] = null_slice;
m_proof_reason[s1] = null_dep;
SASSERT(m_proof_parent[r1] == null_slice);
SASSERT(m_proof_parent[r2] == null_slice);
make_proof_root(r1);
}
void slicing::make_proof_root(slice s) {
// s1 -> s2 -> s3 -> s4
// r1 r2 r3
// =>
// s1 <- s2 <- s3 <- s4
// r1 r2 r3
slice curr = s;
slice prev = null_slice;
dep_t prev_reason = null_dep;
while (curr != null_slice) {
slice const curr_parent = m_proof_parent[curr];
dep_t const curr_reason = m_proof_reason[curr];
m_proof_parent[curr] = prev;
m_proof_reason[curr] = prev_reason;
prev = curr;
prev_reason = curr_reason;
curr = curr_parent;
}
m_egraph.merge(slice2enode(s1), slice2enode(s2), encode_dep(dep));
m_egraph.propagate(); // TODO: could do this later maybe
return !m_egraph.inconsistent();
}
#if 0
bool slicing::merge_value(slice s0, rational val0, dep_t dep) {
vector<std::pair<slice, rational>> todo;
todo.push_back({find(s0), std::move(val0)});
@ -330,7 +308,9 @@ namespace polysat {
}
return true;
}
#endif
#if 0
void slicing::push_reason(slice s, dep_vector& out_deps) {
dep_t reason = m_proof_reason[s];
if (reason == null_dep)
@ -338,36 +318,6 @@ namespace polysat {
out_deps.push_back(reason);
}
void slicing::explain_class(slice x, slice y, dep_vector& out_deps) {
SASSERT_EQ(find(x), find(y));
// /-> ...
// x -> x1 -> x2 -> lca <- y1 <- y
// r0 r1 r2 r4 r3
begin_mark();
// mark ancestors of x in the proof forest
slice s = x;
while (s != null_slice) {
mark(s);
s = m_proof_parent[s];
}
// find lowest common ancestor of x and y
// and collect deps from y to lca
slice lca = y;
while (!is_marked(lca)) {
push_reason(lca, out_deps);
lca = m_proof_parent[lca];
SASSERT(lca != null_slice);
}
// collect deps from x to lca
s = x;
while (s != lca) {
push_reason(s, out_deps);
s = m_proof_parent[s];
SASSERT(s != null_slice);
}
end_mark();
}
void slicing::explain_equal(slice x, slice y, dep_vector& out_deps) {
// TODO: we currently get duplicates in out_deps (if parents are merged, the subslices are all merged due to the same reason)
SASSERT(is_equal(x, y));
@ -411,6 +361,7 @@ namespace polysat {
}
SASSERT(ys.empty());
}
#endif
bool slicing::merge(slice_vector& xs, slice_vector& ys, dep_t dep) {
// LOG_H2("Merging " << xs << " with " << ys);
@ -496,11 +447,9 @@ namespace polysat {
bool result = (xs == ys);
xs.clear();
ys.clear();
#if 0
if (result) {
// TODO: merge equivalence class of x, y (on upper level)? but can we always combine the sub-trees?
}
#endif
return result;
}
@ -596,6 +545,7 @@ namespace polysat {
slice s = slices[0];
if (slice2var(s) != null_var)
return slice2var(s);
// TODO: optimization: could save a slice-tree by directly assigning slice2var(s) = v for new var v.
}
pvar v = m_solver.add_var(hi - lo + 1);
VERIFY(merge(slices, var2slice(v), null_dep));
@ -759,12 +709,13 @@ namespace polysat {
if (s != find(s))
continue;
// if slice has a value, it should be propagated to its sub-slices
if (has_value(s)) {
VERIFY(has_value(find_sub_hi(s)));
VERIFY(has_value(find_sub_lo(s)));
}
// if (has_value(s)) {
// VERIFY(has_value(find_sub_hi(s)));
// VERIFY(has_value(find_sub_lo(s)));
// }
}
return true;
}
#endif
}

View file

@ -50,13 +50,24 @@ namespace polysat {
using dep_t = sat::literal;
using dep_vector = sat::literal_vector;
static constexpr sat::literal null_dep = sat::null_literal;
void* encode_dep(dep_t d);
dep_t decode_dep(void* d);
using slice = unsigned;
using slice_vector = unsigned_vector;
static constexpr slice null_slice = std::numeric_limits<slice>::max();
static constexpr slice null_slice = std::numeric_limits<slice>::max();
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
struct slice_info {
unsigned width = 0;
unsigned cut = null_cut;
euf::enode* sub_hi = nullptr;
euf::enode* sub_lo = nullptr;
};
// using enode = euf::enode<slice_extra>;
struct val2slice_key {
rational value;
unsigned bit_width;
@ -76,33 +87,26 @@ namespace polysat {
using val2slice_eq = default_eq<val2slice_key>;
using val2slice_map = map<val2slice_key, slice, val2slice_hash, val2slice_eq>;
// number of bits in the slice
unsigned_vector m_slice_width;
unsigned_vector m_slice_width; // number of bits in the slice
// Cut point: if slice represents bit-vector x, then x has been sliced into x[|x|-1:cut+1] and x[cut:0].
// The cut point is relative to the parent slice (rather than a root variable, which might not be unique)
// (null_cut for leaf slices)
unsigned_vector m_slice_cut;
// The sub-slices are at indices sub and sub+1 (or null_slice if there is no subdivision)
slice_vector m_slice_sub;
slice_vector m_find; // representative of equivalence class
slice_vector m_size; // number of elements in equivalence class
slice_vector m_next; // next element of the equivalence class
slice_vector m_proof_parent; // the proof forest
dep_vector m_proof_reason; // justification for merge of an element with its parent (in the proof forest)
// unsigned_vector m_value_id; // slice -> value id
// vector<rational> m_value; // id -> value
// slice_vector m_value_root; // the slice representing the associated value, if any. NOTE: subslices will inherit this from their parents.
// TODO: value_root probably not necessary.
// but we will need to create value slices for the sub-slices.
// then the "reason" for that equality must be a marker "equality between parent and its value". explanation at that point must go up recursively.
vector<rational> m_slice2val; // slice -> value (-1 if none)
val2slice_map m_val2slice; // (value, bit-width) -> slice
pvar_vector m_slice2var; // slice -> pvar, or null_var if slice is not equivalent to a variable
slice_vector m_var2slice; // pvar -> slice
// app_ref_vector m_slice2app; // slice -> app*
// ptr_addr_map<app, slice> m_app2slice;
ptr_vector<euf::enode> m_slice2enode;
ptr_addr_map<euf::enode, slice> m_enode2slice;
expr_ref_vector m_expr_storage;
unsigned_vector m_expr_scopes;
#if 0
unsigned_vector m_mark;
unsigned m_mark_timestamp = 0;
#if Z3DEBUG
@ -118,27 +122,42 @@ namespace polysat {
void end_mark() { DEBUG_CODE({ SASSERT(m_mark_active); m_mark_active = false; }); }
bool is_marked(slice s) const { SASSERT(m_mark_active); return m_mark[s] == m_mark_timestamp; }
void mark(slice s) { SASSERT(m_mark_active); m_mark[s] = m_mark_timestamp; }
#endif
slice alloc_slice();
slice alloc_slice(unsigned width);
slice var2slice(pvar v) const { return m_var2slice[v]; }
pvar slice2var(slice s) const { return m_slice2var[s]; }
// slice app2slice(app* a) const { return m_app2slice[a]; }
// app* slice2app(slice s) const { return m_slice2app[s]; }
slice enode2slice(euf::enode* n) const { return m_enode2slice[n]; }
euf::enode* slice2enode(slice s) const { return m_slice2enode[s]; }
slice var2slice(pvar v) const { return find(m_var2slice[v]); }
pvar slice2var(slice s) const { return m_slice2var[find(s)]; }
unsigned width(slice s) const { return m_slice_width[s]; }
bool has_sub(slice s) const { return m_slice_sub[s] != null_slice; }
/// Upper subslice (direct child, not necessarily the representative)
slice sub_hi(slice s) const;
euf::enode* sub_hi(euf::enode* n) const;
/// Lower subslice (direct child, not necessarily the representative)
slice sub_lo(slice s) const;
euf::enode* sub_lo(euf::enode* n) const;
// slice val2slice(rational const& val, unsigned bit_width) const;
// Retrieve (or create) a slice representing the given value.
slice mk_value_slice(rational const& val, unsigned bit_width);
// slice mk_value_slice(rational const& val, unsigned bit_width);
bool has_value(slice s) const { SASSERT_EQ(s, find(s)); return m_slice2val[s].is_nonneg(); }
rational const& get_value(slice s) const { SASSERT(has_value(s)); return m_slice2val[s]; }
// bool has_value(slice s) const { SASSERT_EQ(s, find(s)); return m_slice2val[s].is_nonneg(); }
// rational const& get_value(slice s) const { SASSERT(has_value(s)); return m_slice2val[s]; }
// reverse all edges on the path from s to the root of its tree in the proof forest
void make_proof_root(slice s);
// void make_proof_root(slice s);
/// Split slice s into s[|s|-1:cut+1] and s[cut:0]
void split(slice s, unsigned cut);
void split_core(slice s, unsigned cut);
/// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n
void find_base(slice src, slice_vector& out_base) const;
/// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n.
@ -146,11 +165,6 @@ namespace polysat {
/// If output_base is false, return coarsest intermediate slices instead of only base slices.
void mk_slice(slice src, unsigned hi, unsigned lo, slice_vector& out, bool output_full_src = false, bool output_base = true);
/// Upper subslice (direct child, not necessarily the representative)
slice sub_hi(slice s) const;
/// Lower subslice (direct child, not necessarily the representative)
slice sub_lo(slice s) const;
/// Find representative
slice find(slice s) const;
/// Find representative of upper subslice
@ -194,20 +208,18 @@ namespace polysat {
enum class trail_item {
add_var,
alloc_slice,
split_slice,
split_core,
merge_base,
mk_value_slice,
};
svector<trail_item> m_trail;
slice_vector m_split_trail;
svector<std::pair<slice, slice>> m_merge_trail; // pair of (representative, element)
vector<val2slice_key> m_val2slice_trail;
unsigned_vector m_scopes;
void undo_add_var();
void undo_alloc_slice();
void undo_split_slice();
void undo_merge_base();
void undo_split_core();
void undo_mk_value_slice();
mutable slice_vector m_tmp1;