diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index a87d982f4..24c34a0a4 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -39,9 +39,12 @@ z3_add_component(sat_smt polysat_constraints.cpp polysat_core.cpp polysat_internalize.cpp + polysat_fi.cpp polysat_model.cpp polysat_solver.cpp polysat_ule.cpp + polysat_umul_ovfl.cpp + polysat_viable.cpp q_clause.cpp q_ematch.cpp q_eval.cpp diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat_constraints.cpp index 101d60e62..a03b4f5f5 100644 --- a/src/sat/smt/polysat_constraints.cpp +++ b/src/sat/smt/polysat_constraints.cpp @@ -28,4 +28,14 @@ namespace polysat { auto sc = signed_constraint(ckind_t::ule_t, c); return is_positive ? sc : ~sc; } + + lbool signed_constraint::eval(assignment& a) const { + lbool r = m_constraint->eval(a); + return m_sign ? ~r : r; + } + + std::ostream& signed_constraint::display(std::ostream& out) const { + if (m_sign) out << "~"; + return out << *m_constraint; + } } diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h index da82431c4..121fc2da6 100644 --- a/src/sat/smt/polysat_constraints.h +++ b/src/sat/smt/polysat_constraints.h @@ -21,6 +21,7 @@ namespace polysat { class core; class ule_constraint; + class umul_ovfl_constraint; class assignment; using pdd = dd::pdd; @@ -42,14 +43,8 @@ namespace polysat { virtual lbool eval(assignment const& a) const = 0; }; + inline std::ostream& operator<<(std::ostream& out, constraint const& c) { return c.display(out); } - class umul_ovfl_constraint : public constraint { - pdd m_lhs, m_rhs; - public: - umul_ovfl_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} - pdd const& lhs() const { return m_lhs; } - pdd const& rhs() const { return m_rhs; } - }; class signed_constraint { bool m_sign = false; @@ -60,10 +55,13 @@ namespace polysat { signed_constraint(ckind_t c, constraint* p) : m_op(c), m_constraint(p) {} signed_constraint operator~() const { signed_constraint r(*this); r.m_sign = !r.m_sign; return r; } bool sign() const { return m_sign; } + bool is_positive() const { return !m_sign; } + bool is_negative() const { return m_sign; } unsigned_vector& vars() { return m_constraint->vars(); } unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + lbool eval(assignment& a) const; ckind_t op() const { return m_op; } bool is_ule() const { return m_op == ule_t; } bool is_umul_ovfl() const { return m_op == umul_ovfl_t; } @@ -71,8 +69,11 @@ namespace polysat { ule_constraint const& to_ule() const { return *reinterpret_cast(m_constraint); } umul_ovfl_constraint const& to_umul_ovfl() const { return *reinterpret_cast(m_constraint); } bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } + std::ostream& display(std::ostream& out) const; }; + inline std::ostream& operator<<(std::ostream& out, signed_constraint const& c) { return c.display(out); } + class constraints { trail_stack& m_trail; public: diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index 73e9c5a6b..3de88d93b 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -262,8 +262,10 @@ namespace polysat { } // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { - for (auto idx : m_watch[m_var]) { - auto [sc, d] = m_constraint_trail[idx]; + for (auto idx1 : m_watch[m_var]) { + if (idx == idx1) + continue; + auto [sc, d] = m_constraint_trail[idx1]; switch (eval(sc)) { case l_false: s.propagate(d, true, explain_eval(sc)); @@ -299,7 +301,11 @@ namespace polysat { } lbool core::eval(signed_constraint const& sc) { - throw default_exception("nyi"); + return sc.eval(m_assignment); + } + + pdd core::subst(pdd const& p) { + return m_assignment.apply_to(p); } } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 6944c39d8..3c8a79bd6 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -70,7 +70,7 @@ namespace polysat { dd::pdd_manager& sz2pdd(unsigned sz) const; dd::pdd_manager& var2pdd(pvar v) const; - unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + void del_var(); bool is_assigned(pvar v) { return !m_justification[v].is_null(); } @@ -96,6 +96,7 @@ namespace polysat { void assign_eh(unsigned idx, bool sign, dependency const& d); pdd value(rational const& v, unsigned sz); + pdd subst(pdd const&); signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } signed_constraint eq(pdd const& p, pdd const& q) { return m_constraints.eq(p - q); } @@ -124,6 +125,9 @@ namespace polysat { pdd concat(unsigned n, pdd const* args) { throw default_exception("nyi"); } pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } + unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + + constraints& cs() { return m_constraints; } std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } }; diff --git a/src/sat/smt/polysat_fi.cpp b/src/sat/smt/polysat_fi.cpp new file mode 100644 index 000000000..349243ed8 --- /dev/null +++ b/src/sat/smt/polysat_fi.cpp @@ -0,0 +1,588 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Conflict explanation using forbidden intervals as described in + "Solving bitvectors with MCSAT: explanations from bits and pieces" + by S. Graham-Lengrand, D. Jovanovic, B. Dutertre. + +Author: + + Jakob Rath 2021-04-06 + Nikolaj Bjorner (nbjorner) 2021-03-19 + +--*/ +#include "sat/smt/polysat_fi.h" +#include "sat/smt/polysat_interval.h" +#include "sat/smt/polysat_umul_ovfl.h" +#include "sat/smt/polysat_ule.h" +#include "sat/smt/polysat_core.h" + +namespace polysat { + + /** + * + * \param[in] c Original constraint + * \param[in] v Variable that is bounded by constraint + * \param[out] fi "forbidden interval" record that captures values not allowed for v + * \returns True iff a forbidden interval exists and the output parameters were set. + */ + bool forbidden_intervals::get_interval(signed_constraint const& c, pvar v, fi_record& fi) { + // verbose_stream() << "get_interval for v" << v << " " << c << "\n"; + SASSERT(fi.side_cond.empty()); + SASSERT(fi.src.empty()); + fi.bit_width = s.size(v); // TODO: preliminary + if (c.is_ule()) + return get_interval_ule(c, v, fi); + if (c.is_umul_ovfl()) + return get_interval_umul_ovfl(c, v, fi); + return false; + } + + bool forbidden_intervals::get_interval_umul_ovfl(signed_constraint const& c, pvar v, fi_record& fi) { + using std::swap; + + backtrack _backtrack(fi.side_cond); + + fi.coeff = 1; + fi.src.push_back(c); + + // eval(lhs) = a1*v + eval(e1) = a1*v + b1 + // eval(rhs) = a2*v + eval(e2) = a2*v + b2 + // We keep the e1, e2 around in case we need side conditions such as e1=b1, e2=b2. + auto [ok1, a1, e1, b1] = linear_decompose(v, c.to_umul_ovfl().p(), fi.side_cond); + auto [ok2, a2, e2, b2] = linear_decompose(v, c.to_umul_ovfl().q(), fi.side_cond); + + auto& m = e1.manager(); + rational bound = m.max_value(); + + if (ok2 && !ok1) { + swap(a1, a2); + swap(e1, e2); + swap(b1, b2); + swap(ok1, ok2); + } + if (ok1 && !ok2 && a1.is_one() && b1.is_zero()) { + if (c.is_positive()) { + _backtrack.released = true; + rational lo_val(0); + rational hi_val(2); + pdd lo = m.mk_val(lo_val); + pdd hi = m.mk_val(hi_val); + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + } + + if (!ok1 || !ok2) + return false; + + + if (a2.is_one() && a1.is_zero()) { + swap(a1, a2); + swap(e1, e2); + swap(b1, b2); + } + + if (!a1.is_one() || !a2.is_zero()) + return false; + + if (!b1.is_zero()) + return false; + + _backtrack.released = true; + + // Ovfl(v, e2) + + + if (c.is_positive()) { + if (b2.val() <= 1) { + fi.interval = eval_interval::full(); + fi.side_cond.push_back(s.cs().ule(e2, 1)); + } + else { + // [0, div(bound, b2.val()) + 1[ + rational lo_val(0); + rational hi_val(div(bound, b2.val()) + 1); + pdd lo = m.mk_val(lo_val); + pdd hi = m.mk_val(hi_val); + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + fi.side_cond.push_back(s.cs().ule(e2, b2.val())); + } + + } + else { + if (b2.val() <= 1) { + _backtrack.released = false; + return false; + } + else { + // [div(bound, b2.val()) + 1, 0[ + rational lo_val(div(bound, b2.val()) + 1); + rational hi_val(0); + pdd lo = m.mk_val(lo_val); + pdd hi = m.mk_val(hi_val); + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + fi.side_cond.push_back(s.cs().ule(b2.val(), e2)); + } + } + + // LOG("overflow interval " << fi.interval); + + return true; + } + + static char const* _last_function = ""; + + bool forbidden_intervals::get_interval_ule(signed_constraint const& c, pvar v, fi_record& fi) { + + backtrack _backtrack(fi.side_cond); + + fi.coeff = 1; + fi.src.push_back(c); + + struct show { + forbidden_intervals& f; + signed_constraint const& c; + pvar v; + fi_record& fi; + backtrack& _backtrack; + show(forbidden_intervals& f, + signed_constraint const& c, + pvar v, + fi_record& fi, + backtrack& _backtrack):f(f), c(c), v(v), fi(fi), _backtrack(_backtrack) {} + ~show() { + if (!_backtrack.released) + return; + IF_VERBOSE(0, verbose_stream() << _last_function << " " << v << " " << c << " " << fi.interval << " " << fi.side_cond << "\n"); + } + }; + // uncomment to trace intervals + // show _show(*this, c, v, fi, _backtrack); + + // eval(lhs) = a1*v + eval(e1) = a1*v + b1 + // eval(rhs) = a2*v + eval(e2) = a2*v + b2 + // We keep the e1, e2 around in case we need side conditions such as e1=b1, e2=b2. + auto [ok1, a1, e1, b1] = linear_decompose(v, c.to_ule().lhs(), fi.side_cond); + auto [ok2, a2, e2, b2] = linear_decompose(v, c.to_ule().rhs(), fi.side_cond); + + _backtrack.released = true; + + // v > q + if (false && ok1 && !ok2 && match_non_zero(c, a1, b1, e1, c.to_ule().rhs(), fi)) + return true; + + // p > v + if (false && !ok1 && ok2 && match_non_max(c, c.to_ule().lhs(), a2, b2, e2, fi)) + return true; + + if (!ok1 || !ok2 || (a1.is_zero() && a2.is_zero())) { + _backtrack.released = false; + return false; + } + SASSERT(b1.is_val()); + SASSERT(b2.is_val()); + + // a*v + b <= 0, a odd + // a*v + b > 0, a odd + if (match_zero(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + + // -1 <= a*v + b, a odd + // -1 > a*v + b, a odd + if (match_max(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + + if (match_linear1(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + if (match_linear2(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + if (match_linear3(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + if (match_linear4(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + + _backtrack.released = false; + return false; + } + + void forbidden_intervals::push_eq(bool is_zero, pdd const& p, vector& side_cond) { + SASSERT(!p.is_val() || (is_zero == p.is_zero())); + if (p.is_val()) + return; + else if (is_zero) + side_cond.push_back(s.eq(p)); + else + side_cond.push_back(~s.eq(p)); + } + + std::tuple forbidden_intervals::linear_decompose(pvar v, pdd const& p, vector& out_side_cond) { + auto& m = p.manager(); + pdd q = m.zero(); + pdd e = m.zero(); + unsigned const deg = p.degree(v); + if (deg == 0) + // p = 0*v + e + e = p; + else if (deg == 1) + // p = q*v + e + p.factor(v, 1, q, e); + else + return std::tuple(false, rational(0), q, e); + + // r := eval(q) + // Add side constraint q = r. + if (!q.is_val()) { + pdd r = s.subst(q); + + + if (!r.is_val()) + return std::tuple(false, rational(0), q, e); + out_side_cond.push_back(s.eq(q, r)); + q = r; + } + auto b = s.subst(e); + return std::tuple(b.is_val(), q.val(), e, b); + }; + + eval_interval forbidden_intervals::to_interval( + signed_constraint const& c, bool is_trivial, rational & coeff, + rational & lo_val, pdd & lo, + rational & hi_val, pdd & hi) { + + dd::pdd_manager& m = lo.manager(); + + if (is_trivial) { + if (c.is_positive()) + // TODO: we cannot use empty intervals for interpolation. So we + // can remove the empty case (make it represent 'full' instead), + // and return 'false' here. Then we do not need the proper/full + // tag on intervals. + return eval_interval::empty(m); + else + return eval_interval::full(); + } + + rational pow2 = m.two_to_N(); + + if (coeff > pow2/2) { + // TODO: if coeff != pow2 - 1, isn't this counterproductive now? considering the gap condition on refine-equal-lin acceleration. + + coeff = pow2 - coeff; + SASSERT(coeff > 0); + // Transform according to: y \in [l;u[ <=> -y \in [1-u;1-l[ + // -y \in [1-u;1-l[ + // <=> -y - (1 - u) < (1 - l) - (1 - u) { by: y \in [l;u[ <=> y - l < u - l } + // <=> u - y - 1 < u - l { simplified } + // <=> (u-l) - (u-y-1) - 1 < u-l { by: a < b <=> b - a - 1 < b } + // <=> y - l < u - l { simplified } + // <=> y \in [l;u[. + lo = 1 - lo; + hi = 1 - hi; + swap(lo, hi); + lo_val = mod(1 - lo_val, pow2); + hi_val = mod(1 - hi_val, pow2); + lo_val.swap(hi_val); + } + + if (c.is_positive()) + return eval_interval::proper(lo, lo_val, hi, hi_val); + else + return eval_interval::proper(hi, hi_val, lo, lo_val); + } + + /** + * Match e1 + t <= e2, with t = a1*y + * condition for empty/full: e2 == -1 + */ + bool forbidden_intervals::match_linear1(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a2.is_zero() && !a1.is_zero()) { + SASSERT(!a1.is_zero()); + bool is_trivial = (b2 + 1).is_zero(); + push_eq(is_trivial, e2 + 1, fi.side_cond); + auto lo = e2 - e1 + 1; + rational lo_val = (b2 - b1 + 1).val(); + auto hi = -e1; + rational hi_val = (-b1).val(); + fi.coeff = a1; + fi.interval = to_interval(c, is_trivial, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + return true; + } + return false; + } + + /** + * e1 <= e2 + t, with t = a2*y + * condition for empty/full: e1 == 0 + */ + bool forbidden_intervals::match_linear2(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1.is_zero() && !a2.is_zero()) { + SASSERT(!a2.is_zero()); + bool is_trivial = b1.is_zero(); + push_eq(is_trivial, e1, fi.side_cond); + auto lo = -e2; + rational lo_val = (-b2).val(); + auto hi = e1 - e2; + rational hi_val = (b1 - b2).val(); + fi.coeff = a2; + fi.interval = to_interval(c, is_trivial, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + return true; + } + return false; + } + + /** + * e1 + t <= e2 + t, with t = a1*y = a2*y + * condition for empty/full: e1 == e2 + */ + bool forbidden_intervals::match_linear3(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1 == a2 && !a1.is_zero()) { + bool is_trivial = b1.val() == b2.val(); + push_eq(is_trivial, e1 - e2, fi.side_cond); + auto lo = -e2; + rational lo_val = (-b2).val(); + auto hi = -e1; + rational hi_val = (-b1).val(); + fi.coeff = a1; + fi.interval = to_interval(c, is_trivial, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + return true; + } + return false; + } + + /** + * e1 + t <= e2 + t', with t = a1*y, t' = a2*y, a1 != a2, a1, a2 non-zero + */ + bool forbidden_intervals::match_linear4(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1 != a2 && !a1.is_zero() && !a2.is_zero()) { + // NOTE: we don't have an interval here in the same sense as in the other cases. + // We use the interval to smuggle out the values a1,b1,a2,b2 without adding additional fields. + // to_interval flips a1,b1 with a2,b2 for negative constraints, which we also need for this case. + auto lo = b1; + rational lo_val = a1; + auto hi = b2; + rational hi_val = a2; + // We use fi.coeff = -1 to tell the caller to treat it as a diseq_lin. + fi.coeff = -1; + fi.interval = to_interval(c, false, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + SASSERT(!fi.interval.is_currently_empty()); + return true; + } + return false; + } + + /** + * a*v <= 0, a odd + * forbidden interval for v is [1;0[ + * + * a*v + b <= 0, a odd + * forbidden interval for v is [n+1;n[ where n = -b * a^-1 + * + * TODO: extend to + * 2^k*a*v <= 0, a odd + * (using intervals for the lower bits of v) + */ + bool forbidden_intervals::match_zero( + signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1.is_odd() && a2.is_zero() && b2.is_zero()) { + auto& m = e1.manager(); + rational const& mod_value = m.two_to_N(); + rational a1_inv; + VERIFY(a1.mult_inverse(m.power_of_2(), a1_inv)); + + // interval for a*v + b > 0 is [n;n+1[ where n = -b * a^-1 + rational lo_val = mod(-b1.val() * a1_inv, mod_value); + pdd lo = -e1 * a1_inv; + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + 1; + + // interval for a*v + b <= 0 is the complement + if (c.is_positive()) { + std::swap(lo_val, hi_val); + std::swap(lo, hi); + } + + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + // RHS == 0 is a precondition because we can only multiply with a^-1 in equations, not inequalities + if (b2 != e2) + fi.side_cond.push_back(s.eq(b2, e2)); + return true; + } + return false; + } + + /** + * -1 <= a*v + b, a odd + * forbidden interval for v is [n+1;n[ where n = (-b-1) * a^-1 + */ + bool forbidden_intervals::match_max( + signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1.is_zero() && b1.is_max() && a2.is_odd()) { + auto& m = e2.manager(); + rational const& mod_value = m.two_to_N(); + rational a2_inv; + VERIFY(a2.mult_inverse(m.power_of_2(), a2_inv)); + + // interval for -1 > a*v + b is [n;n+1[ where n = (-b-1) * a^-1 + rational lo_val = mod((-1 - b2.val()) * a2_inv, mod_value); + pdd lo = (-1 - e2) * a2_inv; + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + 1; + + // interval for -1 <= a*v + b is the complement + if (c.is_positive()) { + std::swap(lo_val, hi_val); + std::swap(lo, hi); + } + + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + // LHS == -1 is a precondition because we can only multiply with a^-1 in equations, not inequalities + if (b1 != e1) + fi.side_cond.push_back(s.eq(b1, e1)); + return true; + } + return false; + } + + /** + * v > q + * forbidden interval for v is [0,1[ + * + * v - k > q + * forbidden interval for v is [k,k+1[ + * + * v > q + * forbidden interval for v is [0;q+1[ but at least [0;1[ + * + * The following cases are implemented, and subsume the simple ones above. + * + * v - k > q + * forbidden interval for v is [k;k+q+1[ but at least [k;k+1[ + * + * a*v - k > q, a odd + * forbidden interval for v is [a^-1*k, a^-1*k + 1[ + */ + bool forbidden_intervals::match_non_zero( + signed_constraint const& c, + rational const& a1, pdd const& b1, pdd const& e1, + pdd const& q, + fi_record& fi) { + _last_function = __func__; + SASSERT(b1.is_val()); + if (a1.is_one() && c.is_negative()) { + // v - k > q + auto& m = e1.manager(); + rational const& mod_value = m.two_to_N(); + rational lo_val = (-b1).val(); + pdd lo = -e1; + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + q + 1; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + if (a1.is_odd() && c.is_negative()) { + // a*v - k > q, a odd + auto& m = e1.manager(); + rational const& mod_value = m.two_to_N(); + rational a1_inv; + VERIFY(a1.mult_inverse(m.power_of_2(), a1_inv)); + rational lo_val(mod(-b1.val() * a1_inv, mod_value)); + auto lo = -e1 * a1_inv; + rational hi_val(mod(lo_val + 1, mod_value)); + auto hi = lo + 1; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + return false; + } + + /** + * p > v + * forbidden interval for v is [p;0[ but at least [-1,0[ + * + * p > v + k + * forbidden interval for v is [p-k;-k[ but at least [-1-k,-k[ + * + * p > a*v + k, a odd + * forbidden interval for v is [ a^-1*(-1-k) ; a^-1*(-1-k) + 1 [ + */ + bool forbidden_intervals::match_non_max( + signed_constraint const& c, + pdd const& p, + rational const& a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + SASSERT(b2.is_val()); + if (a2.is_one() && c.is_negative()) { + // p > v + k + auto& m = e2.manager(); + rational const& mod_value = m.two_to_N(); + rational hi_val = (-b2).val(); + pdd hi = -e2; + rational lo_val = mod(hi_val - 1, mod_value); + pdd lo = p - e2; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + if (a2.is_odd() && c.is_negative()) { + // p > a*v + k, a odd + auto& m = e2.manager(); + rational const& mod_value = m.two_to_N(); + rational a2_inv; + VERIFY(a2.mult_inverse(m.power_of_2(), a2_inv)); + rational lo_val = mod(a2_inv * (-1 - b2.val()), mod_value); + pdd lo = a2_inv * (-1 - e2); + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + 1; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + return false; + } + + + void forbidden_intervals::add_non_unit_side_conds(fi_record& fi, pdd const& b1, pdd const& e1, pdd const& b2, pdd const& e2) { + if (fi.coeff == 1) + return; + if (b1 != e1) + fi.side_cond.push_back(s.eq(b1, e1)); + if (b2 != e2) + fi.side_cond.push_back(s.eq(b2, e2)); + } +} diff --git a/src/sat/smt/polysat_fi.h b/src/sat/smt/polysat_fi.h new file mode 100644 index 000000000..7782deb4a --- /dev/null +++ b/src/sat/smt/polysat_fi.h @@ -0,0 +1,122 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Conflict explanation using forbidden intervals as described in + "Solving bitvectors with MCSAT: explanations from bits and pieces" + by S. Graham-Lengrand, D. Jovanovic, B. Dutertre. + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "sat/smt/polysat_types.h" +#include "sat/smt/polysat_interval.h" +#include "sat/smt/polysat_constraints.h" + +namespace polysat { + + class core; + + struct fi_record { + eval_interval interval; + vector side_cond; + vector src; // only units may have multiple src (as they can consist of contracted bit constraints) + rational coeff; + unsigned bit_width = 0; // number of lower bits; TODO: should move this to viable::entry; where the coeff/bit-width is adapted accordingly + + /** Create invalid fi_record */ + fi_record(): interval(eval_interval::full()) {} + + void reset() { + interval = eval_interval::full(); + side_cond.reset(); + src.reset(); + coeff.reset(); + bit_width = 0; + } + + struct less { + bool operator()(fi_record const& a, fi_record const& b) const { + return a.interval.lo_val() < b.interval.lo_val(); + } + }; + }; + + class forbidden_intervals { + + void push_eq(bool is_trivial, pdd const& p, vector& side_cond); + eval_interval to_interval(signed_constraint const& c, bool is_trivial, rational& coeff, + rational & lo_val, pdd & lo, rational & hi_val, pdd & hi); + + + std::tuple linear_decompose(pvar v, pdd const& p, vector& out_side_cond); + + bool match_linear1(signed_constraint const& c, + rational const& a1, pdd const& b1, pdd const& e1, + rational const& a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_linear2(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_linear3(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_linear4(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + void add_non_unit_side_conds(fi_record& fi, pdd const& b1, pdd const& e1, pdd const& b2, pdd const& e2); + + bool match_zero(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_max(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_non_zero(signed_constraint const& c, + rational const& a1, pdd const& b1, pdd const& e1, + pdd const& q, + fi_record& fi); + + bool match_non_max(signed_constraint const& c, + pdd const& p, + rational const& a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool get_interval_ule(signed_constraint const& c, pvar v, fi_record& fi); + + bool get_interval_umul_ovfl(signed_constraint const& c, pvar v, fi_record& fi); + + struct backtrack { + bool released = false; + vector& side_cond; + unsigned sz; + backtrack(vector& s):side_cond(s), sz(s.size()) {} + ~backtrack() { + if (!released) + side_cond.shrink(sz); + } + }; + + core& s; + + public: + forbidden_intervals(core& s): s(s) {} + bool get_interval(signed_constraint const& c, pvar v, fi_record& fi); + }; +} diff --git a/src/sat/smt/polysat_interval.h b/src/sat/smt/polysat_interval.h new file mode 100644 index 000000000..9965dbab1 --- /dev/null +++ b/src/sat/smt/polysat_interval.h @@ -0,0 +1,224 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat intervals + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + +--*/ +#pragma once +#include "sat/smt/polysat_types.h" +#include + +namespace polysat { + + struct pdd_bounds { + pdd lo; ///< lower bound, inclusive + pdd hi; ///< upper bound, exclusive + }; + + /** + * An interval is either [lo; hi[ (excl. upper bound) or the full domain Z_{2^w}. + * If lo > hi, the interval wraps around, i.e., represents the union of [lo; 2^w[ and [0; hi[. + * Membership test t \in [lo; hi[ is equivalent to t - lo < hi - lo. + */ + class interval { + std::optional m_bounds = std::nullopt; + + interval() = default; + interval(pdd const& lo, pdd const& hi): m_bounds({lo, hi}) {} + public: + static interval empty(dd::pdd_manager& m) { return proper(m.zero(), m.zero()); } + static interval full() { return {}; } + static interval proper(pdd const& lo, pdd const& hi) { return {lo, hi}; } + + interval(interval const&) = default; + interval(interval&&) = default; + interval& operator=(interval const& other) { + m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager (probably should change the PDD assignment operator; but for now I want to be able to detect manager confusions) + m_bounds = other.m_bounds; + return *this; + } + interval& operator=(interval&& other) { + m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager + m_bounds = std::move(other.m_bounds); + return *this; + } + ~interval() = default; + + bool is_full() const { return !m_bounds; } + bool is_proper() const { return !!m_bounds; } + bool is_always_empty() const { return is_proper() && lo() == hi(); } + pdd const& lo() const { SASSERT(is_proper()); return m_bounds->lo; } + pdd const& hi() const { SASSERT(is_proper()); return m_bounds->hi; } + }; + + inline std::ostream& operator<<(std::ostream& os, interval const& i) { + if (i.is_full()) + return os << "full"; + else + return os << "[" << i.lo() << " ; " << i.hi() << "["; + } + + // distance from a to b, wrapping around at mod_value. + // basically mod(b - a, mod_value), but distance(0, mod_value, mod_value) = mod_value. + inline rational distance(rational const& a, rational const& b, rational const& mod_value) { + SASSERT(mod_value.is_power_of_two()); + SASSERT(0 <= a && a < mod_value); + SASSERT(0 <= b && b <= mod_value); + rational x = b - a; + if (x.is_neg()) + x += mod_value; + return x; + } + + class r_interval { + rational m_lo; + rational m_hi; + + r_interval(rational lo, rational hi) + : m_lo(std::move(lo)), m_hi(std::move(hi)) + {} + + public: + + static r_interval empty() { + return {rational::zero(), rational::zero()}; + } + + static r_interval full() { + return {rational(-1), rational::zero()}; + } + + static r_interval proper(rational lo, rational hi) { + SASSERT(0 <= lo); + SASSERT(0 <= hi); + return {std::move(lo), std::move(hi)}; + } + + bool is_full() const { return m_lo.is_neg(); } + bool is_proper() const { return !is_full(); } + bool is_empty() const { return is_proper() && lo() == hi(); } + rational const& lo() const { SASSERT(is_proper()); return m_lo; } + rational const& hi() const { SASSERT(is_proper()); return m_hi; } + + // this one also supports representing full intervals as [lo;mod_value[ + static rational len(rational const& lo, rational const& hi, rational const& mod_value) { + SASSERT(mod_value.is_power_of_two()); + SASSERT(0 <= lo && lo < mod_value); + SASSERT(0 <= hi && hi <= mod_value); + SASSERT(hi != mod_value || lo == 0); // hi == mod_value only allowed when lo == 0 + rational len = hi - lo; + if (len.is_neg()) + len += mod_value; + return len; + } + + rational len(rational const& mod_value) const { + SASSERT(is_proper()); + return len(lo(), hi(), mod_value); + } + + // deals only with proper intervals + // but works with full intervals represented as [0;mod_value[ -- maybe we should just change representation of full intervals to this always + static bool contains(rational const& lo, rational const& hi, rational const& val) { + if (lo <= hi) + return lo <= val && val < hi; + else + return val < hi || val >= lo; + } + + bool contains(rational const& val) const { + if (is_full()) + return true; + else + return contains(lo(), hi(), val); + } + + }; + + class eval_interval { + interval m_symbolic; + rational m_concrete_lo; + rational m_concrete_hi; + + eval_interval(interval&& i, rational const& lo_val, rational const& hi_val): + m_symbolic(std::move(i)), m_concrete_lo(lo_val), m_concrete_hi(hi_val) {} + public: + static eval_interval empty(dd::pdd_manager& m) { + return {interval::empty(m), rational::zero(), rational::zero()}; + } + + static eval_interval full() { + return {interval::full(), rational::zero(), rational::zero()}; + } + + static eval_interval proper(pdd const& lo, rational const& lo_val, pdd const& hi, rational const& hi_val) { + SASSERT(0 <= lo_val && lo_val <= lo.manager().max_value()); + SASSERT(0 <= hi_val && hi_val <= hi.manager().max_value()); + return {interval::proper(lo, hi), lo_val, hi_val}; + } + + bool is_full() const { return m_symbolic.is_full(); } + bool is_proper() const { return m_symbolic.is_proper(); } + bool is_always_empty() const { return m_symbolic.is_always_empty(); } + bool is_currently_empty() const { return is_proper() && lo_val() == hi_val(); } + interval const& symbolic() const { return m_symbolic; } + pdd const& lo() const { return m_symbolic.lo(); } + pdd const& hi() const { return m_symbolic.hi(); } + rational const& lo_val() const { SASSERT(is_proper()); return m_concrete_lo; } + rational const& hi_val() const { SASSERT(is_proper()); return m_concrete_hi; } + + rational current_len() const { + SASSERT(is_proper()); + return mod(hi_val() - lo_val(), lo().manager().two_to_N()); + } + + bool currently_contains(rational const& val) const { + if (is_full()) + return true; + else if (lo_val() <= hi_val()) + return lo_val() <= val && val < hi_val(); + else + return val < hi_val() || val >= lo_val(); + } + + bool currently_contains(eval_interval const& other) const { + if (is_full()) + return true; + if (other.is_full()) + return false; + // lo <= lo' <= hi' <= hi' + if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) + return true; + if (lo_val() <= hi_val()) + return false; + // hi < lo <= lo' <= hi' + if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val()) + return true; + // lo' <= hi' <= hi < lo + if (other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) + return true; + // hi' <= hi < lo <= lo' + if (other.hi_val() <= hi_val() && lo_val() <= other.lo_val()) + return true; + return false; + } + + }; // class eval_interval + + inline std::ostream& operator<<(std::ostream& os, eval_interval const& i) { + if (i.is_full()) + return os << "full"; + else { + auto& m = i.hi().manager(); + return os << i.symbolic() << " := [" << m.normalize(i.lo_val()) << ";" << m.normalize(i.hi_val()) << "["; + } + } + +} diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 57098c447..3e01ff391 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -28,6 +28,7 @@ The result of polysat::core::check is one of: #include "sat/smt/polysat_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/polysat_ule.h" +#include "sat/smt/polysat_umul_ovfl.h" namespace polysat { @@ -221,8 +222,8 @@ namespace polysat { return expr_ref(bv.mk_ule(l, h), m); } case ckind_t::umul_ovfl_t: { - auto l = pdd2expr(sc.to_umul_ovfl().lhs()); - auto r = pdd2expr(sc.to_umul_ovfl().rhs()); + auto l = pdd2expr(sc.to_umul_ovfl().p()); + auto r = pdd2expr(sc.to_umul_ovfl().q()); return expr_ref(bv.mk_bvumul_ovfl(l, r), m); } case ckind_t::smul_fl_t: diff --git a/src/sat/smt/polysat_umul_ovfl.cpp b/src/sat/smt/polysat_umul_ovfl.cpp new file mode 100644 index 000000000..5c448bc0a --- /dev/null +++ b/src/sat/smt/polysat_umul_ovfl.cpp @@ -0,0 +1,73 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat multiplication overflow constraint + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 + +--*/ +#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat_assignment.h" +#include "sat/smt/polysat_umul_ovfl.h" + + +namespace polysat { + + umul_ovfl_constraint::umul_ovfl_constraint(pdd const& p, pdd const& q): + m_p(p), m_q(q) { + simplify(); + vars().append(m_p.free_vars()); + for (auto v : m_q.free_vars()) + if (!vars().contains(v)) + vars().push_back(v); + + } + void umul_ovfl_constraint::simplify() { + if (m_p.is_zero() || m_q.is_zero() || m_p.is_one() || m_q.is_one()) { + m_q = 0; + m_p = 0; + return; + } + if (m_p.index() > m_q.index()) + swap(m_p, m_q); + } + + std::ostream& umul_ovfl_constraint::display(std::ostream& out, lbool status) const { + switch (status) { + case l_true: return display(out); + case l_false: return display(out << "~"); + case l_undef: return display(out << "?"); + } + return out; + } + + std::ostream& umul_ovfl_constraint::display(std::ostream& out) const { + return out << "ovfl*(" << m_p << ", " << m_q << ")"; + } + + lbool umul_ovfl_constraint::eval(pdd const& p, pdd const& q) { + if (p.is_zero() || q.is_zero() || p.is_one() || q.is_one()) + return l_false; + + if (p.is_val() && q.is_val()) { + if (p.val() * q.val() > p.manager().max_value()) + return l_true; + else + return l_false; + } + return l_undef; + } + + lbool umul_ovfl_constraint::eval() const { + return eval(p(), q()); + } + + lbool umul_ovfl_constraint::eval(assignment const& a) const { + return eval(a.apply_to(p()), a.apply_to(q())); + } + +} diff --git a/src/sat/smt/polysat_umul_ovfl.h b/src/sat/smt/polysat_umul_ovfl.h new file mode 100644 index 000000000..502ed4bbf --- /dev/null +++ b/src/sat/smt/polysat_umul_ovfl.h @@ -0,0 +1,39 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat multiplication overflow constraint + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 + +--*/ +#pragma once +#include "sat/smt/polysat_constraints.h" + +namespace polysat { + + class umul_ovfl_constraint final : public constraint { + + pdd m_p; + pdd m_q; + + void simplify(); + static bool is_always_true(bool is_positive, pdd const& p, pdd const& q) { return eval(p, q) == to_lbool(is_positive); } + static bool is_always_false(bool is_positive, pdd const& p, pdd const& q) { return is_always_true(!is_positive, p, q); } + static lbool eval(pdd const& p, pdd const& q); + + public: + umul_ovfl_constraint(pdd const& p, pdd const& q); + ~umul_ovfl_constraint() override {} + pdd const& p() const { return m_p; } + pdd const& q() const { return m_q; } + std::ostream& display(std::ostream& out, lbool status) const override; + std::ostream& display(std::ostream& out) const override; + lbool eval() const override; + lbool eval(assignment const& a) const override; + }; + +} diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat_viable.cpp new file mode 100644 index 000000000..79689d01f --- /dev/null +++ b/src/sat/smt/polysat_viable.cpp @@ -0,0 +1,36 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + maintain viable domains + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +Notes: + + +--*/ + + +#include "util/debug.h" +#include "sat/smt/polysat_viable.h" +#include "sat/smt/polysat_core.h" + +namespace polysat { + + std::ostream& operator<<(std::ostream& out, find_t f) { + switch (f) { + case find_t::empty: return out << "empty"; + case find_t::singleton: return out << "singleton"; + case find_t::multiple: return out << "multiple"; + case find_t::resource_out: return out << "resource-out"; + default: return out << ""; + } + } + + +} diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h index def069652..2f87e79cc 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat_viable.h @@ -30,6 +30,8 @@ namespace polysat { class core; + std::ostream& operator<<(std::ostream& out, find_t x); + class viable { core& c; public: