3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-07 18:05:21 +00:00

wip - proof hints

This commit is contained in:
Nikolaj Bjorner 2022-10-08 20:12:57 +02:00
parent 6796ea7e49
commit 4623117af8
10 changed files with 176 additions and 29 deletions

View file

@ -64,6 +64,17 @@ void ast_pp_util::display_decls(std::ostream& out) {
m_rec_decls = n;
}
void ast_pp_util::reset() {
coll.reset();
m_removed.reset();
m_sorts.clear(0u);
m_decls.clear(0u);
m_rec_decls.clear(0u);
m_is_defined.reset();
m_defined.reset();
m_defined_lim.reset();
}
void ast_pp_util::display_skolem_decls(std::ostream& out) {
ast_smt_pp pp(m);
unsigned n = coll.get_num_decls();

View file

@ -40,8 +40,7 @@ class ast_pp_util {
ast_pp_util(ast_manager& m): m(m), m_env(m), m_rec_decls(0), m_decls(0), m_sorts(0), m_defined(m), coll(m) {}
void reset() { coll.reset(); m_removed.reset(); m_sorts.clear(0u); m_decls.clear(0u); m_rec_decls.clear(0u);
m_is_defined.reset(); m_defined.reset(); m_defined_lim.reset(); }
void reset();
void collect(expr* e);

View file

@ -103,10 +103,7 @@ namespace dt {
*/
void solver::assert_eq_axiom(enode* n1, expr* e2, literal antecedent) {
expr* e1 = n1->get_expr();
euf::th_proof_hint* ph = nullptr;
if (ctx.use_drat()) {
// todo
}
euf::th_proof_hint* ph = ctx.mk_smt_prop_hint(name(), antecedent, e1, e2);
if (antecedent == sat::null_literal)
add_unit(eq_internalize(e1, e2), ph);
else if (s().value(antecedent) == l_true) {
@ -166,7 +163,8 @@ namespace dt {
literal l = ctx.enode2literal(r);
SASSERT(s().value(l) == l_false);
clear_mark();
ctx.set_conflict(euf::th_explain::conflict(*this, ~l, c, r->get_arg(0)));
auto* ph = ctx.mk_smt_hint(name(), ~l, c, r->get_arg(0));
ctx.set_conflict(euf::th_explain::conflict(*this, ~l, c, r->get_arg(0), ph));
}
/**
@ -204,7 +202,9 @@ namespace dt {
// update_field is identity if 'n' is not created by a matching constructor.
assert_eq_axiom(n, arg1, ~is_con);
app_ref n_is_con(m.mk_app(rec, own), m);
add_clause(~is_con, mk_literal(n_is_con));
literal _n_is_con = mk_literal(n_is_con);
auto* ph = ctx.mk_smt_hint(name(), is_con, ~_n_is_con);
add_clause(~is_con, _n_is_con, ph);
}
euf::theory_var solver::mk_var(enode* n) {
@ -313,7 +313,8 @@ namespace dt {
}
}
}
ctx.set_conflict(euf::th_explain::conflict(*this, m_lits));
auto* ph = ctx.mk_smt_hint(name(), m_lits);
ctx.set_conflict(euf::th_explain::conflict(*this, m_lits, ph));
}
/**
@ -449,8 +450,10 @@ namespace dt {
++idx;
}
TRACE("dt", tout << "propagate " << num_unassigned << " eqs: " << eqs.size() << "\n";);
if (num_unassigned == 0)
ctx.set_conflict(euf::th_explain::conflict(*this, m_lits, eqs));
if (num_unassigned == 0) {
auto* ph = ctx.mk_smt_hint(name(), m_lits, eqs);
ctx.set_conflict(euf::th_explain::conflict(*this, m_lits, eqs, ph));
}
else if (num_unassigned == 1) {
// propagate remaining recognizer
SASSERT(!m_lits.empty());
@ -464,7 +467,13 @@ namespace dt {
app_ref rec_app(m.mk_app(rec, n->get_expr()), m);
consequent = mk_literal(rec_app);
}
ctx.propagate(consequent, euf::th_explain::propagate(*this, m_lits, eqs, consequent));
euf::th_proof_hint* ph = nullptr;
if (ctx.use_drat()) {
m_lits.push_back(~consequent);
ph = ctx.mk_smt_hint(name(), m_lits, eqs);
m_lits.pop_back();
}
ctx.propagate(consequent, euf::th_explain::propagate(*this, m_lits, eqs, consequent, ph));
}
else if (get_config().m_dt_lazy_splits == 0 || (!srt->is_infinite() && get_config().m_dt_lazy_splits == 1))
// there are more than 2 unassigned recognizers...
@ -481,7 +490,7 @@ namespace dt {
auto* con2 = d2->m_constructor;
TRACE("dt", tout << "merging v" << v1 << " v" << v2 << "\n" << ctx.bpp(var2enode(v1)) << " == " << ctx.bpp(var2enode(v2)) << " " << ctx.bpp(con1) << " " << ctx.bpp(con2) << "\n";);
if (con1 && con2 && con1->get_decl() != con2->get_decl())
ctx.set_conflict(euf::th_explain::conflict(*this, con1, con2));
ctx.set_conflict(euf::th_explain::conflict(*this, con1, con2, ctx.mk_smt_hint(name(), con1, con2)));
else if (con2 && !con1) {
ctx.push(set_ptr_trail<enode>(d1->m_constructor));
// check whether there is a recognizer in d1 that conflicts with con2;
@ -706,7 +715,7 @@ namespace dt {
if (res) {
clear_mark();
ctx.set_conflict(euf::th_explain::conflict(*this, m_used_eqs));
ctx.set_conflict(euf::th_explain::conflict(*this, m_used_eqs, ctx.mk_smt_hint(name(), m_used_eqs)));
TRACE("dt", tout << "occurs check conflict: " << ctx.bpp(n) << "\n";);
}
return res;

View file

@ -79,13 +79,13 @@ namespace euf {
return nullptr;
push(value_trail(m_lit_tail));
push(value_trail(m_cc_tail));
push(restore_size_trail(m_eq_proof_literals));
push(restore_size_trail(m_proof_literals));
if (lit != sat::null_literal)
m_eq_proof_literals.push_back(~lit);
m_eq_proof_literals.append(r);
m_proof_literals.push_back(~lit);
m_proof_literals.append(r);
m_lit_head = m_lit_tail;
m_cc_head = m_cc_tail;
m_lit_tail = m_eq_proof_literals.size();
m_lit_tail = m_proof_literals.size();
m_cc_tail = m_explain_cc.size();
return new (get_region()) eq_proof_hint(m_lit_head, m_lit_tail, m_cc_head, m_cc_tail);
}
@ -114,7 +114,7 @@ namespace euf {
return ta < tb;
};
for (unsigned i = m_lit_head; i < m_lit_tail; ++i)
args.push_back(s.literal2expr(s.m_eq_proof_literals[i]));
args.push_back(s.literal2expr(s.m_proof_literals[i]));
std::sort(s.m_explain_cc.data() + m_cc_head, s.m_explain_cc.data() + m_cc_tail, compare_ts);
for (unsigned i = m_cc_head; i < m_cc_tail; ++i) {
auto const& [a, b, ts, comm] = s.m_explain_cc[i];
@ -126,6 +126,66 @@ namespace euf {
func_decl* f = m.mk_func_decl(symbol("euf"), sorts.size(), sorts.data(), proof);
return m.mk_app(f, args);
}
smt_proof_hint* solver::mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, expr_pair const* eqs, unsigned nd, expr_pair const* deqs) {
if (!use_drat())
return nullptr;
push(value_trail(m_lit_tail));
push(restore_size_trail(m_proof_literals));
for (unsigned i = 0; i < nl; ++i)
if (sat::null_literal != lits[i])
m_proof_literals.push_back(lits[i]);
push(value_trail(m_eq_tail));
push(restore_size_trail(m_proof_eqs));
m_proof_eqs.append(ne, eqs);
push(value_trail(m_deq_tail));
push(restore_size_trail(m_proof_deqs));
m_proof_deqs.append(nd, deqs);
m_lit_head = m_lit_tail;
m_eq_head = m_eq_tail;
m_deq_head = m_deq_tail;
m_lit_tail = m_proof_literals.size();
m_eq_tail = m_proof_eqs.size();
m_deq_tail = m_proof_deqs.size();
return new (get_region()) smt_proof_hint(n, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail, m_deq_head, m_deq_tail);
}
smt_proof_hint* solver::mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, enode_pair const* eqs) {
if (!use_drat())
return nullptr;
m_expr_pairs.reset();
for (unsigned i = 0; i < ne; ++i)
m_expr_pairs.push_back({ eqs[i].first->get_expr(), eqs[i].second->get_expr() });
return mk_smt_hint(n, nl, lits, ne, m_expr_pairs.data());
}
expr* smt_proof_hint::get_hint(euf::solver& s) const {
ast_manager& m = s.get_manager();
sort* proof = m.mk_proof_sort();
ptr_buffer<sort> sorts;
expr_ref_vector args(m);
for (unsigned i = m_lit_head; i < m_lit_tail; ++i)
args.push_back(s.literal2expr(s.m_proof_literals[i]));
for (unsigned i = m_eq_head; i < m_eq_tail; ++i) {
auto const& [a, b] = s.m_proof_eqs[i];
args.push_back(m.mk_eq(a, b));
}
for (unsigned i = m_deq_head; i < m_deq_tail; ++i) {
auto const& [a, b] = s.m_proof_deqs[i];
args.push_back(m.mk_not(m.mk_eq(a, b)));
}
for (auto * arg : args)
sorts.push_back(arg->get_sort());
func_decl* f = m.mk_func_decl(m_name, sorts.size(), sorts.data(), proof);
return m.mk_app(f, args);
}
void solver::set_tmp_bool_var(bool_var b, expr* e) {
m_bool_var2expr.setx(b, e, nullptr);

View file

@ -145,8 +145,10 @@ namespace euf {
else
merge(x, y);
}
else
IF_VERBOSE(0, verbose_stream() << "TODO " << mk_pp(arg, m) << " " << sign << "\n");
else if (m.is_not(arg, arg))
merge(arg, m.mk_false());
else
merge(arg, m.mk_true());
}
else if (m.is_proof(arg)) {
if (!is_app(arg))
@ -274,6 +276,7 @@ namespace euf {
add_plugin(alloc(eq_proof_checker, m));
add_plugin(alloc(res_proof_checker, m, *this));
add_plugin(alloc(q::proof_checker, m));
add_plugin(alloc(smt_proof_checker_plugin, m, symbol("datatype"))); // no-op datatype proof checker
}
proof_checker::~proof_checker() {
@ -317,8 +320,13 @@ namespace euf {
}
void proof_checker::vc(expr* e, expr_ref_vector& clause) {
SASSERT(is_app(e) && m_map.contains(to_app(e)->get_name()));
m_map[to_app(e)->get_name()]->vc(to_app(e), clause);
SASSERT(is_app(e));
app* a = to_app(e);
proof_checker_plugin* p = nullptr;
if (m_map.find(a->get_name(), p))
p->vc(a, clause);
else
IF_VERBOSE(0, verbose_stream() << "there is no proof plugin for " << mk_pp(e, m) << "\n");
}
bool proof_checker::check(expr_ref_vector const& clause1, expr* e, expr_ref_vector & units) {
@ -347,5 +355,13 @@ namespace euf {
return true;
}
expr_ref_vector smt_proof_checker_plugin::clause(app* jst) {
expr_ref_vector result(m);
SASSERT(jst->get_name() == m_rule);
for (expr* arg : *jst)
result.push_back(mk_not(m, arg));
return result;
}
}

View file

@ -49,5 +49,22 @@ namespace euf {
bool check(expr_ref_vector const& clause, expr* e, expr_ref_vector& units);
};
/**
Base class for checking SMT proofs whose justifications are
provided as a set of literals and E-node equalities.
It provides shared implementations for clause and register_plugin.
It overrides check to always fail.
*/
class smt_proof_checker_plugin : public proof_checker_plugin {
ast_manager& m;
symbol m_rule;
public:
smt_proof_checker_plugin(ast_manager& m, symbol const& n): m(m), m_rule(n) {}
~smt_proof_checker_plugin() override {}
bool check(app* jst) override { return false; }
expr_ref_vector clause(app* jst) override;
void register_plugins(proof_checker& pc) override { pc.register_plugin(m_rule, this); }
};
}

View file

@ -305,11 +305,9 @@ namespace euf {
}
void solver::asserted(literal l) {
m_relevancy.asserted(l);
if (!m_relevancy.is_relevant(l))
return;
expr* e = m_bool_var2expr.get(l.var(), nullptr);
TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << " := " << mk_bounded_pp(e, m) << "\n";);
if (!e)
@ -334,7 +332,7 @@ namespace euf {
m_egraph.merge(r, rb, to_ptr(rl));
SASSERT(m_egraph.inconsistent());
return;
}
}
if (n->merge_tf()) {
euf::enode* nb = sign ? mk_false() : mk_true();
m_egraph.merge(n, nb, c);

View file

@ -68,10 +68,20 @@ namespace euf {
expr* get_hint(euf::solver& s) const override;
};
class smt_proof_hint : public th_proof_hint {
symbol m_name;
unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail, m_deq_head, m_deq_tail;
public:
smt_proof_hint(symbol const& n, unsigned lh, unsigned lt, unsigned ch, unsigned ct, unsigned dh, unsigned dt):
m_name(n), m_lit_head(lh), m_lit_tail(lt), m_eq_head(ch), m_eq_tail(ct), m_deq_head(dh), m_deq_tail(dt) {}
expr* get_hint(euf::solver& s) const override;
};
class solver : public sat::extension, public th_internalizer, public th_decompile, public sat::clause_eh {
typedef top_sort<euf::enode> deps_t;
friend class ackerman;
friend class eq_proof_hint;
friend class smt_proof_hint;
class user_sort;
struct stats {
unsigned m_ackerman;
@ -130,6 +140,7 @@ namespace euf {
constraint* m_eq = nullptr;
constraint* m_lit = nullptr;
// internalization
bool visit(expr* e) override;
bool visited(expr* e) override;
@ -184,8 +195,12 @@ namespace euf {
void log_antecedents(std::ostream& out, literal l, literal_vector const& r);
void log_antecedents(literal l, literal_vector const& r, eq_proof_hint* hint);
void log_justification(literal l, th_explain const& jst);
literal_vector m_eq_proof_literals;
typedef std::pair<expr*, expr*> expr_pair;
literal_vector m_proof_literals;
svector<expr_pair> m_proof_eqs, m_proof_deqs, m_expr_pairs;
unsigned m_lit_head = 0, m_lit_tail = 0, m_cc_head = 0, m_cc_tail = 0;
unsigned m_eq_head = 0, m_eq_tail = 0, m_deq_head = 0, m_deq_tail = 0;
eq_proof_hint* mk_hint(literal lit, literal_vector const& r);
bool m_proof_initialized = false;
@ -365,6 +380,26 @@ namespace euf {
void visit_expr(std::ostream& out, expr* e);
std::ostream& display_expr(std::ostream& out, expr* e);
void on_instantiation(unsigned n, sat::literal const* lits, unsigned k, euf::enode* const* bindings);
smt_proof_hint* mk_smt_hint(symbol const& n, literal_vector const& lits, enode_pair_vector const& eqs) {
return mk_smt_hint(n, lits.size(), lits.data(), eqs.size(), eqs.data());
}
smt_proof_hint* mk_smt_hint(symbol const& n, enode_pair_vector const& eqs) {
return mk_smt_hint(n, 0, nullptr, eqs.size(), eqs.data());
}
smt_proof_hint* mk_smt_hint(symbol const& n, literal_vector const& lits) {
return mk_smt_hint(n, lits.size(), lits.data(), 0, (expr_pair const*) nullptr);
}
smt_proof_hint* mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, expr_pair const* eqs, unsigned nd = 0, expr_pair const* deqs = nullptr);
smt_proof_hint* mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, enode_pair const* eqs);
smt_proof_hint* mk_smt_hint(symbol const& n, literal lit, unsigned ne, expr_pair const* eqs) { return mk_smt_hint(n, 1, &lit, ne, eqs); }
smt_proof_hint* mk_smt_hint(symbol const& n, literal lit) { return mk_smt_hint(n, 1, &lit, 0, (expr_pair const*)nullptr); }
smt_proof_hint* mk_smt_hint(symbol const& n, literal l1, literal l2) { literal ls[2] = {l1,l2}; return mk_smt_hint(n, 2, ls, 0, (expr_pair const*)nullptr); }
smt_proof_hint* mk_smt_hint(symbol const& n, literal lit, expr* a, expr* b) { expr_pair e(a, b); return mk_smt_hint(n, 1, &lit, 1, &e); }
smt_proof_hint* mk_smt_hint(symbol const& n, literal lit, enode* a, enode* b) { expr_pair e(a->get_expr(), b->get_expr()); return mk_smt_hint(n, 1, &lit, 1, &e); }
smt_proof_hint* mk_smt_prop_hint(symbol const& n, literal lit, expr* a, expr* b) { expr_pair e(a, b); return mk_smt_hint(n, 1, &lit, 0, nullptr, 1, &e); }
smt_proof_hint* mk_smt_prop_hint(symbol const& n, literal lit, enode* a, enode* b) { return mk_smt_prop_hint(n, lit, a->get_expr(), b->get_expr()); }
smt_proof_hint* mk_smt_hint(symbol const& n, enode* a, enode* b) { expr_pair e(a->get_expr(), b->get_expr()); return mk_smt_hint(n, 0, nullptr, 1, &e); }
scoped_ptr<std::ostream> m_proof_out;
// decompile

View file

@ -238,7 +238,7 @@ namespace euf {
public:
static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, th_proof_hint const* ph = nullptr);
static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits) { return conflict(th, lits.size(), lits.data(), 0, nullptr); }
static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits, th_proof_hint const* ph = nullptr) { return conflict(th, lits.size(), lits.data(), 0, nullptr, nullptr); }
static th_explain* conflict(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, th_proof_hint const* ph = nullptr);
static th_explain* conflict(th_euf_solver& th, enode_pair_vector const& eqs, th_proof_hint const* ph = nullptr);
static th_explain* conflict(th_euf_solver& th, sat::literal lit, th_proof_hint const* ph = nullptr);

View file

@ -1383,6 +1383,8 @@ namespace smt {
Z3_fallthrough;
case CLS_AUX: {
literal_buffer simp_lits;
if (m_searching)
dump_lemma(num_lits, lits);
if (!simplify_aux_clause_literals(num_lits, lits, simp_lits)) {
if (j && !j->in_region()) {
j->del_eh(m);
@ -1394,6 +1396,7 @@ namespace smt {
if (!simp_lits.empty()) {
j = mk_justification(unit_resolution_justification(*this, j, simp_lits.size(), simp_lits.data()));
}
break;
}
case CLS_TH_LEMMA:
@ -1525,7 +1528,6 @@ namespace smt {
}
void context::dump_lemma(unsigned n, literal const* lits) {
if (m_fparams.m_lemmas2console) {
expr_ref fml(m);
expr_ref_vector fmls(m);