From 7ca94e8fef6d5f635fce88f1488b50b5b8382b9f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 10 May 2025 16:15:04 -0700 Subject: [PATCH] add E-matching to EUF completion --- src/ast/simplifiers/euf_completion.cpp | 91 ++++++++++++++++++++------ src/ast/simplifiers/euf_completion.h | 19 +++++- 2 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index 280b5e6bf..61ff7f644 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -42,6 +42,7 @@ Algorithm for extracting canonical form from an E-graph: #include "ast/ast_pp.h" #include "ast/ast_util.h" #include "ast/euf/euf_egraph.h" +#include "ast/rewriter/var_subst.h" #include "ast/simplifiers/euf_completion.h" #include "ast/shared_occs.h" @@ -50,6 +51,7 @@ namespace euf { completion::completion(ast_manager& m, dependent_expr_state& fmls): dependent_expr_simplifier(m, fmls), m_egraph(m), + m_mam(mam::mk(*this, *this)), m_canonical(m), m_eargs(m), m_deps(m), @@ -58,6 +60,19 @@ namespace euf { m_ff = m_egraph.mk(m.mk_false(), 0, 0, nullptr); m_rewriter.set_order_eq(true); m_rewriter.set_flat_and_or(false); + + std::function _on_merge = + [&](euf::enode* root, euf::enode* other) { + m_mam->on_merge(root, other); + }; + + std::function _on_make = + [&](euf::enode* n) { + m_mam->add_node(n, false); + }; + + m_egraph.set_on_merge(_on_merge); + m_egraph.set_on_make(_on_make); } void completion::reduce() { @@ -75,33 +90,67 @@ namespace euf { void completion::add_egraph() { m_nodes_to_canonize.reset(); unsigned sz = qtail(); + + for (unsigned i = qhead(); i < sz; ++i) { + auto [f, p, d] = m_fmls[i](); + add_constraint(f, d); + } + m_should_propagate = true; + while (m_should_propagate) { + m_should_propagate = false; + m_egraph.propagate(); + m_mam->propagate(); + } + } + + void completion::add_constraint(expr* f, expr_dependency* d) { auto add_children = [&](enode* n) { for (auto* ch : enode_args(n)) m_nodes_to_canonize.push_back(ch); }; - - for (unsigned i = qhead(); i < sz; ++i) { - expr* x, * y; - auto [f, p, d] = m_fmls[i](); - if (m.is_eq(f, x, y)) { - enode* a = mk_enode(x); - enode* b = mk_enode(y); - m_egraph.merge(a, b, d); - add_children(a); - add_children(b); - } - else if (m.is_not(f, f)) { - enode* n = mk_enode(f); - m_egraph.merge(n, m_ff, d); - add_children(n); - } - else { - enode* n = mk_enode(f); - m_egraph.merge(n, m_tt, d); - add_children(n); + expr* x, * y; + if (m.is_eq(f, x, y)) { + enode* a = mk_enode(x); + enode* b = mk_enode(y); + m_egraph.merge(a, b, d); + add_children(a); + add_children(b); + } + else if (m.is_not(f, f)) { + enode* n = mk_enode(f); + m_egraph.merge(n, m_ff, d); + add_children(n); + } + else { + enode* n = mk_enode(f); + m_egraph.merge(n, m_tt, d); + add_children(n); + if (is_forall(f)) { + quantifier* q = to_quantifier(f); + ptr_vector ground; + for (unsigned i = 0; i < q->get_num_patterns(); ++i) { + auto p = to_app(q->get_pattern(i)); + mam::ground_subterms(p, ground); + for (expr* g : ground) + mk_enode(g); + m_mam->add_pattern(q, p); + } + if (!get_dependency(q)) { + m_q2dep.insert(q, d); + get_trail().push(insert_obj_map(m_q2dep, q)); + } } } - m_egraph.propagate(); + } + + void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) { + var_subst subst(m); + expr_ref_vector _binding(m); + for (unsigned i = 0; i < q->get_num_decls(); ++i) + _binding.push_back(binding[i]->get_expr()); + expr_ref r = subst(q->get_expr(), _binding); + add_constraint(r, get_dependency(q)); + m_should_propagate = true; } void completion::read_egraph() { diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index da0fb7276..29d3b709a 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -19,11 +19,12 @@ Author: #include "ast/simplifiers/dependent_expr_state.h" #include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_mam.h" #include "ast/rewriter/th_rewriter.h" namespace euf { - class completion : public dependent_expr_simplifier { + class completion : public dependent_expr_simplifier, public on_binding_callback, public mam_solver { struct stats { unsigned m_num_rewrites = 0; @@ -31,16 +32,20 @@ namespace euf { }; egraph m_egraph; + scoped_ptr m_mam; enode* m_tt, *m_ff; ptr_vector m_todo; enode_vector m_args, m_reps, m_nodes_to_canonize; expr_ref_vector m_canonical, m_eargs; expr_dependency_ref_vector m_deps; + obj_map m_q2dep; unsigned m_epoch = 0; unsigned_vector m_epochs; th_rewriter m_rewriter; stats m_stats; bool m_has_new_eq = false; + bool m_should_propagate = false; + enode* mk_enode(expr* e); bool is_new_eq(expr* a, expr* b); @@ -54,8 +59,10 @@ namespace euf { expr* get_canonical(expr* f, expr_dependency_ref& d); expr* get_canonical(enode* n); void set_canonical(enode* n, expr* e); + void add_constraint(expr*f, expr_dependency* d); expr_dependency* explain_eq(enode* a, enode* b); expr_dependency* explain_conflict(); + expr_dependency* get_dependency(quantifier* q) { return m_q2dep.contains(q) ? m_q2dep[q] : nullptr; } public: completion(ast_manager& m, dependent_expr_state& fmls); char const* name() const override { return "euf-reduce"; } @@ -64,5 +71,15 @@ namespace euf { void reduce() override; void collect_statistics(statistics& st) const override; void reset_statistics() override { m_stats.reset(); } + + trail_stack& get_trail() override { return m_trail;} + region& get_region() override { return m_trail.get_region(); } + egraph& get_egraph() override { return m_egraph; } + bool is_relevant(enode* n) const override { return true; } + bool resource_limits_exceeded() const override { return false; } + ast_manager& get_manager() override { return m; } + + void on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) override; + }; }