diff --git a/src/model/model_macro_solver.cpp b/src/model/model_macro_solver.cpp index 64881482e..0e1c44900 100644 --- a/src/model/model_macro_solver.cpp +++ b/src/model/model_macro_solver.cpp @@ -513,7 +513,7 @@ void non_auf_macro_solver::collect_candidates(ptr_vector const& qs, TRACE(model_finder, tout << "considering macro for: " << f->get_name() << "\n"; m->display(tout); tout << "\n";); if (m->is_unconditional() && (!qi->is_auf() || m->get_weight() >= m_mbqi_force_template)) { - full_macros.insert(f, std::make_pair(m, q)); + full_macros.insert(f, {m, q}); cond_macros.erase(f); } else if (!full_macros.contains(f) && !qi->is_auf()) @@ -524,10 +524,8 @@ void non_auf_macro_solver::collect_candidates(ptr_vector const& qs, } void non_auf_macro_solver::process_full_macros(obj_map const& full_macros, obj_hashtable& removed) { - for (auto const& kv : full_macros) { - func_decl* f = kv.m_key; - cond_macro* m = kv.m_value.first; - quantifier* q = kv.m_value.second; + for (auto const &[f, v] : full_macros) { + auto [m, q] = v; SASSERT(m->is_unconditional()); if (add_macro(f, m->get_def())) { get_qinfo(q)->set_the_one(f); diff --git a/src/smt/smt_model_finder.cpp b/src/smt/smt_model_finder.cpp index 1af371e8b..675a35200 100644 --- a/src/smt/smt_model_finder.cpp +++ b/src/smt/smt_model_finder.cpp @@ -31,6 +31,7 @@ Revision History: #include "ast/ast_ll_pp.h" #include "ast/well_sorted.h" #include "ast/ast_smt2_pp.h" +#include "ast/term_enumeration.h" #include "model/model_pp.h" #include "model/model_macro_solver.h" #include "smt/smt_model_finder.h" @@ -1369,6 +1370,88 @@ namespace smt { }; + class ho_var : public qinfo { + unsigned m_var_i; + public: + ho_var(ast_manager& m, unsigned i) : qinfo(m), m_var_i(i) { + } + + char const *get_kind() const override { + return "ho_var"; + } + + bool is_equal(qinfo const *qi) const override { + if (qi->get_kind() != get_kind()) + return false; + ho_var const *other = static_cast(qi); + return m_var_i == other->m_var_i; + } + + void display(std::ostream &out) const override { + out << "(" << "ho-var " << ":" << m_var_i << ")"; + } + + void process_auf(quantifier *q, auf_solver &s, context *ctx) override { + } + + void populate_inst_sets(quantifier *q, auf_solver &s, context *ctx) override { + + node *S = s.get_uvar(q, m_var_i); + sort *srt = S->get_sort(); + sort* range = get_array_range(srt); + unsigned arity = get_array_arity(srt); + IF_VERBOSE(0, verbose_stream() << "ho_var::populate_inst_sets: " << q->get_id() << " " << mk_pp(srt, m) << "\n";); + term_enumeration tn(m); + // Add ground terms of type S. + // Add productions for functions in E-graph + // add other possible relevant functions such as equality over srt, Boolean operators + // TODO: use term_enumerator to produce instances int the instantiation set of S. + expr_ref_vector vars(m); + ptr_vector sorts; + vector names; + for (unsigned i = 0; i < arity; ++i) { + vars.push_back(m.mk_var(i, get_array_domain(srt, i))); + auto v = vars.back(); + tn.add_production(v); + sorts.push_back(v->get_sort()); + names.push_back(symbol(i)); + } + auto mk_lambda = [&](expr* body) { + return m.mk_lambda(vars.size(), sorts.data(), names.data(), body); + }; + ast_mark visited; + for (enode *n : ctx->enodes()) { + if (false && !ctx->is_relevant(n)) + continue; + auto e = n->get_expr(); + if (srt == n->get_sort()) { + S->insert(e, n->get_generation()); + } + else if (is_uninterp_const(e)) { + IF_VERBOSE(0, verbose_stream() << "add production " << mk_pp(e, m) << "\n"); + tn.add_production(e); + } + else if (is_uninterp(e)) { + auto f = to_app(e)->get_decl(); + if (visited.is_marked(f)) + continue; + visited.mark(f, true); + IF_VERBOSE(0, verbose_stream() << "add function " << mk_pp(f, m) << "\n"); + tn.add_production(f); + } + } + + unsigned max_count = 20; + for (auto t : tn.enum_terms(srt)) { + auto lam = mk_lambda(t); + unsigned generation = 0; // todo - inherited from sub-term of t? + IF_VERBOSE(0, verbose_stream() << "ho_var: adding term " << mk_ismt2_pp(t, m) + << " to instantiation set of S" << std::endl;); + S->insert(lam, generation); + } + } + }; + /** \brief auf_arr is a term (pattern) of the form: @@ -2105,7 +2188,12 @@ namespace smt { process_app(to_app(curr)); } else if (is_var(curr)) { - m_info->m_is_auf = false; // unexpected occurrence of variable. + if (m_array_util.is_array(curr)) { + insert_qinfo(alloc(ho_var, m, to_var(curr)->get_idx())); + } + else { + m_info->m_is_auf = false; // unexpected occurrence of variable. + } } else { SASSERT(is_lambda(curr));