diff --git a/src/ast/euf/ho_matcher.cpp b/src/ast/euf/ho_matcher.cpp index f86d82b64..20b913b17 100644 --- a/src/ast/euf/ho_matcher.cpp +++ b/src/ast/euf/ho_matcher.cpp @@ -242,17 +242,72 @@ namespace euf { // We assume that m_rewriter should produce // something amounting to weak-head normal form WHNF + // Unfold a lambda-def application f(args) to the corresponding lambda expression. + // For a func_decl f with arity n and lambda-def quantifier (lambda (x1..xk) body), + // f(a1,...,an) is unfolded to (lambda (x1..xk) body[params := a1..an]). + // For a constant f (arity 0) that is a lambda-def, returns the lambda directly. + expr_ref ho_matcher::unfold_lambda_def(expr* e) const { + if (!is_app(e)) + return expr_ref(e, m); + app* a = to_app(e); + func_decl* f = a->get_decl(); + quantifier* lam = m.is_lambda_def(f); + if (!lam) + return expr_ref(e, m); + + unsigned arity = f->get_arity(); + SASSERT(is_lambda(lam)); + + if (arity == 0) { + // Constant lambda-def: just return the lambda expression + return expr_ref(lam, m); + } + + // f(a1,...,an) where lam = lambda (x1..xk) body + // The lambda body uses var(0)..var(k-1) for the lambda-bound vars + // and var(k)..var(k+n-1) for the function parameters. + // We substitute the function parameters with the actual arguments. + unsigned num_lambda_decls = lam->get_num_decls(); + expr_ref body(lam->get_expr(), m); + + // Build substitution: var(num_lambda_decls + i) -> a->get_arg(arity - 1 - i) for i in [0, arity) + // var_subst replaces var(i) with subst[i] + expr_ref_vector subst(m); + subst.resize(num_lambda_decls + arity); + for (unsigned i = 0; i < num_lambda_decls; ++i) + subst[i] = m.mk_var(i, lam->get_decl_sort(num_lambda_decls - 1 - i)); + for (unsigned i = 0; i < arity; ++i) + subst[num_lambda_decls + i] = a->get_arg(arity - 1 - i); + + var_subst vs(m, false); + body = vs(lam->get_expr(), subst); + + // Rebuild the lambda with the new body + ptr_buffer sorts; + vector names; + for (unsigned i = 0; i < num_lambda_decls; ++i) { + sorts.push_back(lam->get_decl_sort(i)); + names.push_back(lam->get_decl_name(i)); + } + return expr_ref(m.mk_lambda(num_lambda_decls, sorts.data(), names.data(), body), m); + } + void ho_matcher::reduce(match_goal& wi) { while (true) { expr_ref r = whnf(wi.pat, wi.pat_offset()); + if (r == wi.pat) + r = unfold_lambda_def(wi.pat); if (r == wi.pat) break; IF_VERBOSE(3, verbose_stream() << "ho_matcher::reduce: " << wi.pat << " -> " << r << "\n";); wi.pat = r; } + while (true) { expr_ref r = whnf(wi.t, wi.term_offset()); + if (r == wi.t) + r = unfold_lambda_def(wi.t); if (r == wi.t) break; IF_VERBOSE(3, verbose_stream() << "ho_matcher::reduce: " << wi.t << " -> " << r << "\n";); @@ -656,16 +711,17 @@ namespace euf { todo.pop_back(); continue; } - if (m_unitary.is_flex(0, t)) { + if (m_unitary.is_flex(0, t) || (is_app(t) && m.is_lambda_def(to_app(t)->get_decl())) || is_lambda(t)) { m_pat2abs.insert_if_not_there(p, svector>()).push_back({ nb, t }); auto v = m.mk_var(nb++, t->get_sort()); bound.push_back(v); cache.setx(t->get_id(), v); todo.pop_back(); continue; - } + } if (is_app(t)) { auto a = to_app(t); + unsigned sz = a->get_num_args(); ptr_buffer args; for (auto arg : *a) { @@ -737,7 +793,7 @@ namespace euf { trail().push(insert_map(m_pat2abs, p)); trail().push(insert_map(m_q2hoq, q)); trail().push(insert_map(m_hoq2q, q1)); - trail().push(insert_map(m_hopat2free_vars, p1)); + trail().push(insert_map(m_hopat2free_vars, p1)); return { q1, p1 }; } diff --git a/src/ast/euf/ho_matcher.h b/src/ast/euf/ho_matcher.h index 65477078c..434900fa2 100644 --- a/src/ast/euf/ho_matcher.h +++ b/src/ast/euf/ho_matcher.h @@ -350,6 +350,8 @@ namespace euf { void reduce(match_goal& wi); + expr_ref unfold_lambda_def(expr* e) const; + trail_stack& trail() { return m_trail; } std::ostream& display(std::ostream& out) const; @@ -395,5 +397,13 @@ namespace euf { quantifier* hoq2q(quantifier* q) const { return m_hoq2q[q]; } + + svector> const* get_flex_subterms(app* p) const { + auto orig_p = m_hopat2pat.find_core(p); + if (!orig_p) return nullptr; + auto abs = m_pat2abs.find_core(orig_p->get_data().get_value()); + return abs ? &abs->get_data().get_value() : nullptr; + } + }; } diff --git a/src/ast/pattern/pattern_inference.cpp b/src/ast/pattern/pattern_inference.cpp index 32b0546e8..8ec88096a 100644 --- a/src/ast/pattern/pattern_inference.cpp +++ b/src/ast/pattern/pattern_inference.cpp @@ -546,13 +546,13 @@ void pattern_inference_cfg::reset_pre_patterns() { bool pattern_inference_cfg::is_forbidden(app * n) const { - func_decl const * decl = n->get_decl(); + func_decl * decl = n->get_decl(); if (is_ground(n)) return false; // Remark: skolem constants should not be used in patterns, since they do not // occur outside of the quantifier. That is, Z3 will never match this kind of // pattern. - if (m_params.m_pi_avoid_skolems && decl->is_skolem()) { + if (m_params.m_pi_avoid_skolems && decl->is_skolem() && !m.is_lambda_def(decl)) { CTRACE(pattern_inference_skolem, decl->is_skolem(), tout << "ignoring: " << mk_pp(n, m) << "\n";); return true; } diff --git a/src/params/smt_params.cpp b/src/params/smt_params.cpp index a80483d0f..a64075204 100644 --- a/src/params/smt_params.cpp +++ b/src/params/smt_params.cpp @@ -27,6 +27,7 @@ void smt_params::updt_local_params(params_ref const & _p) { m_random_seed = p.random_seed(); m_relevancy_lvl = p.relevancy(); m_ematching = p.ematching(); + m_ho_matching = p.ho_matching(); m_induction = p.induction(); m_clause_proof = p.clause_proof(); m_phase_selection = static_cast(p.phase_selection()); diff --git a/src/params/smt_params.h b/src/params/smt_params.h index 68ab50ffe..7b5efa991 100644 --- a/src/params/smt_params.h +++ b/src/params/smt_params.h @@ -109,6 +109,7 @@ struct smt_params : public preprocessor_params, bool m_display_features = false; bool m_new_core2th_eq = true; bool m_ematching = true; + bool m_ho_matching = false; bool m_induction = false; bool m_clause_proof = false; symbol m_proof_log; diff --git a/src/params/smt_params_helper.pyg b/src/params/smt_params_helper.pyg index 6c8ce796d..cd52b989c 100644 --- a/src/params/smt_params_helper.pyg +++ b/src/params/smt_params_helper.pyg @@ -10,6 +10,7 @@ def_module_params(module_name='smt', ('quasi_macros', BOOL, False, 'try to find universally quantified formulas that are quasi-macros'), ('restricted_quasi_macros', BOOL, False, 'try to find universally quantified formulas that are restricted quasi-macros'), ('ematching', BOOL, True, 'E-Matching based quantifier instantiation'), + ('ho_matching', BOOL, False, 'higher-order matching for quantifier instantiation'), ('phase_selection', UINT, 3, 'phase selection heuristic: 0 - always false, 1 - always true, 2 - phase caching, 3 - phase caching conservative, 4 - phase caching conservative 2, 5 - random, 6 - number of occurrences, 7 - theory'), ('phase_caching_on', UINT, 400, 'number of conflicts while phase caching is on'), ('phase_caching_off', UINT, 100, 'number of conflicts while phase caching is off'), diff --git a/src/util/trace_tags.def b/src/util/trace_tags.def index 67adb62c9..d343b2304 100644 --- a/src/util/trace_tags.def +++ b/src/util/trace_tags.def @@ -56,6 +56,7 @@ X(ctx_propagate_assertions, assert_eq_bug, "assert eq bug") X(ctx_solver_simplify_tactic, ctx_solver_simplify_tactic, "ctx solver simplify tactic") X(default_qm_plugin, default_qm_plugin, "default qm plugin") +X(default_qm_plugin, ho_matching, "ho matching") X(default_qm_plugin, mam_stats, "mam stats") X(default_qm_plugin, quantifier, "quantifier")