From 4af9132f2ea755425148f4ef2bda4d56e9e58df6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 29 Jan 2021 13:39:14 -0800 Subject: [PATCH] more ematching --- src/ast/euf/euf_egraph.cpp | 27 +++ src/ast/euf/euf_egraph.h | 5 + src/ast/euf/euf_enode.h | 8 +- src/ast/pattern/pattern_inference.cpp | 4 +- src/ast/pattern/pattern_inference.h | 6 +- src/sat/smt/euf_solver.cpp | 7 + src/sat/smt/euf_solver.h | 1 + src/sat/smt/q_ematch.cpp | 233 ++++++++++++++++++-------- src/sat/smt/q_ematch.h | 33 +++- src/sat/smt/q_mam.cpp | 32 +--- src/sat/smt/q_solver.cpp | 9 +- src/sat/smt/q_solver.h | 6 +- 12 files changed, 263 insertions(+), 108 deletions(-) diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index a1e83cb46..3bb35320b 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -286,6 +286,10 @@ namespace euf { m_updates.push_back(update_record(n, update_record::value_assignment())); } + void egraph::set_lbl_hash(enode* n) { + NOT_IMPLEMENTED_YET(); + } + void egraph::pop(unsigned num_scopes) { if (num_scopes <= m_num_scopes) { m_num_scopes -= num_scopes; @@ -654,6 +658,27 @@ namespace euf { explain_todo(justifications); } + template + unsigned egraph::explain_diseq(ptr_vector& justifications, enode* a, enode* b) { + enode* ra = a->get_root(), * rb = b->get_root(); + SASSERT(ra != rb); + if (ra->interpreted() && rb->interpreted()) { + explain_eq(justifications, a, ra); + explain_eq(justifications, b, rb); + return UINT_MAX; + } + expr_ref eq(m.mk_eq(a->get_expr(), b->get_expr()), m); + m_tmp_eq->m_args[0] = a; + m_tmp_eq->m_args[1] = b; + m_tmp_eq->m_expr = eq; + SASSERT(m_tmp_eq->num_args() == 2); + enode* r = m_table.find(m_tmp_eq); + SASSERT(r && r->get_root()->value() == l_false); + explain_eq(justifications, r, r->get_root()); + return r->get_root()->bool_var(); + } + + template void egraph::explain_todo(ptr_vector& justifications) { for (unsigned i = 0; i < m_todo.size(); ++i) { @@ -771,10 +796,12 @@ namespace euf { template void euf::egraph::explain(ptr_vector& justifications); template void euf::egraph::explain_todo(ptr_vector& justifications); template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b); +template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, enode* a, enode* b); template void euf::egraph::explain(ptr_vector& justifications); template void euf::egraph::explain_todo(ptr_vector& justifications); template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b); +template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, enode* a, enode* b); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index dda68767c..b475600e7 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -267,6 +267,9 @@ namespace euf { void next_literal() { force_push(); SASSERT(m_new_lits_qhead < m_new_lits.size()); m_new_lits_qhead++; } void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } + void set_lbl_hash(enode* n); + + void add_th_var(enode* n, theory_var v, theory_id id); void set_th_propagates_diseqs(theory_id id); void set_merge_enabled(enode* n, bool enable_merge); @@ -285,6 +288,8 @@ namespace euf { void explain(ptr_vector& justifications); template void explain_eq(ptr_vector& justifications, enode* a, enode* b); + template + unsigned explain_diseq(ptr_vector& justifications, enode* a, enode* b); enode_vector const& nodes() const { return m_nodes; } ast_manager& get_manager() { return m; } diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 17136d31d..61e68665d 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -205,9 +205,11 @@ namespace euf { bool children_are_roots() const; enode* get_next() const { return m_next; } - bool has_lbl_hash() const { UNREACHABLE(); return false; } // TODO - unsigned char get_lbl_hash() const { UNREACHABLE(); return 0; } // TOD0 - void set_lbl_hash(egraph& e) { UNREACHABLE(); } + bool has_lbl_hash() const { return m_lbl_hash >= 0; } + unsigned char get_lbl_hash() const { + SASSERT(m_lbl_hash >= 0 && static_cast(m_lbl_hash) < approx_set_traits::capacity); + return static_cast(m_lbl_hash); + } approx_set & get_lbls() { return m_lbls; } approx_set & get_plbls() { return m_plbls; } const approx_set & get_lbls() const { return m_lbls; } diff --git a/src/ast/pattern/pattern_inference.cpp b/src/ast/pattern/pattern_inference.cpp index 1cc9cfe52..03e4300b5 100644 --- a/src/ast/pattern/pattern_inference.cpp +++ b/src/ast/pattern/pattern_inference.cpp @@ -101,7 +101,7 @@ static void dump_app_vector(std::ostream & out, ptr_vector const & v, ast_m #include "ast/pattern/database.h" -pattern_inference_cfg::pattern_inference_cfg(ast_manager & m, pattern_inference_params & params): +pattern_inference_cfg::pattern_inference_cfg(ast_manager & m, pattern_inference_params const & params): m(m), m_params(params), m_bfid(m.get_basic_family_id()), @@ -724,7 +724,7 @@ bool pattern_inference_cfg::reduce_quantifier( return true; } -pattern_inference_rw::pattern_inference_rw(ast_manager& m, pattern_inference_params & params): +pattern_inference_rw::pattern_inference_rw(ast_manager& m, pattern_inference_params const & params): rewriter_tpl(m, m.proofs_enabled(), m_cfg), m_cfg(m, params) {} diff --git a/src/ast/pattern/pattern_inference.h b/src/ast/pattern/pattern_inference.h index af95bcb6c..d036ad789 100644 --- a/src/ast/pattern/pattern_inference.h +++ b/src/ast/pattern/pattern_inference.h @@ -61,7 +61,7 @@ public: class pattern_inference_cfg : public default_rewriter_cfg { ast_manager& m; - pattern_inference_params & m_params; + pattern_inference_params const & m_params; family_id m_bfid; family_id m_afid; svector m_forbidden; @@ -215,7 +215,7 @@ class pattern_inference_cfg : public default_rewriter_cfg { app_ref_buffer & result); // OUT result public: - pattern_inference_cfg(ast_manager & m, pattern_inference_params & params); + pattern_inference_cfg(ast_manager & m, pattern_inference_params const & params); void register_forbidden_family(family_id fid) { SASSERT(fid != m_bfid); @@ -252,7 +252,7 @@ public: class pattern_inference_rw : public rewriter_tpl { pattern_inference_cfg m_cfg; public: - pattern_inference_rw(ast_manager& m, pattern_inference_params & params); + pattern_inference_rw(ast_manager& m, pattern_inference_params const & params); }; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 7ae2875d3..6658caf6c 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -222,6 +222,13 @@ namespace euf { m_egraph.explain_eq(m_explain, a, b); } + void solver::add_diseq_antecedent(enode* a, enode* b) { + sat::bool_var v = get_egraph().explain_diseq(m_explain, a, b); + SASSERT(v == sat::null_bool_var || s().value(v) == l_false); + if (v != sat::null_bool_var) + m_explain.push_back(to_ptr(sat::literal(v, false))); + } + bool solver::propagate(enode* a, enode* b, ext_justification_idx idx) { if (a->get_root() == b->get_root()) return false; diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index f03a7a109..d73bfba01 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -279,6 +279,7 @@ namespace euf { void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; void get_antecedents(literal l, th_propagation& jst, literal_vector& r, bool probing); void add_antecedent(enode* a, enode* b); + void add_diseq_antecedent(enode* a, enode* b); void asserted(literal l) override; sat::check_result check() override; void push() override; diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index 77b20348b..b22b01cc0 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -16,13 +16,16 @@ Author: Todo: - clausify -- propagate without instantiations, produce explanations for eval - generations - insert instantiations into priority queue - cache instantiations and substitutions - nested quantifiers - non-cnf quantifiers (handled in q_solver) +Done: + +- propagate without instantiations, produce explanations for eval + --*/ #include "ast/ast_util.h" @@ -46,7 +49,8 @@ namespace q { ematch::ematch(euf::solver& ctx, solver& s): ctx(ctx), m_qs(s), - m(ctx.get_manager()) + m(ctx.get_manager()), + m_infer_patterns(m, ctx.get_config()) { std::function _on_merge = [&](euf::enode* root, euf::enode* other) { @@ -73,30 +77,64 @@ namespace q { } } - void ematch::explain(clause& c, unsigned literal_idx, binding& b) { - ctx.get_egraph().begin_explain(); - m_explain.reset(); - unsigned n = c.m_q->get_num_decls(); + sat::ext_justification_idx ematch::mk_justification(unsigned idx, clause& c, euf::enode* const* b) { + void* mem = ctx.get_region().allocate(justification::get_obj_size()); + sat::constraint_base::initialize(mem, &m_qs); + bool sign = false; + expr* l = nullptr, *r = nullptr; + lit lit(expr_ref(l,m), expr_ref(r, m), sign); + if (idx != UINT_MAX) + lit = c[idx]; + auto* constraint = new (sat::constraint_base::ptr2mem(mem)) justification(lit, c, b); + return constraint->to_index(); + } + + void ematch::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) { + auto& j = justification::from_index(idx); + clause& c = j.m_clause; + unsigned l_idx = 0; + for (; l_idx < c.size(); ++l_idx) { + if (c[l_idx].lhs == j.m_lhs && c[l_idx].rhs == j.m_rhs && c[l_idx].sign == j.m_sign) + break; + } + explain(c, l_idx, j.m_binding); + r.push_back(c.m_literal); + (void)probing; // ignored + } + + std::ostream& ematch::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { + auto& j = justification::from_index(idx); + auto& c = j.m_clause; + out << "ematch: "; + for (auto const& lit : c.m_lits) + lit.display(out) << " "; + unsigned num_decls = c.num_decls(); + for (unsigned i = 0; i < num_decls; ++i) + out << ctx.bpp(j.m_binding[i]) << " "; + out << "-> "; + lit lit(expr_ref(j.m_lhs, m), expr_ref(j.m_rhs, m), j.m_sign); + if (j.m_lhs) + lit.display(out); + else + out << "false"; + return out; + } + + void ematch::explain(clause& c, unsigned literal_idx, euf::enode* const* b) { + unsigned n = c.num_decls(); for (unsigned i = c.size(); i-- > 0; ) { if (i == literal_idx) continue; auto const& lit = c[i]; - lit.sign; - lit.lhs; - lit.rhs; - if (lit.sign) { - SASSERT(l_true == compare(n, b.m_nodes, lit.lhs, lit.rhs)); - explain_eq(n, b.m_nodes, lit.lhs, lit.rhs); - } - else { - SASSERT(l_false == compare(n, b.m_nodes, lit.lhs, lit.rhs)); - explain_diseq(n, b.m_nodes, lit.lhs, lit.rhs); - } + if (lit.sign) + explain_eq(n, b, lit.lhs, lit.rhs); + else + explain_diseq(n, b, lit.lhs, lit.rhs); } - ctx.get_egraph().end_explain(); } void ematch::explain_eq(unsigned n, euf::enode* const* binding, expr* s, expr* t) { + SASSERT(l_true == compare(n, binding, s, t)); if (s == t) return; euf::enode* sn = eval(n, binding, s); @@ -111,28 +149,29 @@ namespace q { std::swap(s, t); } if (sn && !tn) { - ctx.add_antecedent(sn, sn->get_root()); for (euf::enode* s1 : euf::enode_class(sn)) { if (l_true == compare_rec(n, binding, t, s1->get_expr())) { + ctx.add_antecedent(sn, s1); explain_eq(n, binding, t, s1->get_expr()); return; } } UNREACHABLE(); } - SASSERT(is_app(s) && is_app(t) && to_app(s)->get_decl() == to_app(t)->get_decl()); + SASSERT(is_app(s) && is_app(t)); + SASSERT(to_app(s)->get_decl() == to_app(t)->get_decl()); for (unsigned i = to_app(s)->get_num_args(); i-- > 0; ) explain_eq(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i)); } void ematch::explain_diseq(unsigned n, euf::enode* const* binding, expr* s, expr* t) { + SASSERT(l_false == compare(n, binding, s, t)); if (m.are_distinct(s, t)) return; euf::enode* sn = eval(n, binding, s); euf::enode* tn = eval(n, binding, t); - if (sn && tn) { - SASSERT(sn->get_root() == tn->get_root()); - ctx.add_antecedent(sn, tn); + if (sn && tn && ctx.get_egraph().are_diseq(sn, tn)) { + ctx.add_diseq_antecedent(sn, tn); return; } if (!sn && tn) { @@ -140,19 +179,22 @@ namespace q { std::swap(s, t); } if (sn && !tn) { - ctx.add_antecedent(sn, sn->get_root()); for (euf::enode* s1 : euf::enode_class(sn)) { if (l_false == compare_rec(n, binding, t, s1->get_expr())) { + ctx.add_antecedent(sn, s1); explain_diseq(n, binding, t, s1->get_expr()); return; } } UNREACHABLE(); } - SASSERT(is_app(s) && is_app(t) && to_app(s)->get_decl() == to_app(t)->get_decl()); - for (unsigned i = to_app(s)->get_num_args(); i-- > 0; ) { - if (l_false == compare_rec(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i))) { - explain_eq(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i)); + SASSERT(is_app(s) && is_app(t)); + app* at = to_app(t); + app* as = to_app(s); + SASSERT(as->get_decl() == at->get_decl()); + for (unsigned i = as->get_num_args(); i-- > 0; ) { + if (l_false == compare_rec(n, binding, as->get_arg(i), at->get_arg(i))) { + explain_eq(n, binding, as->get_arg(i), at->get_arg(i)); return; } } @@ -170,6 +212,7 @@ namespace q { }; void ematch::on_merge(euf::enode* root, euf::enode* other) { + TRACE("q", tout << "on-merge " << ctx.bpp(root) << " " << ctx.bpp(other) << "\n";); SASSERT(root->get_root() == other->get_root()); unsigned root_id = root->get_expr_id(); unsigned other_id = other->get_expr_id(); @@ -252,9 +295,23 @@ namespace q { return new (mem) binding(); } + std::ostream& ematch::lit::display(std::ostream& out) const { + ast_manager& m = lhs.m(); + if (m.is_true(rhs) && !sign) + return out << lhs; + if (m.is_false(rhs) && !sign) + return out << "(not " << lhs << ")"; + return + out << mk_bounded_pp(lhs, lhs.m(), 2) + << (sign ? " != " : " == ") + << mk_bounded_pp(rhs, rhs.m(), 2); + } + + void ematch::clause::add_binding(ematch& em, euf::enode* const* _binding) { - unsigned n = m_q->get_num_decls(); + unsigned n = num_decls(); binding* b = em.alloc_binding(n); + b->init(b); for (unsigned i = 0; i < n; ++i) b->m_nodes[i] = _binding[i]; binding::push_to_front(m_bindings, b); @@ -262,6 +319,7 @@ namespace q { } void ematch::on_binding(quantifier* q, app* pat, euf::enode* const* _binding) { + TRACE("q", tout << "on-binding " << mk_pp(q, m) << "\n";); clause& c = *m_clauses[m_q2clauses[q]]; if (!propagate(_binding, c)) c.add_binding(*this, _binding); @@ -270,14 +328,11 @@ namespace q { std::ostream& ematch::clause::display(euf::solver& ctx, std::ostream& out) const { out << "clause:\n"; for (auto const& lit : m_lits) - out << mk_bounded_pp(lit.lhs, lit.lhs.m(), 2) - << (lit.sign ? " != " : " == ") - << mk_bounded_pp(lit.rhs, lit.rhs.m(), 2) << "\n"; - unsigned num_decls = m_q->get_num_decls(); + lit.display(out) << "\n"; binding* b = m_bindings; if (b) { do { - for (unsigned i = 0; i < num_decls; ++i) + for (unsigned i = 0; i < num_decls(); ++i) out << ctx.bpp(b->nodes()[i]) << " "; out << "\n"; b = b->next(); @@ -294,19 +349,22 @@ namespace q { unsigned idx = UINT_MAX; unsigned sz = c.m_lits.size(); - unsigned n = c.m_q->get_num_decls(); + unsigned n = c.num_decls(); + m_indirect_nodes.reset(); for (unsigned i = 0; i < sz; ++i) { + unsigned lim = m_indirect_nodes.size(); lit l = c[i]; - m_indirect_nodes.reset(); lbool cmp = compare(n, binding, l.lhs, l.rhs); switch (cmp) { case l_false: + m_indirect_nodes.shrink(lim); if (!l.sign) break; if (i > 0) std::swap(c[0], c[i]); return true; case l_true: + m_indirect_nodes.shrink(lim); if (l.sign) break; if (i > 0) @@ -319,7 +377,7 @@ namespace q { // to watch for (euf::enode* n : m_indirect_nodes) add_watch(n, clause_idx); - for (unsigned j = c.m_q->get_num_decls(); j-- > 0; ) + for (unsigned j = c.num_decls(); j-- > 0; ) add_watch(binding[j], clause_idx); if (i > 1) std::swap(c[1], c[i]); @@ -332,7 +390,16 @@ namespace q { } } TRACE("q", tout << "instantiate " << (idx == UINT_MAX ? "clause is false":"unit propagate") << "\n";); - instantiate(binding, c); + +#if 1 + auto j_idx = mk_justification(idx, c, binding); + if (idx == UINT_MAX) + ctx.set_conflict(j_idx); + else + ctx.propagate(instantiate(c, binding, c[idx]), j_idx); +#else + instantiate(c, binding); +#endif return true; } @@ -340,14 +407,33 @@ namespace q { void ematch::instantiate(euf::enode* const* binding, clause& c) { expr_ref_vector _binding(m); quantifier* q = c.m_q; - for (unsigned i = 0; i < q->get_num_decls(); ++i) + for (unsigned i = 0; i < c.num_decls(); ++i) _binding.push_back(binding[i]->get_expr()); var_subst subst(m); expr_ref result = subst(q->get_expr(), _binding); - if (is_forall(q)) - m_qs.add_clause(~ctx.mk_literal(q), ctx.mk_literal(result)); - else - m_qs.add_clause(ctx.mk_literal(q), ~ctx.mk_literal(result)); + sat::literal result_l = ctx.mk_literal(result); + if (is_exists(q)) + result_l.neg(); + m_qs.add_clause(c.m_literal, result_l); + } + + sat::literal ematch::instantiate(clause& c, euf::enode* const* binding, lit const& l) { + expr_ref_vector _binding(m); + quantifier* q = c.m_q; + for (unsigned i = 0; i < c.num_decls(); ++i) + _binding.push_back(binding[i]->get_expr()); + var_subst subst(m); + if (m.is_true(l.rhs)) { + SASSERT(!l.sign); + return ctx.mk_literal(subst(l.lhs, _binding)); + } + else if (m.is_false(l.rhs)) { + SASSERT(!l.sign); + return ~ctx.mk_literal(subst(l.lhs, _binding)); + } + expr_ref fml(m.mk_eq(l.lhs, l.rhs), m); + fml = subst(fml, _binding); + return l.sign ? ~ctx.mk_literal(fml) : ctx.mk_literal(fml); } lbool ematch::compare(unsigned n, euf::enode* const* binding, expr* s, expr* t) { @@ -357,7 +443,7 @@ namespace q { if (tn) tn = tn->get_root(); TRACE("q", tout << mk_pp(s, m) << " ~~ " << mk_pp(t, m) << "\n"; tout << ctx.bpp(sn) << " " << ctx.bpp(tn) << "\n";); - + lbool c; if (sn && sn == tn) return l_true; @@ -367,14 +453,15 @@ namespace q { return l_undef; if (!sn && !tn) return compare_rec(n, binding, s, t); - if (!sn && tn) - for (euf::enode* t1 : euf::enode_class(tn)) - if (c = compare_rec(n, binding, s, t1->get_expr()), c != l_undef) - return c; - if (sn && !tn) - for (euf::enode* s1 : euf::enode_class(sn)) - if (c = compare_rec(n, binding, t, s1->get_expr()), c != l_undef) - return c; + if (!tn && !sn) + return l_undef; + if (!tn && sn) { + std::swap(tn, sn); + std::swap(t, s); + } + for (euf::enode* t1 : euf::enode_class(tn)) + if (c = compare_rec(n, binding, s, t1->get_expr()), c != l_undef) + return c; return l_undef; } @@ -480,6 +567,7 @@ namespace q { return false; bool propagated = false; ctx.push(value_trail(m_qhead)); + ptr_buffer to_remove; for (; m_qhead < m_queue.size(); ++m_qhead) { unsigned idx = m_queue[m_qhead]; clause& c = *m_clauses[idx]; @@ -487,14 +575,17 @@ namespace q { if (!b) continue; do { - binding* next = b->next(); - if (propagate(b->m_nodes, c)) { - binding::remove_from(c.m_bindings, b); - ctx.push(insert_binding(c, b)); - } - b = next; + if (propagate(b->m_nodes, c)) + to_remove.push_back(b); + b = b->next(); } while (b != c.m_bindings); + + for (binding* b : to_remove) { + binding::remove_from(c.m_bindings, b); + ctx.push(insert_binding(c, b)); + } + to_remove.reset(); } m_clause_in_queue.reset(); m_node_in_queue.reset(); @@ -504,16 +595,17 @@ namespace q { /** * basic clausifier, assumes q has been normalized. */ - ematch::clause* ematch::clausify(quantifier* q) { - clause* cl = alloc(clause); + ematch::clause* ematch::clausify(quantifier* _q) { + clause* cl = alloc(clause, m); + cl->m_literal = ctx.mk_literal(_q); + quantifier_ref q(_q, m); + if (is_exists(q)) { + cl->m_literal.neg(); + expr_ref body(mk_not(m, q->get_expr()), m); + q = m.update_quantifier(q, forall_k, body); + } expr_ref_vector ors(m); - if (is_forall(q)) - flatten_or(q->get_expr(), ors); - else { - flatten_and(q->get_expr(), ors); - for (unsigned i = 0; i < ors.size(); ++i) - ors[i] = mk_not(m, ors.get(i)); - } + flatten_or(q->get_expr(), ors); for (expr* arg : ors) { bool sign = m.is_not(arg, arg); expr* l, *r; @@ -524,7 +616,13 @@ namespace q { } cl->m_lits.push_back(lit(expr_ref(l, m), expr_ref(r, m), sign)); } + if (q->get_num_patterns() == 0) { + expr_ref tmp(m); + m_infer_patterns(q, tmp); + q = to_quantifier(tmp); + } cl->m_q = q; + SASSERT(ctx.s().value(cl->m_literal) == l_true); return cl; } @@ -591,6 +689,7 @@ namespace q { } bool ematch::operator()() { + TRACE("q", m_mam->display(tout);); if (propagate()) return true; if (m_lazy_mam) { diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index b54597553..b19362761 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -18,6 +18,7 @@ Author: #include "util/nat_set.h" #include "util/dlist.h" +#include "ast/pattern/pattern_inference.h" #include "solver/solver.h" #include "sat/smt/sat_th.h" #include "sat/smt/q_mam.h" @@ -48,7 +49,7 @@ namespace q { bool sign; lit(expr_ref const& lhs, expr_ref const& rhs, bool sign): lhs(lhs), rhs(rhs), sign(sign) {} - + std::ostream& display(std::ostream& out) const; }; struct remove_binding; @@ -67,23 +68,43 @@ namespace q { struct clause { vector m_lits; - quantifier* m_q; + quantifier_ref m_q; + sat::literal m_literal; binding* m_bindings { nullptr }; + clause(ast_manager& m): m_q(m) {} + void add_binding(ematch& em, euf::enode* const* b); std::ostream& display(euf::solver& ctx, std::ostream& out) const; lit const& operator[](unsigned i) const { return m_lits[i]; } lit& operator[](unsigned i) { return m_lits[i]; } unsigned size() const { return m_lits.size(); } - + unsigned num_decls() const { return m_q->get_num_decls(); } }; + struct justification { + expr* m_lhs, *m_rhs; + bool m_sign; + clause& m_clause; + euf::enode* const* m_binding; + justification(lit const& l, clause& c, euf::enode* const* b): + m_lhs(l.lhs), m_rhs(l.rhs), m_sign(l.sign), m_clause(c), m_binding(b) {} + sat::ext_constraint_idx to_index() const { + return sat::constraint_base::mem2base(this); + } + static justification& from_index(size_t idx) { + return *reinterpret_cast(sat::constraint_base::from_index(idx)->mem()); + } + static size_t get_obj_size() { return sat::constraint_base::obj_size(sizeof(justification)); } + }; + sat::ext_justification_idx mk_justification(unsigned idx, clause& c, euf::enode* const* b); struct pop_clause; euf::solver& ctx; solver& m_qs; ast_manager& m; + pattern_inference_rw m_infer_patterns; scoped_ptr m_mam, m_lazy_mam; ptr_vector m_clauses; obj_map m_q2clauses; @@ -110,6 +131,7 @@ namespace q { bool propagate(euf::enode* const* binding, clause& c); void instantiate(euf::enode* const* binding, clause& c); + sat::literal instantiate(clause& c, euf::enode* const* binding, lit const& l); // register as callback into egraph. void on_merge(euf::enode* root, euf::enode* other); @@ -121,7 +143,7 @@ namespace q { // extract explanation ptr_vector m_explain; - void explain(clause& c, unsigned literal_idx, binding& b); + void explain(clause& c, unsigned literal_idx, euf::enode* const* binding); void explain_eq(unsigned n, euf::enode* const* binding, expr* s, expr* t); void explain_diseq(unsigned n, euf::enode* const* binding, expr* s, expr* t); @@ -143,10 +165,13 @@ namespace q { void collect_statistics(statistics& st) const; + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing); + // callback from mam void on_binding(quantifier* q, app* pat, euf::enode* const* binding); std::ostream& display(std::ostream& out) const; + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const; }; diff --git a/src/sat/smt/q_mam.cpp b/src/sat/smt/q_mam.cpp index 2a4dbc19e..f5030648b 100644 --- a/src/sat/smt/q_mam.cpp +++ b/src/sat/smt/q_mam.cpp @@ -1344,7 +1344,7 @@ namespace q { if (p->is_ground()) { enode * e = m_egraph.find(p); if (!e->has_lbl_hash()) - e->set_lbl_hash(m_egraph); + m_egraph.set_lbl_hash(e); return e->get_lbl_hash(); } else { @@ -1365,7 +1365,7 @@ namespace q { bool is_semi_compatible(check * instr) const { unsigned reg = instr->m_reg; if (instr->m_enode && !instr->m_enode->has_lbl_hash()) - instr->m_enode->set_lbl_hash(m_egraph); + m_egraph.set_lbl_hash(instr->m_enode); return m_registers[reg] != 0 && // if the register was already checked by another filter, then it doesn't make sense @@ -1551,7 +1551,7 @@ namespace q { // So, when the pattern (f (g b) x) is compiled a check instruction // is created for a ground subterm b of the maximal ground term (g b). if (!n1->has_lbl_hash()) - n1->set_lbl_hash(m_egraph); + m_egraph.set_lbl_hash(n1); unsigned h1 = n1->get_lbl_hash(); unsigned h2 = get_pat_lbl_hash(reg); approx_set s(h1); @@ -3103,9 +3103,9 @@ namespace q { if (t != nullptr) { TRACE("mam_candidate", tout << "adding candidate:\n" << mk_ll_pp(app->get_expr(), m);); if (!t->has_candidates()) { - m_to_match.push_back(t); ctx.push(reset_to_match(*this)); } + m_to_match.push_back(t); t->add_candidate(app); } } @@ -3392,7 +3392,7 @@ namespace q { enode * n = m_egraph.find(child); update_plbls(plbl); if (!n->has_lbl_hash()) - n->set_lbl_hash(m_egraph); + m_egraph.set_lbl_hash(n); TRACE("mam_bug", tout << "updating pc labels " << plbl->get_name() << " " << static_cast(n->get_lbl_hash()) << "\n"; @@ -3467,6 +3467,7 @@ namespace q { \brief Collect new E-matching candidates using the inverted path index t. */ void collect_parents(enode * r, path_tree * t) { + TRACE("mam", tout << ctx.bpp(r) << " " << t << "\n";); if (t == nullptr) return; #ifdef _PROFILE_PATH_TREE @@ -3817,30 +3818,11 @@ namespace q { void on_match(quantifier * qa, app * pat, unsigned num_bindings, enode * const * bindings, unsigned max_generation) override { TRACE("trigger_bug", tout << "found match " << mk_pp(qa, m) << "\n";); -#ifdef Z3DEBUG - if (m_check_missing_instances) { -#if 0 - if (!m_egraph.slow_contains_instance(qa, num_bindings, bindings)) { - TRACE("missing_instance", - tout << "qa:\n" << mk_ll_pp(qa, m) << "\npat:\n" << mk_ll_pp(pat, m); - for (unsigned i = 0; i < num_bindings; i++) - tout << "#" << bindings[i]->get_expr_id() << "\n" << mk_ll_pp(bindings[i]->get_expr(), m) << "\n"; - ); - UNREACHABLE(); - } -#endif - return; - } - for (unsigned i = 0; i < num_bindings; i++) { - SASSERT(bindings[i]->generation() <= max_generation); - } -#endif unsigned min_gen = 0, max_gen = 0; m_interpreter.get_min_max_top_generation(min_gen, max_gen); m_ematch.on_binding(qa, pat, bindings); // max_generation); // , min_gen, max_gen; } - // This method is invoked when n becomes relevant. // If lazy == true, then n is not added to the list of candidate enodes for matching. That is, the method just updates the lbls. void relevant_eh(enode * n, bool lazy) { @@ -3874,7 +3856,7 @@ namespace q { flet l1(m_other, other); flet l2(m_root, root); - TRACE("mam", tout << "add_eq_eh: #" << other->get_expr_id() << " #" << root->get_expr_id() << "\n";); + TRACE("mam", tout << "on_merge: #" << other->get_expr_id() << " #" << root->get_expr_id() << "\n";); TRACE("mam_inc_bug_detail", m_egraph.display(tout);); TRACE("mam_inc_bug", tout << "before:\n#" << other->get_expr_id() << " #" << root->get_expr_id() << "\n"; diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index 0e8e9dbb9..e85878fe5 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -76,6 +76,10 @@ namespace q { return out; } + std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { + return m_ematch.display_constraint(out, idx); + } + void solver::collect_statistics(statistics& st) const { st.update("quantifier asserts", m_stats.m_num_quantifier_asserts); m_mbqi.collect_statistics(st); @@ -87,7 +91,6 @@ namespace q { } bool solver::unit_propagate() { - TRACE("q", tout << "propagate\n";); return ctx.get_config().m_ematching && m_ematch.propagate(); } @@ -281,4 +284,8 @@ namespace q { return m_expanded; } + void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) { + m_ematch.get_antecedents(l, idx, r, probing); + } + } diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index 19ad42d18..6f1077dc9 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -64,13 +64,13 @@ namespace q { solver(euf::solver& ctx, family_id fid); ~solver() override {} bool is_external(sat::bool_var v) override { return false; } - void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override; void asserted(sat::literal l) override; sat::check_result check() override; std::ostream& display(std::ostream& out) const override; - std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } - std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return display_constraint(out, idx); } + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; void collect_statistics(statistics& st) const override; euf::th_solver* clone(euf::solver& ctx) override; bool unit_propagate() override;