From f0184c3fde1eb8f28fdfdc140afc8286fd7c87dc Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 27 Jul 2023 13:21:45 -0700 Subject: [PATCH] update format and checker for implied-eq Signed-off-by: Nikolaj Bjorner --- src/math/lp/lp_bound_propagator.h | 3 +- src/sat/smt/arith_axioms.cpp | 8 +- src/sat/smt/arith_diagnostics.cpp | 33 +++-- src/sat/smt/arith_solver.cpp | 9 +- src/sat/smt/arith_solver.h | 7 +- src/sat/smt/arith_theory_checker.h | 221 +++++++++++++++++++---------- src/sat/smt/euf_proof.cpp | 2 + src/sat/smt/sat_th.cpp | 3 + 8 files changed, 192 insertions(+), 94 deletions(-) diff --git a/src/math/lp/lp_bound_propagator.h b/src/math/lp/lp_bound_propagator.h index da2e4488d..6056444e8 100644 --- a/src/math/lp/lp_bound_propagator.h +++ b/src/math/lp/lp_bound_propagator.h @@ -612,7 +612,8 @@ class lp_bound_propagator { constraint_index lc, uc; lp().get_bound_constraint_witnesses_for_column(j, lc, uc); ex.push_back(lc); - ex.push_back(uc); + if (lc != uc) + ex.push_back(uc); } vector connect_in_tree(const vertex* u, const vertex* v) const { diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 93917042e..c2a96177d 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -539,10 +539,12 @@ namespace arith { if (x->get_root() == y->get_root()) return; reset_evidence(); - set_evidence(ci1); - set_evidence(ci2); + m_explanation.clear(); + consume(rational::one(), ci1); + consume(rational::one(), ci2); ++m_stats.m_fixed_eqs; - auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y); + auto* hint = explain_implied_eq(m_explanation, x, y); + auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y, hint); ctx.propagate(x, y, jst->to_index()); } diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index 77c51d87c..9fe7540e0 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -32,7 +32,7 @@ namespace arith { } arith_proof_hint* arith_proof_hint_builder::mk(euf::solver& s) { - return new (s.get_region()) arith_proof_hint(m_ty, m_num_le, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); + return new (s.get_region()) arith_proof_hint(m_ty, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); } std::ostream& solver::display(std::ostream& out) const { @@ -164,7 +164,6 @@ namespace arith { return nullptr; m_arith_hint.set_type(ctx, hint_type::implied_eq_h); explain_assumptions(e); - m_arith_hint.set_num_le(1); // TODO m_arith_hint.add_diseq(a, b); return m_arith_hint.mk(ctx); } @@ -173,13 +172,19 @@ namespace arith { if (!ctx.use_drat()) return nullptr; m_arith_hint.set_type(ctx, hint_type::implied_eq_h); - m_arith_hint.set_num_le(1); m_arith_hint.add_lit(rational(1), le); m_arith_hint.add_lit(rational(1), ge); m_arith_hint.add_lit(rational(1), ~eq); return m_arith_hint.mk(ctx); } + /** + * The expected format is: + * 1. all equalities + * 2. all inequalities + * 3. optional disequalities (used for the steps that propagate equalities) + */ + expr* arith_proof_hint::get_hint(euf::solver& s) const { ast_manager& m = s.get_manager(); family_id fid = m.get_family_id("arith"); @@ -200,29 +205,39 @@ namespace arith { break; case hint_type::implied_eq_h: name = "implied-eq"; - args.push_back(arith.mk_int(m_num_le)); break; default: name = "unknown-arithmetic"; break; } - rational lc(1); - for (unsigned i = m_lit_head; i < m_lit_tail; ++i) - lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first)); - for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { - auto [x, y, is_eq] = a.m_arith_hint.eq(i); + + auto push_eq = [&](bool is_eq, enode* x, enode* y) { if (x->get_id() > y->get_id()) std::swap(x, y); expr_ref eq(m.mk_eq(x->get_expr(), y->get_expr()), m); if (!is_eq) eq = m.mk_not(eq); args.push_back(arith.mk_int(1)); args.push_back(eq); + }; + rational lc(1); + for (unsigned i = m_lit_head; i < m_lit_tail; ++i) + lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first)); + for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { + auto [x, y, is_eq] = a.m_arith_hint.eq(i); + if (is_eq) + push_eq(is_eq, x, y); } for (unsigned i = m_lit_head; i < m_lit_tail; ++i) { auto const& [coeff, lit] = a.m_arith_hint.lit(i); args.push_back(arith.mk_int(abs(coeff*lc))); args.push_back(s.literal2expr(lit)); } + for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { + auto [x, y, is_eq] = a.m_arith_hint.eq(i); + if (!is_eq) + push_eq(is_eq, x, y); + } + return m.mk_app(symbol(name), args.size(), args.data(), m.mk_proof_sort()); } } diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 89bccf76b..f30c29872 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -713,10 +713,11 @@ namespace arith { ++m_stats.m_fixed_eqs; reset_evidence(); - set_evidence(ci1); - set_evidence(ci2); - set_evidence(ci3); - set_evidence(ci4); + m_explanation.clear(); + consume(rational::one(), ci1); + consume(rational::one(), ci2); + consume(rational::one(), ci3); + consume(rational::one(), ci4); enode* x = var2enode(v1); enode* y = var2enode(v2); auto* ex = explain_implied_eq(m_explanation, x, y); diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 8922673c0..4ff46ba13 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -57,10 +57,9 @@ namespace arith { struct arith_proof_hint : public euf::th_proof_hint { hint_type m_ty; - unsigned m_num_le; unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail; - arith_proof_hint(hint_type t, unsigned num_le, unsigned lh, unsigned lt, unsigned eh, unsigned et): - m_ty(t), m_num_le(num_le), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} + arith_proof_hint(hint_type t, unsigned lh, unsigned lt, unsigned eh, unsigned et): + m_ty(t), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} expr* get_hint(euf::solver& s) const override; }; @@ -68,7 +67,6 @@ namespace arith { vector> m_literals; svector> m_eqs; hint_type m_ty; - unsigned m_num_le = 0; unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0; void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; } void add(euf::enode* a, euf::enode* b, bool is_eq) { @@ -80,7 +78,6 @@ namespace arith { } public: void set_type(euf::solver& ctx, hint_type ty); - void set_num_le(unsigned n) { m_num_le = n; } void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); } void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); } void add_lit(rational const& coeff, literal lit) { diff --git a/src/sat/smt/arith_theory_checker.h b/src/sat/smt/arith_theory_checker.h index 65c647658..87868d940 100644 --- a/src/sat/smt/arith_theory_checker.h +++ b/src/sat/smt/arith_theory_checker.h @@ -35,6 +35,15 @@ The module assumes a limited repertoire of arithmetic proof rules. namespace arith { class theory_checker : public euf::theory_checker_plugin { + + enum rule_type_t { + cut_t, + farkas_t, + implied_eq_t, + bound_t, + none_t + }; + struct row { obj_map m_coeffs; rational m_coeff; @@ -42,6 +51,9 @@ namespace arith { m_coeffs.reset(); m_coeff = 0; } + bool is_zero() const { + return m_coeffs.empty() && m_coeff == 0; + } }; ast_manager& m; @@ -50,10 +62,24 @@ namespace arith { bool m_strict = false; row m_ineq; row m_conseq; - vector m_eqs; - symbol m_farkas; - symbol m_implied_eq; - symbol m_bound; + vector m_eqs, m_ineqs; + symbol m_farkas = symbol("farkas"); + symbol m_implied_eq = symbol("implied-eq"); + symbol m_bound = symbol("bound"); + symbol m_cut = symbol("cut"); + + rule_type_t rule_type(app* jst) const { + if (jst->get_name() == m_cut) + return cut_t; + if (jst->get_name() == m_bound) + return bound_t; + if (jst->get_name() == m_implied_eq) + return implied_eq_t; + if (jst->get_name() == m_farkas) + return farkas_t; + return none_t; + } + void add(row& r, expr* v, rational const& coeff) { rational coeff1; @@ -90,10 +116,10 @@ namespace arith { // X = lcm(a,b)/b, Y = -lcm(a,b)/a if v is integer // X = 1/b, Y = -1/a if v is real // - void resolve(expr* v, row& dst, rational const& A, row const& src) { + bool resolve(expr* v, row& dst, rational const& A, row const& src) { rational B, x, y; if (!dst.m_coeffs.find(v, B)) - return; + return false; if (a.is_int(v)) { rational lc = lcm(abs(A), abs(B)); x = lc / abs(B); @@ -109,6 +135,7 @@ namespace arith { y.neg(); mul(dst, x); add(dst, src, y); + return true; } void cut(row& r) { @@ -197,6 +224,8 @@ namespace arith { resolve(v, m_eqs[j], coeff, r); resolve(v, m_ineq, coeff, r); resolve(v, m_conseq, coeff, r); + for (auto& ineq : m_ineqs) + resolve(v, ineq, coeff, r); } return true; } @@ -269,6 +298,81 @@ namespace arith { return false; } + /** + Check implied equality lemma: + + inequalities & equalities => equality + + + We may assume the set of inequality assumptions we are given are all tight, non-strict and imply equalities. + In other words, given a set of inequalities a1x + b1 <= 0, ..., anx + bn <= 0 + the equalities a1x + b1 = 0, ..., anx + bn = 0 are all consequences. + + We use a weaker property: We derive implied equalities by applying exhaustive Fourier-Motzkin + elimination and then collect the tight 0 <= 0 inequalities that are derived. + + Claim: the set of inequalities used to derive 0 <= 0 are all tight equalities. + */ + + svector> m_deps; + unsigned_vector m_tight_inequalities; + uint_set m_ineqs_that_are_eqs; + + bool check_implied_eq() { + if (!reduce_eq()) + return true; + if (m_conseq.is_zero()) + return true; + + m_eqs.reset(); + m_deps.reset(); + unsigned orig_size = m_ineqs.size(); + m_deps.reserve(orig_size); + for (unsigned i = 0; i < m_ineqs.size(); ++i) { + row& r = m_ineqs[i]; + if (r.is_zero()) { + m_tight_inequalities.push_back(i); + continue; + } + auto const& [v, coeff] = *r.m_coeffs.begin(); + unsigned sz = m_ineqs.size(); + + for (unsigned j = i + 1; j < sz; ++j) { + rational B; + row& r2 = m_ineqs[j]; + if (!r2.m_coeffs.find(v, B) || (coeff > 0 && B > 0) || (coeff < 0 && B < 0)) + continue; + row& r3 = fresh(m_ineqs); + add(r3, m_ineqs[j], rational::one()); + resolve(v, r3, coeff, m_ineqs[i]); + m_deps.push_back({i, j}); + } + SASSERT(m_deps.size() == m_ineqs.size()); + } + + m_ineqs_that_are_eqs.reset(); + while (!m_tight_inequalities.empty()) { + unsigned j = m_tight_inequalities.back(); + m_tight_inequalities.pop_back(); + if (m_ineqs_that_are_eqs.contains(j)) + continue; + m_ineqs_that_are_eqs.insert(j); + if (j < orig_size) { + m_eqs.push_back(m_ineqs[j]); + } + else { + auto [a, b] = m_deps[j]; + m_tight_inequalities.push_back(a); + m_tight_inequalities.push_back(b); + } + } + m_ineqs.reset(); + + VERIFY (reduce_eq()); + + return m_conseq.is_zero(); + } + std::ostream& display_row(std::ostream& out, row const& r) { bool first = true; for (auto const& [v, coeff] : r.m_coeffs) { @@ -306,22 +410,21 @@ namespace arith { public: theory_checker(ast_manager& m): m(m), - a(m), - m_farkas("farkas"), - m_implied_eq("implied-eq"), - m_bound("bound") {} + a(m) {} void reset() { m_ineq.reset(); m_conseq.reset(); m_eqs.reset(); + m_ineqs.reset(); m_strict = false; } - bool add_ineq(rational const& coeff, expr* e, bool sign) { - return add_literal(m_ineq, abs(coeff), e, sign); + bool add_ineq(rule_type_t rt, rational const& coeff, expr* e, bool sign) { + row& r = rt == implied_eq_t ? fresh(m_ineqs) : m_ineq; + return add_literal(r, abs(coeff), e, sign); } - + bool add_conseq(rational const& coeff, expr* e, bool sign) { return add_literal(m_conseq, abs(coeff), e, sign); } @@ -332,11 +435,17 @@ namespace arith { linearize(r, rational(-1), b); } - bool check() { - if (m_conseq.m_coeffs.empty()) + bool check(rule_type_t rt) { + switch (rt) { + case farkas_t: return check_farkas(); - else + case bound_t: return check_bound(); + case implied_eq_t: + return check_implied_eq(); + default: + return check_bound(); + } } std::ostream& display(std::ostream& out) { @@ -359,7 +468,7 @@ namespace arith { /** Add implied equality as an inequality */ - bool add_implied_ineq(bool sign, app* jst) { + bool add_implied_diseq(bool sign, app* jst) { unsigned n = jst->get_num_args(); if (n < 2) return false; @@ -374,90 +483,57 @@ namespace arith { return false; if (!sign) coeff.neg(); - auto& r = m_ineq; + auto& r = m_conseq; linearize(r, coeff, arg1); linearize(r, -coeff, arg2); - m_strict = true; return true; } bool check(app* jst) override { reset(); - bool is_bound = jst->get_name() == m_bound; - bool is_implied_eq = jst->get_name() == m_implied_eq; - bool is_farkas = jst->get_name() == m_farkas; - if (!is_farkas && !is_bound && !is_implied_eq) { + + auto rt = rule_type(jst); + switch (rt) { + case cut_t: + return false; + case none_t: IF_VERBOSE(0, verbose_stream() << "unhandled inference " << mk_pp(jst, m) << "\n"); return false; + default: + break; } bool even = true; rational coeff; expr* x, * y; - unsigned j = 0, num_le = 0; - + unsigned j = 0; for (expr* arg : *jst) { + if (even) { if (!a.is_numeral(arg, coeff)) { IF_VERBOSE(0, verbose_stream() << "not numeral " << mk_pp(jst, m) << "\n"); return false; } - if (is_implied_eq) { - is_implied_eq = false; - if (!coeff.is_unsigned()) { - IF_VERBOSE(0, verbose_stream() << "not unsigned " << mk_pp(jst, m) << "\n"); - return false; - } - num_le = coeff.get_unsigned(); - if (!add_implied_ineq(false, jst)) { - IF_VERBOSE(0, display(verbose_stream() << "did not add implied eq")); - return false; - } - ++j; - continue; - } } else { bool sign = m.is_not(arg, arg); if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) { - if (is_bound && j + 1 == jst->get_num_args()) + if (rt == bound_t && j + 1 == jst->get_num_args()) add_conseq(coeff, arg, sign); - else if (num_le > 0) { - add_ineq(coeff, arg, sign); - --num_le; - if (num_le == 0) { - // we processed all the first inequalities, - // check that they imply one half of the implied equality. - if (!check()) { - // we might have added the wrong direction of the implied equality. - // so try the opposite inequality. - add_implied_ineq(true, jst); - add_implied_ineq(true, jst); - if (check()) { - reset(); - add_implied_ineq(false, jst); - } - else { - IF_VERBOSE(0, display(verbose_stream() << "failed to check implied eq ")); - return false; - } - } - else { - reset(); - VERIFY(add_implied_ineq(true, jst)); - } - } - } else - add_ineq(coeff, arg, sign); + add_ineq(rt, coeff, arg, sign); } else if (m.is_eq(arg, x, y)) { - if (is_bound && j + 1 == jst->get_num_args()) + if (rt == bound_t && j + 1 == jst->get_num_args()) add_conseq(coeff, arg, sign); - else if (sign) - return check(); // it should be an implied equality - else + else if (rt == implied_eq_t && j + 1 == jst->get_num_args()) + return add_implied_diseq(sign, jst) && check(rt); + else if (!sign) add_eq(x, y); + else { + IF_VERBOSE(0, verbose_stream() << "unexpected disequality in justification " << mk_pp(arg, m) << "\n"); + return false; + } } else { IF_VERBOSE(0, verbose_stream() << "not a recognized arithmetical relation " << mk_pp(arg, m) << "\n"); @@ -467,13 +543,14 @@ namespace arith { even = !even; ++j; } - return check(); + return check(rt); } void register_plugins(euf::theory_checker& pc) override { pc.register_plugin(m_farkas, this); pc.register_plugin(m_bound, this); pc.register_plugin(m_implied_eq, this); + pc.register_plugin(m_cut, this); } }; diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 39c9879a6..e55726af6 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -465,6 +465,8 @@ namespace euf { void solver::display_inferred(std::ostream& out, unsigned n, literal const* lits, expr* proof_hint) { expr_ref hint(proof_hint, m); + if (!proof_hint) + verbose_stream() << hint << "\n"; if (!hint) hint = m.mk_const(m_smt, m.mk_proof_sort()); visit_expr(out, hint); diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index 21e3883e8..51783f0a5 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -240,6 +240,9 @@ namespace euf { m_literals[i] = lits[i]; base_ptr += sizeof(literal) * n_lits; m_eqs = reinterpret_cast(base_ptr); + if (!pma) { + verbose_stream() << "null\n"; + } for (i = 0; i < n_eqs; ++i) { m_eqs[i] = eqs[i]; if (m_eqs[i].first->get_id() > m_eqs[i].second->get_id())