From 30c0771d24e7f55ff39034abc94f90b394860ff9 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 6 Jan 2024 16:12:01 -0800 Subject: [PATCH] redo fixed bits, add simplifications to intblast solver --- src/sat/smt/intblast_solver.cpp | 59 +++++++++++--- src/sat/smt/intblast_solver.h | 2 + src/sat/smt/polysat/fixed_bits.cpp | 99 +++++------------------ src/sat/smt/polysat/fixed_bits.h | 13 +-- src/sat/smt/polysat/forbidden_intervals.h | 2 + src/sat/smt/polysat/types.h | 2 +- src/sat/smt/polysat/viable.cpp | 47 +++++++---- src/sat/smt/polysat/viable.h | 6 +- src/sat/smt/polysat_model.cpp | 3 + src/sat/smt/polysat_solver.cpp | 12 ++- src/sat/smt/polysat_solver.h | 3 + 11 files changed, 127 insertions(+), 121 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 3e750c8fd..2c373f6b9 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -471,18 +471,25 @@ namespace intblast { }); } + bool solver::is_non_negative(expr* bv_expr, expr* e) { + auto N = rational::power_of_two(bv.get_bv_size(bv_expr)); + rational r; + if (a.is_numeral(e, r)) + return r >= 0; + if (is_bounded(e, N)) + return true; + expr* x, * y; + if (a.is_mul(e, x, y)) + return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); + if (a.is_add(e, x, y)) + return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); + return false; + } + expr* solver::umod(expr* bv_expr, unsigned i) { expr* x = arg(i); - rational r; rational N = bv_size(bv_expr); - if (a.is_numeral(x, r)) { - if (0 <= r && r < N) - return x; - return a.mk_int(mod(r, N)); - } - if (is_bounded(x, N)) - return x; - return a.mk_mod(x, a.mk_int(N)); + return amod(bv_expr, x, N); } expr* solver::smod(expr* bv_expr, unsigned i) { @@ -492,7 +499,7 @@ namespace intblast { rational r; if (a.is_numeral(x, r)) return a.mk_int(mod(r + shift, N)); - return a.mk_mod(add(x, a.mk_int(shift)), a.mk_int(N)); + return amod(bv_expr, add(x, a.mk_int(shift)), N); } expr_ref solver::mul(expr* x, expr* y) { @@ -505,6 +512,9 @@ namespace intblast { return _y; if (a.is_one(y)) return _x; + rational v1, v2; + if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) + return expr_ref(a.mk_int(v1 * v2), m); _x = a.mk_mul(x, y); return _x; } @@ -515,10 +525,37 @@ namespace intblast { return _y; if (a.is_zero(y)) return _x; + rational v1, v2; + if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) + return expr_ref(a.mk_int(v1 + v2), m); _x = a.mk_add(x, y); return _x; } + /* + * Perform simplifications that are claimed sound when the bit-vector interpretations of + * mod/div always guard the mod and dividend to be non-zero. + * Potentially shady area is for arithmetic expressions created by int2bv. + * They will be guarded by a modulus which dose not disappear. + */ + expr* solver::amod(expr* bv_expr, expr* x, rational const& N) { + rational v; + expr* r, *c, * t, * e; + if (m.is_ite(x, c, t, e)) + r = m.mk_ite(c, amod(bv_expr, t, N), amod(bv_expr, e, N)); + else if (a.is_idiv(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N && is_non_negative(bv_expr, e)) + r = x; + else if (a.is_mod(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N) + r = x; + else if (a.is_numeral(x, v)) + r = a.mk_int(mod(v, N)); + else if (is_bounded(x, N)) + r = x; + else + r = a.mk_mod(x, a.mk_int(N)); + return r; + } + rational solver::bv_size(expr* bv_expr) { return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); } @@ -649,7 +686,7 @@ namespace intblast { auto A = rational::power_of_two(sz - n); auto B = rational::power_of_two(n); auto hi = mul(r, a.mk_int(A)); - auto lo = a.mk_mod(a.mk_idiv(umod(e, 0), a.mk_int(B)), a.mk_int(A)); + auto lo = amod(e, a.mk_idiv(umod(e, 0), a.mk_int(B)), A); r = add(hi, lo); } return r; diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 3b155d027..0aceb8b2b 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -74,8 +74,10 @@ namespace intblast { expr* umod(expr* bv_expr, unsigned i); expr* smod(expr* bv_expr, unsigned i); bool is_bounded(expr* v, rational const& N); + bool is_non_negative(expr* bv_expr, expr* e); expr_ref mul(expr* x, expr* y); expr_ref add(expr* x, expr* y); + expr* amod(expr* bv_expr, expr* x, rational const& N); rational bv_size(expr* bv_expr); void translate_expr(expr* e); diff --git a/src/sat/smt/polysat/fixed_bits.cpp b/src/sat/smt/polysat/fixed_bits.cpp index 52cb71480..c275f6a1b 100644 --- a/src/sat/smt/polysat/fixed_bits.cpp +++ b/src/sat/smt/polysat/fixed_bits.cpp @@ -20,96 +20,41 @@ namespace polysat { void fixed_bits::reset() { m_fixed_slices.reset(); m_var = null_var; - m_fixed.reset(); - m_bits.reset(); } // reset with fixed bits information for variable v - void fixed_bits::reset(pvar v) { + void fixed_bits::init(pvar v) { m_fixed_slices.reset(); m_var = v; - m_fixed.reset(); - m_fixed.resize(c.size(v), l_undef); - m_bits.reserve(c.size(v)); - fixed_bits_vector fbs; - c.get_fixed_bits(v, fbs); - for (auto const& fb : fbs) - for (unsigned i = fb.lo; i <= fb.hi; ++i) - m_fixed[i] = to_lbool(fb.value.get_bit(i - fb.lo)); + c.get_fixed_bits(v, m_fixed_slices); } - // find then next value >= val that agrees with fixed bits, or false if none exists within the maximal value for val. - // examples - // fixed bits: 1?0 (least significant bit is last) - // val: 101 - // next: 110 - - // fixed bits ?1?0 - // val 1011 - // next 1100 - - // algorithm: Let i be the most significant index where fixed bits disagree with val. - // Set non-fixed values below i to 0. - // If m_fixed[i] == l_true; then updating val to mask by fixed bits sufficies. - // Otherwise, the range above the disagreement has to be incremented. - // Increment the non-fixed bits by 1 - // The first non-fixed 0 position is set to 1, non-fixed positions below are set to 0. - // If there are none, then the value is maximal and we return false. - - bool fixed_bits::next(rational& val) { - if (m_fixed_slices.empty()) - return true; - unsigned sz = c.size(m_var); - for (unsigned i = 0; i < sz; ++i) - m_bits[i] = val.get_bit(i); - unsigned i = sz; - for (; i-- > 0; ) - if (m_fixed[i] != l_undef && m_fixed[i] != to_lbool(m_bits[i])) - break; - if (i == 0) - return true; - - for (unsigned j = 0; j < sz; ++j) { - if (m_fixed[j] != l_undef) - m_bits[j] = m_fixed[j] == l_true; - else if (j < i) - m_bits[j] = false; - } - - if (m_fixed[i] == l_false) { - for (; i < sz; ++i) { - if (m_fixed[i] != l_undef) - continue; - if (m_bits[i]) - m_bits[i] = false; - else { - m_bits[i] = true; - break; - } - } - - - CTRACE("bv", i == sz, display(tout << "overflow\n")); - // overflow - if (i == sz) + // if x[hi:lo] = value, then + // 2^(w-hi+1)* x >= + bool fixed_bits::check(rational const& val, fi_record& fi) { + for (auto const& s : m_fixed_slices) { + rational bw = rational::power_of_two(s.hi - s.lo + 1); + if (s.value != mod(machine_div2k(val, s.lo + 1), bw)) { + rational hi_val = s.value; + rational lo_val = mod(s.value + 1, bw); + unsigned sz = c.size(m_var); + pdd lo = c.value(rational::power_of_two(sz - s.hi - 1) * lo_val, c.size(m_var)); + pdd hi = c.value(rational::power_of_two(sz - s.hi - 1) * hi_val, c.size(m_var)); + fi.reset(); + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + fi.deps.push_back(dependency({ m_var, s })); + fi.bit_width = s.hi - s.lo + 1; + fi.coeff = 1; return false; + } } - val = 0; - for (unsigned i = sz; i-- > 0;) - val = val * 2 + rational(m_bits[i]); return true; } - // explain the fixed bits ranges. - dependency_vector fixed_bits::explain() { - dependency_vector result; - for (auto const& slice : m_fixed_slices) - result.push_back(dependency({ m_var, slice })); - return result; - } - std::ostream& fixed_bits::display(std::ostream& out) const { - return out << "fixed bits: v" << m_var << " " << m_fixed << "\n"; + for (auto const& s : m_fixed_slices) + out << s.hi << " " << s.lo << " " << s.value << "\n"; + return out; } /** diff --git a/src/sat/smt/polysat/fixed_bits.h b/src/sat/smt/polysat/fixed_bits.h index 551557e7e..0759dbc88 100644 --- a/src/sat/smt/polysat/fixed_bits.h +++ b/src/sat/smt/polysat/fixed_bits.h @@ -13,6 +13,7 @@ Author: #pragma once #include "sat/smt/polysat/types.h" #include "sat/smt/polysat/constraints.h" +#include "sat/smt/polysat/forbidden_intervals.h" #include "util/vector.h" namespace polysat { @@ -31,9 +32,7 @@ namespace polysat { class fixed_bits { core& c; pvar m_var = null_var; - vector m_fixed_slices; - svector m_fixed; - bool_vector m_bits; + fixed_bits_vector m_fixed_slices; public: fixed_bits(core& c) : c(c) {} @@ -41,13 +40,9 @@ namespace polysat { void reset(); // reset with fixed bits information for variable v - void reset(pvar v); + void init(pvar v); - // find then next value >= val that agrees with fixed bits, or false if none exists within the maximal value for val. - bool next(rational& val); - - // explain the fixed bits ranges. - dependency_vector explain(); + bool check(rational const& val, fi_record& fi); std::ostream& display(std::ostream& out) const; }; diff --git a/src/sat/smt/polysat/forbidden_intervals.h b/src/sat/smt/polysat/forbidden_intervals.h index b790da1c8..cab3f737f 100644 --- a/src/sat/smt/polysat/forbidden_intervals.h +++ b/src/sat/smt/polysat/forbidden_intervals.h @@ -26,6 +26,7 @@ namespace polysat { eval_interval interval; vector side_cond; vector src; // only units may have multiple src (as they can consist of contracted bit constraints) + vector deps; rational coeff; unsigned bit_width = 0; // number of lower bits; TODO: should move this to viable::entry; where the coeff/bit-width is adapted accordingly @@ -37,6 +38,7 @@ namespace polysat { side_cond.reset(); src.reset(); coeff.reset(); + deps.reset(); bit_width = 0; } diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index af83b2839..d02d01fb1 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -116,7 +116,7 @@ namespace polysat { return out << "v" << js.v << " at offset " << js.offset; } - using fixed_bits_vector = svector; + using fixed_bits_vector = vector; using dependency_vector = vector; using constraint_or_dependency = std::variant; diff --git a/src/sat/smt/polysat/viable.cpp b/src/sat/smt/polysat/viable.cpp index 340f69c29..f6cfdaccf 100644 --- a/src/sat/smt/polysat/viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -87,7 +87,7 @@ namespace polysat { m_explain.reset(); m_var = v; m_num_bits = c.size(v); - m_fixed_bits.reset(v); + m_fixed_bits.init(v); init_overlaps(v); m_conflict = false; @@ -101,8 +101,9 @@ namespace polysat { return find_t::empty; if (!n) { - if (refine_disequal_lin(v, lo) && - refine_equal_lin(v, lo)) + if (check_fixed_bits(v, lo) && + check_disequal_lin(v, lo) && + check_equal_lin(v, lo)) return find_t::multiple; ++rounds; } @@ -128,11 +129,6 @@ namespace polysat { // viable::entry* viable::find_overlap(rational& val) { - // disable fixed-bits until added to explanation trail. - if (false && !m_fixed_bits.next(val)) { - val = 0; - VERIFY(m_fixed_bits.next(val)); - } entry* last = nullptr; for (auto const& [w, offset] : m_overlaps) { @@ -217,7 +213,7 @@ namespace polysat { return nullptr; } - bool viable::refine_equal_lin(pvar v, rational const& val) { + bool viable::check_equal_lin(pvar v, rational const& val) { // LOG_H2("refine-equal-lin with v" << v << ", val = " << val); entry const* e = m_equal_lin[v]; if (!e) @@ -357,7 +353,22 @@ namespace polysat { return true; } - bool viable::refine_disequal_lin(pvar v, rational const& val) { + bool viable::check_fixed_bits(pvar v, rational const& val) { + // disable fixed bits for now + return true; + + auto e = alloc_entry(v, constraint_id::null()); + if (m_fixed_bits.check(val, *e)) { + m_alloc.push_back(e); + return true; + } + else { + intersect(v, e); + return false; + } + } + + bool viable::check_disequal_lin(pvar v, rational const& val) { // LOG_H2("refine-disequal-lin with v" << v << ", val = " << val); entry const* e = m_diseq_lin[v]; if (!e) @@ -490,14 +501,13 @@ namespace polysat { TRACE("bv", display_explain(tout)); - result.append(m_fixed_bits.explain()); - if (last.e->interval.is_full()) { if (m_var != last.e->var) result.push_back(offset_claim(m_var, { last.e->var, 0 })); for (auto const& sc : last.e->side_cond) result.push_back(c.propagate(sc, c.explain_weak_eval(sc))); - result.push_back(c.get_dependency(last.e->constraint_index)); + if (!last.e->constraint_index.is_null()) + result.push_back(c.get_dependency(last.e->constraint_index)); SASSERT(m_explain.size() == 1); } @@ -506,14 +516,17 @@ namespace polysat { auto index = e.e->constraint_index; explain_overlap(e, after, result); after = e; - if (seen.contains(index.id)) + if (!index.is_null() && seen.contains(index.id)) continue; - seen.insert(index.id); + if (!index.is_null()) + seen.insert(index.id); if (m_var != e.e->var) - result.push_back(offset_claim(m_var, { e.e->var, 0 })); + result.push_back(offset_claim(m_var, { e.e->var, 0 })); for (auto const& sc : e.e->side_cond) result.push_back(c.propagate(sc, c.explain_weak_eval(sc))); - result.push_back(c.get_dependency(index)); + result.append(e.e->deps); + if (!index.is_null()) + result.push_back(c.get_dependency(index)); if (e.e == last.e) break; } diff --git a/src/sat/smt/polysat/viable.h b/src/sat/smt/polysat/viable.h index b933a2743..1a0fd1706 100644 --- a/src/sat/smt/polysat/viable.h +++ b/src/sat/smt/polysat/viable.h @@ -128,9 +128,11 @@ namespace polysat { viable::entry* find_overlap(rational const& val, entry* entries); - bool refine_disequal_lin(pvar v, rational const& val); + bool check_disequal_lin(pvar v, rational const& val); - bool refine_equal_lin(pvar v, rational const& val); + bool check_equal_lin(pvar v, rational const& val); + + bool check_fixed_bits(pvar v, rational const& val); diff --git a/src/sat/smt/polysat_model.cpp b/src/sat/smt/polysat_model.cpp index 23bea0045..f630a47cc 100644 --- a/src/sat/smt/polysat_model.cpp +++ b/src/sat/smt/polysat_model.cpp @@ -75,6 +75,9 @@ namespace polysat { void solver::collect_statistics(statistics& st) const { m_intblast.collect_statistics(st); m_core.collect_statistics(st); + st.update("polysat-conflicts", m_stats.m_num_conflicts); + st.update("polysat-axioms", m_stats.m_num_axioms); + st.update("polysat-propagations", m_stats.m_num_propagations); } std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index c43ea8e8c..9f3873d42 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -119,6 +119,7 @@ namespace polysat { } void solver::set_conflict(dependency_vector const& deps, char const* hint_info) { + ++m_stats.m_num_conflicts; if (inconsistent()) return; auto [lits, eqs] = explain_deps(deps); @@ -127,8 +128,8 @@ namespace polysat { hint = mk_proof_hint(hint_info, lits, eqs); auto ex = euf::th_explain::conflict(*this, lits, eqs, hint); TRACE("bv", tout << "conflict: " << lits << " "; - for (auto [a, b] : eqs) tout << ctx.bpp(a) << " == " << ctx.bpp(b) << " "; - tout << "\n"; s().display(tout)); + for (auto [a, b] : eqs) tout << ctx.bpp(a) << " == " << ctx.bpp(b) << " "; + tout << "\n"; s().display(tout)); validate_conflict(lits, eqs); ctx.set_conflict(ex); } @@ -181,8 +182,7 @@ namespace polysat { for (auto const& [n1, n2] : eqs) SASSERT(n1->get_root() == n2->get_root()); }); - - + return { core, eqs }; } @@ -248,6 +248,7 @@ namespace polysat { // Everything goes over expressions/literals. polysat::core is not responsible for replaying expressions. dependency solver::propagate(signed_constraint sc, dependency_vector const& deps, char const* hint_info) { + ++m_stats.m_num_propagations; TRACE("bv", sc.display(tout << "propagate ") << "\n"); sat::literal lit = ctx.mk_literal(constraint2expr(sc)); if (s().value(lit) == l_true) @@ -280,6 +281,7 @@ namespace polysat { } void solver::propagate(dependency const& d, bool sign, dependency_vector const& deps, char const* hint_info) { + ++m_stats.m_num_propagations; TRACE("bv", tout << "propagate " << d << " " << sign << "\n"); auto [core, eqs] = explain_deps(deps); SASSERT(d.is_bool_var() || d.is_eq()); @@ -322,6 +324,7 @@ namespace polysat { } bool solver::add_axiom(char const* name, constraint_or_dependency const* begin, constraint_or_dependency const* end, bool is_redundant) { + ++m_stats.m_num_axioms; if (inconsistent()) return false; TRACE("bv", tout << "add " << name << "\n"); @@ -357,6 +360,7 @@ namespace polysat { } void solver::add_axiom(char const* name, std::initializer_list const& clause) { + ++m_stats.m_num_axioms; bool is_redundant = false; sat::literal_vector lits; proof_hint* hint = nullptr; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index adbd03ac9..2df1a11ce 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -74,6 +74,9 @@ namespace polysat { struct stats { void reset() { memset(this, 0, sizeof(stats)); } + unsigned m_num_conflicts; + unsigned m_num_propagations; + unsigned m_num_axioms; stats() { reset(); } };