mirror of
https://github.com/Z3Prover/z3
synced 2026-05-25 11:26:21 +00:00
enable higher-order matching in mam/smt_quantifier
This commit is contained in:
parent
98d0e7f27c
commit
ea0964d195
3 changed files with 162 additions and 6 deletions
|
|
@ -19,6 +19,8 @@ Revision History:
|
|||
#include "ast/ast_pp.h"
|
||||
#include "ast/ast_ll_pp.h"
|
||||
#include "ast/quantifier_stat.h"
|
||||
#include "ast/euf/ho_matcher.h"
|
||||
#include "ast/rewriter/var_subst.h"
|
||||
#include "smt/smt_quantifier.h"
|
||||
#include "smt/smt_context.h"
|
||||
#include "smt/smt_model_finder.h"
|
||||
|
|
@ -154,7 +156,8 @@ namespace smt {
|
|||
}
|
||||
|
||||
unsigned get_generation(quantifier * q) const {
|
||||
return get_stat(q)->get_generation();
|
||||
auto* s = m_quantifier_stat.find_core(q);
|
||||
return s ? s->get_data().get_value()->get_generation() : 0;
|
||||
}
|
||||
|
||||
void add(quantifier * q, unsigned generation) {
|
||||
|
|
@ -295,6 +298,15 @@ namespace smt {
|
|||
unsigned max_top_generation,
|
||||
vector<std::tuple<enode *, enode *>> & used_enodes) {
|
||||
|
||||
// Try higher-order refinement first
|
||||
if (pat && m_plugin->refine_instance(q, pat, num_bindings, bindings, max_generation, min_top_generation, max_top_generation, used_enodes))
|
||||
return true;
|
||||
|
||||
if (!m_quantifier_stat.contains(q)) {
|
||||
IF_VERBOSE(2, verbose_stream() << "add_instance: quantifier not in stat map: " << mk_pp(q, m()) << "\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
max_generation = std::max(max_generation, get_generation(q));
|
||||
|
||||
get_stat(q)->update_max_generation(max_generation);
|
||||
|
|
@ -599,9 +611,24 @@ namespace smt {
|
|||
scoped_ptr<mam> m_lazy_mam;
|
||||
scoped_ptr<model_finder> m_model_finder;
|
||||
scoped_ptr<model_checker> m_model_checker;
|
||||
scoped_ptr<euf::ho_matcher> m_ho_matcher;
|
||||
trail_stack m_ho_trail;
|
||||
unsigned m_new_enode_qhead;
|
||||
unsigned m_lazy_matching_idx;
|
||||
bool m_active;
|
||||
|
||||
// State for higher-order match refinement callback
|
||||
struct ho_match_state {
|
||||
quantifier* m_q = nullptr;
|
||||
app* m_pat = nullptr;
|
||||
unsigned m_num_bindings = 0;
|
||||
enode* const* m_bindings = nullptr;
|
||||
unsigned m_max_generation = 0;
|
||||
unsigned m_min_top_generation = 0;
|
||||
unsigned m_max_top_generation = 0;
|
||||
vector<std::tuple<enode*, enode*>>* m_used_enodes = nullptr;
|
||||
};
|
||||
ho_match_state m_ho_state;
|
||||
public:
|
||||
default_qm_plugin():
|
||||
m_qm(nullptr),
|
||||
|
|
@ -625,10 +652,104 @@ namespace smt {
|
|||
|
||||
m_model_finder->set_context(m_context);
|
||||
m_model_checker->set_qm(qm);
|
||||
|
||||
if (m_fparams->m_ho_matching) {
|
||||
m_ho_matcher = alloc(euf::ho_matcher, m, m_ho_trail);
|
||||
std::function<void(euf::ho_subst&)> on_match = [&](euf::ho_subst& s) {
|
||||
on_ho_match(s);
|
||||
};
|
||||
m_ho_matcher->set_on_match(on_match);
|
||||
}
|
||||
}
|
||||
|
||||
quantifier_manager_plugin * mk_fresh() override { return alloc(default_qm_plugin); }
|
||||
|
||||
void on_ho_match(euf::ho_subst& s) {
|
||||
ast_manager& m = m_context->get_manager();
|
||||
auto& st = m_ho_state;
|
||||
auto* hoq = st.m_q;
|
||||
auto* q = m_ho_matcher->hoq2q(hoq);
|
||||
|
||||
expr_ref_vector binding(m);
|
||||
for (unsigned i = 0; i < s.size(); ++i)
|
||||
binding.push_back(s.get(i));
|
||||
|
||||
// Shrink binding to original quantifier's num_decls
|
||||
// The HO quantifier has extra vars at higher indices; drop them.
|
||||
// Binding is indexed by var index: binding[i] = value for var i.
|
||||
// First substitute any remaining vars, then keep only original vars.
|
||||
if (binding.size() > q->get_num_decls()) {
|
||||
var_subst sub(m);
|
||||
bool change = true;
|
||||
while (change) {
|
||||
change = false;
|
||||
for (unsigned i = 0; i < binding.size(); ++i) {
|
||||
if (!binding.get(i)) continue;
|
||||
auto r = sub(binding.get(i), binding);
|
||||
change |= r != binding.get(i);
|
||||
binding[i] = r;
|
||||
}
|
||||
}
|
||||
binding.shrink(q->get_num_decls());
|
||||
}
|
||||
|
||||
// Create enodes for the refined bindings and add instance
|
||||
ptr_buffer<enode> new_bindings;
|
||||
unsigned max_gen = st.m_max_generation;
|
||||
for (unsigned i = 0; i < q->get_num_decls(); ++i) {
|
||||
expr* e = binding.get(i);
|
||||
if (!e)
|
||||
return; // incomplete binding
|
||||
if (!m_context->e_internalized(e)) {
|
||||
m_context->internalize(e, false);
|
||||
}
|
||||
enode* n = m_context->get_enode(e);
|
||||
new_bindings.push_back(n);
|
||||
if (n->get_generation() > max_gen)
|
||||
max_gen = n->get_generation();
|
||||
}
|
||||
|
||||
TRACE(ho_matching,
|
||||
tout << "ho_match refined for " << mk_pp(q, m) << "\n";
|
||||
for (unsigned i = 0; i < new_bindings.size(); ++i)
|
||||
tout << " binding[" << i << "] = " << mk_pp(new_bindings[i]->get_expr(), m) << "\n";);
|
||||
|
||||
vector<std::tuple<enode*, enode*>> used_enodes;
|
||||
m_context->add_instance(q, nullptr, new_bindings.size(), new_bindings.data(),
|
||||
nullptr, max_gen, st.m_min_top_generation, st.m_max_top_generation, used_enodes);
|
||||
}
|
||||
|
||||
bool try_ho_refine(quantifier* qa, app* pat, unsigned num_bindings, enode* const* bindings,
|
||||
unsigned max_generation, unsigned min_top_gen, unsigned max_top_gen,
|
||||
vector<std::tuple<enode*, enode*>>& used_enodes) {
|
||||
if (!m_ho_matcher || !m_ho_matcher->is_ho_pattern(pat))
|
||||
return false;
|
||||
|
||||
ast_manager& m = m_context->get_manager();
|
||||
expr_ref_vector s(m);
|
||||
// With var_subst(std_order=true): var idx maps to s[s.size()-idx-1]
|
||||
// SMT MAM bindings: bindings[i] = var at index (num_bindings-1-i)
|
||||
// So bindings[i] corresponds to s[i] with std_order
|
||||
for (unsigned i = 0; i < num_bindings; ++i)
|
||||
s.push_back(bindings[i]->get_expr());
|
||||
|
||||
m_ho_state.m_q = qa;
|
||||
m_ho_state.m_pat = pat;
|
||||
m_ho_state.m_num_bindings = num_bindings;
|
||||
m_ho_state.m_bindings = bindings;
|
||||
m_ho_state.m_max_generation = max_generation;
|
||||
m_ho_state.m_min_top_generation = min_top_gen;
|
||||
m_ho_state.m_max_top_generation = max_top_gen;
|
||||
m_ho_state.m_used_enodes = &used_enodes;
|
||||
|
||||
IF_VERBOSE(10, verbose_stream() << "try_ho_refine: q=" << mk_pp(qa, m) << "\n pat=" << mk_pp(pat, m) << "\n";
|
||||
for (unsigned i = 0; i < num_bindings; ++i)
|
||||
verbose_stream() << " s[" << i << "] = " << mk_pp(s.get(i), m) << " sort=" << mk_pp(s.get(i)->get_sort(), m) << "\n";);
|
||||
|
||||
m_ho_matcher->refine_ho_match(pat, s);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool model_based() const override { return m_fparams->m_mbqi; }
|
||||
|
||||
bool mbqi_enabled(quantifier *q) const override {
|
||||
|
|
@ -704,6 +825,19 @@ namespace smt {
|
|||
TRACE(quantifier, tout << "adding:\n" << expr_ref(mp, m) << "\n";);
|
||||
m_mam->add_pattern(q, mp);
|
||||
}
|
||||
// Compile HO pattern and also register the compiled version with MAM
|
||||
if (m_ho_matcher) {
|
||||
auto [q1, p1] = m_ho_matcher->compile_ho_pattern(q, mp);
|
||||
IF_VERBOSE(10, verbose_stream() << "ho_matching: q=" << q->get_qid()
|
||||
<< " compiled=" << (p1 != mp)
|
||||
<< " p1=" << mk_pp(p1, m) << "\n");
|
||||
if (p1 != mp) {
|
||||
if (!unary && j >= num_eager_multi_patterns)
|
||||
m_lazy_mam->add_pattern(q1, p1);
|
||||
else
|
||||
m_mam->add_pattern(q1, p1);
|
||||
}
|
||||
}
|
||||
if (!unary)
|
||||
j++;
|
||||
}
|
||||
|
|
@ -713,6 +847,13 @@ namespace smt {
|
|||
return m_fparams->m_ematching && !m_qm->empty();
|
||||
}
|
||||
|
||||
|
||||
bool refine_instance(quantifier* q, app* pat, unsigned num_bindings, enode* const* bindings,
|
||||
unsigned max_generation, unsigned min_top_generation, unsigned max_top_generation,
|
||||
vector<std::tuple<enode*, enode*>>& used_enodes) override {
|
||||
return try_ho_refine(q, pat, num_bindings, bindings, max_generation, min_top_generation, max_top_generation, used_enodes);
|
||||
}
|
||||
|
||||
void add_eq_eh(enode * e1, enode * e2) override {
|
||||
if (use_ematching())
|
||||
m_mam->add_eq_eh(e1, e2);
|
||||
|
|
@ -726,7 +867,9 @@ namespace smt {
|
|||
}
|
||||
|
||||
bool can_propagate() const override {
|
||||
return m_active && m_mam->has_work();
|
||||
bool r = m_active && m_mam->has_work();
|
||||
IF_VERBOSE(11, if (r) verbose_stream() << "ho_matching: can_propagate=true\n");
|
||||
return r;
|
||||
}
|
||||
|
||||
void restart_eh() override {
|
||||
|
|
@ -750,6 +893,7 @@ namespace smt {
|
|||
void propagate() override {
|
||||
if (!m_active)
|
||||
return;
|
||||
IF_VERBOSE(10, verbose_stream() << "ho_matching: propagate(), mam.has_work=" << m_mam->has_work() << "\n");
|
||||
m_mam->match();
|
||||
if (!m_context->relevancy() && use_ematching()) {
|
||||
ptr_vector<enode>::const_iterator it = m_context->begin_enodes();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue