mirror of
https://github.com/Z3Prover/z3
synced 2025-07-30 16:03:16 +00:00
use propagation queues and hash-tables to schedule bindings
This commit is contained in:
parent
7b432ae608
commit
f932d480a0
2 changed files with 158 additions and 13 deletions
|
@ -66,6 +66,7 @@ namespace euf {
|
|||
m_canonical(m),
|
||||
m_eargs(m),
|
||||
m_canonical_proofs(m),
|
||||
m_infer_patterns(m, m_smt_params),
|
||||
m_deps(m),
|
||||
m_rewriter(m) {
|
||||
m_tt = m_egraph.mk(m.mk_true(), 0, 0, nullptr);
|
||||
|
@ -196,9 +197,10 @@ namespace euf {
|
|||
m_should_propagate = false;
|
||||
m_egraph.propagate();
|
||||
m_mam->propagate();
|
||||
flush_binding_queue();
|
||||
propagate_rules();
|
||||
IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n");
|
||||
if (!m_should_propagate)
|
||||
if (!m_should_propagate && !should_stop())
|
||||
propagate_all_rules();
|
||||
}
|
||||
}
|
||||
|
@ -229,7 +231,8 @@ namespace euf {
|
|||
}
|
||||
else if (m.is_not(f, f)) {
|
||||
enode* n = mk_enode(f);
|
||||
m_egraph.merge(n, m_ff, to_ptr(push_pr_dep(pr, d)));
|
||||
auto j = to_ptr(push_pr_dep(pr, d));
|
||||
m_egraph.new_diseq(n, j);
|
||||
add_children(n);
|
||||
}
|
||||
else {
|
||||
|
@ -238,6 +241,12 @@ namespace euf {
|
|||
add_children(n);
|
||||
if (is_forall(f)) {
|
||||
quantifier* q = to_quantifier(f);
|
||||
if (q->get_num_patterns() == 0) {
|
||||
expr_ref tmp(m);
|
||||
m_infer_patterns(q, tmp);
|
||||
m_egraph.mk(tmp, 0, 0, nullptr); // ensure tmp is pinned within this scope.
|
||||
q = to_quantifier(tmp);
|
||||
}
|
||||
ptr_vector<app> ground;
|
||||
for (unsigned i = 0; i < q->get_num_patterns(); ++i) {
|
||||
auto p = to_app(q->get_pattern(i));
|
||||
|
@ -396,33 +405,100 @@ namespace euf {
|
|||
}
|
||||
}
|
||||
|
||||
binding* completion::tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding) {
|
||||
if (q->get_num_decls() > m_tmp_binding_capacity) {
|
||||
void* mem = memory::allocate(sizeof(binding) + q->get_num_decls() * sizeof(euf::enode*));
|
||||
m_tmp_binding = new (mem) binding(q, pat, 0, 0, 0);
|
||||
m_tmp_binding_capacity = q->get_num_decls();
|
||||
}
|
||||
|
||||
for (unsigned i = q->get_num_decls(); i-- > 0; )
|
||||
m_tmp_binding->m_nodes[i] = _binding[i];
|
||||
m_tmp_binding->m_pattern = pat;
|
||||
m_tmp_binding->m_q = q;
|
||||
return m_tmp_binding.get();
|
||||
}
|
||||
|
||||
binding* completion::alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) {
|
||||
binding* b = tmp_binding(q, pat, _binding);
|
||||
|
||||
if (m_bindings.contains(b))
|
||||
return nullptr;
|
||||
|
||||
for (unsigned i = q->get_num_decls(); i-- > 0; )
|
||||
b->m_nodes[i] = b->m_nodes[i]->get_root();
|
||||
|
||||
if (m_bindings.contains(b))
|
||||
return nullptr;
|
||||
|
||||
unsigned n = q->get_num_decls();
|
||||
unsigned sz = sizeof(binding) + sizeof(euf::enode* const*) * n;
|
||||
void* mem = get_region().allocate(sz);
|
||||
b = new (mem) binding(q, pat, max_generation, min_top, max_top);
|
||||
b->init(b);
|
||||
for (unsigned i = 0; i < n; ++i)
|
||||
b->m_nodes[i] = _binding[i];
|
||||
|
||||
m_bindings.insert(b);
|
||||
get_trail().push(insert_map<bindings, binding*>(m_bindings, b));
|
||||
return b;
|
||||
}
|
||||
|
||||
// callback when mam finds a binding
|
||||
void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) {
|
||||
void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned max_global, unsigned min_top, unsigned max_top) {
|
||||
if (should_stop())
|
||||
return;
|
||||
auto* b = alloc_binding(q, pat, binding, max_global, min_top, max_top);
|
||||
if (!b)
|
||||
return;
|
||||
insert_binding(b);
|
||||
}
|
||||
|
||||
void completion::insert_binding(binding* b) {
|
||||
m_queue.reserve(b->m_max_top_generation + 1);
|
||||
m_queue[b->m_max_top_generation].push_back(b);
|
||||
}
|
||||
|
||||
void completion::flush_binding_queue() {
|
||||
TRACE(euf_completion,
|
||||
tout << "flush-queue\n";
|
||||
for (unsigned i = 0; i < m_queue.size(); ++i)
|
||||
tout << i << ": " << m_queue[i].size() << "\n";);
|
||||
IF_VERBOSE(10,
|
||||
verbose_stream() << "flush-queue\n";
|
||||
for (unsigned i = 0; i < m_queue.size(); ++i)
|
||||
verbose_stream() << i << ": " << m_queue[i].size() << "\n");
|
||||
|
||||
for (auto& g : m_queue) {
|
||||
for (auto b : g)
|
||||
apply_binding(*b);
|
||||
g.reset();
|
||||
}
|
||||
}
|
||||
|
||||
void completion::apply_binding(binding& b) {
|
||||
if (should_stop())
|
||||
return;
|
||||
var_subst subst(m);
|
||||
expr_ref_vector _binding(m);
|
||||
unsigned max_generation = 0;
|
||||
for (unsigned i = 0; i < q->get_num_decls(); ++i) {
|
||||
_binding.push_back(binding[i]->get_expr());
|
||||
max_generation = std::max(max_generation, binding[i]->generation());
|
||||
}
|
||||
quantifier* q = b.m_q;
|
||||
for (unsigned i = 0; i < q->get_num_decls(); ++i)
|
||||
_binding.push_back(b.m_nodes[i]->get_expr());
|
||||
|
||||
expr_ref r = subst(q->get_expr(), _binding);
|
||||
IF_VERBOSE(12, verbose_stream() << "add " << r << "\n");
|
||||
IF_VERBOSE(10, verbose_stream() << max_generation << "\n");
|
||||
scoped_generation sg(*this, max_generation + 1);
|
||||
|
||||
scoped_generation sg(*this, b.m_max_top_generation + 1);
|
||||
auto [pr, d] = get_dependency(q);
|
||||
if (pr)
|
||||
pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), _binding.size(), _binding.data());
|
||||
add_constraint(r, pr, d);
|
||||
propagate_rules();
|
||||
m_egraph.propagate();
|
||||
m_should_propagate = true;
|
||||
++m_stats.m_num_instances;
|
||||
}
|
||||
|
||||
void completion::read_egraph() {
|
||||
//m_egraph.display(verbose_stream());
|
||||
//exit(0);
|
||||
if (m_egraph.inconsistent()) {
|
||||
auto* d = explain_conflict();
|
||||
proof_ref pr(m);
|
||||
|
|
|
@ -21,10 +21,13 @@ Author:
|
|||
#pragma once
|
||||
|
||||
#include "util/scoped_vector.h"
|
||||
#include "util/dlist.h"
|
||||
#include "ast/simplifiers/dependent_expr_state.h"
|
||||
#include "ast/euf/euf_egraph.h"
|
||||
#include "ast/euf/euf_mam.h"
|
||||
#include "ast/rewriter/th_rewriter.h"
|
||||
#include "ast/pattern/pattern_inference.h"
|
||||
#include "params/smt_params.h"
|
||||
|
||||
namespace euf {
|
||||
|
||||
|
@ -43,6 +46,60 @@ namespace euf {
|
|||
virtual void solve_for(vector<solution>& sol) = 0;
|
||||
};
|
||||
|
||||
struct binding : public dll_base<binding> {
|
||||
quantifier* m_q;
|
||||
app* m_pattern;
|
||||
unsigned m_max_generation;
|
||||
unsigned m_min_top_generation;
|
||||
unsigned m_max_top_generation;
|
||||
euf::enode* m_nodes[0];
|
||||
|
||||
binding(quantifier* q, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top) :
|
||||
m_q(q),
|
||||
m_pattern(pat),
|
||||
m_max_generation(max_generation),
|
||||
m_min_top_generation(min_top),
|
||||
m_max_top_generation(max_top) {
|
||||
}
|
||||
|
||||
euf::enode* const* nodes() { return m_nodes; }
|
||||
|
||||
euf::enode* operator[](unsigned i) const { return m_nodes[i]; }
|
||||
|
||||
unsigned size() const { return m_q->get_num_decls(); }
|
||||
|
||||
quantifier* q() const { return m_q; }
|
||||
|
||||
bool eq(binding const& other) const {
|
||||
if (q() != other.q())
|
||||
return false;
|
||||
for (unsigned i = size(); i-- > 0; )
|
||||
if ((*this)[i] != other[i])
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct binding_khasher {
|
||||
unsigned operator()(binding const* f) const { return f->q()->get_id(); }
|
||||
};
|
||||
|
||||
struct binding_chasher {
|
||||
unsigned operator()(binding const* f, unsigned idx) const { return f->m_nodes[idx]->hash(); }
|
||||
};
|
||||
|
||||
struct binding_hash_proc {
|
||||
unsigned operator()(binding const* f) const {
|
||||
return get_composite_hash<binding*, binding_khasher, binding_chasher>(const_cast<binding*>(f), f->size());
|
||||
}
|
||||
};
|
||||
|
||||
struct binding_eq_proc {
|
||||
bool operator()(binding const* a, binding const* b) const { return a->eq(*b); }
|
||||
};
|
||||
|
||||
typedef ptr_hashtable<binding, binding_hash_proc, binding_eq_proc> bindings;
|
||||
|
||||
class completion : public dependent_expr_simplifier, public on_binding_callback, public mam_solver {
|
||||
|
||||
struct stats {
|
||||
|
@ -63,6 +120,7 @@ namespace euf {
|
|||
m_body(b), m_head(h), m_proofs(prs), m_dep(d) {}
|
||||
};
|
||||
|
||||
smt_params m_smt_params;
|
||||
egraph m_egraph;
|
||||
scoped_ptr<mam> m_mam;
|
||||
enode* m_tt, *m_ff;
|
||||
|
@ -70,6 +128,10 @@ namespace euf {
|
|||
enode_vector m_args, m_reps, m_nodes_to_canonize;
|
||||
expr_ref_vector m_canonical, m_eargs;
|
||||
proof_ref_vector m_canonical_proofs;
|
||||
pattern_inference_rw m_infer_patterns;
|
||||
bindings m_bindings;
|
||||
scoped_ptr<binding> m_tmp_binding;
|
||||
unsigned m_tmp_binding_capacity = 0;
|
||||
expr_dependency_ref_vector m_deps;
|
||||
obj_map<quantifier, std::pair<proof*, expr_dependency*>> m_q2dep;
|
||||
vector<std::pair<proof_ref, expr_dependency*>> m_pr_dep;
|
||||
|
@ -109,6 +171,13 @@ namespace euf {
|
|||
expr_dependency* explain_conflict();
|
||||
std::pair<proof*, expr_dependency*> get_dependency(quantifier* q) { return m_q2dep.contains(q) ? m_q2dep[q] : std::pair(nullptr, nullptr); }
|
||||
|
||||
binding* tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding);
|
||||
binding* alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);
|
||||
void insert_binding(binding* b);
|
||||
void apply_binding(binding& b);
|
||||
void flush_binding_queue();
|
||||
vector<ptr_vector<binding>> m_queue;
|
||||
|
||||
lbool eval_cond(expr* f, proof_ref& pr, expr_dependency*& d);
|
||||
|
||||
bool should_stop();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue