From f932d480a0900674e3942e3a2b34e41993af4d84 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 15 Jun 2025 13:21:08 -0700 Subject: [PATCH] use propagation queues and hash-tables to schedule bindings --- src/ast/simplifiers/euf_completion.cpp | 102 +++++++++++++++++++++---- src/ast/simplifiers/euf_completion.h | 69 +++++++++++++++++ 2 files changed, 158 insertions(+), 13 deletions(-) diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index 3096d89ec..14a4e431f 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -66,6 +66,7 @@ namespace euf { m_canonical(m), m_eargs(m), m_canonical_proofs(m), + m_infer_patterns(m, m_smt_params), m_deps(m), m_rewriter(m) { m_tt = m_egraph.mk(m.mk_true(), 0, 0, nullptr); @@ -196,9 +197,10 @@ namespace euf { m_should_propagate = false; m_egraph.propagate(); m_mam->propagate(); + flush_binding_queue(); propagate_rules(); IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n"); - if (!m_should_propagate) + if (!m_should_propagate && !should_stop()) propagate_all_rules(); } } @@ -229,7 +231,8 @@ namespace euf { } else if (m.is_not(f, f)) { enode* n = mk_enode(f); - m_egraph.merge(n, m_ff, to_ptr(push_pr_dep(pr, d))); + auto j = to_ptr(push_pr_dep(pr, d)); + m_egraph.new_diseq(n, j); add_children(n); } else { @@ -238,6 +241,12 @@ namespace euf { add_children(n); if (is_forall(f)) { quantifier* q = to_quantifier(f); + if (q->get_num_patterns() == 0) { + expr_ref tmp(m); + m_infer_patterns(q, tmp); + m_egraph.mk(tmp, 0, 0, nullptr); // ensure tmp is pinned within this scope. + q = to_quantifier(tmp); + } ptr_vector ground; for (unsigned i = 0; i < q->get_num_patterns(); ++i) { auto p = to_app(q->get_pattern(i)); @@ -396,33 +405,100 @@ namespace euf { } } + binding* completion::tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding) { + if (q->get_num_decls() > m_tmp_binding_capacity) { + void* mem = memory::allocate(sizeof(binding) + q->get_num_decls() * sizeof(euf::enode*)); + m_tmp_binding = new (mem) binding(q, pat, 0, 0, 0); + m_tmp_binding_capacity = q->get_num_decls(); + } + + for (unsigned i = q->get_num_decls(); i-- > 0; ) + m_tmp_binding->m_nodes[i] = _binding[i]; + m_tmp_binding->m_pattern = pat; + m_tmp_binding->m_q = q; + return m_tmp_binding.get(); + } + + binding* completion::alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) { + binding* b = tmp_binding(q, pat, _binding); + + if (m_bindings.contains(b)) + return nullptr; + + for (unsigned i = q->get_num_decls(); i-- > 0; ) + b->m_nodes[i] = b->m_nodes[i]->get_root(); + + if (m_bindings.contains(b)) + return nullptr; + + unsigned n = q->get_num_decls(); + unsigned sz = sizeof(binding) + sizeof(euf::enode* const*) * n; + void* mem = get_region().allocate(sz); + b = new (mem) binding(q, pat, max_generation, min_top, max_top); + b->init(b); + for (unsigned i = 0; i < n; ++i) + b->m_nodes[i] = _binding[i]; + + m_bindings.insert(b); + get_trail().push(insert_map(m_bindings, b)); + return b; + } + // callback when mam finds a binding - void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) { + void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned max_global, unsigned min_top, unsigned max_top) { + if (should_stop()) + return; + auto* b = alloc_binding(q, pat, binding, max_global, min_top, max_top); + if (!b) + return; + insert_binding(b); + } + + void completion::insert_binding(binding* b) { + m_queue.reserve(b->m_max_top_generation + 1); + m_queue[b->m_max_top_generation].push_back(b); + } + + void completion::flush_binding_queue() { + TRACE(euf_completion, + tout << "flush-queue\n"; + for (unsigned i = 0; i < m_queue.size(); ++i) + tout << i << ": " << m_queue[i].size() << "\n";); + IF_VERBOSE(10, + verbose_stream() << "flush-queue\n"; + for (unsigned i = 0; i < m_queue.size(); ++i) + verbose_stream() << i << ": " << m_queue[i].size() << "\n"); + + for (auto& g : m_queue) { + for (auto b : g) + apply_binding(*b); + g.reset(); + } + } + + void completion::apply_binding(binding& b) { if (should_stop()) return; var_subst subst(m); expr_ref_vector _binding(m); - unsigned max_generation = 0; - for (unsigned i = 0; i < q->get_num_decls(); ++i) { - _binding.push_back(binding[i]->get_expr()); - max_generation = std::max(max_generation, binding[i]->generation()); - } + quantifier* q = b.m_q; + for (unsigned i = 0; i < q->get_num_decls(); ++i) + _binding.push_back(b.m_nodes[i]->get_expr()); + expr_ref r = subst(q->get_expr(), _binding); - IF_VERBOSE(12, verbose_stream() << "add " << r << "\n"); - IF_VERBOSE(10, verbose_stream() << max_generation << "\n"); - scoped_generation sg(*this, max_generation + 1); + + scoped_generation sg(*this, b.m_max_top_generation + 1); auto [pr, d] = get_dependency(q); if (pr) pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), _binding.size(), _binding.data()); add_constraint(r, pr, d); propagate_rules(); + m_egraph.propagate(); m_should_propagate = true; ++m_stats.m_num_instances; } void completion::read_egraph() { - //m_egraph.display(verbose_stream()); - //exit(0); if (m_egraph.inconsistent()) { auto* d = explain_conflict(); proof_ref pr(m); diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index c9c92f948..58c677f27 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -21,10 +21,13 @@ Author: #pragma once #include "util/scoped_vector.h" +#include "util/dlist.h" #include "ast/simplifiers/dependent_expr_state.h" #include "ast/euf/euf_egraph.h" #include "ast/euf/euf_mam.h" #include "ast/rewriter/th_rewriter.h" +#include "ast/pattern/pattern_inference.h" +#include "params/smt_params.h" namespace euf { @@ -43,6 +46,60 @@ namespace euf { virtual void solve_for(vector& sol) = 0; }; + struct binding : public dll_base { + quantifier* m_q; + app* m_pattern; + unsigned m_max_generation; + unsigned m_min_top_generation; + unsigned m_max_top_generation; + euf::enode* m_nodes[0]; + + binding(quantifier* q, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top) : + m_q(q), + m_pattern(pat), + m_max_generation(max_generation), + m_min_top_generation(min_top), + m_max_top_generation(max_top) { + } + + euf::enode* const* nodes() { return m_nodes; } + + euf::enode* operator[](unsigned i) const { return m_nodes[i]; } + + unsigned size() const { return m_q->get_num_decls(); } + + quantifier* q() const { return m_q; } + + bool eq(binding const& other) const { + if (q() != other.q()) + return false; + for (unsigned i = size(); i-- > 0; ) + if ((*this)[i] != other[i]) + return false; + return true; + } + }; + + struct binding_khasher { + unsigned operator()(binding const* f) const { return f->q()->get_id(); } + }; + + struct binding_chasher { + unsigned operator()(binding const* f, unsigned idx) const { return f->m_nodes[idx]->hash(); } + }; + + struct binding_hash_proc { + unsigned operator()(binding const* f) const { + return get_composite_hash(const_cast(f), f->size()); + } + }; + + struct binding_eq_proc { + bool operator()(binding const* a, binding const* b) const { return a->eq(*b); } + }; + + typedef ptr_hashtable bindings; + class completion : public dependent_expr_simplifier, public on_binding_callback, public mam_solver { struct stats { @@ -63,6 +120,7 @@ namespace euf { m_body(b), m_head(h), m_proofs(prs), m_dep(d) {} }; + smt_params m_smt_params; egraph m_egraph; scoped_ptr m_mam; enode* m_tt, *m_ff; @@ -70,6 +128,10 @@ namespace euf { enode_vector m_args, m_reps, m_nodes_to_canonize; expr_ref_vector m_canonical, m_eargs; proof_ref_vector m_canonical_proofs; + pattern_inference_rw m_infer_patterns; + bindings m_bindings; + scoped_ptr m_tmp_binding; + unsigned m_tmp_binding_capacity = 0; expr_dependency_ref_vector m_deps; obj_map> m_q2dep; vector> m_pr_dep; @@ -109,6 +171,13 @@ namespace euf { expr_dependency* explain_conflict(); std::pair get_dependency(quantifier* q) { return m_q2dep.contains(q) ? m_q2dep[q] : std::pair(nullptr, nullptr); } + binding* tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding); + binding* alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top); + void insert_binding(binding* b); + void apply_binding(binding& b); + void flush_binding_queue(); + vector> m_queue; + lbool eval_cond(expr* f, proof_ref& pr, expr_dependency*& d); bool should_stop();