diff --git a/src/smt/mam.cpp b/src/smt/mam.cpp index 94f7ae22a..97a7a899c 100644 --- a/src/smt/mam.cpp +++ b/src/smt/mam.cpp @@ -3802,8 +3802,14 @@ namespace { for (unsigned i = 0; i < num_patterns; ++i) { app * pat = to_app(mp->get_arg(i)); TRACE(mam_pat, tout << mk_ismt2_pp(qa, m) << "\npat:\n" << mk_ismt2_pp(pat, m) << "\n";); - SASSERT(!pat->is_ground()); - todo.push_back(pat); + if (pat->is_ground()) { + enode * e = mk_enode(m_context, qa, pat); + m_context.mark_as_relevant(e); + m_context.push_trail(add_shared_enode_trail(*this, e)); + m_shared_enodes.insert(e); + } + else + todo.push_back(pat); } while (!todo.empty()) { app * n = todo.back(); diff --git a/src/smt/smt_quantifier.cpp b/src/smt/smt_quantifier.cpp index 9cd270f1d..4db888a2e 100644 --- a/src/smt/smt_quantifier.cpp +++ b/src/smt/smt_quantifier.cpp @@ -19,6 +19,8 @@ Revision History: #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" #include "ast/quantifier_stat.h" +#include "ast/euf/ho_matcher.h" +#include "ast/rewriter/var_subst.h" #include "smt/smt_quantifier.h" #include "smt/smt_context.h" #include "smt/smt_model_finder.h" @@ -154,7 +156,8 @@ namespace smt { } unsigned get_generation(quantifier * q) const { - return get_stat(q)->get_generation(); + auto* s = m_quantifier_stat.find_core(q); + return s ? s->get_data().get_value()->get_generation() : 0; } void add(quantifier * q, unsigned generation) { @@ -295,6 +298,15 @@ namespace smt { unsigned max_top_generation, vector> & used_enodes) { + // Try higher-order refinement first + if (pat && m_plugin->refine_instance(q, pat, num_bindings, bindings, max_generation, min_top_generation, max_top_generation, used_enodes)) + return true; + + if (!m_quantifier_stat.contains(q)) { + IF_VERBOSE(2, verbose_stream() << "add_instance: quantifier not in stat map: " << mk_pp(q, m()) << "\n"); + return false; + } + max_generation = std::max(max_generation, get_generation(q)); get_stat(q)->update_max_generation(max_generation); @@ -599,9 +611,24 @@ namespace smt { scoped_ptr m_lazy_mam; scoped_ptr m_model_finder; scoped_ptr m_model_checker; + scoped_ptr m_ho_matcher; + trail_stack m_ho_trail; unsigned m_new_enode_qhead; unsigned m_lazy_matching_idx; bool m_active; + + // State for higher-order match refinement callback + struct ho_match_state { + quantifier* m_q = nullptr; + app* m_pat = nullptr; + unsigned m_num_bindings = 0; + enode* const* m_bindings = nullptr; + unsigned m_max_generation = 0; + unsigned m_min_top_generation = 0; + unsigned m_max_top_generation = 0; + vector>* m_used_enodes = nullptr; + }; + ho_match_state m_ho_state; public: default_qm_plugin(): m_qm(nullptr), @@ -625,10 +652,104 @@ namespace smt { m_model_finder->set_context(m_context); m_model_checker->set_qm(qm); + + if (m_fparams->m_ho_matching) { + m_ho_matcher = alloc(euf::ho_matcher, m, m_ho_trail); + std::function on_match = [&](euf::ho_subst& s) { + on_ho_match(s); + }; + m_ho_matcher->set_on_match(on_match); + } } quantifier_manager_plugin * mk_fresh() override { return alloc(default_qm_plugin); } + void on_ho_match(euf::ho_subst& s) { + ast_manager& m = m_context->get_manager(); + auto& st = m_ho_state; + auto* hoq = st.m_q; + auto* q = m_ho_matcher->hoq2q(hoq); + + expr_ref_vector binding(m); + for (unsigned i = 0; i < s.size(); ++i) + binding.push_back(s.get(i)); + + // Shrink binding to original quantifier's num_decls + // The HO quantifier has extra vars at higher indices; drop them. + // Binding is indexed by var index: binding[i] = value for var i. + // First substitute any remaining vars, then keep only original vars. + if (binding.size() > q->get_num_decls()) { + var_subst sub(m); + bool change = true; + while (change) { + change = false; + for (unsigned i = 0; i < binding.size(); ++i) { + if (!binding.get(i)) continue; + auto r = sub(binding.get(i), binding); + change |= r != binding.get(i); + binding[i] = r; + } + } + binding.shrink(q->get_num_decls()); + } + + // Create enodes for the refined bindings and add instance + ptr_buffer new_bindings; + unsigned max_gen = st.m_max_generation; + for (unsigned i = 0; i < q->get_num_decls(); ++i) { + expr* e = binding.get(i); + if (!e) + return; // incomplete binding + if (!m_context->e_internalized(e)) { + m_context->internalize(e, false); + } + enode* n = m_context->get_enode(e); + new_bindings.push_back(n); + if (n->get_generation() > max_gen) + max_gen = n->get_generation(); + } + + TRACE(ho_matching, + tout << "ho_match refined for " << mk_pp(q, m) << "\n"; + for (unsigned i = 0; i < new_bindings.size(); ++i) + tout << " binding[" << i << "] = " << mk_pp(new_bindings[i]->get_expr(), m) << "\n";); + + vector> used_enodes; + m_context->add_instance(q, nullptr, new_bindings.size(), new_bindings.data(), + nullptr, max_gen, st.m_min_top_generation, st.m_max_top_generation, used_enodes); + } + + bool try_ho_refine(quantifier* qa, app* pat, unsigned num_bindings, enode* const* bindings, + unsigned max_generation, unsigned min_top_gen, unsigned max_top_gen, + vector>& used_enodes) { + if (!m_ho_matcher || !m_ho_matcher->is_ho_pattern(pat)) + return false; + + ast_manager& m = m_context->get_manager(); + expr_ref_vector s(m); + // With var_subst(std_order=true): var idx maps to s[s.size()-idx-1] + // SMT MAM bindings: bindings[i] = var at index (num_bindings-1-i) + // So bindings[i] corresponds to s[i] with std_order + for (unsigned i = 0; i < num_bindings; ++i) + s.push_back(bindings[i]->get_expr()); + + m_ho_state.m_q = qa; + m_ho_state.m_pat = pat; + m_ho_state.m_num_bindings = num_bindings; + m_ho_state.m_bindings = bindings; + m_ho_state.m_max_generation = max_generation; + m_ho_state.m_min_top_generation = min_top_gen; + m_ho_state.m_max_top_generation = max_top_gen; + m_ho_state.m_used_enodes = &used_enodes; + + IF_VERBOSE(10, verbose_stream() << "try_ho_refine: q=" << mk_pp(qa, m) << "\n pat=" << mk_pp(pat, m) << "\n"; + for (unsigned i = 0; i < num_bindings; ++i) + verbose_stream() << " s[" << i << "] = " << mk_pp(s.get(i), m) << " sort=" << mk_pp(s.get(i)->get_sort(), m) << "\n";); + + m_ho_matcher->refine_ho_match(pat, s); + return true; + } + bool model_based() const override { return m_fparams->m_mbqi; } bool mbqi_enabled(quantifier *q) const override { @@ -704,6 +825,19 @@ namespace smt { TRACE(quantifier, tout << "adding:\n" << expr_ref(mp, m) << "\n";); m_mam->add_pattern(q, mp); } + // Compile HO pattern and also register the compiled version with MAM + if (m_ho_matcher) { + auto [q1, p1] = m_ho_matcher->compile_ho_pattern(q, mp); + IF_VERBOSE(10, verbose_stream() << "ho_matching: q=" << q->get_qid() + << " compiled=" << (p1 != mp) + << " p1=" << mk_pp(p1, m) << "\n"); + if (p1 != mp) { + if (!unary && j >= num_eager_multi_patterns) + m_lazy_mam->add_pattern(q1, p1); + else + m_mam->add_pattern(q1, p1); + } + } if (!unary) j++; } @@ -713,6 +847,13 @@ namespace smt { return m_fparams->m_ematching && !m_qm->empty(); } + + bool refine_instance(quantifier* q, app* pat, unsigned num_bindings, enode* const* bindings, + unsigned max_generation, unsigned min_top_generation, unsigned max_top_generation, + vector>& used_enodes) override { + return try_ho_refine(q, pat, num_bindings, bindings, max_generation, min_top_generation, max_top_generation, used_enodes); + } + void add_eq_eh(enode * e1, enode * e2) override { if (use_ematching()) m_mam->add_eq_eh(e1, e2); @@ -726,7 +867,9 @@ namespace smt { } bool can_propagate() const override { - return m_active && m_mam->has_work(); + bool r = m_active && m_mam->has_work(); + IF_VERBOSE(11, if (r) verbose_stream() << "ho_matching: can_propagate=true\n"); + return r; } void restart_eh() override { @@ -750,6 +893,7 @@ namespace smt { void propagate() override { if (!m_active) return; + IF_VERBOSE(10, verbose_stream() << "ho_matching: propagate(), mam.has_work=" << m_mam->has_work() << "\n"); m_mam->match(); if (!m_context->relevancy() && use_ematching()) { ptr_vector::const_iterator it = m_context->begin_enodes(); diff --git a/src/smt/smt_quantifier.h b/src/smt/smt_quantifier.h index 981647606..4d8e6da5a 100644 --- a/src/smt/smt_quantifier.h +++ b/src/smt/smt_quantifier.h @@ -178,8 +178,14 @@ namespace smt { virtual void push() = 0; virtual void pop(unsigned num_scopes) = 0; - + /** + \brief Try to refine a match using higher-order matching. + Returns true if the pattern was an HO pattern and refinement was attempted. + In that case, the plugin handles adding instances via the refined bindings. + */ + virtual bool refine_instance(quantifier* q, app* pat, unsigned num_bindings, enode* const* bindings, + unsigned max_generation, unsigned min_top_generation, unsigned max_top_generation, + vector>& used_enodes) { return false; } }; }; -