3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-05-25 11:26:21 +00:00

enable higher-order matching in mam/smt_quantifier

This commit is contained in:
Nikolaj Bjorner 2026-05-22 17:06:37 -07:00
parent 98d0e7f27c
commit ea0964d195
3 changed files with 162 additions and 6 deletions

View file

@ -19,6 +19,8 @@ Revision History:
#include "ast/ast_pp.h"
#include "ast/ast_ll_pp.h"
#include "ast/quantifier_stat.h"
#include "ast/euf/ho_matcher.h"
#include "ast/rewriter/var_subst.h"
#include "smt/smt_quantifier.h"
#include "smt/smt_context.h"
#include "smt/smt_model_finder.h"
@ -154,7 +156,8 @@ namespace smt {
}
unsigned get_generation(quantifier * q) const {
return get_stat(q)->get_generation();
auto* s = m_quantifier_stat.find_core(q);
return s ? s->get_data().get_value()->get_generation() : 0;
}
void add(quantifier * q, unsigned generation) {
@ -295,6 +298,15 @@ namespace smt {
unsigned max_top_generation,
vector<std::tuple<enode *, enode *>> & used_enodes) {
// Try higher-order refinement first
if (pat && m_plugin->refine_instance(q, pat, num_bindings, bindings, max_generation, min_top_generation, max_top_generation, used_enodes))
return true;
if (!m_quantifier_stat.contains(q)) {
IF_VERBOSE(2, verbose_stream() << "add_instance: quantifier not in stat map: " << mk_pp(q, m()) << "\n");
return false;
}
max_generation = std::max(max_generation, get_generation(q));
get_stat(q)->update_max_generation(max_generation);
@ -599,9 +611,24 @@ namespace smt {
scoped_ptr<mam> m_lazy_mam;
scoped_ptr<model_finder> m_model_finder;
scoped_ptr<model_checker> m_model_checker;
scoped_ptr<euf::ho_matcher> m_ho_matcher;
trail_stack m_ho_trail;
unsigned m_new_enode_qhead;
unsigned m_lazy_matching_idx;
bool m_active;
// State for higher-order match refinement callback
struct ho_match_state {
quantifier* m_q = nullptr;
app* m_pat = nullptr;
unsigned m_num_bindings = 0;
enode* const* m_bindings = nullptr;
unsigned m_max_generation = 0;
unsigned m_min_top_generation = 0;
unsigned m_max_top_generation = 0;
vector<std::tuple<enode*, enode*>>* m_used_enodes = nullptr;
};
ho_match_state m_ho_state;
public:
default_qm_plugin():
m_qm(nullptr),
@ -625,10 +652,104 @@ namespace smt {
m_model_finder->set_context(m_context);
m_model_checker->set_qm(qm);
if (m_fparams->m_ho_matching) {
m_ho_matcher = alloc(euf::ho_matcher, m, m_ho_trail);
std::function<void(euf::ho_subst&)> on_match = [&](euf::ho_subst& s) {
on_ho_match(s);
};
m_ho_matcher->set_on_match(on_match);
}
}
quantifier_manager_plugin * mk_fresh() override { return alloc(default_qm_plugin); }
void on_ho_match(euf::ho_subst& s) {
ast_manager& m = m_context->get_manager();
auto& st = m_ho_state;
auto* hoq = st.m_q;
auto* q = m_ho_matcher->hoq2q(hoq);
expr_ref_vector binding(m);
for (unsigned i = 0; i < s.size(); ++i)
binding.push_back(s.get(i));
// Shrink binding to original quantifier's num_decls
// The HO quantifier has extra vars at higher indices; drop them.
// Binding is indexed by var index: binding[i] = value for var i.
// First substitute any remaining vars, then keep only original vars.
if (binding.size() > q->get_num_decls()) {
var_subst sub(m);
bool change = true;
while (change) {
change = false;
for (unsigned i = 0; i < binding.size(); ++i) {
if (!binding.get(i)) continue;
auto r = sub(binding.get(i), binding);
change |= r != binding.get(i);
binding[i] = r;
}
}
binding.shrink(q->get_num_decls());
}
// Create enodes for the refined bindings and add instance
ptr_buffer<enode> new_bindings;
unsigned max_gen = st.m_max_generation;
for (unsigned i = 0; i < q->get_num_decls(); ++i) {
expr* e = binding.get(i);
if (!e)
return; // incomplete binding
if (!m_context->e_internalized(e)) {
m_context->internalize(e, false);
}
enode* n = m_context->get_enode(e);
new_bindings.push_back(n);
if (n->get_generation() > max_gen)
max_gen = n->get_generation();
}
TRACE(ho_matching,
tout << "ho_match refined for " << mk_pp(q, m) << "\n";
for (unsigned i = 0; i < new_bindings.size(); ++i)
tout << " binding[" << i << "] = " << mk_pp(new_bindings[i]->get_expr(), m) << "\n";);
vector<std::tuple<enode*, enode*>> used_enodes;
m_context->add_instance(q, nullptr, new_bindings.size(), new_bindings.data(),
nullptr, max_gen, st.m_min_top_generation, st.m_max_top_generation, used_enodes);
}
bool try_ho_refine(quantifier* qa, app* pat, unsigned num_bindings, enode* const* bindings,
unsigned max_generation, unsigned min_top_gen, unsigned max_top_gen,
vector<std::tuple<enode*, enode*>>& used_enodes) {
if (!m_ho_matcher || !m_ho_matcher->is_ho_pattern(pat))
return false;
ast_manager& m = m_context->get_manager();
expr_ref_vector s(m);
// With var_subst(std_order=true): var idx maps to s[s.size()-idx-1]
// SMT MAM bindings: bindings[i] = var at index (num_bindings-1-i)
// So bindings[i] corresponds to s[i] with std_order
for (unsigned i = 0; i < num_bindings; ++i)
s.push_back(bindings[i]->get_expr());
m_ho_state.m_q = qa;
m_ho_state.m_pat = pat;
m_ho_state.m_num_bindings = num_bindings;
m_ho_state.m_bindings = bindings;
m_ho_state.m_max_generation = max_generation;
m_ho_state.m_min_top_generation = min_top_gen;
m_ho_state.m_max_top_generation = max_top_gen;
m_ho_state.m_used_enodes = &used_enodes;
IF_VERBOSE(10, verbose_stream() << "try_ho_refine: q=" << mk_pp(qa, m) << "\n pat=" << mk_pp(pat, m) << "\n";
for (unsigned i = 0; i < num_bindings; ++i)
verbose_stream() << " s[" << i << "] = " << mk_pp(s.get(i), m) << " sort=" << mk_pp(s.get(i)->get_sort(), m) << "\n";);
m_ho_matcher->refine_ho_match(pat, s);
return true;
}
bool model_based() const override { return m_fparams->m_mbqi; }
bool mbqi_enabled(quantifier *q) const override {
@ -704,6 +825,19 @@ namespace smt {
TRACE(quantifier, tout << "adding:\n" << expr_ref(mp, m) << "\n";);
m_mam->add_pattern(q, mp);
}
// Compile HO pattern and also register the compiled version with MAM
if (m_ho_matcher) {
auto [q1, p1] = m_ho_matcher->compile_ho_pattern(q, mp);
IF_VERBOSE(10, verbose_stream() << "ho_matching: q=" << q->get_qid()
<< " compiled=" << (p1 != mp)
<< " p1=" << mk_pp(p1, m) << "\n");
if (p1 != mp) {
if (!unary && j >= num_eager_multi_patterns)
m_lazy_mam->add_pattern(q1, p1);
else
m_mam->add_pattern(q1, p1);
}
}
if (!unary)
j++;
}
@ -713,6 +847,13 @@ namespace smt {
return m_fparams->m_ematching && !m_qm->empty();
}
bool refine_instance(quantifier* q, app* pat, unsigned num_bindings, enode* const* bindings,
unsigned max_generation, unsigned min_top_generation, unsigned max_top_generation,
vector<std::tuple<enode*, enode*>>& used_enodes) override {
return try_ho_refine(q, pat, num_bindings, bindings, max_generation, min_top_generation, max_top_generation, used_enodes);
}
void add_eq_eh(enode * e1, enode * e2) override {
if (use_ematching())
m_mam->add_eq_eh(e1, e2);
@ -726,7 +867,9 @@ namespace smt {
}
bool can_propagate() const override {
return m_active && m_mam->has_work();
bool r = m_active && m_mam->has_work();
IF_VERBOSE(11, if (r) verbose_stream() << "ho_matching: can_propagate=true\n");
return r;
}
void restart_eh() override {
@ -750,6 +893,7 @@ namespace smt {
void propagate() override {
if (!m_active)
return;
IF_VERBOSE(10, verbose_stream() << "ho_matching: propagate(), mam.has_work=" << m_mam->has_work() << "\n");
m_mam->match();
if (!m_context->relevancy() && use_ematching()) {
ptr_vector<enode>::const_iterator it = m_context->begin_enodes();