From da0aa710827912f1a8e6579285ed18283174ea1a Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 29 Nov 2017 21:21:56 -0800 Subject: [PATCH] adding uhle/uhte for faster asymmetric branching Signed-off-by: Nikolaj Bjorner --- src/sat/sat_asymm_branch.cpp | 124 ++++++++++++++++++++++++++++++++--- src/sat/sat_asymm_branch.h | 29 ++++++-- src/sat/sat_scc.cpp | 99 ++++++++++++++++------------ src/sat/sat_scc.h | 26 +++++++- 4 files changed, 221 insertions(+), 57 deletions(-) diff --git a/src/sat/sat_asymm_branch.cpp b/src/sat/sat_asymm_branch.cpp index aed29312a..f88f4c5a7 100644 --- a/src/sat/sat_asymm_branch.cpp +++ b/src/sat/sat_asymm_branch.cpp @@ -19,6 +19,7 @@ Revision History: #include "sat/sat_asymm_branch.h" #include "sat/sat_asymm_branch_params.hpp" #include "sat/sat_solver.h" +#include "sat/sat_scc.h" #include "util/stopwatch.h" #include "util/trace.h" @@ -26,6 +27,7 @@ namespace sat { asymm_branch::asymm_branch(solver & _s, params_ref const & p): s(_s), + m_params(p), m_counter(0) { updt_params(p); reset_statistics(); @@ -59,12 +61,12 @@ namespace sat { void asymm_branch::process(clause_vector& clauses) { int64 limit = -m_asymm_branch_limit; - std::stable_sort(s.m_clauses.begin(), s.m_clauses.end(), clause_size_lt()); - m_counter -= s.m_clauses.size(); + std::stable_sort(clauses.begin(), clauses.end(), clause_size_lt()); + m_counter -= clauses.size(); SASSERT(s.m_qhead == s.m_trail.size()); - clause_vector::iterator it = s.m_clauses.begin(); + clause_vector::iterator it = clauses.begin(); clause_vector::iterator it2 = it; - clause_vector::iterator end = s.m_clauses.end(); + clause_vector::iterator end = clauses.end(); try { for (; it != end; ++it) { if (s.inconsistent()) { @@ -86,14 +88,14 @@ namespace sat { *it2 = *it; ++it2; } - s.m_clauses.set_end(it2); + clauses.set_end(it2); } catch (solver_exception & ex) { // put m_clauses in a consistent state... for (; it != end; ++it, ++it2) { *it2 = *it; } - s.m_clauses.set_end(it2); + clauses.set_end(it2); m_counter = -m_counter; throw ex; } @@ -143,6 +145,99 @@ namespace sat { return true; } + void asymm_branch::setup_big() { + scc scc(s, m_params); + vector const& big = scc.get_big(true); // include learned binary clauses + + } + + struct asymm_branch::compare_left { + scc& s; + compare_left(scc& s): s(s) {} + bool operator()(literal u, literal v) const { + return s.get_left(u) < s.get_left(v); + } + }; + + void asymm_branch::sort(scc& scc, clause const& c) { + m_pos.reset(); m_neg.reset(); + for (literal l : c) { + m_pos.push_back(l); + m_neg.push_back(~l); + } + compare_left cmp(scc); + std::sort(m_pos.begin(), m_pos.end(), cmp); + std::sort(m_neg.begin(), m_neg.end(), cmp); + } + + bool asymm_branch::uhte(scc& scc, clause & c) { + unsigned pindex = 0, nindex = 0; + literal lpos = m_pos[pindex++]; + literal lneg = m_neg[nindex++]; + while (true) { + if (scc.get_left(lneg) > scc.get_left(lpos)) { + if (pindex == m_pos.size()) return false; + lpos = m_pos[pindex++]; + } + else if (scc.get_right(lneg) < scc.get_right(lpos) || + (m_pos.size() == 2 && (lpos == ~lneg || scc.get_parent(lpos) == lneg))) { + if (nindex == m_neg.size()) return false; + lneg = m_neg[nindex++]; + } + else { + return true; + } + } + return false; + } + + bool asymm_branch::uhle(scoped_detach& scoped_d, scc& scc, clause & c) { + int right = scc.get_right(m_pos.back()); + m_to_delete.reset(); + for (unsigned i = m_pos.size() - 1; i-- > 0; ) { + literal lit = m_pos[i]; + SASSERT(scc.get_left(lit) < scc.get_left(last)); + int right2 = scc.get_right(lit); + if (right2 > right) { + // lit => last, so lit can be deleted + m_to_delete.push_back(lit); + } + else { + right = right2; + } + } + right = scc.get_right(m_neg[0]); + for (unsigned i = 1; i < m_neg.size(); ++i) { + literal lit = m_neg[i]; + int right2 = scc.get_right(lit); + if (right > right2) { + // ~first => ~lit + m_to_delete.push_back(~lit); + } + else { + right = right2; + } + } + if (!m_to_delete.empty()) { + unsigned j = 0; + for (unsigned i = 0; i < c.size(); ++i) { + if (!m_to_delete.contains(c[i])) { + c[j] = c[i]; + ++j; + } + else { + m_pos.erase(c[i]); + m_neg.erase(~c[i]); + } + } + return re_attach(scoped_d, c, j); + } + else { + return true; + } + } + + bool asymm_branch::propagate_literal(clause const& c, literal l) { SASSERT(!s.inconsistent()); TRACE("asymm_branch_detail", tout << "assigning: " << l << "\n";); @@ -190,8 +285,12 @@ namespace sat { new_sz = j; m_elim_literals += c.size() - new_sz; // std::cout << "cleanup: " << c.id() << ": " << literal_vector(new_sz, c.begin()) << " delta: " << (c.size() - new_sz) << " " << skip_idx << " " << new_sz << "\n"; - switch(new_sz) { - case 0: + return re_attach(scoped_d, c, new_sz); + } + + bool asymm_branch::re_attach(scoped_detach& scoped_d, clause& c, unsigned new_sz) { + switch(new_sz) { + case 0: s.set_conflict(justification()); return false; case 1: @@ -216,6 +315,15 @@ namespace sat { } } + bool asymm_branch::process2(scc& scc, clause & c) { + scoped_detach scoped_d(s, c); + if (uhte(scc, c)) { + scoped_d.del_clause(); + return false; + } + return uhle(scoped_d, scc, c); + } + bool asymm_branch::process(clause & c) { if (c.is_blocked()) return true; TRACE("asymm_branch_detail", tout << "processing: " << c << "\n";); diff --git a/src/sat/sat_asymm_branch.h b/src/sat/sat_asymm_branch.h index 0de4c8d99..87e5cbac5 100644 --- a/src/sat/sat_asymm_branch.h +++ b/src/sat/sat_asymm_branch.h @@ -20,6 +20,7 @@ Revision History: #define SAT_ASYMM_BRANCH_H_ #include "sat/sat_types.h" +#include "sat/sat_scc.h" #include "util/statistics.h" #include "util/params.h" @@ -30,21 +31,37 @@ namespace sat { class asymm_branch { struct report; - solver & s; + solver & s; + params_ref m_params; int64 m_counter; random_gen m_rand; unsigned m_calls; // config - bool m_asymm_branch; - bool m_asymm_branch_all; - int64 m_asymm_branch_limit; + bool m_asymm_branch; + bool m_asymm_branch_all; + int64 m_asymm_branch_limit; // stats - unsigned m_elim_literals; + unsigned m_elim_literals; + + literal_vector m_pos, m_neg; // literals (complements of literals) in clauses sorted by discovery time (m_left in scc). + literal_vector m_to_delete; + + struct compare_left; + + void sort(scc & scc, clause const& c); + + bool uhle(scoped_detach& scoped_d, scc & scc, clause & c); + + bool uhte(scc & scc, clause & c); + + bool re_attach(scoped_detach& scoped_d, clause& c, unsigned new_sz); bool process(clause & c); + bool process2(scc& scc, clause & c); + void process(clause_vector & c); bool process_all(clause & c); @@ -55,6 +72,8 @@ namespace sat { bool propagate_literal(clause const& c, literal l); + void setup_big(); + public: asymm_branch(solver & s, params_ref const & p); diff --git a/src/sat/sat_scc.cpp b/src/sat/sat_scc.cpp index df574f4a8..9da54da29 100644 --- a/src/sat/sat_scc.cpp +++ b/src/sat/sat_scc.cpp @@ -234,58 +234,70 @@ namespace sat { return to_elim.size(); } - void scc::get_dfs_num(svector& dfs, bool learned) { - unsigned num_lits = m_solver.num_vars() * 2; - vector dag(num_lits); - svector roots(num_lits, true); - literal_vector todo; - SASSERT(dfs.size() == num_lits); - unsigned num_edges = 0; + // shuffle vertices to obtain different DAG traversal each time + void scc::shuffle(literal_vector& lits) { + unsigned sz = lits.size(); + if (sz > 1) { + for (unsigned i = sz; i-- > 0; ) { + std::swap(lits[i], lits[m_rand(i+1)]); + } + } + } - // retrieve DAG + vector const& scc::get_big(bool learned) { + unsigned num_lits = m_solver.num_vars() * 2; + m_dag.reset(); + m_roots.reset(); + m_dag.resize(num_lits, 0); + m_roots.resize(num_lits, true); + SASSERT(num_lits == m_dag.size() && num_lits == m_roots.size()); for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) { - literal u(to_literal(l_idx)); + literal u = to_literal(l_idx); if (m_solver.was_eliminated(u.var())) continue; - auto& edges = dag[u.index()]; + auto& edges = m_dag[l_idx]; for (watched const& w : m_solver.m_watches[l_idx]) { if (learned ? w.is_binary_clause() : w.is_binary_unblocked_clause()) { literal v = w.get_literal(); - roots[v.index()] = false; + m_roots[v.index()] = false; edges.push_back(v); - ++num_edges; - } - } - unsigned sz = edges.size(); - // shuffle vertices to obtain different DAG traversal each time - if (sz > 1) { - for (unsigned i = sz; i-- > 0; ) { - std::swap(edges[i], edges[m_rand(i+1)]); } } + shuffle(edges); } - // std::cout << "dag num edges: " << num_edges << "\n"; + return m_dag; + } + + void scc::get_dfs_num(bool learned) { + unsigned num_lits = m_solver.num_vars() * 2; + SASSERT(m_left.size() == num_lits); + SASSERT(m_right.size() == num_lits); + literal_vector todo; // retrieve literals that have no predecessors for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) { literal u(to_literal(l_idx)); - if (roots[u.index()]) todo.push_back(u); + if (m_roots[u.index()]) todo.push_back(u); } - // std::cout << "num roots: " << roots.size() << "\n"; - // traverse DAG, annotate nodes with DFS number + shuffle(todo); int dfs_num = 0; while (!todo.empty()) { literal u = todo.back(); - int& d = dfs[u.index()]; + int& d = m_left[u.index()]; // already visited if (d > 0) { + if (m_right[u.index()] < 0) { + m_right[u.index()] = dfs_num; + } todo.pop_back(); } // visited as child: else if (d < 0) { d = -d; - for (literal v : dag[u.index()]) { - if (dfs[v.index()] == 0) { - dfs[v.index()] = - d - 1; + for (literal v : m_dag[u.index()]) { + if (m_left[v.index()] == 0) { + m_left[v.index()] = - d - 1; + m_root[v.index()] = m_root[u.index()]; + m_parent[v.index()] = u; todo.push_back(v); } } @@ -297,9 +309,21 @@ namespace sat { } } - bool scc::reduce_tr(svector const& dfs, bool learned) { + unsigned scc::reduce_tr(bool learned) { + unsigned num_lits = m_solver.num_vars() * 2; + m_left.reset(); + m_right.reset(); + m_root.reset(); + m_parent.reset(); + m_left.resize(num_lits, 0); + m_right.resize(num_lits, -1); + for (unsigned i = 0; i < num_lits; ++i) { + m_root[i] = to_literal(i); + m_parent[i] = to_literal(i); + } + get_dfs_num(learned); unsigned idx = 0; - bool reduced = false; + unsigned elim = m_num_elim_bin; for (watch_list & wlist : m_solver.m_watches) { literal u = to_literal(idx++); watch_list::iterator it = wlist.begin(); @@ -309,9 +333,8 @@ namespace sat { watched& w = *it; if (learned ? w.is_binary_learned_clause() : w.is_binary_unblocked_clause()) { literal v = w.get_literal(); - if (dfs[u.index()] + 1 < dfs[v.index()]) { + if (m_left[u.index()] + 1 < m_left[v.index()]) { ++m_num_elim_bin; - reduced = true; } else { *itprev = *it; @@ -325,19 +348,13 @@ namespace sat { } wlist.set_end(itprev); } - return reduced; - } - - bool scc::reduce_tr(bool learned) { - unsigned num_lits = m_solver.num_vars() * 2; - svector dfs(num_lits); - get_dfs_num(dfs, learned); - return reduce_tr(dfs, learned); + return m_num_elim_bin - elim; } void scc::reduce_tr() { - while (reduce_tr(false)) {} - while (reduce_tr(true)) {} + unsigned quota = 0, num_reduced = 0; + while ((num_reduced = reduce_tr(false)) > quota) { quota = std::max(100u, num_reduced / 2); } + while ((num_reduced = reduce_tr(true)) > quota) { quota = std::max(100u, num_reduced / 2); } } void scc::collect_statistics(statistics & st) const { diff --git a/src/sat/sat_scc.h b/src/sat/sat_scc.h index cfad76d71..806cf8f33 100644 --- a/src/sat/sat_scc.h +++ b/src/sat/sat_scc.h @@ -37,12 +37,19 @@ namespace sat { unsigned m_num_elim_bin; random_gen m_rand; - void get_dfs_num(svector& dfs, bool learned); + // BIG state: + + vector m_dag; + svector m_roots; + svector m_left, m_right; + literal_vector m_root, m_parent; + + void shuffle(literal_vector& lits); void reduce_tr(); - bool reduce_tr(bool learned); - bool reduce_tr(svector const& dfs, bool learned); + unsigned reduce_tr(bool learned); public: + scc(solver & s, params_ref const & p); unsigned operator()(); @@ -51,6 +58,19 @@ namespace sat { void collect_statistics(statistics & st) const; void reset_statistics(); + + /* + \brief retrieve binary implication graph + */ + vector const& get_big(bool learned); + + int get_left(literal l) const { return m_left[l.index()]; } + int get_right(literal l) const { return m_right[l.index()]; } + literal get_parent(literal l) const { return m_parent[l.index()]; } + literal get_root(literal l) const { return m_root[l.index()]; } + + void get_dfs_num(bool learned); + }; };