From a5b663c52df3d15d5b83987c00587f01b0f453c1 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 17 Dec 2017 16:09:07 -0800 Subject: [PATCH] add unit walk engine Signed-off-by: Nikolaj Bjorner --- src/sat/CMakeLists.txt | 1 + src/sat/ba_solver.cpp | 2 +- src/sat/ba_solver.h | 25 ++- src/sat/sat_config.cpp | 2 + src/sat/sat_config.h | 2 + src/sat/sat_extension.h | 1 + src/sat/sat_parallel.cpp | 32 +++- src/sat/sat_parallel.h | 3 + src/sat/sat_params.pyg | 2 + src/sat/sat_solver.cpp | 121 ++++++++++--- src/sat/sat_solver.h | 4 +- src/sat/sat_types.h | 1 + src/sat/sat_unit_walk.cpp | 362 ++++++++++++++++++++++++++++++++++++++ src/sat/sat_unit_walk.h | 79 +++++++++ src/sat/sat_watched.cpp | 12 +- src/sat/sat_watched.h | 2 + 16 files changed, 604 insertions(+), 47 deletions(-) create mode 100644 src/sat/sat_unit_walk.cpp create mode 100644 src/sat/sat_unit_walk.h diff --git a/src/sat/CMakeLists.txt b/src/sat/CMakeLists.txt index d7da09826..320a674a2 100644 --- a/src/sat/CMakeLists.txt +++ b/src/sat/CMakeLists.txt @@ -24,6 +24,7 @@ z3_add_component(sat sat_scc.cpp sat_simplifier.cpp sat_solver.cpp + sat_unit_walk.cpp sat_watched.cpp COMPONENT_DEPENDENCIES util diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index edf1c82a2..e5623c05f 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -1518,7 +1518,7 @@ namespace sat { return p; } - ba_solver::ba_solver(): m_solver(0), m_lookahead(0), m_constraint_id(0) { + ba_solver::ba_solver(): m_solver(0), m_lookahead(0), m_unit_walk(0), m_constraint_id(0) { TRACE("ba", tout << this << "\n";); } diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index b5ea1e1ad..e7a220f5c 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -24,6 +24,7 @@ Revision History: #include "sat/sat_extension.h" #include "sat/sat_solver.h" #include "sat/sat_lookahead.h" +#include "sat/sat_unit_walk.h" #include "util/scoped_ptr_vector.h" #include "util/lp/lar_solver.h" @@ -204,6 +205,7 @@ namespace sat { solver* m_solver; lookahead* m_lookahead; + unit_walk* m_unit_walk; stats m_stats; small_object_allocator m_allocator; @@ -362,13 +364,25 @@ namespace sat { // access solver inline lbool value(bool_var v) const { return value(literal(v, false)); } inline lbool value(literal lit) const { return m_lookahead ? m_lookahead->value(lit) : m_solver->value(lit); } - inline unsigned lvl(literal lit) const { return m_lookahead ? 0 : m_solver->lvl(lit); } - inline unsigned lvl(bool_var v) const { return m_lookahead ? 0 : m_solver->lvl(v); } - inline bool inconsistent() const { return m_lookahead ? m_lookahead->inconsistent() : m_solver->inconsistent(); } + inline unsigned lvl(literal lit) const { return m_lookahead || m_unit_walk ? 0 : m_solver->lvl(lit); } + inline unsigned lvl(bool_var v) const { return m_lookahead || m_unit_walk ? 0 : m_solver->lvl(v); } + inline bool inconsistent() const { + if (m_lookahead) return m_lookahead->inconsistent(); + if (m_unit_walk) return m_unit_walk->inconsistent(); + return m_solver->inconsistent(); + } inline watch_list& get_wlist(literal l) { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } inline watch_list const& get_wlist(literal l) const { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } - inline void assign(literal l, justification j) { if (m_lookahead) m_lookahead->assign(l); else m_solver->assign(l, j); } - inline void set_conflict(justification j, literal l) { if (m_lookahead) m_lookahead->set_conflict(); else m_solver->set_conflict(j, l); } + inline void assign(literal l, justification j) { + if (m_lookahead) m_lookahead->assign(l); + else if (m_unit_walk) m_unit_walk->assign(l); + else m_solver->assign(l, j); + } + inline void set_conflict(justification j, literal l) { + if (m_lookahead) m_lookahead->set_conflict(); + else if (m_unit_walk) m_unit_walk->set_conflict(); + else m_solver->set_conflict(j, l); + } inline config const& get_config() const { return m_lookahead ? m_lookahead->get_config() : m_solver->get_config(); } inline void drat_add(literal_vector const& c, svector const& premises) { if (m_solver) m_solver->m_drat.add(c, premises); } @@ -434,6 +448,7 @@ namespace sat { virtual ~ba_solver(); virtual void set_solver(solver* s) { m_solver = s; } virtual void set_lookahead(lookahead* l) { m_lookahead = l; } + virtual void set_unit_walk(unit_walk* u) { m_unit_walk = u; } void add_at_least(bool_var v, literal_vector const& lits, unsigned k); void add_pb_ge(bool_var v, svector const& wlits, unsigned k); void add_xor(bool_var v, literal_vector const& lits); diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index cf26fc09e..4cf448394 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -72,6 +72,8 @@ namespace sat { m_num_threads = p.threads(); m_local_search = p.local_search(); m_local_search_threads = p.local_search_threads(); + m_unit_walk = p.unit_walk(); + m_unit_walk_threads = p.unit_walk_threads(); m_lookahead_simplify = p.lookahead_simplify(); m_lookahead_simplify_bca = p.lookahead_simplify_bca(); m_lookahead_simplify_asymm_branch = p.lookahead_simplify_asymm_branch(); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index c40266c93..2aa5a325b 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -90,6 +90,8 @@ namespace sat { unsigned m_num_threads; unsigned m_local_search_threads; bool m_local_search; + unsigned m_unit_walk_threads; + bool m_unit_walk; bool m_lookahead_simplify; bool m_lookahead_simplify_bca; bool m_lookahead_simplify_asymm_branch; diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index c2a9197c1..7db5d63ec 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -52,6 +52,7 @@ namespace sat { virtual ~extension() {} virtual void set_solver(solver* s) = 0; virtual void set_lookahead(lookahead* s) = 0; + virtual void set_unit_walk(unit_walk* u) = 0; virtual bool propagate(literal l, ext_constraint_idx idx) = 0; virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const = 0; virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) = 0; diff --git a/src/sat/sat_parallel.cpp b/src/sat/sat_parallel.cpp index 7ba1ba7c6..462439499 100644 --- a/src/sat/sat_parallel.cpp +++ b/src/sat/sat_parallel.cpp @@ -91,7 +91,7 @@ namespace sat { return false; } - parallel::parallel(solver& s): m_scoped_rlimit(s.rlimit()), m_num_clauses(0) {} + parallel::parallel(solver& s): m_scoped_rlimit(s.rlimit()), m_num_clauses(0), m_consumer_ready(false) {} parallel::~parallel() { for (unsigned i = 0; i < m_solvers.size(); ++i) { @@ -230,14 +230,14 @@ namespace sat { } } IF_VERBOSE(1, verbose_stream() << "set phase: " << m_num_clauses << " " << s.m_clauses.size() << " " << m_solver_copy << "\n";); - if (m_num_clauses == 0 || (m_num_clauses > s.m_clauses.size())) { - // time to update local search with new clauses. - // there could be multiple local search engines runing at the same time. - IF_VERBOSE(1, verbose_stream() << "(sat-parallel refresh local search " << m_num_clauses << " -> " << s.m_clauses.size() << ")\n";); - m_solver_copy = alloc(solver, s.m_params, s.rlimit()); - m_solver_copy->copy(s); - m_num_clauses = s.m_clauses.size(); - } + } + if (m_consumer_ready && (m_num_clauses == 0 || (m_num_clauses > s.m_clauses.size()))) { + // time to update local search with new clauses. + // there could be multiple local search engines runing at the same time. + IF_VERBOSE(1, verbose_stream() << "(sat-parallel refresh :from " << m_num_clauses << " :to " << s.m_clauses.size() << ")\n";); + m_solver_copy = alloc(solver, s.m_params, s.rlimit()); + m_solver_copy->copy(s); + m_num_clauses = s.m_clauses.size(); } } @@ -285,6 +285,7 @@ namespace sat { void parallel::set_phase(local_search& s) { #pragma omp critical (par_solver) { + m_consumer_ready = true; m_phase.reserve(s.num_vars(), l_undef); for (unsigned i = 0; i < s.num_vars(); ++i) { m_phase[i] = s.get_phase(i) ? l_true : l_false; @@ -293,6 +294,19 @@ namespace sat { } } + bool parallel::copy_solver(solver& s) { + bool copied = false; + #pragma omp critical (par_solver) + { + m_consumer_ready = true; + if (m_solver_copy && s.m_clauses.size() > m_solver_copy->m_clauses.size()) { + s.copy(*m_solver_copy); + copied = true; + m_num_clauses = s.m_clauses.size(); + } + } + return copied; + } }; diff --git a/src/sat/sat_parallel.h b/src/sat/sat_parallel.h index f09f07f51..f37c99151 100644 --- a/src/sat/sat_parallel.h +++ b/src/sat/sat_parallel.h @@ -65,6 +65,7 @@ namespace sat { svector m_phase; unsigned m_num_clauses; scoped_ptr m_solver_copy; + bool m_consumer_ready; scoped_limits m_scoped_rlimit; vector m_limits; @@ -106,6 +107,8 @@ namespace sat { void set_phase(local_search& s); void get_phase(local_search& s); + + bool copy_solver(solver& s); }; }; diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 9d1585e7a..c8d47ddd3 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -38,6 +38,8 @@ def_module_params('sat', ('atmost1_encoding', SYMBOL, 'grouped', 'encoding used for at-most-1 constraints grouped, bimander, ordered'), ('local_search', BOOL, False, 'use local search instead of CDCL'), ('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'), + ('unit_walk', BOOL, False, 'use unit-walk search instead of CDCL'), + ('unit_walk_threads', UINT, 0, 'number of unit-walk search threads to find satisfiable solution'), ('lookahead.cube.cutoff', SYMBOL, 'adaptive_freevars', 'cutoff type used to create lookahead cubes: depth, freevars, psat, adaptive_freevars, adaptive_psat'), ('lookahead.cube.fraction', DOUBLE, 0.4, 'adaptive fraction to create lookahead cubes. Used when lookahead.cube.cutoff is adaptive_freevars or adaptive_psat'), ('lookahead.cube.depth', UINT, 10, 'cut-off depth to create cubes. Used when lookahead.cube.cutoff is depth.'), diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index b9d4b0132..de75882af 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -20,6 +20,7 @@ Revision History: #include "sat/sat_solver.h" #include "sat/sat_integrity_checker.h" #include "sat/sat_lookahead.h" +#include "sat/sat_unit_walk.h" #include "util/luby.h" #include "util/trace.h" #include "util/max_cliques.h" @@ -69,15 +70,15 @@ namespace sat { m_ext = 0; SASSERT(check_invariant()); TRACE("sat", tout << "Delete clauses\n";); - del_clauses(m_clauses.begin(), m_clauses.end()); + del_clauses(m_clauses); TRACE("sat", tout << "Delete learned\n";); - del_clauses(m_learned.begin(), m_learned.end()); + del_clauses(m_learned); } - void solver::del_clauses(clause * const * begin, clause * const * end) { - for (clause * const * it = begin; it != end; ++it) { - m_cls_allocator.del_clause(*it); - } + void solver::del_clauses(clause_vector& clauses) { + for (clause * cp : clauses) + m_cls_allocator.del_clause(cp); + clauses.reset(); ++m_stats.m_non_learned_generation; } @@ -88,24 +89,46 @@ namespace sat { void solver::copy(solver const & src) { pop_to_base_level(); + del_clauses(m_clauses); + del_clauses(m_learned); + m_watches.reset(); + m_assignment.reset(); + m_justification.reset(); + m_decision.reset(); + m_eliminated.reset(); + m_activity.reset(); + m_level.reset(); + m_mark.reset(); + m_lit_mark.reset(); + m_phase.reset(); + m_prev_phase.reset(); + m_assigned_since_gc.reset(); + m_last_conflict.reset(); + m_last_propagation.reset(); + m_participated.reset(); + m_canceled.reset(); + m_reasoned.reset(); + m_simplifier.reset_todos(); + m_qhead = 0; + m_trail.reset(); + m_scopes.reset(); + // create new vars - if (num_vars() < src.num_vars()) { - for (bool_var v = num_vars(); v < src.num_vars(); v++) { - bool ext = src.m_external[v] != 0; - bool dvar = src.m_decision[v] != 0; - VERIFY(v == mk_var(ext, dvar)); - if (src.was_eliminated(v)) { - m_eliminated[v] = true; - } - m_phase[v] = src.m_phase[v]; - m_prev_phase[v] = src.m_prev_phase[v]; - -#if 1 - // inherit activity: - m_activity[v] = src.m_activity[v]; - m_case_split_queue.activity_changed_eh(v, false); -#endif + for (bool_var v = num_vars(); v < src.num_vars(); v++) { + bool ext = src.m_external[v] != 0; + bool dvar = src.m_decision[v] != 0; + VERIFY(v == mk_var(ext, dvar)); + if (src.was_eliminated(v)) { + m_eliminated[v] = true; } + m_phase[v] = src.m_phase[v]; + m_prev_phase[v] = src.m_prev_phase[v]; + +#if 1 + // inherit activity: + m_activity[v] = src.m_activity[v]; + m_case_split_queue.activity_changed_eh(v, false); +#endif } // @@ -891,7 +914,7 @@ namespace sat { if (m_config.m_local_search) { return do_local_search(num_lits, lits); } - if ((m_config.m_num_threads > 1 || m_config.m_local_search_threads > 0) && !m_par) { + if ((m_config.m_num_threads > 1 || m_config.m_local_search_threads > 0 || m_config.m_unit_walk_threads > 0) && !m_par) { SASSERT(scope_lvl() == 0); return check_par(num_lits, lits); } @@ -909,6 +932,10 @@ namespace sat { propagate(false); if (check_inconsistent()) return l_false; cleanup(); + + if (m_config.m_unit_walk) { + return do_unit_walk(); + } if (m_config.m_gc_burst) { // force gc m_conflicts_since_gc = m_gc_threshold + 1; @@ -988,11 +1015,19 @@ namespace sat { return r; } + lbool solver::do_unit_walk() { + unit_walk srch(*this); + lbool r = srch(); + return r; + } + lbool solver::check_par(unsigned num_lits, literal const* lits) { scoped_ptr_vector ls; - int num_threads = static_cast(m_config.m_num_threads + m_config.m_local_search_threads); + scoped_ptr_vector uw; int num_extra_solvers = m_config.m_num_threads - 1; int num_local_search = static_cast(m_config.m_local_search_threads); + int num_unit_walk = static_cast(m_config.m_unit_walk_threads); + int num_threads = num_extra_solvers + 1 + num_local_search + num_unit_walk; for (int i = 0; i < num_local_search; ++i) { local_search* l = alloc(local_search); l->config().set_seed(m_config.m_random_seed + i); @@ -1000,9 +1035,23 @@ namespace sat { ls.push_back(l); } + // set up unit walk + vector lims(num_unit_walk); + for (int i = 0; i < num_unit_walk; ++i) { + solver* s = alloc(solver, m_params, lims[i]); + s->copy(*this); + s->m_config.m_unit_walk = true; + uw.push_back(s); + } + + int local_search_offset = num_extra_solvers; + int unit_walk_offset = num_extra_solvers + num_local_search; + int main_solver_offset = unit_walk_offset + num_unit_walk; + #define IS_AUX_SOLVER(i) (0 <= i && i < num_extra_solvers) -#define IS_LOCAL_SEARCH(i) (num_extra_solvers <= i && i + 1 < num_threads) -#define IS_MAIN_SOLVER(i) (i + 1 == num_threads) +#define IS_LOCAL_SEARCH(i) (local_search_offset <= i && i < unit_walk_offset) +#define IS_UNIT_WALK(i) (unit_walk_offset <= i && i < main_solver_offset) +#define IS_MAIN_SOLVER(i) (i == main_solver_offset) sat::parallel par(*this); par.reserve(num_threads, 1 << 12); @@ -1010,6 +1059,12 @@ namespace sat { for (unsigned i = 0; i < ls.size(); ++i) { par.push_child(ls[i]->rlimit()); } + for (reslimit& rl : lims) { + par.push_child(rl); + } + for (unsigned i = 0; i < uw.size(); ++i) { + uw[i]->set_par(&par, 0); + } int finished_id = -1; std::string ex_msg; par_exception_kind ex_kind = DEFAULT_EX; @@ -1024,7 +1079,10 @@ namespace sat { r = par.get_solver(i).check(num_lits, lits); } else if (IS_LOCAL_SEARCH(i)) { - r = ls[i-num_extra_solvers]->check(num_lits, lits, &par); + r = ls[i-local_search_offset]->check(num_lits, lits); + } + else if (IS_UNIT_WALK(i)) { + r = uw[i-unit_walk_offset]->check(num_lits, lits); } else { r = check(num_lits, lits); @@ -1042,6 +1100,9 @@ namespace sat { for (unsigned j = 0; j < ls.size(); ++j) { ls[j]->rlimit().cancel(); } + for (auto& rl : lims) { + rl.cancel(); + } for (int j = 0; j < num_extra_solvers; ++j) { if (i != j) { par.cancel_solver(j); @@ -1076,13 +1137,17 @@ namespace sat { m_core.append(par.get_solver(finished_id).get_core()); } if (result == l_true && IS_LOCAL_SEARCH(finished_id)) { - set_model(ls[finished_id - num_extra_solvers]->get_model()); + set_model(ls[finished_id - local_search_offset]->get_model()); + } + if (result == l_true && IS_UNIT_WALK(finished_id)) { + set_model(uw[finished_id - unit_walk_offset]->get_model()); } if (!canceled) { rlimit().reset_cancel(); } set_par(0, 0); ls.reset(); + uw.reset(); if (finished_id == -1) { switch (ex_kind) { case ERROR_EX: throw z3_error(error_code); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index bac9c4bb3..56d1ace39 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -163,7 +163,7 @@ namespace sat { statistics m_aux_stats; - void del_clauses(clause * const * begin, clause * const * end); + void del_clauses(clause_vector& clauses); friend class integrity_checker; friend class cleaner; @@ -180,6 +180,7 @@ namespace sat { friend class parallel; friend class lookahead; friend class local_search; + friend class unit_walk; friend struct mk_stat; friend class elim_vars; friend class scoped_detach; @@ -398,6 +399,7 @@ namespace sat { void exchange_par(); lbool check_par(unsigned num_lits, literal const* lits); lbool do_local_search(unsigned num_lits, literal const* lits); + lbool do_unit_walk(); // ----------------------- // diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index 5eded92ec..002e49006 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -125,6 +125,7 @@ namespace sat { class solver; class lookahead; + class unit_walk; class clause; class clause_wrapper; class integrity_checker; diff --git a/src/sat/sat_unit_walk.cpp b/src/sat/sat_unit_walk.cpp new file mode 100644 index 000000000..7d556c740 --- /dev/null +++ b/src/sat/sat_unit_walk.cpp @@ -0,0 +1,362 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + sat_unit_walk.cpp + +Abstract: + + unit walk local search procedure. + A variant of UnitWalk. Hirsch and Kojevinkov, SAT 2001. + This version uses a trail to reset assignments and integrates directly with the + watch list structure. Thus, assignments are not delayed and we avoid treating + pending units as a multi-set. + + It uses standard DPLL approach for backracking, flipping the last decision literal that + lead to a conflict. It restarts after evern 100 conflicts. + + It does not attempt to add conflict clauses or alternate with + walksat. + + It can receive conflict clauses from a concurrent CDCL solver and does not + create its own conflict clauses. + + The phase of variables is optionally sticky between rounds. We use a decay rate + to compute stickiness of a variable. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-12-15. + +Revision History: + +--*/ + +#include "sat_unit_walk.h" + +namespace sat { + + unit_walk::unit_walk(solver& s): + s(s) + { + m_runs = 0; + m_periods = 0; + m_max_runs = UINT_MAX; + m_max_periods = 100; // 5000; // UINT_MAX; // TBD configure + m_max_conflicts = 100; + m_sticky_phase = true; + m_flips = 0; + } + + class scoped_set_unit_walk { + solver& s; + public: + scoped_set_unit_walk(unit_walk* u, solver& s): s(s) { + if (s.get_extension()) s.get_extension()->set_unit_walk(u); + } + ~scoped_set_unit_walk() { + if (s.get_extension()) s.get_extension()->set_unit_walk(nullptr); + } + }; + + lbool unit_walk::operator()() { + scoped_set_unit_walk _scoped_set(this, s); + init_runs(); + for (m_runs = 0; m_runs < m_max_runs || m_max_runs == UINT_MAX; ++m_runs) { + init_propagation(); + init_phase(); + for (m_periods = 0; m_periods < m_max_periods || m_max_periods == UINT_MAX; ++m_periods) { + if (!s.rlimit().inc()) return l_undef; + lbool r = unit_propagation(); + if (r != l_undef) return r; + } + } + return l_undef; + } + + lbool unit_walk::unit_propagation() { + init_propagation(); + while (!m_freevars.empty() && !inconsistent()) { + bool_var v = m_freevars.begin()[m_rand(m_freevars.size())]; + literal lit(v, !m_phase[v]); + ++s.m_stats.m_decision; + m_decisions.push_back(lit); + assign(lit); + propagate(); + while (inconsistent() && !m_decisions.empty()) { + ++m_conflicts; + backtrack(); + propagate(); + } + if (m_conflicts >= m_max_conflicts && !m_freevars.empty()) { + set_conflict(); + break; + } + } + if (!inconsistent()) { + log_status(); + IF_VERBOSE(1, verbose_stream() << "(sat-unit-walk sat)\n";); + s.mk_model(); + return l_true; + } + return l_undef; + } + + void unit_walk::init_runs() { + m_freevars.reset(); + m_trail.reset(); + m_decisions.reset(); + m_phase.resize(s.num_vars()); + double2 d2; + d2.t = 1.0; + d2.f = 1.0; + m_phase_tf.resize(s.num_vars(), d2); + for (unsigned i = 0; i < s.num_vars(); ++i) { + literal l(i, false); + if (!s.was_eliminated(l.var()) && s.m_assignment[l.index()] == l_undef) + m_freevars.insert(l.var()); + } + IF_VERBOSE(1, verbose_stream() << "num vars: " << s.num_vars() << " free vars: " << m_freevars.size() << "\n";); + } + + void unit_walk::init_phase() { + m_max_trail = 0; + if (m_sticky_phase) { + for (bool_var v : m_freevars) { + m_phase[v] = m_rand(100 * static_cast(m_phase_tf[v].t + m_phase_tf[v].f)) <= 100 * m_phase_tf[v].t; + } + } + else { + for (bool_var v : m_freevars) + m_phase[v] = (m_rand(2) == 0); + } + } + + void unit_walk::init_propagation() { + if (s.m_par && s.m_par->copy_solver(s)) { + IF_VERBOSE(1, verbose_stream() << "(sat-unit-walk fresh copy)\n";); + if (s.get_extension()) s.get_extension()->set_unit_walk(this); + init_runs(); + init_phase(); + } + if (m_max_trail == 0 || m_trail.size() > m_max_trail) { + m_max_trail = m_trail.size(); + log_status(); + } + for (literal lit : m_trail) { + s.m_assignment[lit.index()] = l_undef; + s.m_assignment[(~lit).index()] = l_undef; + m_freevars.insert(lit.var()); + } + m_flips = 0; + m_trail.reset(); + m_conflicts = 0; + m_decisions.reset(); + m_qhead = 0; + m_inconsistent = false; + } + + void unit_walk::propagate() { + while (m_qhead < m_trail.size() && !inconsistent()) + propagate(choose_literal()); + // IF_VERBOSE(1, verbose_stream() << m_trail.size() << " " << inconsistent() << "\n";); + } + + void unit_walk::propagate(literal l) { + ++s.m_stats.m_propagate; + literal not_l = ~l; + literal l1, l2; + lbool val1, val2; + bool keep; + watch_list & wlist = s.get_wlist(l); + watch_list::iterator it = wlist.begin(); + watch_list::iterator it2 = it; + watch_list::iterator end = wlist.end(); + for (; it != end; ++it) { + switch (it->get_kind()) { + case watched::BINARY: + l1 = it->get_literal(); + switch (value(l1)) { + case l_false: + conflict_cleanup(it, it2, wlist); + set_conflict(l,l1); + return; + case l_undef: + assign(l1); + break; + case l_true: + break; // skip + } + *it2 = *it; + it2++; + break; + case watched::TERNARY: + l1 = it->get_literal1(); + l2 = it->get_literal2(); + val1 = value(l1); + val2 = value(l2); + if (val1 == l_false && val2 == l_undef) { + assign(l2); + } + else if (val1 == l_undef && val2 == l_false) { + assign(l1); + } + else if (val1 == l_false && val2 == l_false) { + conflict_cleanup(it, it2, wlist); + set_conflict(l,l1,l2); + return; + } + *it2 = *it; + it2++; + break; + case watched::CLAUSE: { + if (value(it->get_blocked_literal()) == l_true) { + *it2 = *it; + it2++; + break; + } + clause_offset cls_off = it->get_clause_offset(); + clause & c = s.get_clause(cls_off); + if (c[0] == not_l) + std::swap(c[0], c[1]); + if (c[1] != not_l) { + *it2 = *it; + it2++; + break; + } + if (value(c[0]) == l_true) { + it2->set_clause(c[0], cls_off); + it2++; + break; + } + SASSERT(c[1] == not_l); + literal * l_it = c.begin() + 2; + literal * l_end = c.end(); + for (; l_it != l_end; ++l_it) { + if (value(*l_it) != l_false) { + c[1] = *l_it; + *l_it = not_l; + s.get_wlist((~c[1]).index()).push_back(watched(c[0], cls_off)); + goto end_clause_case; + } + } + SASSERT(value(c[0]) == l_false || value(c[0]) == l_undef); + if (value(c[0]) == l_false) { + c.mark_used(); + conflict_cleanup(it, it2, wlist); + set_conflict(c); + return; + } + else { + *it2 = *it; + it2++; + assign(c[0]); + } + end_clause_case: + break; + } + case watched::EXT_CONSTRAINT: + SASSERT(s.get_extension()); + keep = s.get_extension()->propagate(l, it->get_ext_constraint_idx()); + if (inconsistent()) { + if (!keep) { + ++it; + } + set_conflict(l, l); + conflict_cleanup(it, it2, wlist); + return; + } + if (keep) { + *it2 = *it; + it2++; + } + break; + default: + UNREACHABLE(); + break; + } + } + wlist.set_end(it2); + } + + void unit_walk::assign(literal lit) { + SASSERT(value(lit) == l_undef); + s.m_assignment[lit.index()] = l_true; + s.m_assignment[(~lit).index()] = l_false; + m_trail.push_back(lit); + m_freevars.remove(lit.var()); + if (s.get_extension() && s.is_external(lit.var())) { + s.get_extension()->asserted(lit); + } + if (m_phase[lit.var()] == lit.sign()) { + ++m_flips; + flip_phase(lit); + } + } + + void unit_walk::flip_phase(literal l) { + bool_var v = l.var(); + m_phase[v] = !m_phase[v]; + if (m_sticky_phase) { + m_phase_tf[v].f *= 0.98; + m_phase_tf[v].t *= 0.98; + if (m_phase[v]) m_phase_tf[v].t += 1; else m_phase_tf[v].f += 1; + } + } + + void unit_walk::log_status() { + IF_VERBOSE(1, verbose_stream() << "(sat-unit-walk :trail " << m_max_trail + << " :branches " << m_decisions.size() + << " :free " << m_freevars.size() + << " :periods " << m_periods + << " :decisions " << s.m_stats.m_decision + << " :propagations " << s.m_stats.m_propagate + << ")\n";); + } + + literal unit_walk::choose_literal() { + SASSERT(m_qhead < m_trail.size()); + unsigned idx = m_rand(m_trail.size() - m_qhead); + std::swap(m_trail[m_qhead], m_trail[m_qhead + idx]); + literal lit = m_trail[m_qhead++]; + return lit; + } + + void unit_walk::set_conflict(literal l1, literal l2) { + set_conflict(); + } + + void unit_walk::set_conflict(literal l1, literal l2, literal l3) { + set_conflict(); + } + + void unit_walk::set_conflict(clause const& c) { + set_conflict(); + } + + void unit_walk::set_conflict() { + m_inconsistent = true; + } + + void unit_walk::backtrack() { + if (m_decisions.empty()) return; + literal dlit = m_decisions.back(); + literal lit; + do { + SASSERT(!m_trail.empty()); + lit = m_trail.back(); + s.m_assignment[lit.index()] = l_undef; + s.m_assignment[(~lit).index()] = l_undef; + m_freevars.insert(lit.var()); + m_trail.pop_back(); + } + while (lit != dlit); + m_inconsistent = false; + m_decisions.pop_back(); + m_qhead = m_trail.size(); + assign(~dlit); + } + +}; + diff --git a/src/sat/sat_unit_walk.h b/src/sat/sat_unit_walk.h new file mode 100644 index 000000000..8ab9bab70 --- /dev/null +++ b/src/sat/sat_unit_walk.h @@ -0,0 +1,79 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + sat_unit_walk.h + +Abstract: + + unit walk local search procedure. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-12-15. + +Revision History: + +--*/ +#ifndef SAT_UNIT_WALK_H_ +#define SAT_UNIT_WALK_H_ + +#include "sat/sat_solver.h" + +namespace sat { + + class unit_walk { + struct double2 { + double t, f; + }; + solver& s; + random_gen m_rand; + svector m_phase; + svector m_phase_tf; + indexed_uint_set m_freevars; + unsigned m_runs; + unsigned m_periods; + + // settings + unsigned m_max_runs; + unsigned m_max_periods; + unsigned m_max_conflicts; + bool m_sticky_phase; + + unsigned m_propagations; + unsigned m_flips; + unsigned m_max_trail; + unsigned m_qhead; + literal_vector m_trail; + bool m_inconsistent; + literal_vector m_decisions; + unsigned m_conflicts; + + void push(); + void backtrack(); + void init_runs(); + void init_phase(); + void init_propagation(); + void flip_phase(literal l); + lbool unit_propagation(); + void propagate(); + void propagate(literal lit); + literal choose_literal(); + void set_conflict(literal l1, literal l2); + void set_conflict(literal l1, literal l2, literal l3); + void set_conflict(clause const& c); + inline lbool value(literal lit) { return s.value(lit); } + void log_status(); + public: + + unit_walk(solver& s); + lbool operator()(); + std::ostream& display(std::ostream& out) const; + bool inconsistent() const { return m_inconsistent; } + void set_conflict(); + void assign(literal lit); + }; +}; + +#endif diff --git a/src/sat/sat_watched.cpp b/src/sat/sat_watched.cpp index 369e95034..1a0880282 100644 --- a/src/sat/sat_watched.cpp +++ b/src/sat/sat_watched.cpp @@ -27,11 +27,10 @@ namespace sat { for (; it != end; ++it) { if (it->is_clause() && it->get_clause_offset() == c) { watch_list::iterator it2 = it; - ++it; - for (; it != end; ++it) { + ++it; + for (; it != end; ++it, ++it2) { SASSERT(!((it->is_clause() && it->get_clause_offset() == c))); *it2 = *it; - ++it2; } wlist.set_end(it2); return true; @@ -71,6 +70,13 @@ namespace sat { VERIFY(found); } + void conflict_cleanup(watch_list::iterator it, watch_list::iterator it2, watch_list& wlist) { + watch_list::iterator end = wlist.end(); + for (; it != end; ++it, ++it2) + *it2 = *it; + wlist.set_end(it2); + } + std::ostream& display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist) { bool first = true; diff --git a/src/sat/sat_watched.h b/src/sat/sat_watched.h index 305948251..e66197878 100644 --- a/src/sat/sat_watched.h +++ b/src/sat/sat_watched.h @@ -139,6 +139,8 @@ namespace sat { class clause_allocator; std::ostream& display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist); + + void conflict_cleanup(watch_list::iterator it, watch_list::iterator it2, watch_list& wlist); }; #endif