diff --git a/src/sat/smt/arith_proof_checker.h b/src/sat/smt/arith_proof_checker.h index b6721300a..e55f2635a 100644 --- a/src/sat/smt/arith_proof_checker.h +++ b/src/sat/smt/arith_proof_checker.h @@ -29,7 +29,7 @@ namespace arith { rational m_coeff; void reset() { m_coeffs.reset(); - m_coeff.reset(); + m_coeff = 0; } }; @@ -103,28 +103,29 @@ namespace arith { void linearize(row& r, rational const& mul, expr* e) { SASSERT(m_todo.empty()); m_todo.push_back({ mul, e }); - rational mul1; + rational coeff1; expr* e1, *e2; for (unsigned i = 0; i < m_todo.size(); ++i) { auto [coeff, e] = m_todo[i]; - if (a.is_mul(e, e1, e2) && a.is_numeral(e1, mul1)) - m_todo.push_back({mul*mul1, e2}); - else if (a.is_mul(e, e1, e2) && a.is_numeral(e2, mul1)) - m_todo.push_back({mul*mul1, e1}); + if (a.is_mul(e, e1, e2) && a.is_numeral(e1, coeff1)) + m_todo.push_back({coeff*coeff1, e2}); + else if (a.is_mul(e, e1, e2) && a.is_numeral(e2, coeff1)) + m_todo.push_back({coeff*coeff1, e1}); else if (a.is_add(e)) for (expr* arg : *to_app(e)) - m_todo.push_back({mul, arg}); + m_todo.push_back({coeff, arg}); else if (a.is_uminus(e, e1)) - m_todo.push_back({-mul, e1}); + m_todo.push_back({-coeff, e1}); else if (a.is_sub(e, e1, e2)) { - m_todo.push_back({mul, e1}); - m_todo.push_back({-mul, e2}); + m_todo.push_back({coeff, e1}); + m_todo.push_back({-coeff, e2}); } - else if (a.is_numeral(e, mul1)) - r.m_coeff += mul*mul1; + else if (a.is_numeral(e, coeff1)) + r.m_coeff += coeff*coeff1; else - add(r, e, mul); - } + add(r, e, coeff); + } + m_todo.reset(); } bool check_ineq(row& r) { @@ -154,11 +155,11 @@ namespace arith { expr* e1, *e2 = nullptr; if ((a.is_le(e, e1, e2) || a.is_ge(e, e2, e1)) && !sign) { linearize(r, coeff, e1); - linearize(r, -coeff, e2); + linearize(r, -coeff, e2); } - else if ((a.is_lt(e, e1, e2) || a.is_lt(e, e2, e1)) && sign) { + else if ((a.is_lt(e, e1, e2) || a.is_gt(e, e2, e1)) && sign) { linearize(r, coeff, e2); - linearize(r, -coeff, e1); + linearize(r, -coeff, e1); } else if ((a.is_le(e, e1, e2) || a.is_ge(e, e2, e1)) && sign) { linearize(r, coeff, e2); @@ -168,7 +169,7 @@ namespace arith { else m_strict = true; } - else if ((a.is_lt(e, e1, e2) || a.is_lt(e, e2, e1)) && !sign) { + else if ((a.is_lt(e, e1, e2) || a.is_gt(e, e2, e1)) && !sign) { linearize(r, coeff, e1); linearize(r, -coeff, e2); if (a.is_int(e1)) @@ -178,6 +179,7 @@ namespace arith { } else return false; + // display_row(std::cout << coeff << " * " << (sign?"~":"") << mk_pp(e, m) << "\n", r) << "\n"; return true; } @@ -213,21 +215,19 @@ namespace arith { return false; } - void display_row(std::ostream& out, row const& r) { + std::ostream& display_row(std::ostream& out, row const& r) { bool first = true; for (auto const& [v, coeff] : r.m_coeffs) { - if (!first && coeff > 0) + if (!first) out << " + "; if (coeff != 1) out << coeff << " * "; - out << mk_pp(v, m) << " "; + out << mk_pp(v, m); first = false; } - if (r.m_coeff != 0) { - if (r.m_coeff > 0) - out << "+ "; - out << r.m_coeff; - } + if (r.m_coeff != 0) + out << " + " << r.m_coeff; + return out; } @@ -238,7 +238,10 @@ namespace arith { void display_ineq(std::ostream& out, row const& r) { display_row(out, r); - out << " <= 0\n"; + if (m_strict) + out << " < 0\n"; + else + out << " <= 0\n"; } @@ -253,11 +256,11 @@ namespace arith { } bool add_ineq(rational const& coeff, expr* e, bool sign) { - return add_literal(m_ineq, coeff, e, sign); + return add_literal(m_ineq, abs(coeff), e, sign); } bool add_conseq(rational const& coeff, expr* e, bool sign) { - return add_literal(m_conseq, coeff, e, sign); + return add_literal(m_conseq, abs(coeff), e, sign); } void add_eq(expr* a, expr* b) { diff --git a/src/shell/drat_frontend.cpp b/src/shell/drat_frontend.cpp index 53260be54..aad4ad2ba 100644 --- a/src/shell/drat_frontend.cpp +++ b/src/shell/drat_frontend.cpp @@ -166,6 +166,7 @@ public: // return; // remove when testing this arith_util autil(m); arith::proof_checker achecker(m); + proof_checker pc(m); switch (hint.m_ty) { case sat::hint_type::null_h: break; @@ -178,6 +179,7 @@ public: expr* y = exprs[b]; achecker.add_eq(x, y); } + unsigned sz = hint.m_literals.size(); for (unsigned i = 0; i < sz; ++i) { auto const& [coeff, lit] = hint.m_literals[i]; @@ -195,9 +197,29 @@ public: } } + // achecker.display(std::cout << "checking\n"); bool ok = achecker.check(); if (!ok) { + rational lc(1); + for (auto const& [coeff, lit] : hint.m_literals) + lc = lcm(lc, denominator(coeff)); + bool is_strict = false; + expr_ref sum(m); + for (auto const& [coeff, lit] : hint.m_literals) { + app_ref e(to_app(m_b2e[lit.var()]), m); + VERIFY(pc.check_arith_literal(!lit.sign(), e, coeff*lc, sum, is_strict)); + std::cout << "sum: " << sum << "\n"; + } + sort* s = sum->get_sort(); + if (is_strict) + sum = autil.mk_lt(sum, autil.mk_numeral(rational(0), s)); + else + sum = autil.mk_le(sum, autil.mk_numeral(rational(0), s)); + th_rewriter rw(m); + rw(sum); + std::cout << "sum: " << sum << "\n"; + for (auto const& [coeff, a, b]: hint.m_eqs) { expr* x = exprs[a]; expr* y = exprs[b]; @@ -209,6 +231,8 @@ public: if (lit.sign()) e = m.mk_not(e); std::cout << e << "\n"; } + achecker.display(std::cout); + std::cout << "p hint not verified\n"; return false; }