diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index 747ea4940..478e793e3 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -9,7 +9,7 @@ DDFW Local search module for clauses - Author: + Author: Nikolaj Bjorner, Marijn Heule 2019-4-23 @@ -33,27 +33,45 @@ namespace sat { ddfw::~ddfw() { - for (auto& ci : m_clauses) { - m_alloc.del_clause(ci.m_clause); - } + for (auto& ci : m_clauses) + m_alloc.del_clause(ci.m_clause); } - lbool ddfw::check(unsigned sz, literal const* assumptions, parallel* p) { init(sz, assumptions); flet _p(m_par, p); - while (m_limit.inc() && m_min_sz > 0) { - if (should_reinit_weights()) do_reinit_weights(); - else if (do_flip()) ; - else if (should_restart()) do_restart(); - else if (should_parallel_sync()) do_parallel_sync(); - else shift_weights(); - } + if (m_plugin) + check_with_plugin(); + else + check_without_plugin(); remove_assumptions(); log(); return m_min_sz == 0 ? l_true : l_undef; } + void ddfw::check_without_plugin() { + while (m_limit.inc() && m_min_sz > 0) { + if (should_reinit_weights()) do_reinit_weights(); + else if (do_flip()); + else if (should_restart()) do_restart(); + else if (should_parallel_sync()) do_parallel_sync(); + else shift_weights(); + } + } + + void ddfw::check_with_plugin() { + m_plugin->init_search(); + while (m_limit.inc() && m_min_sz > 0) { + if (should_reinit_weights()) do_reinit_weights(); + else if (do_flip()); + else if (do_literal_flip()); + else if (should_restart()) do_restart(), m_plugin->on_restart(); + else if (should_parallel_sync()) do_parallel_sync(); + else shift_weights(), m_plugin->on_rescale(); + } + m_plugin->finish_search(); + } + void ddfw::log() { double sec = m_stopwatch.get_current_seconds(); double kflips_per_sec = (m_flips - m_last_flips) / (1000.0 * sec); @@ -77,55 +95,72 @@ namespace sat { m_last_flips = m_flips; } + template bool ddfw::do_flip() { - bool_var v = pick_var(); + bool_var v = pick_var(); + return apply_flip(v); + } + + template + bool ddfw::apply_flip(bool_var v) { + if (v == null_bool_var) + return false; if (reward(v) > 0 || (reward(v) == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) { - flip(v); - if (m_unsat.size() <= m_min_sz) save_best_values(); + if (uses_plugin) + m_plugin->flip(v); + else + flip(v); + if (m_unsat.size() <= m_min_sz) + save_best_values(); return true; } return false; } + template bool_var ddfw::pick_var() { double sum_pos = 0; unsigned n = 1; + double r; bool_var v0 = null_bool_var; for (bool_var v : m_unsat_vars) { - double r = reward(v); - if (r > 0.0) { - sum_pos += score(r); - } - else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0) { - v0 = v; - } + r = uses_plugin ? plugin_reward(v) : reward(v); + if (r > 0.0) + sum_pos += score(r); + else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0) + v0 = v; } if (sum_pos > 0) { double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos; for (bool_var v : m_unsat_vars) { - double r = reward(v); + r = uses_plugin ? plugin_reward(v) : reward(v); if (r > 0) { lim_pos -= score(r); - if (lim_pos <= 0) { - return v; - } + if (lim_pos <= 0) + return v; } } } - if (v0 != null_bool_var) { + if (v0 != null_bool_var) return v0; - } + if (m_unsat_vars.empty()) + return 0; return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); } - /** - * TBD: map reward value to a score, possibly through an exponential function, such as - * exp(-tau/r), where tau > 0 - */ - double ddfw::mk_score(double r) { - return r; + template + bool ddfw::do_literal_flip() { + return apply_flip(pick_literal_var()); } + /* + * Pick a random false literal from a satisfied clause such that + * the literal has zero break count and positive reward. + */ + template + bool_var ddfw::pick_literal_var() { + return null_bool_var; + } void ddfw::add(unsigned n, literal const* c) { clause* cls = m_alloc.mk_clause(n, c, false); @@ -409,6 +444,8 @@ namespace sat { for (unsigned i = 0; i < num_vars(); ++i) m_model[i] = to_lbool(value(i)); save_priorities(); + if (m_plugin) + m_plugin->on_save_model(); } diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index ee1c73b6e..5d56c3adc 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -31,6 +31,17 @@ namespace sat { class solver; class parallel; + class local_search_plugin { + public: + virtual void init_search() = 0; + virtual void finish_search() = 0; + virtual void flip(bool_var v) = 0; + virtual double reward(bool_var v) = 0; + virtual void on_rescale() = 0; + virtual void on_save_model() = 0; + virtual void on_restart() = 0; + }; + class ddfw : public i_local_search { public: struct clause_info { @@ -83,6 +94,7 @@ namespace sat { double m_reward = 0; unsigned m_make_count = 0; int m_bias = 0; + bool m_external = false; ema m_reward_avg = 1e-5; }; @@ -113,14 +125,15 @@ namespace sat { stopwatch m_stopwatch; parallel* m_par; - - + scoped_ptr< local_search_plugin> m_plugin; void flatten_use_list(); - double mk_score(double r); - - inline double score(double r) { return r; } // TBD: { for (unsigned sz = m_scores.size(); sz <= r; ++sz) m_scores.push_back(mk_score(sz)); return m_scores[r]; } + /** + * TBD: map reward value to a score, possibly through an exponential function, such as + * exp(-tau/r), where tau > 0 + */ + inline double score(double r) { return r; } inline unsigned num_vars() const { return m_vars.size(); } @@ -134,6 +147,12 @@ namespace sat { inline double reward(bool_var v) const { return m_vars[v].m_reward; } + inline double plugin_reward(bool_var v) const { return m_plugin->reward(v); } + + void set_external(bool_var v) { m_vars[v].m_external = true; } + + inline bool is_external(bool_var v) const { return m_vars[v].m_external; } + inline int& bias(bool_var v) { return m_vars[v].m_bias; } unsigned value_hash() const; @@ -164,9 +183,25 @@ namespace sat { inline void dec_reward(literal lit, double w) { reward(lit.var()) -= w; } + void check_with_plugin(); + void check_without_plugin(); + // flip activity + template bool do_flip(); - bool_var pick_var(); + + template + bool_var pick_var(); + + template + bool apply_flip(bool_var v); + + template + bool do_literal_flip(); + + template + bool_var pick_literal_var(); + void save_best_values(); void save_model(); void save_priorities(); @@ -215,6 +250,8 @@ namespace sat { ~ddfw() override; + void set(local_search_plugin* p) { m_plugin = p; } + lbool check(unsigned sz, literal const* assumptions, parallel* p) override; void updt_params(params_ref const& p) override;