diff --git a/src/sat/smt/arith_proof_checker.h b/src/sat/smt/arith_proof_checker.h new file mode 100644 index 000000000..b6721300a --- /dev/null +++ b/src/sat/smt/arith_proof_checker.h @@ -0,0 +1,289 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + arith_proof_checker.h + +Abstract: + + Plugin for checking arithmetic lemmas + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ +#pragma once + +#include "util/obj_pair_set.h" +#include "ast/ast_trail.h" +#include "ast/arith_decl_plugin.h" + + +namespace arith { + + class proof_checker { + struct row { + obj_map m_coeffs; + rational m_coeff; + void reset() { + m_coeffs.reset(); + m_coeff.reset(); + } + }; + + ast_manager& m; + arith_util a; + vector> m_todo; + bool m_strict = false; + row m_ineq; + row m_conseq; + vector m_eqs; + + void add(row& r, expr* v, rational const& coeff) { + rational coeff1; + if (coeff.is_zero()) + return; + if (r.m_coeffs.find(v, coeff1)) { + coeff1 += coeff; + if (coeff1.is_zero()) + r.m_coeffs.erase(v); + else + r.m_coeffs[v] = coeff1; + } + else + r.m_coeffs.insert(v, coeff); + } + + void mul(row& r, rational const& coeff) { + if (coeff == 1) + return; + for (auto & [v, c] : r.m_coeffs) + c *= coeff; + r.m_coeff *= coeff; + } + + // dst <- dst + mul*src + void add(row& dst, row const& src, rational const& mul) { + for (auto const& [v, c] : src.m_coeffs) + add(dst, v, c*mul); + dst.m_coeff += mul*src.m_coeff; + } + + // dst <- X*dst + Y*src + // where + // X = lcm(a,b)/b, Y = -lcm(a,b)/a if v is integer + // X = 1/b, Y = -1/a if v is real + // + void resolve(expr* v, row& dst, rational const& A, row const& src) { + rational B, x, y; + if (!dst.m_coeffs.find(v, B)) + return; + if (a.is_int(v)) { + rational lc = lcm(abs(A), abs(B)); + x = lc / abs(B); + y = lc / abs(A); + } + else { + x = rational(1) / abs(B); + y = rational(1) / abs(A); + } + if (A < 0 && B < 0) + y.neg(); + if (A > 0 && B > 0) + y.neg(); + mul(dst, x); + add(dst, src, y); + } + + /** + * \brief populate m_coeffs, m_coeff based on mul*e + */ + void linearize(row& r, rational const& mul, expr* e) { + SASSERT(m_todo.empty()); + m_todo.push_back({ mul, e }); + rational mul1; + expr* e1, *e2; + for (unsigned i = 0; i < m_todo.size(); ++i) { + auto [coeff, e] = m_todo[i]; + if (a.is_mul(e, e1, e2) && a.is_numeral(e1, mul1)) + m_todo.push_back({mul*mul1, e2}); + else if (a.is_mul(e, e1, e2) && a.is_numeral(e2, mul1)) + m_todo.push_back({mul*mul1, e1}); + else if (a.is_add(e)) + for (expr* arg : *to_app(e)) + m_todo.push_back({mul, arg}); + else if (a.is_uminus(e, e1)) + m_todo.push_back({-mul, e1}); + else if (a.is_sub(e, e1, e2)) { + m_todo.push_back({mul, e1}); + m_todo.push_back({-mul, e2}); + } + else if (a.is_numeral(e, mul1)) + r.m_coeff += mul*mul1; + else + add(r, e, mul); + } + } + + bool check_ineq(row& r) { + if (r.m_coeffs.empty() && r.m_coeff > 0) + return true; + if (r.m_coeffs.empty() && m_strict && r.m_coeff == 0) + return true; + return false; + } + + // triangulate equalities, substitute results into m_ineq, m_conseq. + void reduce_eq() { + for (unsigned i = 0; i < m_eqs.size(); ++i) { + auto& r = m_eqs[i]; + if (r.m_coeffs.empty()) + continue; + auto [v, coeff] = *r.m_coeffs.begin(); + for (unsigned j = i + 1; j < m_eqs.size(); ++j) + resolve(v, m_eqs[j], coeff, r); + resolve(v, m_ineq, coeff, r); + resolve(v, m_conseq, coeff, r); + } + } + + + bool add_literal(row& r, rational const& coeff, expr* e, bool sign) { + expr* e1, *e2 = nullptr; + if ((a.is_le(e, e1, e2) || a.is_ge(e, e2, e1)) && !sign) { + linearize(r, coeff, e1); + linearize(r, -coeff, e2); + } + else if ((a.is_lt(e, e1, e2) || a.is_lt(e, e2, e1)) && sign) { + linearize(r, coeff, e2); + linearize(r, -coeff, e1); + } + else if ((a.is_le(e, e1, e2) || a.is_ge(e, e2, e1)) && sign) { + linearize(r, coeff, e2); + linearize(r, -coeff, e1); + if (a.is_int(e1)) + r.m_coeff += coeff; + else + m_strict = true; + } + else if ((a.is_lt(e, e1, e2) || a.is_lt(e, e2, e1)) && !sign) { + linearize(r, coeff, e1); + linearize(r, -coeff, e2); + if (a.is_int(e1)) + r.m_coeff += coeff; + else + m_strict = true; + } + else + return false; + return true; + } + + bool check_farkas() { + if (check_ineq(m_ineq)) + return true; + reduce_eq(); + if (check_ineq(m_ineq)) + return true; + + // convert to expression, maybe follows from a cut. + return false; + } + + // + // farkas coefficient is computed for m_conseq + // after all inequalities in ineq have been added up + // + bool check_bound() { + reduce_eq(); + if (check_ineq(m_conseq)) + return true; + if (m_ineq.m_coeffs.empty() || + m_conseq.m_coeffs.empty()) + return false; + auto const& [v, coeff1] = *m_ineq.m_coeffs.begin(); + rational coeff2; + if (!m_conseq.m_coeffs.find(v, coeff2)) + return false; + add(m_conseq, m_ineq, abs(coeff2/coeff1)); + if (check_ineq(m_conseq)) + return true; + return false; + } + + void display_row(std::ostream& out, row const& r) { + bool first = true; + for (auto const& [v, coeff] : r.m_coeffs) { + if (!first && coeff > 0) + out << " + "; + if (coeff != 1) + out << coeff << " * "; + out << mk_pp(v, m) << " "; + first = false; + } + if (r.m_coeff != 0) { + if (r.m_coeff > 0) + out << "+ "; + out << r.m_coeff; + } + } + + + void display_eq(std::ostream& out, row const& r) { + display_row(out, r); + out << " = 0\n"; + } + + void display_ineq(std::ostream& out, row const& r) { + display_row(out, r); + out << " <= 0\n"; + } + + + public: + proof_checker(ast_manager& m): m(m), a(m) {} + + void reset() { + m_ineq.reset(); + m_conseq.reset(); + m_eqs.reset(); + m_strict = false; + } + + bool add_ineq(rational const& coeff, expr* e, bool sign) { + return add_literal(m_ineq, coeff, e, sign); + } + + bool add_conseq(rational const& coeff, expr* e, bool sign) { + return add_literal(m_conseq, coeff, e, sign); + } + + void add_eq(expr* a, expr* b) { + m_eqs.push_back(row()); + row& r = m_eqs.back(); + linearize(r, rational(1), a); + linearize(r, rational(-1), b); + } + + bool check() { + if (!m_conseq.m_coeffs.empty()) + return check_bound(); + else + return check_farkas(); + } + + std::ostream& display(std::ostream& out) { + for (auto & r : m_eqs) + display_eq(out, r); + display_ineq(out, m_ineq); + if (!m_conseq.m_coeffs.empty()) + display_ineq(out, m_conseq); + return out; + } + + + }; + +} diff --git a/src/shell/drat_frontend.cpp b/src/shell/drat_frontend.cpp index 8bd629845..53260be54 100644 --- a/src/shell/drat_frontend.cpp +++ b/src/shell/drat_frontend.cpp @@ -17,6 +17,8 @@ Copyright (c) 2020 Microsoft Corporation #include "cmd_context/cmd_context.h" #include "ast/proofs/proof_checker.h" #include "ast/rewriter/th_rewriter.h" +#include "sat/smt/arith_proof_checker.h" + class smt_checker { ast_manager& m; @@ -162,92 +164,40 @@ public: bool validate_hint(expr_ref_vector const& exprs, sat::literal_vector const& lits, sat::proof_hint const& hint) { // return; // remove when testing this - proof_checker pc(m); arith_util autil(m); + arith::proof_checker achecker(m); switch (hint.m_ty) { case sat::hint_type::null_h: break; case sat::hint_type::cut_h: case sat::hint_type::bound_h: case sat::hint_type::farkas_h: { - expr_ref sum(m), last_sum(m); - bool is_strict = false; - vector coeffs; - rational lc(1); - for (auto const& [coeff, a, b]: hint.m_eqs) { - coeffs.push_back(coeff); - lc = lcm(lc, denominator(coeff)); - } - - for (auto const& [coeff, lit] : hint.m_literals) { - coeffs.push_back(coeff); - lc = lcm(lc, denominator(coeff)); - } - if (!lc.is_one()) - for (auto& coeff : coeffs) - coeff *= lc; - - unsigned i = 0; + achecker.reset(); for (auto const& [coeff, a, b]: hint.m_eqs) { expr* x = exprs[a]; expr* y = exprs[b]; - coeffs.push_back(coeff); - app_ref e(m.mk_eq(x, y), m); - if (!pc.check_arith_literal(true, e, coeffs[i], sum, is_strict)) { - std::cout << "p failed checking hint " << e << "\n"; - return false; - } - ++i; + achecker.add_eq(x, y); } - - for (auto const& [coeff, lit] : hint.m_literals) { - last_sum = sum; + unsigned sz = hint.m_literals.size(); + for (unsigned i = 0; i < sz; ++i) { + auto const& [coeff, lit] = hint.m_literals[i]; app_ref e(to_app(m_b2e[lit.var()]), m); - if (!pc.check_arith_literal(!lit.sign(), e, coeffs[i], sum, is_strict)) { - std::cout << "p failed checking hint " << e << "\n"; - return false; - } - ++i; - } - - if (!sum.get()) { - std::cout << "p no summation\n"; - return false; - } - - th_rewriter rw(m); - if (sat::hint_type::bound_h == hint.m_ty) { - rw(last_sum); - sum = last_sum; - auto const& [coeff, lit] = hint.m_literals.back(); - rational last_coeff = coeff, r; - expr* x, *y, *z; - if (autil.is_add(sum)) { - x = to_app(sum)->get_arg(1); - if (autil.is_mul(x, y, z) && autil.is_numeral(y, r)) { - last_coeff = r; + if (i + 1 == sz && sat::hint_type::bound_h == hint.m_ty) { + if (!achecker.add_conseq(coeff, e, lit.sign())) { + std::cout << "p failed checking hint " << e << "\n"; + return false; } + + } + else if (!achecker.add_ineq(coeff, e, lit.sign())) { + std::cout << "p failed checking hint " << e << "\n"; + return false; } - app_ref e(to_app(m_b2e[lit.var()]), m); - VERIFY(pc.check_arith_literal(!lit.sign(), e, last_coeff, sum, is_strict)); } - if (is_strict) - sum = autil.mk_lt(sum, autil.mk_numeral(rational(0), sum->get_sort())); - else - sum = autil.mk_le(sum, autil.mk_numeral(rational(0), sum->get_sort())); - - rw(sum); - if (!m.is_false(sum)) { - // check hint: - std::cout << "p hint not verified " << sum << "\n"; - auto const& [coeff, lit] = hint.m_literals.back(); - expr_ref sum1(m); - bool is_strict1 = false; - app_ref e(to_app(m_b2e[lit.var()]), m); - rational coeffb = coeffs.back(); - VERIFY(pc.check_arith_literal(!lit.sign(), e, coeffb, sum1, is_strict1)); - std::cout << last_sum << " => ~" << sum1 << "\n"; + bool ok = achecker.check(); + + if (!ok) { for (auto const& [coeff, a, b]: hint.m_eqs) { expr* x = exprs[a]; expr* y = exprs[b]; @@ -261,6 +211,8 @@ public: } return false; } + + std::cout << "p hint verified\n"; break; }