From b629960afb41cd40dd59ab33207445015a89267d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 6 Jun 2022 07:18:33 -0700 Subject: [PATCH] proof format Signed-off-by: Nikolaj Bjorner --- src/math/lp/explanation.h | 4 +- src/sat/sat_drat.cpp | 46 +++++++++++++-------- src/sat/sat_types.h | 11 ++--- src/sat/smt/arith_diagnostics.cpp | 68 ++++++++++++++++++------------- src/sat/smt/arith_proof_checker.h | 44 ++++++++++++++++++-- src/sat/smt/arith_solver.cpp | 8 ++-- src/sat/smt/arith_solver.h | 4 +- src/shell/drat_frontend.cpp | 21 +++++++--- 8 files changed, 144 insertions(+), 62 deletions(-) diff --git a/src/math/lp/explanation.h b/src/math/lp/explanation.h index 700a30acc..d2e7edc33 100644 --- a/src/math/lp/explanation.h +++ b/src/math/lp/explanation.h @@ -53,7 +53,8 @@ public: if (e.m_vector.empty()) { for (constraint_index j : e.m_set) push_back(j); - } else { + } + else { for (const auto & p : e.m_vector) { add_pair(p.first, p.second); } @@ -71,6 +72,7 @@ public: constraint_index ci() const { return m_var; } const mpq &coeff() const { return m_coeff; } }; + class iterator { bool m_run_on_vector; mpq m_one = one_of_type(); diff --git a/src/sat/sat_drat.cpp b/src/sat/sat_drat.cpp index e8594b048..12a0b2600 100644 --- a/src/sat/sat_drat.cpp +++ b/src/sat/sat_drat.cpp @@ -921,14 +921,19 @@ namespace sat { case hint_type::bound_h: ous << "bound "; break; - case hint_type::cut_h: - ous << "cut "; + case hint_type::implied_eq_h: + ous << "implied_eq "; + break; + default: + UNREACHABLE(); break; } for (auto const& [q, l] : m_literals) ous << rational(q) << " * " << l << " "; - for (auto const& [q, a, b] : m_eqs) - ous << rational(q) << " = " << a << " " << b << " "; + for (auto const& [a, b] : m_eqs) + ous << " = " << a << " " << b << " "; + for (auto const& [a, b] : m_diseqs) + ous << " != " << a << " " << b << " "; return ous.str(); } @@ -954,9 +959,9 @@ namespace sat { s += 5; return true; } - if (0 == strncmp(s, "cut", 3)) { - h.m_ty = hint_type::cut_h; - s += 3; + if (0 == strncmp(s, "implied_eq", 10)) { + h.m_ty = hint_type::implied_eq_h; + s += 10; return true; } return false; @@ -982,6 +987,24 @@ namespace sat { return sat::literal(r.get_unsigned(), false); }; auto parse_coeff_literal = [&]() { + if (*s == '=') { + ++s; + ws(); + unsigned a = parse_coeff().get_unsigned(); + ws(); + unsigned b = parse_coeff().get_unsigned(); + h.m_eqs.push_back(std::make_pair(a, b)); + return true; + } + if (*s == '!' && *(s + 1) == '=') { + s += 2; + ws(); + unsigned a = parse_coeff().get_unsigned(); + ws(); + unsigned b = parse_coeff().get_unsigned(); + h.m_diseqs.push_back(std::make_pair(a, b)); + return true; + } rational coeff = parse_coeff(); ws(); if (*s == '*') { @@ -991,15 +1014,6 @@ namespace sat { h.m_literals.push_back(std::make_pair(coeff, lit)); return true; } - if (*s == '=') { - ++s; - ws(); - unsigned a = parse_coeff().get_unsigned(); - ws(); - unsigned b = parse_coeff().get_unsigned(); - h.m_eqs.push_back(std::make_tuple(coeff, a, b)); - return true; - } return false; }; diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index bc7ea6ef9..fa42f0712 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -98,14 +98,15 @@ namespace sat { null_h, farkas_h, bound_h, - cut_h + implied_eq_h, }; struct proof_hint { - hint_type m_ty = hint_type::null_h; - vector> m_literals; - vector> m_eqs; - void reset() { m_ty = hint_type::null_h; m_literals.reset(); m_eqs.reset(); } + hint_type m_ty = hint_type::null_h; + vector> m_literals; + vector> m_eqs; + vector> m_diseqs; + void reset() { m_ty = hint_type::null_h; m_literals.reset(); m_eqs.reset(); m_diseqs.reset(); } std::string to_string() const; void from_string(char const* s); void from_string(std::string const& s) { from_string(s.c_str()); } diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index 3dd4f58de..1208dca11 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -80,6 +80,33 @@ namespace arith { if (m_nla) m_nla->collect_statistics(st); } + void solver::add_assumptions() { + m_arith_hint.reset(); + unsigned i = 0; + for (auto const & ev : m_explanation) { + ++i; + auto idx = ev.ci(); + if (UINT_MAX == idx) + continue; + switch (m_constraint_sources[idx]) { + case inequality_source: { + literal lit = m_inequalities[idx]; + m_arith_hint.m_literals.push_back({ev.coeff(), lit}); + break; + } + case equality_source: { + auto [u, v] = m_equalities[idx]; + ctx.drat_log_expr(u->get_expr()); + ctx.drat_log_expr(v->get_expr()); + m_arith_hint.m_eqs.push_back({u->get_expr_id(), v->get_expr_id()}); + break; + } + default: + break; + } + } + } + /** * It may be necessary to use the following assumption when checking Farkas claims * generated from bounds propagation: @@ -91,34 +118,19 @@ namespace arith { sat::proof_hint const* solver::explain(sat::hint_type ty, sat::literal lit) { if (!ctx.use_drat()) return nullptr; - m_bounds_pragma.m_ty = ty; - m_bounds_pragma.m_literals.reset(); - m_bounds_pragma.m_eqs.reset(); - unsigned i = 0; - for (auto const & ev : m_explanation) { - ++i; - auto idx = ev.ci(); - if (UINT_MAX == idx) - continue; - switch (m_constraint_sources[idx]) { - case inequality_source: { - literal lit = m_inequalities[idx]; - m_bounds_pragma.m_literals.push_back({ev.coeff(), lit}); - break; - } - case equality_source: { - auto [u, v] = m_equalities[idx]; - ctx.drat_log_expr(u->get_expr()); - ctx.drat_log_expr(v->get_expr()); - m_bounds_pragma.m_eqs.push_back({ev.coeff(), u->get_expr_id(), v->get_expr_id()}); - break; - } - default: - break; - } - } + m_arith_hint.m_ty = ty; + add_assumptions(); if (lit != sat::null_literal) - m_bounds_pragma.m_literals.push_back({rational(1), ~lit}); - return &m_bounds_pragma; + m_arith_hint.m_literals.push_back({rational(1), ~lit}); + return &m_arith_hint; + } + + sat::proof_hint const* solver::explain_implied_eq(euf::enode* a, euf::enode* b) { + if (!ctx.use_drat()) + return nullptr; + m_arith_hint.m_ty = sat::hint_type::implied_eq_h; + add_assumptions(); + m_arith_hint.m_diseqs.push_back({a->get_expr_id(), b->get_expr_id()}); + return &m_arith_hint; } } diff --git a/src/sat/smt/arith_proof_checker.h b/src/sat/smt/arith_proof_checker.h index 755f9a3bd..c37a4f7c2 100644 --- a/src/sat/smt/arith_proof_checker.h +++ b/src/sat/smt/arith_proof_checker.h @@ -40,6 +40,8 @@ namespace arith { row m_ineq; row m_conseq; vector m_eqs; + vector m_ineqs; + vector m_diseqs; void add(row& r, expr* v, rational const& coeff) { rational coeff1; @@ -241,6 +243,26 @@ namespace arith { return false; } + // + // checking disequalities is TBD. + // it has to select only a subset of bounds to justify each inequality. + // example + // c <= x <= c, c <= y <= c => x = y + // for the proof of x <= y use the inequalities x <= c <= y + // for the proof of y <= x use the inequalities y <= c <= x + // example + // x <= y, y <= z, z <= u, u <= x => x = z + // for the proof of x <= z use the inequalities x <= y, y <= z + // for the proof of z <= x use the inequalities z <= u, u <= x + // + // so when m_diseqs is non-empty we can't just add inequalities with Farkas coefficients + // into m_ineq, since coefficients of the usable subset vanish. + // + + bool check_diseq() { + return false; + } + std::ostream& display_row(std::ostream& out, row const& r) { bool first = true; for (auto const& [v, coeff] : r.m_coeffs) { @@ -270,6 +292,11 @@ namespace arith { out << " <= 0\n"; } + row& fresh(vector& rows) { + rows.push_back(row()); + return rows.back(); + } + public: proof_checker(ast_manager& m): m(m), a(m) {} @@ -278,10 +305,14 @@ namespace arith { m_ineq.reset(); m_conseq.reset(); m_eqs.reset(); + m_ineqs.reset(); + m_diseqs.reset(); m_strict = false; } bool add_ineq(rational const& coeff, expr* e, bool sign) { + if (!m_diseqs.empty()) + return add_literal(fresh(m_ineqs), abs(coeff), e, sign); return add_literal(m_ineq, abs(coeff), e, sign); } @@ -290,14 +321,21 @@ namespace arith { } void add_eq(expr* a, expr* b) { - m_eqs.push_back(row()); - row& r = m_eqs.back(); + row& r = fresh(m_eqs); linearize(r, rational(1), a); linearize(r, rational(-1), b); } + + void add_diseq(expr* a, expr* b) { + row& r = fresh(m_diseqs); + linearize(r, rational(1), a); + linearize(r, rational(-1), b); + } bool check() { - if (!m_conseq.m_coeffs.empty()) + if (!m_diseqs.empty()) + return check_diseq(); + else if (!m_conseq.m_coeffs.empty()) return check_bound(); else return check_farkas(); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 66a4c2a65..fc3677cb9 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -321,7 +321,8 @@ namespace arith { reset_evidence(); for (auto ev : e) set_evidence(ev.ci()); - auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, n1, n2); // TODO add equality explanation + auto* ex = explain_implied_eq(n1, n2); + auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, n1, n2, ex); ctx.propagate(n1, n2, jst->to_index()); return true; } @@ -756,7 +757,8 @@ namespace arith { set_evidence(ci4); enode* x = var2enode(v1); enode* y = var2enode(v2); - auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y); // TODO add equality explanation + auto* ex = explain_implied_eq(x, y); + auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y, ex); ctx.propagate(x, y, jst->to_index()); } @@ -1176,7 +1178,7 @@ namespace arith { app_ref b = mk_bound(m_lia->get_term(), m_lia->get_offset(), !m_lia->is_upper()); IF_VERBOSE(4, verbose_stream() << "cut " << b << "\n"); literal lit = expr2literal(b); - assign(lit, m_core, m_eqs, explain(sat::hint_type::cut_h, lit)); + assign(lit, m_core, m_eqs, explain(sat::hint_type::bound_h, lit)); lia_check = l_false; break; } diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 8df8b7ec2..858d19ab8 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -419,9 +419,11 @@ namespace arith { void false_case_of_check_nla(const nla::lemma& l); void dbg_finalize_model(model& mdl); - sat::proof_hint m_bounds_pragma; + sat::proof_hint m_arith_hint; sat::proof_hint m_farkas2; sat::proof_hint const* explain(sat::hint_type ty, sat::literal lit = sat::null_literal); + sat::proof_hint const* explain_implied_eq(euf::enode* a, euf::enode* b); + void add_assumptions(); public: diff --git a/src/shell/drat_frontend.cpp b/src/shell/drat_frontend.cpp index 1886c6dfc..8c486cca8 100644 --- a/src/shell/drat_frontend.cpp +++ b/src/shell/drat_frontend.cpp @@ -170,21 +170,26 @@ public: switch (hint.m_ty) { case sat::hint_type::null_h: break; - case sat::hint_type::cut_h: case sat::hint_type::bound_h: - case sat::hint_type::farkas_h: { + case sat::hint_type::farkas_h: + case sat::hint_type::implied_eq_h: { achecker.reset(); - for (auto const& [coeff, a, b]: hint.m_eqs) { + for (auto const& [a, b]: hint.m_eqs) { expr* x = exprs[a]; expr* y = exprs[b]; achecker.add_eq(x, y); } + for (auto const& [a, b]: hint.m_diseqs) { + expr* x = exprs[a]; + expr* y = exprs[b]; + achecker.add_diseq(x, y); + } unsigned sz = hint.m_literals.size(); for (unsigned i = 0; i < sz; ++i) { auto const& [coeff, lit] = hint.m_literals[i]; app_ref e(to_app(m_b2e[lit.var()]), m); - if (i + 1 == sz && (sat::hint_type::bound_h == hint.m_ty || sat::hint_type::cut_h == hint.m_ty)) { + if (i + 1 == sz && sat::hint_type::bound_h == hint.m_ty) { if (!achecker.add_conseq(coeff, e, lit.sign())) { std::cout << "p failed checking hint " << e << "\n"; return false; @@ -220,12 +225,18 @@ public: rw(sum); std::cout << "sum: " << sum << "\n"; - for (auto const& [coeff, a, b]: hint.m_eqs) { + for (auto const& [a, b]: hint.m_eqs) { expr* x = exprs[a]; expr* y = exprs[b]; app_ref e(m.mk_eq(x, y), m); std::cout << e << "\n"; } + for (auto const& [a, b]: hint.m_diseqs) { + expr* x = exprs[a]; + expr* y = exprs[b]; + app_ref e(m.mk_not(m.mk_eq(x, y)), m); + std::cout << e << "\n"; + } for (auto const& [coeff, lit] : hint.m_literals) { app_ref e(to_app(m_b2e[lit.var()]), m); if (lit.sign()) e = m.mk_not(e);