3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-18 14:49:01 +00:00

add value transfer option

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2024-10-20 16:38:00 -07:00
parent 68ee5108d8
commit cc430987b7
2 changed files with 97 additions and 33 deletions

View file

@ -81,11 +81,5 @@ namespace sat {
}
for (clause* c : s.m_clauses)
m_ddfw.add(c->size(), c->begin());
}
}
}

View file

@ -19,6 +19,7 @@ Author:
#include "sat/smt/sls_solver.h"
#include "sat/smt/euf_solver.h"
#include "ast/sls/sls_context.h"
#include "ast/for_each_expr.h"
namespace sls {
@ -39,12 +40,21 @@ namespace sls {
class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context {
ast_manager& m;
solver& s;
sat::ddfw* m_ddfw;
sls::context m_context;
bool m_new_clause_added = false;
unsigned m_num_shared_vars = 0;
unsigned m_min_unsat_size = UINT_MAX;
ast_manager m_sync_manager;
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
obj_map<expr, expr*> m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp
expr_ref_vector m_sync_uninterp;
expr_ref_vector m_sync_values;
std::atomic<bool> m_has_new_sls_values = false;
// export from SAT to SLS:
// - unit literals
// - phase
@ -75,22 +85,39 @@ namespace sls {
}
// import from SLS:
// - activity
// - phase
// - values
// - activity
void import_from_sls() {
if (unsat().size() > m_min_unsat_size)
return;
m_min_unsat_size = unsat().size();
std::lock_guard<std::mutex> lock(s.m_mutex);
for (unsigned v = 0; v < m_num_shared_vars; ++v) {
m_rewards[v] = m_ddfw->get_reward_avg(v);
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
m_has_new_sls_phase = true;
}
// import_values_from_sls();
}
void import_values_from_sls() {
IF_VERBOSE(1, verbose_stream() << "import values from sls\n");
std::lock_guard<std::mutex> lock(s.m_mutex);
ast_translation tr(m, m_sync_manager);
for (auto const& [t, t_sync] : m_sls2sync_uninterp) {
expr_ref val_t = m_context.get_value(t);
m_sync_values.set(t_sync->get_id(), tr(val_t.get()));
}
m_has_new_sls_values = true;
}
public:
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
s(s), m_ddfw(d), m_context(m, *this) {}
m(m), s(s), m_ddfw(d), m_context(m, *this),
m_sync_uninterp(m_sync_manager),
m_sync_values(m_sync_manager) {}
svector<bool> m_sat_phase;
std::atomic<bool> m_has_new_sat_phase = false;
@ -99,6 +126,17 @@ namespace sls {
svector<bool> m_sls_phase;
svector<double> m_rewards;
void add_uninterp(expr* smt_t) {
ast_translation tr1(s.ctx.get_manager(), m_sync_manager);
ast_translation tr2(s.ctx.get_manager(), m);
auto sync_t = tr1(smt_t);
auto sls_t = tr2(smt_t);
m_sync_uninterp.push_back(sync_t);
m_smt2sync_uninterp.insert(smt_t, sync_t);
m_sls2sync_uninterp.insert(sls_t, sync_t);
}
void init_search() override {}
@ -153,9 +191,48 @@ namespace sls {
m_new_clause_added = true;
}
void force_restart() override { m_ddfw->force_restart(); }
void export_values_to_smt() {
if (!m_has_new_sls_values)
return;
IF_VERBOSE(1, verbose_stream() << "export values to smt\n");
std::lock_guard<std::mutex> lock(s.m_mutex);
ast_translation tr(m_sync_manager, s.ctx.get_manager());
for (auto const& [t, t_sync] : m_smt2sync_uninterp) {
expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr);
if (sync_val) {
s.ctx.user_propagate_initialize_value(t, tr(sync_val));
}
}
m_has_new_sls_values = false;
}
void export_phase_to_smt() {
if (m_has_new_sls_phase) {
IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n");
std::lock_guard<std::mutex> lock(s.m_mutex);
for (unsigned i = 0; i < m_sls_phase.size(); ++i)
s.s().set_phase(sat::literal(i, !m_sls_phase[i]));
m_has_new_sls_phase = false;
}
}
void import_phase_from_smt() {
IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n");
m_has_new_sat_phase = true;
s.s().set_has_new_best_phase(false);
std::lock_guard<std::mutex> lock(s.m_mutex);
for (unsigned i = 0; i < m_sat_phase.size(); ++i)
m_sat_phase[i] = s.s().get_best_phase(i);
}
void export_activity_to_smt() {
// TODO
}
};
void solver::finalize() {
void solver::finalize() {
if (!m_completed && m_ddfw) {
m_ddfw->rlimit().cancel();
m_thread.join();
@ -184,23 +261,12 @@ namespace sls {
m_units.push_back(lit);
m_has_units = true;
}
if (s().at_base_lvl()) {
if (s().has_new_best_phase()) {
IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n");
m_smt_plugin->m_has_new_sat_phase = true;
s().set_has_new_best_phase(false);
std::lock_guard<std::mutex> lock(m_mutex);
for (unsigned i = 0; i < m_smt_plugin->m_sat_phase.size(); ++i)
m_smt_plugin->m_sat_phase[i] = s().get_best_phase(i);
}
}
if (m_smt_plugin->m_has_new_sls_phase) {
IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n");
std::lock_guard<std::mutex> lock(m_mutex);
for (unsigned i = 0; i < m_smt_plugin->m_sls_phase.size(); ++i)
s().set_phase(sat::literal(i, !m_smt_plugin->m_sls_phase[i]));
m_smt_plugin->m_has_new_sls_phase = false;
}
if (s().at_base_lvl() && s().has_new_best_phase())
m_smt_plugin->import_phase_from_smt();
m_smt_plugin->export_phase_to_smt();
m_smt_plugin->export_activity_to_smt();
m_smt_plugin->export_values_to_smt();
}
@ -229,12 +295,16 @@ namespace sls {
m_ddfw->add(clause.size(), clause.data());
for (sat::bool_var v = 0; v < s().num_vars(); ++v) {
expr* e = ctx.bool_var2expr(v);
if (e)
m_smt_plugin->register_atom(v, tr(e));
if (!e)
continue;
m_smt_plugin->register_atom(v, tr(e));
for (auto t : subterms::all(expr_ref(e, m))) {
if (is_uninterp(t))
m_smt_plugin->add_uninterp(t);
}
}
run_local_search_sync();
// m_thread = std::thread([this]() { run_local_search_async(); });
m_thread = std::thread([this]() { run_local_search_async(); });
}
void solver::sample_local_search() {