diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index b5eddceb1..6e2f2e6af 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -2386,29 +2386,19 @@ app * ast_manager::mk_pattern(unsigned num_exprs, app * const * exprs) { } bool ast_manager::is_pattern(expr const * n) const { - if (!is_app_of(n, pattern_family_id, OP_PATTERN)) { - return false; - } - for (unsigned i = 0; i < to_app(n)->get_num_args(); ++i) { - if (!is_app(to_app(n)->get_arg(i))) { - return false; - } - } - return true; + if (!is_app_of(n, pattern_family_id, OP_PATTERN)) + return false; + return all_of(*to_app(n), [](expr* arg) { return is_app(arg); }); } -bool ast_manager::is_pattern(expr const * n, ptr_vector &args) { - if (!is_app_of(n, pattern_family_id, OP_PATTERN)) { +bool ast_manager::is_pattern(expr const * n, ptr_vector &args) { + if (!is_pattern(n)) return false; - } - for (unsigned i = 0; i < to_app(n)->get_num_args(); ++i) { - expr *arg = to_app(n)->get_arg(i); - if (!is_app(arg)) { - return false; - } - args.push_back(arg); - } + + for (auto arg : *to_app(n)) + args.push_back(to_app(arg)); + return true; } diff --git a/src/ast/ast.h b/src/ast/ast.h index 870b55c5c..9dd564206 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -2015,7 +2015,7 @@ public: bool is_pattern(expr const * n) const; - bool is_pattern(expr const *n, ptr_vector &args); + bool is_pattern(expr const *n, ptr_vector &args); public: diff --git a/src/ast/euf/.#euf_justification.h b/src/ast/euf/.#euf_justification.h deleted file mode 100644 index aab6b7ce2..000000000 --- a/src/ast/euf/.#euf_justification.h +++ /dev/null @@ -1 +0,0 @@ -nbjorner@LAPTOP-04AEAFKH.38072:1751392111 \ No newline at end of file diff --git a/src/ast/euf/euf_mam.cpp b/src/ast/euf/euf_mam.cpp index ea928415c..5bc449e35 100644 --- a/src/ast/euf/euf_mam.cpp +++ b/src/ast/euf/euf_mam.cpp @@ -2994,6 +2994,8 @@ namespace euf { SASSERT(m.is_pattern(mp)); SASSERT(first_idx < mp->get_num_args()); app * p = to_app(mp->get_arg(first_idx)); + if (is_ground(p)) + return; func_decl * lbl = p->get_decl(); unsigned lbl_id = lbl->get_small_id(); m_trees.reserve(lbl_id+1, nullptr); @@ -3879,9 +3881,10 @@ namespace euf { // Ground patterns are discarded. // However, the simplifier may turn a non-ground pattern into a ground one. // So, we should check it again here. - for (expr* arg : *mp) - if (is_ground(arg) || has_quantifiers(arg)) - return; // ignore multi-pattern containing ground pattern. + if (all_of(*mp, [](expr* arg) { return is_ground(arg); })) + return; // ignore multi-pattern containing only ground pattern. + if (any_of(*mp, [](expr* arg) { return has_quantifiers(arg); })) + return; // patterns with quantifiers are not handled. update_filters(qa, mp); m_new_patterns.push_back(qp_pair(qa, mp)); ctx.get_trail().push(push_back_trail(m_new_patterns)); diff --git a/src/ast/euf/ho_matcher.cpp b/src/ast/euf/ho_matcher.cpp index 16a1cf434..7be2a8943 100644 --- a/src/ast/euf/ho_matcher.cpp +++ b/src/ast/euf/ho_matcher.cpp @@ -55,6 +55,7 @@ namespace euf { void ho_matcher::operator()(expr* pat, expr* t, unsigned num_bound, unsigned num_vars) { m_trail.push_scope(); + m_subst.resize(0); m_subst.resize(num_vars); m_goals.reset(); m_goals.push(0, num_bound, pat, t); @@ -93,7 +94,7 @@ namespace euf { bool st = consume_work(wi); IF_VERBOSE(3, display(verbose_stream() << "ho_matcher::consume_work: " << wi.pat << " =?= " << wi.t << " -> " << (st?"true":"false") << "\n");); if (st) { - if (m_goals.empty()) + if (m_goals.empty()) m_on_match(m_subst); break; } @@ -635,14 +636,17 @@ namespace euf { } - app* ho_matcher::compile_ho_pattern(quantifier* q, app* p) { + quantifier* ho_matcher::compile_ho_pattern(quantifier* q, app*& p) { app* p1 = nullptr; - if (m_pat2hopat.find(p, p1)) - return p1; + if (m_pat2hopat.find(p, p)) { + q = m_q2hoq[q]; + return q; + } auto is_ho = any_of(subterms::all(expr_ref(p, m)), [&](expr* t) { return m_unitary.is_flex(0, t); }); if (!is_ho) - return p; + return q; ptr_vector todo; + ptr_buffer bound; expr_ref_vector cache(m); unsigned nb = q->get_num_decls(); todo.push_back(p); @@ -655,7 +659,9 @@ namespace euf { } if (m_unitary.is_flex(0, t)) { m_pat2abs.insert_if_not_there(p, svector>()).push_back({ nb, t }); - cache.setx(t->get_id(), m.mk_var(nb++, t->get_sort())); + auto v = m.mk_var(nb++, t->get_sort()); + bound.push_back(v); + cache.setx(t->get_id(), v); todo.pop_back(); continue; } @@ -678,41 +684,91 @@ namespace euf { } if (is_quantifier(t)) { m_pat2abs.remove(p); - return p; + return q; } } - p1 = to_app(cache.get(p->get_id())); + expr_free_vars free_vars; + free_vars(p1); + app_ref_vector new_ground(m); + app_ref_vector new_patterns(m); + + ptr_buffer sorts; + vector names; + for (unsigned i = bound.size(); i-- > 0; ) { + sorts.push_back(bound[i]->get_sort()); + names.push_back(symbol(bound[i]->get_idx())); + } + unsigned sz = q->get_num_decls(); + for (unsigned i = 0; i < sz; ++i) { + unsigned idx = sz - i - 1; + auto s = q->get_decl_sort(i); + sorts.push_back(s); + names.push_back(q->get_decl_name(i)); + if (!free_vars.contains(idx)) { + auto p = m.mk_fresh_func_decl("p", 1, &s, m.mk_bool_sort()); + new_patterns.push_back(m.mk_app(p, m.mk_var(idx, s))); + new_ground.push_back(m.mk_app(p, m.mk_fresh_const(symbol("c"), s))); + } + } + auto body = q->get_expr(); + if (!new_patterns.empty()) { + ptr_vector pats; + VERIFY(m.is_pattern(p1, pats)); + for (auto p : new_patterns) // patterns for variables that are not free in new pattern + pats.push_back(p); + for (auto g : new_ground) // ensure ground terms are in pattern so they have enodes + pats.push_back(g); + p1 = m.mk_pattern(pats.size(), pats.data()); + } + + quantifier* q1 = m.mk_forall(sorts.size(), sorts.data(), names.data(), body); + m_pat2hopat.insert(p, p1); m_hopat2pat.insert(p1, p); + m_q2hoq.insert(q, q1); + m_hoq2q.insert(q1, q); + m_hopat2free_vars.insert(p1, free_vars); m_ho_patterns.push_back(p1); + m_ho_qs.push_back(q1); trail().push(push_back_vector(m_ho_patterns)); + trail().push(push_back_vector(m_ho_qs)); trail().push(insert_map(m_pat2hopat, p)); trail().push(insert_map(m_hopat2pat, p1)); trail().push(insert_map(m_pat2abs, p)); - return p1; + trail().push(insert_map(m_q2hoq, q)); + trail().push(insert_map(m_hoq2q, q1)); + trail().push(insert_map(m_hopat2free_vars, p1)); + p = p1; + return q1; } bool ho_matcher::is_ho_pattern(app* p) { return m_hopat2pat.contains(p); } - void ho_matcher::refine_ho_match(app* p, expr_ref_vector const& s) { + void ho_matcher::refine_ho_match(app* p, expr_ref_vector& s) { auto fo_pat = m_hopat2pat[p]; m_trail.push_scope(); + m_subst.resize(0); m_subst.resize(s.size()); m_goals.reset(); for (unsigned i = 0; i < s.size(); ++i) { - if (s[i]) - m_subst.set(i, s[i]); + auto idx = s.size() - i - 1; + if (!m_hopat2free_vars[p].contains(idx)) + s[i] = m.mk_var(idx, s[i]->get_sort()); + else if (s.get(i)) + m_subst.set(i, s.get(i)); } + IF_VERBOSE(1, verbose_stream() << "refine " << mk_pp(p, m) << "\n" << s << "\n"); + unsigned num_bound = 0, level = 0; for (auto [v, pat] : m_pat2abs[fo_pat]) { - var_subst sub(m, false); + var_subst sub(m, true); auto pat_refined = sub(pat, s); IF_VERBOSE(1, verbose_stream() << mk_pp(pat, m) << " -> " << pat_refined << "\n"); - m_goals.push(level, num_bound, pat_refined, s[v]); + m_goals.push(level, num_bound, pat_refined, s.get(s.size() - v - 1)); } search(); diff --git a/src/ast/euf/ho_matcher.h b/src/ast/euf/ho_matcher.h index 5c53d3a16..3bd5b9d2f 100644 --- a/src/ast/euf/ho_matcher.h +++ b/src/ast/euf/ho_matcher.h @@ -316,8 +316,10 @@ namespace euf { mutable array_rewriter m_rewriter; array_util m_array; obj_map m_pat2hopat, m_hopat2pat; + obj_map m_q2hoq, m_hoq2q; + obj_map m_hopat2free_vars; obj_map>> m_pat2abs; - expr_ref_vector m_ho_patterns; + expr_ref_vector m_ho_patterns, m_ho_qs; void resume(); @@ -373,7 +375,8 @@ namespace euf { m_unitary(m), m_rewriter(m), m_array(m), - m_ho_patterns(m) + m_ho_patterns(m), + m_ho_qs(m) { } @@ -383,11 +386,15 @@ namespace euf { void operator()(expr* pat, expr* t, unsigned num_bound, unsigned num_vars); - app* compile_ho_pattern(quantifier* q, app* p); + quantifier* compile_ho_pattern(quantifier* q, app*& p); bool is_ho_pattern(app* p); - void refine_ho_match(app* p, expr_ref_vector const& s); + void refine_ho_match(app* p, expr_ref_vector& s); + + bool is_free(app* p, unsigned i) const { return m_hopat2free_vars[p].contains(i); } + + quantifier* hoq2q(quantifier* q) const { return m_hoq2q[q]; } }; } diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index b96ac22cc..03410d95a 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -70,7 +70,8 @@ namespace euf { m_canonical_proofs(m), // m_infer_patterns(m, m_smt_params), m_deps(m), - m_rewriter(m) { + m_rewriter(m), + m_matcher(m, m_trail) { m_tt = m_egraph.mk(m.mk_true(), 0, 0, nullptr); m_ff = m_egraph.mk(m.mk_false(), 0, 0, nullptr); m_rewriter.set_order_eq(true); @@ -92,6 +93,39 @@ namespace euf { m_egraph.add_plugin(alloc(arith_plugin, m_egraph)); m_egraph.add_plugin(alloc(bv_plugin, m_egraph)); + + std::function on_match = + [&](ho_subst& s) { + IF_VERBOSE(1, s.display(verbose_stream() << "on-match\n") << "\n"); + auto& b = *m_ho_binding; + auto* hopat = b.m_pattern; + auto* hoq = b.m_q; + auto* q = m_matcher.hoq2q(hoq); + // shrink binding + expr_ref_vector binding(m); + for (unsigned i = 0; i < s.size(); ++i) + binding.push_back(s.get(i)); + binding.reverse(); + if (binding.size() > q->get_num_decls()) { + bool change = true; + while (change) { + change = false; + for (unsigned i = binding.size(); i-- > 0;) { + var_subst sub(m, false); + auto r = sub(binding.get(i), binding); + change |= r != binding.get(i); + binding[i] = r; + } + } + } + binding.shrink(q->get_num_decls()); + binding.reverse(); + + IF_VERBOSE(1, verbose_stream() << binding << "\n"); + apply_binding(b, q, binding); + }; + + m_matcher.set_on_match(on_match); } completion::~completion() { @@ -108,6 +142,7 @@ namespace euf { void completion::updt_params(params_ref const& p) { smt_params_helper sp(p); m_max_instantiations = sp.qi_max_instances(); + // m_max_generation = sp.qi_max_generation(); } struct completion::push_watch_rule : public trail { @@ -222,6 +257,7 @@ namespace euf { void completion::add_constraint(expr* f, proof* pr, expr_dependency* d) { if (m_egraph.inconsistent()) return; + TRACE(euf_completion, tout << mk_pp(f, m) << "\n"); auto add_children = [&](enode* n) { for (auto* ch : enode_args(n)) m_nodes_to_canonize.push_back(ch); @@ -234,12 +270,14 @@ namespace euf { m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); add_children(a); add_children(b); + m_should_propagate = true; } else if (m.is_not(f, f)) { enode* n = mk_enode(f); auto j = to_ptr(push_pr_dep(pr, d)); m_egraph.new_diseq(n, j); add_children(n); + m_should_propagate = true; } else { enode* n = mk_enode(f); @@ -255,13 +293,15 @@ namespace euf { q = to_quantifier(tmp); } #endif - ptr_vector ground; + for (unsigned i = 0; i < q->get_num_patterns(); ++i) { auto p = to_app(q->get_pattern(i)); + auto q1 = m_matcher.compile_ho_pattern(q, p); + ptr_vector ground; mam::ground_subterms(p, ground); for (expr* g : ground) mk_enode(g); - m_mam->add_pattern(q, p); + m_mam->add_pattern(q1, p); } m_q2dep.insert(q, { pr, d}); get_trail().push(insert_obj_map(m_q2dep, q)); @@ -295,7 +335,7 @@ namespace euf { if (m.is_true(n->get_root()->get_expr())) return l_false; } - if (m_side_condition_solver) { + if (m_side_condition_solver && m_propagate_with_solver) { expr_dependency* sd = nullptr; if (m_side_condition_solver->is_true(f, pr, sd)) { add_constraint(f, pr, sd); @@ -363,6 +403,7 @@ namespace euf { } void completion::propagate_all_rules() { + flet _propagate_with_solver(m_propagate_with_solver, true); for (auto* r : m_rules) if (!r->m_in_queue) r->m_in_queue = true, @@ -456,6 +497,8 @@ namespace euf { void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned max_global, unsigned min_top, unsigned max_top) { if (should_stop()) return; + if (max_top >= m_max_generation) + return; auto* b = alloc_binding(q, pat, binding, max_global, min_top, max_top); if (!b) return; @@ -487,23 +530,21 @@ namespace euf { void completion::apply_binding(binding& b) { if (should_stop()) return; -#if 0 - if (is_ho_binding(b)) - apply_ho_binding(b); - else -#endif - { - expr_ref_vector _binding(m); - quantifier* q = b.m_q; - for (unsigned i = 0; i < q->get_num_decls(); ++i) - _binding.push_back(b.m_nodes[i]->get_expr()); - apply_binding(b, _binding); + expr_ref_vector _binding(m); + quantifier* q = b.m_q; + for (unsigned i = 0; i < q->get_num_decls(); ++i) + _binding.push_back(b.m_nodes[i]->get_expr()); + if (m_matcher.is_ho_pattern(b.m_pattern)) { + flet set_binding(m_ho_binding, &b); + m_matcher.refine_ho_match(b.m_pattern, _binding); } + else + apply_binding(b, q, _binding); + } - void completion::apply_binding(binding& b, expr_ref_vector const& s) { + void completion::apply_binding(binding& b, quantifier* q, expr_ref_vector const& s) { var_subst subst(m); - quantifier* q = b.m_q; expr_ref r = subst(q->get_expr(), s); scoped_generation sg(*this, b.m_max_top_generation + 1); auto [pr, d] = get_dependency(q); @@ -512,10 +553,8 @@ namespace euf { add_constraint(r, pr, d); propagate_rules(); m_egraph.propagate(); - m_should_propagate = true; } - void completion::read_egraph() { if (m_egraph.inconsistent()) { auto* d = explain_conflict(); diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index edd71b160..cda12aebf 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -25,6 +25,7 @@ Author: #include "ast/simplifiers/dependent_expr_state.h" #include "ast/euf/euf_egraph.h" #include "ast/euf/euf_mam.h" +#include "ast/euf/ho_matcher.h" #include "ast/rewriter/th_rewriter.h" // include "ast/pattern/pattern_inference.h" #include "params/smt_params.h" @@ -133,18 +134,22 @@ namespace euf { bindings m_bindings; scoped_ptr m_tmp_binding; unsigned m_tmp_binding_capacity = 0; + binding* m_ho_binding = nullptr; expr_dependency_ref_vector m_deps; obj_map> m_q2dep; vector> m_pr_dep; unsigned m_epoch = 0; unsigned_vector m_epochs; th_rewriter m_rewriter; + ho_matcher m_matcher; stats m_stats; scoped_ptr m_side_condition_solver; ptr_vector m_rules; bool m_has_new_eq = false; bool m_should_propagate = false; + bool m_propagate_with_solver = false; unsigned m_max_instantiations = std::numeric_limits::max(); + unsigned m_max_generation = 10; unsigned m_generation = 0; vector> m_rule_watch; @@ -176,7 +181,7 @@ namespace euf { binding* alloc_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top); void insert_binding(binding* b); void apply_binding(binding& b); - void apply_binding(binding& b, expr_ref_vector const& s); + void apply_binding(binding& b, quantifier* q, expr_ref_vector const& s); void flush_binding_queue(); vector> m_queue; diff --git a/src/tactic/portfolio/euf_completion_tactic.cpp b/src/tactic/portfolio/euf_completion_tactic.cpp index a418c43c6..a2d00ff74 100644 --- a/src/tactic/portfolio/euf_completion_tactic.cpp +++ b/src/tactic/portfolio/euf_completion_tactic.cpp @@ -88,6 +88,8 @@ public: expr_ref_vector core(m); m_solver->get_unsat_core(core); for (auto c : core) { + if (c == nf) + continue; auto [pr, dep] = m_e2d[c]; d = m.mk_join(d, dep); } @@ -96,6 +98,8 @@ public: SASSERT(pr); expr_safe_replace rep(m); for (auto c : core) { + if (c == nf) + continue; auto [p, dep] = m_e2d[c]; rep.insert(m.mk_asserted(c), p); } diff --git a/src/test/ho_matcher.cpp b/src/test/ho_matcher.cpp index 5d2329af4..a7f275812 100644 --- a/src/test/ho_matcher.cpp +++ b/src/test/ho_matcher.cpp @@ -23,7 +23,7 @@ namespace euf { m_f = m.mk_func_decl(symbol("f"), m_int, m_int, m_int); std::function on_match = [&](ho_subst& s) { - s.display(verbose_stream() << "match\n"); + s.display(verbose_stream() << "match\n"); }; m_matcher.set_on_match(on_match);