3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-13 20:38:43 +00:00

add options to perform transitive reduction and add hyper binary clauses

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-11-27 10:53:22 -08:00
parent 15d8532d27
commit 62e3906957
9 changed files with 189 additions and 25 deletions

View file

@ -41,6 +41,7 @@ namespace sat {
m_local_search = 0;
m_lookahead_search = false;
m_lookahead_simplify = false;
m_lookahead_simplify_bca = false;
m_elim_vars = false;
m_incremental = false;
updt_params(p);
@ -90,6 +91,7 @@ namespace sat {
m_local_search = p.local_search();
m_local_search_threads = p.local_search_threads();
m_lookahead_simplify = p.lookahead_simplify();
m_lookahead_simplify_bca = p.lookahead_simplify_bca();
m_lookahead_search = p.lookahead_search();
if (p.lookahead_reward() == symbol("heule_schur")) {
m_lookahead_reward = heule_schur_reward;

View file

@ -84,6 +84,7 @@ namespace sat {
bool m_local_search;
bool m_lookahead_search;
bool m_lookahead_simplify;
bool m_lookahead_simplify_bca;
unsigned m_lookahead_cube_cutoff;
double m_lookahead_cube_fraction;
reward_t m_lookahead_reward;

View file

@ -34,11 +34,9 @@ namespace sat {
}
void elim_eqs::cleanup_bin_watches(literal_vector const & roots) {
vector<watch_list>::iterator it = m_solver.m_watches.begin();
vector<watch_list>::iterator end = m_solver.m_watches.end();
for (unsigned l_idx = 0; it != end; ++it, ++l_idx) {
watch_list & wlist = *it;
literal l1 = ~to_literal(l_idx);
unsigned l_idx = 0;
for (watch_list & wlist : m_solver.m_watches) {
literal l1 = ~to_literal(l_idx++);
literal r1 = norm(roots, l1);
watch_list::iterator it2 = wlist.begin();
watch_list::iterator itprev = it2;

View file

@ -17,9 +17,10 @@ Author:
Notes:
--*/
#include "sat_solver.h"
#include "sat_extension.h"
#include "sat_lookahead.h"
#include "sat/sat_solver.h"
#include "sat/sat_extension.h"
#include "sat/sat_lookahead.h"
#include "util/union_find.h"
namespace sat {
lookahead::scoped_ext::scoped_ext(lookahead& p): p(p) {
@ -648,7 +649,6 @@ namespace sat {
TRACE("sat", display_scc(tout););
}
void lookahead::init_scc() {
std::cerr << "init-scc\n";
inc_bstamp();
for (unsigned i = 0; i < m_candidates.size(); ++i) {
literal lit(m_candidates[i].m_var, false);
@ -2290,20 +2290,51 @@ namespace sat {
elim_eqs elim(m_s);
elim(roots, to_elim);
#if 0
// TBD:
// Finally create a new graph between parents
// based on the arcs in the the m_dfs[index].m_next structure
// Visit all nodes, assign DFS numbers
// Then prune binary clauses that differ in DFS number more than 1
// Add binary clauses that have DFS number 1, but no companion binary clause.
//
#endif
if (get_config().m_lookahead_simplify_bca) {
add_hyper_binary();
}
}
}
m_lookahead.reset();
m_lookahead.reset();
}
/**
\brief reduction based on binary implication graph
*/
void lookahead::add_hyper_binary() {
unsigned num_lits = m_s.num_vars() * 2;
union_find_default_ctx ufctx;
union_find<union_find_default_ctx> uf(ufctx);
for (unsigned i = 0; i < num_lits; ++i) uf.mk_var();
for (unsigned idx = 0; idx < num_lits; ++idx) {
literal u = get_parent(to_literal(idx));
if (null_literal != u) {
for (watched const& w : m_s.m_watches[idx]) {
if (!w.is_binary_clause()) continue;
literal v = get_parent(w.get_literal());
if (null_literal != v) {
uf.merge(u.index(), v.index());
}
}
}
}
unsigned disconnected = 0;
for (unsigned i = 0; i < m_binary.size(); ++i) {
literal u = to_literal(i);
if (u == get_parent(u)) {
for (literal v : m_binary[i]) {
if (v == get_parent(v) && uf.find(u.index()) != uf.find(v.index())) {
++disconnected;
uf.merge(u.index(), v.index());
m_s.mk_clause(~u, v, true);
}
}
}
}
IF_VERBOSE(10, verbose_stream() << "(sat-lookahead :bca " << disconnected << ")\n";);
m_stats.m_bca += disconnected;
}
void lookahead::normalize_parents() {
@ -2378,6 +2409,7 @@ namespace sat {
void lookahead::collect_statistics(statistics& st) const {
st.update("lh bool var", m_vprefix.size());
// TBD: keep count of ternary and >3-ary clauses.
st.update("lh bca", m_stats.m_bca);
st.update("lh add binary", m_stats.m_add_binary);
st.update("lh del binary", m_stats.m_del_binary);
st.update("lh propagations", m_stats.m_propagations);

View file

@ -115,6 +115,7 @@ namespace sat {
struct stats {
unsigned m_propagations;
unsigned m_bca;
unsigned m_add_binary;
unsigned m_del_binary;
unsigned m_decisions;
@ -533,6 +534,8 @@ namespace sat {
void normalize_parents();
void add_hyper_binary();
public:
lookahead(solver& s) :
m_s(s),

View file

@ -41,6 +41,7 @@ def_module_params('sat',
('lookahead_search', BOOL, False, 'use lookahead solver'),
('lookahead.preselect', BOOL, False, 'use pre-selection of subset of variables for branching'),
('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'),
('lookahead_simplify.bca', BOOL, False, 'add learned binary clauses as part of lookahead simplification'),
('lookahead.global_autarky', BOOL, False, 'prefer to branch on variables that occur in clauses that are reduced'),
('lookahead.reward', SYMBOL, 'march_cu', 'select lookahead heuristic: ternary, heule_schur (Heule Schur), heuleu (Heule Unit), unit, or march_cu'),
('dimacs.inprocess.display', BOOL, False, 'display SAT instance in DIMACS format if unsolved after inprocess.max inprocessing passes')))

View file

@ -45,16 +45,20 @@ namespace sat {
scc & m_scc;
stopwatch m_watch;
unsigned m_num_elim;
unsigned m_num_elim_bin;
report(scc & c):
m_scc(c),
m_num_elim(c.m_num_elim) {
m_num_elim(c.m_num_elim),
m_num_elim_bin(c.m_num_elim_bin) {
m_watch.start();
}
~report() {
m_watch.stop();
unsigned elim_bin = m_scc.m_num_elim_bin - m_num_elim_bin;
IF_VERBOSE(SAT_VB_LVL,
verbose_stream() << " (sat-scc :elim-vars " << (m_scc.m_num_elim - m_num_elim)
<< mk_stat(m_scc.m_solver)
verbose_stream() << " (sat-scc :elim-vars " << (m_scc.m_num_elim - m_num_elim);
if (elim_bin > 0) verbose_stream() << " :elim-bin " << elim_bin;
verbose_stream() << mk_stat(m_scc.m_solver)
<< " :time " << std::fixed << std::setprecision(2) << m_watch.get_seconds() << ")\n";);
}
};
@ -223,20 +227,133 @@ namespace sat {
eliminator(roots, to_elim);
TRACE("scc_detail", m_solver.display(tout););
CASSERT("scc_bug", m_solver.check_invariant());
if (m_scc_tr) {
reduce_tr();
}
return to_elim.size();
}
void scc::get_dfs_num(svector<int>& dfs, bool learned) {
unsigned num_lits = m_solver.num_vars() * 2;
vector<literal_vector> dag(num_lits);
svector<bool> roots(num_lits, true);
literal_vector todo;
SASSERT(dfs.size() == num_lits);
unsigned num_edges = 0;
// retrieve DAG
for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) {
literal u(to_literal(l_idx));
if (m_solver.was_eliminated(u.var()))
continue;
auto& edges = dag[u.index()];
for (watched const& w : m_solver.m_watches[l_idx]) {
if (learned ? w.is_binary_clause() : w.is_binary_unblocked_clause()) {
literal v = w.get_literal();
roots[v.index()] = false;
edges.push_back(v);
++num_edges;
}
}
unsigned sz = edges.size();
// shuffle vertices to obtain different DAG traversal each time
if (sz > 1) {
for (unsigned i = sz; i-- > 0; ) {
std::swap(edges[i], edges[m_rand(i+1)]);
}
}
}
// std::cout << "dag num edges: " << num_edges << "\n";
// retrieve literals that have no predecessors
for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) {
literal u(to_literal(l_idx));
if (roots[u.index()]) todo.push_back(u);
}
// std::cout << "num roots: " << roots.size() << "\n";
// traverse DAG, annotate nodes with DFS number
int dfs_num = 0;
while (!todo.empty()) {
literal u = todo.back();
int& d = dfs[u.index()];
// already visited
if (d > 0) {
todo.pop_back();
}
// visited as child:
else if (d < 0) {
d = -d;
for (literal v : dag[u.index()]) {
if (dfs[v.index()] == 0) {
dfs[v.index()] = - d - 1;
todo.push_back(v);
}
}
}
// new root.
else {
d = --dfs_num;
}
}
}
bool scc::reduce_tr(svector<int> const& dfs, bool learned) {
unsigned idx = 0;
bool reduced = false;
for (watch_list & wlist : m_solver.m_watches) {
literal u = to_literal(idx++);
watch_list::iterator it = wlist.begin();
watch_list::iterator itprev = it;
watch_list::iterator end = wlist.end();
for (; it != end; ++it) {
watched& w = *it;
if (learned ? w.is_binary_learned_clause() : w.is_binary_unblocked_clause()) {
literal v = w.get_literal();
if (dfs[u.index()] + 1 < dfs[v.index()]) {
++m_num_elim_bin;
reduced = true;
}
else {
*itprev = *it;
itprev++;
}
}
else {
*itprev = *it;
itprev++;
}
}
wlist.set_end(itprev);
}
return reduced;
}
bool scc::reduce_tr(bool learned) {
unsigned num_lits = m_solver.num_vars() * 2;
svector<int> dfs(num_lits);
get_dfs_num(dfs, learned);
return reduce_tr(dfs, learned);
}
void scc::reduce_tr() {
while (reduce_tr(false)) {}
while (reduce_tr(true)) {}
}
void scc::collect_statistics(statistics & st) const {
st.update("elim bool vars", m_num_elim);
st.update("elim binary", m_num_elim_bin);
}
void scc::reset_statistics() {
m_num_elim = 0;
m_num_elim_bin = 0;
}
void scc::updt_params(params_ref const & _p) {
sat_scc_params p(_p);
m_scc = p.scc();
m_scc_tr = p.scc_tr();
}
void scc::collect_param_descrs(param_descrs & d) {

View file

@ -31,8 +31,17 @@ namespace sat {
solver & m_solver;
// config
bool m_scc;
bool m_scc_tr;
// stats
unsigned m_num_elim;
unsigned m_num_elim_bin;
random_gen m_rand;
void get_dfs_num(svector<int>& dfs, bool learned);
void reduce_tr();
bool reduce_tr(bool learned);
bool reduce_tr(svector<int> const& dfs, bool learned);
public:
scc(solver & s, params_ref const & p);
unsigned operator()();

View file

@ -1,5 +1,6 @@
def_module_params(module_name='sat',
class_name='sat_scc_params',
export=True,
params=(('scc', BOOL, True, 'eliminate Boolean variables by computing strongly connected components'),))
params=(('scc', BOOL, True, 'eliminate Boolean variables by computing strongly connected components'),
('scc.tr', BOOL, False, 'apply transitive reduction, eliminate redundant binary clauses'), ))