3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-05-31 22:27:48 +00:00

prepare for lambda unfolding in ho-matcher and selectively enable ho matching

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2026-05-22 13:24:19 -07:00
parent f40e4759e4
commit 19166bd0b5
7 changed files with 75 additions and 5 deletions

View file

@ -242,17 +242,72 @@ namespace euf {
// We assume that m_rewriter should produce
// something amounting to weak-head normal form WHNF
// Unfold a lambda-def application f(args) to the corresponding lambda expression.
// For a func_decl f with arity n and lambda-def quantifier (lambda (x1..xk) body),
// f(a1,...,an) is unfolded to (lambda (x1..xk) body[params := a1..an]).
// For a constant f (arity 0) that is a lambda-def, returns the lambda directly.
expr_ref ho_matcher::unfold_lambda_def(expr* e) const {
if (!is_app(e))
return expr_ref(e, m);
app* a = to_app(e);
func_decl* f = a->get_decl();
quantifier* lam = m.is_lambda_def(f);
if (!lam)
return expr_ref(e, m);
unsigned arity = f->get_arity();
SASSERT(is_lambda(lam));
if (arity == 0) {
// Constant lambda-def: just return the lambda expression
return expr_ref(lam, m);
}
// f(a1,...,an) where lam = lambda (x1..xk) body
// The lambda body uses var(0)..var(k-1) for the lambda-bound vars
// and var(k)..var(k+n-1) for the function parameters.
// We substitute the function parameters with the actual arguments.
unsigned num_lambda_decls = lam->get_num_decls();
expr_ref body(lam->get_expr(), m);
// Build substitution: var(num_lambda_decls + i) -> a->get_arg(arity - 1 - i) for i in [0, arity)
// var_subst replaces var(i) with subst[i]
expr_ref_vector subst(m);
subst.resize(num_lambda_decls + arity);
for (unsigned i = 0; i < num_lambda_decls; ++i)
subst[i] = m.mk_var(i, lam->get_decl_sort(num_lambda_decls - 1 - i));
for (unsigned i = 0; i < arity; ++i)
subst[num_lambda_decls + i] = a->get_arg(arity - 1 - i);
var_subst vs(m, false);
body = vs(lam->get_expr(), subst);
// Rebuild the lambda with the new body
ptr_buffer<sort> sorts;
vector<symbol> names;
for (unsigned i = 0; i < num_lambda_decls; ++i) {
sorts.push_back(lam->get_decl_sort(i));
names.push_back(lam->get_decl_name(i));
}
return expr_ref(m.mk_lambda(num_lambda_decls, sorts.data(), names.data(), body), m);
}
void ho_matcher::reduce(match_goal& wi) {
while (true) {
expr_ref r = whnf(wi.pat, wi.pat_offset());
if (r == wi.pat)
r = unfold_lambda_def(wi.pat);
if (r == wi.pat)
break;
IF_VERBOSE(3, verbose_stream() << "ho_matcher::reduce: " << wi.pat << " -> " << r << "\n";);
wi.pat = r;
}
while (true) {
expr_ref r = whnf(wi.t, wi.term_offset());
if (r == wi.t)
r = unfold_lambda_def(wi.t);
if (r == wi.t)
break;
IF_VERBOSE(3, verbose_stream() << "ho_matcher::reduce: " << wi.t << " -> " << r << "\n";);
@ -656,16 +711,17 @@ namespace euf {
todo.pop_back();
continue;
}
if (m_unitary.is_flex(0, t)) {
if (m_unitary.is_flex(0, t) || (is_app(t) && m.is_lambda_def(to_app(t)->get_decl())) || is_lambda(t)) {
m_pat2abs.insert_if_not_there(p, svector<std::pair<unsigned, expr*>>()).push_back({ nb, t });
auto v = m.mk_var(nb++, t->get_sort());
bound.push_back(v);
cache.setx(t->get_id(), v);
todo.pop_back();
continue;
}
}
if (is_app(t)) {
auto a = to_app(t);
unsigned sz = a->get_num_args();
ptr_buffer<expr> args;
for (auto arg : *a) {
@ -737,7 +793,7 @@ namespace euf {
trail().push(insert_map(m_pat2abs, p));
trail().push(insert_map(m_q2hoq, q));
trail().push(insert_map(m_hoq2q, q1));
trail().push(insert_map(m_hopat2free_vars, p1));
trail().push(insert_map(m_hopat2free_vars, p1));
return { q1, p1 };
}

View file

@ -350,6 +350,8 @@ namespace euf {
void reduce(match_goal& wi);
expr_ref unfold_lambda_def(expr* e) const;
trail_stack& trail() { return m_trail; }
std::ostream& display(std::ostream& out) const;
@ -395,5 +397,13 @@ namespace euf {
quantifier* hoq2q(quantifier* q) const { return m_hoq2q[q]; }
svector<std::pair<unsigned, expr*>> const* get_flex_subterms(app* p) const {
auto orig_p = m_hopat2pat.find_core(p);
if (!orig_p) return nullptr;
auto abs = m_pat2abs.find_core(orig_p->get_data().get_value());
return abs ? &abs->get_data().get_value() : nullptr;
}
};
}