From 28bce8f09ccdd194c8ebcab1a128edc0e8e2d721 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 28 Dec 2021 11:00:02 -0800 Subject: [PATCH] working on relevant --- src/ast/euf/euf_egraph.cpp | 4 +- src/ast/euf/euf_egraph.h | 5 ++- src/sat/smt/euf_relevancy.cpp | 73 ++++++++++++++++------------------- src/sat/smt/euf_solver.cpp | 53 ++++++++++++------------- src/sat/smt/euf_solver.h | 8 ++-- src/sat/smt/q_ematch.cpp | 12 +++++- src/sat/smt/q_ematch.h | 3 +- src/sat/smt/q_solver.h | 1 + src/sat/smt/sat_th.h | 2 + src/sat/smt/smt_relevant.cpp | 43 ++++++++++++--------- src/sat/smt/smt_relevant.h | 15 ++++--- 11 files changed, 121 insertions(+), 98 deletions(-) diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 1f8e14565..f1e51ee1d 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -446,8 +446,8 @@ namespace euf { r2->inc_class_size(r1->class_size()); merge_th_eq(r1, r2); reinsert_parents(r1, r2); - if (m_on_merge) - m_on_merge(r2, r1); + for (auto& cb : m_on_merge) + cb(r2, r1); } void egraph::remove_parents(enode* r1, enode* r2) { diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 7c1f9e566..ef70fafd5 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -32,6 +32,7 @@ Notes: #include "ast/euf/euf_enode.h" #include "ast/euf/euf_etable.h" #include "ast/ast_ll_pp.h" +#include namespace euf { @@ -181,7 +182,7 @@ namespace euf { enode_vector m_todo; stats m_stats; bool m_uses_congruence = false; - std::function m_on_merge; + std::vector> m_on_merge; std::function m_on_make; std::function m_used_eq; std::function m_used_cc; @@ -293,7 +294,7 @@ namespace euf { void set_value(enode* n, lbool value); void set_bool_var(enode* n, unsigned v) { n->set_bool_var(v); } - void set_on_merge(std::function& on_merge) { m_on_merge = on_merge; } + void set_on_merge(std::function& on_merge) { m_on_merge.push_back(on_merge); } void set_on_make(std::function& on_make) { m_on_make = on_make; } void set_used_eq(std::function& used_eq) { m_used_eq = used_eq; } void set_used_cc(std::function& used_cc) { m_used_cc = used_cc; } diff --git a/src/sat/smt/euf_relevancy.cpp b/src/sat/smt/euf_relevancy.cpp index 64aea0c8c..a15e7bacc 100644 --- a/src/sat/smt/euf_relevancy.cpp +++ b/src/sat/smt/euf_relevancy.cpp @@ -23,10 +23,10 @@ Author: namespace euf { void solver::add_auto_relevant(sat::literal lit) { -#if NEW_RELEVANCY - m_relevancy.mark_relevant(lit); - return; -#endif + if (m_relevancy.enabled()) { + m_relevancy.mark_relevant(lit); + return; + } if (!relevancy_enabled()) return; for (; m_auto_relevant_scopes > 0; --m_auto_relevant_scopes) @@ -37,10 +37,10 @@ namespace euf { } void solver::pop_relevant(unsigned n) { -#if NEW_RELEVANCY - m_relevancy.pop(n); - return; -#endif + if (m_relevancy.enabled()) { + m_relevancy.pop(n); + return; + } if (m_auto_relevant_scopes >= n) { m_auto_relevant_scopes -= n; return; @@ -54,31 +54,28 @@ namespace euf { } void solver::push_relevant() { -#if NEW_RELEVANCY - m_relevancy.push(); - return; -#endif + if (m_relevancy.enabled()) { + m_relevancy.push(); + return; + } ++m_auto_relevant_scopes; } bool solver::is_relevant(expr* e) const { -#if NEW_RELEVANCY - return m_relevancy.is_relevant(e); -#endif + if (m_relevancy.enabled()) + return m_relevancy.is_relevant(e); return m_relevant_expr_ids.get(e->get_id(), true); } bool solver::is_relevant(enode* n) const { -#if NEW_RELEVANCY - return m_relevancy.is_relevant(n); -#endif + if (m_relevancy.enabled()) + return m_relevancy.is_relevant(n); return m_relevant_expr_ids.get(n->get_expr_id(), true); } void solver::ensure_dual_solver() { -#if NEW_RELEVANCY - return; -#endif + if (m_relevancy.enabled()) + return; if (m_dual_solver) return; m_dual_solver = alloc(sat::dual_solver, s(), s().rlimit()); @@ -93,10 +90,10 @@ namespace euf { * not tracked. */ void solver::add_root(unsigned n, sat::literal const* lits) { -#if NEW_RELEVANCY - m_relevancy.add_root(n, lits); - return; -#endif + if (m_relevancy.enabled()) { + m_relevancy.add_root(n, lits); + return; + } if (!relevancy_enabled()) return; ensure_dual_solver(); @@ -104,10 +101,10 @@ namespace euf { } void solver::add_aux(unsigned n, sat::literal const* lits) { -#if NEW_RELEVANCY - m_relevancy.add_def(n, lits); - return; -#endif + if (m_relevancy.enabled()) { + m_relevancy.add_def(n, lits); + return; + } if (!relevancy_enabled()) return; ensure_dual_solver(); @@ -115,17 +112,15 @@ namespace euf { } void solver::track_relevancy(sat::bool_var v) { -#if NEW_RELEVANCY - return; -#endif + if (m_relevancy.enabled()) + return; ensure_dual_solver(); m_dual_solver->track_relevancy(v); } bool solver::init_relevancy() { -#if NEW_RELEVANCY - return true; -#endif + if (m_relevancy.enabled()) + return true; m_relevant_expr_ids.reset(); if (!relevancy_enabled()) return true; @@ -144,19 +139,19 @@ namespace euf { } void solver::push_relevant(sat::bool_var v) { - SASSERT(!NEW_RELEVANCY); + SASSERT(!m_relevancy.enabled()); expr* e = m_bool_var2expr.get(v, nullptr); if (e) m_relevant_todo.push_back(e); } bool solver::is_propagated(sat::literal lit) { - SASSERT(!NEW_RELEVANCY); + SASSERT(!m_relevancy.enabled()); return s().value(lit) == l_true && !s().get_justification(lit.var()).is_none(); } void solver::init_relevant_expr_ids() { - SASSERT(!NEW_RELEVANCY); + SASSERT(!m_relevancy.enabled()); unsigned max_id = 0; for (enode* n : m_egraph.nodes()) max_id = std::max(max_id, n->get_expr_id()); @@ -166,7 +161,7 @@ namespace euf { } void solver::relevant_subterms() { - SASSERT(!NEW_RELEVANCY); + SASSERT(!m_relevancy.enabled()); ptr_vector& todo = m_relevant_todo; bool_vector& visited = m_relevant_visited; for (unsigned i = 0; i < todo.size(); ++i) { diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 2c602d205..ca931ce8f 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -58,6 +58,14 @@ namespace euf { display_justification_ptr(out, reinterpret_cast(j)); }; m_egraph.set_display_justification(disp); + + if (m_relevancy.enabled()) { + std::function on_merge = + [&](enode* root, enode* other) { + m_relevancy.merge(root, other); + }; + m_egraph.set_on_merge(on_merge); + } } void solver::updt_params(params_ref const& p) { @@ -244,17 +252,10 @@ namespace euf { bool solver::propagate(enode* a, enode* b, ext_justification_idx idx) { if (a->get_root() == b->get_root()) return false; - merge(a, b, to_ptr(idx)); + m_egraph.merge(a, b, to_ptr(idx)); return true; } - void solver::merge(enode* a, enode* b, void* r) { -#if NEW_RELEVANCY - m_relevancy.merge(a, b); -#endif - m_egraph.merge(a, b, r); - } - void solver::get_antecedents(literal l, constraint& j, literal_vector& r, bool probing) { expr* e = nullptr; euf::enode* n = nullptr; @@ -294,12 +295,11 @@ namespace euf { } void solver::asserted(literal l) { -#if NEW_RELEVANCY - if (!m_relevancy.is_relevant(l)) { + if (m_relevancy.enabled() && !m_relevancy.is_relevant(l)) { m_relevancy.asserted(l); return; } -#endif + expr* e = m_bool_var2expr.get(l.var(), nullptr); TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << " := " << mk_bounded_pp(e, m) << "\n";); if (!e) @@ -320,14 +320,14 @@ namespace euf { euf::enode* r = n->get_root(); euf::enode* rb = sign ? mk_true() : mk_false(); sat::literal rl(r->bool_var(), r->value() == l_false); - merge(n, nb, c); - merge(r, rb, to_ptr(rl)); + m_egraph.merge(n, nb, c); + m_egraph.merge(r, rb, to_ptr(rl)); SASSERT(m_egraph.inconsistent()); return; } if (n->merge_tf()) { euf::enode* nb = sign ? mk_false() : mk_true(); - merge(n, nb, c); + m_egraph.merge(n, nb, c); } if (n->is_equality()) { SASSERT(!m.is_iff(e)); @@ -335,7 +335,7 @@ namespace euf { if (sign) m_egraph.new_diseq(n); else - merge(n->get_arg(0), n->get_arg(1), c); + m_egraph.merge(n->get_arg(0), n->get_arg(1), c); } } @@ -343,9 +343,8 @@ namespace euf { bool solver::unit_propagate() { bool propagated = false; while (!s().inconsistent()) { -#if NEW_RELEVANCY - m_relevancy.propagate(); -#endif + if (m_relevancy.enabled()) + m_relevancy.propagate(); if (m_egraph.inconsistent()) { unsigned lvl = s().scope_lvl(); s().set_conflict(sat::justification::mk_ext_justification(lvl, conflict_constraint().to_index())); @@ -362,9 +361,13 @@ namespace euf { if (m_solvers[i]->unit_propagate()) propagated1 = true; - if (!propagated1) - break; - propagated = true; + if (propagated1) { + propagated = true; + continue; + } + if (m_relevancy.enabled() && m_relevancy.can_propagate()) + continue; + break; } DEBUG_CODE(if (!propagated && !s().inconsistent()) check_missing_eq_propagation();); return propagated; @@ -439,12 +442,10 @@ namespace euf { void solver::propagate_th_eqs() { for (; m_egraph.has_th_eq() && !s().inconsistent() && !m_egraph.inconsistent(); m_egraph.next_th_eq()) { th_eq eq = m_egraph.get_th_eq(); - if (eq.is_eq()) { - if (!is_self_propagated(eq)) - m_id2solver[eq.id()]->new_eq_eh(eq); - } - else + if (!eq.is_eq()) m_id2solver[eq.id()]->new_diseq_eh(eq); + else if (!is_self_propagated(eq)) + m_id2solver[eq.id()]->new_eq_eh(eq); } } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 1396e00fb..2145594f1 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -31,7 +31,6 @@ Author: #include "sat/smt/smt_relevant.h" #include "smt/params/smt_params.h" -#define NEW_RELEVANCY 0 namespace euf { typedef sat::literal literal; @@ -257,6 +256,10 @@ namespace euf { sat::sat_internalizer& get_si() { return si; } ast_manager& get_manager() { return m; } enode* get_enode(expr* e) const { return m_egraph.find(e); } + enode* bool_var2enode(sat::bool_var b) const { + expr* e = m_bool_var2expr.get(b); + return e ? get_enode(e) : nullptr; + } sat::literal expr2literal(expr* e) const { return enode2literal(get_enode(e)); } sat::literal enode2literal(enode* n) const { return sat::literal(n->bool_var(), false); } lbool value(enode* n) const { return s().value(enode2literal(n)); } @@ -293,7 +296,6 @@ namespace euf { void propagate(literal lit, ext_justification_idx idx); bool propagate(enode* a, enode* b, ext_justification_idx idx); - void merge(enode* a, enode* b, void* r); void set_conflict(ext_justification_idx idx); void propagate(literal lit, th_explain* p) { propagate(lit, p->to_index()); } @@ -395,7 +397,7 @@ namespace euf { void add_auto_relevant(sat::literal lit); void pop_relevant(unsigned n); void push_relevant(); - + smt::relevancy& relevancy() { return m_relevancy; } // model construction void update_model(model_ref& mdl); diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index e4e01b964..a5f64f677 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -66,13 +66,21 @@ namespace q { }; std::function _on_make = [&](euf::enode* n) { - m_mam->add_node(n, false); + relevant_eh(n); + }; ctx.get_egraph().set_on_merge(_on_merge); - ctx.get_egraph().set_on_make(_on_make); + if (ctx.relevancy().enabled()) + ctx.get_egraph().set_on_make(_on_make); + else + ctx.relevancy().add_relevant(&s); m_mam = mam::mk(ctx, *this); } + void ematch::relevant_eh(euf::enode* n) { + m_mam->add_node(n, false); + } + void ematch::ensure_ground_enodes(expr* e) { mam::ground_subterms(e, m_ground); for (expr* g : m_ground) diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index 3cdcfc80e..0db541c1c 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -135,9 +135,10 @@ namespace q { bool unit_propagate(); - void add(quantifier* q); + void relevant_eh(euf::enode* n); + void collect_statistics(statistics& st) const; void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing); diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index 934864669..5d6a52c8f 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -83,6 +83,7 @@ namespace q { void init_search() override; void finalize_model(model& mdl) override; bool is_shared(euf::theory_var v) const override { return true; } + void relevant_eh(euf::enode* n) override { m_ematch.relevant_eh(n); } ast_manager& get_manager() { return m; } sat::literal_vector const& universal() const { return m_universal; } diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 7e66a7156..fecb745c5 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -111,6 +111,8 @@ namespace euf { virtual void new_diseq_eh(euf::th_eq const& eq) {} + virtual void relevant_eh(euf::enode* n) {} + /** \brief Parametric theories (e.g. Arrays) should implement this method. */ diff --git a/src/sat/smt/smt_relevant.cpp b/src/sat/smt/smt_relevant.cpp index d70d37beb..9d118001c 100644 --- a/src/sat/smt/smt_relevant.cpp +++ b/src/sat/smt/smt_relevant.cpp @@ -3,7 +3,7 @@ Copyright (c) 2020 Microsoft Corporation Module Name: - relevancy.cpp + smt_relevant.cpp Abstract: @@ -22,11 +22,13 @@ Author: namespace smt { relevancy::relevancy(euf::solver& ctx): ctx(ctx) { - m_enabled = ctx.relevancy_enabled(); + m_enabled = ctx.get_config().m_relevancy_lvl > 2; } void relevancy::relevant_eh(euf::enode* n) { - // nothing + SASSERT(is_relevant(n)); + for (auto* th : m_relevant_eh) + th->relevant_eh(n); } void relevancy::relevant_eh(sat::literal lit) { @@ -102,8 +104,10 @@ namespace smt { m_clauses.push_back(cl); m_roots.push_back(true); m_trail.push_back(std::make_pair(update::add_clause, 0)); - for (sat::literal lit : *cl) + for (sat::literal lit : *cl) { + ctx.s().set_external(lit.var()); occurs(lit).push_back(sz); + } } void relevancy::add_def(unsigned n, sat::literal const* lits) { @@ -121,15 +125,17 @@ namespace smt { m_clauses.push_back(cl); m_roots.push_back(false); m_trail.push_back(std::make_pair(update::add_clause, 0)); - for (sat::literal lit : *cl) - occurs(lit).push_back(sz); + for (sat::literal lit : *cl) { + ctx.s().set_external(lit.var()); + occurs(lit).push_back(sz); + } } void relevancy::asserted(sat::literal lit) { if (!m_enabled) return; flush(); - if (ctx.s().lvl(lit) == 0) { + if (ctx.s().lvl(lit) <= ctx.s().search_lvl()) { mark_relevant(lit); return; } @@ -164,11 +170,11 @@ namespace smt { } } - void relevancy::merge(euf::enode* n1, euf::enode* n2) { - if (is_relevant(n1)) - mark_relevant(n2); - else if (is_relevant(n2)) - mark_relevant(n1); + void relevancy::merge(euf::enode* root, euf::enode* other) { + if (is_relevant(root)) + mark_relevant(other); + else if (is_relevant(other)) + mark_relevant(root); } void relevancy::mark_relevant(euf::enode* n) { @@ -177,6 +183,8 @@ namespace smt { flush(); if (is_relevant(n)) return; + if (ctx.get_si().is_bool_op(n->get_expr())) + return; for (euf::enode* sib : euf::enode_class(n)) set_relevant(sib); } @@ -195,12 +203,12 @@ namespace smt { flush(); if (is_relevant(lit)) return; + euf::enode* n = ctx.bool_var2enode(lit.var()); + if (n) + mark_relevant(n); m_relevant_var_ids.setx(lit.var(), true, false); m_trail.push_back(std::make_pair(update::relevant_var, lit.var())); m_queue.push_back(std::make_pair(lit, nullptr)); - euf::enode* n = nullptr; - if (n) - mark_relevant(n); } void relevancy::propagate_relevant(sat::literal lit) { @@ -208,9 +216,9 @@ namespace smt { for (auto idx : occurs(~lit)) { if (m_roots[idx]) continue; - sat::clause& cl = *m_clauses[idx]; + sat::clause* cl = m_clauses[idx]; sat::literal true_lit = sat::null_literal; - for (sat::literal lit2 : cl) { + for (sat::literal lit2 : *cl) { if (ctx.s().value(lit2) == l_true) { if (is_relevant(lit2)) goto next; @@ -231,7 +239,6 @@ namespace smt { void relevancy::propagate_relevant(euf::enode* n) { relevant_eh(n); - // if is_bool_op n, return; for (euf::enode* arg : euf::enode_args(n)) mark_relevant(arg); } diff --git a/src/sat/smt/smt_relevant.h b/src/sat/smt/smt_relevant.h index 6faa66b0f..a20ca1b44 100644 --- a/src/sat/smt/smt_relevant.h +++ b/src/sat/smt/smt_relevant.h @@ -3,7 +3,7 @@ Copyright (c) 2020 Microsoft Corporation Module Name: - relevancy.h + smt_relevant.h Abstract: @@ -26,7 +26,7 @@ The state transitions are: -> lit is set relevant - lit is justified at level 0 + lit is justified at search level -> lit is set relevant @@ -39,7 +39,7 @@ The state transitions are: -> all clauses C in Defs where lit appears negatively are added to Roots - - When a clause R is added to Roots: +- When a clause R is added to Roots: R contains a positive literal lit that is relevant -> skip adding R to Roots @@ -72,7 +72,7 @@ Can a literal that is not in a root be set relevant? - yes, if we propagate over expressions Do we need full watch lists instead of 2-watch lists? - - probably, but unclear. The dual SAT solver only uses 2-watch lists, but has uses a large clause for tracking + - probably, but unclear. The dual SAT solver only uses 2-watch lists, but uses a large clause for tracking roots. @@ -105,6 +105,7 @@ namespace smt { vector m_occurs; // where do literals occur unsigned m_qhead = 0; // queue head for relevancy svector> m_queue; // propagation queue for relevancy + ptr_vector m_relevant_eh; // callbacks during propagation void relevant_eh(euf::enode* n); @@ -131,6 +132,7 @@ namespace smt { void add_def(unsigned n, sat::literal const* lits); void asserted(sat::literal lit); void propagate(); + bool can_propagate() const { return m_qhead < m_queue.size(); } void mark_relevant(euf::enode* n); void mark_relevant(sat::literal lit); @@ -139,6 +141,9 @@ namespace smt { bool is_relevant(sat::literal lit) const { return !m_enabled || m_relevant_var_ids.get(lit.var(), false); } bool is_relevant(euf::enode* n) const { return !m_enabled || m_relevant_expr_ids.get(n->get_expr_id(), false); } bool is_relevant(expr* e) const { return !m_enabled || m_relevant_expr_ids.get(e->get_id(), false); } - + + bool enabled() const { return m_enabled; } + + void add_relevant(euf::th_solver* th) { m_relevant_eh.push_back(th); } }; }