From 9be8fc7857024c8888695c0ec3814b6750a83a36 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 25 Sep 2022 14:26:20 -0700 Subject: [PATCH] Add EUF (congruence closure) proof hints and checker to the new core EUF proofs are checked modulo union-find. Equalities are added to to union-find if they are assumptions or if they can be derived using congruence closure. The congruence closure assumptions are added as proof-hints. Note that this proof format does not track equality inferences, symmetry and transitivity. Instead they are handled by assuming a union-find based checker. --- src/ast/euf/euf_egraph.cpp | 47 +++++---- src/ast/euf/euf_egraph.h | 19 +++- src/ast/euf/euf_justification.h | 24 ++++- src/sat/smt/bv_solver.cpp | 6 +- src/sat/smt/euf_proof.cpp | 53 +++++++++- src/sat/smt/euf_proof_checker.cpp | 162 +++++++++++++++++++++++++++++- src/sat/smt/euf_solver.cpp | 33 +++--- src/sat/smt/euf_solver.h | 19 +++- src/sat/smt/q_ematch.cpp | 6 +- src/sat/smt/q_ematch.h | 1 + src/sat/smt/user_solver.cpp | 2 +- 11 files changed, 315 insertions(+), 57 deletions(-) diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 3820d2592..30dd1d720 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -126,7 +126,7 @@ namespace euf { if (n2 == n) update_children(n); else - merge(n, n2, justification::congruence(comm)); + merge(n, n2, justification::congruence(comm, m_congruence_timestamp++)); } return n; } @@ -554,7 +554,7 @@ namespace euf { force_push(); for (unsigned i = 0; i < m_to_merge.size() && m.limit().inc() && !inconsistent(); ++i) { auto const& w = m_to_merge[i]; - merge(w.a, w.b, justification::congruence(w.commutativity)); + merge(w.a, w.b, justification::congruence(w.commutativity, m_congruence_timestamp++)); } m_to_merge.reset(); return @@ -707,25 +707,28 @@ namespace euf { } template - void egraph::explain(ptr_vector& justifications) { + void egraph::explain(ptr_vector& justifications, cc_justification* cc) { SASSERT(m_inconsistent); push_todo(m_n1); push_todo(m_n2); - explain_eq(justifications, m_n1, m_n2, m_justification); - explain_todo(justifications); + explain_eq(justifications, cc, m_n1, m_n2, m_justification); + explain_todo(justifications, cc); } template - void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j) { + void egraph::explain_eq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b, justification const& j) { + TRACE("euf_verbose", tout << "explain-eq: " << bpp(a) << " == " << bpp(b) << " jst: " << j << "\n";); if (j.is_external()) justifications.push_back(j.ext()); else if (j.is_congruence()) push_congruence(a, b, j.is_commutative()); + if (cc && j.is_congruence()) + cc->push_back(std::tuple(a, b, j.timestamp(), j.is_commutative())); } template - void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b) { + void egraph::explain_eq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b) { SASSERT(a->get_root() == b->get_root()); enode* lca = find_lca(a, b); @@ -734,27 +737,27 @@ namespace euf { push_to_lca(b, lca); if (m_used_eq) m_used_eq(a->get_expr(), b->get_expr(), lca->get_expr()); - explain_todo(justifications); + explain_todo(justifications, cc); } template - unsigned egraph::explain_diseq(ptr_vector& justifications, enode* a, enode* b) { + unsigned egraph::explain_diseq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b) { enode* ra = a->get_root(), * rb = b->get_root(); SASSERT(ra != rb); if (ra->interpreted() && rb->interpreted()) { - explain_eq(justifications, a, ra); - explain_eq(justifications, b, rb); + explain_eq(justifications, cc, a, ra); + explain_eq(justifications, cc, b, rb); return sat::null_bool_var; } enode* r = tmp_eq(ra, rb); SASSERT(r && r->get_root()->value() == l_false); - explain_eq(justifications, r, r->get_root()); + explain_eq(justifications, cc, r, r->get_root()); return r->get_root()->bool_var(); } template - void egraph::explain_todo(ptr_vector& justifications) { + void egraph::explain_todo(ptr_vector& justifications, cc_justification* cc) { for (unsigned i = 0; i < m_todo.size(); ++i) { enode* n = m_todo[i]; if (n->is_marked1()) @@ -762,7 +765,7 @@ namespace euf { if (n->m_target) { n->mark1(); CTRACE("euf_verbose", m_display_justification, n->m_justification.display(tout << n->get_expr_id() << " = " << n->m_target->get_expr_id() << " ", m_display_justification) << "\n";); - explain_eq(justifications, n, n->m_target, n->m_justification); + explain_eq(justifications, cc, n, n->m_target, n->m_justification); } else if (!n->is_marked1() && n->value() != l_undef) { n->mark1(); @@ -890,15 +893,15 @@ namespace euf { } } -template void euf::egraph::explain(ptr_vector& justifications); -template void euf::egraph::explain_todo(ptr_vector& justifications); -template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b); -template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, enode* a, enode* b); +template void euf::egraph::explain(ptr_vector& justifications, cc_justification*); +template void euf::egraph::explain_todo(ptr_vector& justifications, cc_justification*); +template void euf::egraph::explain_eq(ptr_vector& justifications, cc_justification*, enode* a, enode* b); +template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, cc_justification*, enode* a, enode* b); -template void euf::egraph::explain(ptr_vector& justifications); -template void euf::egraph::explain_todo(ptr_vector& justifications); -template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b); -template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, enode* a, enode* b); +template void euf::egraph::explain(ptr_vector& justifications, cc_justification*); +template void euf::egraph::explain_todo(ptr_vector& justifications, cc_justification*); +template void euf::egraph::explain_eq(ptr_vector& justifications, cc_justification*, enode* a, enode* b); +template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, cc_justification*, enode* a, enode* b); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 55f94f0f2..c0d7f03d8 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -72,6 +72,13 @@ namespace euf { th_eq(theory_id id, theory_var v1, theory_var v2, expr* eq) : m_id(id), m_v1(v1), m_v2(v2), m_eq(eq), m_root(nullptr) {} }; + + // cc_justification contains the uses of congruence closure + // It is the only information collected from justifications in order to + // reconstruct EUF proofs. Transitivity, Symmetry of equality are not + // tracked. + typedef std::tuple cc_justification_record; + typedef svector cc_justification; class egraph { @@ -186,6 +193,8 @@ namespace euf { stats m_stats; bool m_uses_congruence = false; bool m_default_relevant = true; + uint64_t m_congruence_timestamp = 0; + std::vector> m_on_merge; std::function m_on_make; std::function m_used_eq; @@ -226,10 +235,10 @@ namespace euf { void erase_from_table(enode* p); template - void explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j); + void explain_eq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b, justification const& j); template - void explain_todo(ptr_vector& justifications); + void explain_todo(ptr_vector& justifications, cc_justification* cc); std::ostream& display(std::ostream& out, unsigned max_args, enode* n) const; @@ -306,11 +315,11 @@ namespace euf { void end_explain(); bool uses_congruence() const { return m_uses_congruence; } template - void explain(ptr_vector& justifications); + void explain(ptr_vector& justifications, cc_justification* cc); template - void explain_eq(ptr_vector& justifications, enode* a, enode* b); + void explain_eq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b); template - unsigned explain_diseq(ptr_vector& justifications, enode* a, enode* b); + unsigned explain_diseq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b); enode_vector const& nodes() const { return m_nodes; } ast_manager& get_manager() { return m; } diff --git a/src/ast/euf/euf_justification.h b/src/ast/euf/euf_justification.h index 2241ff0b6..57b532e3b 100644 --- a/src/ast/euf/euf_justification.h +++ b/src/ast/euf/euf_justification.h @@ -13,6 +13,11 @@ Author: Nikolaj Bjorner (nbjorner) 2020-08-23 +Notes: + +- congruence closure justifications are given a timestamp so it is easy to sort them. + See the longer descriptoin in euf_proof_checker.cpp + --*/ #pragma once @@ -27,11 +32,15 @@ namespace euf { }; kind_t m_kind; bool m_comm; - void* m_external; - justification(bool comm): + union { + void* m_external; + uint64_t m_timestamp; + }; + + justification(bool comm, uint64_t ts): m_kind(kind_t::congruence_t), m_comm(comm), - m_external(nullptr) + m_timestamp(ts) {} justification(void* ext): @@ -48,12 +57,13 @@ namespace euf { {} static justification axiom() { return justification(); } - static justification congruence(bool c) { return justification(c); } + static justification congruence(bool c, uint64_t ts) { return justification(c, ts); } static justification external(void* ext) { return justification(ext); } bool is_external() const { return m_kind == kind_t::external_t; } bool is_congruence() const { return m_kind == kind_t::congruence_t; } bool is_commutative() const { return m_comm; } + uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; } template T* ext() const { SASSERT(is_external()); return static_cast(m_external); } @@ -64,7 +74,7 @@ namespace euf { case kind_t::axiom_t: return axiom(); case kind_t::congruence_t: - return congruence(m_comm); + return congruence(m_comm, m_timestamp); default: UNREACHABLE(); return axiom(); @@ -90,4 +100,8 @@ namespace euf { return out; } }; + + inline std::ostream& operator<<(std::ostream& out, justification const& j) { + return j.display(out, nullptr); + } } diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index 4a14067d1..8ef2c5cd5 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -313,7 +313,7 @@ namespace bv { case bv_justification::kind_t::eq2bit: SASSERT(s().value(c.m_antecedent) == l_true); r.push_back(c.m_antecedent); - ctx.add_antecedent(var2enode(c.m_v1), var2enode(c.m_v2)); + ctx.add_antecedent(probing, var2enode(c.m_v1), var2enode(c.m_v2)); break; case bv_justification::kind_t::ne2bit: { r.push_back(c.m_antecedent); @@ -381,8 +381,8 @@ namespace bv { break; } case bv_justification::kind_t::bv2int: { - ctx.add_antecedent(c.a, c.b); - ctx.add_antecedent(c.a, c.c); + ctx.add_antecedent(probing, c.a, c.b); + ctx.add_antecedent(probing, c.a, c.c); break; } } diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index bcbc9439b..bf8aad44b 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -46,7 +46,7 @@ namespace euf { * so it isn't necessarily an axiom over EUF, * We will here leave it to the EUF checker to perform resolution steps. */ - void solver::log_antecedents(literal l, literal_vector const& r) { + void solver::log_antecedents(literal l, literal_vector const& r, eq_proof_hint* hint) { TRACE("euf", log_antecedents(tout, l, r);); if (!use_drat()) return; @@ -55,7 +55,7 @@ namespace euf { lits.push_back(~lit); if (l != sat::null_literal) lits.push_back(l); - get_drat().add(lits, sat::status::th(true, get_id())); + get_drat().add(lits, sat::status::th(true, get_id(), hint)); } void solver::log_antecedents(std::ostream& out, literal l, literal_vector const& r) { @@ -74,6 +74,55 @@ namespace euf { } } + eq_proof_hint* solver::mk_hint(literal lit, literal_vector const& r) { + if (!use_drat()) + return nullptr; + push(value_trail(m_lit_tail)); + push(value_trail(m_cc_tail)); + push(restore_size_trail(m_eq_proof_literals)); + if (lit != sat::null_literal) + m_eq_proof_literals.push_back(~lit); + m_eq_proof_literals.append(r); + m_lit_head = m_lit_tail; + m_cc_head = m_cc_tail; + m_lit_tail = m_eq_proof_literals.size(); + m_cc_tail = m_explain_cc.size(); + return new (get_region()) eq_proof_hint(m_lit_head, m_lit_tail, m_cc_head, m_cc_tail); + } + + expr* eq_proof_hint::get_hint(euf::solver& s) const { + ast_manager& m = s.get_manager(); + func_decl_ref cc(m); + sort* proof = m.mk_proof_sort(); + ptr_buffer sorts; + expr_ref_vector args(m); + if (m_cc_head < m_cc_tail) { + sort* sorts[2] = { m.mk_bool_sort(), m.mk_bool_sort() }; + cc = m.mk_func_decl(symbol("cc"), 2, sorts, proof); + } + auto cc_proof = [&](bool comm, expr* eq) { + return m.mk_app(cc, m.mk_bool_val(comm), eq); + }; + auto compare_ts = [](cc_justification_record const& a, + cc_justification_record const& b) { + auto const& [_1, _2, ta, _3] = a; + auto const& [_4, _5, tb, _6] = b; + return ta < tb; + }; + for (unsigned i = m_lit_head; i < m_lit_tail; ++i) + args.push_back(s.literal2expr(s.m_eq_proof_literals[i])); + std::sort(s.m_explain_cc.data() + m_cc_head, s.m_explain_cc.data() + m_cc_tail, compare_ts); + for (unsigned i = m_cc_head; i < m_cc_tail; ++i) { + auto const& [a, b, ts, comm] = s.m_explain_cc[i]; + args.push_back(cc_proof(comm, m.mk_eq(a->get_expr(), b->get_expr()))); + } + for (auto * arg : args) + sorts.push_back(arg->get_sort()); + + func_decl* f = m.mk_func_decl(symbol("euf"), sorts.size(), sorts.data(), proof); + return m.mk_app(f, args); + } + void solver::set_tmp_bool_var(bool_var b, expr* e) { m_bool_var2expr.setx(b, e, nullptr); } diff --git a/src/sat/smt/euf_proof_checker.cpp b/src/sat/smt/euf_proof_checker.cpp index 41f627914..3ddafd00b 100644 --- a/src/sat/smt/euf_proof_checker.cpp +++ b/src/sat/smt/euf_proof_checker.cpp @@ -15,18 +15,178 @@ Author: --*/ +#include "util/union_find.h" #include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" #include "sat/smt/euf_proof_checker.h" #include "sat/smt/arith_proof_checker.h" namespace euf { + /** + * The equality proof checker checks congruence proofs. + * A congruence claim comprises + * - a set of equality and diseqality literals that are + * unsatisfiable modulo equality reasoning. + * - a list of congruence claims that are used for equality reasoning. + * Congruence claims are expressions of the form + * (cc uses_commutativity (= a b)) + * where uses_commutativity is true or false + * If uses commutativity is true, then a, b are (the same) binary functions + * a := f(x,y), b := f(z,u), such that x = u and y = z are consequences from + * the current equalities. + * If uses_commtativity is false, then a, b are the same n-ary expressions + * each argument position i, a_i == b_i follows from current equalities. + * If the arguments are equal according to the current equalities, then the equality + * a = b is added as a consequence. + * + * The congruence claims can be justified from the equalities in the literals. + * To be more precise, the congruence claims are justified in the they appear. + * The congruence closure algorithm (egraph) uses timestamps to record a timestamp + * when a congruence was inferred. Proof generation ensures that the congruence premises + * are sorted by the timestamp such that a congruence that depends on an earlier congruence + * appears later in the sorted order. + * + * Equality justifications are checked using union-find. + * We use union-find instead of fine-grained equality proofs (symmetry and transitivity + * of equality) assuming that it is both cheap and simple to establish a certified + * union-find checker. + */ + + class eq_proof_checker : public proof_checker_plugin { + ast_manager& m; + basic_union_find m_uf; + svector> m_expr2id; + svector> m_diseqs; + unsigned m_ts = 0; + + void merge(expr* x, expr* y) { + m_uf.merge(expr2id(x), expr2id(y)); + IF_VERBOSE(10, verbose_stream() << "merge " << mk_bounded_pp(x, m) << " == " << mk_bounded_pp(y, m) << "\n"); + } + + bool are_equal(expr* x, expr* y) { + return m_uf.find(expr2id(x)) == m_uf.find(expr2id(y)); + } + + bool congruence(bool comm, app* x, app* y) { + if (x->get_decl() != y->get_decl()) + return false; + if (x->get_num_args() != y->get_num_args()) + return false; + if (comm) { + if (x->get_num_args() != 2) + return false; + if (!are_equal(x->get_arg(0), y->get_arg(1))) + return false; + if (!are_equal(y->get_arg(0), x->get_arg(1))) + return false; + merge(x, y); + } + else { + for (unsigned i = 0; i < x->get_num_args(); ++i) + if (!are_equal(x->get_arg(i), y->get_arg(i))) + return false; + merge(x, y); + } + IF_VERBOSE(10, verbose_stream() << "cc " << mk_bounded_pp(x, m) << " == " << mk_bounded_pp(y, m) << "\n"); + return true; + } + + void reset() { + ++m_ts; + if (m_ts == 0) { + m_expr2id.reset(); + ++m_ts; + } + m_uf.reset(); + m_diseqs.reset(); + } + + unsigned expr2id(expr* e) { + auto [ts, id] = m_expr2id.get(e->get_id(), {0,0}); + if (ts != m_ts) { + id = m_uf.mk_var(); + m_expr2id.setx(e->get_id(), {m_ts, id}, {0,0}); + } + return id; + } + + + public: + eq_proof_checker(ast_manager& m): m(m) {} + + ~eq_proof_checker() override {} + + bool check(expr_ref_vector const& clause, app* jst, expr_ref_vector& units) override { + IF_VERBOSE(10, verbose_stream() << clause << "\n" << mk_pp(jst, m) << "\n"); + reset(); + expr_mark pos, neg; + expr* x, *y; + for (expr* e : clause) + if (m.is_not(e, e)) + neg.mark(e, true); + else + pos.mark(e, true); + + for (expr* arg : *jst) { + if (m.is_bool(arg)) { + bool sign = m.is_not(arg, arg); + if (sign && !pos.is_marked(arg)) + units.push_back(m.mk_not(arg)); + else if (!sign & !neg.is_marked(arg)) + units.push_back(arg); + if (m.is_eq(arg, x, y)) { + if (sign) + m_diseqs.push_back({x, y}); + else + merge(x, y); + } + else + IF_VERBOSE(0, verbose_stream() << "TODO " << mk_pp(arg, m) << " " << sign << "\n"); + } + else if (m.is_proof(arg)) { + if (!is_app(arg)) + return false; + app* a = to_app(arg); + if (a->get_num_args() != 2) + return false; + if (a->get_name() != symbol("cc")) + return false; + if (!m.is_eq(a->get_arg(1), x, y)) + return false; + if (!is_app(x) || !is_app(y)) + return false; + if (!congruence(m.is_true(a->get_arg(0)), to_app(x), to_app(y))) { + IF_VERBOSE(0, verbose_stream() << "not congruent " << mk_pp(a, m) << "\n"); + return false; + } + } + else { + IF_VERBOSE(0, verbose_stream() << "unrecognized argument " << mk_pp(arg, m) << "\n"); + return false; + } + } + for (auto const& [a, b] : m_diseqs) + if (are_equal(a, b)) + return true; + return false; + } + + void register_plugins(proof_checker& pc) override { + pc.register_plugin(symbol("euf"), this); + } + + }; + proof_checker::proof_checker(ast_manager& m): m(m) { arith::proof_checker* apc = alloc(arith::proof_checker, m); + eq_proof_checker* epc = alloc(eq_proof_checker, m); m_plugins.push_back(apc); + m_plugins.push_back(epc); apc->register_plugins(*this); - (void)m; + epc->register_plugins(*this); } proof_checker::~proof_checker() {} diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index c30fdbf89..2e6b07e51 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -202,6 +202,8 @@ namespace euf { void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) { m_egraph.begin_explain(); m_explain.reset(); + if (use_drat() && !probing) + push(restore_size_trail(m_explain_cc, m_explain_cc.size())); auto* ext = sat::constraint_base::to_extension(idx); if (ext == this) get_antecedents(l, constraint::from_idx(idx), r, probing); @@ -220,33 +222,35 @@ namespace euf { } } m_egraph.end_explain(); + eq_proof_hint* hint = (use_drat() && !probing) ? mk_hint(l, r) : nullptr; unsigned j = 0; for (sat::literal lit : r) if (s().lvl(lit) > 0) r[j++] = lit; r.shrink(j); - TRACE("euf", tout << "explain " << l << " <- " << r << " " << probing << "\n";); + CTRACE("euf", probing, tout << "explain " << l << " <- " << r << "\n"); DEBUG_CODE(for (auto lit : r) SASSERT(s().value(lit) == l_true);); if (!probing) - log_antecedents(l, r); + log_antecedents(l, r, hint); } void solver::get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing) { for (auto lit : euf::th_explain::lits(jst)) r.push_back(lit); for (auto eq : euf::th_explain::eqs(jst)) - add_antecedent(eq.first, eq.second); - + add_antecedent(probing, eq.first, eq.second); + if (!probing && use_drat()) log_justification(l, jst); } - void solver::add_antecedent(enode* a, enode* b) { - m_egraph.explain_eq(m_explain, a, b); + void solver::add_antecedent(bool probing, enode* a, enode* b) { + cc_justification* cc = (!probing && use_drat()) ? &m_explain_cc : nullptr; + m_egraph.explain_eq(m_explain, cc, a, b); } - void solver::add_diseq_antecedent(ptr_vector& ex, enode* a, enode* b) { - sat::bool_var v = get_egraph().explain_diseq(ex, a, b); + void solver::add_diseq_antecedent(ptr_vector& ex, cc_justification* cc, enode* a, enode* b) { + sat::bool_var v = get_egraph().explain_diseq(ex, cc, a, b); SASSERT(v == sat::null_bool_var || s().value(v) == l_false); if (v != sat::null_bool_var) ex.push_back(to_ptr(sat::literal(v, true))); @@ -262,14 +266,17 @@ namespace euf { void solver::get_antecedents(literal l, constraint& j, literal_vector& r, bool probing) { expr* e = nullptr; euf::enode* n = nullptr; + cc_justification* cc = nullptr; if (!probing && !m_drating) init_ackerman(); - + if (!probing && use_drat()) + cc = &m_explain_cc; + switch (j.kind()) { case constraint::kind_t::conflict: SASSERT(m_egraph.inconsistent()); - m_egraph.explain(m_explain); + m_egraph.explain(m_explain, cc); break; case constraint::kind_t::eq: e = m_bool_var2expr[l.var()]; @@ -277,14 +284,14 @@ namespace euf { SASSERT(n); SASSERT(n->is_equality()); SASSERT(!l.sign()); - m_egraph.explain_eq(m_explain, n->get_arg(0), n->get_arg(1)); + m_egraph.explain_eq(m_explain, cc, n->get_arg(0), n->get_arg(1)); break; case constraint::kind_t::lit: e = m_bool_var2expr[l.var()]; n = m_egraph.find(e); SASSERT(n); SASSERT(m.is_bool(n->get_expr())); - m_egraph.explain_eq(m_explain, n, (l.sign() ? mk_false() : mk_true())); + m_egraph.explain_eq(m_explain, cc, n, (l.sign() ? mk_false() : mk_true())); break; default: IF_VERBOSE(0, verbose_stream() << (unsigned)j.kind() << "\n"); @@ -423,7 +430,7 @@ namespace euf { m_egraph.begin_explain(); m_explain.reset(); - m_egraph.explain_eq(m_explain, e.child(), e.root()); + m_egraph.explain_eq(m_explain, nullptr, e.child(), e.root()); m_egraph.end_explain(); if (m_egraph.uses_congruence()) return false; diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index d19e0acac..f4867d841 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -60,9 +60,18 @@ namespace euf { std::ostream& display(std::ostream& out) const; }; + class eq_proof_hint : public th_proof_hint { + unsigned m_lit_head, m_lit_tail, m_cc_head, m_cc_tail; + public: + eq_proof_hint(unsigned lh, unsigned lt, unsigned ch, unsigned ct): + m_lit_head(lh), m_lit_tail(lt), m_cc_head(ch), m_cc_tail(ct) {} + expr* get_hint(euf::solver& s) const override; + }; + class solver : public sat::extension, public th_internalizer, public th_decompile, public sat::clause_eh { typedef top_sort deps_t; friend class ackerman; + friend class eq_proof_hint; class user_sort; struct stats { unsigned m_ackerman; @@ -110,6 +119,7 @@ namespace euf { ptr_vector m_bool_var2expr; ptr_vector m_explain; + euf::cc_justification m_explain_cc; unsigned m_num_scopes = 0; unsigned_vector m_var_trail; svector m_scopes; @@ -172,8 +182,11 @@ namespace euf { // proofs void log_antecedents(std::ostream& out, literal l, literal_vector const& r); - void log_antecedents(literal l, literal_vector const& r); + void log_antecedents(literal l, literal_vector const& r, eq_proof_hint* hint); void log_justification(literal l, th_explain const& jst); + literal_vector m_eq_proof_literals; + unsigned m_lit_head = 0, m_lit_tail = 0, m_cc_head = 0, m_cc_tail = 0; + eq_proof_hint* mk_hint(literal lit, literal_vector const& r); bool m_proof_initialized = false; void init_proof(); @@ -307,8 +320,8 @@ namespace euf { void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; void get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); - void add_antecedent(enode* a, enode* b); - void add_diseq_antecedent(ptr_vector& ex, enode* a, enode* b); + void add_antecedent(bool probing, enode* a, enode* b); + void add_diseq_antecedent(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); void add_explain(size_t* p) { m_explain.push_back(p); } void reset_explain() { m_explain.reset(); } void set_eliminated(bool_var v) override; diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index 490bce46e..a4093cdb7 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -113,14 +113,16 @@ namespace q { if (idx != UINT_MAX) lit = c[idx]; m_explain.reset(); + m_explain_cc.reset(); ctx.get_egraph().begin_explain(); ctx.reset_explain(); + euf::cc_justification* cc = ctx.use_drat() ? &m_explain_cc : nullptr; for (auto const& [a, b] : m_evidence) { SASSERT(a->get_root() == b->get_root() || ctx.get_egraph().are_diseq(a, b)); if (a->get_root() == b->get_root()) - ctx.get_egraph().explain_eq(m_explain, a, b); + ctx.get_egraph().explain_eq(m_explain, cc, a, b); else - ctx.add_diseq_antecedent(m_explain, a, b); + ctx.add_diseq_antecedent(m_explain, cc, a, b); } ctx.get_egraph().end_explain(); diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index ef933a3a8..834c8740d 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -96,6 +96,7 @@ namespace q { binding* alloc_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top); ptr_vector m_explain; + euf::cc_justification m_explain_cc; sat::ext_justification_idx mk_justification(unsigned idx, clause& c, euf::enode* const* b); void ensure_ground_enodes(expr* e); diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 494e69e55..5c98a6fac 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -207,7 +207,7 @@ namespace user_solver { for (unsigned id : prop.m_ids) r.append(m_id2justification[id]); for (auto const& p : prop.m_eqs) - ctx.add_antecedent(expr2enode(p.first), expr2enode(p.second)); + ctx.add_antecedent(probing, expr2enode(p.first), expr2enode(p.second)); } /*