3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-22 13:53:39 +00:00

work on proof checking

- add outline of trim routine
- streamline how proof terms are checked and how residue units are extracted.
This commit is contained in:
Nikolaj Bjorner 2022-09-30 13:04:19 -04:00
parent ccda49bad5
commit b9cba82531
10 changed files with 384 additions and 109 deletions

View file

@ -731,6 +731,8 @@ public:
unsigned get_num_args() const { return m_num_args; } unsigned get_num_args() const { return m_num_args; }
expr * get_arg(unsigned idx) const { SASSERT(idx < m_num_args); return m_args[idx]; } expr * get_arg(unsigned idx) const { SASSERT(idx < m_num_args); return m_args[idx]; }
expr * const * get_args() const { return m_args; } expr * const * get_args() const { return m_args; }
std::tuple<expr*,expr*> args2() const { SASSERT(m_num_args == 2); return {get_arg(0), get_arg(1)}; }
std::tuple<expr*,expr*,expr*> args3() const { SASSERT(m_num_args == 3); return {get_arg(0), get_arg(1), get_arg(2)}; }
unsigned get_size() const { return get_obj_size(get_num_args()); } unsigned get_size() const { return get_obj_size(get_num_args()); }
expr * const * begin() const { return m_args; } expr * const * begin() const { return m_args; }
expr * const * end() const { return m_args + m_num_args; } expr * const * end() const { return m_args + m_num_args; }

View file

