diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index d73bfba01..9dd814a0e 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -330,6 +330,7 @@ namespace euf { euf::enode* mk_enode(expr* e, unsigned n, enode* const* args) { return m_egraph.mk(e, m_generation, n, args); } expr* bool_var2expr(sat::bool_var v) const { return m_bool_var2expr.get(v, nullptr); } expr_ref literal2expr(sat::literal lit) const { expr* e = bool_var2expr(lit.var()); return lit.sign() ? expr_ref(m.mk_not(e), m) : expr_ref(e, m); } + unsigned generation() const { return m_generation; } sat::literal attach_lit(sat::literal lit, expr* e); void unhandled_function(func_decl* f); diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index bab662b04..648c94651 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -47,11 +47,23 @@ namespace q { ~scoped_mark_reset() { e.m_mark.reset(); } }; + unsigned ematch::fingerprint::hash() const { + NOT_IMPLEMENTED_YET(); + return 0; + } + + bool ematch::fingerprint::eq(fingerprint const& other) const { + NOT_IMPLEMENTED_YET(); + return false; + } + + ematch::ematch(euf::solver& ctx, solver& s): ctx(ctx), m_qs(s), m(ctx.get_manager()), - m_infer_patterns(m, ctx.get_config()) + m_infer_patterns(m, ctx.get_config()), + m_qstat_gen(m, ctx.get_region()) { std::function _on_merge = [&](euf::enode* root, euf::enode* other) { @@ -59,7 +71,7 @@ namespace q { }; std::function _on_make = [&](euf::enode* n) { - m_mam->relevant_eh(n, false); + m_mam->add_node(n, false); }; ctx.get_egraph().set_on_merge(_on_merge); ctx.get_egraph().set_on_make(_on_make); @@ -295,10 +307,10 @@ namespace q { } }; - ematch::binding* ematch::alloc_binding(unsigned n) { + ematch::binding* ematch::alloc_binding(unsigned n, unsigned max_generation, unsigned min_top, unsigned max_top) { unsigned sz = sizeof(binding) + sizeof(euf::enode* const*)*n; void* mem = ctx.get_region().allocate(sz); - return new (mem) binding(); + return new (mem) binding(max_generation, min_top, max_top); } std::ostream& ematch::lit::display(std::ostream& out) const { @@ -313,10 +325,9 @@ namespace q { << mk_bounded_pp(rhs, rhs.m(), 2); } - - void ematch::clause::add_binding(ematch& em, euf::enode* const* _binding) { + void ematch::clause::add_binding(ematch& em, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) { unsigned n = num_decls(); - binding* b = em.alloc_binding(n); + binding* b = em.alloc_binding(n, max_generation, min_top, max_top); b->init(b); for (unsigned i = 0; i < n; ++i) b->m_nodes[i] = _binding[i]; @@ -324,11 +335,11 @@ namespace q { em.ctx.push(remove_binding(*this, b)); } - void ematch::on_binding(quantifier* q, app* pat, euf::enode* const* _binding) { + void ematch::on_binding(quantifier* q, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_gen, unsigned max_gen) { TRACE("q", tout << "on-binding " << mk_pp(q, m) << "\n";); clause& c = *m_clauses[m_q2clauses[q]]; if (!propagate(_binding, c)) - c.add_binding(*this, _binding); + c.add_binding(*this, _binding, max_generation, min_gen, max_gen); } std::ostream& ematch::clause::display(euf::solver& ctx, std::ostream& out) const { @@ -397,32 +408,49 @@ namespace q { } TRACE("q", tout << "instantiate " << (idx == UINT_MAX ? "clause is false":"unit propagate") << "\n";); -#if 1 auto j_idx = mk_justification(idx, c, binding); if (idx == UINT_MAX) ctx.set_conflict(j_idx); else ctx.propagate(instantiate(c, binding, c[idx]), j_idx); -#else - instantiate(c, binding); -#endif return true; } // vanilla instantiation method. - void ematch::instantiate(euf::enode* const* binding, clause& c) { - expr_ref_vector _binding(m); - quantifier* q = c.m_q; + void ematch::instantiate(binding& b, clause& c) { + expr_ref_vector _nodes(m); + quantifier* q = c.m_q; + if (m_stats.m_num_instantiations > ctx.get_config().m_qi_max_instances) + return; + unsigned max_generation = b.m_max_generation; + max_generation = std::max(max_generation, c.m_stat->get_generation()); + c.m_stat->update_max_generation(max_generation); +#if 0 + fingerprint * f = add_fingerprint(c, b, max_generation); + if (f) { + m_queue.insert(f, max_generation); + m_stats.m_num_instantiations++; + } + return; +#endif + + m_stats.m_num_instantiations++; + for (unsigned i = 0; i < c.num_decls(); ++i) - _binding.push_back(binding[i]->get_expr()); + _nodes.push_back(b.m_nodes[i]->get_expr()); var_subst subst(m); - expr_ref result = subst(q->get_expr(), _binding); + expr_ref result = subst(q->get_expr(), _nodes); sat::literal result_l = ctx.mk_literal(result); if (is_exists(q)) result_l.neg(); m_qs.add_clause(c.m_literal, result_l); } + ematch::fingerprint* ematch::add_fingerprint(clause& c, binding& b, unsigned max_generation) { + NOT_IMPLEMENTED_YET(); + return nullptr; + } + sat::literal ematch::instantiate(clause& c, euf::enode* const* binding, lit const& l) { expr_ref_vector _binding(m); quantifier* q = c.m_q; @@ -443,6 +471,11 @@ namespace q { } lbool ematch::compare(unsigned n, euf::enode* const* binding, expr* s, expr* t) { + if (s == t) + return l_true; + if (m.are_distinct(s, t)) + return l_false; + euf::enode* sn = eval(n, binding, s); euf::enode* tn = eval(n, binding, t); if (sn) sn = sn->get_root(); @@ -459,8 +492,6 @@ namespace q { return l_undef; if (!sn && !tn) return compare_rec(n, binding, s, t); - if (!tn && !sn) - return l_undef; if (!tn && sn) { std::swap(tn, sn); std::swap(t, s); @@ -628,6 +659,14 @@ namespace q { q = to_quantifier(tmp); } cl->m_q = q; + unsigned generation = ctx.generation(); +#if 0 + unsigned _generation; + if (!m_cached_generation.empty() && m_cached_generation.find(q, _generation)) { + generation = _generation; + } +#endif + cl->m_stat = m_qstat_gen(_q, generation); SASSERT(ctx.s().value(cl->m_literal) == l_true); return cl; } @@ -714,7 +753,7 @@ namespace q { continue; instantiated = true; do { - instantiate(b->m_nodes, *c); + instantiate(*b, *c); b = b->next(); } while (b != c->m_bindings); diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index b19362761..46e897477 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -18,6 +18,7 @@ Author: #include "util/nat_set.h" #include "util/dlist.h" +#include "ast/quantifier_stat.h" #include "ast/pattern/pattern_inference.h" #include "solver/solver.h" #include "sat/smt/sat_th.h" @@ -56,25 +57,32 @@ namespace q { struct insert_binding; struct binding : public dll_base { + unsigned m_max_generation; + unsigned m_min_top_generation; + unsigned m_max_top_generation; euf::enode* m_nodes[0]; - binding() {} + binding(unsigned max_generation, unsigned min_top, unsigned max_top): + m_max_generation(max_generation), + m_min_top_generation(min_top), + m_max_top_generation(max_top) {} euf::enode* const* nodes() { return m_nodes; } }; - binding* alloc_binding(unsigned n); + binding* alloc_binding(unsigned n, unsigned max_generation, unsigned min_top, unsigned max_top); struct clause { vector m_lits; quantifier_ref m_q; sat::literal m_literal; + q::quantifier_stat* m_stat { nullptr }; binding* m_bindings { nullptr }; clause(ast_manager& m): m_q(m) {} - void add_binding(ematch& em, euf::enode* const* b); + void add_binding(ematch& em, euf::enode* const* b, unsigned max_generation, unsigned min_top, unsigned max_top); std::ostream& display(euf::solver& ctx, std::ostream& out) const; lit const& operator[](unsigned i) const { return m_lits[i]; } lit& operator[](unsigned i) { return m_lits[i]; } @@ -82,6 +90,24 @@ namespace q { unsigned num_decls() const { return m_q->get_num_decls(); } }; + struct fingerprint { + clause& c; + binding& b; + unsigned max_generation; + fingerprint(clause& c, binding& b, unsigned max_generation): + c(c), b(b), max_generation(max_generation) {} + unsigned hash() const; + bool eq(fingerprint const& other) const; + }; + + struct fingerprint_hash_proc { + bool operator()(fingerprint const* f) const { return f->hash(); } + }; + struct fingerprint_eq_proc { + bool operator()(fingerprint const* a, fingerprint const* b) const { return a->eq(*b); } + }; + typedef ptr_hashtable fingerprints; + struct justification { expr* m_lhs, *m_rhs; bool m_sign; @@ -104,6 +130,8 @@ namespace q { euf::solver& ctx; solver& m_qs; ast_manager& m; + q::quantifier_stat_gen m_qstat_gen; + fingerprints m_fingerprints; pattern_inference_rw m_infer_patterns; scoped_ptr m_mam, m_lazy_mam; ptr_vector m_clauses; @@ -130,7 +158,7 @@ namespace q { euf::enode* eval(unsigned n, euf::enode* const* binding, expr* e); bool propagate(euf::enode* const* binding, clause& c); - void instantiate(euf::enode* const* binding, clause& c); + void instantiate(binding& b, clause& c); sat::literal instantiate(clause& c, euf::enode* const* binding, lit const& l); // register as callback into egraph. @@ -150,6 +178,8 @@ namespace q { void attach_ground_pattern_terms(expr* pat); clause* clausify(quantifier* q); + fingerprint* add_fingerprint(clause& c, binding& b, unsigned max_generation); + public: @@ -168,7 +198,7 @@ namespace q { void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing); // callback from mam - void on_binding(quantifier* q, app* pat, euf::enode* const* binding); + void on_binding(quantifier* q, app* pat, euf::enode* const* binding, unsigned max_generation, unsigned min_gen, unsigned max_gen); std::ostream& display(std::ostream& out) const; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const; diff --git a/src/sat/smt/q_mam.cpp b/src/sat/smt/q_mam.cpp index e821a3f99..680fc7d23 100644 --- a/src/sat/smt/q_mam.cpp +++ b/src/sat/smt/q_mam.cpp @@ -3821,13 +3821,13 @@ namespace q { TRACE("trigger_bug", tout << "found match " << mk_pp(qa, m) << "\n";); unsigned min_gen = 0, max_gen = 0; m_interpreter.get_min_max_top_generation(min_gen, max_gen); - m_ematch.on_binding(qa, pat, bindings); // max_generation); // , min_gen, max_gen; + m_ematch.on_binding(qa, pat, bindings, max_generation, min_gen, max_gen); } // This method is invoked when n becomes relevant. // If lazy == true, then n is not added to the list of // candidate enodes for matching. That is, the method just updates the lbls. - void relevant_eh(enode * n, bool lazy) override { + void add_node(enode * n, bool lazy) override { TRACE("trigger_bug", tout << "relevant_eh:\n" << mk_ismt2_pp(n->get_expr(), m) << "\n"; tout << "mam: " << this << "\n";); TRACE("mam", tout << "relevant_eh: #" << n->get_expr_id() << "\n";); diff --git a/src/sat/smt/q_mam.h b/src/sat/smt/q_mam.h index 42de86773..c396a319b 100644 --- a/src/sat/smt/q_mam.h +++ b/src/sat/smt/q_mam.h @@ -49,7 +49,7 @@ namespace q { virtual void add_pattern(quantifier * q, app * mp) = 0; - virtual void relevant_eh(enode * n, bool lazy) = 0; + virtual void add_node(enode * n, bool lazy) = 0; virtual void propagate() = 0; @@ -59,15 +59,16 @@ namespace q { virtual void on_merge(enode * root, enode * other) = 0; + virtual void on_match(quantifier * qa, app * pat, unsigned num_bindings, enode * const * bindings, unsigned max_generation) = 0; + virtual void reset() = 0; virtual std::ostream& display(std::ostream& out) = 0; virtual bool check_missing_instances() = 0; - virtual void on_match(quantifier * qa, app * pat, unsigned num_bindings, enode * const * bindings, unsigned max_generation) = 0; - static void ground_subterms(expr* e, ptr_vector& ground); + }; };