3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-07-30 16:03:16 +00:00

use propagation queues and hash-tables to schedule bindings

This commit is contained in:
Nikolaj Bjorner 2025-06-15 13:21:08 -07:00
parent 7b432ae608
commit f932d480a0
2 changed files with 158 additions and 13 deletions

View file

@ -66,6 +66,7 @@ namespace euf {
m_canonical(m),
m_eargs(m),
m_canonical_proofs(m),
m_infer_patterns(m, m_smt_params),
m_deps(m),
m_rewriter(m) {
m_tt = m_egraph.mk(m.mk_true(), 0, 0, nullptr);
@ -196,9 +197,10 @@ namespace euf {
m_should_propagate = false;
m_egraph.propagate();
m_mam->propagate();
flush_binding_queue();
propagate_rules();
IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n");
if (!m_should_propagate)
if (!m_should_propagate && !should_stop())
propagate_all_rules();
}
}
@ -229,7 +231,8 @@ namespace euf {
}
else if (m.is_not(f, f)) {
enode* n = mk_enode(f);
m_egraph.merge(n, m_ff, to_ptr(push_pr_dep(pr, d)));
auto j = to_ptr(push_pr_dep(pr, d));
m_egraph.new_diseq(n, j);
add_children(n);
}
else {
@ -238,6 +241,12 @@ namespace euf {
add_children(n);
if (is_forall(f)) {
quantifier* q = to_quantifier(f);
if (q->get_num_patterns() == 0) {
expr_ref tmp(m);
m_infer_patterns(q, tmp);
m_egraph.mk(tmp, 0, 0, nullptr); // ensure tmp is pinned within this scope.
q = to_quantifier(tmp);
}
ptr_vector<app> ground;
for (unsigned i = 0; i < q->get_num_patterns(); ++i) {
auto p = to_app(q->get_pattern(i));
@ -396,33 +405,100 @@ namespace euf {
}
}
binding* completion::tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding) {
if (q->get_num_decls() > m_tmp_binding_capacity) {
void* mem = memory::allocate(sizeof(binding) + q->get_num_decls() * sizeof(euf::enode*));
m_tmp_binding = new (mem) binding(q, pat, 0, 0, 0);
m_tmp_binding_capacity = q->get_num_decls();
}
for (unsigned i = q->get_num_decls(); i-- > 0; )
m_tmp_binding->m_nodes[i] = _binding[i];
m_tmp_binding->m_pattern = pat;
m_tmp_binding->m_q = q;
return m_tmp_binding.get();
}
binding* completion::alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) {
binding* b = tmp_binding(q, pat, _binding);
if (m_bindings.contains(b))
return nullptr;
for (unsigned i = q->get_num_decls(); i-- > 0; )
b->m_nodes[i] = b->m_nodes[i]->get_root();
if (m_bindings.contains(b))
return nullptr;
unsigned n = q->get_num_decls();
unsigned sz = sizeof(binding) + sizeof(euf::enode* const*) * n;
void* mem = get_region().allocate(sz);
b = new (mem) binding(q, pat, max_generation, min_top, max_top);
b->init(b);
for (unsigned i = 0; i < n; ++i)
b->m_nodes[i] = _binding[i];
m_bindings.insert(b);
get_trail().push(insert_map<bindings, binding*>(m_bindings, b));
return b;
}
// callback when mam finds a binding
void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) {
void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned max_global, unsigned min_top, unsigned max_top) {
if (should_stop())
return;
auto* b = alloc_binding(q, pat, binding, max_global, min_top, max_top);
if (!b)
return;
insert_binding(b);
}
void completion::insert_binding(binding* b) {
m_queue.reserve(b->m_max_top_generation + 1);
m_queue[b->m_max_top_generation].push_back(b);
}
void completion::flush_binding_queue() {
TRACE(euf_completion,
tout << "flush-queue\n";
for (unsigned i = 0; i < m_queue.size(); ++i)
tout << i << ": " << m_queue[i].size() << "\n";);
IF_VERBOSE(10,
verbose_stream() << "flush-queue\n";
for (unsigned i = 0; i < m_queue.size(); ++i)
verbose_stream() << i << ": " << m_queue[i].size() << "\n");
for (auto& g : m_queue) {
for (auto b : g)
apply_binding(*b);
g.reset();
}
}
void completion::apply_binding(binding& b) {
if (should_stop())
return;
var_subst subst(m);
expr_ref_vector _binding(m);
unsigned max_generation = 0;
for (unsigned i = 0; i < q->get_num_decls(); ++i) {
_binding.push_back(binding[i]->get_expr());
max_generation = std::max(max_generation, binding[i]->generation());
}
quantifier* q = b.m_q;
for (unsigned i = 0; i < q->get_num_decls(); ++i)
_binding.push_back(b.m_nodes[i]->get_expr());
expr_ref r = subst(q->get_expr(), _binding);
IF_VERBOSE(12, verbose_stream() << "add " << r << "\n");
IF_VERBOSE(10, verbose_stream() << max_generation << "\n");
scoped_generation sg(*this, max_generation + 1);
scoped_generation sg(*this, b.m_max_top_generation + 1);
auto [pr, d] = get_dependency(q);
if (pr)
pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), _binding.size(), _binding.data());
add_constraint(r, pr, d);
propagate_rules();
m_egraph.propagate();
m_should_propagate = true;
++m_stats.m_num_instances;
}
void completion::read_egraph() {
//m_egraph.display(verbose_stream());
//exit(0);
if (m_egraph.inconsistent()) {
auto* d = explain_conflict();
proof_ref pr(m);

View file

@ -21,10 +21,13 @@ Author:
#pragma once
#include "util/scoped_vector.h"
#include "util/dlist.h"
#include "ast/simplifiers/dependent_expr_state.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_mam.h"
#include "ast/rewriter/th_rewriter.h"
#include "ast/pattern/pattern_inference.h"
#include "params/smt_params.h"
namespace euf {
@ -43,6 +46,60 @@ namespace euf {
virtual void solve_for(vector<solution>& sol) = 0;
};
struct binding : public dll_base<binding> {
quantifier* m_q;
app* m_pattern;
unsigned m_max_generation;
unsigned m_min_top_generation;
unsigned m_max_top_generation;
euf::enode* m_nodes[0];
binding(quantifier* q, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top) :
m_q(q),
m_pattern(pat),
m_max_generation(max_generation),
m_min_top_generation(min_top),
m_max_top_generation(max_top) {
}
euf::enode* const* nodes() { return m_nodes; }
euf::enode* operator[](unsigned i) const { return m_nodes[i]; }
unsigned size() const { return m_q->get_num_decls(); }
quantifier* q() const { return m_q; }
bool eq(binding const& other) const {
if (q() != other.q())
return false;
for (unsigned i = size(); i-- > 0; )
if ((*this)[i] != other[i])
return false;
return true;
}
};
struct binding_khasher {
unsigned operator()(binding const* f) const { return f->q()->get_id(); }
};
struct binding_chasher {
unsigned operator()(binding const* f, unsigned idx) const { return f->m_nodes[idx]->hash(); }
};
struct binding_hash_proc {
unsigned operator()(binding const* f) const {
return get_composite_hash<binding*, binding_khasher, binding_chasher>(const_cast<binding*>(f), f->size());
}
};
struct binding_eq_proc {
bool operator()(binding const* a, binding const* b) const { return a->eq(*b); }
};
typedef ptr_hashtable<binding, binding_hash_proc, binding_eq_proc> bindings;
class completion : public dependent_expr_simplifier, public on_binding_callback, public mam_solver {
struct stats {
@ -63,6 +120,7 @@ namespace euf {
m_body(b), m_head(h), m_proofs(prs), m_dep(d) {}
};
smt_params m_smt_params;
egraph m_egraph;
scoped_ptr<mam> m_mam;
enode* m_tt, *m_ff;
@ -70,6 +128,10 @@ namespace euf {
enode_vector m_args, m_reps, m_nodes_to_canonize;
expr_ref_vector m_canonical, m_eargs;
proof_ref_vector m_canonical_proofs;
pattern_inference_rw m_infer_patterns;
bindings m_bindings;
scoped_ptr<binding> m_tmp_binding;
unsigned m_tmp_binding_capacity = 0;
expr_dependency_ref_vector m_deps;
obj_map<quantifier, std::pair<proof*, expr_dependency*>> m_q2dep;
vector<std::pair<proof_ref, expr_dependency*>> m_pr_dep;
@ -109,6 +171,13 @@ namespace euf {
expr_dependency* explain_conflict();
std::pair<proof*, expr_dependency*> get_dependency(quantifier* q) { return m_q2dep.contains(q) ? m_q2dep[q] : std::pair(nullptr, nullptr); }
binding* tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding);
binding* alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);
void insert_binding(binding* b);
void apply_binding(binding& b);
void flush_binding_queue();
vector<ptr_vector<binding>> m_queue;
lbool eval_cond(expr* f, proof_ref& pr, expr_dependency*& d);
bool should_stop();