diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 7143df032..f7cb96e90 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -123,6 +123,7 @@ namespace polysat { } slicing::slice_info const& slicing::info(euf::enode* n) const { + SASSERT(n); SASSERT(!n->is_equality()); SASSERT(m_bv->is_bv_sort(n->get_sort())); slice_info const& i = m_info[n->get_id()]; @@ -780,6 +781,7 @@ namespace polysat { } SASSERT(!has_sub(x)); SASSERT(!has_sub(y)); + // TODO: move this above the has_sub check to merge intermediate nodes too? if (width(x) == width(y)) { if (!merge_base(x, y, dep)) { xs.clear(); @@ -833,6 +835,7 @@ namespace polysat { enode_vector& ys = m_tmp3; SASSERT(xs.empty()); SASSERT(ys.empty()); + // TODO: we don't always have to collect the full base if intermediate nodes are already equal get_root_base(x, xs); get_root_base(y, ys); SASSERT(all_of(xs, [](enode* s) { return s->is_root(); })); @@ -1005,7 +1008,6 @@ namespace polysat { continue; pdd const body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p); // c is either x = body or x != body, depending on polarity - LOG("Equation from lit(" << c.blit() << ") " << c << ": v" << x << (c.is_positive() ? " = " : " != ") << body); if (!add_equation(x, body, c.blit())) { SASSERT(is_conflict()); return; @@ -1017,6 +1019,7 @@ namespace polysat { } bool slicing::add_equation(pvar x, pdd const& body, sat::literal lit) { + LOG("Equation from lit(" << lit << "): v" << x << (lit.sign() ? " != " : " = ") << body); enode* const sx = var2slice(x); if (!lit.sign() && body.is_val()) { LOG(" simple assignment"); @@ -1055,14 +1058,112 @@ namespace polysat { (void)merge(sv, sval, mk_var_dep(v, sv)); } - void slicing::collect_overlaps(pvar v, var_overlap_vector& out) { - // - start at var2slice(v) - // - go into subslices, always starting at lowest ones - // - when we find multiple overlaps, we want to merge them: keep a map for this. - // (but note that there can be "holes" in the overlap. by starting at lsb, process overlaps in order low->high. so when we encounter a hole that should mean the overlap is "done" and replace it with the new one in the map.) - // - at each slice: iterate over equivalence class, check if it has a variable as parent (usually it would be the root of the parent-pointers. maybe we should cache a pointer to the next variable-enode, instead of keeping all the parents.) - // - use enode->mark1/2/3 to process each node only once - NOT_IMPLEMENTED_YET(); + void slicing::collect_simple_overlaps(pvar v, pvar_vector& out) { + unsigned const first_out = out.size(); + enode* const sv = var2slice(v); + unsigned const v_width = width(sv); + enode_vector& v_base = m_tmp2; + SASSERT(v_base.empty()); + get_base(var2slice(v), v_base); + + SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); + + // Collect direct sub-slices of v and their equivalences + // (these don't need any extra checks) + for (enode* s = sv; s != nullptr; s = has_sub(s) ? sub_lo(s) : nullptr) { + for (enode* n : euf::enode_class(s)) { + if (!is_proper_slice(n)) + continue; + pvar const w = slice2var(n); + if (w == null_var) + continue; + SASSERT(!n->is_marked1()); + n->mark1(); + out.push_back(w); + } + } + + // lowermost base slice of v + enode* const v_base_lo = v_base.back(); + + svector> candidates; + // Collect all other candidate variables, + // i.e., those who share the lowermost base slice with v. + for (enode* n : euf::enode_class(v_base_lo)) { + if (!is_proper_slice(n)) + continue; + if (n == v_base_lo) + continue; + enode* const n0 = n; + pvar w2 = null_var; // the highest variable we care about from this equivalence class + // examine parents to find variables + SASSERT(!has_sub(n)); + while (true) { + pvar const w = slice2var(n); + if (w != null_var && !n->is_marked1()) + w2 = w; + enode* p = parent(n); + if (!p) + break; + if (sub_lo(p) != n) // we only care about lowermost slices of variables + break; + if (width(p) > v_width) + break; + n = p; + } + if (w2 != null_var) + candidates.push_back({n0, w2}); + } + + // Check candidates + for (auto const& [n0, w2] : candidates) { + // unsigned v_next = v_base.size(); + auto v_it = v_base.rbegin(); + enode* n = n0; + SASSERT_EQ(n->get_root(), (*v_it)->get_root()); + ++v_it; + while (true) { + // here: base of n is equivalent to lower portion of base of v + pvar const w = slice2var(n); + if (w != null_var && !n->is_marked1()) { + n->mark1(); + out.push_back(w); + } + if (w == w2) + break; + // + enode* const p = parent(n); + SASSERT(p); + SASSERT_EQ(sub_lo(p), n); // otherwise not a candidate + // check if base of sub_hi(p) matches the base of v + enode_vector& p_hi_base = m_tmp3; + get_base(sub_hi(p), p_hi_base); + auto p_it = p_hi_base.rbegin(); + bool ok = true; + while (ok && p_it != p_hi_base.rend()) { + if (v_it == v_base.rend()) + ok = false; + else if ((*p_it)->get_root() != (*v_it)->get_root()) + ok = false; + else { + ++p_it; + ++v_it; + } + } + p_hi_base.reset(); + if (!ok) + break; + n = p; + } + } + + v_base.reset(); + for (unsigned i = out.size(); i-- > first_out; ) { + enode* n = var2slice(out[i]); + SASSERT(n->is_marked1()); + n->unmark1(); + } + SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); } std::ostream& slicing::display(std::ostream& out) const { @@ -1095,12 +1196,22 @@ namespace polysat { out << std::string(indent, ' ') << "[" << hi << ":" << lo << "]"; out << " id=" << s->get_id(); out << " w=" << width(s); - if (!s->is_root()) - out << " root=" << s->get_root_id(); + if (slice2var(s) != null_var) + out << " var=v" << slice2var(s); if (parent(s)) out << " parent=" << parent(s)->get_id(); + if (!s->is_root()) + out << " root=" << s->get_root_id(); if (is_value(s->get_root())) out << " root_value=" << get_value(s->get_root()); + // if (is_proper_slice(s)) + // out << " "; + if (is_value(s)) + out << " "; + if (is_concat(s)) + out << " "; + if (is_equality(s)) + out << " "; out << "\n"; if (has_sub(s)) { unsigned cut = info(s).cut; @@ -1111,7 +1222,15 @@ namespace polysat { } std::ostream& slicing::display(std::ostream& out, enode* s) const { - out << "{id:" << s->get_id() << ",w:" << width(s) << "}"; + out << "{id:" << s->get_id(); + out << ",w:" << width(s); + if (is_value(s)) + out << ","; + if (is_concat(s)) + out << ","; + if (is_equality(s)) + out << ","; + out << "}"; return out; } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index 1aecf1133..64c30100b 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -303,23 +303,8 @@ namespace polysat { /** Extract conflict clause */ clause_ref build_conflict_clause(); - /// Example: - /// - assume query_var has segments 11122233 and var has segments 2224 - /// - then the overlapping region "222" is given by width = 3, offset_var = 1, offset_query = 2. - /// (First version will probably only support offset 0.) - struct var_overlap { - pvar var; - /// number of overlapping bits - unsigned width; - /// offset of overlapping region in var - unsigned offset_var; - /// offset of overlapping region in query variable - unsigned offset_query; - }; - using var_overlap_vector = svector; - - /** For a given variable v, find the set of variables that share at least one slice with v. */ - void collect_overlaps(pvar v, var_overlap_vector& out); + /** For a given variable v, find the set of variables w such that w = v[|w|:0]. */ + void collect_simple_overlaps(pvar v, pvar_vector& out); /** Collect fixed portions of the variable v */ void collect_fixed(pvar v, rational& mask, rational& value); diff --git a/src/math/polysat/viable.cpp b/src/math/polysat/viable.cpp index 056a60c01..6ef35c555 100644 --- a/src/math/polysat/viable.cpp +++ b/src/math/polysat/viable.cpp @@ -1554,8 +1554,8 @@ namespace { if (!collect_bit_information(v, true, fixed, justifications)) return l_false; // conflict already added - slicing::var_overlap_vector overlaps; - s.m_slicing.collect_overlaps(v, overlaps); + pvar_vector overlaps; + s.m_slicing.collect_simple_overlaps(v, overlaps); // TODO: (combining intervals across equivalence classes from slicing) // // When iterating over intervals: @@ -1567,10 +1567,12 @@ namespace { // - direct equivalences (x = y); could just point one interval set to the other and store them together (may be annoying for bookkeeping) // - lower bits extractions (x[h:0]) and equivalent slices; // (this is what Algorithm 3 in "Solving Bitvectors with MCSAT" does, and will also let us better handle even coefficients of inequalities). + // - intervals with coefficient 2^k*a to be treated as intervals over x[|x|-k:0] with coefficient a (with odd a) // // Problem: // - the conflict clause will involve relations between different bit-widths // - can we avoid introducing new extract-terms? (if not, can we at least avoid additional slices?) + // e.g., multiply other terms by 2^k instead of introducing extract? // - NOTE: currently our clauses survive across backtracking points, but the slicing will be reset. // It is currently unsafe to create extract/concat terms internally. // (to be fixed when we re-internalize conflict clauses after backtracking) diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index c641a5de6..9b2aa630a 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -38,6 +38,8 @@ namespace polysat { char const* delim = ""; for (void* dp : deps) { slicing::dep_t d = slicing::decode_dep(dp); + if (d.is_null()) + continue; s.sl().display(out << delim, d); delim = " "; } @@ -130,7 +132,7 @@ namespace polysat { pvar c = sl.mk_extract(x, 5, 0); std::cout << "v" << c << " := v" << x << "[5:0]\n" << sl << "\n"; pvar d = sl.mk_concat({sl.mk_extract(x, 5, 4), sl.mk_extract(y, 3, 0)}); - std::cout << d << " := v" << x << "[5:4] ++ v" << y << "[3:0]\n" << sl << "\n"; + std::cout << "v" << d << " := v" << x << "[5:4] ++ v" << y << "[3:0]\n" << sl << "\n"; std::cout << "v" << b << " = v" << c << "? " << sl.is_equal(sl.var2slice(b), sl.var2slice(c)) << "\n\n"; std::cout << "v" << b << " = v" << d << "? " << sl.is_equal(sl.var2slice(b), sl.var2slice(d)) << "\n\n"; @@ -160,6 +162,12 @@ namespace polysat { sl.propagate(); sl.display_tree(std::cout); VERIFY(sl.invariant()); + + for (pvar v : {x, y, a, b, c, d}) { + pvar_vector vars; + sl.collect_simple_overlaps(v, vars); + std::cout << "Simple overlaps for v" << v << ": " << vars << "\n"; + } } // 1. a = b