diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index e5e2d5408..17cdad789 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -90,6 +90,16 @@ namespace polysat { fixed_bits(unsigned hi, unsigned lo, rational value) : hi(hi), lo(lo), value(value) {} }; + struct justified_slice { + pvar v; + unsigned offset; + dependency dep; + }; + + inline std::ostream& operator<<(std::ostream& out, justified_slice const& js) { + return out << "v" << js.v << "[" << js.offset << "[ @" << js.dep; + } + using justified_fixed_bits = vector>; using dependency_vector = vector; @@ -100,7 +110,7 @@ namespace polysat { using core_vector = std::initializer_list; using constraint_id_vector = svector; using constraint_id_list = std::initializer_list; - using justified_slices = vector>; + using justified_slices = vector; using eq_justification = svector>; // @@ -118,6 +128,8 @@ namespace polysat { virtual trail_stack& trail() = 0; virtual bool inconsistent() const = 0; virtual void get_bitvector_suffixes(pvar v, justified_slices& out) = 0; + virtual void get_bitvector_sub_slices(pvar v, justified_slices& out) = 0; + virtual void get_bitvector_super_slices(pvar v, justified_slices& out) = 0; virtual void get_fixed_bits(pvar v, justified_fixed_bits& fixed_bits) = 0; }; diff --git a/src/sat/smt/polysat/viable.cpp b/src/sat/smt/polysat/viable.cpp index 6d813d211..b0fc95fe5 100644 --- a/src/sat/smt/polysat/viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -105,13 +105,13 @@ namespace polysat { justified_slices overlaps; c.get_bitvector_suffixes(v, overlaps); - std::sort(overlaps.begin(), overlaps.end(), [&](auto const& x, auto const& y) { return c.size(x.first) > c.size(y.first); }); + std::sort(overlaps.begin(), overlaps.end(), [&](auto const& x, auto const& y) { return c.size(x.v) > c.size(y.v); }); uint_set widths_set; // max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered) widths_set.insert(c.size(v)); - for (auto const& [v, j] : overlaps) + for (auto const& [v, offset, j] : overlaps) for (layer const& l : m_units[v].get_layers()) widths_set.insert(l.bit_width); @@ -176,7 +176,7 @@ namespace polysat { // however, we probably should rotate to avoid getting stuck in refinement loop on a 'bad' constraint bool refined = false; for (unsigned i = overlaps.size(); i-- > 0; ) { - pvar x = overlaps[i].first; + pvar x = overlaps[i].v; rational const& mod_value = c.var2pdd(x).two_to_N(); rational x_val = mod(val, mod_value); if (!refine_viable(x, x_val)) { @@ -240,7 +240,7 @@ namespace polysat { // find relevant interval lists svector ecs; - for (auto const& [x, j] : overlaps) { + for (auto const& [x, offset, j] : overlaps) { if (c.size(x) < w) // note that overlaps are sorted by variable size descending break; if (entry* e = m_units[x].get_entries(w)) { diff --git a/src/sat/smt/polysat_egraph.cpp b/src/sat/smt/polysat_egraph.cpp index ecb215127..1a676e778 100644 --- a/src/sat/smt/polysat_egraph.cpp +++ b/src/sat/smt/polysat_egraph.cpp @@ -36,7 +36,7 @@ namespace polysat { } }; - void solver::get_subslices(pvar pv, subslice_infos& slices) { + void solver::get_sub_slices(pvar pv, slice_infos& slices) { theory_var v = m_pddvar2var[pv]; unsigned lo, hi; expr* e = nullptr; @@ -60,11 +60,11 @@ namespace polysat { } } for (auto p : euf::enode_parents(n->get_root())) { - if (p->is_marked1()) + if (p->get_root()->is_marked1()) continue; if (bv.is_extract(p->get_expr(), lo, hi, e)) { auto child = expr2enode(e); - SASSERT(n == child->get_root()); + SASSERT(n->get_root() == child->get_root()); scoped_eq_justification sp(*this, just, child, n); slices.push_back({ p, offset + lo, just }); } @@ -74,41 +74,136 @@ namespace polysat { n->get_root()->unmark1(); } + void solver::get_super_slices(pvar pv, slice_infos& slices) { + theory_var v = m_pddvar2var[pv]; + unsigned lo, hi; + expr* e = nullptr; + euf::enode* n = var2enode(v); + slices.push_back({ n, 0, {} }); - // walk the egraph starting with pvar for overlaps. + for (unsigned i = 0; i < slices.size(); ++i) { + auto [n, offset, just] = slices[i]; + if (n->get_root()->is_marked1()) + continue; + n->get_root()->mark1(); + for (auto sib : euf::enode_class(n)) { + if (bv.is_extract(sib->get_expr(), lo, hi, e)) { + auto child = expr2enode(e); + SASSERT(n->get_root() == child->get_root()); + scoped_eq_justification sp(*this, just, child, n); + slices.push_back({ sib, offset + lo, just }); + } + } + for (auto p : euf::enode_parents(n->get_root())) { + if (p->get_root()->is_marked1()) + continue; + if (bv.is_concat(p->get_expr())) { + unsigned delta = 0; + for (unsigned j = p->num_args(); j-- > 0; ) { + auto arg = p->get_arg(j); + if (arg->get_root() == n->get_root()) { + scoped_eq_justification sp(*this, just, arg, n); + slices.push_back({ p, offset + delta, just }); + } + delta += bv.get_bv_size(arg->get_expr()); + } + } + } + } + for (auto const& [n, offset, d] : slices) + n->get_root()->unmark1(); + } + + + // walk the egraph starting with pvar for suffix overlaps. void solver::get_bitvector_suffixes(pvar pv, justified_slices& out) { - subslice_infos slices; - get_subslices(pv, slices); + slice_infos slices; + get_sub_slices(pv, slices); + uint_set seen; for (auto& [n, offset, just] : slices) { if (offset != 0) continue; - auto w = n->get_th_var(get_id()); - if (w == euf::null_theory_var) - continue; - auto const& p = m_var2pdd[w]; - if (p.is_var()) - out.push_back({ p.var(), dependency(just, s().scope_lvl())}); // approximate to current scope + for (auto sib : euf::enode_class(n)) { + auto w = sib->get_th_var(get_id()); + if (w == euf::null_theory_var) + continue; + if (seen.contains(w)) + continue; + seen.insert(w); + auto const& p = m_var2pdd[w]; + if (!p.is_var()) + continue; + scoped_eq_justification sp(*this, just, sib, n); + out.push_back({ p.var(), offset, dependency(just, s().scope_lvl()) }); // approximate to current scope + } + } + } + + // walk the egraph starting with pvar for any overlaps. + void solver::get_bitvector_sub_slices(pvar pv, justified_slices& out) { + slice_infos slices; + get_sub_slices(pv, slices); + uint_set seen; + + for (auto& [n, offset, just] : slices) { + for (auto sib : euf::enode_class(n)) { + auto w = sib->get_th_var(get_id()); + if (w == euf::null_theory_var) + continue; + if (seen.contains(w)) + continue; + seen.insert(w); + auto const& p = m_var2pdd[w]; + if (!p.is_var()) + continue; + scoped_eq_justification sp(*this, just, sib, n); + out.push_back({ p.var(), offset, dependency(just, s().scope_lvl()) }); // approximate to current scope + } + } + } + + // walk the egraph for bit-vectors that contain pv. + void solver::get_bitvector_super_slices(pvar pv, justified_slices& out) { + slice_infos slices; + get_super_slices(pv, slices); + uint_set seen; + + for (auto& [n, offset, just] : slices) { + for (auto sib : euf::enode_class(n)) { + auto w = sib->get_th_var(get_id()); + if (w == euf::null_theory_var) + continue; + if (seen.contains(w)) + continue; + seen.insert(w); + auto const& p = m_var2pdd[w]; + if (!p.is_var()) + continue; + scoped_eq_justification sp(*this, just, sib, n); + out.push_back({ p.var(), offset, dependency(just, s().scope_lvl()) }); // approximate to current scope + } } } // walk the e-graph to retrieve fixed overlaps void solver::get_fixed_bits(pvar pv, justified_fixed_bits& out) { - subslice_infos slices; - get_subslices(pv, slices); + slice_infos slices; + get_sub_slices(pv, slices); for (auto& [n, offset, just] : slices) { if (offset != 0) continue; - n = n->get_root(); - if (!n->interpreted()) + if (!n->get_root()->interpreted()) continue; - auto w = n->get_th_var(get_id()); + + auto w = n->get_root()->get_th_var(get_id()); if (w == euf::null_theory_var) continue; auto const& p = m_var2pdd[w]; if (!p.is_var()) continue; + scoped_eq_justification sp(*this, just, n, n->get_root()); unsigned lo = offset, hi = bv.get_bv_size(n->get_expr()); rational value; VERIFY(bv.is_numeral(n->get_expr(), value)); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index f133bd05c..192de891a 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -34,7 +34,7 @@ namespace polysat { typedef sat::literal literal; typedef sat::bool_var bool_var; typedef sat::literal_vector literal_vector; - using subslice_infos = vector>; + using slice_infos = vector>; using pdd = dd::pdd; struct stats { @@ -75,7 +75,8 @@ namespace polysat { sat::check_result intblast(); - void get_subslices(pvar v, subslice_infos& slices); + void get_sub_slices(pvar v, slice_infos& slices); + void get_super_slices(pvar v, slice_infos& slices); // internalize bool visit(expr* e) override; @@ -169,6 +170,8 @@ namespace polysat { void propagate(dependency const& d, bool sign, constraint_id_vector const& deps) override; trail_stack& trail() override; bool inconsistent() const override; + void get_bitvector_sub_slices(pvar v, justified_slices& out) override; + void get_bitvector_super_slices(pvar v, justified_slices& out) override; void get_bitvector_suffixes(pvar v, justified_slices& out) override; void get_fixed_bits(pvar v, justified_fixed_bits& fixed_bits) override;