mirror of
https://github.com/Z3Prover/z3
synced 2025-04-22 16:45:31 +00:00
update the interface in sls_solver to transfer phase between SAT and SLS
This commit is contained in:
parent
a48044c6e0
commit
68ee5108d8
7 changed files with 175 additions and 141 deletions
|
@ -49,7 +49,7 @@ namespace sat {
|
|||
void ddfw::check_without_plugin() {
|
||||
while (m_limit.inc() && m_min_sz > 0) {
|
||||
if (should_reinit_weights()) do_reinit_weights();
|
||||
else if (do_flip<false>());
|
||||
else if (do_flip());
|
||||
else if (should_restart()) do_restart();
|
||||
else if (m_parallel_sync && m_parallel_sync());
|
||||
else shift_weights();
|
||||
|
@ -67,7 +67,7 @@ namespace sat {
|
|||
if (should_reinit_weights()) do_reinit_weights();
|
||||
else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale();
|
||||
else if (should_restart()) do_restart(), m_plugin->on_restart();
|
||||
else if (do_flip<true>());
|
||||
else if (do_flip());
|
||||
else shift_weights(), m_plugin->on_rescale();
|
||||
//verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n";
|
||||
++steps;
|
||||
|
@ -102,15 +102,13 @@ namespace sat {
|
|||
m_last_flips = m_flips;
|
||||
}
|
||||
|
||||
template<bool uses_plugin>
|
||||
bool ddfw::do_flip() {
|
||||
double reward = 0;
|
||||
bool_var v = pick_var<uses_plugin>(reward);
|
||||
bool_var v = pick_var(reward);
|
||||
//verbose_stream() << "flip " << v << " " << reward << "\n";
|
||||
return apply_flip<uses_plugin>(v, reward);
|
||||
return apply_flip(v, reward);
|
||||
}
|
||||
|
||||
template<bool uses_plugin>
|
||||
bool ddfw::apply_flip(bool_var v, double reward) {
|
||||
if (v == null_bool_var)
|
||||
return false;
|
||||
|
@ -124,7 +122,6 @@ namespace sat {
|
|||
return false;
|
||||
}
|
||||
|
||||
template<bool uses_plugin>
|
||||
bool_var ddfw::pick_var(double& r) {
|
||||
double sum_pos = 0;
|
||||
unsigned n = 1;
|
||||
|
@ -167,13 +164,17 @@ namespace sat {
|
|||
}
|
||||
}
|
||||
|
||||
sat::bool_var ddfw::add_var(bool is_internal) {
|
||||
sat::bool_var ddfw::add_var() {
|
||||
auto v = m_vars.size();
|
||||
m_vars.reserve(v + 1);
|
||||
m_vars[v].m_internal = is_internal;
|
||||
return v;
|
||||
}
|
||||
|
||||
void ddfw::reserve_vars(unsigned n) {
|
||||
m_vars.reserve(n);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Remove the last clause that was added
|
||||
*/
|
||||
|
@ -215,11 +216,6 @@ namespace sat {
|
|||
m_restart_count = 0;
|
||||
m_restart_next = m_config.m_restart_base*2;
|
||||
|
||||
#if 0
|
||||
m_parsync_count = 0;
|
||||
m_parsync_next = m_config.m_parsync_base;
|
||||
#endif
|
||||
|
||||
m_min_sz = m_unsat.size();
|
||||
m_flips = 0;
|
||||
m_last_flips = 0;
|
||||
|
@ -244,9 +240,8 @@ namespace sat {
|
|||
m_use_list_index.push_back(m_flat_use_list.size());
|
||||
}
|
||||
|
||||
bool ddfw::flip(bool_var v) {
|
||||
void ddfw::flip(bool_var v) {
|
||||
++m_flips;
|
||||
bool new_unsat = false;
|
||||
literal lit = literal(v, !value(v));
|
||||
literal nlit = ~lit;
|
||||
SASSERT(is_true(lit));
|
||||
|
@ -262,7 +257,6 @@ namespace sat {
|
|||
verbose_stream() << "flipping unit clause " << ci << "\n";
|
||||
#endif
|
||||
m_unsat.insert_fresh(cls_idx);
|
||||
new_unsat = true;
|
||||
auto const& c = get_clause(cls_idx);
|
||||
for (literal l : c) {
|
||||
inc_reward(l, w);
|
||||
|
@ -304,7 +298,6 @@ namespace sat {
|
|||
}
|
||||
value(v) = !value(v);
|
||||
update_reward_avg(v);
|
||||
return new_unsat;
|
||||
}
|
||||
|
||||
bool ddfw::should_reinit_weights() {
|
||||
|
@ -404,38 +397,20 @@ namespace sat {
|
|||
for (unsigned i = 0; i < num_vars(); ++i)
|
||||
m_model[i] = to_lbool(value(i));
|
||||
save_priorities();
|
||||
if (m_plugin && m_unsat.empty())
|
||||
m_plugin->on_save_model();
|
||||
if (m_plugin)
|
||||
m_plugin->on_save_model();
|
||||
}
|
||||
|
||||
|
||||
void ddfw::save_best_values() {
|
||||
if (m_unsat.size() < m_min_sz || m_unsat.empty()) {
|
||||
if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)
|
||||
save_model();
|
||||
if ((m_unsat.size() < m_min_sz || m_unsat.empty()) &&
|
||||
((m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)))
|
||||
save_model();
|
||||
|
||||
if (m_unsat.size() < m_min_sz) {
|
||||
m_models.reset();
|
||||
m_min_sz = m_unsat.size();
|
||||
}
|
||||
|
||||
if (m_unsat.size() < m_min_sz)
|
||||
m_models.reset();
|
||||
|
||||
m_min_sz = m_unsat.size();
|
||||
|
||||
#if 0
|
||||
m_num_models.reserve(m_min_sz + 1);
|
||||
unsigned nm = m_num_models[m_min_sz]++;
|
||||
|
||||
|
||||
if (nm >= 10) {
|
||||
if (nm >= 200)
|
||||
m_num_models[m_min_sz] = 10, m_restart_next = m_flips;
|
||||
if (nm % 1 == 0) {
|
||||
for (unsigned v = 0; v < num_vars(); ++v)
|
||||
bias(v) += value(v) ? 1 : -1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
unsigned h = value_hash();
|
||||
unsigned occs = 0;
|
||||
bool contains = m_models.find(h, occs);
|
||||
|
@ -449,8 +424,7 @@ namespace sat {
|
|||
if (occs > 100) {
|
||||
m_restart_next = m_flips;
|
||||
m_models.erase(h);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
unsigned ddfw::value_hash() const {
|
||||
|
|
|
@ -69,13 +69,11 @@ namespace sat {
|
|||
|
||||
struct var_info {
|
||||
var_info() {}
|
||||
bool m_internal = false;
|
||||
bool m_value = false;
|
||||
double m_reward = 0;
|
||||
double m_last_reward = 0;
|
||||
unsigned m_make_count = 0;
|
||||
int m_bias = 0;
|
||||
bool m_external = false;
|
||||
ema m_reward_avg = 1e-5;
|
||||
};
|
||||
|
||||
|
@ -124,11 +122,6 @@ namespace sat {
|
|||
|
||||
inline double& reward(bool_var v) { return m_vars[v].m_reward; }
|
||||
|
||||
void set_external(bool_var v) { m_vars[v].m_external = true; }
|
||||
|
||||
inline bool is_external(bool_var v) const { return m_vars[v].m_external; }
|
||||
|
||||
inline int& bias(bool_var v) { return m_vars[v].m_bias; }
|
||||
|
||||
unsigned value_hash() const;
|
||||
|
||||
|
@ -162,13 +155,10 @@ namespace sat {
|
|||
void check_without_plugin();
|
||||
|
||||
// flip activity
|
||||
template<bool uses_plugin>
|
||||
bool do_flip();
|
||||
|
||||
template<bool uses_plugin>
|
||||
bool_var pick_var(double& reward);
|
||||
|
||||
template<bool uses_plugin>
|
||||
bool apply_flip(bool_var v, double reward);
|
||||
|
||||
|
||||
|
@ -253,18 +243,19 @@ namespace sat {
|
|||
|
||||
void remove_assumptions();
|
||||
|
||||
bool flip(bool_var v);
|
||||
void flip(bool_var v);
|
||||
|
||||
inline double get_reward(bool_var v) const { return m_vars[v].m_reward; }
|
||||
|
||||
double get_reward_avg(bool_var v) const { return m_vars[v].m_reward_avg; }
|
||||
|
||||
inline int& bias(bool_var v) { return m_vars[v].m_bias; }
|
||||
|
||||
void reserve_vars(unsigned n);
|
||||
|
||||
void add(unsigned sz, literal const* c);
|
||||
|
||||
sat::bool_var add_var(bool is_internal = true);
|
||||
|
||||
// is this a variable that was added during initialization?
|
||||
bool is_initial_var(sat::bool_var v) const {
|
||||
return m_vars.size() > v && !m_vars[v].m_internal;
|
||||
}
|
||||
sat::bool_var add_var();
|
||||
|
||||
void reinit();
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ namespace sls {
|
|||
virtual vector<sat::clause_info> const& clauses() const = 0;
|
||||
virtual sat::clause_info const& get_clause(unsigned idx) const = 0;
|
||||
virtual ptr_iterator<unsigned> get_use_list(sat::literal lit) = 0;
|
||||
virtual bool flip(sat::bool_var v) = 0;
|
||||
virtual void flip(sat::bool_var v) = 0;
|
||||
virtual double reward(sat::bool_var v) = 0;
|
||||
virtual double get_weigth(unsigned clause_idx) = 0;
|
||||
virtual bool is_true(sat::literal lit) = 0;
|
||||
|
@ -173,7 +173,7 @@ namespace sls {
|
|||
sat::literal mk_literal(expr* e);
|
||||
void add_clause(expr* f);
|
||||
void add_clause(sat::literal_vector const& lits);
|
||||
bool flip(sat::bool_var v) { return s.flip(v); }
|
||||
void flip(sat::bool_var v) { s.flip(v); }
|
||||
double reward(sat::bool_var v) { return s.reward(v); }
|
||||
indexed_uint_set const& unsat() const { return s.unsat(); }
|
||||
unsigned rand() { return m_rand(); }
|
||||
|
|
|
@ -57,7 +57,7 @@ namespace sls {
|
|||
if (m_on_save_model)
|
||||
return;
|
||||
flet<bool> _on_save_model(m_on_save_model, true);
|
||||
TRACE("sls", display(tout));
|
||||
CTRACE("sls", unsat().empty(), display(tout));
|
||||
while (unsat().empty()) {
|
||||
m_context.check();
|
||||
if (!m_new_constraint)
|
||||
|
@ -87,7 +87,7 @@ namespace sls {
|
|||
vector<sat::clause_info> const& clauses() const override { return m_ddfw.clauses(); }
|
||||
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); }
|
||||
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); }
|
||||
bool flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; return m_ddfw.flip(v); }
|
||||
void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); }
|
||||
double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); }
|
||||
double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; }
|
||||
bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); }
|
||||
|
|
|
@ -2932,6 +2932,7 @@ namespace sat {
|
|||
bool_var v = m_trail[i].var();
|
||||
m_best_phase[v] = m_phase[v];
|
||||
}
|
||||
set_has_new_best_phase(true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -152,6 +152,7 @@ namespace sat {
|
|||
bool_vector m_phase;
|
||||
bool_vector m_best_phase;
|
||||
bool_vector m_prev_phase;
|
||||
bool m_new_best_phase = false;
|
||||
svector<char> m_assigned_since_gc;
|
||||
search_state m_search_state;
|
||||
unsigned m_search_unsat_conflicts;
|
||||
|
@ -380,6 +381,9 @@ namespace sat {
|
|||
bool was_eliminated(literal l) const { return was_eliminated(l.var()); }
|
||||
void set_phase(literal l) override { if (l.var() < num_vars()) m_best_phase[l.var()] = m_phase[l.var()] = !l.sign(); }
|
||||
bool get_phase(bool_var b) { return m_phase.get(b, false); }
|
||||
bool get_best_phase(bool_var b) { return m_best_phase.get(b, false); }
|
||||
void set_has_new_best_phase(bool b) { m_new_best_phase = b; }
|
||||
bool has_new_best_phase() const { return m_new_best_phase; }
|
||||
void move_to_front(bool_var b);
|
||||
unsigned scope_lvl() const { return m_scope_lvl; }
|
||||
unsigned search_lvl() const { return m_search_lvl; }
|
||||
|
|
|
@ -13,6 +13,7 @@ Author:
|
|||
|
||||
Nikolaj Bjorner (nbjorner) 2024-02-21
|
||||
|
||||
|
||||
--*/
|
||||
|
||||
#include "sat/smt/sls_solver.h"
|
||||
|
@ -36,7 +37,125 @@ namespace sls {
|
|||
finalize();
|
||||
}
|
||||
|
||||
void solver::finalize() {
|
||||
|
||||
class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context {
|
||||
solver& s;
|
||||
sat::ddfw* m_ddfw;
|
||||
sls::context m_context;
|
||||
bool m_new_clause_added = false;
|
||||
unsigned m_num_shared_vars = 0;
|
||||
|
||||
// export from SAT to SLS:
|
||||
// - unit literals
|
||||
// - phase
|
||||
// - values
|
||||
bool export_to_sls() {
|
||||
bool updated = false;
|
||||
if (s.m_has_units) {
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
IF_VERBOSE(1, verbose_stream() << "SAT->SLS units " << s.m_units << "\n");
|
||||
for (auto lit : s.m_units)
|
||||
if (lit.var() < m_num_shared_vars)
|
||||
m_ddfw->add(1, &lit);
|
||||
s.m_has_units = false;
|
||||
s.m_units.reset();
|
||||
updated = true;
|
||||
}
|
||||
if (m_has_new_sat_phase) {
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
IF_VERBOSE(1, verbose_stream() << "SAT->SLS phase\n");
|
||||
for (unsigned i = 0; i < m_sat_phase.size(); ++i) {
|
||||
if (m_sat_phase[i] != is_true(sat::literal(i, false)))
|
||||
flip(i);
|
||||
m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1;
|
||||
}
|
||||
m_has_new_sat_phase = false;
|
||||
}
|
||||
return updated;
|
||||
}
|
||||
|
||||
// import from SLS:
|
||||
// - activity
|
||||
// - phase
|
||||
// - values
|
||||
void import_from_sls() {
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (unsigned v = 0; v < m_num_shared_vars; ++v) {
|
||||
m_rewards[v] = m_ddfw->get_reward_avg(v);
|
||||
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
|
||||
m_has_new_sls_phase = true;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
|
||||
s(s), m_ddfw(d), m_context(m, *this) {}
|
||||
|
||||
|
||||
svector<bool> m_sat_phase;
|
||||
std::atomic<bool> m_has_new_sat_phase = false;
|
||||
|
||||
std::atomic<bool> m_has_new_sls_phase = false;
|
||||
svector<bool> m_sls_phase;
|
||||
|
||||
svector<double> m_rewards;
|
||||
|
||||
void init_search() override {}
|
||||
|
||||
void finish_search() override {}
|
||||
|
||||
void on_rescale() override {}
|
||||
|
||||
void on_restart() override {
|
||||
if (export_to_sls())
|
||||
m_ddfw->reinit();
|
||||
}
|
||||
|
||||
void on_save_model() override {
|
||||
TRACE("sls", display(tout));
|
||||
while (unsat().empty()) {
|
||||
m_context.check();
|
||||
if (!m_new_clause_added)
|
||||
break;
|
||||
m_ddfw->reinit();
|
||||
m_new_clause_added = false;
|
||||
}
|
||||
import_from_sls();
|
||||
}
|
||||
|
||||
void on_model(model_ref& mdl) override {
|
||||
IF_VERBOSE(1, verbose_stream() << "on-model " << "\n");
|
||||
s.m_sls_model = mdl;
|
||||
}
|
||||
|
||||
void register_atom(sat::bool_var v, expr* e) {
|
||||
m_context.register_atom(v, e);
|
||||
}
|
||||
|
||||
std::ostream& display(std::ostream& out) {
|
||||
m_ddfw->display(out);
|
||||
m_context.display(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
vector<sat::clause_info> const& clauses() const override { return m_ddfw->clauses(); }
|
||||
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); }
|
||||
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); }
|
||||
void flip(sat::bool_var v) override { m_ddfw->flip(v); }
|
||||
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
|
||||
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
|
||||
bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); }
|
||||
unsigned num_vars() const override { return m_ddfw->num_vars(); }
|
||||
indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); }
|
||||
sat::bool_var add_var() override { return m_ddfw->add_var(); }
|
||||
void add_clause(unsigned n, sat::literal const* lits) override {
|
||||
m_ddfw->add(n, lits);
|
||||
m_new_clause_added = true;
|
||||
}
|
||||
void force_restart() override { m_ddfw->force_restart(); }
|
||||
};
|
||||
|
||||
void solver::finalize() {
|
||||
if (!m_completed && m_ddfw) {
|
||||
m_ddfw->rlimit().cancel();
|
||||
m_thread.join();
|
||||
|
@ -65,79 +184,25 @@ namespace sls {
|
|||
m_units.push_back(lit);
|
||||
m_has_units = true;
|
||||
}
|
||||
if (s().at_base_lvl()) {
|
||||
if (s().has_new_best_phase()) {
|
||||
IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n");
|
||||
m_smt_plugin->m_has_new_sat_phase = true;
|
||||
s().set_has_new_best_phase(false);
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
for (unsigned i = 0; i < m_smt_plugin->m_sat_phase.size(); ++i)
|
||||
m_smt_plugin->m_sat_phase[i] = s().get_best_phase(i);
|
||||
}
|
||||
}
|
||||
if (m_smt_plugin->m_has_new_sls_phase) {
|
||||
IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n");
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
for (unsigned i = 0; i < m_smt_plugin->m_sls_phase.size(); ++i)
|
||||
s().set_phase(sat::literal(i, !m_smt_plugin->m_sls_phase[i]));
|
||||
m_smt_plugin->m_has_new_sls_phase = false;
|
||||
}
|
||||
}
|
||||
|
||||
class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context {
|
||||
solver& s;
|
||||
sat::ddfw* m_ddfw;
|
||||
sls::context m_context;
|
||||
bool m_new_clause_added = false;
|
||||
public:
|
||||
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
|
||||
s(s), m_ddfw(d), m_context(m, *this) {}
|
||||
|
||||
void init_search() override {}
|
||||
|
||||
void finish_search() override {}
|
||||
|
||||
void on_rescale() override {}
|
||||
|
||||
void on_restart() override {
|
||||
if (!s.m_has_units)
|
||||
return;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (auto lit : s.m_units)
|
||||
if (m_ddfw->is_initial_var(lit.var()))
|
||||
m_ddfw->add(1, &lit);
|
||||
s.m_has_units = false;
|
||||
s.m_units.reset();
|
||||
}
|
||||
m_ddfw->reinit();
|
||||
}
|
||||
|
||||
void on_save_model() override {
|
||||
TRACE("sls", display(tout));
|
||||
while (unsat().empty()) {
|
||||
m_context.check();
|
||||
if (!m_new_clause_added)
|
||||
break;
|
||||
m_ddfw->reinit();
|
||||
m_new_clause_added = false;
|
||||
}
|
||||
}
|
||||
|
||||
void on_model(model_ref& mdl) override {
|
||||
IF_VERBOSE(1, verbose_stream() << "on-model " << "\n");
|
||||
s.m_sls_model = mdl;
|
||||
}
|
||||
|
||||
void register_atom(sat::bool_var v, expr* e) {
|
||||
m_context.register_atom(v, e);
|
||||
}
|
||||
|
||||
std::ostream& display(std::ostream& out) {
|
||||
m_ddfw->display(out);
|
||||
m_context.display(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
vector<sat::clause_info> const& clauses() const override { return m_ddfw->clauses(); }
|
||||
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); }
|
||||
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); }
|
||||
bool flip(sat::bool_var v) override { return m_ddfw->flip(v); }
|
||||
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
|
||||
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
|
||||
bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); }
|
||||
unsigned num_vars() const override { return m_ddfw->num_vars(); }
|
||||
indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); }
|
||||
sat::bool_var add_var() override { return m_ddfw->add_var(); }
|
||||
void add_clause(unsigned n, sat::literal const* lits) override {
|
||||
m_ddfw->add(n, lits);
|
||||
m_new_clause_added = true;
|
||||
}
|
||||
void force_restart() override { m_ddfw->force_restart(); }
|
||||
};
|
||||
|
||||
void solver::init_search() {
|
||||
if (m_ddfw) {
|
||||
|
@ -215,6 +280,5 @@ namespace sls {
|
|||
out << "sls-solver\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue