diff --git a/src/sat/sat_sls.cpp b/src/sat/sat_sls.cpp index 04c8d0605..78d9b66e5 100644 --- a/src/sat/sat_sls.cpp +++ b/src/sat/sat_sls.cpp @@ -65,7 +65,7 @@ namespace sat { lbool sls::operator()(unsigned sz, literal const* tabu, bool reuse_model) { init(sz, tabu, reuse_model); unsigned i; - for (i = 0; !m_false.empty() && i < m_max_tries; ++i) { + for (i = 0; !m_false.empty() && !m_cancel && i < m_max_tries; ++i) { flip(); } IF_VERBOSE(2, verbose_stream() << "tries " << i << "\n";); @@ -265,5 +265,233 @@ namespace sat { } } + void sls::check_invariant() { + + } + + wsls::wsls(solver& s): + sls(s) + { + m_smoothing_probability = 1; // 1/1000 + } + + wsls::~wsls() {} + + void wsls::set_soft(unsigned sz, double const* weights, literal const* lits) { + m_soft.reset(); + m_weights.reset(); + m_soft.append(sz, lits); + m_weights.append(sz, weights); + } + + void wsls::opt(unsigned sz, literal const* tabu, bool reuse_model) { + init(sz, tabu, reuse_model); + + // + // Initialize m_clause_weights, m_hscore, m_sscore. + // + m_best_value = m_false.empty()?evaluate_model():-1.0; + m_clause_weights.reset(); + m_hscore.reset(); + m_sscore.reset(); + m_H.reset(); + m_S.reset(); + m_clause_weights.resize(m_clauses.size(), 1); + m_sscore.resize(s.num_vars(), 0.0); + m_hscore.resize(s.num_vars(), 0); + for (unsigned i = 0; i < m_soft.size(); ++i) { + literal lit = m_soft[i]; + m_sscore[lit.var()] = m_weights[i]; + if (value_at(lit, m_model) == l_true) { + m_sscore[lit.var()] = -m_sscore[lit.var()]; + } + } + for (unsigned i = 0; i < s.num_vars(); ++i) { + m_hscore[i] = compute_hscore(i); + refresh_scores(i); + } + unsigned i = 0; + for (; !m_cancel && i < m_max_tries; ++i) { + wflip(); + } + IF_VERBOSE(2, verbose_stream() << "tries " << i << "\n";); + } + + void wsls::wflip() { + literal lit; + if (pick_wflip(lit)) { + wflip(lit); + } + } + + bool wsls::pick_wflip(literal & lit) { + if (m_false.empty()) { + double val = evaluate_model(); + if (val < m_best_value || m_best_value < 0.0) { + m_best_model.reset(); + m_best_model.append(m_model); + } + } + unsigned idx; + if (!m_H.empty()) { + idx = m_H.choose(m_rand); + lit = literal(idx, false); + if (value_at(lit, m_model) == l_true) lit.neg(); + } + else if (!m_S.empty()) { + double score = 0.0; + m_min_vars.reset(); + for (unsigned i = 0; i < m_S.num_elems(); ++i) { + unsigned v = m_S[i]; + SASSERT(m_sscore[v] > 0.0); + if (m_sscore[v] > score) { + m_min_vars.reset(); + m_min_vars.push_back(literal(v, false)); + score = m_sscore[v]; + } + else if (m_sscore[v] == score) { + m_min_vars.push_back(literal(v, false)); + } + } + idx = m_min_vars[m_rand(m_min_vars.size())].var(); // pick with largest sscore. + } + else { + update_hard_weights(); + if (!m_false.empty()) { + unsigned cls_idx = m_false.choose(m_rand); + } + else { + lit = m_soft[m_rand(m_soft.size())]; + } + } + return !m_tabu[lit.var()]; + } + + void wsls::wflip(literal lit) { + flip(lit); + unsigned v = lit.var(); + m_hscore[v] = compute_hscore(v); + m_sscore[v] = -m_sscore[v]; + refresh_scores(v); + } + + void wsls::update_hard_weights() { + unsigned csz = m_clauses.size(); + if (m_smoothing_probability >= m_rand(1000)) { + for (unsigned i = 0; i < csz; ++i) { + if (m_clause_weights[i] > 1 && !m_false.contains(i)) { + --m_clause_weights[i]; + if (m_num_true[i] == 1) { + clause const& c = *m_clauses[i]; + unsigned sz = c.size(); + for (unsigned j = 0; j < sz; ++j) { + if (value_at(c[j], m_model) == l_true) { + ++m_hscore[c[j].var()]; + refresh_scores(c[j].var()); + break; + } + } + } + } + } + } + else { + for (unsigned i = 0; i < csz; ++i) { + if (m_false.contains(i)) { + ++m_clause_weights[i]; + clause const& c = *m_clauses[i]; + unsigned sz = c.size(); + for (unsigned j = 0; j < sz; ++j) { + ++m_hscore[c[j].var()]; + refresh_scores(c[j].var()); + } + } + } + } + + DEBUG_CODE(check_invariant();); + } + + double wsls::evaluate_model() { + SASSERT(m_false.empty()); + double result = 0.0; + for (unsigned i = 0; i < m_soft.size(); ++i) { + literal lit = m_soft[i]; + if (value_at(lit, m_model) != l_true) { + result += m_weights[i]; + } + } + return result; + } + + int wsls::compute_hscore(unsigned v) { + literal lit(v, false); + if (value_at(lit, m_model) == l_false) { + lit.neg(); + } + SASSERT(value_at(lit, m_model) == l_true); + int hs = 0; + unsigned_vector const& use1 = get_use(~lit); + unsigned sz = use1.size(); + for (unsigned i = 0; i < sz; ++i) { + unsigned cl = use1[i]; + if (m_num_true[cl] == 0) { + SASSERT(m_false.contains(cl)); + hs += m_clause_weights[cl]; + } + else { + SASSERT(!m_false.contains(cl)); + } + } + unsigned_vector const& use2 = get_use(lit); + sz = use2.size(); + for (unsigned i = 0; i < sz; ++i) { + unsigned cl = use2[i]; + if (m_num_true[cl] == 1) { + SASSERT(!m_false.contains(cl)); + hs -= m_clause_weights[cl]; + } + } + return hs; + } + + void wsls::refresh_scores(unsigned v) { + if (m_hscore[v] > 0) { + m_H.insert(v); + } + else { + m_H.remove(v); + } + if (m_sscore[v] > 0) { + if (m_hscore[v] == 0) { + m_S.insert(v); + } + else { + m_S.remove(v); + } + } + else if (m_sscore[v] < 0) { + m_S.remove(v); + } + } + + void wsls::check_invariant() { + sls::check_invariant(); + // The hscore is the reward for flipping the truth value of variable v. + // hscore(v) = Sum weight(c) for num_true(c) = 0 and v in c + // - Sum weight(c) for num_true(c) = 1 and (v in c, M(v) or !v in c and !M(v)) + for (unsigned v = 0; v < s.num_vars(); ++v) { + int hs = compute_hscore(v); + SASSERT(m_hscore[v] == hs); + } + + // The score(v) is the reward on soft clauses for flipping v. + for (unsigned j = 0; j < m_soft.size(); ++j) { + unsigned v = m_soft[j].var(); + double ss = value_at(m_soft[j], m_model)?(-m_weights[j]):m_weights[j]; + SASSERT(m_sscore[v] == ss); + } + } + }; diff --git a/src/sat/sat_sls.h b/src/sat/sat_sls.h index 6237c5b13..5546b90ac 100644 --- a/src/sat/sat_sls.h +++ b/src/sat/sat_sls.h @@ -28,7 +28,8 @@ namespace sat { unsigned_vector m_elems; unsigned_vector m_index; public: - unsigned num_elems() const { return m_elems.size(); } + unsigned num_elems() const { return m_elems.size(); } + unsigned operator[](unsigned idx) const { return m_elems[idx]; } void reset() { m_elems.reset(); m_index.reset(); } bool empty() const { return m_elems.empty(); } bool contains(unsigned idx) const; @@ -38,6 +39,7 @@ namespace sat { }; class sls { + protected: solver& s; random_gen m_rand; unsigned m_max_tries; @@ -52,22 +54,55 @@ namespace sat { clause_allocator m_alloc; // clause allocator clause_vector m_bin_clauses; // binary clauses svector m_tabu; // variables that cannot be swapped + volatile bool m_cancel; public: sls(solver& s); - ~sls(); + virtual ~sls(); lbool operator()(unsigned sz, literal const* tabu, bool reuse_model); - private: - bool local_search(); + void set_cancel(bool f) { m_cancel = f; } + protected: void init(unsigned sz, literal const* tabu, bool reuse_model); void init_tabu(unsigned sz, literal const* tabu); void init_model(); void init_use(); void init_clauses(); + unsigned_vector const& get_use(literal lit); + void flip(literal lit); + virtual void check_invariant(); + private: bool pick_flip(literal& lit); void flip(); - void flip(literal lit); unsigned get_break_count(literal lit, unsigned min_break); - unsigned_vector const& get_use(literal lit); + }; + + /** + \brief sls with weighted soft clauses. + */ + class wsls : public sls { + unsigned_vector m_clause_weights; + svector m_hscore; + svector m_sscore; + literal_vector m_soft; + svector m_weights; + double m_best_value; + model m_best_model; + index_set m_H, m_S; + unsigned m_smoothing_probability; + public: + wsls(solver& s); + virtual ~wsls(); + void set_soft(unsigned sz, double const* weights, literal const* lits); + void opt(unsigned sz, literal const* tabu, bool reuse_model); + model const& get_model() { return m_best_model; } + private: + void wflip(); + void wflip(literal lit); + void update_hard_weights(); + bool pick_wflip(literal & lit); + double evaluate_model(); + virtual void check_invariant(); + void refresh_scores(unsigned v); + int compute_hscore(unsigned v); }; }; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 33e88f64e..869330e06 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -137,6 +137,7 @@ namespace sat { friend class iff3_finder; friend class mus; friend class sls; + friend class wsls; friend struct mk_stat; public: solver(params_ref const & p, extension * ext);