diff --git a/src/ast/ast.h b/src/ast/ast.h index 512501226..511c2cee0 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -731,6 +731,8 @@ public: unsigned get_num_args() const { return m_num_args; } expr * get_arg(unsigned idx) const { SASSERT(idx < m_num_args); return m_args[idx]; } expr * const * get_args() const { return m_args; } + std::tuple args2() const { SASSERT(m_num_args == 2); return {get_arg(0), get_arg(1)}; } + std::tuple args3() const { SASSERT(m_num_args == 3); return {get_arg(0), get_arg(1), get_arg(2)}; } unsigned get_size() const { return get_obj_size(get_num_args()); } expr * const * begin() const { return m_args; } expr * const * end() const { return m_args + m_num_args; } diff --git a/src/cmd_context/extra_cmds/proof_cmds.cpp b/src/cmd_context/extra_cmds/proof_cmds.cpp index 03cb3f95c..e429358fa 100644 --- a/src/cmd_context/extra_cmds/proof_cmds.cpp +++ b/src/cmd_context/extra_cmds/proof_cmds.cpp @@ -181,6 +181,141 @@ public: }; +namespace sat { + /** + * Replay proof entierly, then walk backwards extracting reduced proof. + */ + class proof_trim { + cmd_context& ctx; + ast_manager& m; + solver s; + literal_vector m_clause; + struct hash { + unsigned operator()(literal_vector const& v) const { + return string_hash((char const*)v.begin(), v.size()*sizeof(literal), 3); + } + }; + struct eq { + bool operator()(literal_vector const& a, literal_vector const& b) const { + return a == b; + } + }; + map m_clauses; + + void mk_clause(expr_ref_vector const& clause) { + m_clause.reset(); + for (expr* arg: clause) + add_literal(arg); + std::sort(m_clause.begin(), m_clause.end()); + } + + bool_var mk_var(expr* arg) { + while (arg->get_id() >= s.num_vars()) + s.mk_var(true, true); + return arg->get_id(); + } + + void add_literal(expr* arg) { + bool sign = m.is_not(arg, arg); + m_clause.push_back(literal(mk_var(arg), sign)); + } + + + /** + Pseudo-code from Gurfinkel, Vizel, FMCAD 2014 + Input: trail (a0,d0), ..., (an,dn) = ({},bot) + Output: reduced trail - result + result = [] + C = an + for i = n to 0 do + if s.is_deleted(ai) then s.revive(ai) + else + if s.isontrail(ai) then + s.undotrailcore(ai,C) + s.delete(ai) + if ai in C then + if ai is not initial then + s.savetrail() + s.enqueue(not ai) + c = s.propagate() + s.conflictanalysiscore(c, C) + s.restoretrail() + result += [ai] + reverse(result) + + is_deleted(ai): + clause was detached + revive(ai): + attach clause ai + isontrail(ai): + some literal on the current trail in s is justified by ai + undotrailcore(ai, C): + pop the trail until dependencies on ai are gone + savetrail: + store current trail so it can be restored + enqueue(not ai): + assert negations of ai at a new decision level + conflictanalysiscore(c, C): + ? + restoretrail: + restore the trail to the position before enqueue + + + + */ + void trim() { + + } + + public: + proof_trim(cmd_context& ctx): + ctx(ctx), + m(ctx.m()), + s(gparams::get_module("sat"), m.limit()) { + + } + + void assume(expr_ref_vector const& _clause) { + mk_clause(_clause); + IF_VERBOSE(3, verbose_stream() << "add: " << m_clause << "\n"); + auto* cl = s.mk_clause(m_clause, status::redundant()); + s.propagate(false); + if (!cl) + return; + IF_VERBOSE(3, verbose_stream() << "add: " << *cl << "\n"); + auto& v = m_clauses.insert_if_not_there(m_clause, clause_vector()); + v.push_back(cl); + } + + void del(expr_ref_vector const& _clause) { + mk_clause(_clause); + IF_VERBOSE(3, verbose_stream() << "del: " << m_clause << "\n"); + if (m_clause.size() == 2) { + s.detach_bin_clause(m_clause[0], m_clause[1], true); + return; + } + auto* e = m_clauses.find_core(m_clause); + if (!e) + return; + auto& v = e->get_data().m_value; + if (!v.empty()) { + IF_VERBOSE(3, verbose_stream() << "del: " << *v.back() << "\n"); + s.detach_clause(*v.back()); + v.pop_back(); + } + } + + void infer(expr_ref_vector const& _clause, app*) { + assume(_clause); + } + + void updt_params(params_ref const& p) { + s.updt_params(p); + } + + }; +} + class proof_saver { cmd_context& ctx; @@ -218,10 +353,11 @@ class proof_cmds_imp : public proof_cmds { bool m_trim = false; scoped_ptr m_checker; scoped_ptr m_saver; + scoped_ptr m_trimmer; smt_checker& checker() { if (!m_checker) m_checker = alloc(smt_checker, m); return *m_checker; } proof_saver& saver() { if (!m_saver) m_saver = alloc(proof_saver, ctx); return *m_saver; } - + sat::proof_trim& trim() { if (!m_trimmer) m_trimmer = alloc(sat::proof_trim, ctx); return *m_trimmer; } public: proof_cmds_imp(cmd_context& ctx): ctx(ctx), m(ctx.m()), m_lits(m), m_proof_hint(m) { @@ -240,6 +376,8 @@ public: checker().assume(m_lits); if (m_save) saver().assume(m_lits); + if (m_trim) + trim().assume(m_lits); m_lits.reset(); m_proof_hint.reset(); } @@ -249,6 +387,8 @@ public: checker().check(m_lits, m_proof_hint); if (m_save) saver().infer(m_lits, m_proof_hint); + if (m_trim) + trim().infer(m_lits, m_proof_hint); m_lits.reset(); m_proof_hint.reset(); } @@ -258,6 +398,8 @@ public: checker().del(m_lits); if (m_save) saver().del(m_lits); + if (m_trim) + trim().del(m_lits); m_lits.reset(); m_proof_hint.reset(); } @@ -266,6 +408,9 @@ public: solver_params sp(p); m_check = sp.proof_check(); m_save = sp.proof_save(); + m_trim = sp.proof_trim(); + if (m_trim) + trim().updt_params(p); } }; diff --git a/src/params/solver_params.pyg b/src/params/solver_params.pyg index 6e33ca6d7..9ff13864d 100644 --- a/src/params/solver_params.pyg +++ b/src/params/solver_params.pyg @@ -10,5 +10,6 @@ def_module_params('solver', ('axioms2files', BOOL, False, 'print negated theory axioms to separate files during search'), ('proof.check', BOOL, True, 'check proof logs'), ('proof.save', BOOL, False, 'save proof log into a proof object that can be extracted using (get-proof)'), + ('proof.trim', BOOL, False, 'trim and save proof into a proof object that an be extracted using (get-proof)'), )) diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 0b01b777c..5c413ce09 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -235,6 +235,7 @@ namespace sat { friend class aig_finder; friend class lut_finder; friend class npn3_finder; + friend class proof_trim; public: solver(params_ref const & p, reslimit& l); ~solver() override; diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 2b02015e0..93917042e 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -424,7 +424,7 @@ namespace arith { ++m_stats.m_assert_diseq; add_farkas_clause(~eq, le); add_farkas_clause(~eq, ge); - add_clause(~le, ~ge, eq, explain_triangle_eq(le, ge, eq)); + add_clause(~le, ~ge, eq, explain_trichotomy(le, ge, eq)); } diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index cba691319..ec9b11e76 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -129,14 +129,16 @@ namespace arith { return nullptr; m_arith_hint.set_type(ctx, hint_type::implied_eq_h); explain_assumptions(); + m_arith_hint.set_num_le(1); // TODO m_arith_hint.add_diseq(a, b); return m_arith_hint.mk(ctx); } - arith_proof_hint const* solver::explain_triangle_eq(sat::literal le, sat::literal ge, sat::literal eq) { + arith_proof_hint const* solver::explain_trichotomy(sat::literal le, sat::literal ge, sat::literal eq) { if (!ctx.use_drat()) return nullptr; m_arith_hint.set_type(ctx, hint_type::implied_eq_h); + m_arith_hint.set_num_le(1); m_arith_hint.add_lit(rational(1), le); m_arith_hint.add_lit(rational(1), ge); m_arith_hint.add_lit(rational(1), ~eq); @@ -149,6 +151,9 @@ namespace arith { arith_util arith(m); solver& a = dynamic_cast(*s.fid2solver(fid)); char const* name; + expr_ref_vector args(m); + sort_ref_vector sorts(m); + switch (m_ty) { case hint_type::farkas_h: name = "farkas"; @@ -158,15 +163,14 @@ namespace arith { break; case hint_type::implied_eq_h: name = "implied-eq"; + args.push_back(arith.mk_int(m_num_le)); break; } rational lc(1); for (unsigned i = m_lit_head; i < m_lit_tail; ++i) lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first)); - - expr_ref_vector args(m); - sort_ref_vector sorts(m); - for (unsigned i = m_lit_head; i < m_lit_tail; ++i) { + + for (unsigned i = m_lit_head; i < m_lit_tail; ++i) { auto const& [coeff, lit] = a.m_arith_hint.lit(i); args.push_back(arith.mk_int(abs(coeff*lc))); args.push_back(s.literal2expr(lit)); diff --git a/src/sat/smt/arith_proof_checker.h b/src/sat/smt/arith_proof_checker.h index 56e4cf1f8..df1dc00f1 100644 --- a/src/sat/smt/arith_proof_checker.h +++ b/src/sat/smt/arith_proof_checker.h @@ -18,7 +18,7 @@ Notes: The module assumes a limited repertoire of arithmetic proof rules. - farkas - inequalities, equalities and disequalities with coefficients -- implied-eq - last literal is a disequality. The literals before imply the corresponding equality. +- implied-eq - last literal is a disequality. The literals before imply the complementary equality. - bound - last literal is a bound. It is implied by prior literals. --*/ @@ -26,8 +26,10 @@ The module assumes a limited repertoire of arithmetic proof rules. #include "util/obj_pair_set.h" #include "ast/ast_trail.h" +#include "ast/ast_util.h" #include "ast/arith_decl_plugin.h" #include "sat/smt/euf_proof_checker.h" +#include namespace arith { @@ -49,8 +51,6 @@ namespace arith { row m_ineq; row m_conseq; vector m_eqs; - vector m_ineqs; - vector m_diseqs; symbol m_farkas; symbol m_implied_eq; symbol m_bound; @@ -261,26 +261,6 @@ namespace arith { return false; } - // - // checking disequalities is TBD. - // it has to select only a subset of bounds to justify each inequality. - // example - // c <= x <= c, c <= y <= c => x = y - // for the proof of x <= y use the inequalities x <= c <= y - // for the proof of y <= x use the inequalities y <= c <= x - // example - // x <= y, y <= z, z <= u, u <= x => x = z - // for the proof of x <= z use the inequalities x <= y, y <= z - // for the proof of z <= x use the inequalities z <= u, u <= x - // - // so when m_diseqs is non-empty we can't just add inequalities with Farkas coefficients - // into m_ineq, since coefficients of the usable subset vanish. - // - - bool check_diseq() { - return false; - } - std::ostream& display_row(std::ostream& out, row const& r) { bool first = true; for (auto const& [v, coeff] : r.m_coeffs) { @@ -329,16 +309,11 @@ namespace arith { m_ineq.reset(); m_conseq.reset(); m_eqs.reset(); - m_ineqs.reset(); - m_diseqs.reset(); m_strict = false; } bool add_ineq(rational const& coeff, expr* e, bool sign) { - if (!m_diseqs.empty()) - return add_literal(fresh(m_ineqs), abs(coeff), e, sign); - else - return add_literal(m_ineq, abs(coeff), e, sign); + return add_literal(m_ineq, abs(coeff), e, sign); } bool add_conseq(rational const& coeff, expr* e, bool sign) { @@ -350,20 +325,12 @@ namespace arith { linearize(r, rational(1), a); linearize(r, rational(-1), b); } - - void add_diseq(expr* a, expr* b) { - row& r = fresh(m_diseqs); - linearize(r, rational(1), a); - linearize(r, rational(-1), b); - } bool check() { - if (!m_diseqs.empty()) - return check_diseq(); - else if (!m_conseq.m_coeffs.empty()) - return check_bound(); - else + if (m_conseq.m_coeffs.empty()) return check_farkas(); + else + return check_bound(); } std::ostream& display(std::ostream& out) { @@ -375,14 +342,41 @@ namespace arith { return out; } - bool check(expr_ref_vector const& clause, app* jst, expr_ref_vector& units) override { + expr_ref_vector clause(app* jst) override { + expr_ref_vector result(m); + for (expr* arg : *jst) + if (m.is_bool(arg)) + result.push_back(mk_not(m, arg)); + return result; + } + + /** + Add implied equality as an inequality + */ + bool add_implied_ineq(bool sign, app* jst) { + unsigned n = jst->get_num_args(); + if (n < 2) + return false; + expr* arg1 = jst->get_arg(n - 2); + expr* arg2 = jst->get_arg(n - 1); + rational coeff; + if (!a.is_numeral(arg1, coeff)) + return false; + if (!m.is_not(arg2, arg2)) + return false; + if (!m.is_eq(arg2, arg1, arg2)) + return false; + if (!sign) + coeff.neg(); + auto& r = m_ineq; + linearize(r, coeff, arg1); + linearize(r, -coeff, arg2); + m_strict = true; + return true; + } + + bool check(app* jst) override { reset(); - expr_mark pos, neg; - for (expr* e : clause) - if (m.is_not(e, e)) - neg.mark(e, true); - else - pos.mark(e, true); bool is_bound = jst->get_name() == m_bound; bool is_implied_eq = jst->get_name() == m_implied_eq; bool is_farkas = jst->get_name() == m_farkas; @@ -393,25 +387,51 @@ namespace arith { bool even = true; rational coeff; expr* x, * y; - unsigned j = 0; + unsigned j = 0, num_le = 0; + + for (expr* arg : *jst) { if (even) { if (!a.is_numeral(arg, coeff)) { IF_VERBOSE(0, verbose_stream() << "not numeral " << mk_pp(jst, m) << "\n"); return false; } + if (is_implied_eq) { + is_implied_eq = false; + if (!coeff.is_unsigned()) { + IF_VERBOSE(0, verbose_stream() << "not unsigned " << mk_pp(jst, m) << "\n"); + return false; + } + num_le = coeff.get_unsigned(); + if (!add_implied_ineq(false, jst)) + return false; + ++j; + continue; + } } else { bool sign = m.is_not(arg, arg); if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) { if (is_bound && j + 1 == jst->get_num_args()) add_conseq(coeff, arg, sign); + else if (num_le > 0) { + add_ineq(coeff, arg, sign); + --num_le; + if (num_le == 0) { + // we processed all the first inequalities, + // check that they imply one half of the implied equality. + if (!check()) + return false; + reset(); + VERIFY(add_implied_ineq(true, jst)); + } + } else add_ineq(coeff, arg, sign); } else if (m.is_eq(arg, x, y)) { - if (sign) - add_diseq(x, y); + if (sign) + return check(); // it should be an implied equality else add_eq(x, y); } @@ -419,23 +439,11 @@ namespace arith { IF_VERBOSE(0, verbose_stream() << "not a recognized arithmetical relation " << mk_pp(arg, m) << "\n"); return false; } - - if (sign && !pos.is_marked(arg)) { - units.push_back(m.mk_not(arg)); - pos.mark(arg, false); - } - else if (!sign && !neg.is_marked(arg)) { - units.push_back(arg); - neg.mark(arg, false); - } } even = !even; ++j; } - if (check()) - return true; - - return false; + return check(); } void register_plugins(euf::proof_checker& pc) override { diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 2142874d5..4fdb57386 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -51,14 +51,15 @@ namespace arith { enum class hint_type { farkas_h, bound_h, - implied_eq_h + implied_eq_h }; struct arith_proof_hint : public euf::th_proof_hint { - hint_type m_ty; - unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail; - arith_proof_hint(hint_type t, unsigned lh, unsigned lt, unsigned eh, unsigned et): - m_ty(t), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} + hint_type m_ty; + unsigned m_num_le; + unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail; + arith_proof_hint(hint_type t, unsigned num_le, unsigned lh, unsigned lt, unsigned eh, unsigned et): + m_ty(t), m_num_le(num_le), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} expr* get_hint(euf::solver& s) const override; }; @@ -66,6 +67,7 @@ namespace arith { vector> m_literals; svector> m_eqs; hint_type m_ty; + unsigned m_num_le = 0; unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0; void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; } void add(euf::enode* a, euf::enode* b, bool is_eq) { @@ -82,6 +84,7 @@ namespace arith { m_ty = ty; reset(); } + void set_num_le(unsigned n) { m_num_le = n; } void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); } void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); } void add_lit(rational const& coeff, literal lit) { @@ -94,7 +97,7 @@ namespace arith { std::pair const& lit(unsigned i) const { return m_literals[i]; } std::tuple const& eq(unsigned i) const { return m_eqs[i]; } arith_proof_hint* mk(euf::solver& s) { - return new (s.get_region()) arith_proof_hint(m_ty, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); + return new (s.get_region()) arith_proof_hint(m_ty, m_num_le, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); } }; @@ -474,7 +477,7 @@ namespace arith { arith_proof_hint const* explain(hint_type ty, sat::literal lit = sat::null_literal); arith_proof_hint const* explain_implied_eq(euf::enode* a, euf::enode* b); - arith_proof_hint const* explain_triangle_eq(sat::literal le, sat::literal ge, sat::literal eq); + arith_proof_hint const* explain_trichotomy(sat::literal le, sat::literal ge, sat::literal eq); void explain_assumptions(); diff --git a/src/sat/smt/euf_proof_checker.cpp b/src/sat/smt/euf_proof_checker.cpp index d0a4e4ab1..4774e154f 100644 --- a/src/sat/smt/euf_proof_checker.cpp +++ b/src/sat/smt/euf_proof_checker.cpp @@ -17,6 +17,7 @@ Author: #include "util/union_find.h" #include "ast/ast_pp.h" +#include "ast/ast_util.h" #include "ast/ast_ll_pp.h" #include "sat/smt/euf_proof_checker.h" #include "sat/smt/arith_proof_checker.h" @@ -120,24 +121,23 @@ namespace euf { ~eq_proof_checker() override {} - bool check(expr_ref_vector const& clause, app* jst, expr_ref_vector& units) override { - IF_VERBOSE(10, verbose_stream() << clause << "\n" << mk_pp(jst, m) << "\n"); + expr_ref_vector clause(app* jst) override { + expr_ref_vector result(m); + for (expr* arg : *jst) + if (m.is_bool(arg)) + result.push_back(mk_not(m, arg)); + return result; + } + + bool check(app* jst) override { + IF_VERBOSE(10, verbose_stream() << mk_pp(jst, m) << "\n"); reset(); - expr_mark pos, neg; - expr* x, *y; - for (expr* e : clause) - if (m.is_not(e, e)) - neg.mark(e, true); - else - pos.mark(e, true); for (expr* arg : *jst) { - if (m.is_bool(arg)) { - bool sign = m.is_not(arg, arg); - if (sign && !pos.is_marked(arg)) - units.push_back(m.mk_not(arg)); - else if (!sign & !neg.is_marked(arg)) - units.push_back(arg); + expr* x, *y; + bool sign = m.is_not(arg, arg); + + if (m.is_bool(arg)) { if (m.is_eq(arg, x, y)) { if (sign) m_diseqs.push_back({x, y}); @@ -198,38 +198,144 @@ namespace euf { void register_plugins(proof_checker& pc) override { pc.register_plugin(symbol("euf"), this); } + }; + /** + A resolution proof term is of the form + (res pivot proof1 proof2) + The pivot occurs with opposite signs in proof1 and proof2 + */ + + class res_proof_checker : public proof_checker_plugin { + ast_manager& m; + proof_checker& pc; + + public: + res_proof_checker(ast_manager& m, proof_checker& pc): m(m), pc(pc) {} + + ~res_proof_checker() override {} + + bool check(app* jst) override { + if (jst->get_num_args() != 3) + return false; + auto [pivot, proof1, proof2] = jst->args3(); + if (!m.is_bool(pivot) || !m.is_proof(proof1) || !m.is_proof(proof2)) + return false; + expr* narg; + bool found1 = false, found2 = false, found3 = false, found4 = false; + for (expr* arg : pc.clause(proof1)) { + found1 |= arg == pivot; + found2 |= m.is_not(arg, narg) && narg == pivot; + } + if (found1 == found2) + return false; + + for (expr* arg : pc.clause(proof2)) { + found3 |= arg == pivot; + found4 |= m.is_not(arg, narg) && narg == pivot; + } + if (found3 == found4) + return false; + if (found3 == found1) + return false; + return pc.check(proof1) && pc.check(proof2); + } + + expr_ref_vector clause(app* jst) override { + expr_ref_vector result(m); + auto [pivot, proof1, proof2] = jst->args3(); + expr* narg; + auto is_pivot = [&](expr* arg) { + if (arg == pivot) + return true; + return m.is_not(arg, narg) && narg == pivot; + }; + for (expr* arg : pc.clause(proof1)) + if (!is_pivot(arg)) + result.push_back(arg); + for (expr* arg : pc.clause(proof2)) + if (!is_pivot(arg)) + result.push_back(arg); + return result; + } + + void register_plugins(proof_checker& pc) override { + pc.register_plugin(symbol("res"), this); + } }; proof_checker::proof_checker(ast_manager& m): m(m) { - arith::proof_checker* apc = alloc(arith::proof_checker, m); - eq_proof_checker* epc = alloc(eq_proof_checker, m); - m_plugins.push_back(apc); - m_plugins.push_back(epc); - apc->register_plugins(*this); - epc->register_plugins(*this); + add_plugin(alloc(arith::proof_checker, m)); + add_plugin(alloc(eq_proof_checker, m)); + add_plugin(alloc(res_proof_checker, m, *this)); } - proof_checker::~proof_checker() {} + proof_checker::~proof_checker() { + for (auto& [k, v] : m_checked_clauses) + dealloc(v); + } + + void proof_checker::add_plugin(proof_checker_plugin* p) { + m_plugins.push_back(p); + p->register_plugins(*this); + } void proof_checker::register_plugin(symbol const& rule, proof_checker_plugin* p) { m_map.insert(rule, p); } - bool proof_checker::check(expr_ref_vector const& clause, expr* e, expr_ref_vector& units) { + bool proof_checker::check(expr* e) { + if (m_checked_clauses.contains(e)) + return true; + if (!e || !is_app(e)) return false; - units.reset(); app* a = to_app(e); proof_checker_plugin* p = nullptr; if (!m_map.find(a->get_decl()->get_name(), p)) return false; - if (p->check(clause, a, units)) - return true; - - std::cout << "(missed-hint " << mk_pp(e, m) << ")\n"; - return false; + if (!p->check(a)) { + std::cout << "(missed-hint " << mk_pp(e, m) << ")\n"; + return false; + } + return true; + } + + expr_ref_vector proof_checker::clause(expr* e) { + expr_ref_vector* rr; + if (m_checked_clauses.find(e, rr)) + return *rr; + SASSERT(is_app(e) && m_map.contains(to_app(e)->get_decl()->get_name())); + auto& r = m_map[to_app(e)->get_decl()->get_name()]->clause(to_app(e)); + m_checked_clauses.insert(e, alloc(expr_ref_vector, r)); + return r; + } + + bool proof_checker::check(expr_ref_vector const& clause1, expr* e, expr_ref_vector & units) { + if (!check(e)) + return false; + units.reset(); + expr_mark literals; + auto clause2 = clause(e); + + // check that all literals in clause1 are in clause2 + for (expr* arg : clause2) + literals.mark(arg, true); + for (expr* arg : clause1) + if (!literals.is_marked(arg)) + return false; + + // extract negated units for literals in clause2 but not in clause1 + // the literals should be rup + literals.reset(); + for (expr* arg : clause1) + literals.mark(arg, true); + for (expr* arg : clause2) + if (!literals.is_marked(arg)) + units.push_back(mk_not(m, arg)); + + return true; } } diff --git a/src/sat/smt/euf_proof_checker.h b/src/sat/smt/euf_proof_checker.h index 464d90559..023bfae48 100644 --- a/src/sat/smt/euf_proof_checker.h +++ b/src/sat/smt/euf_proof_checker.h @@ -27,18 +27,23 @@ namespace euf { class proof_checker_plugin { public: virtual ~proof_checker_plugin() {} - virtual bool check(expr_ref_vector const& clause, app* jst, expr_ref_vector& units) = 0; + virtual bool check(app* jst) = 0; + virtual expr_ref_vector clause(app* jst) = 0; virtual void register_plugins(proof_checker& pc) = 0; }; class proof_checker { ast_manager& m; - scoped_ptr_vector m_plugins; - map m_map; + scoped_ptr_vector m_plugins; // plugins of proof checkers + map m_map; // symbol table of proof checkers + obj_map m_checked_clauses; // cache of previously checked proofs and their clauses. + void add_plugin(proof_checker_plugin* p); public: proof_checker(ast_manager& m); ~proof_checker(); void register_plugin(symbol const& rule, proof_checker_plugin*); + bool check(expr* jst); + expr_ref_vector clause(expr* jst); bool check(expr_ref_vector const& clause, expr* e, expr_ref_vector& units); };