From 62e39069576fa199255aafa51c8c98ecf5d8ec07 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 27 Nov 2017 10:53:22 -0800 Subject: [PATCH] add options to perform transitive reduction and add hyper binary clauses Signed-off-by: Nikolaj Bjorner --- src/sat/sat_config.cpp | 2 + src/sat/sat_config.h | 1 + src/sat/sat_elim_eqs.cpp | 8 +-- src/sat/sat_lookahead.cpp | 64 ++++++++++++++----- src/sat/sat_lookahead.h | 3 + src/sat/sat_params.pyg | 1 + src/sat/sat_scc.cpp | 123 ++++++++++++++++++++++++++++++++++++- src/sat/sat_scc.h | 9 +++ src/sat/sat_scc_params.pyg | 3 +- 9 files changed, 189 insertions(+), 25 deletions(-) diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index e062841b7..b672fe065 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -41,6 +41,7 @@ namespace sat { m_local_search = 0; m_lookahead_search = false; m_lookahead_simplify = false; + m_lookahead_simplify_bca = false; m_elim_vars = false; m_incremental = false; updt_params(p); @@ -90,6 +91,7 @@ namespace sat { m_local_search = p.local_search(); m_local_search_threads = p.local_search_threads(); m_lookahead_simplify = p.lookahead_simplify(); + m_lookahead_simplify_bca = p.lookahead_simplify_bca(); m_lookahead_search = p.lookahead_search(); if (p.lookahead_reward() == symbol("heule_schur")) { m_lookahead_reward = heule_schur_reward; diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index 4fc1f4e7e..132b85b24 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -84,6 +84,7 @@ namespace sat { bool m_local_search; bool m_lookahead_search; bool m_lookahead_simplify; + bool m_lookahead_simplify_bca; unsigned m_lookahead_cube_cutoff; double m_lookahead_cube_fraction; reward_t m_lookahead_reward; diff --git a/src/sat/sat_elim_eqs.cpp b/src/sat/sat_elim_eqs.cpp index 49b371104..cb659247f 100644 --- a/src/sat/sat_elim_eqs.cpp +++ b/src/sat/sat_elim_eqs.cpp @@ -34,11 +34,9 @@ namespace sat { } void elim_eqs::cleanup_bin_watches(literal_vector const & roots) { - vector::iterator it = m_solver.m_watches.begin(); - vector::iterator end = m_solver.m_watches.end(); - for (unsigned l_idx = 0; it != end; ++it, ++l_idx) { - watch_list & wlist = *it; - literal l1 = ~to_literal(l_idx); + unsigned l_idx = 0; + for (watch_list & wlist : m_solver.m_watches) { + literal l1 = ~to_literal(l_idx++); literal r1 = norm(roots, l1); watch_list::iterator it2 = wlist.begin(); watch_list::iterator itprev = it2; diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index bc5f9b071..b50d7aa2f 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -17,9 +17,10 @@ Author: Notes: --*/ -#include "sat_solver.h" -#include "sat_extension.h" -#include "sat_lookahead.h" +#include "sat/sat_solver.h" +#include "sat/sat_extension.h" +#include "sat/sat_lookahead.h" +#include "util/union_find.h" namespace sat { lookahead::scoped_ext::scoped_ext(lookahead& p): p(p) { @@ -648,7 +649,6 @@ namespace sat { TRACE("sat", display_scc(tout);); } void lookahead::init_scc() { - std::cerr << "init-scc\n"; inc_bstamp(); for (unsigned i = 0; i < m_candidates.size(); ++i) { literal lit(m_candidates[i].m_var, false); @@ -2290,20 +2290,51 @@ namespace sat { elim_eqs elim(m_s); elim(roots, to_elim); -#if 0 - // TBD: - // Finally create a new graph between parents - // based on the arcs in the the m_dfs[index].m_next structure - // Visit all nodes, assign DFS numbers - // Then prune binary clauses that differ in DFS number more than 1 - // Add binary clauses that have DFS number 1, but no companion binary clause. - // -#endif - + if (get_config().m_lookahead_simplify_bca) { + add_hyper_binary(); + } } } - m_lookahead.reset(); - + m_lookahead.reset(); + } + + /** + \brief reduction based on binary implication graph + */ + + void lookahead::add_hyper_binary() { + unsigned num_lits = m_s.num_vars() * 2; + union_find_default_ctx ufctx; + union_find uf(ufctx); + for (unsigned i = 0; i < num_lits; ++i) uf.mk_var(); + for (unsigned idx = 0; idx < num_lits; ++idx) { + literal u = get_parent(to_literal(idx)); + if (null_literal != u) { + for (watched const& w : m_s.m_watches[idx]) { + if (!w.is_binary_clause()) continue; + literal v = get_parent(w.get_literal()); + if (null_literal != v) { + uf.merge(u.index(), v.index()); + } + } + } + } + + unsigned disconnected = 0; + for (unsigned i = 0; i < m_binary.size(); ++i) { + literal u = to_literal(i); + if (u == get_parent(u)) { + for (literal v : m_binary[i]) { + if (v == get_parent(v) && uf.find(u.index()) != uf.find(v.index())) { + ++disconnected; + uf.merge(u.index(), v.index()); + m_s.mk_clause(~u, v, true); + } + } + } + } + IF_VERBOSE(10, verbose_stream() << "(sat-lookahead :bca " << disconnected << ")\n";); + m_stats.m_bca += disconnected; } void lookahead::normalize_parents() { @@ -2378,6 +2409,7 @@ namespace sat { void lookahead::collect_statistics(statistics& st) const { st.update("lh bool var", m_vprefix.size()); // TBD: keep count of ternary and >3-ary clauses. + st.update("lh bca", m_stats.m_bca); st.update("lh add binary", m_stats.m_add_binary); st.update("lh del binary", m_stats.m_del_binary); st.update("lh propagations", m_stats.m_propagations); diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 99d5511ea..6f0c256a7 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -115,6 +115,7 @@ namespace sat { struct stats { unsigned m_propagations; + unsigned m_bca; unsigned m_add_binary; unsigned m_del_binary; unsigned m_decisions; @@ -533,6 +534,8 @@ namespace sat { void normalize_parents(); + void add_hyper_binary(); + public: lookahead(solver& s) : m_s(s), diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index bae1d0c11..887cea2a7 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -41,6 +41,7 @@ def_module_params('sat', ('lookahead_search', BOOL, False, 'use lookahead solver'), ('lookahead.preselect', BOOL, False, 'use pre-selection of subset of variables for branching'), ('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'), + ('lookahead_simplify.bca', BOOL, False, 'add learned binary clauses as part of lookahead simplification'), ('lookahead.global_autarky', BOOL, False, 'prefer to branch on variables that occur in clauses that are reduced'), ('lookahead.reward', SYMBOL, 'march_cu', 'select lookahead heuristic: ternary, heule_schur (Heule Schur), heuleu (Heule Unit), unit, or march_cu'), ('dimacs.inprocess.display', BOOL, False, 'display SAT instance in DIMACS format if unsolved after inprocess.max inprocessing passes'))) diff --git a/src/sat/sat_scc.cpp b/src/sat/sat_scc.cpp index 544ef1f66..df574f4a8 100644 --- a/src/sat/sat_scc.cpp +++ b/src/sat/sat_scc.cpp @@ -45,16 +45,20 @@ namespace sat { scc & m_scc; stopwatch m_watch; unsigned m_num_elim; + unsigned m_num_elim_bin; report(scc & c): m_scc(c), - m_num_elim(c.m_num_elim) { + m_num_elim(c.m_num_elim), + m_num_elim_bin(c.m_num_elim_bin) { m_watch.start(); } ~report() { m_watch.stop(); + unsigned elim_bin = m_scc.m_num_elim_bin - m_num_elim_bin; IF_VERBOSE(SAT_VB_LVL, - verbose_stream() << " (sat-scc :elim-vars " << (m_scc.m_num_elim - m_num_elim) - << mk_stat(m_scc.m_solver) + verbose_stream() << " (sat-scc :elim-vars " << (m_scc.m_num_elim - m_num_elim); + if (elim_bin > 0) verbose_stream() << " :elim-bin " << elim_bin; + verbose_stream() << mk_stat(m_scc.m_solver) << " :time " << std::fixed << std::setprecision(2) << m_watch.get_seconds() << ")\n";); } }; @@ -223,20 +227,133 @@ namespace sat { eliminator(roots, to_elim); TRACE("scc_detail", m_solver.display(tout);); CASSERT("scc_bug", m_solver.check_invariant()); + + if (m_scc_tr) { + reduce_tr(); + } return to_elim.size(); } + void scc::get_dfs_num(svector& dfs, bool learned) { + unsigned num_lits = m_solver.num_vars() * 2; + vector dag(num_lits); + svector roots(num_lits, true); + literal_vector todo; + SASSERT(dfs.size() == num_lits); + unsigned num_edges = 0; + + // retrieve DAG + for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) { + literal u(to_literal(l_idx)); + if (m_solver.was_eliminated(u.var())) + continue; + auto& edges = dag[u.index()]; + for (watched const& w : m_solver.m_watches[l_idx]) { + if (learned ? w.is_binary_clause() : w.is_binary_unblocked_clause()) { + literal v = w.get_literal(); + roots[v.index()] = false; + edges.push_back(v); + ++num_edges; + } + } + unsigned sz = edges.size(); + // shuffle vertices to obtain different DAG traversal each time + if (sz > 1) { + for (unsigned i = sz; i-- > 0; ) { + std::swap(edges[i], edges[m_rand(i+1)]); + } + } + } + // std::cout << "dag num edges: " << num_edges << "\n"; + // retrieve literals that have no predecessors + for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) { + literal u(to_literal(l_idx)); + if (roots[u.index()]) todo.push_back(u); + } + // std::cout << "num roots: " << roots.size() << "\n"; + // traverse DAG, annotate nodes with DFS number + int dfs_num = 0; + while (!todo.empty()) { + literal u = todo.back(); + int& d = dfs[u.index()]; + // already visited + if (d > 0) { + todo.pop_back(); + } + // visited as child: + else if (d < 0) { + d = -d; + for (literal v : dag[u.index()]) { + if (dfs[v.index()] == 0) { + dfs[v.index()] = - d - 1; + todo.push_back(v); + } + } + } + // new root. + else { + d = --dfs_num; + } + } + } + + bool scc::reduce_tr(svector const& dfs, bool learned) { + unsigned idx = 0; + bool reduced = false; + for (watch_list & wlist : m_solver.m_watches) { + literal u = to_literal(idx++); + watch_list::iterator it = wlist.begin(); + watch_list::iterator itprev = it; + watch_list::iterator end = wlist.end(); + for (; it != end; ++it) { + watched& w = *it; + if (learned ? w.is_binary_learned_clause() : w.is_binary_unblocked_clause()) { + literal v = w.get_literal(); + if (dfs[u.index()] + 1 < dfs[v.index()]) { + ++m_num_elim_bin; + reduced = true; + } + else { + *itprev = *it; + itprev++; + } + } + else { + *itprev = *it; + itprev++; + } + } + wlist.set_end(itprev); + } + return reduced; + } + + bool scc::reduce_tr(bool learned) { + unsigned num_lits = m_solver.num_vars() * 2; + svector dfs(num_lits); + get_dfs_num(dfs, learned); + return reduce_tr(dfs, learned); + } + + void scc::reduce_tr() { + while (reduce_tr(false)) {} + while (reduce_tr(true)) {} + } + void scc::collect_statistics(statistics & st) const { st.update("elim bool vars", m_num_elim); + st.update("elim binary", m_num_elim_bin); } void scc::reset_statistics() { m_num_elim = 0; + m_num_elim_bin = 0; } void scc::updt_params(params_ref const & _p) { sat_scc_params p(_p); m_scc = p.scc(); + m_scc_tr = p.scc_tr(); } void scc::collect_param_descrs(param_descrs & d) { diff --git a/src/sat/sat_scc.h b/src/sat/sat_scc.h index c8392685e..cfad76d71 100644 --- a/src/sat/sat_scc.h +++ b/src/sat/sat_scc.h @@ -31,8 +31,17 @@ namespace sat { solver & m_solver; // config bool m_scc; + bool m_scc_tr; // stats unsigned m_num_elim; + unsigned m_num_elim_bin; + random_gen m_rand; + + void get_dfs_num(svector& dfs, bool learned); + void reduce_tr(); + bool reduce_tr(bool learned); + bool reduce_tr(svector const& dfs, bool learned); + public: scc(solver & s, params_ref const & p); unsigned operator()(); diff --git a/src/sat/sat_scc_params.pyg b/src/sat/sat_scc_params.pyg index b88de4de8..0bf88a0cd 100644 --- a/src/sat/sat_scc_params.pyg +++ b/src/sat/sat_scc_params.pyg @@ -1,5 +1,6 @@ def_module_params(module_name='sat', class_name='sat_scc_params', export=True, - params=(('scc', BOOL, True, 'eliminate Boolean variables by computing strongly connected components'),)) + params=(('scc', BOOL, True, 'eliminate Boolean variables by computing strongly connected components'), + ('scc.tr', BOOL, False, 'apply transitive reduction, eliminate redundant binary clauses'), ))