@ -181,6 +181,141 @@ public:
}; };
namespace sat {
/**
* Replay proof entierly, then walk backwards extracting reduced proof.
*/
class proof_trim {
cmd_context& ctx;
ast_manager& m;
solver s;
literal_vector m_clause;
struct hash {
unsigned operator()(literal_vector const& v) const {
return string_hash((char const*)v.begin(), v.size()*sizeof(literal), 3);
}
};
struct eq {
bool operator()(literal_vector const& a, literal_vector const& b) const {
return a == b;
}
};
map<literal_vector, clause_vector, hash, eq> m_clauses;
void mk_clause(expr_ref_vector const& clause) {
m_clause.reset();
for (expr* arg: clause)
add_literal(arg);
std::sort(m_clause.begin(), m_clause.end());
}
bool_var mk_var(expr* arg) {
while (arg->get_id() >= s.num_vars())
s.mk_var(true, true);
return arg->get_id();
}
void add_literal(expr* arg) {
bool sign = m.is_not(arg, arg);
m_clause.push_back(literal(mk_var(arg), sign));
}
/**
Pseudo-code from Gurfinkel, Vizel, FMCAD 2014
Input: trail (a0,d0), ..., (an,dn) = ({},bot)
Output: reduced trail - result
result = []
C = an
for i = n to 0 do
if s.is_deleted(ai) then s.revive(ai)
else
if s.isontrail(ai) then
s.undotrailcore(ai,C)
s.delete(ai)
if ai in C then
if ai is not initial then
s.savetrail()
s.enqueue(not ai)
c = s.propagate()
s.conflictanalysiscore(c, C)
s.restoretrail()
result += [ai]
reverse(result)
is_deleted(ai):
clause was detached
revive(ai):
attach clause ai
isontrail(ai):
some literal on the current trail in s is justified by ai
undotrailcore(ai, C):
pop the trail until dependencies on ai are gone
savetrail:
store current trail so it can be restored
enqueue(not ai):
assert negations of ai at a new decision level
conflictanalysiscore(c, C):
?
restoretrail:
restore the trail to the position before enqueue
*/
void trim() {
}
public:
proof_trim(cmd_context& ctx):
ctx(ctx),
m(ctx.m()),
s(gparams::get_module("sat"), m.limit()) {
}
void assume(expr_ref_vector const& _clause) {
mk_clause(_clause);
IF_VERBOSE(3, verbose_stream() << "add: " << m_clause << "\n");
auto* cl = s.mk_clause(m_clause, status::redundant());
s.propagate(false);
if (!cl)
return;
IF_VERBOSE(3, verbose_stream() << "add: " << *cl << "\n");
auto& v = m_clauses.insert_if_not_there(m_clause, clause_vector());
v.push_back(cl);
}
void del(expr_ref_vector const& _clause) {
mk_clause(_clause);
IF_VERBOSE(3, verbose_stream() << "del: " << m_clause << "\n");
if (m_clause.size() == 2) {
s.detach_bin_clause(m_clause[0], m_clause[1], true);
return;
}
auto* e = m_clauses.find_core(m_clause);
if (!e)
return;
auto& v = e->get_data().m_value;
if (!v.empty()) {
IF_VERBOSE(3, verbose_stream() << "del: " << *v.back() << "\n");
s.detach_clause(*v.back());
v.pop_back();
}
}
void infer(expr_ref_vector const& _clause, app*) {
assume(_clause);
}
void updt_params(params_ref const& p) {
s.updt_params(p);
}
};
}
class proof_saver { class proof_saver {
cmd_context& ctx; cmd_context& ctx;
@ -218,10 +353,11 @@ class proof_cmds_imp : public proof_cmds {
bool m_trim = false; bool m_trim = false;
scoped_ptr<smt_checker> m_checker; scoped_ptr<smt_checker> m_checker;
scoped_ptr<proof_saver> m_saver; scoped_ptr<proof_saver> m_saver;
scoped_ptr<sat::proof_trim> m_trimmer;
smt_checker& checker() { if (!m_checker) m_checker = alloc(smt_checker, m); return *m_checker; } smt_checker& checker() { if (!m_checker) m_checker = alloc(smt_checker, m); return *m_checker; }
proof_saver& saver() { if (!m_saver) m_saver = alloc(proof_saver, ctx); return *m_saver; } proof_saver& saver() { if (!m_saver) m_saver = alloc(proof_saver, ctx); return *m_saver; }
sat::proof_trim& trim() { if (!m_trimmer) m_trimmer = alloc(sat::proof_trim, ctx); return *m_trimmer; }
public: public:
proof_cmds_imp(cmd_context& ctx): ctx(ctx), m(ctx.m()), m_lits(m), m_proof_hint(m) { proof_cmds_imp(cmd_context& ctx): ctx(ctx), m(ctx.m()), m_lits(m), m_proof_hint(m) {
@ -240,6 +376,8 @@ public:
checker().assume(m_lits); checker().assume(m_lits);
if (m_save) if (m_save)
saver().assume(m_lits); saver().assume(m_lits);
if (m_trim)
trim().assume(m_lits);
m_lits.reset(); m_lits.reset();
m_proof_hint.reset(); m_proof_hint.reset();
} }
@ -249,6 +387,8 @@ public:
checker().check(m_lits, m_proof_hint); checker().check(m_lits, m_proof_hint);
if (m_save) if (m_save)
saver().infer(m_lits, m_proof_hint); saver().infer(m_lits, m_proof_hint);
if (m_trim)
trim().infer(m_lits, m_proof_hint);
m_lits.reset(); m_lits.reset();
m_proof_hint.reset(); m_proof_hint.reset();
} }
@ -258,6 +398,8 @@ public:
checker().del(m_lits); checker().del(m_lits);
if (m_save) if (m_save)
saver().del(m_lits); saver().del(m_lits);
if (m_trim)
trim().del(m_lits);
m_lits.reset(); m_lits.reset();
m_proof_hint.reset(); m_proof_hint.reset();
} }
@ -266,6 +408,9 @@ public:
solver_params sp(p); solver_params sp(p);
m_check = sp.proof_check(); m_check = sp.proof_check();
m_save = sp.proof_save(); m_save = sp.proof_save();
m_trim = sp.proof_trim();
if (m_trim)
trim().updt_params(p);
} }
}; };

View file

@ -10,5 +10,6 @@ def_module_params('solver',
('axioms2files', BOOL, False, 'print negated theory axioms to separate files during search'), ('axioms2files', BOOL, False, 'print negated theory axioms to separate files during search'),
('proof.check', BOOL, True, 'check proof logs'), ('proof.check', BOOL, True, 'check proof logs'),
('proof.save', BOOL, False, 'save proof log into a proof object that can be extracted using (get-proof)'), ('proof.save', BOOL, False, 'save proof log into a proof object that can be extracted using (get-proof)'),
('proof.trim', BOOL, False, 'trim and save proof into a proof object that an be extracted using (get-proof)'),
)) ))

