3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-07-24 21:26:59 +00:00

turn on ho-matcher for completion

This commit is contained in:
Nikolaj Bjorner 2025-07-07 14:08:51 +02:00
parent 1b3c3c2716
commit 0c5b0c3724
10 changed files with 166 additions and 63 deletions

View file

@ -2386,29 +2386,19 @@ app * ast_manager::mk_pattern(unsigned num_exprs, app * const * exprs) {
}
bool ast_manager::is_pattern(expr const * n) const {
if (!is_app_of(n, pattern_family_id, OP_PATTERN)) {
return false;
}
for (unsigned i = 0; i < to_app(n)->get_num_args(); ++i) {
if (!is_app(to_app(n)->get_arg(i))) {
return false;
}
}
return true;
if (!is_app_of(n, pattern_family_id, OP_PATTERN))
return false;
return all_of(*to_app(n), [](expr* arg) { return is_app(arg); });
}
bool ast_manager::is_pattern(expr const * n, ptr_vector<expr> &args) {
if (!is_app_of(n, pattern_family_id, OP_PATTERN)) {
bool ast_manager::is_pattern(expr const * n, ptr_vector<app> &args) {
if (!is_pattern(n))
return false;
}
for (unsigned i = 0; i < to_app(n)->get_num_args(); ++i) {
expr *arg = to_app(n)->get_arg(i);
if (!is_app(arg)) {
return false;
}
args.push_back(arg);
}
for (auto arg : *to_app(n))
args.push_back(to_app(arg));
return true;
}

View file

@ -2015,7 +2015,7 @@ public:
bool is_pattern(expr const * n) const;
bool is_pattern(expr const *n, ptr_vector<expr> &args);
bool is_pattern(expr const *n, ptr_vector<app> &args);
public:

View file

@ -1 +0,0 @@
nbjorner@LAPTOP-04AEAFKH.38072:1751392111

View file

@ -2994,6 +2994,8 @@ namespace euf {
SASSERT(m.is_pattern(mp));
SASSERT(first_idx < mp->get_num_args());
app * p = to_app(mp->get_arg(first_idx));
if (is_ground(p))
return;
func_decl * lbl = p->get_decl();
unsigned lbl_id = lbl->get_small_id();
m_trees.reserve(lbl_id+1, nullptr);
@ -3879,9 +3881,10 @@ namespace euf {
// Ground patterns are discarded.
// However, the simplifier may turn a non-ground pattern into a ground one.
// So, we should check it again here.
for (expr* arg : *mp)
if (is_ground(arg) || has_quantifiers(arg))
return; // ignore multi-pattern containing ground pattern.
if (all_of(*mp, [](expr* arg) { return is_ground(arg); }))
return; // ignore multi-pattern containing only ground pattern.
if (any_of(*mp, [](expr* arg) { return has_quantifiers(arg); }))
return; // patterns with quantifiers are not handled.
update_filters(qa, mp);
m_new_patterns.push_back(qp_pair(qa, mp));
ctx.get_trail().push(push_back_trail<qp_pair, false>(m_new_patterns));

View file

@ -55,6 +55,7 @@ namespace euf {
void ho_matcher::operator()(expr* pat, expr* t, unsigned num_bound, unsigned num_vars) {
m_trail.push_scope();
m_subst.resize(0);
m_subst.resize(num_vars);
m_goals.reset();
m_goals.push(0, num_bound, pat, t);
@ -93,7 +94,7 @@ namespace euf {
bool st = consume_work(wi);
IF_VERBOSE(3, display(verbose_stream() << "ho_matcher::consume_work: " << wi.pat << " =?= " << wi.t << " -> " << (st?"true":"false") << "\n"););
if (st) {
if (m_goals.empty())
if (m_goals.empty())
m_on_match(m_subst);
break;
}
@ -635,14 +636,17 @@ namespace euf {
}
app* ho_matcher::compile_ho_pattern(quantifier* q, app* p) {
quantifier* ho_matcher::compile_ho_pattern(quantifier* q, app*& p) {
app* p1 = nullptr;
if (m_pat2hopat.find(p, p1))
return p1;
if (m_pat2hopat.find(p, p)) {
q = m_q2hoq[q];
return q;
}
auto is_ho = any_of(subterms::all(expr_ref(p, m)), [&](expr* t) { return m_unitary.is_flex(0, t); });
if (!is_ho)
return p;
return q;
ptr_vector<expr> todo;
ptr_buffer<var> bound;
expr_ref_vector cache(m);
unsigned nb = q->get_num_decls();
todo.push_back(p);
@ -655,7 +659,9 @@ namespace euf {
}
if (m_unitary.is_flex(0, t)) {
m_pat2abs.insert_if_not_there(p, svector<std::pair<unsigned, expr*>>()).push_back({ nb, t });
cache.setx(t->get_id(), m.mk_var(nb++, t->get_sort()));
auto v = m.mk_var(nb++, t->get_sort());
bound.push_back(v);
cache.setx(t->get_id(), v);
todo.pop_back();
continue;
}
@ -678,41 +684,91 @@ namespace euf {
}
if (is_quantifier(t)) {
m_pat2abs.remove(p);
return p;
return q;
}
}
p1 = to_app(cache.get(p->get_id()));
expr_free_vars free_vars;
free_vars(p1);
app_ref_vector new_ground(m);
app_ref_vector new_patterns(m);
ptr_buffer<sort> sorts;
vector<symbol> names;
for (unsigned i = bound.size(); i-- > 0; ) {
sorts.push_back(bound[i]->get_sort());
names.push_back(symbol(bound[i]->get_idx()));
}
unsigned sz = q->get_num_decls();
for (unsigned i = 0; i < sz; ++i) {
unsigned idx = sz - i - 1;
auto s = q->get_decl_sort(i);
sorts.push_back(s);
names.push_back(q->get_decl_name(i));
if (!free_vars.contains(idx)) {
auto p = m.mk_fresh_func_decl("p", 1, &s, m.mk_bool_sort());
new_patterns.push_back(m.mk_app(p, m.mk_var(idx, s)));
new_ground.push_back(m.mk_app(p, m.mk_fresh_const(symbol("c"), s)));
}
}
auto body = q->get_expr();
if (!new_patterns.empty()) {
ptr_vector<app> pats;
VERIFY(m.is_pattern(p1, pats));
for (auto p : new_patterns) // patterns for variables that are not free in new pattern
pats.push_back(p);
for (auto g : new_ground) // ensure ground terms are in pattern so they have enodes
pats.push_back(g);
p1 = m.mk_pattern(pats.size(), pats.data());
}
quantifier* q1 = m.mk_forall(sorts.size(), sorts.data(), names.data(), body);
m_pat2hopat.insert(p, p1);
m_hopat2pat.insert(p1, p);
m_q2hoq.insert(q, q1);
m_hoq2q.insert(q1, q);
m_hopat2free_vars.insert(p1, free_vars);
m_ho_patterns.push_back(p1);
m_ho_qs.push_back(q1);
trail().push(push_back_vector(m_ho_patterns));
trail().push(push_back_vector(m_ho_qs));
trail().push(insert_map(m_pat2hopat, p));
trail().push(insert_map(m_hopat2pat, p1));
trail().push(insert_map(m_pat2abs, p));
return p1;
trail().push(insert_map(m_q2hoq, q));
trail().push(insert_map(m_hoq2q, q1));
trail().push(insert_map(m_hopat2free_vars, p1));
p = p1;
return q1;
}
bool ho_matcher::is_ho_pattern(app* p) {
return m_hopat2pat.contains(p);
}
void ho_matcher::refine_ho_match(app* p, expr_ref_vector const& s) {
void ho_matcher::refine_ho_match(app* p, expr_ref_vector& s) {
auto fo_pat = m_hopat2pat[p];
m_trail.push_scope();
m_subst.resize(0);
m_subst.resize(s.size());
m_goals.reset();
for (unsigned i = 0; i < s.size(); ++i) {
if (s[i])
m_subst.set(i, s[i]);
auto idx = s.size() - i - 1;
if (!m_hopat2free_vars[p].contains(idx))
s[i] = m.mk_var(idx, s[i]->get_sort());
else if (s.get(i))
m_subst.set(i, s.get(i));
}
IF_VERBOSE(1, verbose_stream() << "refine " << mk_pp(p, m) << "\n" << s << "\n");
unsigned num_bound = 0, level = 0;
for (auto [v, pat] : m_pat2abs[fo_pat]) {
var_subst sub(m, false);
var_subst sub(m, true);
auto pat_refined = sub(pat, s);
IF_VERBOSE(1, verbose_stream() << mk_pp(pat, m) << " -> " << pat_refined << "\n");
m_goals.push(level, num_bound, pat_refined, s[v]);
m_goals.push(level, num_bound, pat_refined, s.get(s.size() - v - 1));
}
search();

View file

@ -316,8 +316,10 @@ namespace euf {
mutable array_rewriter m_rewriter;
array_util m_array;
obj_map<app, app*> m_pat2hopat, m_hopat2pat;
obj_map<quantifier, quantifier*> m_q2hoq, m_hoq2q;
obj_map<app, expr_free_vars> m_hopat2free_vars;
obj_map<app, svector<std::pair<unsigned, expr*>>> m_pat2abs;
expr_ref_vector m_ho_patterns;
expr_ref_vector m_ho_patterns, m_ho_qs;
void resume();
@ -373,7 +375,8 @@ namespace euf {
m_unitary(m),
m_rewriter(m),
m_array(m),
m_ho_patterns(m)
m_ho_patterns(m),
m_ho_qs(m)
{
}
@ -383,11 +386,15 @@ namespace euf {
void operator()(expr* pat, expr* t, unsigned num_bound, unsigned num_vars);
app* compile_ho_pattern(quantifier* q, app* p);
quantifier* compile_ho_pattern(quantifier* q, app*& p);
bool is_ho_pattern(app* p);
void refine_ho_match(app* p, expr_ref_vector const& s);
void refine_ho_match(app* p, expr_ref_vector& s);
bool is_free(app* p, unsigned i) const { return m_hopat2free_vars[p].contains(i); }
quantifier* hoq2q(quantifier* q) const { return m_hoq2q[q]; }
};
}

View file

@ -70,7 +70,8 @@ namespace euf {
m_canonical_proofs(m),
// m_infer_patterns(m, m_smt_params),
m_deps(m),
m_rewriter(m) {
m_rewriter(m),
m_matcher(m, m_trail) {
m_tt = m_egraph.mk(m.mk_true(), 0, 0, nullptr);
m_ff = m_egraph.mk(m.mk_false(), 0, 0, nullptr);
m_rewriter.set_order_eq(true);
@ -92,6 +93,39 @@ namespace euf {
m_egraph.add_plugin(alloc(arith_plugin, m_egraph));
m_egraph.add_plugin(alloc(bv_plugin, m_egraph));
std::function<void(ho_subst&)> on_match =
[&](ho_subst& s) {
IF_VERBOSE(1, s.display(verbose_stream() << "on-match\n") << "\n");
auto& b = *m_ho_binding;
auto* hopat = b.m_pattern;
auto* hoq = b.m_q;
auto* q = m_matcher.hoq2q(hoq);
// shrink binding
expr_ref_vector binding(m);
for (unsigned i = 0; i < s.size(); ++i)
binding.push_back(s.get(i));
binding.reverse();
if (binding.size() > q->get_num_decls()) {
bool change = true;
while (change) {
change = false;
for (unsigned i = binding.size(); i-- > 0;) {
var_subst sub(m, false);
auto r = sub(binding.get(i), binding);
change |= r != binding.get(i);
binding[i] = r;
}
}
}
binding.shrink(q->get_num_decls());
binding.reverse();
IF_VERBOSE(1, verbose_stream() << binding << "\n");
apply_binding(b, q, binding);
};
m_matcher.set_on_match(on_match);
}
completion::~completion() {
@ -108,6 +142,7 @@ namespace euf {
void completion::updt_params(params_ref const& p) {
smt_params_helper sp(p);
m_max_instantiations = sp.qi_max_instances();
// m_max_generation = sp.qi_max_generation();
}
struct completion::push_watch_rule : public trail {
@ -222,6 +257,7 @@ namespace euf {
void completion::add_constraint(expr* f, proof* pr, expr_dependency* d) {
if (m_egraph.inconsistent())
return;
TRACE(euf_completion, tout << mk_pp(f, m) << "\n");
auto add_children = [&](enode* n) {
for (auto* ch : enode_args(n))
m_nodes_to_canonize.push_back(ch);
@ -234,12 +270,14 @@ namespace euf {
m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d)));
add_children(a);
add_children(b);
m_should_propagate = true;
}
else if (m.is_not(f, f)) {
enode* n = mk_enode(f);
auto j = to_ptr(push_pr_dep(pr, d));
m_egraph.new_diseq(n, j);
add_children(n);
m_should_propagate = true;
}
else {
enode* n = mk_enode(f);
@ -255,13 +293,15 @@ namespace euf {
q = to_quantifier(tmp);
}
#endif
ptr_vector<app> ground;
for (unsigned i = 0; i < q->get_num_patterns(); ++i) {
auto p = to_app(q->get_pattern(i));
auto q1 = m_matcher.compile_ho_pattern(q, p);
ptr_vector<app> ground;
mam::ground_subterms(p, ground);
for (expr* g : ground)
mk_enode(g);
m_mam->add_pattern(q, p);
m_mam->add_pattern(q1, p);
}
m_q2dep.insert(q, { pr, d});
get_trail().push(insert_obj_map(m_q2dep, q));
@ -295,7 +335,7 @@ namespace euf {
if (m.is_true(n->get_root()->get_expr()))
return l_false;
}
if (m_side_condition_solver) {
if (m_side_condition_solver && m_propagate_with_solver) {
expr_dependency* sd = nullptr;
if (m_side_condition_solver->is_true(f, pr, sd)) {
add_constraint(f, pr, sd);
@ -363,6 +403,7 @@ namespace euf {
}
void completion::propagate_all_rules() {
flet<bool> _propagate_with_solver(m_propagate_with_solver, true);
for (auto* r : m_rules)
if (!r->m_in_queue)
r->m_in_queue = true,
@ -456,6 +497,8 @@ namespace euf {
void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned max_global, unsigned min_top, unsigned max_top) {
if (should_stop())
return;
if (max_top >= m_max_generation)
return;
auto* b = alloc_binding(q, pat, binding, max_global, min_top, max_top);
if (!b)
return;
@ -487,23 +530,21 @@ namespace euf {
void completion::apply_binding(binding& b) {
if (should_stop())
return;
#if 0
if (is_ho_binding(b))
apply_ho_binding(b);
else
#endif
{
expr_ref_vector _binding(m);
quantifier* q = b.m_q;
for (unsigned i = 0; i < q->get_num_decls(); ++i)
_binding.push_back(b.m_nodes[i]->get_expr());
apply_binding(b, _binding);
expr_ref_vector _binding(m);
quantifier* q = b.m_q;
for (unsigned i = 0; i < q->get_num_decls(); ++i)
_binding.push_back(b.m_nodes[i]->get_expr());
if (m_matcher.is_ho_pattern(b.m_pattern)) {
flet<binding*> set_binding(m_ho_binding, &b);
m_matcher.refine_ho_match(b.m_pattern, _binding);
}
else
apply_binding(b, q, _binding);
}
void completion::apply_binding(binding& b, expr_ref_vector const& s) {
void completion::apply_binding(binding& b, quantifier* q, expr_ref_vector const& s) {
var_subst subst(m);
quantifier* q = b.m_q;
expr_ref r = subst(q->get_expr(), s);
scoped_generation sg(*this, b.m_max_top_generation + 1);
auto [pr, d] = get_dependency(q);
@ -512,10 +553,8 @@ namespace euf {
add_constraint(r, pr, d);
propagate_rules();
m_egraph.propagate();
m_should_propagate = true;
}
void completion::read_egraph() {
if (m_egraph.inconsistent()) {
auto* d = explain_conflict();

View file

@ -25,6 +25,7 @@ Author:
#include "ast/simplifiers/dependent_expr_state.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_mam.h"
#include "ast/euf/ho_matcher.h"
#include "ast/rewriter/th_rewriter.h"
// include "ast/pattern/pattern_inference.h"
#include "params/smt_params.h"
@ -133,18 +134,22 @@ namespace euf {
bindings m_bindings;
scoped_ptr<binding> m_tmp_binding;
unsigned m_tmp_binding_capacity = 0;
binding* m_ho_binding = nullptr;
expr_dependency_ref_vector m_deps;
obj_map<quantifier, std::pair<proof*, expr_dependency*>> m_q2dep;
vector<std::pair<proof_ref, expr_dependency*>> m_pr_dep;
unsigned m_epoch = 0;
unsigned_vector m_epochs;
th_rewriter m_rewriter;
ho_matcher m_matcher;
stats m_stats;
scoped_ptr<side_condition_solver> m_side_condition_solver;
ptr_vector<conditional_rule> m_rules;
bool m_has_new_eq = false;
bool m_should_propagate = false;
bool m_propagate_with_solver = false;
unsigned m_max_instantiations = std::numeric_limits<unsigned>::max();
unsigned m_max_generation = 10;
unsigned m_generation = 0;
vector<ptr_vector<conditional_rule>> m_rule_watch;
@ -176,7 +181,7 @@ namespace euf {
binding* alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);
void insert_binding(binding* b);
void apply_binding(binding& b);
void apply_binding(binding& b, expr_ref_vector const& s);
void apply_binding(binding& b, quantifier* q, expr_ref_vector const& s);
void flush_binding_queue();
vector<ptr_vector<binding>> m_queue;

View file

@ -88,6 +88,8 @@ public:
expr_ref_vector core(m);
m_solver->get_unsat_core(core);
for (auto c : core) {
if (c == nf)
continue;
auto [pr, dep] = m_e2d[c];
d = m.mk_join(d, dep);
}
@ -96,6 +98,8 @@ public:
SASSERT(pr);
expr_safe_replace rep(m);
for (auto c : core) {
if (c == nf)
continue;
auto [p, dep] = m_e2d[c];
rep.insert(m.mk_asserted(c), p);
}

View file

@ -23,7 +23,7 @@ namespace euf {
m_f = m.mk_func_decl(symbol("f"), m_int, m_int, m_int);
std::function<void(ho_subst& s)> on_match = [&](ho_subst& s) {
s.display(verbose_stream() << "match\n");
s.display(verbose_stream() << "match\n");
};
m_matcher.set_on_match(on_match);