From 947335e147db9120509bb6806ac7d5856a05859a Mon Sep 17 00:00:00 2001 From: Jakob Rath Date: Wed, 28 Jun 2023 09:59:04 +0200 Subject: [PATCH] slicing: prepare for explain() --- src/math/polysat/slicing.cpp | 103 +++++++++++++++++++++++++++-------- src/math/polysat/slicing.h | 39 ++++++++++--- src/util/sat_literal.h | 15 +++-- 3 files changed, 118 insertions(+), 39 deletions(-) diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 40c1ea9af..d0eb214b3 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -62,6 +62,8 @@ namespace polysat { m_size.push_back(1); m_next.push_back(s); m_slice2var.push_back(null_var); + m_proof_parent.push_back(null_slice); + m_proof_reason.push_back(null_dep); m_trail.push_back(trail_item::alloc_slice); return s; } @@ -74,6 +76,8 @@ namespace polysat { m_size.pop_back(); m_next.pop_back(); m_slice2var.pop_back(); + m_proof_parent.pop_back(); + m_proof_reason.pop_back(); } slicing::slice slicing::find_sub_hi(slice parent) const { @@ -89,13 +93,13 @@ namespace polysat { void slicing::split(slice s, unsigned cut) { SASSERT(!has_sub(s)); SASSERT(width(s) - 1 >= cut + 1); - slice const sub1 = alloc_slice(); - slice const sub2 = alloc_slice(); + slice const sub_hi = alloc_slice(); + slice const sub_lo = alloc_slice(); m_slice_cut[s] = cut; - m_slice_sub[s] = sub1; - SASSERT_EQ(sub2, sub1 + 1); - m_slice_width[sub1] = width(s) - cut - 1; - m_slice_width[sub2] = cut + 1; + m_slice_sub[s] = sub_hi; + SASSERT_EQ(sub_lo, sub_hi + 1); + m_slice_width[sub_hi] = width(s) - cut - 1; + m_slice_width[sub_lo] = cut + 1; m_trail.push_back(trail_item::split_slice); m_split_trail.push_back(s); } @@ -117,16 +121,18 @@ namespace polysat { } } - void slicing::merge_base(slice s1, slice s2) { + bool slicing::merge_base(slice s1, slice s2, dep_t dep) { SASSERT_EQ(width(s1), width(s2)); SASSERT(!has_sub(s1)); SASSERT(!has_sub(s2)); slice r1 = find(s1); slice r2 = find(s2); if (r1 == r2) - return; - if (m_size[r1] > m_size[r2]) + return true; + if (m_size[r1] > m_size[r2]) { std::swap(r1, r2); + std::swap(s1, s2); + } // r2 becomes the representative of the merged class m_find[r1] = r2; m_size[r2] += m_size[r1]; @@ -137,8 +143,18 @@ namespace polysat { // otherwise the classes should have been merged already SASSERT(m_slice2var[r2] != m_slice2var[r1]); } + // Add justification 'dep' for s1 = s2 + // NOTE: invariant: root of the proof tree is the representative + SASSERT(m_proof_parent[r1] == null_slice); + SASSERT(m_proof_parent[r2] == null_slice); + make_proof_root(s1); + SASSERT(m_proof_parent[s1] == null_slice); + m_proof_parent[s1] = s2; + m_proof_reason[s1] = dep; + SASSERT(m_proof_parent[r2] == null_slice); m_trail.push_back(trail_item::merge_base); - m_merge_trail.push_back(r1); + m_merge_trail.push_back({r1, s1}); + return true; } void slicing::undo_merge_base() { @@ -151,9 +167,41 @@ namespace polysat { std::swap(m_next[r1], m_next[r2]); if (m_slice2var[r2] == m_slice2var[r1]) m_slice2var[r2] = null_var; + SASSERT(m_proof_parent[s1] == null_slice); + SASSERT(m_proof_parent[r2] == null_slice); + m_proof_parent[s1] = null_slice; + m_proof_reason[s1] = null_dep; + SASSERT(m_proof_parent[r1] == null_slice); + SASSERT(m_proof_parent[r2] == null_slice); + make_proof_root(r1); } - void slicing::merge(slice_vector& xs, slice_vector& ys) { + void slicing::make_proof_root(slice s) { + // s1 -> s2 -> s3 -> s4 + // r1 r2 r3 + // => + // s1 <- s2 <- s3 <- s4 + // r1 r2 r3 + slice curr = s; + slice prev = null_slice; + dep_t prev_reason = null_dep; + while (curr != null_slice) { + slice const curr_parent = m_proof_parent[curr]; + dep_t const curr_reason = m_proof_reason[curr]; + m_proof_parent[curr] = prev; + m_proof_reason[curr] = prev_reason; + prev = curr; + prev_reason = curr_reason; + curr = curr_parent; + } + } + + void slicing::explain(slice x, slice y, dep_vector& out_deps) { + SASSERT_EQ(find(x), find(y)); + NOT_IMPLEMENTED_YET(); + } + + bool slicing::merge(slice_vector& xs, slice_vector& ys, dep_t dep) { // LOG_H2("Merging " << xs << " with " << ys); while (!xs.empty()) { SASSERT(!ys.empty()); @@ -175,7 +223,8 @@ namespace polysat { SASSERT(!has_sub(y)); if (width(x) == width(y)) { // LOG("Match " << x << " and " << y); - merge_base(x, y); + if (!merge_base(x, y, dep)) + return false; } else if (width(x) > width(y)) { // need to split x according to y @@ -192,21 +241,27 @@ namespace polysat { } } SASSERT(ys.empty()); + return true; } - void slicing::merge(slice_vector& xs, slice y) { - slice_vector tmp; - tmp.push_back(y); - merge(xs, tmp); + bool slicing::merge(slice_vector& xs, slice y, dep_t dep) { + slice_vector& ys = m_tmp2; + SASSERT(ys.empty()); + ys.push_back(y); + return merge(xs, ys, dep); // will clear xs and ys } - void slicing::merge(slice x, slice y) { + bool slicing::merge(slice x, slice y, dep_t dep) { if (!has_sub(x) && !has_sub(y)) - return merge_base(x, y); - slice_vector tmpx, tmpy; - tmpx.push_back(x); - tmpy.push_back(y); - merge(tmpx, tmpy); + return merge_base(x, y, dep); + slice_vector& xs = m_tmp2; + slice_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + xs.push_back(x); + ys.push_back(y); + return merge(xs, ys, dep); // will clear xs and ys + } } void slicing::find_base(slice src, slice_vector& out_base) const { @@ -303,7 +358,7 @@ namespace polysat { return slice2var(s); } pvar v = m_solver.add_var(hi - lo + 1); - merge(slices, var2slice(v)); + VERIFY(merge(slices, var2slice(v), null_dep)); return v; } @@ -338,7 +393,7 @@ namespace polysat { slice_vector tmp; tmp.push_back(pdd2slice(p)); tmp.push_back(pdd2slice(q)); - merge(tmp, var2slice(v)); + VERIFY(merge(tmp, var2slice(v), null_dep)); return m_solver.var(v); } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index e5e45c729..e5f24bc1a 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -59,6 +59,10 @@ namespace polysat { // need src -> [v] and v -> [src] for propagation? #endif + using dep_t = sat::literal; + using dep_vector = sat::literal_vector; + static constexpr sat::literal null_dep = sat::null_literal; + using slice = unsigned; using slice_vector = unsigned_vector; static constexpr slice null_slice = std::numeric_limits::max(); @@ -71,13 +75,16 @@ namespace polysat { // The cut point is relative to the parent slice (rather than a root variable, which might not be unique) // (UINT_MAX for leaf slices) unsigned_vector m_slice_cut; - // The sub-slices are at indices sub and sub+1 (null_slice if no subdivision) + // The sub-slices are at indices sub and sub+1 (or null_slice if there is no subdivision) slice_vector m_slice_sub; slice_vector m_find; // representative of equivalence class slice_vector m_size; // number of elements in equivalence class slice_vector m_next; // next element of the equivalence class - unsigned_vector m_slice2var; // slice -> pvar, or null_var if slice is not equivalent to a variable + slice_vector m_proof_parent; // the proof forest + dep_vector m_proof_reason; // justification for merge of an element with its parent (in the proof forest) + + pvar_vector m_slice2var; // slice -> pvar, or null_var if slice is not equivalent to a variable slice_vector m_var2slice; // pvar -> slice slice alloc_slice(); @@ -87,6 +94,9 @@ namespace polysat { unsigned width(slice s) const { return m_slice_width[s]; } bool has_sub(slice s) const { return m_slice_sub[s] != null_slice; } + // reverse all edges on the path from s to the root of its tree in the proof forest + void make_proof_root(slice s); + /// Split slice s into s[|s|-1:cut+1] and s[cut:0] void split(slice s, unsigned cut); /// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n @@ -103,17 +113,24 @@ namespace polysat { /// Find representative of lower subslice slice find_sub_lo(slice s) const; - // Merge equivalence classes of two base slices - void merge_base(slice s1, slice s2); + // Merge equivalence classes of two base slices. + // Returns true if merge succeeded without conflict. + [[nodiscard]] bool merge_base(slice s1, slice s2, dep_t dep); + + void explain(slice x, slice y, dep_vector& out_deps); // Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k // // Precondition: // - sequence of slices with equal total width // - ordered from msb to lsb - void merge(slice_vector& xs, slice_vector& ys); - void merge(slice_vector& xs, slice y); - void merge(slice x, slice y); + // + // The argument vectors will be cleared. + // + // Returns true if merge succeeded without conflict. + [[nodiscard]] bool merge(slice_vector& xs, slice_vector& ys, dep_t dep); + [[nodiscard]] bool merge(slice_vector& xs, slice y, dep_t dep); + [[nodiscard]] bool merge(slice x, slice y, dep_t dep); enum class trail_item { @@ -124,7 +141,7 @@ namespace polysat { }; svector m_trail; slice_vector m_split_trail; - slice_vector m_merge_trail; + svector> m_merge_trail; // pair of (representative, element) unsigned_vector m_scopes; void undo_add_var(); @@ -134,13 +151,17 @@ namespace polysat { mutable slice_vector m_tmp1; + mutable slice_vector m_tmp2; + mutable slice_vector m_tmp3; - // get slice equivalent to the given pdd (may introduce new variable) + // get a slice that is equivalent to the given pdd (may introduce new variable) slice pdd2slice(pdd const& p); /** Get variable representing src[hi:lo] */ pvar mk_slice_extract(slice src, unsigned hi, unsigned lo); + bool invariant() const; + public: slicing(solver& s): m_solver(s) {} diff --git a/src/util/sat_literal.h b/src/util/sat_literal.h index aeb23bddd..eb4c16705 100644 --- a/src/util/sat_literal.h +++ b/src/util/sat_literal.h @@ -30,7 +30,7 @@ namespace sat { typedef svector bool_var_vector; - const bool_var null_bool_var = UINT_MAX >> 1; + inline constexpr bool_var null_bool_var = UINT_MAX >> 1; /** \brief The literal b is represented by the value 2*b, and @@ -39,8 +39,11 @@ namespace sat { class literal { unsigned m_val; public: - literal():m_val(null_bool_var << 1) { - SASSERT(var() == null_bool_var && !sign()); + constexpr literal(): m_val(null_bool_var << 1) { +#ifdef Z3DEBUG + assert(var() == null_bool_var); + assert(!sign()); +#endif } explicit literal(bool_var v, bool _sign = false): @@ -49,11 +52,11 @@ namespace sat { SASSERT(sign() == _sign); } - bool_var var() const { + constexpr bool_var var() const { return m_val >> 1; } - bool sign() const { + constexpr bool sign() const { return m_val & 1ul; } @@ -86,7 +89,7 @@ namespace sat { friend bool operator!=(literal const & l1, literal const & l2); }; - const literal null_literal; + inline constexpr literal null_literal; using literal_hash = obj_hash; inline literal to_literal(unsigned x) { literal l; l.m_val = x; return l; }