View file

@ -235,6 +235,7 @@ namespace sat {
friend class aig_finder; friend class aig_finder;
friend class lut_finder; friend class lut_finder;
friend class npn3_finder; friend class npn3_finder;
friend class proof_trim;
public: public:
solver(params_ref const & p, reslimit& l); solver(params_ref const & p, reslimit& l);
~solver() override; ~solver() override;

View file

@ -424,7 +424,7 @@ namespace arith {
++m_stats.m_assert_diseq; ++m_stats.m_assert_diseq;
add_farkas_clause(~eq, le); add_farkas_clause(~eq, le);
add_farkas_clause(~eq, ge); add_farkas_clause(~eq, ge);
add_clause(~le, ~ge, eq, explain_triangle_eq(le, ge, eq)); add_clause(~le, ~ge, eq, explain_trichotomy(le, ge, eq));
} }

View file

@ -129,14 +129,16 @@ namespace arith {
return nullptr; return nullptr;
m_arith_hint.set_type(ctx, hint_type::implied_eq_h); m_arith_hint.set_type(ctx, hint_type::implied_eq_h);
explain_assumptions(); explain_assumptions();
m_arith_hint.set_num_le(1); // TODO
m_arith_hint.add_diseq(a, b); m_arith_hint.add_diseq(a, b);
return m_arith_hint.mk(ctx); return m_arith_hint.mk(ctx);
} }
arith_proof_hint const* solver::explain_triangle_eq(sat::literal le, sat::literal ge, sat::literal eq) { arith_proof_hint const* solver::explain_trichotomy(sat::literal le, sat::literal ge, sat::literal eq) {
if (!ctx.use_drat()) if (!ctx.use_drat())
return nullptr; return nullptr;
m_arith_hint.set_type(ctx, hint_type::implied_eq_h); m_arith_hint.set_type(ctx, hint_type::implied_eq_h);
m_arith_hint.set_num_le(1);
m_arith_hint.add_lit(rational(1), le); m_arith_hint.add_lit(rational(1), le);
m_arith_hint.add_lit(rational(1), ge); m_arith_hint.add_lit(rational(1), ge);
m_arith_hint.add_lit(rational(1), ~eq); m_arith_hint.add_lit(rational(1), ~eq);
@ -149,6 +151,9 @@ namespace arith {
arith_util arith(m); arith_util arith(m);
solver& a = dynamic_cast<solver&>(*s.fid2solver(fid)); solver& a = dynamic_cast<solver&>(*s.fid2solver(fid));
char const* name; char const* name;
expr_ref_vector args(m);
sort_ref_vector sorts(m);
switch (m_ty) { switch (m_ty) {
case hint_type::farkas_h: case hint_type::farkas_h:
name = "farkas"; name = "farkas";
@ -158,14 +163,13 @@ namespace arith {
break; break;
case hint_type::implied_eq_h: case hint_type::implied_eq_h:
name = "implied-eq"; name = "implied-eq";
args.push_back(arith.mk_int(m_num_le));
break; break;
} }
rational lc(1); rational lc(1);
for (unsigned i = m_lit_head; i < m_lit_tail; ++i) for (unsigned i = m_lit_head; i < m_lit_tail; ++i)
lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first)); lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first));
expr_ref_vector args(m);
sort_ref_vector sorts(m);
for (unsigned i = m_lit_head; i < m_lit_tail; ++i) { for (unsigned i = m_lit_head; i < m_lit_tail; ++i) {
auto const& [coeff, lit] = a.m_arith_hint.lit(i); auto const& [coeff, lit] = a.m_arith_hint.lit(i);
args.push_back(arith.mk_int(abs(coeff*lc))); args.push_back(arith.mk_int(abs(coeff*lc)));

View file

@ -18,7 +18,7 @@ Notes:
The module assumes a limited repertoire of arithmetic proof rules. The module assumes a limited repertoire of arithmetic proof rules.
- farkas - inequalities, equalities and disequalities with coefficients - farkas - inequalities, equalities and disequalities with coefficients
- implied-eq - last literal is a disequality. The literals before imply the corresponding equality. - implied-eq - last literal is a disequality. The literals before imply the complementary equality.
- bound - last literal is a bound. It is implied by prior literals. - bound - last literal is a bound. It is implied by prior literals.
--*/ --*/
@ -26,8 +26,10 @@ The module assumes a limited repertoire of arithmetic proof rules.
#include "util/obj_pair_set.h" #include "util/obj_pair_set.h"
#include "ast/ast_trail.h" #include "ast/ast_trail.h"
#include "ast/ast_util.h"
#include "ast/arith_decl_plugin.h" #include "ast/arith_decl_plugin.h"
#include "sat/smt/euf_proof_checker.h" #include "sat/smt/euf_proof_checker.h"
#include <iostream>
namespace arith { namespace arith {
@ -49,8 +51,6 @@ namespace arith {
row m_ineq; row m_ineq;
row m_conseq; row m_conseq;
vector<row> m_eqs; vector<row> m_eqs;
vector<row> m_ineqs;
vector<row> m_diseqs;
symbol m_farkas; symbol m_farkas;
symbol m_implied_eq; symbol m_implied_eq;
symbol m_bound; symbol m_bound;
@ -261,26 +261,6 @@ namespace arith {
return false; return false;
} }
//
// checking disequalities is TBD.
// it has to select only a subset of bounds to justify each inequality.
// example
// c <= x <= c, c <= y <= c => x = y
// for the proof of x <= y use the inequalities x <= c <= y
// for the proof of y <= x use the inequalities y <= c <= x
// example
// x <= y, y <= z, z <= u, u <= x => x = z
// for the proof of x <= z use the inequalities x <= y, y <= z
// for the proof of z <= x use the inequalities z <= u, u <= x
//
// so when m_diseqs is non-empty we can't just add inequalities with Farkas coefficients
// into m_ineq, since coefficients of the usable subset vanish.
//
bool check_diseq() {
return false;
}
std::ostream& display_row(std::ostream& out, row const& r) { std::ostream& display_row(std::ostream& out, row const& r) {
bool first = true; bool first = true;
for (auto const& [v, coeff] : r.m_coeffs) { for (auto const& [v, coeff] : r.m_coeffs) {
@ -329,16 +309,11 @@ namespace arith {
m_ineq.reset(); m_ineq.reset();
m_conseq.reset(); m_conseq.reset();
m_eqs.reset(); m_eqs.reset();
m_ineqs.reset();
m_diseqs.reset();
m_strict = false; m_strict = false;
} }
bool add_ineq(rational const& coeff, expr* e, bool sign) { bool add_ineq(rational const& coeff, expr* e, bool sign) {
if (!m_diseqs.empty()) return add_literal(m_ineq, abs(coeff), e, sign);
return add_literal(fresh(m_ineqs), abs(coeff), e, sign);
else
return add_literal(m_ineq, abs(coeff), e, sign);
} }
bool add_conseq(rational const& coeff, expr* e, bool sign) { bool add_conseq(rational const& coeff, expr* e, bool sign) {
@ -351,19 +326,11 @@ namespace arith {
linearize(r, rational(-1), b); linearize(r, rational(-1), b);
} }
void add_diseq(expr* a, expr* b) {
row& r = fresh(m_diseqs);
linearize(r, rational(1), a);
linearize(r, rational(-1), b);
}
bool check() { bool check() {
if (!m_diseqs.empty()) if (m_conseq.m_coeffs.empty())
return check_diseq();
else if (!m_conseq.m_coeffs.empty())
return check_bound();
else
return check_farkas(); return check_farkas();
else
return check_bound();
} }
std::ostream& display(std::ostream& out) { std::ostream& display(std::ostream& out) {
@ -375,14 +342,41 @@ namespace arith {
return out; return out;
} }
bool check(expr_ref_vector const& clause, app* jst, expr_ref_vector& units) override { expr_ref_vector clause(app* jst) override {
expr_ref_vector result(m);
for (expr* arg : *jst)
if (m.is_bool(arg))
result.push_back(mk_not(m, arg));
return result;
}
/**
Add implied equality as an inequality
*/
bool add_implied_ineq(bool sign, app* jst) {
unsigned n = jst->get_num_args();
if (n < 2)
return false;
expr* arg1 = jst->get_arg(n - 2);
expr* arg2 = jst->get_arg(n - 1);
rational coeff;
if (!a.is_numeral(arg1, coeff))
return false;
if (!m.is_not(arg2, arg2))
return false;
if (!m.is_eq(arg2, arg1, arg2))
return false;
if (!sign)
coeff.neg();
auto& r = m_ineq;
linearize(r, coeff, arg1);
linearize(r, -coeff, arg2);
m_strict = true;
return true;
}
bool check(app* jst) override {
reset(); reset();
expr_mark pos, neg;
for (expr* e : clause)
if (m.is_not(e, e))
neg.mark(e, true);
else
pos.mark(e, true);
bool is_bound = jst->get_name() == m_bound; bool is_bound = jst->get_name() == m_bound;
bool is_implied_eq = jst->get_name() == m_implied_eq; bool is_implied_eq = jst->get_name() == m_implied_eq;
bool is_farkas = jst->get_name() == m_farkas; bool is_farkas = jst->get_name() == m_farkas;
@ -393,25 +387,51 @@ namespace arith {
bool even = true; bool even = true;
rational coeff; rational coeff;
expr* x, * y; expr* x, * y;
unsigned j = 0; unsigned j = 0, num_le = 0;
for (expr* arg : *jst) { for (expr* arg : *jst) {
if (even) { if (even) {
if (!a.is_numeral(arg, coeff)) { if (!a.is_numeral(arg, coeff)) {
IF_VERBOSE(0, verbose_stream() << "not numeral " << mk_pp(jst, m) << "\n"); IF_VERBOSE(0, verbose_stream() << "not numeral " << mk_pp(jst, m) << "\n");
return false; return false;
} }
if (is_implied_eq) {
is_implied_eq = false;
if (!coeff.is_unsigned()) {
IF_VERBOSE(0, verbose_stream() << "not unsigned " << mk_pp(jst, m) << "\n");
return false;
}
num_le = coeff.get_unsigned();
if (!add_implied_ineq(false, jst))
return false;
++j;
continue;
}
} }
else { else {
bool sign = m.is_not(arg, arg); bool sign = m.is_not(arg, arg);
if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) { if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) {
if (is_bound && j + 1 == jst->get_num_args()) if (is_bound && j + 1 == jst->get_num_args())
add_conseq(coeff, arg, sign); add_conseq(coeff, arg, sign);
else if (num_le > 0) {
add_ineq(coeff, arg, sign);
--num_le;
if (num_le == 0) {
// we processed all the first inequalities,
// check that they imply one half of the implied equality.
if (!check())
return false;
reset();
VERIFY(add_implied_ineq(true, jst));
}
}
else else
add_ineq(coeff, arg, sign); add_ineq(coeff, arg, sign);
} }
else if (m.is_eq(arg, x, y)) { else if (m.is_eq(arg, x, y)) {
if (sign) if (sign)
add_diseq(x, y); return check(); // it should be an implied equality
else else
add_eq(x, y); add_eq(x, y);
} }
@ -419,23 +439,11 @@ namespace arith {
IF_VERBOSE(0, verbose_stream() << "not a recognized arithmetical relation " << mk_pp(arg, m) << "\n"); IF_VERBOSE(0, verbose_stream() << "not a recognized arithmetical relation " << mk_pp(arg, m) << "\n");
return false; return false;
} }
if (sign && !pos.is_marked(arg)) {
units.push_back(m.mk_not(arg));
pos.mark(arg, false);
}
else if (!sign && !neg.is_marked(arg)) {
units.push_back(arg);
neg.mark(arg, false);
}
} }
even = !even; even = !even;
++j; ++j;
} }
if (check()) return check();
return true;
return false;
} }
void register_plugins(euf::proof_checker& pc) override { void register_plugins(euf::proof_checker& pc) override {

View file

@ -55,10 +55,11 @@ namespace arith {
}; };
struct arith_proof_hint : public euf::th_proof_hint { struct arith_proof_hint : public euf::th_proof_hint {
hint_type m_ty; hint_type m_ty;
unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail; unsigned m_num_le;
arith_proof_hint(hint_type t, unsigned lh, unsigned lt, unsigned eh, unsigned et): unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail;
m_ty(t), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} arith_proof_hint(hint_type t, unsigned num_le, unsigned lh, unsigned lt, unsigned eh, unsigned et):
m_ty(t), m_num_le(num_le), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {}
expr* get_hint(euf::solver& s) const override; expr* get_hint(euf::solver& s) const override;
}; };
@ -66,6 +67,7 @@ namespace arith {
vector<std::pair<rational, literal>> m_literals; vector<std::pair<rational, literal>> m_literals;
svector<std::tuple<euf::enode*,euf::enode*,bool>> m_eqs; svector<std::tuple<euf::enode*,euf::enode*,bool>> m_eqs;
hint_type m_ty; hint_type m_ty;
unsigned m_num_le = 0;
unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0; unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0;
void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; } void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; }
void add(euf::enode* a, euf::enode* b, bool is_eq) { void add(euf::enode* a, euf::enode* b, bool is_eq) {
@ -82,6 +84,7 @@ namespace arith {
m_ty = ty; m_ty = ty;
reset(); reset();
} }
void set_num_le(unsigned n) { m_num_le = n; }
void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); } void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); }
void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); } void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); }
void add_lit(rational const& coeff, literal lit) { void add_lit(rational const& coeff, literal lit) {
@ -94,7 +97,7 @@ namespace arith {
std::pair<rational, literal> const& lit(unsigned i) const { return m_literals[i]; } std::pair<rational, literal> const& lit(unsigned i) const { return m_literals[i]; }
std::tuple<enode*, enode*, bool> const& eq(unsigned i) const { return m_eqs[i]; } std::tuple<enode*, enode*, bool> const& eq(unsigned i) const { return m_eqs[i]; }
arith_proof_hint* mk(euf::solver& s) { arith_proof_hint* mk(euf::solver& s) {
return new (s.get_region()) arith_proof_hint(m_ty, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); return new (s.get_region()) arith_proof_hint(m_ty, m_num_le, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail);
} }
}; };
@ -474,7 +477,7 @@ namespace arith {
arith_proof_hint const* explain(hint_type ty, sat::literal lit = sat::null_literal); arith_proof_hint const* explain(hint_type ty, sat::literal lit = sat::null_literal);
arith_proof_hint const* explain_implied_eq(euf::enode* a, euf::enode* b); arith_proof_hint const* explain_implied_eq(euf::enode* a, euf::enode* b);
arith_proof_hint const* explain_triangle_eq(sat::literal le, sat::literal ge, sat::literal eq); arith_proof_hint const* explain_trichotomy(sat::literal le, sat::literal ge, sat::literal eq);
void explain_assumptions(); void explain_assumptions();

View file

@ -17,6 +17,7 @@ Author:
#include "util/union_find.h" #include "util/union_find.h"
#include "ast/ast_pp.h" #include "ast/ast_pp.h"
#include "ast/ast_util.h"
#include "ast/ast_ll_pp.h" #include "ast/ast_ll_pp.h"
#include "sat/smt/euf_proof_checker.h" #include "sat/smt/euf_proof_checker.h"
#include "sat/smt/arith_proof_checker.h" #include "sat/smt/arith_proof_checker.h"
@ -120,24 +121,23 @@ namespace euf {
~eq_proof_checker() override {} ~eq_proof_checker() override {}
bool check(expr_ref_vector const& clause, app* jst, expr_ref_vector& units) override { expr_ref_vector clause(app* jst) override {
IF_VERBOSE(10, verbose_stream() << clause << "\n" << mk_pp(jst, m) << "\n"); expr_ref_vector result(m);
for (expr* arg : *jst)
if (m.is_bool(arg))
result.push_back(mk_not(m, arg));
return result;
}
bool check(app* jst) override {
IF_VERBOSE(10, verbose_stream() << mk_pp(jst, m) << "\n");
reset(); reset();
expr_mark pos, neg;
expr* x, *y;
for (expr* e : clause)
if (m.is_not(e, e))
neg.mark(e, true);
else
pos.mark(e, true);
for (expr* arg : *jst) { for (expr* arg : *jst) {
expr* x, *y;
bool sign = m.is_not(arg, arg);
if (m.is_bool(arg)) { if (m.is_bool(arg)) {
bool sign = m.is_not(arg, arg);
if (sign && !pos.is_marked(arg))
units.push_back(m.mk_not(arg));
else if (!sign & !neg.is_marked(arg))
units.push_back(arg);
if (m.is_eq(arg, x, y)) { if (m.is_eq(arg, x, y)) {
if (sign) if (sign)
m_diseqs.push_back({x, y}); m_diseqs.push_back({x, y});
@ -198,38 +198,144 @@ namespace euf {
void register_plugins(proof_checker& pc) override { void register_plugins(proof_checker& pc) override {
pc.register_plugin(symbol("euf"), this); pc.register_plugin(symbol("euf"), this);
} }
};
/**
A resolution proof term is of the form
(res pivot proof1 proof2)
The pivot occurs with opposite signs in proof1 and proof2
*/
class res_proof_checker : public proof_checker_plugin {
ast_manager& m;
proof_checker& pc;
public:
res_proof_checker(ast_manager& m, proof_checker& pc): m(m), pc(pc) {}
~res_proof_checker() override {}
bool check(app* jst) override {
if (jst->get_num_args() != 3)
return false;
auto [pivot, proof1, proof2] = jst->args3();
if (!m.is_bool(pivot) || !m.is_proof(proof1) || !m.is_proof(proof2))
return false;
expr* narg;
bool found1 = false, found2 = false, found3 = false, found4 = false;
for (expr* arg : pc.clause(proof1)) {
found1 |= arg == pivot;
found2 |= m.is_not(arg, narg) && narg == pivot;
}
if (found1 == found2)
return false;
for (expr* arg : pc.clause(proof2)) {
found3 |= arg == pivot;
found4 |= m.is_not(arg, narg) && narg == pivot;
}
if (found3 == found4)
return false;
if (found3 == found1)
return false;
return pc.check(proof1) && pc.check(proof2);
}
expr_ref_vector clause(app* jst) override {
expr_ref_vector result(m);
auto [pivot, proof1, proof2] = jst->args3();
expr* narg;
auto is_pivot = [&](expr* arg) {
if (arg == pivot)
return true;
return m.is_not(arg, narg) && narg == pivot;
};
for (expr* arg : pc.clause(proof1))
if (!is_pivot(arg))
result.push_back(arg);
for (expr* arg : pc.clause(proof2))
if (!is_pivot(arg))
result.push_back(arg);
return result;
}
void register_plugins(proof_checker& pc) override {
pc.register_plugin(symbol("res"), this);
}
}; };
proof_checker::proof_checker(ast_manager& m): proof_checker::proof_checker(ast_manager& m):
m(m) { m(m) {
arith::proof_checker* apc = alloc(arith::proof_checker, m); add_plugin(alloc(arith::proof_checker, m));
eq_proof_checker* epc = alloc(eq_proof_checker, m); add_plugin(alloc(eq_proof_checker, m));
m_plugins.push_back(apc); add_plugin(alloc(res_proof_checker, m, *this));
m_plugins.push_back(epc);
apc->register_plugins(*this);
epc->register_plugins(*this);
} }
proof_checker::~proof_checker() {} proof_checker::~proof_checker() {
for (auto& [k, v] : m_checked_clauses)
dealloc(v);
}
void proof_checker::add_plugin(proof_checker_plugin* p) {
m_plugins.push_back(p);
p->register_plugins(*this);
}
void proof_checker::register_plugin(symbol const& rule, proof_checker_plugin* p) { void proof_checker::register_plugin(symbol const& rule, proof_checker_plugin* p) {
m_map.insert(rule, p); m_map.insert(rule, p);
} }
bool proof_checker::check(expr_ref_vector const& clause, expr* e, expr_ref_vector& units) { bool proof_checker::check(expr* e) {
if (m_checked_clauses.contains(e))
return true;
if (!e || !is_app(e)) if (!e || !is_app(e))
return false; return false;
units.reset();
app* a = to_app(e); app* a = to_app(e);
proof_checker_plugin* p = nullptr; proof_checker_plugin* p = nullptr;
if (!m_map.find(a->get_decl()->get_name(), p)) if (!m_map.find(a->get_decl()->get_name(), p))
return false; return false;
if (p->check(clause, a, units)) if (!p->check(a)) {
return true; std::cout << "(missed-hint " << mk_pp(e, m) << ")\n";
return false;
}
return true;
}
std::cout << "(missed-hint " << mk_pp(e, m) << ")\n"; expr_ref_vector proof_checker::clause(expr* e) {
return false; expr_ref_vector* rr;
if (m_checked_clauses.find(e, rr))
return *rr;
SASSERT(is_app(e) && m_map.contains(to_app(e)->get_decl()->get_name()));
auto& r = m_map[to_app(e)->get_decl()->get_name()]->clause(to_app(e));
m_checked_clauses.insert(e, alloc(expr_ref_vector, r));
return r;
}
bool proof_checker::check(expr_ref_vector const& clause1, expr* e, expr_ref_vector & units) {
if (!check(e))
return false;
units.reset();
expr_mark literals;
auto clause2 = clause(e);
// check that all literals in clause1 are in clause2
for (expr* arg : clause2)
literals.mark(arg, true);
for (expr* arg : clause1)
if (!literals.is_marked(arg))
return false;
// extract negated units for literals in clause2 but not in clause1
// the literals should be rup
literals.reset();
for (expr* arg : clause1)
literals.mark(arg, true);
for (expr* arg : clause2)
if (!literals.is_marked(arg))
units.push_back(mk_not(m, arg));
return true;
} }
} }

View file

@ -27,18 +27,23 @@ namespace euf {
class proof_checker_plugin { class proof_checker_plugin {
public: public:
virtual ~proof_checker_plugin() {} virtual ~proof_checker_plugin() {}
virtual bool check(expr_ref_vector const& clause, app* jst, expr_ref_vector& units) = 0; virtual bool check(app* jst) = 0;
virtual expr_ref_vector clause(app* jst) = 0;
virtual void register_plugins(proof_checker& pc) = 0; virtual void register_plugins(proof_checker& pc) = 0;
}; };
class proof_checker { class proof_checker {
ast_manager& m; ast_manager& m;
scoped_ptr_vector<proof_checker_plugin> m_plugins; scoped_ptr_vector<proof_checker_plugin> m_plugins; // plugins of proof checkers
map<symbol, proof_checker_plugin*, symbol_hash_proc, symbol_eq_proc> m_map; map<symbol, proof_checker_plugin*, symbol_hash_proc, symbol_eq_proc> m_map; // symbol table of proof checkers
obj_map<expr, expr_ref_vector*> m_checked_clauses; // cache of previously checked proofs and their clauses.
void add_plugin(proof_checker_plugin* p);
public: public:
proof_checker(ast_manager& m); proof_checker(ast_manager& m);
~proof_checker(); ~proof_checker();
void register_plugin(symbol const& rule, proof_checker_plugin*); void register_plugin(symbol const& rule, proof_checker_plugin*);
bool check(expr* jst);
expr_ref_vector clause(expr* jst);
bool check(expr_ref_vector const& clause, expr* e, expr_ref_vector& units); bool check(expr_ref_vector const& clause, expr* e, expr_ref_vector& units);
}; };