diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index 9192a6f21..581c1a355 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -49,7 +49,7 @@ namespace sat { void ddfw::check_without_plugin() { while (m_limit.inc() && m_min_sz > 0) { if (should_reinit_weights()) do_reinit_weights(); - else if (do_flip()); + else if (do_flip()); else if (should_restart()) do_restart(); else if (m_parallel_sync && m_parallel_sync()); else shift_weights(); @@ -67,7 +67,7 @@ namespace sat { if (should_reinit_weights()) do_reinit_weights(); else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale(); else if (should_restart()) do_restart(), m_plugin->on_restart(); - else if (do_flip()); + else if (do_flip()); else shift_weights(), m_plugin->on_rescale(); //verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n"; ++steps; @@ -102,15 +102,13 @@ namespace sat { m_last_flips = m_flips; } - template bool ddfw::do_flip() { double reward = 0; - bool_var v = pick_var(reward); + bool_var v = pick_var(reward); //verbose_stream() << "flip " << v << " " << reward << "\n"; - return apply_flip(v, reward); + return apply_flip(v, reward); } - template bool ddfw::apply_flip(bool_var v, double reward) { if (v == null_bool_var) return false; @@ -124,7 +122,6 @@ namespace sat { return false; } - template bool_var ddfw::pick_var(double& r) { double sum_pos = 0; unsigned n = 1; @@ -167,13 +164,17 @@ namespace sat { } } - sat::bool_var ddfw::add_var(bool is_internal) { + sat::bool_var ddfw::add_var() { auto v = m_vars.size(); m_vars.reserve(v + 1); - m_vars[v].m_internal = is_internal; return v; } + void ddfw::reserve_vars(unsigned n) { + m_vars.reserve(n); + } + + /** * Remove the last clause that was added */ @@ -215,11 +216,6 @@ namespace sat { m_restart_count = 0; m_restart_next = m_config.m_restart_base*2; -#if 0 - m_parsync_count = 0; - m_parsync_next = m_config.m_parsync_base; -#endif - m_min_sz = m_unsat.size(); m_flips = 0; m_last_flips = 0; @@ -244,9 +240,8 @@ namespace sat { m_use_list_index.push_back(m_flat_use_list.size()); } - bool ddfw::flip(bool_var v) { + void ddfw::flip(bool_var v) { ++m_flips; - bool new_unsat = false; literal lit = literal(v, !value(v)); literal nlit = ~lit; SASSERT(is_true(lit)); @@ -262,7 +257,6 @@ namespace sat { verbose_stream() << "flipping unit clause " << ci << "\n"; #endif m_unsat.insert_fresh(cls_idx); - new_unsat = true; auto const& c = get_clause(cls_idx); for (literal l : c) { inc_reward(l, w); @@ -304,7 +298,6 @@ namespace sat { } value(v) = !value(v); update_reward_avg(v); - return new_unsat; } bool ddfw::should_reinit_weights() { @@ -404,38 +397,20 @@ namespace sat { for (unsigned i = 0; i < num_vars(); ++i) m_model[i] = to_lbool(value(i)); save_priorities(); - if (m_plugin && m_unsat.empty()) - m_plugin->on_save_model(); + if (m_plugin) + m_plugin->on_save_model(); } - void ddfw::save_best_values() { - if (m_unsat.size() < m_min_sz || m_unsat.empty()) { - if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11) - save_model(); + if ((m_unsat.size() < m_min_sz || m_unsat.empty()) && + ((m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11))) + save_model(); + + if (m_unsat.size() < m_min_sz) { + m_models.reset(); + m_min_sz = m_unsat.size(); } - - if (m_unsat.size() < m_min_sz) - m_models.reset(); - m_min_sz = m_unsat.size(); - -#if 0 - m_num_models.reserve(m_min_sz + 1); - unsigned nm = m_num_models[m_min_sz]++; - - - if (nm >= 10) { - if (nm >= 200) - m_num_models[m_min_sz] = 10, m_restart_next = m_flips; - if (nm % 1 == 0) { - for (unsigned v = 0; v < num_vars(); ++v) - bias(v) += value(v) ? 1 : -1; - } - return; - } -#endif - unsigned h = value_hash(); unsigned occs = 0; bool contains = m_models.find(h, occs); @@ -449,8 +424,7 @@ namespace sat { if (occs > 100) { m_restart_next = m_flips; m_models.erase(h); - } - + } } unsigned ddfw::value_hash() const { diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index d44cb2bb7..60e9424a7 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -69,13 +69,11 @@ namespace sat { struct var_info { var_info() {} - bool m_internal = false; bool m_value = false; double m_reward = 0; double m_last_reward = 0; unsigned m_make_count = 0; int m_bias = 0; - bool m_external = false; ema m_reward_avg = 1e-5; }; @@ -124,11 +122,6 @@ namespace sat { inline double& reward(bool_var v) { return m_vars[v].m_reward; } - void set_external(bool_var v) { m_vars[v].m_external = true; } - - inline bool is_external(bool_var v) const { return m_vars[v].m_external; } - - inline int& bias(bool_var v) { return m_vars[v].m_bias; } unsigned value_hash() const; @@ -162,13 +155,10 @@ namespace sat { void check_without_plugin(); // flip activity - template bool do_flip(); - template bool_var pick_var(double& reward); - template bool apply_flip(bool_var v, double reward); @@ -253,18 +243,19 @@ namespace sat { void remove_assumptions(); - bool flip(bool_var v); + void flip(bool_var v); inline double get_reward(bool_var v) const { return m_vars[v].m_reward; } + double get_reward_avg(bool_var v) const { return m_vars[v].m_reward_avg; } + + inline int& bias(bool_var v) { return m_vars[v].m_bias; } + + void reserve_vars(unsigned n); + void add(unsigned sz, literal const* c); - sat::bool_var add_var(bool is_internal = true); - - // is this a variable that was added during initialization? - bool is_initial_var(sat::bool_var v) const { - return m_vars.size() > v && !m_vars[v].m_internal; - } + sat::bool_var add_var(); void reinit(); diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 4d334e8c0..95f914292 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -67,7 +67,7 @@ namespace sls { virtual vector const& clauses() const = 0; virtual sat::clause_info const& get_clause(unsigned idx) const = 0; virtual ptr_iterator get_use_list(sat::literal lit) = 0; - virtual bool flip(sat::bool_var v) = 0; + virtual void flip(sat::bool_var v) = 0; virtual double reward(sat::bool_var v) = 0; virtual double get_weigth(unsigned clause_idx) = 0; virtual bool is_true(sat::literal lit) = 0; @@ -173,7 +173,7 @@ namespace sls { sat::literal mk_literal(expr* e); void add_clause(expr* f); void add_clause(sat::literal_vector const& lits); - bool flip(sat::bool_var v) { return s.flip(v); } + void flip(sat::bool_var v) { s.flip(v); } double reward(sat::bool_var v) { return s.reward(v); } indexed_uint_set const& unsat() const { return s.unsat(); } unsigned rand() { return m_rand(); } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 78daa678a..e4e9f1200 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -57,7 +57,7 @@ namespace sls { if (m_on_save_model) return; flet _on_save_model(m_on_save_model, true); - TRACE("sls", display(tout)); + CTRACE("sls", unsat().empty(), display(tout)); while (unsat().empty()) { m_context.check(); if (!m_new_constraint) @@ -87,7 +87,7 @@ namespace sls { vector const& clauses() const override { return m_ddfw.clauses(); } sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); } ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); } - bool flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; return m_ddfw.flip(v); } + void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); } double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); } double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; } bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index ed312a76c..69f3c0f37 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -2932,6 +2932,7 @@ namespace sat { bool_var v = m_trail[i].var(); m_best_phase[v] = m_phase[v]; } + set_has_new_best_phase(true); } } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 9e7186a34..657c92178 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -152,6 +152,7 @@ namespace sat { bool_vector m_phase; bool_vector m_best_phase; bool_vector m_prev_phase; + bool m_new_best_phase = false; svector m_assigned_since_gc; search_state m_search_state; unsigned m_search_unsat_conflicts; @@ -380,6 +381,9 @@ namespace sat { bool was_eliminated(literal l) const { return was_eliminated(l.var()); } void set_phase(literal l) override { if (l.var() < num_vars()) m_best_phase[l.var()] = m_phase[l.var()] = !l.sign(); } bool get_phase(bool_var b) { return m_phase.get(b, false); } + bool get_best_phase(bool_var b) { return m_best_phase.get(b, false); } + void set_has_new_best_phase(bool b) { m_new_best_phase = b; } + bool has_new_best_phase() const { return m_new_best_phase; } void move_to_front(bool_var b); unsigned scope_lvl() const { return m_scope_lvl; } unsigned search_lvl() const { return m_search_lvl; } diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index 950a62ca0..bf40278b7 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -13,6 +13,7 @@ Author: Nikolaj Bjorner (nbjorner) 2024-02-21 + --*/ #include "sat/smt/sls_solver.h" @@ -36,7 +37,125 @@ namespace sls { finalize(); } - void solver::finalize() { + + class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context { + solver& s; + sat::ddfw* m_ddfw; + sls::context m_context; + bool m_new_clause_added = false; + unsigned m_num_shared_vars = 0; + + // export from SAT to SLS: + // - unit literals + // - phase + // - values + bool export_to_sls() { + bool updated = false; + if (s.m_has_units) { + std::lock_guard lock(s.m_mutex); + IF_VERBOSE(1, verbose_stream() << "SAT->SLS units " << s.m_units << "\n"); + for (auto lit : s.m_units) + if (lit.var() < m_num_shared_vars) + m_ddfw->add(1, &lit); + s.m_has_units = false; + s.m_units.reset(); + updated = true; + } + if (m_has_new_sat_phase) { + std::lock_guard lock(s.m_mutex); + IF_VERBOSE(1, verbose_stream() << "SAT->SLS phase\n"); + for (unsigned i = 0; i < m_sat_phase.size(); ++i) { + if (m_sat_phase[i] != is_true(sat::literal(i, false))) + flip(i); + m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1; + } + m_has_new_sat_phase = false; + } + return updated; + } + + // import from SLS: + // - activity + // - phase + // - values + void import_from_sls() { + std::lock_guard lock(s.m_mutex); + for (unsigned v = 0; v < m_num_shared_vars; ++v) { + m_rewards[v] = m_ddfw->get_reward_avg(v); + m_sls_phase[v] = l_true == m_ddfw->get_model()[v]; + m_has_new_sls_phase = true; + } + } + + public: + smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : + s(s), m_ddfw(d), m_context(m, *this) {} + + + svector m_sat_phase; + std::atomic m_has_new_sat_phase = false; + + std::atomic m_has_new_sls_phase = false; + svector m_sls_phase; + + svector m_rewards; + + void init_search() override {} + + void finish_search() override {} + + void on_rescale() override {} + + void on_restart() override { + if (export_to_sls()) + m_ddfw->reinit(); + } + + void on_save_model() override { + TRACE("sls", display(tout)); + while (unsat().empty()) { + m_context.check(); + if (!m_new_clause_added) + break; + m_ddfw->reinit(); + m_new_clause_added = false; + } + import_from_sls(); + } + + void on_model(model_ref& mdl) override { + IF_VERBOSE(1, verbose_stream() << "on-model " << "\n"); + s.m_sls_model = mdl; + } + + void register_atom(sat::bool_var v, expr* e) { + m_context.register_atom(v, e); + } + + std::ostream& display(std::ostream& out) { + m_ddfw->display(out); + m_context.display(out); + return out; + } + + vector const& clauses() const override { return m_ddfw->clauses(); } + sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); } + ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); } + void flip(sat::bool_var v) override { m_ddfw->flip(v); } + double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } + double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } + bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); } + unsigned num_vars() const override { return m_ddfw->num_vars(); } + indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } + sat::bool_var add_var() override { return m_ddfw->add_var(); } + void add_clause(unsigned n, sat::literal const* lits) override { + m_ddfw->add(n, lits); + m_new_clause_added = true; + } + void force_restart() override { m_ddfw->force_restart(); } + }; + + void solver::finalize() { if (!m_completed && m_ddfw) { m_ddfw->rlimit().cancel(); m_thread.join(); @@ -65,79 +184,25 @@ namespace sls { m_units.push_back(lit); m_has_units = true; } + if (s().at_base_lvl()) { + if (s().has_new_best_phase()) { + IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n"); + m_smt_plugin->m_has_new_sat_phase = true; + s().set_has_new_best_phase(false); + std::lock_guard lock(m_mutex); + for (unsigned i = 0; i < m_smt_plugin->m_sat_phase.size(); ++i) + m_smt_plugin->m_sat_phase[i] = s().get_best_phase(i); + } + } + if (m_smt_plugin->m_has_new_sls_phase) { + IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n"); + std::lock_guard lock(m_mutex); + for (unsigned i = 0; i < m_smt_plugin->m_sls_phase.size(); ++i) + s().set_phase(sat::literal(i, !m_smt_plugin->m_sls_phase[i])); + m_smt_plugin->m_has_new_sls_phase = false; + } } - class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context { - solver& s; - sat::ddfw* m_ddfw; - sls::context m_context; - bool m_new_clause_added = false; - public: - smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : - s(s), m_ddfw(d), m_context(m, *this) {} - - void init_search() override {} - - void finish_search() override {} - - void on_rescale() override {} - - void on_restart() override { - if (!s.m_has_units) - return; - { - std::lock_guard lock(s.m_mutex); - for (auto lit : s.m_units) - if (m_ddfw->is_initial_var(lit.var())) - m_ddfw->add(1, &lit); - s.m_has_units = false; - s.m_units.reset(); - } - m_ddfw->reinit(); - } - - void on_save_model() override { - TRACE("sls", display(tout)); - while (unsat().empty()) { - m_context.check(); - if (!m_new_clause_added) - break; - m_ddfw->reinit(); - m_new_clause_added = false; - } - } - - void on_model(model_ref& mdl) override { - IF_VERBOSE(1, verbose_stream() << "on-model " << "\n"); - s.m_sls_model = mdl; - } - - void register_atom(sat::bool_var v, expr* e) { - m_context.register_atom(v, e); - } - - std::ostream& display(std::ostream& out) { - m_ddfw->display(out); - m_context.display(out); - return out; - } - - vector const& clauses() const override { return m_ddfw->clauses(); } - sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); } - ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); } - bool flip(sat::bool_var v) override { return m_ddfw->flip(v); } - double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } - double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } - bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); } - unsigned num_vars() const override { return m_ddfw->num_vars(); } - indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } - sat::bool_var add_var() override { return m_ddfw->add_var(); } - void add_clause(unsigned n, sat::literal const* lits) override { - m_ddfw->add(n, lits); - m_new_clause_added = true; - } - void force_restart() override { m_ddfw->force_restart(); } - }; void solver::init_search() { if (m_ddfw) { @@ -215,6 +280,5 @@ namespace sls { out << "sls-solver\n"; return out; } - #endif }