3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 00:55:31 +00:00

pvar deps also need to track the slice they're coming from

This commit is contained in:
Jakob Rath 2023-07-26 09:38:29 +02:00
parent 2f0d74fca8
commit 12e9356f0f
2 changed files with 80 additions and 107 deletions

View file

@ -28,7 +28,6 @@ Example:
TODO:
- replay mk_extract/mk_concat in pop_scope. (easiest solution until we have proper garbage collection / reinitialization in the solver)
- notify solver about equalities discovered by congruence
- variable equalities x = y will be handled on-demand by the viable component
- but whenever we derive an equality between pvar and value we must propagate the value in the solver
@ -36,20 +35,7 @@ TODO:
- track fixed bits along with enodes
- implement query functions
- when solver assigns value of a variable v, add equations with v substituted by its value?
TODO: better conflicts with pvar justification
- pvar justification is only introduced by add_value (when a variable is assigned in the model)
- so there can be at most two pvar justifications in a single conflict
- when explaining a conflict that contains pvars:
- single pvar x: the egraph has derived that x must have a different value c, learn literal x = c (instead of x != value(x) as is done now by the naive integration)
- two pvars x, y: learn literal x = y
Actually: it is an equality over slices x[h1:l1] = y[h2:l2], i.e., those slices that failed to merge.
-> how to get slice from egraph-explain? could store pointer to slice alongside pvar-dependencies.
-> we don't need to create a new slice since the equality will be over existing slices,
but (in general) we have to create a new variable for it.
- (this is basically what Algorithm 1 of "Solving Bitvectors with MCSAT" does)
- then check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now?
- check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now?
*/
@ -77,7 +63,7 @@ namespace polysat {
return UINT_MAX;
else if constexpr (std::is_same_v<T, sat::literal>)
return (arg.to_uint() << 1);
else if constexpr (std::is_same_v<T, pvar>)
else if constexpr (std::is_same_v<T, unsigned>)
return (arg << 1) + 1;
else
static_assert(always_false_v<T>, "non-exhaustive visitor!");
@ -90,22 +76,22 @@ namespace polysat {
else if ((x & 1) == 0)
return dep_t(sat::to_literal(x >> 1));
else
return dep_t(static_cast<pvar>(x >> 1));
return dep_t(static_cast<unsigned>(x >> 1));
}
std::ostream& slicing::dep_t::display(std::ostream& out) {
if (is_null())
std::ostream& slicing::display(std::ostream& out, dep_t d) {
if (d.is_null())
out << "null";
else if (is_var())
out << "v" << var();
else if (is_lit())
out << "lit(" << lit() << ")";
else if (d.is_var_idx())
out << "var(v" << get_dep_var(d) << " on slice " << get_dep_slice(d)->get_id() << ")";
else if (d.is_lit())
out << "lit(" << d.lit() << ")";
return out;
}
void* slicing::encode_dep(dep_t d) {
void* p = box<void>(d.to_uint());
SASSERT_EQ(d, decode_dep(p));
SASSERT(d == decode_dep(p));
return p;
}
@ -113,8 +99,12 @@ namespace polysat {
return dep_t::from_uint(unbox<unsigned>(p));
}
void slicing::display_dep(std::ostream& out, void* d) {
out << decode_dep(d);
slicing::dep_t slicing::mk_var_dep(pvar v, enode* s) {
SASSERT_EQ(m_dep_var.size(), m_dep_slice.size());
unsigned const idx = m_dep_var.size();
m_dep_var.push_back(v);
m_dep_slice.push_back(s);
return dep_t(idx);
}
slicing::slicing(solver& s):
@ -123,7 +113,8 @@ namespace polysat {
{
reg_decl_plugins(m_ast);
m_bv = alloc(bv_util, m_ast);
m_egraph.set_display_justification(display_dep);
m_egraph.set_display_justification([&](std::ostream& out, void* d) { display(out, decode_dep(d)); });
m_egraph.set_on_merge([&](enode* root, enode* other) { egraph_on_merge(root, other); });
m_egraph.set_on_propagate([&](enode* lit, enode* ante) { egraph_on_propagate(lit, ante); });
}
@ -175,6 +166,7 @@ namespace polysat {
propagate();
m_scopes.push_back(m_trail.size());
m_egraph.push();
m_dep_size_trail.push_back(m_dep_var.size());
SASSERT(m_needs_congruence.empty());
}
@ -213,6 +205,9 @@ namespace polysat {
m_egraph.pop(num_scopes);
m_needs_congruence.reset();
m_disequality_conflict = nullptr;
m_dep_var.shrink(m_dep_size_trail[target_lvl]);
m_dep_slice.shrink(m_dep_size_trail[target_lvl]);
m_dep_size_trail.shrink(target_lvl);
// replay add_var/mk_extract/mk_concat in the same order
// (only until polysat::solver supports proper garbage collection of variables)
unsigned add_var_idx = replay_add_var.size();
@ -401,10 +396,19 @@ namespace polysat {
enode* target = n->get_target();
if (!target)
continue;
euf::justification j = n->get_justification();
euf::justification const j = n->get_justification();
SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before.
m_egraph.merge(sub_hi(n), sub_hi(target), j.ext<void>());
m_egraph.merge(sub_lo(n), sub_lo(target), j.ext<void>());
void* j_hi = j.ext<void>();
void* j_lo = j.ext<void>();
dep_t d = decode_dep(j.ext<void>());
if (d.is_var_idx()) {
enode* ds = get_dep_slice(d);
SASSERT(ds == n || ds == target);
j_hi = encode_dep(mk_var_dep(get_dep_var(d), sub_hi(ds)));
j_lo = encode_dep(mk_var_dep(get_dep_var(d), sub_lo(ds)));
}
m_egraph.merge(sub_hi(n), sub_hi(target), j_hi);
m_egraph.merge(sub_lo(n), sub_lo(target), j_lo);
}
}
@ -490,49 +494,14 @@ namespace polysat {
return m_bv->is_numeral(s->get_expr(), val);
}
void slicing::begin_explain() {
SASSERT(m_marked_lits.empty());
SASSERT(m_marked_vars.empty());
}
void slicing::end_explain() {
m_marked_lits.reset();
m_marked_vars.reset();
}
void slicing::push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
dep_t d = decode_dep(dp);
if (d.is_var()) {
pvar v = d.var();
if (m_marked_vars.contains(v))
return;
m_marked_vars.insert(v);
out_vars.push_back(v);
}
else if (d.is_lit()) {
sat::literal lit = d.lit();
if (m_marked_lits.contains(lit))
return;
m_marked_lits.insert(lit);
out_lits.push_back(lit);
}
else {
SASSERT(d.is_null());
}
}
void slicing::explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
void slicing::explain_class(enode* x, enode* y, ptr_vector<void>& out_deps) {
SASSERT_EQ(x->get_root(), y->get_root());
SASSERT(m_tmp_justifications.empty());
m_egraph.begin_explain();
m_egraph.explain_eq(m_tmp_justifications, nullptr, x, y);
m_egraph.explain_eq(out_deps, nullptr, x, y);
m_egraph.end_explain();
for (void* dp : m_tmp_justifications)
push_dep(dp, out_lits, out_vars);
m_tmp_justifications.reset();
}
void slicing::explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
void slicing::explain_equal(enode* x, enode* y, ptr_vector<void>& out_deps) {
SASSERT(is_equal(x, y));
enode_vector& xs = m_tmp2;
enode_vector& ys = m_tmp3;
@ -550,7 +519,7 @@ namespace polysat {
enode* const rx = x->get_root();
enode* const ry = y->get_root();
if (rx == ry)
explain_class(x, y, out_lits, out_vars);
explain_class(x, y, out_deps);
else {
xs.push_back(sub_hi(rx));
xs.push_back(sub_lo(rx));
@ -575,16 +544,12 @@ namespace polysat {
SASSERT(ys.empty());
}
void slicing::explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
begin_explain();
explain_equal(var2slice(x), var2slice(y), out_lits, out_vars);
end_explain();
void slicing::explain_equal(pvar x, pvar y, ptr_vector<void>& out_deps) {
explain_equal(var2slice(x), var2slice(y), out_deps);
}
void slicing::explain(sat::literal_vector& out_lits, unsigned_vector& out_vars) {
void slicing::explain(ptr_vector<void>& out_deps) {
SASSERT(is_conflict());
begin_explain();
SASSERT(m_tmp_justifications.empty());
m_egraph.begin_explain();
if (m_disequality_conflict) {
enode* eqn = m_disequality_conflict;
@ -593,18 +558,14 @@ namespace polysat {
SASSERT(eqn->get_lit_justification().is_external());
SASSERT(m_ast.is_eq(eqn->get_expr()));
SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root());
m_egraph.explain_eq(m_tmp_justifications, nullptr, eqn->get_arg(0), eqn->get_arg(1));
push_dep(eqn->get_lit_justification().ext<void>(), out_lits, out_vars);
m_egraph.explain_eq(out_deps, nullptr, eqn->get_arg(0), eqn->get_arg(1));
out_deps.push_back(eqn->get_lit_justification().ext<void>());
}
else {
SASSERT(m_egraph.inconsistent());
m_egraph.explain(m_tmp_justifications, nullptr);
m_egraph.explain(out_deps, nullptr);
}
m_egraph.end_explain();
for (void* dp : m_tmp_justifications)
push_dep(dp, out_lits, out_vars);
m_tmp_justifications.reset();
end_explain();
}
clause_ref slicing::conflict_clause() {
@ -641,6 +602,10 @@ namespace polysat {
SASSERT_EQ(width(s1), width(s2));
SASSERT(!has_sub(s1));
SASSERT(!has_sub(s2));
if (dep.is_var_idx()) {
SASSERT(is_value(s2));
dep = mk_var_dep(get_dep_var(dep), s1);
}
m_egraph.merge(s1, s2, encode_dep(dep));
return !is_conflict();
}
@ -894,6 +859,7 @@ namespace polysat {
}
void slicing::add_constraint(signed_constraint c) {
LOG(c);
SASSERT(!is_conflict());
if (!c->is_eq())
return;
@ -947,10 +913,11 @@ namespace polysat {
}
void slicing::add_value(pvar v, rational const& val) {
LOG("v" << v << " := " << val);
SASSERT(!is_conflict());
enode* const sv = var2slice(v);
enode* const sval = mk_value_slice(val, width(sv));
(void)merge(sv, sval, v);
(void)merge(sv, sval, mk_var_dep(v, sv));
}
void slicing::collect_overlaps(pvar v, var_overlap_vector& out) {

View file

@ -38,28 +38,38 @@ namespace polysat {
friend class test_slicing;
using enode = euf::enode;
using enode_vector = euf::enode_vector;
class dep_t {
std::variant<std::monostate, sat::literal, pvar> m_data;
std::variant<std::monostate, sat::literal, unsigned> m_data;
public:
dep_t() { SASSERT(is_null()); }
dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); }
dep_t(pvar v): m_data(v) { SASSERT(v != null_var); SASSERT_EQ(v, var()); }
explicit dep_t(unsigned vi): m_data(vi) { SASSERT_EQ(vi, var_idx()); }
bool is_null() const { return std::holds_alternative<std::monostate>(m_data); }
bool is_lit() const { return std::holds_alternative<sat::literal>(m_data); }
bool is_var() const { return std::holds_alternative<pvar>(m_data); }
bool is_var_idx() const { return std::holds_alternative<unsigned>(m_data); }
sat::literal lit() const { SASSERT(is_lit()); return *std::get_if<sat::literal>(&m_data); }
pvar var() const { SASSERT(is_var()); return *std::get_if<pvar>(&m_data); }
unsigned var_idx() const { SASSERT(is_var_idx()); return *std::get_if<unsigned>(&m_data); }
bool operator==(dep_t other) const { return m_data == other.m_data; }
bool operator!=(dep_t other) const { return !operator==(other); }
std::ostream& display(std::ostream& out);
unsigned to_uint() const;
static dep_t from_uint(unsigned x);
};
friend std::ostream& operator<<(std::ostream&, slicing::dep_t);
using dep_vector = svector<dep_t>;
using enode = euf::enode;
using enode_vector = euf::enode_vector;
std::ostream& display(std::ostream& out, dep_t d);
dep_t mk_var_dep(pvar v, enode* s);
pvar_vector m_dep_var;
ptr_vector<enode> m_dep_slice;
unsigned_vector m_dep_size_trail;
pvar get_dep_var(dep_t d) const { return m_dep_var[d.var_idx()]; }
enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.var_idx()]; }
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
@ -116,7 +126,6 @@ namespace polysat {
static void* encode_dep(dep_t d);
static dep_t decode_dep(void* d);
static void display_dep(std::ostream& out, void* d);
slice_info& info(euf::enode* n);
slice_info const& info(euf::enode* n) const;
@ -164,17 +173,20 @@ namespace polysat {
/// If output_base is false, return coarsest intermediate slices instead of only base slices.
void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true);
void begin_explain();
void end_explain();
void push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars);
// Extract reason why slices x and y are in the same equivalence class
void explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars);
void explain_class(enode* x, enode* y, ptr_vector<void>& out_deps);
// Extract reason why slices x and y are equal
// (i.e., x and y have the same base, but are not necessarily in the same equivalence class)
void explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars);
void explain_equal(enode* x, enode* y, ptr_vector<void>& out_deps);
/** Extract reason for conflict */
void explain(ptr_vector<void>& out_deps);
/** Extract reason for x == y */
void explain_equal(pvar x, pvar y, ptr_vector<void>& out_deps);
void egraph_on_merge(enode* root, enode* other);
void egraph_on_propagate(enode* lit, enode* ante);
// Merge equivalence classes of two base slices.
@ -237,9 +249,8 @@ namespace polysat {
mutable enode_vector m_tmp1;
mutable enode_vector m_tmp2;
mutable enode_vector m_tmp3;
ptr_vector<void> m_tmp_justifications;
ptr_vector<void> m_tmp_deps;
sat::literal_set m_marked_lits;
uint_set m_marked_vars;
/** Get variable representing src[hi:lo] */
pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var);
@ -284,12 +295,8 @@ namespace polysat {
bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); }
/** Extract reason for conflict */
void explain(sat::literal_vector& out_lits, unsigned_vector& out_vars);
/** Extract conflict clause */
clause_ref conflict_clause();
/** Extract reason for x == y */
void explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars);
clause_ref build_conflict_clause();
/// Example:
/// - assume query_var has segments 11122233 and var has segments 2224
@ -318,5 +325,4 @@ namespace polysat {
inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); }
inline std::ostream& operator<<(std::ostream& out, slicing::dep_t d) { return d.display(out); }
}