From bdc40b1f5f83cca22dc1d6c5808e935a3b50176c Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 16 Dec 2023 16:10:06 -0800 Subject: [PATCH] na --- src/sat/smt/intblast_solver.cpp | 98 ++++----- src/sat/smt/polysat/CMakeLists.txt | 1 + src/sat/smt/polysat/constraints.cpp | 8 + src/sat/smt/polysat/constraints.h | 2 + src/sat/smt/polysat/core.cpp | 27 ++- src/sat/smt/polysat/core.h | 19 +- src/sat/smt/polysat/inequality.cpp | 131 ++++++++++++ src/sat/smt/polysat/inequality.h | 164 +++++++++++++++ src/sat/smt/polysat/saturation.cpp | 296 ++++++---------------------- src/sat/smt/polysat/saturation.h | 126 ++---------- src/sat/smt/polysat/types.h | 4 +- src/sat/smt/polysat/viable.cpp | 2 +- src/sat/smt/polysat_solver.cpp | 10 +- src/sat/smt/polysat_solver.h | 4 +- 14 files changed, 478 insertions(+), 414 deletions(-) create mode 100644 src/sat/smt/polysat/inequality.cpp create mode 100644 src/sat/smt/polysat/inequality.h diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 9d03d0ad0..27f85525d 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -104,10 +104,10 @@ namespace intblast { ctx.push(push_back_vector(m_preds)); } - void solver::set_translated(expr* e, expr* r) { + void solver::set_translated(expr* e, expr* r) { SASSERT(r); - SASSERT(!is_translated(e)); - m_translate.setx(e->get_id(), r); + SASSERT(!is_translated(e)); + m_translate.setx(e->get_id(), r); ctx.push(set_vector_idx_trail(m_translate, e->get_id())); } @@ -148,7 +148,7 @@ namespace intblast { auto a = expr2literal(e); auto b = mk_literal(r); ctx.mark_relevant(b); -// verbose_stream() << "add-predicate-axiom: " << mk_pp(e, m) << " == " << r << "\n"; + // verbose_stream() << "add-predicate-axiom: " << mk_pp(e, m) << " == " << r << "\n"; add_equiv(a, b); } return true; @@ -157,7 +157,7 @@ namespace intblast { bool solver::unit_propagate() { return add_bound_axioms() || add_predicate_axioms(); } - + void solver::ensure_translated(expr* e) { if (m_translate.get(e->get_id(), nullptr)) return; @@ -179,7 +179,7 @@ namespace intblast { } } std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); - for (expr* e : todo) + for (expr* e : todo) translate_expr(e); } @@ -305,28 +305,6 @@ namespace intblast { sorted.push_back(arg); } } - - // - // Add ground equalities to ensure the model is valid with respect to the current case splits. - // This may cause more conflicts than necessary. Instead could use intblast on the base level, but using literal - // assignment from complete level. - // E.g., force the solver to completely backtrack, check satisfiability using the assignment obtained under a complete assignment. - // If intblast is SAT, then force the model and literal assignment on the rest of the literals. - // - if (!is_ground(e)) - continue; - euf::enode* n = ctx.get_enode(e); - if (!n) - continue; - if (n == n->get_root()) - continue; - expr* r = n->get_root()->get_expr(); - es.push_back(m.mk_eq(e, r)); - r = es.back(); - if (!visited.is_marked(r) && !is_translated(r)) { - visited.mark(r); - sorted.push_back(r); - } } else if (is_quantifier(e)) { quantifier* q = to_quantifier(e); @@ -357,7 +335,7 @@ namespace intblast { es[i] = translated(es.get(i)); } - sat::check_result solver::check() { + sat::check_result solver::check() { // ensure that bv2int is injective for (auto e : m_bv2int) { euf::enode* n = expr2enode(e); @@ -369,10 +347,10 @@ namespace intblast { continue; if (sib->get_arg(0)->get_root() == r1) continue; - auto a = eq_internalize(n, sib); - auto b = eq_internalize(sib->get_arg(0), n->get_arg(0)); - ctx.mark_relevant(a); - ctx.mark_relevant(b); + auto a = eq_internalize(n, sib); + auto b = eq_internalize(sib->get_arg(0), n->get_arg(0)); + ctx.mark_relevant(a); + ctx.mark_relevant(b); add_clause(~a, b, nullptr); return sat::check_result::CR_CONTINUE; } @@ -390,13 +368,13 @@ namespace intblast { auto nBv2int = ctx.get_enode(bv2int); auto nxModN = ctx.get_enode(xModN); if (nBv2int->get_root() != nxModN->get_root()) { - auto a = eq_internalize(nBv2int, nxModN); - ctx.mark_relevant(a); + auto a = eq_internalize(nBv2int, nxModN); + ctx.mark_relevant(a); add_unit(a); return sat::check_result::CR_CONTINUE; } } - return sat::check_result::CR_DONE; + return sat::check_result::CR_DONE; } expr* solver::umod(expr* bv_expr, unsigned i) { @@ -504,8 +482,8 @@ namespace intblast { m_args[i] = bv.mk_int2bv(bv.get_bv_size(e->get_arg(i)), m_args.get(i)); if (has_bv_sort) - m_vars.push_back(e); - + m_vars.push_back(e); + if (m_is_plugin) { expr* r = m.mk_app(f, m_args); if (has_bv_sort) { @@ -526,7 +504,7 @@ namespace intblast { f = g; m_pinned.push_back(f); } - set_translated(e, m.mk_app(f, m_args)); + set_translated(e, m.mk_app(f, m_args)); } void solver::translate_bv(app* e) { @@ -558,7 +536,7 @@ namespace intblast { r = a.mk_add(hi, lo); } return r; - }; + }; expr* bv_expr = e; expr* r = nullptr; @@ -659,7 +637,7 @@ namespace intblast { expr* x = arg(0), * y = umod(e, 1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); break; } case OP_BNOT: @@ -671,7 +649,7 @@ namespace intblast { for (unsigned i = 0; i < bv.get_bv_size(e); ++i) r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; - } + } case OP_BOR: { // p | q := (p + q) - band(p, q) r = arg(0); @@ -706,13 +684,13 @@ namespace intblast { // unsigned sz = bv.get_bv_size(e); rational N = bv_size(e); - expr* x = umod(e, 0), *y = umod(e, 1); + expr* x = umod(e, 0), * y = umod(e, 1); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); - r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0)); + r = m.mk_ite(signx, a.mk_int(-1), a.mk_int(0)); for (unsigned i = 0; i < sz; ++i) { - expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), - m.mk_ite(signx, a.mk_add(d, a.mk_int(- rational::power_of_two(sz-i))), d), + m.mk_ite(signx, a.mk_add(d, a.mk_int(-rational::power_of_two(sz - i))), d), r); } break; @@ -749,11 +727,11 @@ namespace intblast { r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); break; case OP_BSMOD_I: - case OP_BSMOD: { - expr* x = umod(e, 0), *y = umod(e, 1); - rational N = bv_size(e); - expr* signx = a.mk_ge(x, a.mk_int(N/2)); - expr* signy = a.mk_ge(y, a.mk_int(N/2)); + case OP_BSMOD: { + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); expr* u = a.mk_mod(x, y); // u = 0 -> 0 // y = 0 -> x @@ -761,14 +739,14 @@ namespace intblast { // x < 0, y >= 0 -> y - u // x >= 0, y < 0 -> y + u // x >= 0, y >= 0 -> u - r = a.mk_uminus(u); + r = a.mk_uminus(u); r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), a.mk_add(u, y), r); r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); r = m.mk_ite(m.mk_and(m.mk_not(signx), m.mk_not(signy)), u, r); r = m.mk_ite(m.mk_eq(u, a.mk_int(0)), a.mk_int(0), r); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); break; - } + } case OP_BSDIV_I: case OP_BSDIV: { // d = udiv(abs(x), abs(y)) @@ -804,7 +782,7 @@ namespace intblast { d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); r = a.mk_sub(x, a.mk_mul(d, y)); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); - break; + break; } case OP_ROTATE_LEFT: { auto n = e->get_parameter(0).get_int(); @@ -817,11 +795,11 @@ namespace intblast { r = rotate_left(sz - n); break; } - case OP_EXT_ROTATE_LEFT: { + case OP_EXT_ROTATE_LEFT: { unsigned sz = bv.get_bv_size(e); expr* y = umod(e, 1); r = a.mk_int(0); - for (unsigned i = 0; i < sz; ++i) + for (unsigned i = 0; i < sz; ++i) r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(i), r); break; } @@ -829,7 +807,7 @@ namespace intblast { unsigned sz = bv.get_bv_size(e); expr* y = umod(e, 1); r = a.mk_int(0); - for (unsigned i = 0; i < sz; ++i) + for (unsigned i = 0; i < sz; ++i) r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(sz - i), r); break; } @@ -842,7 +820,7 @@ namespace intblast { for (unsigned i = 1; i < n; ++i) r = a.mk_add(a.mk_mul(a.mk_int(N), x), r), N *= N0; break; - } + } case OP_BREDOR: { r = umod(e->get_arg(0), 0); r = m.mk_not(m.mk_eq(r, a.mk_int(0))); @@ -902,7 +880,7 @@ namespace intblast { } bool solver::add_dep(euf::enode* n, top_sort& dep) { - if (!is_app(n->get_expr())) + if (!is_app(n->get_expr())) return false; app* e = to_app(n->get_expr()); if (n->num_args() == 0) { @@ -921,7 +899,7 @@ namespace intblast { void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { expr* e = n->get_expr(); SASSERT(bv.is_bv(e)); - + if (bv.is_numeral(e)) { values.setx(n->get_root_id(), e); return; diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt index 72d919b94..70e0f9592 100644 --- a/src/sat/smt/polysat/CMakeLists.txt +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -5,6 +5,7 @@ z3_add_component(polysat core.cpp fixed_bits.cpp forbidden_intervals.cpp + inequality.cpp op_constraint.cpp saturation.cpp ule_constraint.cpp diff --git a/src/sat/smt/polysat/constraints.cpp b/src/sat/smt/polysat/constraints.cpp index 83476160a..c96b43d6a 100644 --- a/src/sat/smt/polysat/constraints.cpp +++ b/src/sat/smt/polysat/constraints.cpp @@ -89,4 +89,12 @@ namespace polysat { return out << *m_constraint; } + bool signed_constraint::is_currently_true(core& c) const { + return eval(c.get_assignment()) == l_true; + } + + bool signed_constraint::is_currently_false(core& c) const { + return eval(c.get_assignment()) == l_false; + } + } diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index fa2b62c11..28d4b0529 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -72,6 +72,8 @@ namespace polysat { void propagate(core& c, lbool value, dependency const& d) { m_constraint->propagate(c, value, d); } bool is_always_true() const { return eval() == l_true; } bool is_always_false() const { return eval() == l_false; } + bool is_currently_true(core& c) const; + bool is_currently_false(core& c) const; lbool eval(assignment& a) const; lbool eval() const { return m_sign ? ~m_constraint->eval() : m_constraint->eval();} ckind_t op() const { return m_op; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 6cf1db764..79f3f6bcb 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -314,6 +314,16 @@ namespace polysat { return result; } + dependency_vector core::get_dependencies(std::initializer_list const& cc) { + dependency_vector result; + for (auto idx : cc) { + auto [sc, d, value] = m_constraint_index[idx.id]; + SASSERT(value != l_undef); + result.push_back(value == l_false ? ~d : d); + } + return result; + } + void core::propagate(constraint_id id, signed_constraint& sc, lbool value, dependency const& d) { lbool eval_value = eval(sc); if (eval_value == l_undef) @@ -327,8 +337,8 @@ namespace polysat { } } - void core::get_bitvector_prefixes(pvar v, pvar_vector& out) { - s.get_bitvector_prefixes(v, out); + void core::get_bitvector_suffixes(pvar v, pvar_vector& out) { + s.get_bitvector_suffixes(v, out); } void core::get_fixed_bits(pvar v, svector& fixed_bits) { @@ -415,4 +425,17 @@ namespace polysat { s.add_polysat_clause(name, cs, is_redundant); } + signed_constraint core::get_constraint(constraint_id idx) { + auto [sc, d, value] = m_constraint_index[idx.id]; + SASSERT(value != l_undef); + if (value == l_false) + sc = ~sc; + return sc; + } + + lbool core::eval(constraint_id id) { + auto sc = get_constraint(id); + return sc.eval(m_assignment); + } + } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index 109f0ac0e..37cc348f7 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -83,7 +83,7 @@ namespace polysat { void propagate_unsat_core(); void propagate(constraint_id id, signed_constraint& sc, lbool value, dependency const& d); - void get_bitvector_prefixes(pvar v, pvar_vector& out); + void get_bitvector_suffixes(pvar v, pvar_vector& out); void get_fixed_bits(pvar v, svector& fixed_bits); bool inconsistent() const; @@ -92,6 +92,9 @@ namespace polysat { lbool eval(signed_constraint const& sc); constraint_id_vector explain_eval(signed_constraint const& sc); dependency_vector get_dependencies(constraint_id_vector const& cc); + dependency_vector get_dependencies(std::initializer_list const& cc); + + void add_axiom(signed_constraint sc); @@ -143,6 +146,20 @@ namespace polysat { trail_stack& trail(); std::ostream& display(std::ostream& out) const; + + /* + * Saturation + */ + signed_constraint get_constraint(constraint_id id); + constraint_id_vector const& unsat_core() const { return m_unsat_core; } + lbool eval(constraint_id id); + bool propagate(signed_constraint const& sc, constraint_id_vector const& ids) { return s.propagate(sc, get_dependencies(ids)); } + bool propagate(signed_constraint const& sc, std::initializer_list const& ids) { return s.propagate(sc, get_dependencies(ids)); } + + /* + * Constraints + */ + assignment& get_assignment() { return m_assignment; } }; } diff --git a/src/sat/smt/polysat/inequality.cpp b/src/sat/smt/polysat/inequality.cpp new file mode 100644 index 000000000..e2233e5e1 --- /dev/null +++ b/src/sat/smt/polysat/inequality.cpp @@ -0,0 +1,131 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Polysat inequalities + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + + + +--*/ +#include "sat/smt/polysat/core.h" +#include "sat/smt/polysat/inequality.h" +#include "sat/smt/polysat/ule_constraint.h" + + +namespace polysat { + + + inequality inequality::from_ule(core& c, constraint_id id) { + auto src = c.get_constraint(id); + ule_constraint const& ule = src.to_ule(); + if (src.is_positive()) + return inequality(c, id, ule.lhs(), ule.rhs(), src); + else + return inequality(c, id, ule.rhs(), ule.lhs(), src); + } + + +#if 0 + + + bool saturation::verify_Y_l_AxB(pvar x, inequality const& i, pdd const& y, pdd const& a, pdd& b) { + return i.lhs() == y && i.rhs() == a * c.var(x) + b; + } + + + /** + * Match [x] a*x + b <= y, val(y) = 0 + */ + bool saturation::is_AxB_eq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { + y.reset(i.rhs().manager()); + y = i.rhs(); + rational y_val; + if (!c.try_eval(y, y_val) || y_val != 0) + return false; + return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); + } + + bool saturation::verify_AxB_eq_0(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { + return y.is_val() && y.val() == 0 && i.rhs() == y && i.lhs() == a * c.var(x) + b; + } + + bool saturation::is_AxB_diseq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { + if (!i.is_strict()) + return false; + y.reset(i.lhs().manager()); + y = i.lhs(); + if (i.rhs().is_val() && i.rhs().val() == 1) + return false; + rational y_val; + if (!c.try_eval(y, y_val) || y_val != 0) + return false; + a.reset(i.rhs().manager()); + b.reset(i.rhs().manager()); + return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); + } + + /** + * Match [coeff*x] coeff*x*Y where x is a variable + */ + bool saturation::is_coeffxY(pdd const& x, pdd const& p, pdd& y) { + pdd xy = x.manager().zero(); + return x.is_unary() && p.try_div(x.hi().val(), xy) && xy.factor(x.var(), 1, y); + } + + /** + * Match [v] v*x <= z*x with x a variable + */ + bool saturation::is_Xy_l_XZ(pvar v, inequality const& i, pdd& x, pdd& z) { + return is_xY(v, i.lhs(), x) && is_coeffxY(x, i.rhs(), z); + } + + bool saturation::verify_Xy_l_XZ(pvar v, inequality const& i, pdd const& x, pdd const& z) { + return i.lhs() == c.var(v) * x && i.rhs() == z * x; + } + + + /** + * Determine whether values of x * y is non-overflowing. + */ + bool saturation::is_non_overflow(pdd const& x, pdd const& y) { + rational x_val, y_val; + rational bound = x.manager().two_to_N(); + return c.try_eval(x, x_val) && c.try_eval(y, y_val) && x_val * y_val < bound; + } + + + /** + * Match [z] yx <= zx with x a variable + */ + bool saturation::is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y) { + return is_xY(z, c.rhs(), x) && is_coeffxY(x, c.lhs(), y); + } + + bool saturation::verify_YX_l_zX(pvar z, inequality const& i, pdd const& x, pdd const& y) { + return i.lhs() == y * x && i.rhs() == c.var(z) * x; + } + + /** + * Match [x] xY <= xZ + */ + bool saturation::is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z) { + return is_xY(x, c.lhs(), y) && is_xY(x, c.rhs(), z); + } + + /** + * Match xy = x * Y + */ + bool saturation::is_xY(pvar x, pdd const& xy, pdd& y) { + return xy.degree(x) == 1 && xy.factor(x, 1, y); + } + +#endif + + +} diff --git a/src/sat/smt/polysat/inequality.h b/src/sat/smt/polysat/inequality.h new file mode 100644 index 000000000..03a45bfc1 --- /dev/null +++ b/src/sat/smt/polysat/inequality.h @@ -0,0 +1,164 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Polysat core saturation + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + +--*/ +#pragma once + +#include "sat/smt/polysat/constraints.h" + +namespace polysat { + + /// Normalized inequality: + /// lhs <= rhs, if !is_strict + /// lhs < rhs, otherwise + class inequality { + core& c; + constraint_id m_id; + pdd m_lhs; + pdd m_rhs; + signed_constraint m_src; + + inequality(core& c, constraint_id id, pdd lhs, pdd rhs, signed_constraint src) : + c(c), m_id(id), m_lhs(std::move(lhs)), m_rhs(std::move(rhs)), m_src(std::move(src)) {} + + void set(pdd& dst, pdd const& src) const { + dst.reset(src.manager()); + dst = src; + } + + public: + static inequality from_ule(core& c, constraint_id id); + pdd const& lhs() const { return m_lhs; } + pdd const& rhs() const { return m_rhs; } + bool is_strict() const { return m_src.is_negative(); } + constraint_id id() const { return m_id; } + signed_constraint as_signed_constraint() const { return m_src; } + operator signed_constraint() const { return m_src; } + + // c := lhs ~ v + // where ~ is < or <= + bool is_l_v(pvar v) const { return rhs() == c.var(v); } + + // c := v ~ rhs + bool is_g_v(pvar v) const { return lhs() == c.var(v); } + + // c := x ~ Y + bool is_x_l_Y(pvar x, pdd& y) const { return is_g_v(x) && (set(y, rhs()), true); } + + // c := Y ~ x + bool is_Y_l_x(pvar x, pdd& y) const { return is_l_v(x) && (set(y, lhs()), true); } + + // c := Y ~ Ax + bool is_Y_l_Ax(pvar x, pdd& a, pdd& y) const { return is_xY(x, rhs(), a) && (set(y, lhs()), true); } + bool verify_Y_l_Ax(pvar x, pdd const& a, pdd const& y) const { return lhs() == y && rhs() == a * c.var(x); } + + // c := X*y ~ X*Z + bool is_Xy_l_XZ(pvar y, pdd& x, pdd& z) const { return is_xY(y, lhs(), x) && (false); } + bool verify_Xy_l_XZ(pvar y, pdd const& x, pdd const& z) const { lhs() == c.var(y) * x && rhs() == z * x; } + + // c := Ax ~ Y + bool is_Ax_l_Y(pvar x, pdd& a, pdd& y) const; + bool verify_Ax_l_Y(pvar x, pdd const& a, pdd const& y) const; + + // c := Ax + B ~ Y + bool is_AxB_l_Y(pvar x, pdd& a, pdd& b, pdd& y) const { + return lhs().degree(x) == 1 && (set(y, rhs()), lhs().factor(x, 1, a, b), true); + } + bool verify_AxB_l_Y(pvar x, pdd const& a, pdd const& b, pdd const& y) const { return rhs() == y && lhs() == a * c.var(x) + b; } + + // c := Y ~ Ax + B + bool is_Y_l_AxB(pvar x, pdd& y, pdd& a, pdd& b) const { return rhs().degree(x) == 1 && (set(y, lhs()), rhs().factor(x, 1, a, b), true); } + bool verify_Y_l_AxB(pvar x, pdd const& y, pdd const& a, pdd& b) const; + + // c := Ax + B ~ Y, val(Y) = 0 + bool is_AxB_eq_0(pvar x, pdd& a, pdd& b, pdd& y) const; + bool verify_AxB_eq_0(pvar x, pdd const& a, pdd const& b, pdd const& y) const; + + // c := Ax + B != Y, val(Y) = 0 + bool is_AxB_diseq_0(pvar x, pdd& a, pdd& b, pdd& y) const; + + // c := Y*X ~ z*X + bool is_YX_l_zX(pvar z, pdd& x, pdd& y) const; + bool verify_YX_l_zX(pvar z, pdd const& x, pdd const& y) const; + + // c := xY <= xZ + bool is_xY_l_xZ(pvar x, pdd& y, pdd& z) const; + + /** + * Match xy = x * Y + */ + static bool is_xY(pvar x, pdd const& xy, pdd& y) { return xy.degree(x) == 1 && xy.factor(x, 1, y); } + + /** + * Rewrite to one of six equivalent forms: + * + * i=0 p <= q (unchanged) + * i=1 p <= p - q - 1 + * i=2 q - p <= q + * i=3 q - p <= -p - 1 + * i=4 -q - 1 <= -p - 1 + * i=5 -q - 1 <= p - q - 1 + */ + //inequality rewrite_equiv(int i) const; + + //std::ostream& display(std::ostream& out) const; + }; + + + struct bilinear { + rational a, b, c, d; + + rational eval(rational const& x, rational const& y) const { + return a*x*y + b*x + c*y + d; + } + + bilinear operator-() const { + bilinear r(*this); + r.a = -r.a; + r.b = -r.b; + r.c = -r.c; + r.d = -r.d; + return r; + } + + bilinear operator-(bilinear const& other) const { + bilinear r(*this); + r.a -= other.a; + r.b -= other.b; + r.c -= other.c; + r.d -= other.d; + return r; + } + + bilinear operator+(rational const& d) const { + bilinear r(*this); + r.d += d; + return r; + } + + bilinear operator-(rational const& d) const { + bilinear r(*this); + r.d -= d; + return r; + } + + bilinear operator-(int d) const { + bilinear r(*this); + r.d -= d; + return r; + } +}; + + inline std::ostream& operator<<(std::ostream& out, bilinear const& b) { + return out << b.a << "*x*y + " << b.b << "*x + " << b.c << "*y + " << b.d; + } +} diff --git a/src/sat/smt/polysat/saturation.cpp b/src/sat/smt/polysat/saturation.cpp index c2a961e7f..204a44de3 100644 --- a/src/sat/smt/polysat/saturation.cpp +++ b/src/sat/smt/polysat/saturation.cpp @@ -34,29 +34,70 @@ namespace polysat { saturation::saturation(core& c) : c(c), C(c.cs()) {} -#if 0 - void saturation::perform(pvar v) { - for (signed_constraint c : core) - perform(v, sc, core); + void saturation::propagate(pvar v) { + for (auto id : c.unsat_core()) + propagate(v, id); } - bool saturation::perform(pvar v, signed_constraint sc) { - if (sc.is_currently_true(c)) + bool saturation::propagate(pvar v, constraint_id id) { + if (c.eval(id) == l_true) return false; + auto sc = c.get_constraint(id); + m_propagated = false; + if (sc.is_ule()) + propagate(v, inequality::from_ule(c, id)); + else + ; - if (sc.is_ule()) { - auto i = inequality::from_ule(sc); - return try_inequality(v, i); - } - if (sc.is_umul_ovfl()) - return try_umul_ovfl(v, sc); - - if (sc.is_op()) - return try_op(v, sc); - - return false; + return m_propagated; } + void saturation::propagate(pvar v, inequality const& i) { + if (c.size(v) != i.lhs().power_of_2()) + return; + propagate_infer_equality(v, i); + return; + + } + + void saturation::propagate(signed_constraint const& sc, std::initializer_list const& premises) { + if (c.propagate(sc, premises)) + m_propagated = true; + } + + /** + * p <= q, q <= p => p - q = 0 + */ + void saturation::propagate_infer_equality(pvar x, inequality const& a_l_b) { + set_rule("[x] p <= q, q <= p => p - q = 0"); + if (a_l_b.is_strict()) + return; + if (a_l_b.lhs().degree(x) == 0 && a_l_b.rhs().degree(x) == 0) + return; + for (auto id : c.unsat_core()) { + auto sc = c.get_constraint(id); + if (!sc.is_ule()) + continue; + auto i = inequality::from_ule(c, id); + if (i.lhs() == a_l_b.rhs() && i.rhs() == a_l_b.lhs() && !i.is_strict()) { + c.propagate(c.eq(i.lhs(), i.rhs()), { id, a_l_b.id() }); + return; + } + } + } + + /** + * Determine whether values of x * y is non-overflowing. + */ + bool saturation::is_non_overflow(pdd const& x, pdd const& y) { + rational x_val, y_val; + rational bound = x.manager().two_to_N(); + return c.try_eval(x, x_val) && c.try_eval(y, y_val) && x_val * y_val < bound; + } + +#if 0 + + bool saturation::try_inequality(pvar v, inequality const& i, conflict& core) { bool prop = false; if (s.size(v) != i.lhs().power_of_2()) @@ -98,35 +139,6 @@ namespace polysat { return prop; } - bool saturation::try_congruence(pvar x, conflict& core, inequality const& i) { - set_rule("egraph(x == y) & C(x,y) ==> C(y,y)"); - // TODO: generalize to other constraint types? - // if (!i.is_strict()) - // return false; - // if (!core.vars().contains(x)) - // return false; - if (!i.as_signed_constraint().contains_var(x)) - return false; - for (pvar y : s.m_slicing.equivalent_vars(x)) { - if (x == y) - continue; - if (!s.is_assigned(y)) - continue; - if (!core.vars().contains(y)) - continue; - if (!i.as_signed_constraint().contains_var(y)) - continue; - SASSERT(s.m_search.get_pvar_index(y) < s.m_search.get_pvar_index(x)); // y was the earlier one since we are currently resolving x - pdd const lhs = i.lhs().subst_pdd(x, s.var(y)); - pdd const rhs = i.rhs().subst_pdd(x, s.var(y)); - signed_constraint c = ineq(true, lhs, rhs); - m_lemma.reset(); - s.m_slicing.explain_equal(x, y, [&](sat::literal lit) { m_lemma.insert(~lit); }); - if (propagate(x, core, i, c)) - return true; - } - return false; - } bool saturation::try_nonzero_upper_extract(pvar y, conflict& core, inequality const& i) { set_rule("y = x[h:l] & y != 0 ==> x >= 2^l"); @@ -402,172 +414,7 @@ namespace polysat { return false; } - /* - * Match [v] .. <= v - */ - bool saturation::is_l_v(pvar v, inequality const& i) { - return i.rhs() == s.var(v); - } - - /* - * Match [v] v <= ... - */ - bool saturation::is_g_v(pvar v, inequality const& i) { - return i.lhs() == s.var(v); - } - - /* - * Match [x] x <= y - */ - bool saturation::is_x_l_Y(pvar x, inequality const& i, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - return is_g_v(x, i); - } - - /* - * Match [x] y <= x - */ - bool saturation::is_Y_l_x(pvar x, inequality const& i, pdd& y) { - y.reset(i.lhs().manager()); - y = i.lhs(); - return is_l_v(x, i); - } - - /* - * Match [x] y <= a*x - */ - bool saturation::is_Y_l_Ax(pvar x, inequality const& i, pdd& a, pdd& y) { - y.reset(i.lhs().manager()); - y = i.lhs(); - return is_xY(x, i.rhs(), a); - } - - bool saturation::verify_Y_l_Ax(pvar x, inequality const& i, pdd const& a, pdd const& y) { - return i.lhs() == y && i.rhs() == a * s.var(x); - } - - /** - * Match [x] a*x <= y - */ - bool saturation::is_Ax_l_Y(pvar x, inequality const& i, pdd& a, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - return is_xY(x, i.lhs(), a); - } - - bool saturation::verify_Ax_l_Y(pvar x, inequality const& i, pdd const& a, pdd const& y) { - return i.rhs() == y && i.lhs() == a * s.var(x); - } - - /** - * Match [x] a*x + b <= y - */ - bool saturation::is_AxB_l_Y(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); - } - - bool saturation::verify_AxB_l_Y(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { - return i.rhs() == y && i.lhs() == a * s.var(x) + b; - } - - - bool saturation::is_Y_l_AxB(pvar x, inequality const& i, pdd& y, pdd& a, pdd& b) { - y.reset(i.lhs().manager()); - y = i.lhs(); - return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); - } - - bool saturation::verify_Y_l_AxB(pvar x, inequality const& i, pdd const& y, pdd const& a, pdd& b) { - return i.lhs() == y && i.rhs() == a * s.var(x) + b; - } - - - /** - * Match [x] a*x + b <= y, val(y) = 0 - */ - bool saturation::is_AxB_eq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - rational y_val; - if (!s.try_eval(y, y_val) || y_val != 0) - return false; - return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); - } - - bool saturation::verify_AxB_eq_0(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { - return y.is_val() && y.val() == 0 && i.rhs() == y && i.lhs() == a * s.var(x) + b; - } - - bool saturation::is_AxB_diseq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { - if (!i.is_strict()) - return false; - y.reset(i.lhs().manager()); - y = i.lhs(); - if (i.rhs().is_val() && i.rhs().val() == 1) - return false; - rational y_val; - if (!s.try_eval(y, y_val) || y_val != 0) - return false; - a.reset(i.rhs().manager()); - b.reset(i.rhs().manager()); - return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); - } - - /** - * Match [coeff*x] coeff*x*Y where x is a variable - */ - bool saturation::is_coeffxY(pdd const& x, pdd const& p, pdd& y) { - pdd xy = x.manager().zero(); - return x.is_unary() && p.try_div(x.hi().val(), xy) && xy.factor(x.var(), 1, y); - } - - /** - * Determine whether values of x * y is non-overflowing. - */ - bool saturation::is_non_overflow(pdd const& x, pdd const& y) { - rational x_val, y_val; - rational bound = x.manager().two_to_N(); - return s.try_eval(x, x_val) && s.try_eval(y, y_val) && x_val * y_val < bound; - } - - /** - * Match [v] v*x <= z*x with x a variable - */ - bool saturation::is_Xy_l_XZ(pvar v, inequality const& i, pdd& x, pdd& z) { - return is_xY(v, i.lhs(), x) && is_coeffxY(x, i.rhs(), z); - } - - bool saturation::verify_Xy_l_XZ(pvar v, inequality const& i, pdd const& x, pdd const& z) { - return i.lhs() == s.var(v) * x && i.rhs() == z * x; - } - - /** - * Match [z] yx <= zx with x a variable - */ - bool saturation::is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y) { - return is_xY(z, c.rhs(), x) && is_coeffxY(x, c.lhs(), y); - } - - bool saturation::verify_YX_l_zX(pvar z, inequality const& c, pdd const& x, pdd const& y) { - return c.lhs() == y * x && c.rhs() == s.var(z) * x; - } - - /** - * Match [x] xY <= xZ - */ - bool saturation::is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z) { - return is_xY(x, c.lhs(), y) && is_xY(x, c.rhs(), z); - } - - /** - * Match xy = x * Y - */ - bool saturation::is_xY(pvar x, pdd const& xy, pdd& y) { - return xy.degree(x) == 1 && xy.factor(x, 1, y); - } + // // overall comment: we use value propagation to check if p is val @@ -1336,30 +1183,7 @@ namespace polysat { return false; } - /** - * p <= q, q <= p => p - q = 0 - */ - bool saturation::try_infer_equality(pvar x, conflict& core, inequality const& a_l_b) { - set_rule("[x] p <= q, q <= p => p - q = 0"); - if (a_l_b.is_strict()) - return false; - if (a_l_b.lhs().degree(x) == 0 && a_l_b.rhs().degree(x) == 0) - return false; - for (auto c : core) { - if (!c->is_ule()) - continue; - auto i = inequality::from_ule(c); - if (i.lhs() == a_l_b.rhs() && i.rhs() == a_l_b.lhs() && !i.is_strict()) { - m_lemma.reset(); - m_lemma.insert(~c); - if (propagate(x, core, a_l_b, s.eq(i.lhs() - i.rhs()))) { - IF_VERBOSE(1, verbose_stream() << "infer equality " << s.eq(i.lhs() - i.rhs()) << "\n"); - return true; - } - } - } - return false; - } + lbool saturation::get_multiple(const pdd& p1, const pdd& p2, pdd& out) { LOG("Check if " << p2 << " can be multiplied with something to get " << p1); diff --git a/src/sat/smt/polysat/saturation.h b/src/sat/smt/polysat/saturation.h index 2d81f52fa..58c3aca62 100644 --- a/src/sat/smt/polysat/saturation.h +++ b/src/sat/smt/polysat/saturation.h @@ -14,57 +14,10 @@ Author: #pragma once #include "sat/smt/polysat/constraints.h" +#include "sat/smt/polysat/inequality.h" namespace polysat { - struct bilinear { - rational a, b, c, d; - - rational eval(rational const& x, rational const& y) const { - return a*x*y + b*x + c*y + d; - } - - bilinear operator-() const { - bilinear r(*this); - r.a = -r.a; - r.b = -r.b; - r.c = -r.c; - r.d = -r.d; - return r; - } - - bilinear operator-(bilinear const& other) const { - bilinear r(*this); - r.a -= other.a; - r.b -= other.b; - r.c -= other.c; - r.d -= other.d; - return r; - } - - bilinear operator+(rational const& d) const { - bilinear r(*this); - r.d += d; - return r; - } - - bilinear operator-(rational const& d) const { - bilinear r(*this); - r.d -= d; - return r; - } - - bilinear operator-(int d) const { - bilinear r(*this); - r.d -= d; - return r; - } -}; - - inline std::ostream& operator<<(std::ostream& out, bilinear const& b) { - return out << b.a << "*x*y + " << b.b << "*x + " << b.c << "*y + " << b.d; - } - /** * Introduce lemmas that derive new (simpler) constraints from the current conflict and partial model. */ @@ -73,13 +26,26 @@ namespace polysat { core& c; constraints& C; char const* m_rule = nullptr; + bool m_propagated = false; + void set_rule(char const* r) { m_rule = r; } + + void propagate(signed_constraint const& sc, std::initializer_list const& premises); + + + // a * b does not overflow + bool is_non_overflow(pdd const& a, pdd const& b); + + // p := coeff*x*y where coeff_x = coeff*x, x a variable + bool is_coeffxY(pdd const& coeff_x, pdd const& p, pdd& y); + + void propagate_infer_equality(pvar x, inequality const& a_l_b); #if 0 parity_tracker m_parity_tracker; unsigned_vector m_occ; unsigned_vector m_occ_cnt; - void set_rule(char const* r) { m_rule = r; } + bool is_non_overflow(pdd const& x, pdd const& y, signed_constraint& c); signed_constraint ineq(bool strict, pdd const& lhs, pdd const& rhs); @@ -137,61 +103,7 @@ namespace polysat { void fix_values(pvar x, pvar y, pdd const& p); void fix_values(pvar y, pdd const& p); - // c := lhs ~ v - // where ~ is < or <= - bool is_l_v(pvar v, inequality const& c); - // c := v ~ rhs - bool is_g_v(pvar v, inequality const& c); - - // c := x ~ Y - bool is_x_l_Y(pvar x, inequality const& i, pdd& y); - - // c := Y ~ x - bool is_Y_l_x(pvar x, inequality const& i, pdd& y); - - // c := X*y ~ X*Z - bool is_Xy_l_XZ(pvar y, inequality const& c, pdd& x, pdd& z); - bool verify_Xy_l_XZ(pvar y, inequality const& c, pdd const& x, pdd const& z); - - // c := Y ~ Ax - bool is_Y_l_Ax(pvar x, inequality const& c, pdd& a, pdd& y); - bool verify_Y_l_Ax(pvar x, inequality const& c, pdd const& a, pdd const& y); - - // c := Ax ~ Y - bool is_Ax_l_Y(pvar x, inequality const& c, pdd& a, pdd& y); - bool verify_Ax_l_Y(pvar x, inequality const& c, pdd const& a, pdd const& y); - - // c := Ax + B ~ Y - bool is_AxB_l_Y(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); - bool verify_AxB_l_Y(pvar x, inequality const& c, pdd const& a, pdd const& b, pdd const& y); - - // c := Y ~ Ax + B - bool is_Y_l_AxB(pvar x, inequality const& c, pdd& y, pdd& a, pdd& b); - bool verify_Y_l_AxB(pvar x, inequality const& c, pdd const& y, pdd const& a, pdd& b); - - // c := Ax + B ~ Y, val(Y) = 0 - bool is_AxB_eq_0(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); - bool verify_AxB_eq_0(pvar x, inequality const& c, pdd const& a, pdd const& b, pdd const& y); - - // c := Ax + B != Y, val(Y) = 0 - bool is_AxB_diseq_0(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); - - // c := Y*X ~ z*X - bool is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y); - bool verify_YX_l_zX(pvar z, inequality const& c, pdd const& x, pdd const& y); - - // c := xY <= xZ - bool is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z); - - // xy := x * Y - bool is_xY(pvar x, pdd const& xy, pdd& y); - - // a * b does not overflow - bool is_non_overflow(pdd const& a, pdd const& b); - - // p := coeff*x*y where coeff_x = coeff*x, x a variable - bool is_coeffxY(pdd const& coeff_x, pdd const& p, pdd& y); bool is_add_overflow(pvar x, inequality const& i, pdd& y, bool& is_minus); @@ -222,7 +134,7 @@ namespace polysat { bool is_forced_true(signed_constraint const& sc); - bool try_inequality(pvar v, inequality const& i); + bool try_umul_ovfl(pvar v, signed_constraint c); bool try_umul_ovfl_noovfl(pvar v, signed_constraint c); @@ -233,9 +145,11 @@ namespace polysat { bool try_op(pvar v, signed_constraint c); #endif + void propagate(pvar v); + bool propagate(pvar v, constraint_id cid); + void propagate(pvar v, inequality const& i); + public: saturation(core& c); - void perform(pvar v); - bool perform(pvar v, signed_constraint sc); }; } diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index e0aefb6a9..1c88620fe 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -97,11 +97,11 @@ namespace polysat { virtual void set_conflict(dependency_vector const& core) = 0; virtual void set_lemma(core_vector const& aux_core, dependency_vector const& core) = 0; virtual void add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0; - virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; + virtual bool propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; virtual bool inconsistent() const = 0; - virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0; + virtual void get_bitvector_suffixes(pvar v, pvar_vector& out) = 0; virtual void get_fixed_bits(pvar v, svector& fixed_bits) = 0; }; diff --git a/src/sat/smt/polysat/viable.cpp b/src/sat/smt/polysat/viable.cpp index 3f217d1ce..6442f6739 100644 --- a/src/sat/smt/polysat/viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -104,7 +104,7 @@ namespace polysat { #endif pvar_vector overlaps; - c.get_bitvector_prefixes(v, overlaps); + c.get_bitvector_suffixes(v, overlaps); std::sort(overlaps.begin(), overlaps.end(), [&](pvar x, pvar y) { return c.size(x) > c.size(y); }); uint_set widths_set; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 219b9017a..28a040869 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -238,12 +238,14 @@ namespace polysat { // Core uses the propagate callback to add unit propagations to the trail. // The polysat::solver takes care of translating signed constraints into expressions, which translate into literals. // Everything goes over expressions/literals. polysat::core is not responsible for replaying expressions. - dependency solver::propagate(signed_constraint sc, dependency_vector const& deps) { + bool solver::propagate(signed_constraint sc, dependency_vector const& deps) { sat::literal lit = ctx.mk_literal(constraint2expr(sc)); + if (s().value(lit) == l_true) + return false; auto [core, eqs] = explain_deps(deps); auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); ctx.propagate(lit, ex); - return dependency(lit, s().lvl(lit)); + return true; } void solver::propagate(dependency const& d, bool sign, dependency_vector const& deps) { @@ -338,7 +340,7 @@ namespace polysat { } // walk the egraph starting with pvar for overlaps. - void solver::get_bitvector_prefixes(pvar pv, pvar_vector& out) { + void solver::get_bitvector_suffixes(pvar pv, pvar_vector& out) { theory_var v = m_pddvar2var[pv]; euf::enode_vector todo; uint_set seen; @@ -355,7 +357,7 @@ namespace polysat { // identify prefixes if (bv.is_concat(sib->get_expr())) - todo.push_back(sib->get_arg(0)); + todo.push_back(sib->get_arg(sib->num_args() - 1)); if (w == euf::null_theory_var) continue; if (seen.contains(w)) diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 60535207b..74989c1ac 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -162,11 +162,11 @@ namespace polysat { void add_eq_literal(pvar v, rational const& val) override; void set_conflict(dependency_vector const& core) override; void set_lemma(core_vector const& aux_core, dependency_vector const& core) override; - dependency propagate(signed_constraint sc, dependency_vector const& deps) override; + bool propagate(signed_constraint sc, dependency_vector const& deps) override; void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override; bool inconsistent() const override; - void get_bitvector_prefixes(pvar v, pvar_vector& out) override; + void get_bitvector_suffixes(pvar v, pvar_vector& out) override; void get_fixed_bits(pvar v, svector& fixed_bits) override; void add_polysat_clause(char const* name, core_vector cs, bool redundant) override;