diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index b462e2ac1..bf247d1f9 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -82,14 +82,22 @@ namespace polysat { m_mark.pop_back(); } - slicing::slice slicing::find_sub_hi(slice parent) const { + slicing::slice slicing::sub_hi(slice parent) const { SASSERT(has_sub(parent)); - return find(m_slice_sub[parent]); + return m_slice_sub[parent]; + } + + slicing::slice slicing::sub_lo(slice parent) const { + SASSERT(has_sub(parent)); + return m_slice_sub[parent] + 1; + } + + slicing::slice slicing::find_sub_hi(slice parent) const { + return find(sub_hi(parent)); } slicing::slice slicing::find_sub_lo(slice parent) const { - SASSERT(has_sub(parent)); - return find(m_slice_sub[parent] + 1); + return find(sub_lo(parent)); } void slicing::split(slice s, unsigned cut) { @@ -198,13 +206,18 @@ namespace polysat { } } - void slicing::explain_base(slice x, slice y, dep_vector& out_deps) { - SASSERT(!has_sub(x)); - SASSERT(!has_sub(y)); + void slicing::push_reason(slice s, dep_vector& out_deps) { + dep_t reason = m_proof_reason[s]; + if (reason == null_dep) + return; + out_deps.push_back(reason); + } + + void slicing::explain_class(slice x, slice y, dep_vector& out_deps) { + SASSERT_EQ(find(x), find(y)); // /-> ... // x -> x1 -> x2 -> lca <- y1 <- y // r0 r1 r2 r4 r3 - SASSERT_EQ(find(x), find(y)); begin_mark(); // mark ancestors of x in the proof forest slice s = x; @@ -216,20 +229,64 @@ namespace polysat { // and collect deps from y to lca slice lca = y; while (!is_marked(lca)) { - out_deps.push_back(m_proof_reason[lca]); + push_reason(lca, out_deps); lca = m_proof_parent[lca]; SASSERT(lca != null_slice); } // collect deps from x to lca s = x; while (s != lca) { - out_deps.push_back(m_proof_reason[s]); + push_reason(s, out_deps); s = m_proof_parent[s]; SASSERT(s != null_slice); } end_mark(); } + void slicing::explain_equal(slice x, slice y, dep_vector& out_deps) { + // TODO: we currently get duplicates in out_deps (if parents are merged, the subslices are all merged due to the same reason) + SASSERT(is_equal(x, y)); + slice_vector& xs = m_tmp2; + slice_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + xs.push_back(x); + ys.push_back(y); + while (!xs.empty()) { + SASSERT(!ys.empty()); + slice const x = xs.back(); xs.pop_back(); + slice const y = ys.back(); ys.pop_back(); + if (x == y) + continue; + if (width(x) == width(y)) { + slice const rx = find(x); + slice const ry = find(y); + if (rx == ry) + explain_class(x, y, out_deps); + else { + xs.push_back(sub_hi(rx)); + xs.push_back(sub_lo(rx)); + ys.push_back(sub_hi(ry)); + ys.push_back(sub_lo(ry)); + } + } + else if (width(x) > width(y)) { + slice const rx = find(x); + xs.push_back(sub_hi(rx)); + xs.push_back(sub_lo(rx)); + ys.push_back(y); + } + else { + SASSERT(width(x) < width(y)); + xs.push_back(x); + slice const ry = find(y); + ys.push_back(sub_hi(ry)); + ys.push_back(sub_lo(ry)); + } + } + SASSERT(ys.empty()); + } + bool slicing::merge(slice_vector& xs, slice_vector& ys, dep_t dep) { // LOG_H2("Merging " << xs << " with " << ys); while (!xs.empty()) { diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index 3c426d998..b8b96559f 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -99,7 +99,7 @@ namespace polysat { if (!m_mark_timestamp) m_mark_timestamp++; } - void end_mark() { DEBUG_CODE({ SASSERT(!m_mark_active); m_mark_active = false; }); } + void end_mark() { DEBUG_CODE({ SASSERT(m_mark_active); m_mark_active = false; }); } bool is_marked(slice s) const { SASSERT(m_mark_active); return m_mark[s] == m_mark_timestamp; } void mark(slice s) { SASSERT(m_mark_active); m_mark[s] = m_mark_timestamp; } @@ -122,6 +122,11 @@ namespace polysat { /// If output_base is false, return coarsest intermediate slices instead of only base slices. void mk_slice(slice src, unsigned hi, unsigned lo, slice_vector& out, bool output_full_src = false, bool output_base = true); + /// Upper subslice (direct child, not necessarily the representative) + slice sub_hi(slice s) const; + /// Lower subslice (direct child, not necessarily the representative) + slice sub_lo(slice s) const; + /// Find representative slice find(slice s) const; /// Find representative of upper subslice @@ -133,8 +138,14 @@ namespace polysat { // Returns true if merge succeeded without conflict. [[nodiscard]] bool merge_base(slice s1, slice s2, dep_t dep); - // Extract reason for equality of base slices - void explain_base(slice x, slice y, dep_vector& out_deps); + void push_reason(slice s, dep_vector& out_deps); + + // Extract reason why slices x and y are in the same equivalence class + void explain_class(slice x, slice y, dep_vector& out_deps); + + // Extract reason why slices x and y are equal + // (i.e., x and y have the same base, but are not necessarily in the same equivalence class) + void explain_equal(slice x, slice y, dep_vector& out_deps); // Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k // diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index 93988ed8c..2a84bf5a9 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -98,17 +98,24 @@ namespace polysat { pdd d = sl.mk_concat(sl.mk_extract(x, 5, 4), sl.mk_extract(y, 3, 0)); std::cout << d << " := v" << x << "[5:4] ++ v" << y << "[3:0]\n" << sl << "\n"; - VERIFY(sl.merge(sl.var2slice(x), sl.var2slice(y), sat::literal(1))); + VERIFY(sl.merge(sl.var2slice(x), sl.var2slice(y), sat::literal(123))); std::cout << "v" << x << " = v" << y << "\n" << sl << "\n"; std::cout << "v" << b << " = v" << c << "? " << sl.is_equal(sl.var2slice(b), sl.var2slice(c)) << " find(v" << b << ") = " << sl.find(sl.var2slice(b)) << " find(v" << c << ") = " << sl.find(sl.var2slice(c)) << "\n"; + sat::literal_vector reason; + sl.explain_equal(sl.var2slice(b), sl.var2slice(c), reason); + std::cout << " Reason: " << reason << "\n"; + std::cout << "v" << b << " = " << d << "? " << sl.is_equal(sl.var2slice(b), sl.pdd2slice(d)) << " find(v" << b << ") = " << sl.find(sl.var2slice(b)) << " find(" << d << ") = " << sl.find(sl.pdd2slice(d)) << "\n"; + reason.reset(); + sl.explain_equal(sl.var2slice(b), sl.pdd2slice(d), reason); + std::cout << " Reason: " << reason << "\n"; } };