diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 1e3ab7135..674804145 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -17,132 +17,233 @@ Author: namespace polysat { - void slicing::push_var() { - m_stack.push_scope(); // TODO: we don't need a scope for each variable -#if 0 - m_src.push_back(null_var); - m_hi.push_back(0); - m_lo.push_back(0); -#endif - pvar const v = m_var_slices.size(); - slice_idx const s = alloc_slice(); - m_var_slices.push_back(s); + void slicing::push_scope() { + m_scopes.push_back(m_trail.size()); } - - void slicing::pop_var() { -#if 0 - if (m_src != null_var) { - extract_key key{m_src.back(), m_hi.back(), m_lo.back()}; - m_extracted.remove(key); + + void slicing::pop_scope(unsigned num_scopes) { + if (num_scopes == 0) + return; + unsigned const lvl = m_scopes.size(); + SASSERT(num_scopes <= lvl); + unsigned const target_lvl = lvl - num_scopes; + unsigned const target_size = m_scopes[target_lvl]; + while (m_trail.size() > target_size) { + switch (m_trail.back()) { + case trail_item::add_var: undo_add_var(); break; + case trail_item::alloc_slice: undo_alloc_slice(); break; + case trail_item::split_slice: undo_split_slice(); break; + case trail_item::merge_class: undo_merge_class(); break; + default: UNREACHABLE(); + } + m_trail.pop_back(); } - m_src.pop_back(); - m_hi.pop_back(); - m_lo.pop_back(); -#endif - m_var_slices.pop_back(); - m_stack.pop_scope(1); + m_scopes.shrink(target_lvl); + } + + void slicing::add_var(unsigned bit_width) { + pvar const v = m_var2slice.size(); + slice_idx const s = alloc_slice(); + m_slice_width[s] = bit_width; + m_var2slice.push_back(s); + } + + void slicing::undo_add_var() { + m_var2slice.pop_back(); } slicing::slice_idx slicing::alloc_slice() { - slice_idx const s = m_slices_uf.mk_var(); - SASSERT_EQ(s, m_slices.size()); - m_slices.push_back({}); - m_stack.push_ptr(&m_alloc_slice_trail); + slice_idx const s = m_slice_cut.size(); + m_slice_width.push_back(0); + m_slice_cut.push_back(null_cut); + m_slice_sub.push_back(null_slice_idx); + m_find.push_back(s); + m_size.push_back(1); + m_next.push_back(s); + m_trail.push_back(trail_item::alloc_slice); return s; } - void slicing::alloc_slice_trail::undo() { - m_owner.m_slices.pop_back(); + void slicing::undo_alloc_slice() { + m_slice_width.pop_back(); + m_slice_cut.pop_back(); + m_slice_sub.pop_back(); + m_find.pop_back(); + m_size.pop_back(); + m_next.pop_back(); } - void slicing::set_extract(pvar v, pvar src, unsigned hi, unsigned lo) { -#if 0 - SASSERT(!is_extract(v)); - SASSERT(lo < hi && hi <= s.size(src)); - SASSERT_EQ(hi - lo, s.size(v)); - SASSERT(src < v); - SASSERT(!m_extracted.contains(extract_key{src, hi, lo})); -#if 0 // try without this first - if (is_extract(src)) { - // y = (x[k:m])[h:l] = x[h+m:l+m] - unsigned const offset = m_lo[src]; - set_extract(m_src[src], hi + offset, lo + offset); - return; - } -#endif - m_extracted.insert({src, hi, lo}, v); - m_src[v] = src; - m_hi[v] = hi; - m_lo[v] = lo; -#endif + slicing::slice slicing::var2slice(pvar v) const { + slice_idx const idx = find(m_var2slice[v]); + slice s; + s.idx = idx; + s.hi = m_slice_width[idx] - 1; + // s.hi = m_solver.size(v) - 1; + s.lo = 0; + return s; } - slicing::slice_info slicing::var2slice(pvar v) const { - slice_info si; - si.idx = m_var_slices[v]; - si.hi = s.size(v) - 1; - si.lo = 0; - return si; - } - - slicing::slice_info slicing::sub_hi(slice_info const& parent) const { + slicing::slice slicing::sub_hi(slice const& parent) const { SASSERT(has_sub(parent)); - slice const& parent_slice = m_slices[parent.idx]; - slice_info si; - si.idx = parent_slice.sub; - si.hi = parent.hi; - si.lo = parent_slice.cut + 1; - SASSERT(si.hi >= si.lo); - return si; + SASSERT(parent.hi >= parent.lo); + slice s; + s.idx = find(m_slice_sub[parent.idx]); + // |parent|-1 ... cut+1 and cut ............ 0 + // hi ........... lo+cut+1 lo+cut ........ lo + s.hi = parent.hi; + s.lo = parent.lo + m_slice_cut[parent.idx] + 1; + SASSERT(s.hi >= s.lo); + SASSERT_EQ(m_slice_width[s.idx], s.hi - s.lo + 1); + return s; } - slicing::slice_info slicing::sub_lo(slice_info const& parent) const { + slicing::slice slicing::sub_lo(slice const& parent) const { SASSERT(has_sub(parent)); - slice const& parent_slice = m_slices[parent.idx]; - slice_info si; - si.idx = parent_slice.sub + 1; - si.hi = parent_slice.cut; - si.lo = parent.lo; - SASSERT(si.hi >= si.lo); - return si; + slice s; + s.idx = find(m_slice_sub[parent.idx] + 1); + // |parent|-1 ... cut+1 and cut ............ 0 + // hi ........... lo+cut+1 lo+cut ........ lo + s.hi = parent.lo + m_slice_cut[parent.idx]; + s.lo = parent.lo; + SASSERT(s.hi >= s.lo); + SASSERT_EQ(m_slice_width[s.idx], s.hi - s.lo + 1); + return s; } - unsigned slicing::get_cut(slice_info const& si) const { - SASSERT(has_sub(si)); - return m_slices[si.idx].cut; - } - - void slicing::split(slice_info const& si, unsigned const cut) { - SASSERT(!has_sub(si)); - SASSERT(si.hi > cut); SASSERT(cut >= si.lo); + void slicing::split(slice_idx s, unsigned cut) { + SASSERT(!has_sub(s)); slice_idx const sub1 = alloc_slice(); slice_idx const sub2 = alloc_slice(); - slice& s = m_slices[si.idx]; - s.cut = cut; - s.sub = sub1; + m_slice_cut[s] = cut; + m_slice_sub[s] = sub1; SASSERT_EQ(sub2, sub1 + 1); + m_slice_width[sub1] = m_slice_width[s] - cut - 1; + m_slice_width[sub2] = cut + 1; + + m_trail.push_back(trail_item::split_slice); + m_split_trail.push_back(s); } - void slicing::mk_slice(slice_info const& src, unsigned const hi, unsigned const lo, vector& out) - { + void slicing::split(slice const& s, unsigned const cut) { + SASSERT(s.hi > cut); SASSERT(cut >= s.lo); + split(s.idx, cut - s.lo); + } + + void slicing::undo_split_slice() { + slice_idx i = m_split_trail.back(); + m_split_trail.pop_back(); + m_slice_cut[i] = null_cut; + m_slice_sub[i] = null_slice_idx; + } + + slicing::slice_idx slicing::find(slice_idx i) const { + while (true) { + SASSERT(i < m_find.size()); + slice_idx const new_i = m_find[i]; + if (new_i == i) + return i; + i = new_i; + } + } + + void slicing::merge(slice_idx s1, slice_idx s2) { + SASSERT(!has_sub(s1)); + SASSERT(!has_sub(s2)); + slice_idx r1 = find(s1); + slice_idx r2 = find(s2); + if (r1 == r2) + return; + if (m_size[r1] > m_size[r2]) + std::swap(r1, r2); + // r2 becomes the representative of the merged class + m_find[r1] = r2; + m_size[r2] += m_size[r1]; + std::swap(m_next[r1], m_next[r2]); + m_trail.push_back(trail_item::merge_class); + m_merge_trail.push_back(r1); + } + + void slicing::undo_merge_class() { + slice_idx r1 = m_merge_trail.back(); + m_merge_trail.pop_back(); + slice_idx r2 = m_find[r1]; + SASSERT(find(r2) == r2); + m_find[r1] = r1; + m_size[r2] -= m_size[r1]; + std::swap(m_next[r1], m_next[r2]); + } + + void slicing::merge(slice_vector& xs, slice_vector& ys) { + while (!xs.empty()) { + SASSERT(!ys.empty()); + slice x = xs.back(); + slice y = ys.back(); + xs.pop_back(); + ys.pop_back(); + SASSERT_EQ(x.lo, y.lo); + SASSERT(!has_sub(x)); + SASSERT(!has_sub(y)); + if (x.hi == y.hi) { + merge(x.idx, y.idx); + } + else if (x.hi > y.hi) { + // need to split x according to y + mk_slice(x, y.hi, y.lo, xs); + ys.push_back(y); + } + else { + SASSERT(y.hi > x.hi); + // need to split y according to x + mk_slice(y, x.hi, x.lo, ys); + xs.push_back(x); + } + } + } + + void slicing::find_base(slice src, slice_vector& out_base) const { + // splits are only stored for the representative + SASSERT_EQ(src.idx, find(src.idx)); + if (!has_sub(src)) { + out_base.push_back(src); + return; + } + slice_vector& todo = m_tmp1; + SASSERT(todo.empty()); + todo.push_back(src); + while (!todo.empty()) { + slice s = todo.back(); + todo.pop_back(); + if (!has_sub(s)) + out_base.push_back(s); + else { + todo.push_back(sub_lo(s)); + todo.push_back(sub_hi(s)); + } + } + SASSERT(todo.empty()); + } + + void slicing::mk_slice(slice src, unsigned const hi, unsigned const lo, slice_vector& out_base) { + // splits are only stored for the representative + SASSERT_EQ(src.idx, find(src.idx)); // extracted range must be fully contained inside the src slice SASSERT(src.hi >= hi); SASSERT(hi >= lo); SASSERT(lo >= src.lo); if (src.hi == hi && src.lo == lo) { - out.push_back(src); + find_base(src, out_base); return; } if (has_sub(src)) { // src is split into [src.hi, cut+1] and [cut, src.lo] - unsigned const cut = get_cut(src); + unsigned const cut = m_slice_cut[src.idx] + src.lo; // adjust cut to current bounds if (lo >= cut + 1) - return mk_slice(sub_hi(src), hi, lo, out); + return mk_slice(sub_hi(src), hi, lo, out_base); else if (cut >= hi) - return mk_slice(sub_lo(src), hi, lo, out); + return mk_slice(sub_lo(src), hi, lo, out_base); else { SASSERT(hi > cut && cut >= lo); // desired range spans over the cutpoint, so we get multiple slices in the result - mk_slice(sub_hi(src), hi, cut + 1, out); - mk_slice(sub_lo(src), cut, lo, out); + mk_slice(sub_hi(src), hi, cut + 1, out_base); + mk_slice(sub_lo(src), cut, lo, out_base); return; } } @@ -150,16 +251,17 @@ namespace polysat { // [src.hi, src.lo] has no subdivision yet if (src.hi > hi) { split(src, hi); - mk_slice(sub_lo(src), hi, lo, out); + mk_slice(sub_lo(src), hi, lo, out_base); return; } else { SASSERT(src.hi == hi); SASSERT(lo > src.lo); split(src, lo - 1); - slice_info si = sub_hi(src); - SASSERT_EQ(si.hi, hi); SASSERT_EQ(si.lo, lo); - out.push_back(si); + slice s = sub_hi(src); + SASSERT_EQ(s.hi, hi); + SASSERT_EQ(s.lo, lo); + out_base.push_back(s); return; } } @@ -167,7 +269,7 @@ namespace polysat { } pvar slicing::mk_extract_var(pvar src, unsigned hi, unsigned lo) { - vector slices; + slice_vector slices; mk_slice(var2slice(src), hi, lo, slices); // src[hi:lo] is the concatenation of the returned slices // TODO: for each slice, set_extract @@ -183,15 +285,16 @@ namespace polysat { #endif } +#if 0 pdd slicing::mk_extract(pvar src, unsigned hi, unsigned lo) { - return s.var(mk_extract_var(src, hi, lo)); + return m_solver.var(mk_extract_var(src, hi, lo)); } pdd slicing::mk_extract(pdd const& p, unsigned hi, unsigned lo) { if (!lo) { // TODO: we could push the extract down into variables of the term instead of introducing a name. } - pvar const v = s.m_names.mk_name(p); + pvar const v = m_solver.m_names.mk_name(p); return mk_extract(v, hi, lo); } @@ -218,6 +321,29 @@ namespace polysat { #endif NOT_IMPLEMENTED_YET(); } +#endif + + void slicing::set_extract(pvar v, pvar src, unsigned hi, unsigned lo) { +#if 0 + SASSERT(!is_extract(v)); + SASSERT(lo < hi && hi <= s.size(src)); + SASSERT_EQ(hi - lo + 1, s.size(v)); + SASSERT(src < v); + SASSERT(!m_extracted.contains(extract_key{src, hi, lo})); +#if 0 // try without this first + if (is_extract(src)) { + // y = (x[k:m])[h:l] = x[h+m:l+m] + unsigned const offset = m_lo[src]; + set_extract(m_src[src], hi + offset, lo + offset); + return; + } +#endif + m_extracted.insert({src, hi, lo}, v); + m_src[v] = src; + m_hi[v] = hi; + m_lo[v] = lo; +#endif + } void slicing::propagate(pvar v) { } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index 5e6bd24ea..03320f2ba 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -25,8 +25,6 @@ Notation: --*/ #pragma once #include "math/polysat/types.h" -#include "util/trail.h" -#include "util/union_find.h" namespace polysat { @@ -34,8 +32,9 @@ namespace polysat { class slicing final { - solver& s; + // solver& m_solver; +#if 0 /// If y := x[h:l], then m_src[y] = x, m_hi[y] = h, m_lo[y] = l. /// Otherwise m_src[y] = null_var. /// @@ -46,7 +45,6 @@ namespace polysat { unsigned_vector m_hi; unsigned_vector m_lo; -#if 0 struct extract_key { pvar src; unsigned hi; @@ -68,81 +66,116 @@ namespace polysat { // need src -> [v] and v -> [src] for propagation? #endif - - - - trail_stack m_stack; - using slice_idx = unsigned; - static constexpr slice_idx null_slice_idx = UINT_MAX; + using slice_idx_vector = unsigned_vector; + static constexpr slice_idx null_slice_idx = std::numeric_limits::max(); - struct slice { - // If sub != null_slice_idx, the bit-vector x has been sliced into x[|x|-1:cut+1] and x[cut:0] - unsigned cut = UINT_MAX; - // If sub != null_slice_idx, the sub-slices are at indices sub and sub+1 - slice_idx sub = null_slice_idx; + static constexpr unsigned null_cut = std::numeric_limits::max(); - bool has_sub() const { return cut != 0; } - slice_idx sub_hi() const { return sub; } - slice_idx sub_lo() const { return sub + 1; } - }; - svector m_slices; // slice_idx -> slice - svector m_var_slices; // pvar -> slice_idx + // number of bits in the slice + // TODO: slice width is useful for debugging but we can probably drop it in release mode? + unsigned_vector m_slice_width; + // Cut point: if slice represents bit-vector x, then x has been sliced into x[|x|-1:cut+1] and x[cut:0]. + // The cut point is relative to the parent slice (rather than a root variable, which might not be unique) + // (UINT_MAX for leaf slices) + unsigned_vector m_slice_cut; + // The sub-slices are at indices sub and sub+1 (null_slice_idx if no subdivision) + slice_idx_vector m_slice_sub; + slice_idx_vector m_find; // representative of equivalence class + slice_idx_vector m_size; // number of elements in equivalence class + slice_idx_vector m_next; // next element of the equivalence class - // union_find over slices (union_find vars are indices into m_slices, i.e., slice_idx) - union_find m_slices_uf; + slice_idx_vector m_var2slice; // pvar -> slice_idx slice_idx alloc_slice(); - friend class alloc_slice_trail; - class alloc_slice_trail : public trail { - slicing& m_owner; - public: - alloc_slice_trail(slicing& o): m_owner(o) {} - void undo() override; + // track slice range while traversing sub-slices + // (reference point of hi/lo is user-defined, e.g., relative to entry point of traversal) + struct slice { + slice_idx idx = null_slice_idx; + unsigned hi = UINT_MAX; + unsigned lo = UINT_MAX; }; - alloc_slice_trail m_alloc_slice_trail; + using slice_vector = svector; + slice var2slice(pvar v) const; + bool has_sub(slice_idx i) const { return m_slice_sub[i] != null_slice_idx; } + bool has_sub(slice const& s) const { return has_sub(s.idx); } + slice sub_hi(slice const& s) const; + slice sub_lo(slice const& s) const; + // Split a slice into two; the cut is relative to |s|...0 + void split(slice_idx s, unsigned cut); + // Split a slice into two; NOTE: the cut point here is relative to hi/lo in s + void split(slice const& s, unsigned cut); + // Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... + s_n + void find_base(slice src, slice_vector& out_base) const; + // Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n + void mk_slice(slice src, unsigned hi, unsigned lo, slice_vector& out_base); + + // find representative + slice_idx find(slice_idx i) const; + + // merge equivalence classes of two base slices + void merge(slice_idx s1, slice_idx s2); + + // Equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k + // + // Precondition: + // - sequence of base slices without holes (TODO: condition on holes probably not necessary? total widths have to match of course) + // - ordered from msb to lsb + // - slices have the same reference point + void merge(slice_vector& xs, slice_vector& ys); void set_extract(pvar v, pvar src, unsigned hi_bit, unsigned lo_bit); - struct slice_info { - slice_idx idx; - unsigned hi; - unsigned lo; + + enum class trail_item { + add_var, + alloc_slice, + split_slice, + merge_class, }; - slice_info var2slice(pvar v) const; - bool has_sub(slice_info const& si) const { return m_slices[si.idx].has_sub(); } - slice_info sub_hi(slice_info const& si) const; - slice_info sub_lo(slice_info const& si) const; - unsigned get_cut(slice_info const& si) const; - void split(slice_info const& si, unsigned cut); - void mk_slice(slice_info const& src, unsigned hi, unsigned lo, vector& out); + svector m_trail; + slice_idx_vector m_split_trail; + slice_idx_vector m_merge_trail; + unsigned_vector m_scopes; + + void undo_add_var(); + void undo_alloc_slice(); + void undo_split_slice(); + void undo_merge_class(); + + + mutable slice_vector m_tmp1; + public: - slicing(solver& s): - s(s), - m_slices_uf(*this), - m_alloc_slice_trail(*this) - {} + // slicing(solver& s): m_solver(s) {} - trail_stack& get_trail_stack() { return m_stack; } + void push_scope(); + void pop_scope(unsigned num_scopes = 1); - void push_var(); - void pop_var(); + void add_var(unsigned bit_width); - bool is_extract(pvar v) const { return m_src[v] != null_var; } + + + + + + + + // bool is_extract(pvar v) const { return m_src[v] != null_var; } /** Get variable representing x[hi:lo] */ pvar mk_extract_var(pvar x, unsigned hi, unsigned lo); - /** Create expression for x[hi:lo] */ - pdd mk_extract(pvar x, unsigned hi, unsigned lo); + // /** Create expression for x[hi:lo] */ + // pdd mk_extract(pvar x, unsigned hi, unsigned lo); - /** Create expression for p[hi:lo] */ - pdd mk_extract(pdd const& p, unsigned hi, unsigned lo); + // /** Create expression for p[hi:lo] */ + // pdd mk_extract(pdd const& p, unsigned hi, unsigned lo); - /** Create expression for p ++ q */ - pdd mk_concat(pdd const& p, pdd const& q); + // /** Create expression for p ++ q */ + // pdd mk_concat(pdd const& p, pdd const& q); // propagate: // - value assignments