mirror of
https://github.com/Z3Prover/z3
synced 2025-04-23 00:55:31 +00:00
pvar deps also need to track the slice they're coming from
This commit is contained in:
parent
2f0d74fca8
commit
12e9356f0f
2 changed files with 80 additions and 107 deletions
|
@ -28,7 +28,6 @@ Example:
|
|||
|
||||
|
||||
TODO:
|
||||
- replay mk_extract/mk_concat in pop_scope. (easiest solution until we have proper garbage collection / reinitialization in the solver)
|
||||
- notify solver about equalities discovered by congruence
|
||||
- variable equalities x = y will be handled on-demand by the viable component
|
||||
- but whenever we derive an equality between pvar and value we must propagate the value in the solver
|
||||
|
@ -36,20 +35,7 @@ TODO:
|
|||
- track fixed bits along with enodes
|
||||
- implement query functions
|
||||
- when solver assigns value of a variable v, add equations with v substituted by its value?
|
||||
|
||||
TODO: better conflicts with pvar justification
|
||||
- pvar justification is only introduced by add_value (when a variable is assigned in the model)
|
||||
- so there can be at most two pvar justifications in a single conflict
|
||||
- when explaining a conflict that contains pvars:
|
||||
- single pvar x: the egraph has derived that x must have a different value c, learn literal x = c (instead of x != value(x) as is done now by the naive integration)
|
||||
- two pvars x, y: learn literal x = y
|
||||
Actually: it is an equality over slices x[h1:l1] = y[h2:l2], i.e., those slices that failed to merge.
|
||||
-> how to get slice from egraph-explain? could store pointer to slice alongside pvar-dependencies.
|
||||
-> we don't need to create a new slice since the equality will be over existing slices,
|
||||
but (in general) we have to create a new variable for it.
|
||||
- (this is basically what Algorithm 1 of "Solving Bitvectors with MCSAT" does)
|
||||
|
||||
- then check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now?
|
||||
- check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now?
|
||||
|
||||
*/
|
||||
|
||||
|
@ -77,7 +63,7 @@ namespace polysat {
|
|||
return UINT_MAX;
|
||||
else if constexpr (std::is_same_v<T, sat::literal>)
|
||||
return (arg.to_uint() << 1);
|
||||
else if constexpr (std::is_same_v<T, pvar>)
|
||||
else if constexpr (std::is_same_v<T, unsigned>)
|
||||
return (arg << 1) + 1;
|
||||
else
|
||||
static_assert(always_false_v<T>, "non-exhaustive visitor!");
|
||||
|
@ -90,22 +76,22 @@ namespace polysat {
|
|||
else if ((x & 1) == 0)
|
||||
return dep_t(sat::to_literal(x >> 1));
|
||||
else
|
||||
return dep_t(static_cast<pvar>(x >> 1));
|
||||
return dep_t(static_cast<unsigned>(x >> 1));
|
||||
}
|
||||
|
||||
std::ostream& slicing::dep_t::display(std::ostream& out) {
|
||||
if (is_null())
|
||||
std::ostream& slicing::display(std::ostream& out, dep_t d) {
|
||||
if (d.is_null())
|
||||
out << "null";
|
||||
else if (is_var())
|
||||
out << "v" << var();
|
||||
else if (is_lit())
|
||||
out << "lit(" << lit() << ")";
|
||||
else if (d.is_var_idx())
|
||||
out << "var(v" << get_dep_var(d) << " on slice " << get_dep_slice(d)->get_id() << ")";
|
||||
else if (d.is_lit())
|
||||
out << "lit(" << d.lit() << ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
void* slicing::encode_dep(dep_t d) {
|
||||
void* p = box<void>(d.to_uint());
|
||||
SASSERT_EQ(d, decode_dep(p));
|
||||
SASSERT(d == decode_dep(p));
|
||||
return p;
|
||||
}
|
||||
|
||||
|
@ -113,8 +99,12 @@ namespace polysat {
|
|||
return dep_t::from_uint(unbox<unsigned>(p));
|
||||
}
|
||||
|
||||
void slicing::display_dep(std::ostream& out, void* d) {
|
||||
out << decode_dep(d);
|
||||
slicing::dep_t slicing::mk_var_dep(pvar v, enode* s) {
|
||||
SASSERT_EQ(m_dep_var.size(), m_dep_slice.size());
|
||||
unsigned const idx = m_dep_var.size();
|
||||
m_dep_var.push_back(v);
|
||||
m_dep_slice.push_back(s);
|
||||
return dep_t(idx);
|
||||
}
|
||||
|
||||
slicing::slicing(solver& s):
|
||||
|
@ -123,7 +113,8 @@ namespace polysat {
|
|||
{
|
||||
reg_decl_plugins(m_ast);
|
||||
m_bv = alloc(bv_util, m_ast);
|
||||
m_egraph.set_display_justification(display_dep);
|
||||
m_egraph.set_display_justification([&](std::ostream& out, void* d) { display(out, decode_dep(d)); });
|
||||
m_egraph.set_on_merge([&](enode* root, enode* other) { egraph_on_merge(root, other); });
|
||||
m_egraph.set_on_propagate([&](enode* lit, enode* ante) { egraph_on_propagate(lit, ante); });
|
||||
}
|
||||
|
||||
|
@ -175,6 +166,7 @@ namespace polysat {
|
|||
propagate();
|
||||
m_scopes.push_back(m_trail.size());
|
||||
m_egraph.push();
|
||||
m_dep_size_trail.push_back(m_dep_var.size());
|
||||
SASSERT(m_needs_congruence.empty());
|
||||
}
|
||||
|
||||
|
@ -213,6 +205,9 @@ namespace polysat {
|
|||
m_egraph.pop(num_scopes);
|
||||
m_needs_congruence.reset();
|
||||
m_disequality_conflict = nullptr;
|
||||
m_dep_var.shrink(m_dep_size_trail[target_lvl]);
|
||||
m_dep_slice.shrink(m_dep_size_trail[target_lvl]);
|
||||
m_dep_size_trail.shrink(target_lvl);
|
||||
// replay add_var/mk_extract/mk_concat in the same order
|
||||
// (only until polysat::solver supports proper garbage collection of variables)
|
||||
unsigned add_var_idx = replay_add_var.size();
|
||||
|
@ -401,10 +396,19 @@ namespace polysat {
|
|||
enode* target = n->get_target();
|
||||
if (!target)
|
||||
continue;
|
||||
euf::justification j = n->get_justification();
|
||||
euf::justification const j = n->get_justification();
|
||||
SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before.
|
||||
m_egraph.merge(sub_hi(n), sub_hi(target), j.ext<void>());
|
||||
m_egraph.merge(sub_lo(n), sub_lo(target), j.ext<void>());
|
||||
void* j_hi = j.ext<void>();
|
||||
void* j_lo = j.ext<void>();
|
||||
dep_t d = decode_dep(j.ext<void>());
|
||||
if (d.is_var_idx()) {
|
||||
enode* ds = get_dep_slice(d);
|
||||
SASSERT(ds == n || ds == target);
|
||||
j_hi = encode_dep(mk_var_dep(get_dep_var(d), sub_hi(ds)));
|
||||
j_lo = encode_dep(mk_var_dep(get_dep_var(d), sub_lo(ds)));
|
||||
}
|
||||
m_egraph.merge(sub_hi(n), sub_hi(target), j_hi);
|
||||
m_egraph.merge(sub_lo(n), sub_lo(target), j_lo);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -490,49 +494,14 @@ namespace polysat {
|
|||
return m_bv->is_numeral(s->get_expr(), val);
|
||||
}
|
||||
|
||||
void slicing::begin_explain() {
|
||||
SASSERT(m_marked_lits.empty());
|
||||
SASSERT(m_marked_vars.empty());
|
||||
}
|
||||
|
||||
void slicing::end_explain() {
|
||||
m_marked_lits.reset();
|
||||
m_marked_vars.reset();
|
||||
}
|
||||
|
||||
void slicing::push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
|
||||
dep_t d = decode_dep(dp);
|
||||
if (d.is_var()) {
|
||||
pvar v = d.var();
|
||||
if (m_marked_vars.contains(v))
|
||||
return;
|
||||
m_marked_vars.insert(v);
|
||||
out_vars.push_back(v);
|
||||
}
|
||||
else if (d.is_lit()) {
|
||||
sat::literal lit = d.lit();
|
||||
if (m_marked_lits.contains(lit))
|
||||
return;
|
||||
m_marked_lits.insert(lit);
|
||||
out_lits.push_back(lit);
|
||||
}
|
||||
else {
|
||||
SASSERT(d.is_null());
|
||||
}
|
||||
}
|
||||
|
||||
void slicing::explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
|
||||
void slicing::explain_class(enode* x, enode* y, ptr_vector<void>& out_deps) {
|
||||
SASSERT_EQ(x->get_root(), y->get_root());
|
||||
SASSERT(m_tmp_justifications.empty());
|
||||
m_egraph.begin_explain();
|
||||
m_egraph.explain_eq(m_tmp_justifications, nullptr, x, y);
|
||||
m_egraph.explain_eq(out_deps, nullptr, x, y);
|
||||
m_egraph.end_explain();
|
||||
for (void* dp : m_tmp_justifications)
|
||||
push_dep(dp, out_lits, out_vars);
|
||||
m_tmp_justifications.reset();
|
||||
}
|
||||
|
||||
void slicing::explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
|
||||
void slicing::explain_equal(enode* x, enode* y, ptr_vector<void>& out_deps) {
|
||||
SASSERT(is_equal(x, y));
|
||||
enode_vector& xs = m_tmp2;
|
||||
enode_vector& ys = m_tmp3;
|
||||
|
@ -550,7 +519,7 @@ namespace polysat {
|
|||
enode* const rx = x->get_root();
|
||||
enode* const ry = y->get_root();
|
||||
if (rx == ry)
|
||||
explain_class(x, y, out_lits, out_vars);
|
||||
explain_class(x, y, out_deps);
|
||||
else {
|
||||
xs.push_back(sub_hi(rx));
|
||||
xs.push_back(sub_lo(rx));
|
||||
|
@ -575,16 +544,12 @@ namespace polysat {
|
|||
SASSERT(ys.empty());
|
||||
}
|
||||
|
||||
void slicing::explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars) {
|
||||
begin_explain();
|
||||
explain_equal(var2slice(x), var2slice(y), out_lits, out_vars);
|
||||
end_explain();
|
||||
void slicing::explain_equal(pvar x, pvar y, ptr_vector<void>& out_deps) {
|
||||
explain_equal(var2slice(x), var2slice(y), out_deps);
|
||||
}
|
||||
|
||||
void slicing::explain(sat::literal_vector& out_lits, unsigned_vector& out_vars) {
|
||||
void slicing::explain(ptr_vector<void>& out_deps) {
|
||||
SASSERT(is_conflict());
|
||||
begin_explain();
|
||||
SASSERT(m_tmp_justifications.empty());
|
||||
m_egraph.begin_explain();
|
||||
if (m_disequality_conflict) {
|
||||
enode* eqn = m_disequality_conflict;
|
||||
|
@ -593,18 +558,14 @@ namespace polysat {
|
|||
SASSERT(eqn->get_lit_justification().is_external());
|
||||
SASSERT(m_ast.is_eq(eqn->get_expr()));
|
||||
SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root());
|
||||
m_egraph.explain_eq(m_tmp_justifications, nullptr, eqn->get_arg(0), eqn->get_arg(1));
|
||||
push_dep(eqn->get_lit_justification().ext<void>(), out_lits, out_vars);
|
||||
m_egraph.explain_eq(out_deps, nullptr, eqn->get_arg(0), eqn->get_arg(1));
|
||||
out_deps.push_back(eqn->get_lit_justification().ext<void>());
|
||||
}
|
||||
else {
|
||||
SASSERT(m_egraph.inconsistent());
|
||||
m_egraph.explain(m_tmp_justifications, nullptr);
|
||||
m_egraph.explain(out_deps, nullptr);
|
||||
}
|
||||
m_egraph.end_explain();
|
||||
for (void* dp : m_tmp_justifications)
|
||||
push_dep(dp, out_lits, out_vars);
|
||||
m_tmp_justifications.reset();
|
||||
end_explain();
|
||||
}
|
||||
|
||||
clause_ref slicing::conflict_clause() {
|
||||
|
@ -641,6 +602,10 @@ namespace polysat {
|
|||
SASSERT_EQ(width(s1), width(s2));
|
||||
SASSERT(!has_sub(s1));
|
||||
SASSERT(!has_sub(s2));
|
||||
if (dep.is_var_idx()) {
|
||||
SASSERT(is_value(s2));
|
||||
dep = mk_var_dep(get_dep_var(dep), s1);
|
||||
}
|
||||
m_egraph.merge(s1, s2, encode_dep(dep));
|
||||
return !is_conflict();
|
||||
}
|
||||
|
@ -894,6 +859,7 @@ namespace polysat {
|
|||
}
|
||||
|
||||
void slicing::add_constraint(signed_constraint c) {
|
||||
LOG(c);
|
||||
SASSERT(!is_conflict());
|
||||
if (!c->is_eq())
|
||||
return;
|
||||
|
@ -947,10 +913,11 @@ namespace polysat {
|
|||
}
|
||||
|
||||
void slicing::add_value(pvar v, rational const& val) {
|
||||
LOG("v" << v << " := " << val);
|
||||
SASSERT(!is_conflict());
|
||||
enode* const sv = var2slice(v);
|
||||
enode* const sval = mk_value_slice(val, width(sv));
|
||||
(void)merge(sv, sval, v);
|
||||
(void)merge(sv, sval, mk_var_dep(v, sv));
|
||||
}
|
||||
|
||||
void slicing::collect_overlaps(pvar v, var_overlap_vector& out) {
|
||||
|
|
|
@ -38,28 +38,38 @@ namespace polysat {
|
|||
|
||||
friend class test_slicing;
|
||||
|
||||
using enode = euf::enode;
|
||||
using enode_vector = euf::enode_vector;
|
||||
|
||||
class dep_t {
|
||||
std::variant<std::monostate, sat::literal, pvar> m_data;
|
||||
std::variant<std::monostate, sat::literal, unsigned> m_data;
|
||||
public:
|
||||
dep_t() { SASSERT(is_null()); }
|
||||
dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); }
|
||||
dep_t(pvar v): m_data(v) { SASSERT(v != null_var); SASSERT_EQ(v, var()); }
|
||||
explicit dep_t(unsigned vi): m_data(vi) { SASSERT_EQ(vi, var_idx()); }
|
||||
bool is_null() const { return std::holds_alternative<std::monostate>(m_data); }
|
||||
bool is_lit() const { return std::holds_alternative<sat::literal>(m_data); }
|
||||
bool is_var() const { return std::holds_alternative<pvar>(m_data); }
|
||||
bool is_var_idx() const { return std::holds_alternative<unsigned>(m_data); }
|
||||
sat::literal lit() const { SASSERT(is_lit()); return *std::get_if<sat::literal>(&m_data); }
|
||||
pvar var() const { SASSERT(is_var()); return *std::get_if<pvar>(&m_data); }
|
||||
unsigned var_idx() const { SASSERT(is_var_idx()); return *std::get_if<unsigned>(&m_data); }
|
||||
bool operator==(dep_t other) const { return m_data == other.m_data; }
|
||||
bool operator!=(dep_t other) const { return !operator==(other); }
|
||||
std::ostream& display(std::ostream& out);
|
||||
unsigned to_uint() const;
|
||||
static dep_t from_uint(unsigned x);
|
||||
};
|
||||
|
||||
friend std::ostream& operator<<(std::ostream&, slicing::dep_t);
|
||||
using dep_vector = svector<dep_t>;
|
||||
|
||||
using enode = euf::enode;
|
||||
using enode_vector = euf::enode_vector;
|
||||
std::ostream& display(std::ostream& out, dep_t d);
|
||||
|
||||
dep_t mk_var_dep(pvar v, enode* s);
|
||||
|
||||
pvar_vector m_dep_var;
|
||||
ptr_vector<enode> m_dep_slice;
|
||||
unsigned_vector m_dep_size_trail;
|
||||
|
||||
pvar get_dep_var(dep_t d) const { return m_dep_var[d.var_idx()]; }
|
||||
enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.var_idx()]; }
|
||||
|
||||
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
|
||||
|
||||
|
@ -116,7 +126,6 @@ namespace polysat {
|
|||
|
||||
static void* encode_dep(dep_t d);
|
||||
static dep_t decode_dep(void* d);
|
||||
static void display_dep(std::ostream& out, void* d);
|
||||
|
||||
slice_info& info(euf::enode* n);
|
||||
slice_info const& info(euf::enode* n) const;
|
||||
|
@ -164,17 +173,20 @@ namespace polysat {
|
|||
/// If output_base is false, return coarsest intermediate slices instead of only base slices.
|
||||
void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true);
|
||||
|
||||
void begin_explain();
|
||||
void end_explain();
|
||||
void push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars);
|
||||
|
||||
// Extract reason why slices x and y are in the same equivalence class
|
||||
void explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars);
|
||||
void explain_class(enode* x, enode* y, ptr_vector<void>& out_deps);
|
||||
|
||||
// Extract reason why slices x and y are equal
|
||||
// (i.e., x and y have the same base, but are not necessarily in the same equivalence class)
|
||||
void explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars);
|
||||
void explain_equal(enode* x, enode* y, ptr_vector<void>& out_deps);
|
||||
|
||||
/** Extract reason for conflict */
|
||||
void explain(ptr_vector<void>& out_deps);
|
||||
|
||||
/** Extract reason for x == y */
|
||||
void explain_equal(pvar x, pvar y, ptr_vector<void>& out_deps);
|
||||
|
||||
void egraph_on_merge(enode* root, enode* other);
|
||||
void egraph_on_propagate(enode* lit, enode* ante);
|
||||
|
||||
// Merge equivalence classes of two base slices.
|
||||
|
@ -237,9 +249,8 @@ namespace polysat {
|
|||
mutable enode_vector m_tmp1;
|
||||
mutable enode_vector m_tmp2;
|
||||
mutable enode_vector m_tmp3;
|
||||
ptr_vector<void> m_tmp_justifications;
|
||||
ptr_vector<void> m_tmp_deps;
|
||||
sat::literal_set m_marked_lits;
|
||||
uint_set m_marked_vars;
|
||||
|
||||
/** Get variable representing src[hi:lo] */
|
||||
pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var);
|
||||
|
@ -284,12 +295,8 @@ namespace polysat {
|
|||
|
||||
bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); }
|
||||
|
||||
/** Extract reason for conflict */
|
||||
void explain(sat::literal_vector& out_lits, unsigned_vector& out_vars);
|
||||
/** Extract conflict clause */
|
||||
clause_ref conflict_clause();
|
||||
/** Extract reason for x == y */
|
||||
void explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars);
|
||||
clause_ref build_conflict_clause();
|
||||
|
||||
/// Example:
|
||||
/// - assume query_var has segments 11122233 and var has segments 2224
|
||||
|
@ -318,5 +325,4 @@ namespace polysat {
|
|||
|
||||
inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); }
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, slicing::dep_t d) { return d.display(out); }
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue