diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index da676aac9..4dac35d23 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -11,6 +11,35 @@ Author: --*/ + + + +/* + + (x=y) + x <=========== y + / \ / \ + x[7:4] x[3:0] y[3:0] + <========== + (by x=y) + +Try later: + Congruence closure with "virtual concat" terms + x = x[7:4] ++ x[3:0] + y = y[7:4] ++ y[3:0] + x[7:4] = y[7:4] + x[3:0] = y[3:0] + => x = y + +Recycle the z3 egraph? + - x = x[7:4] ++ x[3:0] + - Add instance euf_egraph.h + - What do we need from the egraph? + - backtracking trail to check for new equalities + +*/ + + #include "math/polysat/slicing.h" #include "math/polysat/solver.h" #include "math/polysat/log.h" @@ -30,10 +59,11 @@ namespace polysat { unsigned const target_size = m_scopes[target_lvl]; while (m_trail.size() > target_size) { switch (m_trail.back()) { - case trail_item::add_var: undo_add_var(); break; - case trail_item::alloc_slice: undo_alloc_slice(); break; - case trail_item::split_slice: undo_split_slice(); break; - case trail_item::merge_base: undo_merge_base(); break; + case trail_item::add_var: undo_add_var(); break; + case trail_item::alloc_slice: undo_alloc_slice(); break; + case trail_item::split_slice: undo_split_slice(); break; + case trail_item::merge_base: undo_merge_base(); break; + case trail_item::mk_value_slice: undo_mk_value_slice(); break; default: UNREACHABLE(); } m_trail.pop_back(); @@ -65,6 +95,8 @@ namespace polysat { m_proof_parent.push_back(null_slice); m_proof_reason.push_back(null_dep); m_mark.push_back(0); + // m_value_root.push_back(null_slice); + m_slice2val.push_back(rational(-1)); m_trail.push_back(trail_item::alloc_slice); return s; } @@ -80,6 +112,8 @@ namespace polysat { m_proof_parent.pop_back(); m_proof_reason.pop_back(); m_mark.pop_back(); + // m_value_root.pop_back(); + m_slice2val.pop_back(); } slicing::slice slicing::sub_hi(slice parent) const { @@ -112,6 +146,11 @@ namespace polysat { m_slice_width[sub_lo] = cut + 1; m_trail.push_back(trail_item::split_slice); m_split_trail.push_back(s); + if (has_value(s)) { + rational const& val = get_value(s); + // set_value(sub_lo, mod2k(val, cut + 1)); + // set_value(sub_hi, machine_div2k(val, cut + 1)); + } } void slicing::undo_split_slice() { @@ -121,6 +160,27 @@ namespace polysat { m_slice_sub[s] = null_slice; } + slicing::slice slicing::mk_value_slice(rational const& val, unsigned bit_width) { + SASSERT(0 <= val && val < rational::power_of_two(bit_width)); + val2slice_key key(val, bit_width); + auto it = m_val2slice.find_iterator(key); + if (it != m_val2slice.end()) + return it->m_value; + slice s = alloc_slice(); + m_slice_width[s] = bit_width; + m_slice2val[s] = val; + // m_value_root[s] = s; + m_val2slice.insert(key, s); + m_val2slice_trail.push_back(std::move(key)); + m_trail.push_back(trail_item::mk_value_slice); + return s; + } + + void slicing::undo_mk_value_slice() { + m_val2slice.remove(m_val2slice_trail.back()); + m_val2slice_trail.pop_back(); + } + slicing::slice slicing::find(slice s) const { while (true) { SASSERT(s < m_find.size()); @@ -143,6 +203,16 @@ namespace polysat { std::swap(r1, r2); std::swap(s1, s2); } + if (has_value(r1)) { + if (has_value(r2)) { + if (get_value(r1) != get_value(r2)) { + NOT_IMPLEMENTED_YET(); // TODO: conflict + return false; + } + } + // else + // set_value(r2, get_value(r1)); + } // r2 becomes the representative of the merged class m_find[r1] = r2; m_size[r2] += m_size[r1]; @@ -206,6 +276,39 @@ namespace polysat { } } + bool slicing::merge_value(slice s0, rational val0, dep_t dep) { + vector> todo; + todo.push_back({find(s0), std::move(val0)}); + // check compatibility for sub-slices + for (unsigned i = 0; i < todo.size(); ++i) { + auto const& [s, val] = todo[i]; + if (has_value(s)) { + if (get_value(s) != val) { + // TODO: conflict + NOT_IMPLEMENTED_YET(); + return false; + } + SASSERT_EQ(get_value(s), val); + continue; + } + if (has_sub(s)) { + // s is split into [s.width-1, cut+1] and [cut, 0] + unsigned const cut = m_slice_cut[s]; + todo.push_back({find_sub_lo(s), mod2k(val, cut + 1)}); + todo.push_back({find_sub_hi(s), machine_div2k(val, cut + 1)}); + } + } + // all succeeded, so apply the values + for (auto const& [s, val] : todo) { + if (has_value(s)) { + SASSERT_EQ(get_value(s), val); + continue; + } + // set_value(s, val); + } + return true; + } + void slicing::push_reason(slice s, dep_vector& out_deps) { dep_t reason = m_proof_reason[s]; if (reason == null_dep) @@ -295,6 +398,10 @@ namespace polysat { slice y = ys.back(); xs.pop_back(); ys.pop_back(); + if (x == y) { + // merge upper level? + // but continue loop + } if (has_sub(x)) { find_base(x, xs); x = xs.back(); @@ -338,6 +445,7 @@ namespace polysat { } bool slicing::merge(slice x, slice y, dep_t dep) { + SASSERT_EQ(width(x), width(y)); if (!has_sub(x) && !has_sub(y)) return merge_base(x, y, dep); slice_vector& xs = m_tmp2; @@ -350,6 +458,7 @@ namespace polysat { } bool slicing::is_equal(slice x, slice y) { + SASSERT_EQ(width(x), width(y)); x = find(x); y = find(y); if (x == y) @@ -480,6 +589,13 @@ namespace polysat { } pdd slicing::mk_extract(pdd const& p, unsigned hi, unsigned lo) { + SASSERT(hi >= lo); + SASSERT(p.power_of_2() > hi); + if (p.is_val()) { + // p[hi:lo] = (p >> lo) % 2^(hi - lo + 1) + rational q = mod2k(machine_div2k(p.val(), lo), hi - lo + 1); + return p.manager().mk_val(q); + } if (!lo) { // TODO: we could push the extract down into variables of the term instead of introducing a name. } @@ -498,6 +614,14 @@ namespace polysat { unsigned const p_sz = p.power_of_2(); unsigned const q_sz = q.power_of_2(); unsigned const v_sz = p_sz + q_sz; + if (p.is_val() && q.is_val()) { + rational const val = p.val() * rational::power_of_two(q_sz) + q.val(); + return m_solver.sz2pdd(v_sz).mk_val(val); + } + if (p.is_val()) { + } + if (q.is_val()) { + } pvar const v = m_solver.add_var(v_sz); slice_vector tmp; tmp.push_back(pdd2slice(p)); @@ -506,7 +630,48 @@ namespace polysat { return m_solver.var(v); } + void slicing::propagate(signed_constraint c) { + // TODO: evaluate under current assignment? + if (!c->is_eq()) + return; + pdd const& p = c->to_eq(); + auto& m = p.manager(); + for (auto& [a, x] : p.linear_monomials()) { + if (a != 1 && a != m.max_value()) + continue; + pdd body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p); + // c is either x = body or x != body, depending on polarity + LOG("Equation from constraint " << c << ": v" << x << " = " << body); + slice const sx = var2slice(x); + if (body.is_val()) { + // Simple assignment x = value + // TODO: set fixed bits + continue; + } + pvar const y = m_solver.m_names.get_name(body); + if (y == null_var) { + // TODO: register name trigger (if a name for value 'body' is created later, then merge x=y at that time) + continue; + } + slice const sy = var2slice(y); + if (c.is_positive()) { + if (!merge(sx, sy, c.blit())) + return; + } + else { + SASSERT(c.is_negative()); + if (is_equal(sx, sy)) { + // TODO: conflict + NOT_IMPLEMENTED_YET(); + return; + } + } + } + } + void slicing::propagate(pvar v) { + // go through all existing nodes, and evaluate v? + // can do that externally } std::ostream& slicing::display(std::ostream& out) const { @@ -514,7 +679,8 @@ namespace polysat { for (pvar v = 0; v < m_var2slice.size(); ++v) { out << "v" << v << ":"; base.reset(); - find_base(var2slice(v), base); + slice const vs = var2slice(v); + find_base(vs, base); // unsigned hi = width(var2slice(v)) - 1; for (slice s : base) { // unsigned w = width(s); @@ -523,13 +689,49 @@ namespace polysat { // hi -= w; display(out << " ", s); } + if (has_value(vs)) { + out << " -- (val:" << get_value(vs) << ")"; + } out << "\n"; } + for (pvar v = 0; v < m_var2slice.size(); ++v) { + out << "v" << v << ":"; + slice const s = m_var2slice[v]; + } return out; } + std::ostream& slicing::display_tree(std::ostream& out, char const* name, slice s) const { + // TODO + } + std::ostream& slicing::display(std::ostream& out, slice s) const { - return out << "{id:" << s << ",w:" << width(s) << "}"; + out << "{id:" << s << ",w:" << width(s); + if (has_value(s)) { + out << ",val:" << get_value(s); + } + out << "}"; + return out; + } + + bool slicing::invariant() const { + VERIFY(m_tmp1.empty()); + VERIFY(m_tmp2.empty()); + VERIFY(m_tmp3.empty()); + for (slice s = 0; s < m_slice_cut.size(); ++s) { + // if the slice is equivalent to a variable, then the variable's slice is in the equivalence class + pvar const v = slice2var(s); + SASSERT_EQ(v != null_var, find(var2slice(v)) == find(s)); + // properties below only matter for representatives + if (s != find(s)) + continue; + // if slice has a value, it should be propagated to its sub-slices + if (has_value(s)) { + VERIFY(has_value(find_sub_hi(s))); + VERIFY(has_value(find_sub_lo(s))); + } + } + return true; } } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index b8b96559f..9f104fc73 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -25,6 +25,7 @@ Notation: --*/ #pragma once #include "math/polysat/types.h" +#include "math/polysat/constraint.h" #include namespace polysat { @@ -37,28 +38,6 @@ namespace polysat { solver& m_solver; -#if 0 - struct extract_key { - pvar src; - unsigned hi; - unsigned lo; - - bool operator==(extract_key const& other) const { - return src == other.src && hi == other.hi && lo == other.lo; - } - - unsigned hash() const { - return mk_mix(src, hi, lo); - } - }; - using extract_hash = obj_hash; - using extract_eq = default_eq; - using extract_map = map; - - extract_map m_extracted; ///< src, hi, lo -> v - // need src -> [v] and v -> [src] for propagation? -#endif - using dep_t = sat::literal; using dep_vector = sat::literal_vector; static constexpr sat::literal null_dep = sat::null_literal; @@ -69,11 +48,30 @@ namespace polysat { static constexpr unsigned null_cut = std::numeric_limits::max(); + struct val2slice_key { + rational value; + unsigned bit_width; + + val2slice_key() {} + val2slice_key(rational value, unsigned bit_width): value(std::move(value)), bit_width(bit_width) {} + + bool operator==(val2slice_key const& other) const { + return bit_width == other.bit_width && value == other.value; + } + + unsigned hash() const { + return combine_hash(value.hash(), bit_width); + } + }; + using val2slice_hash = obj_hash; + using val2slice_eq = default_eq; + using val2slice_map = map; + // number of bits in the slice unsigned_vector m_slice_width; // Cut point: if slice represents bit-vector x, then x has been sliced into x[|x|-1:cut+1] and x[cut:0]. // The cut point is relative to the parent slice (rather than a root variable, which might not be unique) - // (UINT_MAX for leaf slices) + // (null_cut for leaf slices) unsigned_vector m_slice_cut; // The sub-slices are at indices sub and sub+1 (or null_slice if there is no subdivision) slice_vector m_slice_sub; @@ -84,6 +82,15 @@ namespace polysat { slice_vector m_proof_parent; // the proof forest dep_vector m_proof_reason; // justification for merge of an element with its parent (in the proof forest) + // unsigned_vector m_value_id; // slice -> value id + // vector m_value; // id -> value + // slice_vector m_value_root; // the slice representing the associated value, if any. NOTE: subslices will inherit this from their parents. + // TODO: value_root probably not necessary. + // but we will need to create value slices for the sub-slices. + // then the "reason" for that equality must be a marker "equality between parent and its value". explanation at that point must go up recursively. + vector m_slice2val; // slice -> value (-1 if none) + val2slice_map m_val2slice; // (value, bit-width) -> slice + pvar_vector m_slice2var; // slice -> pvar, or null_var if slice is not equivalent to a variable slice_vector m_var2slice; // pvar -> slice @@ -110,6 +117,14 @@ namespace polysat { unsigned width(slice s) const { return m_slice_width[s]; } bool has_sub(slice s) const { return m_slice_sub[s] != null_slice; } + // slice val2slice(rational const& val, unsigned bit_width) const; + + // Retrieve (or create) a slice representing the given value. + slice mk_value_slice(rational const& val, unsigned bit_width); + + bool has_value(slice s) const { SASSERT_EQ(s, find(s)); return m_slice2val[s].is_nonneg(); } + rational const& get_value(slice s) const { SASSERT(has_value(s)); return m_slice2val[s]; } + // reverse all edges on the path from s to the root of its tree in the proof forest void make_proof_root(slice s); @@ -138,6 +153,10 @@ namespace polysat { // Returns true if merge succeeded without conflict. [[nodiscard]] bool merge_base(slice s1, slice s2, dep_t dep); + // Merge equality s == val and propagate the value downward into sub-slices. + // Returns true if merge succeeded without conflict. + [[nodiscard]] bool merge_value(slice s, rational val, dep_t dep); + void push_reason(slice s, dep_vector& out_deps); // Extract reason why slices x and y are in the same equivalence class @@ -168,17 +187,19 @@ namespace polysat { alloc_slice, split_slice, merge_base, + mk_value_slice, }; svector m_trail; slice_vector m_split_trail; svector> m_merge_trail; // pair of (representative, element) + vector m_val2slice_trail; unsigned_vector m_scopes; void undo_add_var(); void undo_alloc_slice(); void undo_split_slice(); void undo_merge_base(); - + void undo_mk_value_slice(); mutable slice_vector m_tmp1; mutable slice_vector m_tmp2; @@ -217,8 +238,10 @@ namespace polysat { // - fixed bits // - intervals ????? -- that will also need changes in the viable algorithm void propagate(pvar v); + void propagate(signed_constraint c); std::ostream& display(std::ostream& out) const; + std::ostream& display_tree(std::ostream& out, char const* name, slice s) const; std::ostream& display(std::ostream& out, slice s) const; }; diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index 2a84bf5a9..2a9535226 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -118,6 +118,65 @@ namespace polysat { std::cout << " Reason: " << reason << "\n"; } + // 1. a = b + // 2. d = c[1:0] + // 3. c = b[3:0] + // 4. e = a[1:0] + // + // Explain(d = e) should be {1, 2, 3, 4} + static void test4() { + std::cout << __func__ << "\n"; + scoped_solver_slicing s; + slicing& sl = s.sl(); + pvar a = s.add_var(8); + pvar b = s.add_var(8); + pvar c = s.add_var(4); + pvar d = s.add_var(2); + pvar e = s.add_var(2); + + VERIFY(sl.merge(sl.var2slice(a), sl.var2slice(b), sat::literal(101))); + VERIFY(sl.merge(sl.var2slice(d), sl.var2slice(sl.mk_extract_var(c, 1, 0)), sat::literal(102))); + VERIFY(sl.merge(sl.var2slice(c), sl.var2slice(sl.mk_extract_var(b, 3, 0)), sat::literal(103))); + VERIFY(sl.merge(sl.var2slice(e), sl.var2slice(sl.mk_extract_var(a, 1, 0)), sat::literal(104))); + + std::cout << "v" << d << " = v" << e << "? " << sl.is_equal(sl.var2slice(d), sl.var2slice(e)) + << " find(v" << d << ") = " << sl.find(sl.var2slice(d)) + << " find(v" << e << ") = " << sl.find(sl.var2slice(e)) + << " slice(v" << d << ") = " << sl.var2slice(d) + << " slice(v" << e << ") = " << sl.var2slice(e) + << " slice(v" << d << ") = " << sl.m_var2slice[d] + << " slice(v" << e << ") = " << sl.m_var2slice[e] + << "\n"; + sat::literal_vector reason; + sl.explain_equal(sl.var2slice(d), sl.var2slice(e), reason); + std::cout << " Reason: " << reason << "\n"; + reason.reset(); + sl.explain_equal(sl.m_var2slice[d], sl.m_var2slice[e], reason); + std::cout << " Reason: " << reason << "\n"; + } + + // x[5:2] = y + // x[3:0] = z + // y = 0b1001 + // z = 0b0111 + static void test5() { + std::cout << __func__ << "\n"; + scoped_solver_slicing s; + slicing& sl = s.sl(); + pvar x = s.add_var(6); + std::cout << sl << "\n"; + + pvar y = sl.mk_extract_var(x, 5, 2); + std::cout << "v" << y << " := v" << x << "[5:2]\n" << sl << "\n"; + pvar z = sl.mk_extract_var(x, 3, 0); + std::cout << "v" << z << " := v" << x << "[3:0]\n" << sl << "\n"; + + // VERIFY(sl.merge_value(sl.var2slice(y), rational(9))); + // std::cout << "v" << y << " = 9\n" << sl << "\n"; + // VERIFY(sl.merge_value(sl.var2slice(z), rational(7))); + // std::cout << "v" << z << " = 7\n" << sl << "\n"; + } + }; } @@ -128,5 +187,7 @@ void tst_slicing() { test_slicing::test1(); test_slicing::test2(); test_slicing::test3(); + test_slicing::test4(); + // test_slicing::test5(); std::cout << "ok\n"; }