From 75c573877d0bec64e36a7cf69689915a52c1ba81 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 5 Feb 2023 21:35:18 -0800 Subject: [PATCH] updates to ddfw, initial local search phase option --- src/sat/sat_config.cpp | 2 + src/sat/sat_config.h | 1 + src/sat/sat_ddfw.cpp | 118 ++++++++++++++++++++++++----------------- src/sat/sat_ddfw.h | 64 +++++++++++++--------- src/sat/sat_solver.cpp | 79 +++++++++++++++++---------- src/sat/sat_solver.h | 2 + src/sat/sat_types.h | 3 +- 7 files changed, 165 insertions(+), 104 deletions(-) diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index eb2d0071d..0a9e803a9 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -65,6 +65,8 @@ namespace sat { m_phase = PS_RANDOM; else if (s == symbol("frozen")) m_phase = PS_FROZEN; + else if (s == symbol("local_search")) + m_phase = PS_LOCAL_SEARCH; else throw sat_param_exception("invalid phase selection strategy: always_false, always_true, basic_caching, caching, random"); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index 8adfc13ed..f8c0775b1 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -28,6 +28,7 @@ namespace sat { PS_ALWAYS_FALSE, PS_BASIC_CACHING, PS_SAT_CACHING, + PS_LOCAL_SEARCH, PS_FROZEN, PS_RANDOM }; diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index ecfc13aa6..f1493232c 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -49,6 +49,7 @@ namespace sat { else if (should_parallel_sync()) do_parallel_sync(); else shift_weights(); } + log(); return m_min_sz == 0 ? l_true : l_undef; } @@ -66,9 +67,9 @@ namespace sat { << std::setw(10) << kflips_per_sec << std::setw(10) << m_flips << std::setw(10) << m_restart_count - << std::setw(10) << m_reinit_count - << std::setw(10) << m_unsat_vars.size() - << std::setw(10) << m_shifts; + << std::setw(11) << m_reinit_count + << std::setw(13) << m_unsat_vars.size() + << std::setw(9) << m_shifts; if (m_par) verbose_stream() << std::setw(10) << m_parsync_count; verbose_stream() << ")\n"); m_stopwatch.start(); @@ -90,18 +91,18 @@ namespace sat { unsigned n = 1; bool_var v0 = null_bool_var; for (bool_var v : m_unsat_vars) { - int r = reward(v); - if (r > 0) { + double r = reward(v); + if (r > 0.0) { sum_pos += score(r); } - else if (r == 0 && sum_pos == 0 && (m_rand() % (n++)) == 0) { + else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0) { v0 = v; } } if (sum_pos > 0) { double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos; for (bool_var v : m_unsat_vars) { - int r = reward(v); + double r = reward(v); if (r > 0) { lim_pos -= score(r); if (lim_pos <= 0) { @@ -121,7 +122,7 @@ namespace sat { * TBD: map reward value to a score, possibly through an exponential function, such as * exp(-tau/r), where tau > 0 */ - double ddfw::mk_score(unsigned r) { + double ddfw::mk_score(double r) { return r; } @@ -201,7 +202,7 @@ namespace sat { m_shifts = 0; m_stopwatch.start(); } - + void ddfw::reinit(solver& s) { add(s); add_assumptions(); @@ -235,7 +236,7 @@ namespace sat { for (unsigned cls_idx : use_list(*this, lit)) { clause_info& ci = m_clauses[cls_idx]; ci.del(lit); - unsigned w = ci.m_weight; + double w = ci.m_weight; // cls becomes false: flip any variable in clause to receive reward w switch (ci.m_num_trues) { case 0: { @@ -257,7 +258,7 @@ namespace sat { } for (unsigned cls_idx : use_list(*this, nlit)) { clause_info& ci = m_clauses[cls_idx]; - unsigned w = ci.m_weight; + double w = ci.m_weight; // the clause used to have a single true (pivot) literal, now it has two. // Then the previous pivot is no longer penalized for flipping. switch (ci.m_num_trues) { @@ -406,9 +407,8 @@ namespace sat { void ddfw::save_best_values() { if (m_unsat.empty()) { m_model.reserve(num_vars()); - for (unsigned i = 0; i < num_vars(); ++i) { + for (unsigned i = 0; i < num_vars(); ++i) m_model[i] = to_lbool(value(i)); - } } if (m_unsat.size() < m_min_sz) { m_models.reset(); @@ -422,13 +422,11 @@ namespace sat { } unsigned h = value_hash(); if (!m_models.contains(h)) { - for (unsigned v = 0; v < num_vars(); ++v) { + for (unsigned v = 0; v < num_vars(); ++v) bias(v) += value(v) ? 1 : -1; - } m_models.insert(h); - if (m_models.size() > m_config.m_max_num_models) { + if (m_models.size() > m_config.m_max_num_models) m_models.erase(*m_models.begin()); - } } m_min_sz = m_unsat.size(); } @@ -450,10 +448,9 @@ namespace sat { 3. select multiple clauses instead of just one per clause in unsat. */ - bool ddfw::select_clause(unsigned max_weight, unsigned max_trues, clause_info const& cn, unsigned& n) { - if (cn.m_num_trues == 0 || cn.m_weight < max_weight) { + bool ddfw::select_clause(double max_weight, clause_info const& cn, unsigned& n) { + if (cn.m_num_trues == 0 || cn.m_weight + 1e-5 < max_weight) return false; - } if (cn.m_weight > max_weight) { n = 2; return true; @@ -462,51 +459,72 @@ namespace sat { } unsigned ddfw::select_max_same_sign(unsigned cf_idx) { - clause const& c = get_clause(cf_idx); - unsigned max_weight = 2; - unsigned max_trues = 0; + auto& ci = m_clauses[cf_idx]; unsigned cl = UINT_MAX; // clause pointer to same sign, max weight satisfied clause. + clause const& c = *ci.m_clause; + double max_weight = m_init_weight; unsigned n = 1; for (literal lit : c) { for (unsigned cn_idx : use_list(*this, lit)) { auto& cn = m_clauses[cn_idx]; - if (select_clause(max_weight, max_trues, cn, n)) { + if (select_clause(max_weight, cn, n)) { cl = cn_idx; max_weight = cn.m_weight; - max_trues = cn.m_num_trues; } } } return cl; } + void ddfw::transfer_weight(unsigned from, unsigned to, double w) { + auto& cf = m_clauses[to]; + auto& cn = m_clauses[from]; + if (cn.m_weight < w) + return; + cf.m_weight += w; + cn.m_weight -= w; + + for (literal lit : get_clause(to)) + inc_reward(lit, w); + if (cn.m_num_trues == 1) + inc_reward(to_literal(cn.m_trues), w); + } + + unsigned ddfw::select_random_true_clause() { + unsigned num_clauses = m_clauses.size(); + unsigned rounds = 100 * num_clauses; + for (unsigned i = 0; i < rounds; ++i) { + unsigned idx = (m_rand() * m_rand()) % num_clauses; + auto & cn = m_clauses[idx]; + if (cn.is_true() && cn.m_weight >= m_init_weight) + return idx; + } + return UINT_MAX; + } + + // 1% chance to disregard neighbor + inline bool ddfw::disregard_neighbor() { + return false; // rand() % 1000 == 0; + } + + double ddfw::calculate_transfer_weight(double w) { + return (w > m_init_weight) ? m_init_weight : 1; + } + void ddfw::shift_weights() { ++m_shifts; - for (unsigned cf_idx : m_unsat) { - auto& cf = m_clauses[cf_idx]; + for (unsigned to_idx : m_unsat) { + auto& cf = m_clauses[to_idx]; SASSERT(!cf.is_true()); - unsigned cn_idx = select_max_same_sign(cf_idx); - while (cn_idx == UINT_MAX) { - unsigned idx = (m_rand() * m_rand()) % m_clauses.size(); - auto & cn = m_clauses[idx]; - if (cn.is_true() && cn.m_weight >= 2) { - cn_idx = idx; - } - } - auto & cn = m_clauses[cn_idx]; + unsigned from_idx = select_max_same_sign(to_idx); + if (from_idx == UINT_MAX || disregard_neighbor()) + from_idx = select_random_true_clause(); + if (from_idx == UINT_MAX) + continue; + auto & cn = m_clauses[from_idx]; SASSERT(cn.is_true()); - unsigned wn = cn.m_weight; - SASSERT(wn >= 2); - unsigned inc = (wn > 2) ? 2 : 1; - SASSERT(wn - inc >= 1); - cf.m_weight += inc; - cn.m_weight -= inc; - for (literal lit : get_clause(cf_idx)) { - inc_reward(lit, inc); - } - if (cn.m_num_trues == 1) { - inc_reward(to_literal(cn.m_trues), inc); - } + double w = calculate_transfer_weight(cn.m_weight); + transfer_weight(from_idx, to_idx, w); } // DEBUG_CODE(invariant();); } @@ -543,7 +561,7 @@ namespace sat { VERIFY(found); } for (unsigned v = 0; v < num_vars(); ++v) { - int v_reward = 0; + double v_reward = 0; literal lit(v, !value(v)); for (unsigned j : m_use_list[lit.index()]) { clause_info const& ci = m_clauses[j]; @@ -559,7 +577,7 @@ namespace sat { } } IF_VERBOSE(0, if (v_reward != reward(v)) verbose_stream() << v << " " << v_reward << " " << reward(v) << "\n"); - SASSERT(reward(v) == v_reward); + // SASSERT(reward(v) == v_reward); } DEBUG_CODE( for (auto const& ci : m_clauses) { diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index 1cad87363..1d28a82c4 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -34,10 +34,10 @@ namespace sat { class ddfw : public i_local_search { struct clause_info { - clause_info(clause* cl, unsigned init_weight): m_weight(init_weight), m_trues(0), m_num_trues(0), m_clause(cl) {} - unsigned m_weight; // weight of clause - unsigned m_trues; // set of literals that are true - unsigned m_num_trues; // size of true set + clause_info(clause* cl, double init_weight): m_weight(init_weight), m_clause(cl) {} + double m_weight; // weight of clause + unsigned m_trues = 0; // set of literals that are true + unsigned m_num_trues = 0; // size of true set clause* m_clause; bool is_true() const { return m_num_trues > 0; } void add(literal lit) { ++m_num_trues; m_trues += lit.index(); } @@ -65,23 +65,24 @@ namespace sat { }; struct var_info { - var_info(): m_value(false), m_reward(0), m_make_count(0), m_bias(0), m_reward_avg(1e-5) {} - bool m_value; - int m_reward; - unsigned m_make_count; - int m_bias; - ema m_reward_avg; + var_info() {} + bool m_value = false; + double m_reward = 0; + unsigned m_make_count = 0; + int m_bias = 0; + ema m_reward_avg = 1e-5; }; - config m_config; - reslimit m_limit; - clause_allocator m_alloc; + config m_config; + reslimit m_limit; + clause_allocator m_alloc; svector m_clauses; literal_vector m_assumptions; svector m_vars; // var -> info svector m_probs; // var -> probability of flipping svector m_scores; // reward -> score model m_model; // var -> best assignment + unsigned m_init_weight = 2; vector m_use_list; unsigned_vector m_flat_use_list; @@ -90,11 +91,11 @@ namespace sat { indexed_uint_set m_unsat; indexed_uint_set m_unsat_vars; // set of variables that are in unsat clauses random_gen m_rand; - unsigned m_num_non_binary_clauses{ 0 }; - unsigned m_restart_count{ 0 }, m_reinit_count{ 0 }, m_parsync_count{ 0 }; - uint64_t m_restart_next{ 0 }, m_reinit_next{ 0 }, m_parsync_next{ 0 }; - uint64_t m_flips{ 0 }, m_last_flips{ 0 }, m_shifts{ 0 }; - unsigned m_min_sz{ 0 }; + unsigned m_num_non_binary_clauses = 0; + unsigned m_restart_count = 0, m_reinit_count = 0, m_parsync_count = 0; + uint64_t m_restart_next = 0, m_reinit_next = 0, m_parsync_next = 0; + uint64_t m_flips = 0, m_last_flips = 0, m_shifts = 0; + unsigned m_min_sz = 0; hashtable> m_models; stopwatch m_stopwatch; @@ -112,9 +113,9 @@ namespace sat { void flatten_use_list(); - double mk_score(unsigned r); + double mk_score(double r); - inline double score(unsigned r) { return r; } // TBD: { for (unsigned sz = m_scores.size(); sz <= r; ++sz) m_scores.push_back(mk_score(sz)); return m_scores[r]; } + inline double score(double r) { return r; } // TBD: { for (unsigned sz = m_scores.size(); sz <= r; ++sz) m_scores.push_back(mk_score(sz)); return m_scores[r]; } inline unsigned num_vars() const { return m_vars.size(); } @@ -124,9 +125,9 @@ namespace sat { inline bool value(bool_var v) const { return m_vars[v].m_value; } - inline int& reward(bool_var v) { return m_vars[v].m_reward; } + inline double& reward(bool_var v) { return m_vars[v].m_reward; } - inline int reward(bool_var v) const { return m_vars[v].m_reward; } + inline double reward(bool_var v) const { return m_vars[v].m_reward; } inline int& bias(bool_var v) { return m_vars[v].m_bias; } @@ -136,7 +137,7 @@ namespace sat { inline clause const& get_clause(unsigned idx) const { return *m_clauses[idx].m_clause; } - inline unsigned get_weight(unsigned idx) const { return m_clauses[idx].m_weight; } + inline double get_weight(unsigned idx) const { return m_clauses[idx].m_weight; } inline bool is_true(unsigned idx) const { return m_clauses[idx].is_true(); } @@ -154,9 +155,9 @@ namespace sat { if (--make_count(v) == 0) m_unsat_vars.remove(v); } - inline void inc_reward(literal lit, int inc) { reward(lit.var()) += inc; } + inline void inc_reward(literal lit, double w) { reward(lit.var()) += w; } - inline void dec_reward(literal lit, int inc) { reward(lit.var()) -= inc; } + inline void dec_reward(literal lit, double w) { reward(lit.var()) -= w; } // flip activity bool do_flip(); @@ -166,17 +167,20 @@ namespace sat { // shift activity void shift_weights(); + inline double calculate_transfer_weight(double w); // reinitialize weights activity bool should_reinit_weights(); void do_reinit_weights(); - inline bool select_clause(unsigned max_weight, unsigned max_trues, clause_info const& cn, unsigned& n); + inline bool select_clause(double max_weight, clause_info const& cn, unsigned& n); // restart activity bool should_restart(); void do_restart(); void reinit_values(); + unsigned select_random_true_clause(); + // parallel integration bool should_parallel_sync(); void do_parallel_sync(); @@ -193,6 +197,10 @@ namespace sat { void add_assumptions(); + inline void transfer_weight(unsigned from, unsigned to, double w); + + inline bool disregard_neighbor(); + public: ddfw(): m_par(nullptr) {} @@ -210,6 +218,10 @@ namespace sat { void set_seed(unsigned n) override { m_rand.set_seed(n); } void add(solver const& s) override; + + void set_bias(bool_var v, int bias) override { m_vars[v].m_bias = bias; } + + bool get_value(bool_var v) const override { return value(v); } std::ostream& display(std::ostream& out) const; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 7cfa102b7..61fc816e1 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1330,17 +1330,37 @@ namespace sat { ERROR_EX }; + struct solver::scoped_ls { + solver& s; + scoped_ls(solver& s): s(s) {} + ~scoped_ls() { + dealloc(s.m_local_search); + s.m_local_search = nullptr; + } + }; + + void solver::bounded_local_search() { + literal_vector _lits; + scoped_limits scoped_rl(rlimit()); + m_local_search = alloc(ddfw); + scoped_ls _ls(*this); + SASSERT(m_local_search); + m_local_search->add(*this); + m_local_search->updt_params(m_params); + m_local_search->set_seed(m_rand()); + scoped_rl.push_child(&(m_local_search->rlimit())); + m_local_search->rlimit().push(500000); + m_local_search->reinit(*this); + m_local_search->check(_lits.size(), _lits.data(), nullptr); + for (unsigned i = 0; i < m_phase.size(); ++i) + m_best_phase[i] = m_local_search->get_value(i); + } + + lbool solver::invoke_local_search(unsigned num_lits, literal const* lits) { literal_vector _lits(num_lits, lits); - for (literal lit : m_user_scope_literals) _lits.push_back(~lit); - struct scoped_ls { - solver& s; - scoped_ls(solver& s): s(s) {} - ~scoped_ls() { - dealloc(s.m_local_search); - s.m_local_search = nullptr; - } - }; + for (literal lit : m_user_scope_literals) + _lits.push_back(~lit); scoped_ls _ls(*this); if (inconsistent()) return l_false; @@ -1610,27 +1630,28 @@ namespace sat { bool solver::guess(bool_var next) { lbool lphase = m_ext ? m_ext->get_phase(next) : l_undef; - + if (lphase != l_undef) return lphase == l_true; switch (m_config.m_phase) { - case PS_ALWAYS_TRUE: - return true; - case PS_ALWAYS_FALSE: - return false; - case PS_BASIC_CACHING: + case PS_ALWAYS_TRUE: + return true; + case PS_ALWAYS_FALSE: + return false; + case PS_BASIC_CACHING: + return m_phase[next]; + case PS_FROZEN: + return m_best_phase[next]; + case PS_SAT_CACHING: + case PS_LOCAL_SEARCH: + if (m_search_state == s_unsat) return m_phase[next]; - case PS_FROZEN: - return m_best_phase[next]; - case PS_SAT_CACHING: - if (m_search_state == s_unsat) - return m_phase[next]; - return m_best_phase[next]; - case PS_RANDOM: - return (m_rand() % 2) == 0; - default: - UNREACHABLE(); - return false; + return m_best_phase[next]; + case PS_RANDOM: + return (m_rand() % 2) == 0; + default: + UNREACHABLE(); + return false; } } @@ -2822,7 +2843,7 @@ namespace sat { } bool solver::is_two_phase() const { - return m_config.m_phase == PS_SAT_CACHING; + return m_config.m_phase == PS_SAT_CACHING || m_config.m_phase == PS_LOCAL_SEARCH; } bool solver::is_sat_phase() const { @@ -2922,6 +2943,10 @@ namespace sat { case PS_RANDOM: for (auto& p : m_phase) p = (m_rand() % 2) == 0; break; + case PS_LOCAL_SEARCH: + if (m_search_state == s_sat) + bounded_local_search(); + break; default: UNREACHABLE(); break; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 227568f3d..524f6b06d 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -589,7 +589,9 @@ namespace sat { lbool do_ddfw_search(unsigned num_lits, literal const* lits); lbool do_prob_search(unsigned num_lits, literal const* lits); lbool invoke_local_search(unsigned num_lits, literal const* lits); + void bounded_local_search(); lbool do_unit_walk(); + struct scoped_ls; // ----------------------- // diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index 4e119a2ae..626a3e606 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -91,7 +91,8 @@ namespace sat { virtual model const& get_model() const = 0; virtual void collect_statistics(statistics& st) const = 0; virtual double get_priority(bool_var v) const { return 0; } - + virtual void set_bias(bool_var v, int bias) {} + virtual bool get_value(bool_var v) const { return true; } }; class proof_hint {