From 7d915eb295318eb24aa7f221c8abb437a03a2ac0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 19 Jul 2021 07:40:46 -0700 Subject: [PATCH] #5417 - revise q_eval based on bug based on non-chronological dependencies with post-hoc explain function --- src/ast/euf/euf_egraph.h | 20 ++--- src/ast/euf/euf_etable.h | 2 +- src/sat/smt/q_clause.h | 8 +- src/sat/smt/q_ematch.cpp | 18 +++-- src/sat/smt/q_ematch.h | 13 +-- src/sat/smt/q_eval.cpp | 168 ++++++++++++++------------------------- src/sat/smt/q_eval.h | 15 ++-- 7 files changed, 99 insertions(+), 145 deletions(-) diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 100d00460..9533a83e6 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -156,26 +156,26 @@ namespace euf { svector m_updates; unsigned_vector m_scopes; enode_vector m_expr2enode; - enode* m_tmp_eq { nullptr }; - enode* m_tmp_node { nullptr }; - unsigned m_tmp_node_capacity { 0 }; + enode* m_tmp_eq = nullptr; + enode* m_tmp_node = nullptr; + unsigned m_tmp_node_capacity = 0; enode_vector m_nodes; expr_ref_vector m_exprs; vector m_decl2enodes; enode_vector m_empty_enodes; - unsigned m_num_scopes { 0 }; - bool m_inconsistent { false }; - enode *m_n1 { nullptr }; - enode *m_n2 { nullptr }; + unsigned m_num_scopes = 0; + bool m_inconsistent = false; + enode *m_n1 = nullptr; + enode *m_n2 = nullptr; justification m_justification; - unsigned m_new_lits_qhead { 0 }; - unsigned m_new_th_eqs_qhead { 0 }; + unsigned m_new_lits_qhead = 0; + unsigned m_new_th_eqs_qhead = 0; svector m_new_lits; svector m_new_th_eqs; bool_vector m_th_propagates_diseqs; enode_vector m_todo; stats m_stats; - bool m_uses_congruence { false }; + bool m_uses_congruence = false; std::function m_on_merge; std::function m_on_make; std::function m_used_eq; diff --git a/src/ast/euf/euf_etable.h b/src/ast/euf/euf_etable.h index 68ae95cd2..d6b64e756 100644 --- a/src/ast/euf/euf_etable.h +++ b/src/ast/euf/euf_etable.h @@ -125,7 +125,7 @@ namespace euf { ast_manager & m_manager; - bool m_commutativity{ false }; //!< true if the last found congruence used commutativity + bool m_commutativity = false; //!< true if the last found congruence used commutativity ptr_vector m_tables; map m_func_decl2id; diff --git a/src/sat/smt/q_clause.h b/src/sat/smt/q_clause.h index 7278d1db1..8eca0df76 100644 --- a/src/sat/smt/q_clause.h +++ b/src/sat/smt/q_clause.h @@ -57,7 +57,7 @@ namespace q { unsigned m_index; vector m_lits; quantifier_ref m_q; - sat::literal m_literal; + sat::literal m_literal = sat::null_literal; q::quantifier_stat* m_stat = nullptr; binding* m_bindings = nullptr; @@ -75,10 +75,12 @@ namespace q { struct justification { expr* m_lhs, *m_rhs; bool m_sign; + unsigned m_num_ev; + euf::enode_pair* m_evidence; clause& m_clause; euf::enode* const* m_binding; - justification(lit const& l, clause& c, euf::enode* const* b): - m_lhs(l.lhs), m_rhs(l.rhs), m_sign(l.sign), m_clause(c), m_binding(b) {} + justification(lit const& l, clause& c, euf::enode* const* b, unsigned n, euf::enode_pair* ev): + m_lhs(l.lhs), m_rhs(l.rhs), m_sign(l.sign), m_clause(c), m_binding(b), m_num_ev(n), m_evidence(ev) {} sat::ext_constraint_idx to_index() const { return sat::constraint_base::mem2base(this); } diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index e997476ed..7de922f40 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -94,7 +94,10 @@ namespace q { lit lit(expr_ref(l, m), expr_ref(r, m), sign); if (idx != UINT_MAX) lit = c[idx]; - auto* constraint = new (sat::constraint_base::ptr2mem(mem)) justification(lit, c, b); + auto* ev = static_cast(ctx.get_region().allocate(sizeof(euf::enode_pair) * m_evidence.size())); + for (unsigned i = m_evidence.size(); i-- > 0; ) + ev[i] = m_evidence[i]; + auto* constraint = new (sat::constraint_base::ptr2mem(mem)) justification(lit, c, b, m_evidence.size(), ev); return constraint->to_index(); } @@ -251,7 +254,8 @@ namespace q { bool ematch::propagate(bool is_owned, euf::enode* const* binding, unsigned max_generation, clause& c, bool& propagated) { TRACE("q", c.display(ctx, tout) << "\n";); unsigned idx = UINT_MAX; - lbool ev = m_eval(binding, c, idx); + m_evidence.reset(); + lbool ev = m_eval(binding, c, idx, m_evidence); if (ev == l_true) { ++m_stats.m_num_redundant; return true; @@ -267,15 +271,18 @@ namespace q { if (ev == l_undef && max_generation > m_generation_propagation_threshold) return false; if (!is_owned) - binding = alloc_binding(c, binding); - auto j_idx = mk_justification(idx, c, binding); + binding = alloc_binding(c, binding); + + auto j_idx = mk_justification(idx, c, binding); + if (ev == l_false) { ++m_stats.m_num_conflicts; ctx.set_conflict(j_idx); } else { ++m_stats.m_num_propagations; - ctx.propagate(instantiate(c, binding, c[idx]), j_idx); + auto lit = instantiate(c, binding, c[idx]); + ctx.propagate(lit, j_idx); } propagated = true; return true; @@ -295,6 +302,7 @@ namespace q { } void ematch::add_instantiation(clause& c, binding& b, sat::literal lit) { + m_evidence.reset(); ctx.propagate(lit, mk_justification(UINT_MAX, c, b.nodes())); } diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index 443b9d947..0f8add4c1 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -63,7 +63,7 @@ namespace q { quantifier_stat_gen m_qstat_gen; fingerprints m_fingerprints; scoped_ptr m_tmp_binding; - unsigned m_tmp_binding_capacity { 0 }; + unsigned m_tmp_binding_capacity = 0; queue m_inst_queue; pattern_inference_rw m_infer_patterns; scoped_ptr m_mam, m_lazy_mam; @@ -72,13 +72,14 @@ namespace q { vector m_watch; // expr_id -> clause-index* stats m_stats; expr_fast_mark1 m_mark; - unsigned m_generation_propagation_threshold{ 3 }; + unsigned m_generation_propagation_threshold = 3; ptr_vector m_ground; - bool m_in_queue_set{ false }; + bool m_in_queue_set = false; nat_set m_node_in_queue; nat_set m_clause_in_queue; - unsigned m_qhead { 0 }; + unsigned m_qhead = 0; unsigned_vector m_clause_queue; + euf::enode_pair_vector m_evidence; binding* alloc_binding(unsigned n, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top); euf::enode* const* alloc_binding(clause& c, euf::enode* const* _binding); @@ -115,7 +116,7 @@ namespace q { bool propagate(bool flush); - void init_search(); + // void init_search(); void add(quantifier* q); @@ -127,7 +128,7 @@ namespace q { void on_binding(quantifier* q, app* pat, euf::enode* const* binding, unsigned max_generation, unsigned min_gen, unsigned max_gen); // callbacks from queue - lbool evaluate(euf::enode* const* binding, clause& c) { return m_eval(binding, c); } + lbool evaluate(euf::enode* const* binding, clause& c) { m_evidence.reset(); return m_eval(binding, c, m_evidence); } void add_instantiation(clause& c, binding& b, sat::literal lit); diff --git a/src/sat/smt/q_eval.cpp b/src/sat/smt/q_eval.cpp index 5e5cc83b8..9dcf5755f 100644 --- a/src/sat/smt/q_eval.cpp +++ b/src/sat/smt/q_eval.cpp @@ -32,7 +32,7 @@ namespace q { m(ctx.get_manager()) {} - lbool eval::operator()(euf::enode* const* binding, clause& c, unsigned& idx) { + lbool eval::operator()(euf::enode* const* binding, clause& c, unsigned& idx, euf::enode_pair_vector& evidence) { scoped_mark_reset _sr(*this); idx = UINT_MAX; unsigned sz = c.m_lits.size(); @@ -41,7 +41,7 @@ namespace q { for (unsigned i = 0; i < sz; ++i) { unsigned lim = m_indirect_nodes.size(); lit l = c[i]; - lbool cmp = compare(n, binding, l.lhs, l.rhs); + lbool cmp = compare(n, binding, l.lhs, l.rhs, evidence); switch (cmp) { case l_false: m_indirect_nodes.shrink(lim); @@ -75,46 +75,55 @@ namespace q { return l_undef; } - lbool eval::operator()(euf::enode* const* binding, clause& c) { + lbool eval::operator()(euf::enode* const* binding, clause& c, euf::enode_pair_vector& evidence) { unsigned idx = 0; - return (*this)(binding, c, idx); + return (*this)(binding, c, idx, evidence); } - lbool eval::compare(unsigned n, euf::enode* const* binding, expr* s, expr* t) { + lbool eval::compare(unsigned n, euf::enode* const* binding, expr* s, expr* t, euf::enode_pair_vector& evidence) { if (s == t) return l_true; if (m.are_distinct(s, t)) return l_false; - euf::enode* sn = (*this)(n, binding, s); - euf::enode* tn = (*this)(n, binding, t); - if (sn) sn = sn->get_root(); - if (tn) tn = tn->get_root(); + euf::enode* sn = (*this)(n, binding, s, evidence); + euf::enode* tn = (*this)(n, binding, t, evidence); + euf::enode* sr = sn ? sn->get_root() : sn; + euf::enode* tr = tn ? tn->get_root() : tn; + if (sn != sr) evidence.push_back(euf::enode_pair(sn, sr)), sn = sr; + if (tn != tr) evidence.push_back(euf::enode_pair(tn, tr)), tn = tr; TRACE("q", tout << mk_pp(s, m) << " ~~ " << mk_pp(t, m) << "\n"; tout << ctx.bpp(sn) << " " << ctx.bpp(tn) << "\n";); lbool c; if (sn && sn == tn) return l_true; - if (sn && tn && ctx.get_egraph().are_diseq(sn, tn)) + if (sn && tn && ctx.get_egraph().are_diseq(sn, tn)) { + evidence.push_back(euf::enode_pair(sn, tn)); return l_false; + } if (sn && tn) return l_undef; if (!sn && !tn) - return compare_rec(n, binding, s, t); + return compare_rec(n, binding, s, t, evidence); if (!tn && sn) { std::swap(tn, sn); std::swap(t, s); } - for (euf::enode* t1 : euf::enode_class(tn)) - if (c = compare_rec(n, binding, s, t1->get_expr()), c != l_undef) + unsigned sz = evidence.size(); + for (euf::enode* t1 : euf::enode_class(tn)) { + if (c = compare_rec(n, binding, s, t1->get_expr(), evidence), c != l_undef) { + evidence.push_back(euf::enode_pair(t1, tn)); return c; + } + evidence.shrink(sz); + } return l_undef; } // f(p1) = f(p2) if p1 = p2 // f(p1) != f(p2) if p1 != p2 and f is injective - lbool eval::compare_rec(unsigned n, euf::enode* const* binding, expr* s, expr* t) { + lbool eval::compare_rec(unsigned n, euf::enode* const* binding, expr* s, expr* t, euf::enode_pair_vector& evidence) { if (m.are_equal(s, t)) return l_true; if (m.are_distinct(s, t)) @@ -127,14 +136,20 @@ namespace q { return l_undef; bool is_injective = to_app(s)->get_decl()->is_injective(); bool has_undef = false; + unsigned sz = evidence.size(); for (unsigned i = to_app(s)->get_num_args(); i-- > 0; ) { - switch (compare(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i))) { + unsigned sz1 = evidence.size(), sz2; + switch (compare(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i), evidence)) { case l_true: break; - case l_false: - if (is_injective) - return l_false; - return l_undef; + case l_false: + if (!is_injective) + return l_undef; + sz2 = evidence.size(); + for (unsigned i = 0; i < sz2 - sz1; ++i) + evidence[sz + i] = evidence[sz1 + i]; + evidence.shrink(sz + sz2 - sz1); + return l_false; case l_undef: if (!is_injective) return l_undef; @@ -142,10 +157,15 @@ namespace q { break; } } - return has_undef ? l_undef : l_true; + + if (!has_undef) + return l_true; + + evidence.shrink(sz); + return l_undef; } - euf::enode* eval::operator()(unsigned n, euf::enode* const* binding, expr* e) { + euf::enode* eval::operator()(unsigned n, euf::enode* const* binding, expr* e, euf::enode_pair_vector& evidence) { if (is_ground(e)) return ctx.get_egraph().find(e); if (m_mark.is_marked(e)) @@ -186,6 +206,15 @@ namespace q { euf::enode* n = ctx.get_egraph().find(t, args.size(), args.data()); if (!n) return nullptr; + for (unsigned i = args.size(); i-- > 0; ) { + if (args[i] != n->get_arg(i)) { + // roots could be different when using commutativity + // instead of compensating for this, we just bail out + if (args[i]->get_root() != n->get_arg(i)->get_root()) + return nullptr; + evidence.push_back(euf::enode_pair(args[i], n->get_arg(i))); + } + } m_indirect_nodes.push_back(n); m_eval.setx(t->get_id(), n, nullptr); m_mark.mark(t); @@ -195,99 +224,18 @@ namespace q { return m_eval[e->get_id()]; } - void eval::explain(clause& c, unsigned literal_idx, euf::enode* const* b) { - unsigned n = c.num_decls(); - for (unsigned i = c.size(); i-- > 0; ) { - if (i == literal_idx) - continue; - auto const& lit = c[i]; - if (lit.sign) - explain_eq(n, b, lit.lhs, lit.rhs); - else - explain_diseq(n, b, lit.lhs, lit.rhs); - } - } - - void eval::explain_eq(unsigned n, euf::enode* const* binding, expr* s, expr* t) { - SASSERT(l_true == compare(n, binding, s, t)); - if (s == t) - return; - euf::enode* sn = (*this)(n, binding, s); - euf::enode* tn = (*this)(n, binding, t); - if (sn && tn) { - SASSERT(sn->get_root() == tn->get_root()); - ctx.add_antecedent(sn, tn); - return; - } - if (!sn && tn) { - std::swap(sn, tn); - std::swap(s, t); - } - if (sn && !tn) { - for (euf::enode* s1 : euf::enode_class(sn)) { - if (l_true == compare_rec(n, binding, t, s1->get_expr())) { - ctx.add_antecedent(sn, s1); - explain_eq(n, binding, t, s1->get_expr()); - return; - } - } - UNREACHABLE(); - } - SASSERT(is_app(s) && is_app(t)); - SASSERT(to_app(s)->get_decl() == to_app(t)->get_decl()); - for (unsigned i = to_app(s)->get_num_args(); i-- > 0; ) - explain_eq(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i)); - } - - void eval::explain_diseq(unsigned n, euf::enode* const* binding, expr* s, expr* t) { - SASSERT(l_false == compare(n, binding, s, t)); - if (m.are_distinct(s, t)) - return; - euf::enode* sn = (*this)(n, binding, s); - euf::enode* tn = (*this)(n, binding, t); - if (sn && tn && ctx.get_egraph().are_diseq(sn, tn)) { - ctx.add_diseq_antecedent(sn, tn); - return; - } - if (!sn && tn) { - std::swap(sn, tn); - std::swap(s, t); - } - if (sn && !tn) { - for (euf::enode* s1 : euf::enode_class(sn)) { - if (l_false == compare_rec(n, binding, t, s1->get_expr())) { - ctx.add_antecedent(sn, s1); - explain_diseq(n, binding, t, s1->get_expr()); - return; - } - } - UNREACHABLE(); - } - SASSERT(is_app(s) && is_app(t)); - app* at = to_app(t); - app* as = to_app(s); - SASSERT(as->get_decl() == at->get_decl()); - for (unsigned i = as->get_num_args(); i-- > 0; ) { - if (l_false == compare_rec(n, binding, as->get_arg(i), at->get_arg(i))) { - explain_eq(n, binding, as->get_arg(i), at->get_arg(i)); - return; - } - } - UNREACHABLE(); - } - - void eval::explain(sat::literal l, justification& j, sat::literal_vector& r, bool probing) { - scoped_mark_reset _sr(*this); - unsigned l_idx = 0; clause& c = j.m_clause; - for (; l_idx < c.size(); ++l_idx) { - if (c[l_idx].lhs == j.m_lhs && c[l_idx].rhs == j.m_rhs && c[l_idx].sign == j.m_sign) - break; + for (unsigned i = 0; i < j.m_num_ev; ++i) { + auto [a, b] = j.m_evidence[i]; + SASSERT(a->get_root() == b->get_root() || ctx.get_egraph().are_diseq(a, b)); + if (a->get_root() == b->get_root()) + ctx.add_antecedent(a, b); + else + ctx.add_diseq_antecedent(a, b); } - explain(c, l_idx, j.m_binding); r.push_back(c.m_literal); - (void)probing; // ignored + (void)probing; // ignored } diff --git a/src/sat/smt/q_eval.h b/src/sat/smt/q_eval.h index 6219c473a..5e520dc17 100644 --- a/src/sat/smt/q_eval.h +++ b/src/sat/smt/q_eval.h @@ -30,25 +30,20 @@ namespace q { expr_fast_mark1 m_mark; euf::enode_vector m_eval; euf::enode_vector m_indirect_nodes; - ptr_vector m_explain; struct scoped_mark_reset; - void explain(clause& c, unsigned literal_idx, euf::enode* const* binding); - void explain_eq(unsigned n, euf::enode* const* binding, expr* s, expr* t); - void explain_diseq(unsigned n, euf::enode* const* binding, expr* s, expr* t); - // compare s, t modulo binding - lbool compare(unsigned n, euf::enode* const* binding, expr* s, expr* t); - lbool compare_rec(unsigned n, euf::enode* const* binding, expr* s, expr* t); + lbool compare(unsigned n, euf::enode* const* binding, expr* s, expr* t, euf::enode_pair_vector& evidence); + lbool compare_rec(unsigned n, euf::enode* const* binding, expr* s, expr* t, euf::enode_pair_vector& evidence); public: eval(euf::solver& ctx); void explain(sat::literal l, justification& j, sat::literal_vector& r, bool probing); - lbool operator()(euf::enode* const* binding, clause& c); - lbool operator()(euf::enode* const* binding, clause& c, unsigned& idx); - euf::enode* operator()(unsigned n, euf::enode* const* binding, expr* e); + lbool operator()(euf::enode* const* binding, clause& c, euf::enode_pair_vector& evidence); + lbool operator()(euf::enode* const* binding, clause& c, unsigned& idx, euf::enode_pair_vector& evidence); + euf::enode* operator()(unsigned n, euf::enode* const* binding, expr* e, euf::enode_pair_vector& evidence); euf::enode_vector const& get_watch() { return m_indirect_nodes; } };