diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index fab000059..51abc6d68 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -28,7 +28,6 @@ Example: TODO: -- replay mk_extract/mk_concat in pop_scope. (easiest solution until we have proper garbage collection / reinitialization in the solver) - notify solver about equalities discovered by congruence - variable equalities x = y will be handled on-demand by the viable component - but whenever we derive an equality between pvar and value we must propagate the value in the solver @@ -36,20 +35,7 @@ TODO: - track fixed bits along with enodes - implement query functions - when solver assigns value of a variable v, add equations with v substituted by its value? - -TODO: better conflicts with pvar justification -- pvar justification is only introduced by add_value (when a variable is assigned in the model) -- so there can be at most two pvar justifications in a single conflict -- when explaining a conflict that contains pvars: - - single pvar x: the egraph has derived that x must have a different value c, learn literal x = c (instead of x != value(x) as is done now by the naive integration) - - two pvars x, y: learn literal x = y - Actually: it is an equality over slices x[h1:l1] = y[h2:l2], i.e., those slices that failed to merge. - -> how to get slice from egraph-explain? could store pointer to slice alongside pvar-dependencies. - -> we don't need to create a new slice since the equality will be over existing slices, - but (in general) we have to create a new variable for it. - - (this is basically what Algorithm 1 of "Solving Bitvectors with MCSAT" does) - -- then check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now? +- check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now? */ @@ -77,7 +63,7 @@ namespace polysat { return UINT_MAX; else if constexpr (std::is_same_v) return (arg.to_uint() << 1); - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) return (arg << 1) + 1; else static_assert(always_false_v, "non-exhaustive visitor!"); @@ -90,22 +76,22 @@ namespace polysat { else if ((x & 1) == 0) return dep_t(sat::to_literal(x >> 1)); else - return dep_t(static_cast(x >> 1)); + return dep_t(static_cast(x >> 1)); } - std::ostream& slicing::dep_t::display(std::ostream& out) { - if (is_null()) + std::ostream& slicing::display(std::ostream& out, dep_t d) { + if (d.is_null()) out << "null"; - else if (is_var()) - out << "v" << var(); - else if (is_lit()) - out << "lit(" << lit() << ")"; + else if (d.is_var_idx()) + out << "var(v" << get_dep_var(d) << " on slice " << get_dep_slice(d)->get_id() << ")"; + else if (d.is_lit()) + out << "lit(" << d.lit() << ")"; return out; } void* slicing::encode_dep(dep_t d) { void* p = box(d.to_uint()); - SASSERT_EQ(d, decode_dep(p)); + SASSERT(d == decode_dep(p)); return p; } @@ -113,8 +99,12 @@ namespace polysat { return dep_t::from_uint(unbox(p)); } - void slicing::display_dep(std::ostream& out, void* d) { - out << decode_dep(d); + slicing::dep_t slicing::mk_var_dep(pvar v, enode* s) { + SASSERT_EQ(m_dep_var.size(), m_dep_slice.size()); + unsigned const idx = m_dep_var.size(); + m_dep_var.push_back(v); + m_dep_slice.push_back(s); + return dep_t(idx); } slicing::slicing(solver& s): @@ -123,7 +113,8 @@ namespace polysat { { reg_decl_plugins(m_ast); m_bv = alloc(bv_util, m_ast); - m_egraph.set_display_justification(display_dep); + m_egraph.set_display_justification([&](std::ostream& out, void* d) { display(out, decode_dep(d)); }); + m_egraph.set_on_merge([&](enode* root, enode* other) { egraph_on_merge(root, other); }); m_egraph.set_on_propagate([&](enode* lit, enode* ante) { egraph_on_propagate(lit, ante); }); } @@ -175,6 +166,7 @@ namespace polysat { propagate(); m_scopes.push_back(m_trail.size()); m_egraph.push(); + m_dep_size_trail.push_back(m_dep_var.size()); SASSERT(m_needs_congruence.empty()); } @@ -213,6 +205,9 @@ namespace polysat { m_egraph.pop(num_scopes); m_needs_congruence.reset(); m_disequality_conflict = nullptr; + m_dep_var.shrink(m_dep_size_trail[target_lvl]); + m_dep_slice.shrink(m_dep_size_trail[target_lvl]); + m_dep_size_trail.shrink(target_lvl); // replay add_var/mk_extract/mk_concat in the same order // (only until polysat::solver supports proper garbage collection of variables) unsigned add_var_idx = replay_add_var.size(); @@ -401,10 +396,19 @@ namespace polysat { enode* target = n->get_target(); if (!target) continue; - euf::justification j = n->get_justification(); + euf::justification const j = n->get_justification(); SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before. - m_egraph.merge(sub_hi(n), sub_hi(target), j.ext()); - m_egraph.merge(sub_lo(n), sub_lo(target), j.ext()); + void* j_hi = j.ext(); + void* j_lo = j.ext(); + dep_t d = decode_dep(j.ext()); + if (d.is_var_idx()) { + enode* ds = get_dep_slice(d); + SASSERT(ds == n || ds == target); + j_hi = encode_dep(mk_var_dep(get_dep_var(d), sub_hi(ds))); + j_lo = encode_dep(mk_var_dep(get_dep_var(d), sub_lo(ds))); + } + m_egraph.merge(sub_hi(n), sub_hi(target), j_hi); + m_egraph.merge(sub_lo(n), sub_lo(target), j_lo); } } @@ -490,49 +494,14 @@ namespace polysat { return m_bv->is_numeral(s->get_expr(), val); } - void slicing::begin_explain() { - SASSERT(m_marked_lits.empty()); - SASSERT(m_marked_vars.empty()); - } - - void slicing::end_explain() { - m_marked_lits.reset(); - m_marked_vars.reset(); - } - - void slicing::push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars) { - dep_t d = decode_dep(dp); - if (d.is_var()) { - pvar v = d.var(); - if (m_marked_vars.contains(v)) - return; - m_marked_vars.insert(v); - out_vars.push_back(v); - } - else if (d.is_lit()) { - sat::literal lit = d.lit(); - if (m_marked_lits.contains(lit)) - return; - m_marked_lits.insert(lit); - out_lits.push_back(lit); - } - else { - SASSERT(d.is_null()); - } - } - - void slicing::explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) { + void slicing::explain_class(enode* x, enode* y, ptr_vector& out_deps) { SASSERT_EQ(x->get_root(), y->get_root()); - SASSERT(m_tmp_justifications.empty()); m_egraph.begin_explain(); - m_egraph.explain_eq(m_tmp_justifications, nullptr, x, y); + m_egraph.explain_eq(out_deps, nullptr, x, y); m_egraph.end_explain(); - for (void* dp : m_tmp_justifications) - push_dep(dp, out_lits, out_vars); - m_tmp_justifications.reset(); } - void slicing::explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) { + void slicing::explain_equal(enode* x, enode* y, ptr_vector& out_deps) { SASSERT(is_equal(x, y)); enode_vector& xs = m_tmp2; enode_vector& ys = m_tmp3; @@ -550,7 +519,7 @@ namespace polysat { enode* const rx = x->get_root(); enode* const ry = y->get_root(); if (rx == ry) - explain_class(x, y, out_lits, out_vars); + explain_class(x, y, out_deps); else { xs.push_back(sub_hi(rx)); xs.push_back(sub_lo(rx)); @@ -575,16 +544,12 @@ namespace polysat { SASSERT(ys.empty()); } - void slicing::explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars) { - begin_explain(); - explain_equal(var2slice(x), var2slice(y), out_lits, out_vars); - end_explain(); + void slicing::explain_equal(pvar x, pvar y, ptr_vector& out_deps) { + explain_equal(var2slice(x), var2slice(y), out_deps); } - void slicing::explain(sat::literal_vector& out_lits, unsigned_vector& out_vars) { + void slicing::explain(ptr_vector& out_deps) { SASSERT(is_conflict()); - begin_explain(); - SASSERT(m_tmp_justifications.empty()); m_egraph.begin_explain(); if (m_disequality_conflict) { enode* eqn = m_disequality_conflict; @@ -593,18 +558,14 @@ namespace polysat { SASSERT(eqn->get_lit_justification().is_external()); SASSERT(m_ast.is_eq(eqn->get_expr())); SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root()); - m_egraph.explain_eq(m_tmp_justifications, nullptr, eqn->get_arg(0), eqn->get_arg(1)); - push_dep(eqn->get_lit_justification().ext(), out_lits, out_vars); + m_egraph.explain_eq(out_deps, nullptr, eqn->get_arg(0), eqn->get_arg(1)); + out_deps.push_back(eqn->get_lit_justification().ext()); } else { SASSERT(m_egraph.inconsistent()); - m_egraph.explain(m_tmp_justifications, nullptr); + m_egraph.explain(out_deps, nullptr); } m_egraph.end_explain(); - for (void* dp : m_tmp_justifications) - push_dep(dp, out_lits, out_vars); - m_tmp_justifications.reset(); - end_explain(); } clause_ref slicing::conflict_clause() { @@ -641,6 +602,10 @@ namespace polysat { SASSERT_EQ(width(s1), width(s2)); SASSERT(!has_sub(s1)); SASSERT(!has_sub(s2)); + if (dep.is_var_idx()) { + SASSERT(is_value(s2)); + dep = mk_var_dep(get_dep_var(dep), s1); + } m_egraph.merge(s1, s2, encode_dep(dep)); return !is_conflict(); } @@ -894,6 +859,7 @@ namespace polysat { } void slicing::add_constraint(signed_constraint c) { + LOG(c); SASSERT(!is_conflict()); if (!c->is_eq()) return; @@ -947,10 +913,11 @@ namespace polysat { } void slicing::add_value(pvar v, rational const& val) { + LOG("v" << v << " := " << val); SASSERT(!is_conflict()); enode* const sv = var2slice(v); enode* const sval = mk_value_slice(val, width(sv)); - (void)merge(sv, sval, v); + (void)merge(sv, sval, mk_var_dep(v, sv)); } void slicing::collect_overlaps(pvar v, var_overlap_vector& out) { diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index 320fbdbdd..ebe351308 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -38,28 +38,38 @@ namespace polysat { friend class test_slicing; + using enode = euf::enode; + using enode_vector = euf::enode_vector; + class dep_t { - std::variant m_data; + std::variant m_data; public: dep_t() { SASSERT(is_null()); } dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); } - dep_t(pvar v): m_data(v) { SASSERT(v != null_var); SASSERT_EQ(v, var()); } + explicit dep_t(unsigned vi): m_data(vi) { SASSERT_EQ(vi, var_idx()); } bool is_null() const { return std::holds_alternative(m_data); } bool is_lit() const { return std::holds_alternative(m_data); } - bool is_var() const { return std::holds_alternative(m_data); } + bool is_var_idx() const { return std::holds_alternative(m_data); } sat::literal lit() const { SASSERT(is_lit()); return *std::get_if(&m_data); } - pvar var() const { SASSERT(is_var()); return *std::get_if(&m_data); } + unsigned var_idx() const { SASSERT(is_var_idx()); return *std::get_if(&m_data); } bool operator==(dep_t other) const { return m_data == other.m_data; } bool operator!=(dep_t other) const { return !operator==(other); } - std::ostream& display(std::ostream& out); unsigned to_uint() const; static dep_t from_uint(unsigned x); }; - friend std::ostream& operator<<(std::ostream&, slicing::dep_t); + using dep_vector = svector; - using enode = euf::enode; - using enode_vector = euf::enode_vector; + std::ostream& display(std::ostream& out, dep_t d); + + dep_t mk_var_dep(pvar v, enode* s); + + pvar_vector m_dep_var; + ptr_vector m_dep_slice; + unsigned_vector m_dep_size_trail; + + pvar get_dep_var(dep_t d) const { return m_dep_var[d.var_idx()]; } + enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.var_idx()]; } static constexpr unsigned null_cut = std::numeric_limits::max(); @@ -116,7 +126,6 @@ namespace polysat { static void* encode_dep(dep_t d); static dep_t decode_dep(void* d); - static void display_dep(std::ostream& out, void* d); slice_info& info(euf::enode* n); slice_info const& info(euf::enode* n) const; @@ -164,17 +173,20 @@ namespace polysat { /// If output_base is false, return coarsest intermediate slices instead of only base slices. void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true); - void begin_explain(); - void end_explain(); - void push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars); - // Extract reason why slices x and y are in the same equivalence class - void explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars); + void explain_class(enode* x, enode* y, ptr_vector& out_deps); // Extract reason why slices x and y are equal // (i.e., x and y have the same base, but are not necessarily in the same equivalence class) - void explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars); + void explain_equal(enode* x, enode* y, ptr_vector& out_deps); + /** Extract reason for conflict */ + void explain(ptr_vector& out_deps); + + /** Extract reason for x == y */ + void explain_equal(pvar x, pvar y, ptr_vector& out_deps); + + void egraph_on_merge(enode* root, enode* other); void egraph_on_propagate(enode* lit, enode* ante); // Merge equivalence classes of two base slices. @@ -237,9 +249,8 @@ namespace polysat { mutable enode_vector m_tmp1; mutable enode_vector m_tmp2; mutable enode_vector m_tmp3; - ptr_vector m_tmp_justifications; + ptr_vector m_tmp_deps; sat::literal_set m_marked_lits; - uint_set m_marked_vars; /** Get variable representing src[hi:lo] */ pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var); @@ -284,12 +295,8 @@ namespace polysat { bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); } - /** Extract reason for conflict */ - void explain(sat::literal_vector& out_lits, unsigned_vector& out_vars); /** Extract conflict clause */ - clause_ref conflict_clause(); - /** Extract reason for x == y */ - void explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars); + clause_ref build_conflict_clause(); /// Example: /// - assume query_var has segments 11122233 and var has segments 2224 @@ -318,5 +325,4 @@ namespace polysat { inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); } - inline std::ostream& operator<<(std::ostream& out, slicing::dep_t d) { return d.display(out); } }