3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 17:45:32 +00:00

slicing checkpoint

This commit is contained in:
Jakob Rath 2023-07-08 20:08:45 +02:00
parent 28810e55a0
commit b4edc4d20c
3 changed files with 316 additions and 30 deletions

View file

@ -11,6 +11,35 @@ Author:
--*/
/*
(x=y)
x <=========== y
/ \ / \
x[7:4] x[3:0] y[3:0]
<==========
(by x=y)
Try later:
Congruence closure with "virtual concat" terms
x = x[7:4] ++ x[3:0]
y = y[7:4] ++ y[3:0]
x[7:4] = y[7:4]
x[3:0] = y[3:0]
=> x = y
Recycle the z3 egraph?
- x = x[7:4] ++ x[3:0]
- Add instance euf_egraph.h
- What do we need from the egraph?
- backtracking trail to check for new equalities
*/
#include "math/polysat/slicing.h"
#include "math/polysat/solver.h"
#include "math/polysat/log.h"
@ -30,10 +59,11 @@ namespace polysat {
unsigned const target_size = m_scopes[target_lvl];
while (m_trail.size() > target_size) {
switch (m_trail.back()) {
case trail_item::add_var: undo_add_var(); break;
case trail_item::alloc_slice: undo_alloc_slice(); break;
case trail_item::split_slice: undo_split_slice(); break;
case trail_item::merge_base: undo_merge_base(); break;
case trail_item::add_var: undo_add_var(); break;
case trail_item::alloc_slice: undo_alloc_slice(); break;
case trail_item::split_slice: undo_split_slice(); break;
case trail_item::merge_base: undo_merge_base(); break;
case trail_item::mk_value_slice: undo_mk_value_slice(); break;
default: UNREACHABLE();
}
m_trail.pop_back();
@ -65,6 +95,8 @@ namespace polysat {
m_proof_parent.push_back(null_slice);
m_proof_reason.push_back(null_dep);
m_mark.push_back(0);
// m_value_root.push_back(null_slice);
m_slice2val.push_back(rational(-1));
m_trail.push_back(trail_item::alloc_slice);
return s;
}
@ -80,6 +112,8 @@ namespace polysat {
m_proof_parent.pop_back();
m_proof_reason.pop_back();
m_mark.pop_back();
// m_value_root.pop_back();
m_slice2val.pop_back();
}
slicing::slice slicing::sub_hi(slice parent) const {
@ -112,6 +146,11 @@ namespace polysat {
m_slice_width[sub_lo] = cut + 1;
m_trail.push_back(trail_item::split_slice);
m_split_trail.push_back(s);
if (has_value(s)) {
rational const& val = get_value(s);
// set_value(sub_lo, mod2k(val, cut + 1));
// set_value(sub_hi, machine_div2k(val, cut + 1));
}
}
void slicing::undo_split_slice() {
@ -121,6 +160,27 @@ namespace polysat {
m_slice_sub[s] = null_slice;
}
slicing::slice slicing::mk_value_slice(rational const& val, unsigned bit_width) {
SASSERT(0 <= val && val < rational::power_of_two(bit_width));
val2slice_key key(val, bit_width);
auto it = m_val2slice.find_iterator(key);
if (it != m_val2slice.end())
return it->m_value;
slice s = alloc_slice();
m_slice_width[s] = bit_width;
m_slice2val[s] = val;
// m_value_root[s] = s;
m_val2slice.insert(key, s);
m_val2slice_trail.push_back(std::move(key));
m_trail.push_back(trail_item::mk_value_slice);
return s;
}
void slicing::undo_mk_value_slice() {
m_val2slice.remove(m_val2slice_trail.back());
m_val2slice_trail.pop_back();
}
slicing::slice slicing::find(slice s) const {
while (true) {
SASSERT(s < m_find.size());
@ -143,6 +203,16 @@ namespace polysat {
std::swap(r1, r2);
std::swap(s1, s2);
}
if (has_value(r1)) {
if (has_value(r2)) {
if (get_value(r1) != get_value(r2)) {
NOT_IMPLEMENTED_YET(); // TODO: conflict
return false;
}
}
// else
// set_value(r2, get_value(r1));
}
// r2 becomes the representative of the merged class
m_find[r1] = r2;
m_size[r2] += m_size[r1];
@ -206,6 +276,39 @@ namespace polysat {
}
}
bool slicing::merge_value(slice s0, rational val0, dep_t dep) {
vector<std::pair<slice, rational>> todo;
todo.push_back({find(s0), std::move(val0)});
// check compatibility for sub-slices
for (unsigned i = 0; i < todo.size(); ++i) {
auto const& [s, val] = todo[i];
if (has_value(s)) {
if (get_value(s) != val) {
// TODO: conflict
NOT_IMPLEMENTED_YET();
return false;
}
SASSERT_EQ(get_value(s), val);
continue;
}
if (has_sub(s)) {
// s is split into [s.width-1, cut+1] and [cut, 0]
unsigned const cut = m_slice_cut[s];
todo.push_back({find_sub_lo(s), mod2k(val, cut + 1)});
todo.push_back({find_sub_hi(s), machine_div2k(val, cut + 1)});
}
}
// all succeeded, so apply the values
for (auto const& [s, val] : todo) {
if (has_value(s)) {
SASSERT_EQ(get_value(s), val);
continue;
}
// set_value(s, val);
}
return true;
}
void slicing::push_reason(slice s, dep_vector& out_deps) {
dep_t reason = m_proof_reason[s];
if (reason == null_dep)
@ -295,6 +398,10 @@ namespace polysat {
slice y = ys.back();
xs.pop_back();
ys.pop_back();
if (x == y) {
// merge upper level?
// but continue loop
}
if (has_sub(x)) {
find_base(x, xs);
x = xs.back();
@ -338,6 +445,7 @@ namespace polysat {
}
bool slicing::merge(slice x, slice y, dep_t dep) {
SASSERT_EQ(width(x), width(y));
if (!has_sub(x) && !has_sub(y))
return merge_base(x, y, dep);
slice_vector& xs = m_tmp2;
@ -350,6 +458,7 @@ namespace polysat {
}
bool slicing::is_equal(slice x, slice y) {
SASSERT_EQ(width(x), width(y));
x = find(x);
y = find(y);
if (x == y)
@ -480,6 +589,13 @@ namespace polysat {
}
pdd slicing::mk_extract(pdd const& p, unsigned hi, unsigned lo) {
SASSERT(hi >= lo);
SASSERT(p.power_of_2() > hi);
if (p.is_val()) {
// p[hi:lo] = (p >> lo) % 2^(hi - lo + 1)
rational q = mod2k(machine_div2k(p.val(), lo), hi - lo + 1);
return p.manager().mk_val(q);
}
if (!lo) {
// TODO: we could push the extract down into variables of the term instead of introducing a name.
}
@ -498,6 +614,14 @@ namespace polysat {
unsigned const p_sz = p.power_of_2();
unsigned const q_sz = q.power_of_2();
unsigned const v_sz = p_sz + q_sz;
if (p.is_val() && q.is_val()) {
rational const val = p.val() * rational::power_of_two(q_sz) + q.val();
return m_solver.sz2pdd(v_sz).mk_val(val);
}
if (p.is_val()) {
}
if (q.is_val()) {
}
pvar const v = m_solver.add_var(v_sz);
slice_vector tmp;
tmp.push_back(pdd2slice(p));
@ -506,7 +630,48 @@ namespace polysat {
return m_solver.var(v);
}
void slicing::propagate(signed_constraint c) {
// TODO: evaluate under current assignment?
if (!c->is_eq())
return;
pdd const& p = c->to_eq();
auto& m = p.manager();
for (auto& [a, x] : p.linear_monomials()) {
if (a != 1 && a != m.max_value())
continue;
pdd body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p);
// c is either x = body or x != body, depending on polarity
LOG("Equation from constraint " << c << ": v" << x << " = " << body);
slice const sx = var2slice(x);
if (body.is_val()) {
// Simple assignment x = value
// TODO: set fixed bits
continue;
}
pvar const y = m_solver.m_names.get_name(body);
if (y == null_var) {
// TODO: register name trigger (if a name for value 'body' is created later, then merge x=y at that time)
continue;
}
slice const sy = var2slice(y);
if (c.is_positive()) {
if (!merge(sx, sy, c.blit()))
return;
}
else {
SASSERT(c.is_negative());
if (is_equal(sx, sy)) {
// TODO: conflict
NOT_IMPLEMENTED_YET();
return;
}
}
}
}
void slicing::propagate(pvar v) {
// go through all existing nodes, and evaluate v?
// can do that externally
}
std::ostream& slicing::display(std::ostream& out) const {
@ -514,7 +679,8 @@ namespace polysat {
for (pvar v = 0; v < m_var2slice.size(); ++v) {
out << "v" << v << ":";
base.reset();
find_base(var2slice(v), base);
slice const vs = var2slice(v);
find_base(vs, base);
// unsigned hi = width(var2slice(v)) - 1;
for (slice s : base) {
// unsigned w = width(s);
@ -523,13 +689,49 @@ namespace polysat {
// hi -= w;
display(out << " ", s);
}
if (has_value(vs)) {
out << " -- (val:" << get_value(vs) << ")";
}
out << "\n";
}
for (pvar v = 0; v < m_var2slice.size(); ++v) {
out << "v" << v << ":";
slice const s = m_var2slice[v];
}
return out;
}
std::ostream& slicing::display_tree(std::ostream& out, char const* name, slice s) const {
// TODO
}
std::ostream& slicing::display(std::ostream& out, slice s) const {
return out << "{id:" << s << ",w:" << width(s) << "}";
out << "{id:" << s << ",w:" << width(s);
if (has_value(s)) {
out << ",val:" << get_value(s);
}
out << "}";
return out;
}
bool slicing::invariant() const {
VERIFY(m_tmp1.empty());
VERIFY(m_tmp2.empty());
VERIFY(m_tmp3.empty());
for (slice s = 0; s < m_slice_cut.size(); ++s) {
// if the slice is equivalent to a variable, then the variable's slice is in the equivalence class
pvar const v = slice2var(s);
SASSERT_EQ(v != null_var, find(var2slice(v)) == find(s));
// properties below only matter for representatives
if (s != find(s))
continue;
// if slice has a value, it should be propagated to its sub-slices
if (has_value(s)) {
VERIFY(has_value(find_sub_hi(s)));
VERIFY(has_value(find_sub_lo(s)));
}
}
return true;
}
}

View file

@ -25,6 +25,7 @@ Notation:
--*/
#pragma once
#include "math/polysat/types.h"
#include "math/polysat/constraint.h"
#include <optional>
namespace polysat {
@ -37,28 +38,6 @@ namespace polysat {
solver& m_solver;
#if 0
struct extract_key {
pvar src;
unsigned hi;
unsigned lo;
bool operator==(extract_key const& other) const {
return src == other.src && hi == other.hi && lo == other.lo;
}
unsigned hash() const {
return mk_mix(src, hi, lo);
}
};
using extract_hash = obj_hash<extract_key>;
using extract_eq = default_eq<extract_key>;
using extract_map = map<extract_key, pvar, extract_hash, extract_eq>;
extract_map m_extracted; ///< src, hi, lo -> v
// need src -> [v] and v -> [src] for propagation?
#endif
using dep_t = sat::literal;
using dep_vector = sat::literal_vector;
static constexpr sat::literal null_dep = sat::null_literal;
@ -69,11 +48,30 @@ namespace polysat {
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
struct val2slice_key {
rational value;
unsigned bit_width;
val2slice_key() {}
val2slice_key(rational value, unsigned bit_width): value(std::move(value)), bit_width(bit_width) {}
bool operator==(val2slice_key const& other) const {
return bit_width == other.bit_width && value == other.value;
}
unsigned hash() const {
return combine_hash(value.hash(), bit_width);
}
};
using val2slice_hash = obj_hash<val2slice_key>;
using val2slice_eq = default_eq<val2slice_key>;
using val2slice_map = map<val2slice_key, slice, val2slice_hash, val2slice_eq>;
// number of bits in the slice
unsigned_vector m_slice_width;
// Cut point: if slice represents bit-vector x, then x has been sliced into x[|x|-1:cut+1] and x[cut:0].
// The cut point is relative to the parent slice (rather than a root variable, which might not be unique)
// (UINT_MAX for leaf slices)
// (null_cut for leaf slices)
unsigned_vector m_slice_cut;
// The sub-slices are at indices sub and sub+1 (or null_slice if there is no subdivision)
slice_vector m_slice_sub;
@ -84,6 +82,15 @@ namespace polysat {
slice_vector m_proof_parent; // the proof forest
dep_vector m_proof_reason; // justification for merge of an element with its parent (in the proof forest)
// unsigned_vector m_value_id; // slice -> value id
// vector<rational> m_value; // id -> value
// slice_vector m_value_root; // the slice representing the associated value, if any. NOTE: subslices will inherit this from their parents.
// TODO: value_root probably not necessary.
// but we will need to create value slices for the sub-slices.
// then the "reason" for that equality must be a marker "equality between parent and its value". explanation at that point must go up recursively.
vector<rational> m_slice2val; // slice -> value (-1 if none)
val2slice_map m_val2slice; // (value, bit-width) -> slice
pvar_vector m_slice2var; // slice -> pvar, or null_var if slice is not equivalent to a variable
slice_vector m_var2slice; // pvar -> slice
@ -110,6 +117,14 @@ namespace polysat {
unsigned width(slice s) const { return m_slice_width[s]; }
bool has_sub(slice s) const { return m_slice_sub[s] != null_slice; }
// slice val2slice(rational const& val, unsigned bit_width) const;
// Retrieve (or create) a slice representing the given value.
slice mk_value_slice(rational const& val, unsigned bit_width);
bool has_value(slice s) const { SASSERT_EQ(s, find(s)); return m_slice2val[s].is_nonneg(); }
rational const& get_value(slice s) const { SASSERT(has_value(s)); return m_slice2val[s]; }
// reverse all edges on the path from s to the root of its tree in the proof forest
void make_proof_root(slice s);
@ -138,6 +153,10 @@ namespace polysat {
// Returns true if merge succeeded without conflict.
[[nodiscard]] bool merge_base(slice s1, slice s2, dep_t dep);
// Merge equality s == val and propagate the value downward into sub-slices.
// Returns true if merge succeeded without conflict.
[[nodiscard]] bool merge_value(slice s, rational val, dep_t dep);
void push_reason(slice s, dep_vector& out_deps);
// Extract reason why slices x and y are in the same equivalence class
@ -168,17 +187,19 @@ namespace polysat {
alloc_slice,
split_slice,
merge_base,
mk_value_slice,
};
svector<trail_item> m_trail;
slice_vector m_split_trail;
svector<std::pair<slice, slice>> m_merge_trail; // pair of (representative, element)
vector<val2slice_key> m_val2slice_trail;
unsigned_vector m_scopes;
void undo_add_var();
void undo_alloc_slice();
void undo_split_slice();
void undo_merge_base();
void undo_mk_value_slice();
mutable slice_vector m_tmp1;
mutable slice_vector m_tmp2;
@ -217,8 +238,10 @@ namespace polysat {
// - fixed bits
// - intervals ????? -- that will also need changes in the viable algorithm
void propagate(pvar v);
void propagate(signed_constraint c);
std::ostream& display(std::ostream& out) const;
std::ostream& display_tree(std::ostream& out, char const* name, slice s) const;
std::ostream& display(std::ostream& out, slice s) const;
};

View file

@ -118,6 +118,65 @@ namespace polysat {
std::cout << " Reason: " << reason << "\n";
}
// 1. a = b
// 2. d = c[1:0]
// 3. c = b[3:0]
// 4. e = a[1:0]
//
// Explain(d = e) should be {1, 2, 3, 4}
static void test4() {
std::cout << __func__ << "\n";
scoped_solver_slicing s;
slicing& sl = s.sl();
pvar a = s.add_var(8);
pvar b = s.add_var(8);
pvar c = s.add_var(4);
pvar d = s.add_var(2);
pvar e = s.add_var(2);
VERIFY(sl.merge(sl.var2slice(a), sl.var2slice(b), sat::literal(101)));
VERIFY(sl.merge(sl.var2slice(d), sl.var2slice(sl.mk_extract_var(c, 1, 0)), sat::literal(102)));
VERIFY(sl.merge(sl.var2slice(c), sl.var2slice(sl.mk_extract_var(b, 3, 0)), sat::literal(103)));
VERIFY(sl.merge(sl.var2slice(e), sl.var2slice(sl.mk_extract_var(a, 1, 0)), sat::literal(104)));
std::cout << "v" << d << " = v" << e << "? " << sl.is_equal(sl.var2slice(d), sl.var2slice(e))
<< " find(v" << d << ") = " << sl.find(sl.var2slice(d))
<< " find(v" << e << ") = " << sl.find(sl.var2slice(e))
<< " slice(v" << d << ") = " << sl.var2slice(d)
<< " slice(v" << e << ") = " << sl.var2slice(e)
<< " slice(v" << d << ") = " << sl.m_var2slice[d]
<< " slice(v" << e << ") = " << sl.m_var2slice[e]
<< "\n";
sat::literal_vector reason;
sl.explain_equal(sl.var2slice(d), sl.var2slice(e), reason);
std::cout << " Reason: " << reason << "\n";
reason.reset();
sl.explain_equal(sl.m_var2slice[d], sl.m_var2slice[e], reason);
std::cout << " Reason: " << reason << "\n";
}
// x[5:2] = y
// x[3:0] = z
// y = 0b1001
// z = 0b0111
static void test5() {
std::cout << __func__ << "\n";
scoped_solver_slicing s;
slicing& sl = s.sl();
pvar x = s.add_var(6);
std::cout << sl << "\n";
pvar y = sl.mk_extract_var(x, 5, 2);
std::cout << "v" << y << " := v" << x << "[5:2]\n" << sl << "\n";
pvar z = sl.mk_extract_var(x, 3, 0);
std::cout << "v" << z << " := v" << x << "[3:0]\n" << sl << "\n";
// VERIFY(sl.merge_value(sl.var2slice(y), rational(9)));
// std::cout << "v" << y << " = 9\n" << sl << "\n";
// VERIFY(sl.merge_value(sl.var2slice(z), rational(7)));
// std::cout << "v" << z << " = 7\n" << sl << "\n";
}
};
}
@ -128,5 +187,7 @@ void tst_slicing() {
test_slicing::test1();
test_slicing::test2();
test_slicing::test3();
test_slicing::test4();
// test_slicing::test5();
std::cout << "ok\n";
}