diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 7fc6478b4..40c1ea9af 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -33,7 +33,7 @@ namespace polysat { case trail_item::add_var: undo_add_var(); break; case trail_item::alloc_slice: undo_alloc_slice(); break; case trail_item::split_slice: undo_split_slice(); break; - case trail_item::merge_class: undo_merge_class(); break; + case trail_item::merge_base: undo_merge_base(); break; default: UNREACHABLE(); } m_trail.pop_back(); @@ -117,7 +117,7 @@ namespace polysat { } } - void slicing::merge(slice s1, slice s2) { + void slicing::merge_base(slice s1, slice s2) { SASSERT_EQ(width(s1), width(s2)); SASSERT(!has_sub(s1)); SASSERT(!has_sub(s2)); @@ -137,11 +137,11 @@ namespace polysat { // otherwise the classes should have been merged already SASSERT(m_slice2var[r2] != m_slice2var[r1]); } - m_trail.push_back(trail_item::merge_class); + m_trail.push_back(trail_item::merge_base); m_merge_trail.push_back(r1); } - void slicing::undo_merge_class() { + void slicing::undo_merge_base() { slice r1 = m_merge_trail.back(); m_merge_trail.pop_back(); slice r2 = m_find[r1]; @@ -161,11 +161,21 @@ namespace polysat { slice y = ys.back(); xs.pop_back(); ys.pop_back(); + if (has_sub(x)) { + find_base(x, xs); + x = xs.back(); + xs.pop_back(); + } + if (has_sub(y)) { + find_base(y, ys); + y = ys.back(); + ys.pop_back(); + } SASSERT(!has_sub(x)); SASSERT(!has_sub(y)); if (width(x) == width(y)) { // LOG("Match " << x << " and " << y); - merge(x, y); + merge_base(x, y); } else if (width(x) > width(y)) { // need to split x according to y @@ -190,6 +200,15 @@ namespace polysat { merge(xs, tmp); } + void slicing::merge(slice x, slice y) { + if (!has_sub(x) && !has_sub(y)) + return merge_base(x, y); + slice_vector tmpx, tmpy; + tmpx.push_back(x); + tmpy.push_back(y); + merge(tmpx, tmpy); + } + void slicing::find_base(slice src, slice_vector& out_base) const { // splits are only stored for the representative SASSERT_EQ(src, find(src)); @@ -213,12 +232,18 @@ namespace polysat { SASSERT(todo.empty()); } - void slicing::mk_slice(slice src, unsigned const hi, unsigned const lo, slice_vector& out_base, bool output_full_src) { + void slicing::mk_slice(slice src, unsigned const hi, unsigned const lo, slice_vector& out, bool output_full_src, bool output_base) { SASSERT(hi >= lo); SASSERT_EQ(src, find(src)); // splits are only stored for the representative SASSERT(width(src) > hi); // extracted range must be fully contained inside the src slice + auto output_slice = [this, output_base, &out](slice s) { + if (output_base) + find_base(s, out); + else + out.push_back(s); + }; if (lo == 0 && width(src) - 1 == hi) { - find_base(src, out_base); + output_slice(src); return; } if (has_sub(src)) { @@ -226,23 +251,23 @@ namespace polysat { unsigned const cut = m_slice_cut[src]; if (lo >= cut + 1) { // target slice falls into upper subslice - mk_slice(find_sub_hi(src), hi - cut - 1, lo - cut - 1, out_base); + mk_slice(find_sub_hi(src), hi - cut - 1, lo - cut - 1, out, output_full_src, output_base); if (output_full_src) - out_base.push_back(find_sub_lo(src)); + output_slice(find_sub_lo(src)); return; } else if (cut >= hi) { // target slice falls into lower subslice if (output_full_src) - out_base.push_back(find_sub_hi(src)); - mk_slice(find_sub_lo(src), hi, lo, out_base); + output_slice(find_sub_hi(src)); + mk_slice(find_sub_lo(src), hi, lo, out, output_full_src, output_base); return; } else { SASSERT(hi > cut && cut >= lo); // desired range spans over the cutpoint, so we get multiple slices in the result - mk_slice(find_sub_hi(src), hi - cut - 1, 0, out_base); - mk_slice(find_sub_lo(src), cut, lo, out_base); + mk_slice(find_sub_hi(src), hi - cut - 1, 0, out, output_full_src, output_base); + mk_slice(find_sub_lo(src), cut, lo, out, output_full_src, output_base); return; } } @@ -250,41 +275,42 @@ namespace polysat { // [src.width-1, 0] has no subdivision yet if (width(src) - 1 > hi) { split(src, hi); + SASSERT(!has_sub(find_sub_hi(src))); if (output_full_src) - out_base.push_back(find_sub_hi(src)); - mk_slice(find_sub_lo(src), hi, lo, out_base); // recursive call to take care of case lo > 0 + out.push_back(find_sub_hi(src)); + mk_slice(find_sub_lo(src), hi, lo, out, output_full_src, output_base); // recursive call to take care of case lo > 0 return; } else { SASSERT(lo > 0); split(src, lo - 1); - out_base.push_back(find_sub_hi(src)); + out.push_back(find_sub_hi(src)); + SASSERT(!has_sub(find_sub_lo(src))); if (output_full_src) - out_base.push_back(find_sub_lo(src)); + out.push_back(find_sub_lo(src)); return; } } UNREACHABLE(); } - pvar slicing::mk_extract_var(pvar src, unsigned hi, unsigned lo) { + pvar slicing::mk_slice_extract(slice src, unsigned hi, unsigned lo) { slice_vector slices; - mk_slice(var2slice(src), hi, lo, slices); - // src[hi:lo] is the concatenation of the returned slices - // TODO: for each slice, set_extract - -#if 0 - extract_key key{src, hi, lo}; - auto it = m_extracted.find_iterator(key); - if (it != m_extracted.end()) - return it->m_value; - pvar v = s.add_var(hi - lo); - set_extract(v, src, hi, lo); + mk_slice(src, hi, lo, slices, false, true); + if (slices.size() == 1) { + slice s = slices[0]; + if (slice2var(s) != null_var) + return slice2var(s); + } + pvar v = m_solver.add_var(hi - lo + 1); + merge(slices, var2slice(v)); return v; -#endif } -#if 0 + pvar slicing::mk_extract_var(pvar src, unsigned hi, unsigned lo) { + return mk_slice_extract(var2slice(src), hi, lo); + } + pdd slicing::mk_extract(pvar src, unsigned hi, unsigned lo) { return m_solver.var(mk_extract_var(src, hi, lo)); } @@ -293,55 +319,27 @@ namespace polysat { if (!lo) { // TODO: we could push the extract down into variables of the term instead of introducing a name. } + return m_solver.var(mk_slice_extract(pdd2slice(p), hi, lo)); + } + + slicing::slice slicing::pdd2slice(pdd const& p) { pvar const v = m_solver.m_names.mk_name(p); - return mk_extract(v, hi, lo); + return var2slice(v); } pdd slicing::mk_concat(pdd const& p, pdd const& q) { -#if 0 // v := p ++ q (new variable of size |p| + |q|) // v[:|q|] = p // v[|q|:] = q unsigned const p_sz = p.power_of_2(); unsigned const q_sz = q.power_of_2(); unsigned const v_sz = p_sz + q_sz; - // TODO: lookup to see if we can reuse a variable - // either: - // - table of concats - // - check for variable with v[:|q|] = p and v[|q|:] = q in extract table (probably better) - pvar const v = s.add_var(v_sz); - - // TODO: probably wrong to use names for p, q. - // we should rather check if there's already an extraction for v[...] and reuse that variable. - pvar const p_name = s.m_names.mk_name(p); - pvar const q_name = s.m_names.mk_name(q); - set_extract(p_name, v, v_sz, q_sz); - set_extract(q_name, v, q_sz, 0); -#endif - NOT_IMPLEMENTED_YET(); - } -#endif - - void slicing::set_extract(pvar v, pvar src, unsigned hi, unsigned lo) { -#if 0 - SASSERT(!is_extract(v)); - SASSERT(lo < hi && hi <= s.size(src)); - SASSERT_EQ(hi - lo + 1, s.size(v)); - SASSERT(src < v); - SASSERT(!m_extracted.contains(extract_key{src, hi, lo})); -#if 0 // try without this first - if (is_extract(src)) { - // y = (x[k:m])[h:l] = x[h+m:l+m] - unsigned const offset = m_lo[src]; - set_extract(m_src[src], hi + offset, lo + offset); - return; - } -#endif - m_extracted.insert({src, hi, lo}, v); - m_src[v] = src; - m_hi[v] = hi; - m_lo[v] = lo; -#endif + pvar const v = m_solver.add_var(v_sz); + slice_vector tmp; + tmp.push_back(pdd2slice(p)); + tmp.push_back(pdd2slice(q)); + merge(tmp, var2slice(v)); + return m_solver.var(v); } void slicing::propagate(pvar v) { @@ -353,8 +351,14 @@ namespace polysat { out << "v" << v << ":"; base.reset(); find_base(var2slice(v), base); - for (slice s : base) + // unsigned hi = width(var2slice(v)) - 1; + for (slice s : base) { + // unsigned w = width(s); + // unsigned lo = hi - w + 1; + // out << " s" << s << "_[" << hi << ":" << lo << "]"; + // hi -= w; display(out << " ", s); + } out << "\n"; } return out; diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index b81ee9579..e5e45c729 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -25,6 +25,7 @@ Notation: --*/ #pragma once #include "math/polysat/types.h" +#include namespace polysat { @@ -34,19 +35,9 @@ namespace polysat { friend class test_slicing; - // solver& m_solver; + solver& m_solver; #if 0 - /// If y := x[h:l], then m_src[y] = x, m_hi[y] = h, m_lo[y] = l. - /// Otherwise m_src[y] = null_var. - /// - /// Invariants: - /// m_src[y] != null_var ==> m_src[y] < y (at least as long as we always introduce new variables for extract terms.) - /// m_lo[y] <= m_hi[y] - unsigned_vector m_src; - unsigned_vector m_hi; - unsigned_vector m_lo; - struct extract_key { pvar src; unsigned hi; @@ -98,11 +89,12 @@ namespace polysat { /// Split slice s into s[|s|-1:cut+1] and s[cut:0] void split(slice s, unsigned cut); - /// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... + s_n + /// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n void find_base(slice src, slice_vector& out_base) const; - // Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n - // If output_full_src is true, returns the new base for src, i.e., src == s_1 ++ ... ++ s_n - void mk_slice(slice src, unsigned hi, unsigned lo, slice_vector& out_base, bool output_full_src = false); + /// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n. + /// If output_full_src is true, return the new base for src, i.e., src == s_1 ++ ... ++ s_n. + /// If output_base is false, return coarsest intermediate slices instead of only base slices. + void mk_slice(slice src, unsigned hi, unsigned lo, slice_vector& out, bool output_full_src = false, bool output_base = true); /// Find representative slice find(slice s) const; @@ -112,67 +104,62 @@ namespace polysat { slice find_sub_lo(slice s) const; // Merge equivalence classes of two base slices - void merge(slice s1, slice s2); + void merge_base(slice s1, slice s2); // Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k // // Precondition: - // - sequence of base slices (equal total width) + // - sequence of slices with equal total width // - ordered from msb to lsb void merge(slice_vector& xs, slice_vector& ys); void merge(slice_vector& xs, slice y); - - void set_extract(pvar v, pvar src, unsigned hi_bit, unsigned lo_bit); + void merge(slice x, slice y); enum class trail_item { add_var, alloc_slice, split_slice, - merge_class, + merge_base, }; svector m_trail; - slice_vector m_split_trail; - slice_vector m_merge_trail; + slice_vector m_split_trail; + slice_vector m_merge_trail; unsigned_vector m_scopes; void undo_add_var(); void undo_alloc_slice(); void undo_split_slice(); - void undo_merge_class(); + void undo_merge_base(); mutable slice_vector m_tmp1; + // get slice equivalent to the given pdd (may introduce new variable) + slice pdd2slice(pdd const& p); + + /** Get variable representing src[hi:lo] */ + pvar mk_slice_extract(slice src, unsigned hi, unsigned lo); public: - // slicing(solver& s): m_solver(s) {} + slicing(solver& s): m_solver(s) {} void push_scope(); void pop_scope(unsigned num_scopes = 1); void add_var(unsigned bit_width); - - - - - - - - // bool is_extract(pvar v) const { return m_src[v] != null_var; } - /** Get variable representing x[hi:lo] */ pvar mk_extract_var(pvar x, unsigned hi, unsigned lo); - // /** Create expression for x[hi:lo] */ - // pdd mk_extract(pvar x, unsigned hi, unsigned lo); + /** Create expression for x[hi:lo] */ + pdd mk_extract(pvar x, unsigned hi, unsigned lo); - // /** Create expression for p[hi:lo] */ - // pdd mk_extract(pdd const& p, unsigned hi, unsigned lo); + /** Create expression for p[hi:lo] */ + pdd mk_extract(pdd const& p, unsigned hi, unsigned lo); - // /** Create expression for p ++ q */ - // pdd mk_concat(pdd const& p, pdd const& q); + /** Create expression for p ++ q */ + pdd mk_concat(pdd const& p, pdd const& q); // propagate: // - value assignments diff --git a/src/math/polysat/solver.cpp b/src/math/polysat/solver.cpp index 1e14a7b8f..7f2895bd3 100644 --- a/src/math/polysat/solver.cpp +++ b/src/math/polysat/solver.cpp @@ -42,7 +42,7 @@ namespace polysat { m_free_pvars(m_activity), m_constraints(*this), m_names(*this), - // m_slicing(*this), + m_slicing(*this), m_search(*this) { } diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index ee5d3c3ca..aa8d0d6ce 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -60,29 +60,45 @@ namespace polysat { slicing& sl = s.sl(); pvar x = s.add_var(8); pvar y = s.add_var(8); - pvar a = s.add_var(5); - pvar b = s.add_var(6); - slicing::slice_vector x_7_3; - sl.mk_slice(sl.var2slice(x), 7, 3, x_7_3); - slicing::slice_vector a_4_0; - sl.mk_slice(sl.var2slice(a), 4, 0, a_4_0); - sl.merge(x_7_3, a_4_0); + pvar a = sl.mk_extract_var(x, 7, 3); std::cout << sl << "\n"; - slicing::slice_vector x_base; - sl.find_base(sl.var2slice(x), x_base); - slicing::slice_vector y_base; - sl.find_base(sl.var2slice(y), y_base); - sl.merge(x_base, y_base); + sl.merge(sl.var2slice(x), sl.var2slice(y)); std::cout << sl << "\n"; - slicing::slice_vector y_5_0; - sl.mk_slice(sl.var2slice(y), 5, 0, y_5_0); - sl.merge(y_5_0, sl.var2slice(b)); + pvar b = sl.mk_extract_var(y, 5, 0); std::cout << sl << "\n"; } + // x[7:3] = a + // y[5:0] = b + // x[5:0] = c + // x[5:4] ++ y[3:0] = d + // x = y + // + // How easily can we find b=c and b=d? + static void test3() { + std::cout << __func__ << "\n"; + scoped_solver_slicing s; + slicing& sl = s.sl(); + pvar x = s.add_var(8); + pvar y = s.add_var(8); + std::cout << sl << "\n"; + + pvar a = sl.mk_extract_var(x, 7, 3); + std::cout << "v" << a << " := v" << x << "[7:3]\n" << sl << "\n"; + pvar b = sl.mk_extract_var(y, 5, 0); + std::cout << "v" << b << " := v" << y << "[5:0]\n" << sl << "\n"; + pvar c = sl.mk_extract_var(x, 5, 0); + std::cout << "v" << c << " := v" << x << "[5:0]\n" << sl << "\n"; + pdd d = sl.mk_concat(sl.mk_extract(x, 5, 4), sl.mk_extract(y, 3, 0)); + std::cout << d << " := v" << x << "[5:4] ++ v" << y << "[3:0]\n" << sl << "\n"; + + sl.merge(sl.var2slice(x), sl.var2slice(y)); + std::cout << "v" << x << " = v" << y << "\n" << sl << "\n"; + } + }; } @@ -92,5 +108,6 @@ void tst_slicing() { using namespace polysat; test_slicing::test1(); test_slicing::test2(); + test_slicing::test3(); std::cout << "ok\n"; }