From 48deb4d3e02bceeb1f5bb11b46e8b1798811dae7 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 23 Jul 2023 14:31:44 -0700 Subject: [PATCH] fix proof generation for euf-solver Signed-off-by: Nikolaj Bjorner --- src/sat/smt/arith_diagnostics.cpp | 4 +- src/sat/smt/euf_proof.cpp | 84 +++++++++++++++++++++++++++---- src/sat/smt/euf_solver.cpp | 56 ++++++++------------- src/sat/smt/euf_solver.h | 4 +- src/sat/smt/q_ematch.cpp | 2 +- src/sat/smt/sat_th.cpp | 7 ++- 6 files changed, 108 insertions(+), 49 deletions(-) diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index e621ee9d7..9fafbcc80 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -213,7 +213,9 @@ namespace arith { args.push_back(s.literal2expr(lit)); } for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { - auto const& [x, y, is_eq] = a.m_arith_hint.eq(i); + auto [x, y, is_eq] = a.m_arith_hint.eq(i); + if (x->get_id() > y->get_id()) + std::swap(x, y); expr_ref eq(m.mk_eq(x->get_expr(), y->get_expr()), m); if (!is_eq) eq = m.mk_not(eq); args.push_back(arith.mk_int(1)); diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 62a9a11e8..39c9879a6 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -44,17 +44,73 @@ namespace euf { } /** - * \brief logs antecedents to a proof trail. - * - * NB with theories, this is not a pure EUF justification, - * It is true modulo EUF and previously logged certificates - * so it isn't necessarily an axiom over EUF, - * We will here leave it to the EUF checker to perform resolution steps. - */ + * Log justifications. + * is_euf - true if l is justified by congruence closure. In this case create a congruence closure proof. + * explain_size - the relevant portion of premises for the congruence closure proof. + * The EUF solver manages equality propagation. Each propagated equality is justified by a congruence closure. + */ + void solver::log_justifications(literal l, unsigned explain_size, bool is_euf) { + + unsigned nv = s().num_vars(); + expr_ref_vector eqs(m); + + auto add_hint_literals = [&](unsigned sz) { + eqs.reset(); + m_hint_lits.reset(); + nv = s().num_vars(); + for (unsigned i = 0; i < sz; ++i) { + size_t* e = m_explain[i]; + if (is_literal(e)) + m_hint_lits.push_back(get_literal(e)); + else { + auto [x, y] = th_explain::from_index(get_justification(e)).eq_consequent(); + eqs.push_back(m.mk_eq(x->get_expr(), y->get_expr())); + set_tmp_bool_var(nv, eqs.back()); + m_hint_lits.push_back(literal(nv, false)); + ++nv; + } + } + }; + + auto clear_hint_literals = [&]() { + for (unsigned v = s().num_vars(); v < nv; ++v) + set_tmp_bool_var(v, nullptr); + }; + + // log EUF justifications + if (is_euf) { + add_hint_literals(explain_size); + auto* hint = mk_hint(m_euf, l); + log_antecedents(l, m_hint_lits, hint); + clear_hint_literals(); + } + + // explain equalities + for (auto const& [a, b] : m_hint_eqs) { + m_egraph.begin_explain(); + m_explain.reset(); + m_egraph.explain_eq(m_explain, &m_explain_cc, a, b); + m_egraph.end_explain(); + // Detect shortcut if equality is explained directly by a theory + if (m_explain.size() == 1 && !is_literal(m_explain[0])) { + auto const& [x, y] = th_explain::from_index(get_justification(m_explain[0])).eq_consequent(); + if (x == a && y == b) + continue; + } + add_hint_literals(m_explain.size()); + eqs.push_back(m.mk_eq(a->get_expr(), b->get_expr())); + set_tmp_bool_var(nv, eqs.back()); + sat::literal eql = literal(nv, false); + ++nv; + auto* hint = mk_hint(m_euf, eql); + log_antecedents(eql, m_hint_lits, hint); + clear_hint_literals(); + } + } + void solver::log_antecedents(literal l, literal_vector const& r, th_proof_hint* hint) { + SASSERT(hint && use_drat()); TRACE("euf", log_antecedents(tout, l, r); tout << mk_pp(hint->get_hint(*this), m) << "\n"); - if (!use_drat()) - return; literal_vector lits; for (literal lit : r) lits.push_back(~lit); @@ -63,6 +119,15 @@ namespace euf { get_drat().add(lits, sat::status::th(true, get_id(), hint)); } + void solver::log_rup(literal l, literal_vector const& r) { + literal_vector lits; + for (literal lit : r) + lits.push_back(~lit); + if (l != sat::null_literal) + lits.push_back(l); + get_drat().add(lits, sat::status::redundant()); + } + void solver::log_antecedents(std::ostream& out, literal l, literal_vector const& r) { for (sat::literal l : r) { expr* n = m_bool_var2expr[l.var()]; @@ -159,6 +224,7 @@ namespace euf { }; for (unsigned i = m_lit_head; i < m_lit_tail; ++i) args.push_back(s.literal2expr(s.m_proof_literals[i])); + std::sort(s.m_explain_cc.data() + m_cc_head, s.m_explain_cc.data() + m_cc_tail, compare_ts); for (unsigned i = m_cc_head; i < m_cc_tail; ++i) { auto const& [a, b, ts, comm] = s.m_explain_cc[i]; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 94ff9db38..8f0f29445 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -228,69 +228,52 @@ namespace euf { void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) { bool create_hint = use_drat() && !probing; - m_egraph.begin_explain(); - m_explain.reset(); if (create_hint) { push(restore_vector(m_explain_cc)); m_hint_eqs.reset(); - m_hint_lits.reset(); } auto* ext = sat::constraint_base::to_extension(idx); - th_proof_hint* hint = nullptr; + bool is_euf = ext == this; + bool multiple_theories = false; - if (ext == this) + m_egraph.begin_explain(); + m_explain.reset(); + if (is_euf) get_euf_antecedents(l, constraint::from_idx(idx), r, probing); else ext->get_antecedents(l, idx, r, probing); - if (create_hint && ext != this) - ext->get_antecedents(l, idx, m_hint_lits, probing); + unsigned ez = m_explain.size(); for (unsigned qhead = 0; qhead < m_explain.size(); ++qhead) { size_t* e = m_explain[qhead]; if (is_literal(e)) r.push_back(get_literal(e)); else { + multiple_theories = true; size_t idx = get_justification(e); auto* ext = sat::constraint_base::to_extension(idx); SASSERT(ext != this); sat::literal lit = sat::null_literal; ext->get_antecedents(lit, idx, r, probing); } - if (create_hint) { - if (is_literal(e)) - m_hint_lits.push_back(get_literal(e)); - else { - auto const& eq = th_explain::from_index(get_justification(e)).eq_consequent(); - TRACE("euf", tout << "consequent " << bpp(eq.first) << " " << bpp(eq.second) << "\n"; ); - m_hint_eqs.push_back(eq); - } - } } m_egraph.end_explain(); CTRACE("euf", probing, tout << "explain " << l << " <- " << r << "\n"); unsigned j = 0; - for (sat::literal lit : r) - if (s().lvl(lit) > 0) r[j++] = lit; + for (auto lit : r) + if (s().lvl(lit) > 0) + r[j++] = lit; + bool reduced = j < r.size(); r.shrink(j); - CTRACE("euf", create_hint, tout << "explain " << l << " <- " << m_hint_lits << "\n"); + DEBUG_CODE(for (auto lit : r) SASSERT(s().value(lit) == l_true);); - if (create_hint) { - unsigned nv = s().num_vars(); - expr_ref_vector eqs(m); - // add equalities to hint. - for (auto const& [a,b] : m_hint_eqs) { - eqs.push_back(m.mk_eq(a->get_expr(), b->get_expr())); - set_tmp_bool_var(nv, eqs.back()); - m_hint_lits.push_back(literal(nv, false)); - ++nv; - } - hint = mk_hint(m_euf, l); - log_antecedents(l, r, hint); - for (unsigned v = s().num_vars(); v < nv; ++v) - set_tmp_bool_var(v, nullptr); + if (create_hint) { + log_justifications(l, ez, is_euf); + if (reduced || multiple_theories) + log_rup(l, r); } } @@ -305,11 +288,12 @@ namespace euf { } void solver::add_eq_antecedent(bool probing, enode* a, enode* b) { - cc_justification* cc = (!probing && use_drat()) ? &m_explain_cc : nullptr; - m_egraph.explain_eq(m_explain, cc, a, b); + if (!probing && use_drat()) + m_hint_eqs.push_back({a, b}); + m_egraph.explain_eq(m_explain, nullptr, a, b); } - void solver::add_diseq_antecedent(ptr_vector& ex, cc_justification* cc, enode* a, enode* b) { + void solver::explain_diseq(ptr_vector& ex, cc_justification* cc, enode* a, enode* b) { sat::bool_var v = get_egraph().explain_diseq(ex, cc, a, b); SASSERT(v == sat::null_bool_var || s().value(v) == l_false); if (v != sat::null_bool_var) diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index f2add1e87..a9dc20e0e 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -228,6 +228,8 @@ namespace euf { void log_antecedents(std::ostream& out, literal l, literal_vector const& r); void log_antecedents(literal l, literal_vector const& r, th_proof_hint* hint); void log_justification(literal l, th_explain const& jst); + void log_justifications(literal l, unsigned explain_size, bool is_euf); + void log_rup(literal l, literal_vector const& r); eq_proof_hint* mk_hint(symbol const& th, literal lit); @@ -367,7 +369,7 @@ namespace euf { void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; void get_th_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); void add_eq_antecedent(bool probing, enode* a, enode* b); - void add_diseq_antecedent(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); + void explain_diseq(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); void add_explain(size_t* p) { m_explain.push_back(p); } void reset_explain() { m_explain.reset(); } void set_eliminated(bool_var v) override; diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index df832a675..ec10426a7 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -125,7 +125,7 @@ namespace q { if (a->get_root() == b->get_root()) ctx.get_egraph().explain_eq(m_explain, cc, a, b); else - ctx.add_diseq_antecedent(m_explain, cc, a, b); + ctx.explain_diseq(m_explain, cc, a, b); } ctx.get_egraph().end_explain(); diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index 17d167829..21e3883e8 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -228,6 +228,8 @@ namespace euf { th_explain::th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& p, th_proof_hint const* pma) { m_consequent = c; m_eq = p; + if (m_eq.first && m_eq.first->get_id() > m_eq.second->get_id()) + std::swap(m_eq.first, m_eq.second); m_proof_hint = pma; m_num_literals = n_lits; m_num_eqs = n_eqs; @@ -238,8 +240,11 @@ namespace euf { m_literals[i] = lits[i]; base_ptr += sizeof(literal) * n_lits; m_eqs = reinterpret_cast(base_ptr); - for (i = 0; i < n_eqs; ++i) + for (i = 0; i < n_eqs; ++i) { m_eqs[i] = eqs[i]; + if (m_eqs[i].first->get_id() > m_eqs[i].second->get_id()) + std::swap(m_eqs[i].first, m_eqs[i].second); + } } th_explain* th_explain::mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, th_proof_hint const* pma) {