From e018b024c52d65a43689dfc2e19e6dfbfefd76b3 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 12 Jun 2025 11:31:50 -0700 Subject: [PATCH] adding proofs to euf-completion --- src/ast/simplifiers/euf_completion.cpp | 246 ++++++++++++++++++------- src/ast/simplifiers/euf_completion.h | 26 ++- src/tactic/goal.cpp | 1 + 3 files changed, 195 insertions(+), 78 deletions(-) diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index 75abb8c6f..3096d89ec 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -54,6 +54,7 @@ Mam optimization? #include "ast/rewriter/var_subst.h" #include "ast/simplifiers/euf_completion.h" #include "ast/shared_occs.h" +#include "ast/scoped_proof.h" #include "params/smt_params_helper.hpp" namespace euf { @@ -64,6 +65,7 @@ namespace euf { m_mam(mam::mk(*this, *this)), m_canonical(m), m_eargs(m), + m_canonical_proofs(m), m_deps(m), m_rewriter(m) { m_tt = m_egraph.mk(m.mk_true(), 0, 0, nullptr); @@ -176,7 +178,7 @@ namespace euf { add_egraph(); map_canonical(); read_egraph(); - IF_VERBOSE(11, verbose_stream() << "(euf.completion :rounds " << rounds << ")\n"); + IF_VERBOSE(1, verbose_stream() << "(euf.completion :rounds " << rounds << " :instances " << m_stats.m_num_instances << " :stop " << should_stop() << ")\n"); } } @@ -186,6 +188,7 @@ namespace euf { for (unsigned i = qhead(); i < sz; ++i) { auto [f, p, d] = m_fmls[i](); + add_constraint(f, p, d); } m_should_propagate = true; @@ -200,6 +203,14 @@ namespace euf { } } + unsigned completion::push_pr_dep(proof* pr, expr_dependency* d) { + unsigned sz = m_pr_dep.size(); + SASSERT(!m.proofs_enabled() || pr); + m_pr_dep.push_back({ proof_ref(pr, m), d }); + get_trail().push(push_back_vector(m_pr_dep)); + return sz; + } + void completion::add_constraint(expr* f, proof* pr, expr_dependency* d) { if (m_egraph.inconsistent()) return; @@ -211,18 +222,19 @@ namespace euf { if (m.is_eq(f, x, y)) { enode* a = mk_enode(x); enode* b = mk_enode(y); - m_egraph.merge(a, b, d); + + m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); add_children(a); add_children(b); } else if (m.is_not(f, f)) { enode* n = mk_enode(f); - m_egraph.merge(n, m_ff, d); + m_egraph.merge(n, m_ff, to_ptr(push_pr_dep(pr, d))); add_children(n); } else { enode* n = mk_enode(f); - m_egraph.merge(n, m_tt, d); + m_egraph.merge(n, m_tt, to_ptr(push_pr_dep(pr, d))); add_children(n); if (is_forall(f)) { quantifier* q = to_quantifier(f); @@ -234,8 +246,7 @@ namespace euf { mk_enode(g); m_mam->add_pattern(q, p); } - auto pq = get_dependency(q); - m_q2dep.insert(q, pq); + m_q2dep.insert(q, { pr, d}); get_trail().push(insert_obj_map(m_q2dep, q)); } add_rule(f, pr, d); @@ -248,7 +259,8 @@ namespace euf { auto n = mk_enode(f); if (m.is_true(n->get_root()->get_expr())) { d = m.mk_join(d, explain_eq(n, n->get_root())); - // TODO update pr + if (m.proofs_enabled()) + pr = prove_eq(n, n->get_root()); return l_true; } if (m.is_false(n->get_root()->get_expr())) @@ -259,7 +271,8 @@ namespace euf { n = mk_enode(g); if (m.is_false(n->get_root()->get_expr())) { d = m.mk_join(d, explain_eq(n, n->get_root())); - // TODO update pr + if (m.proofs_enabled()) + pr = prove_eq(n, n->get_root()); return l_true; } if (m.is_true(n->get_root()->get_expr())) @@ -282,11 +295,12 @@ namespace euf { return; expr_ref_vector body(m); proof_ref pr_i(m), pr0(m); - proof_ref_vector prs(m); + expr_ref_vector prs(m); expr_ref head(y, m); body.push_back(x); flatten_and(body); unsigned j = 0; + for (auto f : body) { switch (eval_cond(f, pr_i, d)) { case l_true: @@ -302,15 +316,19 @@ namespace euf { } body.shrink(j); if (m.proofs_enabled()) { - // TODO + prs.push_back(pr); + if (body.empty()) { + prs.push_back(head); + pr0 = m.mk_app(symbol("rup"), prs.size(), prs.data(), m.mk_proof_sort()); + } } if (body.empty()) - add_constraint(head, pr0, d); + add_constraint(head, pr0, d); else { euf::enode_vector _body; for (auto* f : body) _body.push_back(m_egraph.find(f)->get_root()); - auto r = alloc(conditional_rule, _body, head, pr0, d); + auto r = alloc(conditional_rule, _body, head, prs, d); m_rules.push_back(r); get_trail().push(new_obj_trail(r)); get_trail().push(push_back_vector(m_rules)); @@ -347,14 +365,15 @@ namespace euf { void completion::propagate_rule(conditional_rule& r) { if (!r.m_active) return; + proof_ref pr(m); for (unsigned i = r.m_watch_index; i < r.m_body.size(); ++i) { auto* f = r.m_body.get(i); - proof_ref pr(m); switch (eval_cond(f->get_expr(), pr, r.m_dep)) { case l_true: get_trail().push(value_trail(r.m_watch_index)); + get_trail().push(push_back_vector(r.m_proofs)); ++r.m_watch_index; - // TODO accumulate proof in r? + r.m_proofs.push_back(pr); break; case l_false: get_trail().push(value_trail(r.m_active)); @@ -366,7 +385,12 @@ namespace euf { } } if (r.m_body.empty()) { - add_constraint(r.m_head, r.m_proof, r.m_dep); + if (m.proofs_enabled()) { + get_trail().push(push_back_vector(r.m_proofs)); + r.m_proofs.push_back(r.m_head); + pr = m.mk_app(symbol("rup"), r.m_proofs.size(), r.m_proofs.data(), m.mk_proof_sort()); + } + add_constraint(r.m_head, pr, r.m_dep); get_trail().push(value_trail(r.m_active)); r.m_active = false; } @@ -374,7 +398,7 @@ namespace euf { // callback when mam finds a binding void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) { - if (m_egraph.inconsistent()) + if (should_stop()) return; var_subst subst(m); expr_ref_vector _binding(m); @@ -385,9 +409,11 @@ namespace euf { } expr_ref r = subst(q->get_expr(), _binding); IF_VERBOSE(12, verbose_stream() << "add " << r << "\n"); - IF_VERBOSE(1, verbose_stream() << max_generation << "\n"); + IF_VERBOSE(10, verbose_stream() << max_generation << "\n"); scoped_generation sg(*this, max_generation + 1); auto [pr, d] = get_dependency(q); + if (pr) + pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), _binding.size(), _binding.data()); add_constraint(r, pr, d); propagate_rules(); m_should_propagate = true; @@ -395,9 +421,15 @@ namespace euf { } void completion::read_egraph() { + //m_egraph.display(verbose_stream()); + //exit(0); if (m_egraph.inconsistent()) { auto* d = explain_conflict(); - dependent_expr de(m, m.mk_false(), nullptr, d); + proof_ref pr(m); + if (m.proofs_enabled()) + pr = prove_conflict(); + + dependent_expr de(m, m.mk_false(), pr.get(), d); m_fmls.update(0, de); return; } @@ -405,11 +437,12 @@ namespace euf { for (unsigned i = qhead(); i < sz; ++i) { auto [f, p, d] = m_fmls[i](); expr_dependency_ref dep(d, m); - expr_ref g = canonize_fml(f, dep); + proof_ref pr(p, m); + expr_ref g = canonize_fml(f, pr, dep); if (g != f) { - m_fmls.update(i, dependent_expr(m, g, nullptr, dep)); + m_fmls.update(i, dependent_expr(m, g, pr, dep)); m_stats.m_num_rewrites++; - IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(f, m, 3) << " -> " << mk_bounded_pp(g, m, 3) << "\n"); + IF_VERBOSE(0, verbose_stream() << mk_bounded_pp(f, m, 3) << " -> " << mk_bounded_pp(g, m, 3) << "\n"); update_has_new_eq(g); } CTRACE(euf_completion, g != f, tout << mk_bounded_pp(f, m) << " -> " << mk_bounded_pp(g, m) << "\n"); @@ -475,65 +508,79 @@ namespace euf { return m_egraph.find(e); } - expr_ref completion::canonize_fml(expr* f, expr_dependency_ref& d) { + + expr_ref completion::canonize_fml(expr* f, proof_ref& pr, expr_dependency_ref& d) { auto is_nullary = [&](expr* e) { return is_app(e) && to_app(e)->get_num_args() == 0; }; expr* x, * y; + proof_ref pr1(m), pr2(m), pr3(m); if (m.is_eq(f, x, y)) { - expr_ref x1 = canonize(x, d); - expr_ref y1 = canonize(y, d); + expr_ref x1 = canonize(x, pr1, d); + expr_ref y1 = canonize(y, pr2, d); if (is_nullary(x)) { SASSERT(x1 == x); - x1 = get_canonical(x, d); + x1 = get_canonical(x, pr1, d); } if (is_nullary(y)) { SASSERT(y1 == y); - y1 = get_canonical(y, d); + y1 = get_canonical(y, pr2, d); } + expr_ref r(m); + if (x == y) - return expr_ref(m.mk_true(), m); - - if (x == x1 && y == y1) - return m_rewriter.mk_eq(x, y); - - if (is_nullary(x) && is_nullary(y)) - return mk_and(m_rewriter.mk_eq(x, x1), m_rewriter.mk_eq(y, x1)); - - if (x == x1 && is_nullary(x)) - return m_rewriter.mk_eq(y1, x1); - - if (y == y1 && is_nullary(y)) - return m_rewriter.mk_eq(x1, y1); - - if (is_nullary(x)) - return mk_and(m_rewriter.mk_eq(x, x1), m_rewriter.mk_eq(y1, x1)); - - if (is_nullary(y)) - return mk_and(m_rewriter.mk_eq(y, y1), m_rewriter.mk_eq(x1, y1)); - + r = expr_ref(m.mk_true(), m); + else if (x == x1 && y == y1) + r = m_rewriter.mk_eq(x, y); + else if (is_nullary(x) && is_nullary(y)) + r = mk_and(m_rewriter.mk_eq(x, x1), m_rewriter.mk_eq(y, x1)); + else if (x == x1 && is_nullary(x)) + r = m_rewriter.mk_eq(y1, x1); + else if (y == y1 && is_nullary(y)) + r = m_rewriter.mk_eq(x1, y1); + else if (is_nullary(x)) + r = mk_and(m_rewriter.mk_eq(x, x1), m_rewriter.mk_eq(y1, x1)); + else if (is_nullary(y)) + r = mk_and(m_rewriter.mk_eq(y, y1), m_rewriter.mk_eq(x1, y1)); if (x1 == y1) - return expr_ref(m.mk_true(), m); + r = expr_ref(m.mk_true(), m); else { - expr* c = get_canonical(x, d); + expr* c = get_canonical(x, pr3, d); if (c == x1) - return m_rewriter.mk_eq(y1, c); + r = m_rewriter.mk_eq(y1, c); else if (c == y1) - return m_rewriter.mk_eq(x1, c); + r = m_rewriter.mk_eq(x1, c); else - return mk_and(m_rewriter.mk_eq(x1, c), m_rewriter.mk_eq(y1, c)); + r = mk_and(m_rewriter.mk_eq(x1, c), m_rewriter.mk_eq(y1, c)); } + + if (m.proofs_enabled()) { + expr_ref_vector prs(m); + prs.push_back(pr); + if (pr1) prs.push_back(pr1); + if (pr2) prs.push_back(pr2); + if (pr3) prs.push_back(pr3); + prs.push_back(r); + pr = m.mk_app(symbol("euf"), prs.size(), prs.data(), m.mk_proof_sort()); + } + + return r; } if (m.is_not(f, x)) { - expr_ref x1 = canonize(x, d); - return expr_ref(mk_not(m, x1), m); + expr_ref x1 = canonize(x, pr1, d); + expr_ref r(mk_not(m, x1), m); + if (m.proofs_enabled()) { + expr* prs[3] = { pr, pr1, r }; + pr = m.mk_app(symbol("euf"), 3, prs, m.mk_proof_sort()); + } + return r; } - return canonize(f, d); + return canonize(f, pr, d); } expr_ref completion::mk_and(expr* a, expr* b) { @@ -544,29 +591,44 @@ namespace euf { return expr_ref(m.mk_and(a, b), m); } - expr_ref completion::canonize(expr* f, expr_dependency_ref& d) { + expr_ref completion::canonize(expr* f, proof_ref& pr, expr_dependency_ref& d) { if (!is_app(f)) return expr_ref(f, m); // todo could normalize ground expressions under quantifiers m_eargs.reset(); bool change = false; + expr_ref_vector prs(m); for (expr* arg : *to_app(f)) { - m_eargs.push_back(get_canonical(arg, d)); + proof_ref pr1(m); + m_eargs.push_back(get_canonical(arg, pr1, d)); change |= arg != m_eargs.back(); + if (arg != m_eargs.back() && pr1) + prs.push_back(pr1); } + expr_ref r(m); if (m.is_eq(f)) - return m_rewriter.mk_eq(m_eargs.get(0), m_eargs.get(1)); - if (!change) + r = m_rewriter.mk_eq(m_eargs.get(0), m_eargs.get(1)); + else if (!change) return expr_ref(f, m); else - return expr_ref(m_rewriter.mk_app(to_app(f)->get_decl(), m_eargs.size(), m_eargs.data()), m); + r = expr_ref(m_rewriter.mk_app(to_app(f)->get_decl(), m_eargs.size(), m_eargs.data()), m); + if (m.proofs_enabled()) { + prs.push_back(r); + pr = m.mk_app(symbol("euf"), prs.size(), prs.data(), m.mk_proof_sort()); + } + return r; } - expr* completion::get_canonical(expr* f, expr_dependency_ref& d) { + expr* completion::get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d) { enode* n = m_egraph.find(f); enode* r = n->get_root(); d = m.mk_join(d, explain_eq(n, r)); d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); + if (m.proofs_enabled()) { + pr = prove_eq(n, r); + if (get_canonical_proof(r)) + pr = m.mk_transitivity(pr, get_canonical_proof(r)); + } SASSERT(m_canonical.get(r->get_id())); return m_canonical.get(r->get_id()); } @@ -578,7 +640,14 @@ namespace euf { return nullptr; } - void completion::set_canonical(enode* n, expr* e) { + proof* completion::get_canonical_proof(enode* n) { + if (m_epochs.get(n->get_id(), 0) == m_epoch && n->get_id() < m_canonical_proofs.size()) + return m_canonical_proofs.get(n->get_id()); + else + return nullptr; + } + + void completion::set_canonical(enode* n, expr* e, proof* pr) { class vtrail : public trail { expr_ref_vector& c; unsigned idx; @@ -597,33 +666,63 @@ namespace euf { if (num_scopes() > 0 && m_canonical.size() > n->get_id()) m_trail.push(vtrail(m_canonical, n->get_id())); m_canonical.setx(n->get_id(), e); + if (pr) + m_canonical_proofs.setx(n->get_id(), pr); m_epochs.setx(n->get_id(), m_epoch, 0); } expr_dependency* completion::explain_eq(enode* a, enode* b) { if (a == b) return nullptr; - ptr_vector just; + ptr_vector just; m_egraph.begin_explain(); m_egraph.explain_eq(just, nullptr, a, b); m_egraph.end_explain(); expr_dependency* d = nullptr; - for (expr_dependency* d2 : just) - d = m.mk_join(d, d2); + for (size_t* j : just) + d = m.mk_join(d, m_pr_dep[from_ptr(j)].second); return d; } expr_dependency* completion::explain_conflict() { - ptr_vector just; + ptr_vector just; m_egraph.begin_explain(); m_egraph.explain(just, nullptr); m_egraph.end_explain(); expr_dependency* d = nullptr; - for (expr_dependency* d2 : just) - d = m.mk_join(d, d2); + for (size_t* j : just) + d = m.mk_join(d, m_pr_dep[from_ptr(j)].second); return d; } + proof_ref completion::prove_eq(enode* a, enode* b) { + expr_ref_vector prs(m); + proof_ref pr(m); + ptr_vector just; + m_egraph.begin_explain(); + m_egraph.explain_eq(just, nullptr, a, b); + m_egraph.end_explain(); + for (size_t* j : just) + prs.push_back(m_pr_dep[from_ptr(j)].first); + prs.push_back(m.mk_eq(a->get_expr(), b->get_expr())); + pr = m.mk_app(symbol("euf"), prs.size(), prs.data(), m.mk_proof_sort()); + return pr; + } + + proof_ref completion::prove_conflict() { + expr_ref_vector prs(m); + proof_ref pr(m); + ptr_vector just; + m_egraph.begin_explain(); + m_egraph.explain(just, nullptr); + m_egraph.end_explain(); + for (size_t* j : just) + prs.push_back(m_pr_dep[from_ptr(j)].first); + prs.push_back(m.mk_false()); + pr = m.mk_app(symbol("euf"), prs.size(), prs.data(), m.mk_proof_sort()); + return pr; + } + void completion::collect_statistics(statistics& st) const { st.update("euf-completion-rewrites", m_stats.m_num_rewrites); st.update("euf-completion-instances", m_stats.m_num_instances); @@ -715,6 +814,7 @@ namespace euf { m_deps.setx(r->get_id(), d); } expr_ref new_expr(m); + expr_ref_vector prs(m); while (!m_todo.empty()) { expr* e = m_todo.back(); enode* n = m_egraph.find(e); @@ -723,7 +823,7 @@ namespace euf { if (get_canonical(n)) m_todo.pop_back(); else if (get_depth(rep->get_expr()) == 0 || !is_app(rep->get_expr())) { - set_canonical(n, rep->get_expr()); + set_canonical(n, rep->get_expr(), nullptr); m_todo.pop_back(); } else { @@ -731,6 +831,8 @@ namespace euf { unsigned sz = m_todo.size(); bool new_arg = false; expr_dependency* d = m_deps.get(n->get_id(), nullptr); + proof_ref pr(m); + prs.reset(); for (enode* arg : enode_args(rep)) { enode* rarg = arg->get_root(); expr* c = get_canonical(rarg); @@ -738,6 +840,8 @@ namespace euf { m_eargs.push_back(c); new_arg |= c != arg->get_expr(); d = m.mk_join(d, m_deps.get(rarg->get_id(), nullptr)); + if (m.proofs_enabled() && c != arg->get_expr() && get_canonical_proof(rarg)) + prs.push_back(get_canonical_proof(rarg)); } else m_todo.push_back(rarg->get_expr()); @@ -748,7 +852,11 @@ namespace euf { new_expr = m_rewriter.mk_app(to_app(rep->get_expr())->get_decl(), m_eargs.size(), m_eargs.data()); else new_expr = rep->get_expr(); - set_canonical(n, new_expr); + if (m.proofs_enabled() && new_arg) { + prs.push_back(m.mk_eq(n->get_expr(), new_expr)); + pr = m.mk_app(symbol("euf"), prs.size(), prs.data(), m.mk_proof_sort()); + } + set_canonical(n, new_expr, pr); m_deps.setx(n->get_id(), d); } } diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index d991fa7e4..c9c92f948 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -54,13 +54,13 @@ namespace euf { struct conditional_rule { euf::enode_vector m_body; expr_ref m_head; - proof_ref m_proof; + expr_ref_vector m_proofs; expr_dependency* m_dep; unsigned m_watch_index = 0; bool m_active = true; bool m_in_queue = false; - conditional_rule(euf::enode_vector& b, expr_ref& h, proof* pr, expr_dependency* d) : - m_body(b), m_head(h), m_proof(pr, h.get_manager()), m_dep(d) {} + conditional_rule(euf::enode_vector& b, expr_ref& h, expr_ref_vector& prs, expr_dependency* d) : + m_body(b), m_head(h), m_proofs(prs), m_dep(d) {} }; egraph m_egraph; @@ -69,8 +69,10 @@ namespace euf { ptr_vector m_todo; enode_vector m_args, m_reps, m_nodes_to_canonize; expr_ref_vector m_canonical, m_eargs; + proof_ref_vector m_canonical_proofs; expr_dependency_ref_vector m_deps; obj_map> m_q2dep; + vector> m_pr_dep; unsigned m_epoch = 0; unsigned_vector m_epochs; th_rewriter m_rewriter; @@ -82,6 +84,10 @@ namespace euf { unsigned m_max_instantiations = std::numeric_limits::max(); unsigned m_generation = 0; vector> m_rule_watch; + + size_t* to_ptr(size_t i) const { return reinterpret_cast(i); } + unsigned from_ptr(size_t* s) const { return (unsigned)reinterpret_cast(s); } + unsigned push_pr_dep(proof* pr, expr_dependency* d); enode* mk_enode(expr* e); bool is_new_eq(expr* a, expr* b); @@ -90,20 +96,21 @@ namespace euf { void add_egraph(); void map_canonical(); void read_egraph(); - expr_ref canonize(expr* f, expr_dependency_ref& dep); - expr_ref canonize_fml(expr* f, expr_dependency_ref& dep); - expr* get_canonical(expr* f, expr_dependency_ref& d); + expr_ref canonize(expr* f, proof_ref& pr, expr_dependency_ref& dep); + expr_ref canonize_fml(expr* f, proof_ref& pr, expr_dependency_ref& dep); + expr* get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d); expr* get_canonical(enode* n); - void set_canonical(enode* n, expr* e); + proof* get_canonical_proof(enode* n); + void set_canonical(enode* n, expr* e, proof* pr); void add_constraint(expr*f, proof* pr, expr_dependency* d); expr_dependency* explain_eq(enode* a, enode* b); - void prove_eq(enode* a, enode* b, proof_ref& pr); + proof_ref prove_eq(enode* a, enode* b); + proof_ref prove_conflict(); expr_dependency* explain_conflict(); std::pair get_dependency(quantifier* q) { return m_q2dep.contains(q) ? m_q2dep[q] : std::pair(nullptr, nullptr); } lbool eval_cond(expr* f, proof_ref& pr, expr_dependency*& d); - bool should_stop(); void add_rule(expr* f, proof* pr, expr_dependency* d); @@ -129,6 +136,7 @@ namespace euf { void collect_statistics(statistics& st) const override; void reset_statistics() override { m_stats.reset(); } void updt_params(params_ref const& p) override; + bool supports_proofs() const override { return true; } trail_stack& get_trail() override { return m_trail;} region& get_region() override { return m_trail.get_region(); } diff --git a/src/tactic/goal.cpp b/src/tactic/goal.cpp index f8af6a36e..ba7676a67 100644 --- a/src/tactic/goal.cpp +++ b/src/tactic/goal.cpp @@ -296,6 +296,7 @@ void goal::update(unsigned i, expr * f, proof * pr, expr_dependency * d) { if (!m_inconsistent) { if (m().is_false(out_f)) { push_back(out_f, out_pr, d); + m_inconsistent = true; } else { m().set(m_forms, i, out_f);