3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

update the interface in sls_solver to transfer phase between SAT and SLS

This commit is contained in:
Nikolaj Bjorner 2024-10-20 15:42:26 -07:00
parent a48044c6e0
commit 68ee5108d8
7 changed files with 175 additions and 141 deletions

View file

@ -49,7 +49,7 @@ namespace sat {
void ddfw::check_without_plugin() {
while (m_limit.inc() && m_min_sz > 0) {
if (should_reinit_weights()) do_reinit_weights();
else if (do_flip<false>());
else if (do_flip());
else if (should_restart()) do_restart();
else if (m_parallel_sync && m_parallel_sync());
else shift_weights();
@ -67,7 +67,7 @@ namespace sat {
if (should_reinit_weights()) do_reinit_weights();
else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale();
else if (should_restart()) do_restart(), m_plugin->on_restart();
else if (do_flip<true>());
else if (do_flip());
else shift_weights(), m_plugin->on_rescale();
//verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n";
++steps;
@ -102,15 +102,13 @@ namespace sat {
m_last_flips = m_flips;
}
template<bool uses_plugin>
bool ddfw::do_flip() {
double reward = 0;
bool_var v = pick_var<uses_plugin>(reward);
bool_var v = pick_var(reward);
//verbose_stream() << "flip " << v << " " << reward << "\n";
return apply_flip<uses_plugin>(v, reward);
return apply_flip(v, reward);
}
template<bool uses_plugin>
bool ddfw::apply_flip(bool_var v, double reward) {
if (v == null_bool_var)
return false;
@ -124,7 +122,6 @@ namespace sat {
return false;
}
template<bool uses_plugin>
bool_var ddfw::pick_var(double& r) {
double sum_pos = 0;
unsigned n = 1;
@ -167,13 +164,17 @@ namespace sat {
}
}
sat::bool_var ddfw::add_var(bool is_internal) {
sat::bool_var ddfw::add_var() {
auto v = m_vars.size();
m_vars.reserve(v + 1);
m_vars[v].m_internal = is_internal;
return v;
}
void ddfw::reserve_vars(unsigned n) {
m_vars.reserve(n);
}
/**
* Remove the last clause that was added
*/
@ -215,11 +216,6 @@ namespace sat {
m_restart_count = 0;
m_restart_next = m_config.m_restart_base*2;
#if 0
m_parsync_count = 0;
m_parsync_next = m_config.m_parsync_base;
#endif
m_min_sz = m_unsat.size();
m_flips = 0;
m_last_flips = 0;
@ -244,9 +240,8 @@ namespace sat {
m_use_list_index.push_back(m_flat_use_list.size());
}
bool ddfw::flip(bool_var v) {
void ddfw::flip(bool_var v) {
++m_flips;
bool new_unsat = false;
literal lit = literal(v, !value(v));
literal nlit = ~lit;
SASSERT(is_true(lit));
@ -262,7 +257,6 @@ namespace sat {
verbose_stream() << "flipping unit clause " << ci << "\n";
#endif
m_unsat.insert_fresh(cls_idx);
new_unsat = true;
auto const& c = get_clause(cls_idx);
for (literal l : c) {
inc_reward(l, w);
@ -304,7 +298,6 @@ namespace sat {
}
value(v) = !value(v);
update_reward_avg(v);
return new_unsat;
}
bool ddfw::should_reinit_weights() {
@ -404,38 +397,20 @@ namespace sat {
for (unsigned i = 0; i < num_vars(); ++i)
m_model[i] = to_lbool(value(i));
save_priorities();
if (m_plugin && m_unsat.empty())
m_plugin->on_save_model();
if (m_plugin)
m_plugin->on_save_model();
}
void ddfw::save_best_values() {
if (m_unsat.size() < m_min_sz || m_unsat.empty()) {
if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)
save_model();
if ((m_unsat.size() < m_min_sz || m_unsat.empty()) &&
((m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)))
save_model();
if (m_unsat.size() < m_min_sz) {
m_models.reset();
m_min_sz = m_unsat.size();
}
if (m_unsat.size() < m_min_sz)
m_models.reset();
m_min_sz = m_unsat.size();
#if 0
m_num_models.reserve(m_min_sz + 1);
unsigned nm = m_num_models[m_min_sz]++;
if (nm >= 10) {
if (nm >= 200)
m_num_models[m_min_sz] = 10, m_restart_next = m_flips;
if (nm % 1 == 0) {
for (unsigned v = 0; v < num_vars(); ++v)
bias(v) += value(v) ? 1 : -1;
}
return;
}
#endif
unsigned h = value_hash();
unsigned occs = 0;
bool contains = m_models.find(h, occs);
@ -449,8 +424,7 @@ namespace sat {
if (occs > 100) {
m_restart_next = m_flips;
m_models.erase(h);
}
}
}
unsigned ddfw::value_hash() const {

View file

@ -69,13 +69,11 @@ namespace sat {
struct var_info {
var_info() {}
bool m_internal = false;
bool m_value = false;
double m_reward = 0;
double m_last_reward = 0;
unsigned m_make_count = 0;
int m_bias = 0;
bool m_external = false;
ema m_reward_avg = 1e-5;
};
@ -124,11 +122,6 @@ namespace sat {
inline double& reward(bool_var v) { return m_vars[v].m_reward; }
void set_external(bool_var v) { m_vars[v].m_external = true; }
inline bool is_external(bool_var v) const { return m_vars[v].m_external; }
inline int& bias(bool_var v) { return m_vars[v].m_bias; }
unsigned value_hash() const;
@ -162,13 +155,10 @@ namespace sat {
void check_without_plugin();
// flip activity
template<bool uses_plugin>
bool do_flip();
template<bool uses_plugin>
bool_var pick_var(double& reward);
template<bool uses_plugin>
bool apply_flip(bool_var v, double reward);
@ -253,18 +243,19 @@ namespace sat {
void remove_assumptions();
bool flip(bool_var v);
void flip(bool_var v);
inline double get_reward(bool_var v) const { return m_vars[v].m_reward; }
double get_reward_avg(bool_var v) const { return m_vars[v].m_reward_avg; }
inline int& bias(bool_var v) { return m_vars[v].m_bias; }
void reserve_vars(unsigned n);
void add(unsigned sz, literal const* c);
sat::bool_var add_var(bool is_internal = true);
// is this a variable that was added during initialization?
bool is_initial_var(sat::bool_var v) const {
return m_vars.size() > v && !m_vars[v].m_internal;
}
sat::bool_var add_var();
void reinit();

View file

@ -67,7 +67,7 @@ namespace sls {
virtual vector<sat::clause_info> const& clauses() const = 0;
virtual sat::clause_info const& get_clause(unsigned idx) const = 0;
virtual ptr_iterator<unsigned> get_use_list(sat::literal lit) = 0;
virtual bool flip(sat::bool_var v) = 0;
virtual void flip(sat::bool_var v) = 0;
virtual double reward(sat::bool_var v) = 0;
virtual double get_weigth(unsigned clause_idx) = 0;
virtual bool is_true(sat::literal lit) = 0;
@ -173,7 +173,7 @@ namespace sls {
sat::literal mk_literal(expr* e);
void add_clause(expr* f);
void add_clause(sat::literal_vector const& lits);
bool flip(sat::bool_var v) { return s.flip(v); }
void flip(sat::bool_var v) { s.flip(v); }
double reward(sat::bool_var v) { return s.reward(v); }
indexed_uint_set const& unsat() const { return s.unsat(); }
unsigned rand() { return m_rand(); }

View file

@ -57,7 +57,7 @@ namespace sls {
if (m_on_save_model)
return;
flet<bool> _on_save_model(m_on_save_model, true);
TRACE("sls", display(tout));
CTRACE("sls", unsat().empty(), display(tout));
while (unsat().empty()) {
m_context.check();
if (!m_new_constraint)
@ -87,7 +87,7 @@ namespace sls {
vector<sat::clause_info> const& clauses() const override { return m_ddfw.clauses(); }
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); }
bool flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; return m_ddfw.flip(v); }
void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); }
double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); }

View file

@ -2932,6 +2932,7 @@ namespace sat {
bool_var v = m_trail[i].var();
m_best_phase[v] = m_phase[v];
}
set_has_new_best_phase(true);
}
}

View file

@ -152,6 +152,7 @@ namespace sat {
bool_vector m_phase;
bool_vector m_best_phase;
bool_vector m_prev_phase;
bool m_new_best_phase = false;
svector<char> m_assigned_since_gc;
search_state m_search_state;
unsigned m_search_unsat_conflicts;
@ -380,6 +381,9 @@ namespace sat {
bool was_eliminated(literal l) const { return was_eliminated(l.var()); }
void set_phase(literal l) override { if (l.var() < num_vars()) m_best_phase[l.var()] = m_phase[l.var()] = !l.sign(); }
bool get_phase(bool_var b) { return m_phase.get(b, false); }
bool get_best_phase(bool_var b) { return m_best_phase.get(b, false); }
void set_has_new_best_phase(bool b) { m_new_best_phase = b; }
bool has_new_best_phase() const { return m_new_best_phase; }
void move_to_front(bool_var b);
unsigned scope_lvl() const { return m_scope_lvl; }
unsigned search_lvl() const { return m_search_lvl; }

View file

@ -13,6 +13,7 @@ Author:
Nikolaj Bjorner (nbjorner) 2024-02-21
--*/
#include "sat/smt/sls_solver.h"
@ -36,7 +37,125 @@ namespace sls {
finalize();
}
void solver::finalize() {
class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context {
solver& s;
sat::ddfw* m_ddfw;
sls::context m_context;
bool m_new_clause_added = false;
unsigned m_num_shared_vars = 0;
// export from SAT to SLS:
// - unit literals
// - phase
// - values
bool export_to_sls() {
bool updated = false;
if (s.m_has_units) {
std::lock_guard<std::mutex> lock(s.m_mutex);
IF_VERBOSE(1, verbose_stream() << "SAT->SLS units " << s.m_units << "\n");
for (auto lit : s.m_units)
if (lit.var() < m_num_shared_vars)
m_ddfw->add(1, &lit);
s.m_has_units = false;
s.m_units.reset();
updated = true;
}
if (m_has_new_sat_phase) {
std::lock_guard<std::mutex> lock(s.m_mutex);
IF_VERBOSE(1, verbose_stream() << "SAT->SLS phase\n");
for (unsigned i = 0; i < m_sat_phase.size(); ++i) {
if (m_sat_phase[i] != is_true(sat::literal(i, false)))
flip(i);
m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1;
}
m_has_new_sat_phase = false;
}
return updated;
}
// import from SLS:
// - activity
// - phase
// - values
void import_from_sls() {
std::lock_guard<std::mutex> lock(s.m_mutex);
for (unsigned v = 0; v < m_num_shared_vars; ++v) {
m_rewards[v] = m_ddfw->get_reward_avg(v);
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
m_has_new_sls_phase = true;
}
}
public:
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
s(s), m_ddfw(d), m_context(m, *this) {}
svector<bool> m_sat_phase;
std::atomic<bool> m_has_new_sat_phase = false;
std::atomic<bool> m_has_new_sls_phase = false;
svector<bool> m_sls_phase;
svector<double> m_rewards;
void init_search() override {}
void finish_search() override {}
void on_rescale() override {}
void on_restart() override {
if (export_to_sls())
m_ddfw->reinit();
}
void on_save_model() override {
TRACE("sls", display(tout));
while (unsat().empty()) {
m_context.check();
if (!m_new_clause_added)
break;
m_ddfw->reinit();
m_new_clause_added = false;
}
import_from_sls();
}
void on_model(model_ref& mdl) override {
IF_VERBOSE(1, verbose_stream() << "on-model " << "\n");
s.m_sls_model = mdl;
}
void register_atom(sat::bool_var v, expr* e) {
m_context.register_atom(v, e);
}
std::ostream& display(std::ostream& out) {
m_ddfw->display(out);
m_context.display(out);
return out;
}
vector<sat::clause_info> const& clauses() const override { return m_ddfw->clauses(); }
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); }
void flip(sat::bool_var v) override { m_ddfw->flip(v); }
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); }
unsigned num_vars() const override { return m_ddfw->num_vars(); }
indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); }
sat::bool_var add_var() override { return m_ddfw->add_var(); }
void add_clause(unsigned n, sat::literal const* lits) override {
m_ddfw->add(n, lits);
m_new_clause_added = true;
}
void force_restart() override { m_ddfw->force_restart(); }
};
void solver::finalize() {
if (!m_completed && m_ddfw) {
m_ddfw->rlimit().cancel();
m_thread.join();
@ -65,79 +184,25 @@ namespace sls {
m_units.push_back(lit);
m_has_units = true;
}
if (s().at_base_lvl()) {
if (s().has_new_best_phase()) {
IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n");
m_smt_plugin->m_has_new_sat_phase = true;
s().set_has_new_best_phase(false);
std::lock_guard<std::mutex> lock(m_mutex);
for (unsigned i = 0; i < m_smt_plugin->m_sat_phase.size(); ++i)
m_smt_plugin->m_sat_phase[i] = s().get_best_phase(i);
}
}
if (m_smt_plugin->m_has_new_sls_phase) {
IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n");
std::lock_guard<std::mutex> lock(m_mutex);
for (unsigned i = 0; i < m_smt_plugin->m_sls_phase.size(); ++i)
s().set_phase(sat::literal(i, !m_smt_plugin->m_sls_phase[i]));
m_smt_plugin->m_has_new_sls_phase = false;
}
}
class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context {
solver& s;
sat::ddfw* m_ddfw;
sls::context m_context;
bool m_new_clause_added = false;
public:
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
s(s), m_ddfw(d), m_context(m, *this) {}
void init_search() override {}
void finish_search() override {}
void on_rescale() override {}
void on_restart() override {
if (!s.m_has_units)
return;
{
std::lock_guard<std::mutex> lock(s.m_mutex);
for (auto lit : s.m_units)
if (m_ddfw->is_initial_var(lit.var()))
m_ddfw->add(1, &lit);
s.m_has_units = false;
s.m_units.reset();
}
m_ddfw->reinit();
}
void on_save_model() override {
TRACE("sls", display(tout));
while (unsat().empty()) {
m_context.check();
if (!m_new_clause_added)
break;
m_ddfw->reinit();
m_new_clause_added = false;
}
}
void on_model(model_ref& mdl) override {
IF_VERBOSE(1, verbose_stream() << "on-model " << "\n");
s.m_sls_model = mdl;
}
void register_atom(sat::bool_var v, expr* e) {
m_context.register_atom(v, e);
}
std::ostream& display(std::ostream& out) {
m_ddfw->display(out);
m_context.display(out);
return out;
}
vector<sat::clause_info> const& clauses() const override { return m_ddfw->clauses(); }
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); }
bool flip(sat::bool_var v) override { return m_ddfw->flip(v); }
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); }
unsigned num_vars() const override { return m_ddfw->num_vars(); }
indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); }
sat::bool_var add_var() override { return m_ddfw->add_var(); }
void add_clause(unsigned n, sat::literal const* lits) override {
m_ddfw->add(n, lits);
m_new_clause_added = true;
}
void force_restart() override { m_ddfw->force_restart(); }
};
void solver::init_search() {
if (m_ddfw) {
@ -215,6 +280,5 @@ namespace sls {
out << "sls-solver\n";
return out;
}
#endif
}