From 4623117af8c6d70f3d84274738013f45ff71a164 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 8 Oct 2022 20:12:57 +0200 Subject: [PATCH] wip - proof hints --- src/ast/ast_pp_util.cpp | 11 +++++ src/ast/ast_pp_util.h | 3 +- src/sat/smt/dt_solver.cpp | 33 +++++++++------ src/sat/smt/euf_proof.cpp | 70 ++++++++++++++++++++++++++++--- src/sat/smt/euf_proof_checker.cpp | 24 +++++++++-- src/sat/smt/euf_proof_checker.h | 17 ++++++++ src/sat/smt/euf_solver.cpp | 4 +- src/sat/smt/euf_solver.h | 37 +++++++++++++++- src/sat/smt/sat_th.h | 2 +- src/smt/smt_internalizer.cpp | 4 +- 10 files changed, 176 insertions(+), 29 deletions(-) diff --git a/src/ast/ast_pp_util.cpp b/src/ast/ast_pp_util.cpp index a74566199..c0608522f 100644 --- a/src/ast/ast_pp_util.cpp +++ b/src/ast/ast_pp_util.cpp @@ -64,6 +64,17 @@ void ast_pp_util::display_decls(std::ostream& out) { m_rec_decls = n; } +void ast_pp_util::reset() { + coll.reset(); + m_removed.reset(); + m_sorts.clear(0u); + m_decls.clear(0u); + m_rec_decls.clear(0u); + m_is_defined.reset(); + m_defined.reset(); + m_defined_lim.reset(); +} + void ast_pp_util::display_skolem_decls(std::ostream& out) { ast_smt_pp pp(m); unsigned n = coll.get_num_decls(); diff --git a/src/ast/ast_pp_util.h b/src/ast/ast_pp_util.h index 9cec62267..9dbfec6af 100644 --- a/src/ast/ast_pp_util.h +++ b/src/ast/ast_pp_util.h @@ -40,8 +40,7 @@ class ast_pp_util { ast_pp_util(ast_manager& m): m(m), m_env(m), m_rec_decls(0), m_decls(0), m_sorts(0), m_defined(m), coll(m) {} - void reset() { coll.reset(); m_removed.reset(); m_sorts.clear(0u); m_decls.clear(0u); m_rec_decls.clear(0u); - m_is_defined.reset(); m_defined.reset(); m_defined_lim.reset(); } + void reset(); void collect(expr* e); diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index f2ce8803f..a87f8770b 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -103,10 +103,7 @@ namespace dt { */ void solver::assert_eq_axiom(enode* n1, expr* e2, literal antecedent) { expr* e1 = n1->get_expr(); - euf::th_proof_hint* ph = nullptr; - if (ctx.use_drat()) { - // todo - } + euf::th_proof_hint* ph = ctx.mk_smt_prop_hint(name(), antecedent, e1, e2); if (antecedent == sat::null_literal) add_unit(eq_internalize(e1, e2), ph); else if (s().value(antecedent) == l_true) { @@ -166,7 +163,8 @@ namespace dt { literal l = ctx.enode2literal(r); SASSERT(s().value(l) == l_false); clear_mark(); - ctx.set_conflict(euf::th_explain::conflict(*this, ~l, c, r->get_arg(0))); + auto* ph = ctx.mk_smt_hint(name(), ~l, c, r->get_arg(0)); + ctx.set_conflict(euf::th_explain::conflict(*this, ~l, c, r->get_arg(0), ph)); } /** @@ -204,7 +202,9 @@ namespace dt { // update_field is identity if 'n' is not created by a matching constructor. assert_eq_axiom(n, arg1, ~is_con); app_ref n_is_con(m.mk_app(rec, own), m); - add_clause(~is_con, mk_literal(n_is_con)); + literal _n_is_con = mk_literal(n_is_con); + auto* ph = ctx.mk_smt_hint(name(), is_con, ~_n_is_con); + add_clause(~is_con, _n_is_con, ph); } euf::theory_var solver::mk_var(enode* n) { @@ -313,7 +313,8 @@ namespace dt { } } } - ctx.set_conflict(euf::th_explain::conflict(*this, m_lits)); + auto* ph = ctx.mk_smt_hint(name(), m_lits); + ctx.set_conflict(euf::th_explain::conflict(*this, m_lits, ph)); } /** @@ -449,8 +450,10 @@ namespace dt { ++idx; } TRACE("dt", tout << "propagate " << num_unassigned << " eqs: " << eqs.size() << "\n";); - if (num_unassigned == 0) - ctx.set_conflict(euf::th_explain::conflict(*this, m_lits, eqs)); + if (num_unassigned == 0) { + auto* ph = ctx.mk_smt_hint(name(), m_lits, eqs); + ctx.set_conflict(euf::th_explain::conflict(*this, m_lits, eqs, ph)); + } else if (num_unassigned == 1) { // propagate remaining recognizer SASSERT(!m_lits.empty()); @@ -464,7 +467,13 @@ namespace dt { app_ref rec_app(m.mk_app(rec, n->get_expr()), m); consequent = mk_literal(rec_app); } - ctx.propagate(consequent, euf::th_explain::propagate(*this, m_lits, eqs, consequent)); + euf::th_proof_hint* ph = nullptr; + if (ctx.use_drat()) { + m_lits.push_back(~consequent); + ph = ctx.mk_smt_hint(name(), m_lits, eqs); + m_lits.pop_back(); + } + ctx.propagate(consequent, euf::th_explain::propagate(*this, m_lits, eqs, consequent, ph)); } else if (get_config().m_dt_lazy_splits == 0 || (!srt->is_infinite() && get_config().m_dt_lazy_splits == 1)) // there are more than 2 unassigned recognizers... @@ -481,7 +490,7 @@ namespace dt { auto* con2 = d2->m_constructor; TRACE("dt", tout << "merging v" << v1 << " v" << v2 << "\n" << ctx.bpp(var2enode(v1)) << " == " << ctx.bpp(var2enode(v2)) << " " << ctx.bpp(con1) << " " << ctx.bpp(con2) << "\n";); if (con1 && con2 && con1->get_decl() != con2->get_decl()) - ctx.set_conflict(euf::th_explain::conflict(*this, con1, con2)); + ctx.set_conflict(euf::th_explain::conflict(*this, con1, con2, ctx.mk_smt_hint(name(), con1, con2))); else if (con2 && !con1) { ctx.push(set_ptr_trail(d1->m_constructor)); // check whether there is a recognizer in d1 that conflicts with con2; @@ -706,7 +715,7 @@ namespace dt { if (res) { clear_mark(); - ctx.set_conflict(euf::th_explain::conflict(*this, m_used_eqs)); + ctx.set_conflict(euf::th_explain::conflict(*this, m_used_eqs, ctx.mk_smt_hint(name(), m_used_eqs))); TRACE("dt", tout << "occurs check conflict: " << ctx.bpp(n) << "\n";); } return res; diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 99bc99d48..73037fc8b 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -79,13 +79,13 @@ namespace euf { return nullptr; push(value_trail(m_lit_tail)); push(value_trail(m_cc_tail)); - push(restore_size_trail(m_eq_proof_literals)); + push(restore_size_trail(m_proof_literals)); if (lit != sat::null_literal) - m_eq_proof_literals.push_back(~lit); - m_eq_proof_literals.append(r); + m_proof_literals.push_back(~lit); + m_proof_literals.append(r); m_lit_head = m_lit_tail; m_cc_head = m_cc_tail; - m_lit_tail = m_eq_proof_literals.size(); + m_lit_tail = m_proof_literals.size(); m_cc_tail = m_explain_cc.size(); return new (get_region()) eq_proof_hint(m_lit_head, m_lit_tail, m_cc_head, m_cc_tail); } @@ -114,7 +114,7 @@ namespace euf { return ta < tb; }; for (unsigned i = m_lit_head; i < m_lit_tail; ++i) - args.push_back(s.literal2expr(s.m_eq_proof_literals[i])); + args.push_back(s.literal2expr(s.m_proof_literals[i])); std::sort(s.m_explain_cc.data() + m_cc_head, s.m_explain_cc.data() + m_cc_tail, compare_ts); for (unsigned i = m_cc_head; i < m_cc_tail; ++i) { auto const& [a, b, ts, comm] = s.m_explain_cc[i]; @@ -126,6 +126,66 @@ namespace euf { func_decl* f = m.mk_func_decl(symbol("euf"), sorts.size(), sorts.data(), proof); return m.mk_app(f, args); } + + smt_proof_hint* solver::mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, expr_pair const* eqs, unsigned nd, expr_pair const* deqs) { + if (!use_drat()) + return nullptr; + push(value_trail(m_lit_tail)); + push(restore_size_trail(m_proof_literals)); + + for (unsigned i = 0; i < nl; ++i) + if (sat::null_literal != lits[i]) + m_proof_literals.push_back(lits[i]); + + push(value_trail(m_eq_tail)); + push(restore_size_trail(m_proof_eqs)); + m_proof_eqs.append(ne, eqs); + + push(value_trail(m_deq_tail)); + push(restore_size_trail(m_proof_deqs)); + m_proof_deqs.append(nd, deqs); + + m_lit_head = m_lit_tail; + m_eq_head = m_eq_tail; + m_deq_head = m_deq_tail; + m_lit_tail = m_proof_literals.size(); + m_eq_tail = m_proof_eqs.size(); + m_deq_tail = m_proof_deqs.size(); + + return new (get_region()) smt_proof_hint(n, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail, m_deq_head, m_deq_tail); + } + + smt_proof_hint* solver::mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, enode_pair const* eqs) { + if (!use_drat()) + return nullptr; + m_expr_pairs.reset(); + for (unsigned i = 0; i < ne; ++i) + m_expr_pairs.push_back({ eqs[i].first->get_expr(), eqs[i].second->get_expr() }); + return mk_smt_hint(n, nl, lits, ne, m_expr_pairs.data()); + } + + + expr* smt_proof_hint::get_hint(euf::solver& s) const { + ast_manager& m = s.get_manager(); + sort* proof = m.mk_proof_sort(); + ptr_buffer sorts; + expr_ref_vector args(m); + + for (unsigned i = m_lit_head; i < m_lit_tail; ++i) + args.push_back(s.literal2expr(s.m_proof_literals[i])); + for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { + auto const& [a, b] = s.m_proof_eqs[i]; + args.push_back(m.mk_eq(a, b)); + } + for (unsigned i = m_deq_head; i < m_deq_tail; ++i) { + auto const& [a, b] = s.m_proof_deqs[i]; + args.push_back(m.mk_not(m.mk_eq(a, b))); + } + for (auto * arg : args) + sorts.push_back(arg->get_sort()); + func_decl* f = m.mk_func_decl(m_name, sorts.size(), sorts.data(), proof); + return m.mk_app(f, args); + } void solver::set_tmp_bool_var(bool_var b, expr* e) { m_bool_var2expr.setx(b, e, nullptr); diff --git a/src/sat/smt/euf_proof_checker.cpp b/src/sat/smt/euf_proof_checker.cpp index fc9c80b75..befde07bf 100644 --- a/src/sat/smt/euf_proof_checker.cpp +++ b/src/sat/smt/euf_proof_checker.cpp @@ -145,8 +145,10 @@ namespace euf { else merge(x, y); } - else - IF_VERBOSE(0, verbose_stream() << "TODO " << mk_pp(arg, m) << " " << sign << "\n"); + else if (m.is_not(arg, arg)) + merge(arg, m.mk_false()); + else + merge(arg, m.mk_true()); } else if (m.is_proof(arg)) { if (!is_app(arg)) @@ -274,6 +276,7 @@ namespace euf { add_plugin(alloc(eq_proof_checker, m)); add_plugin(alloc(res_proof_checker, m, *this)); add_plugin(alloc(q::proof_checker, m)); + add_plugin(alloc(smt_proof_checker_plugin, m, symbol("datatype"))); // no-op datatype proof checker } proof_checker::~proof_checker() { @@ -317,8 +320,13 @@ namespace euf { } void proof_checker::vc(expr* e, expr_ref_vector& clause) { - SASSERT(is_app(e) && m_map.contains(to_app(e)->get_name())); - m_map[to_app(e)->get_name()]->vc(to_app(e), clause); + SASSERT(is_app(e)); + app* a = to_app(e); + proof_checker_plugin* p = nullptr; + if (m_map.find(a->get_name(), p)) + p->vc(a, clause); + else + IF_VERBOSE(0, verbose_stream() << "there is no proof plugin for " << mk_pp(e, m) << "\n"); } bool proof_checker::check(expr_ref_vector const& clause1, expr* e, expr_ref_vector & units) { @@ -347,5 +355,13 @@ namespace euf { return true; } + expr_ref_vector smt_proof_checker_plugin::clause(app* jst) { + expr_ref_vector result(m); + SASSERT(jst->get_name() == m_rule); + for (expr* arg : *jst) + result.push_back(mk_not(m, arg)); + return result; + } + } diff --git a/src/sat/smt/euf_proof_checker.h b/src/sat/smt/euf_proof_checker.h index 443d23186..530644488 100644 --- a/src/sat/smt/euf_proof_checker.h +++ b/src/sat/smt/euf_proof_checker.h @@ -49,5 +49,22 @@ namespace euf { bool check(expr_ref_vector const& clause, expr* e, expr_ref_vector& units); }; + /** + Base class for checking SMT proofs whose justifications are + provided as a set of literals and E-node equalities. + It provides shared implementations for clause and register_plugin. + It overrides check to always fail. + */ + class smt_proof_checker_plugin : public proof_checker_plugin { + ast_manager& m; + symbol m_rule; + public: + smt_proof_checker_plugin(ast_manager& m, symbol const& n): m(m), m_rule(n) {} + ~smt_proof_checker_plugin() override {} + bool check(app* jst) override { return false; } + expr_ref_vector clause(app* jst) override; + void register_plugins(proof_checker& pc) override { pc.register_plugin(m_rule, this); } + }; + } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 2e6b07e51..21290b70a 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -305,11 +305,9 @@ namespace euf { } void solver::asserted(literal l) { - m_relevancy.asserted(l); if (!m_relevancy.is_relevant(l)) return; - expr* e = m_bool_var2expr.get(l.var(), nullptr); TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << " := " << mk_bounded_pp(e, m) << "\n";); if (!e) @@ -334,7 +332,7 @@ namespace euf { m_egraph.merge(r, rb, to_ptr(rl)); SASSERT(m_egraph.inconsistent()); return; - } + } if (n->merge_tf()) { euf::enode* nb = sign ? mk_false() : mk_true(); m_egraph.merge(n, nb, c); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index b5d65205c..1a8186eaa 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -68,10 +68,20 @@ namespace euf { expr* get_hint(euf::solver& s) const override; }; + class smt_proof_hint : public th_proof_hint { + symbol m_name; + unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail, m_deq_head, m_deq_tail; + public: + smt_proof_hint(symbol const& n, unsigned lh, unsigned lt, unsigned ch, unsigned ct, unsigned dh, unsigned dt): + m_name(n), m_lit_head(lh), m_lit_tail(lt), m_eq_head(ch), m_eq_tail(ct), m_deq_head(dh), m_deq_tail(dt) {} + expr* get_hint(euf::solver& s) const override; + }; + class solver : public sat::extension, public th_internalizer, public th_decompile, public sat::clause_eh { typedef top_sort deps_t; friend class ackerman; friend class eq_proof_hint; + friend class smt_proof_hint; class user_sort; struct stats { unsigned m_ackerman; @@ -130,6 +140,7 @@ namespace euf { constraint* m_eq = nullptr; constraint* m_lit = nullptr; + // internalization bool visit(expr* e) override; bool visited(expr* e) override; @@ -184,8 +195,12 @@ namespace euf { void log_antecedents(std::ostream& out, literal l, literal_vector const& r); void log_antecedents(literal l, literal_vector const& r, eq_proof_hint* hint); void log_justification(literal l, th_explain const& jst); - literal_vector m_eq_proof_literals; + + typedef std::pair expr_pair; + literal_vector m_proof_literals; + svector m_proof_eqs, m_proof_deqs, m_expr_pairs; unsigned m_lit_head = 0, m_lit_tail = 0, m_cc_head = 0, m_cc_tail = 0; + unsigned m_eq_head = 0, m_eq_tail = 0, m_deq_head = 0, m_deq_tail = 0; eq_proof_hint* mk_hint(literal lit, literal_vector const& r); bool m_proof_initialized = false; @@ -365,6 +380,26 @@ namespace euf { void visit_expr(std::ostream& out, expr* e); std::ostream& display_expr(std::ostream& out, expr* e); void on_instantiation(unsigned n, sat::literal const* lits, unsigned k, euf::enode* const* bindings); + smt_proof_hint* mk_smt_hint(symbol const& n, literal_vector const& lits, enode_pair_vector const& eqs) { + return mk_smt_hint(n, lits.size(), lits.data(), eqs.size(), eqs.data()); + } + smt_proof_hint* mk_smt_hint(symbol const& n, enode_pair_vector const& eqs) { + return mk_smt_hint(n, 0, nullptr, eqs.size(), eqs.data()); + } + smt_proof_hint* mk_smt_hint(symbol const& n, literal_vector const& lits) { + return mk_smt_hint(n, lits.size(), lits.data(), 0, (expr_pair const*) nullptr); + } + smt_proof_hint* mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, expr_pair const* eqs, unsigned nd = 0, expr_pair const* deqs = nullptr); + smt_proof_hint* mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, enode_pair const* eqs); + smt_proof_hint* mk_smt_hint(symbol const& n, literal lit, unsigned ne, expr_pair const* eqs) { return mk_smt_hint(n, 1, &lit, ne, eqs); } + smt_proof_hint* mk_smt_hint(symbol const& n, literal lit) { return mk_smt_hint(n, 1, &lit, 0, (expr_pair const*)nullptr); } + smt_proof_hint* mk_smt_hint(symbol const& n, literal l1, literal l2) { literal ls[2] = {l1,l2}; return mk_smt_hint(n, 2, ls, 0, (expr_pair const*)nullptr); } + smt_proof_hint* mk_smt_hint(symbol const& n, literal lit, expr* a, expr* b) { expr_pair e(a, b); return mk_smt_hint(n, 1, &lit, 1, &e); } + smt_proof_hint* mk_smt_hint(symbol const& n, literal lit, enode* a, enode* b) { expr_pair e(a->get_expr(), b->get_expr()); return mk_smt_hint(n, 1, &lit, 1, &e); } + smt_proof_hint* mk_smt_prop_hint(symbol const& n, literal lit, expr* a, expr* b) { expr_pair e(a, b); return mk_smt_hint(n, 1, &lit, 0, nullptr, 1, &e); } + smt_proof_hint* mk_smt_prop_hint(symbol const& n, literal lit, enode* a, enode* b) { return mk_smt_prop_hint(n, lit, a->get_expr(), b->get_expr()); } + smt_proof_hint* mk_smt_hint(symbol const& n, enode* a, enode* b) { expr_pair e(a->get_expr(), b->get_expr()); return mk_smt_hint(n, 0, nullptr, 1, &e); } + scoped_ptr m_proof_out; // decompile diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index cf645d21b..532d04a2e 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -238,7 +238,7 @@ namespace euf { public: static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, th_proof_hint const* ph = nullptr); - static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits) { return conflict(th, lits.size(), lits.data(), 0, nullptr); } + static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits, th_proof_hint const* ph = nullptr) { return conflict(th, lits.size(), lits.data(), 0, nullptr, nullptr); } static th_explain* conflict(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, th_proof_hint const* ph = nullptr); static th_explain* conflict(th_euf_solver& th, enode_pair_vector const& eqs, th_proof_hint const* ph = nullptr); static th_explain* conflict(th_euf_solver& th, sat::literal lit, th_proof_hint const* ph = nullptr); diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index d87a4f971..ac433c602 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -1383,6 +1383,8 @@ namespace smt { Z3_fallthrough; case CLS_AUX: { literal_buffer simp_lits; + if (m_searching) + dump_lemma(num_lits, lits); if (!simplify_aux_clause_literals(num_lits, lits, simp_lits)) { if (j && !j->in_region()) { j->del_eh(m); @@ -1394,6 +1396,7 @@ namespace smt { if (!simp_lits.empty()) { j = mk_justification(unit_resolution_justification(*this, j, simp_lits.size(), simp_lits.data())); } + break; } case CLS_TH_LEMMA: @@ -1525,7 +1528,6 @@ namespace smt { } void context::dump_lemma(unsigned n, literal const* lits) { - if (m_fparams.m_lemmas2console) { expr_ref fml(m); expr_ref_vector fmls(m);