mirror of
https://github.com/Z3Prover/z3
synced 2025-04-18 14:49:01 +00:00
add value transfer option
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
68ee5108d8
commit
cc430987b7
|
@ -81,11 +81,5 @@ namespace sat {
|
|||
}
|
||||
for (clause* c : s.m_clauses)
|
||||
m_ddfw.add(c->size(), c->begin());
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ Author:
|
|||
#include "sat/smt/sls_solver.h"
|
||||
#include "sat/smt/euf_solver.h"
|
||||
#include "ast/sls/sls_context.h"
|
||||
#include "ast/for_each_expr.h"
|
||||
|
||||
namespace sls {
|
||||
|
||||
|
@ -39,12 +40,21 @@ namespace sls {
|
|||
|
||||
|
||||
class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context {
|
||||
ast_manager& m;
|
||||
solver& s;
|
||||
sat::ddfw* m_ddfw;
|
||||
sls::context m_context;
|
||||
bool m_new_clause_added = false;
|
||||
unsigned m_num_shared_vars = 0;
|
||||
|
||||
unsigned m_min_unsat_size = UINT_MAX;
|
||||
ast_manager m_sync_manager;
|
||||
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
|
||||
obj_map<expr, expr*> m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp
|
||||
expr_ref_vector m_sync_uninterp;
|
||||
expr_ref_vector m_sync_values;
|
||||
std::atomic<bool> m_has_new_sls_values = false;
|
||||
|
||||
|
||||
// export from SAT to SLS:
|
||||
// - unit literals
|
||||
// - phase
|
||||
|
@ -75,22 +85,39 @@ namespace sls {
|
|||
}
|
||||
|
||||
// import from SLS:
|
||||
// - activity
|
||||
// - phase
|
||||
// - values
|
||||
// - activity
|
||||
void import_from_sls() {
|
||||
if (unsat().size() > m_min_unsat_size)
|
||||
return;
|
||||
m_min_unsat_size = unsat().size();
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (unsigned v = 0; v < m_num_shared_vars; ++v) {
|
||||
m_rewards[v] = m_ddfw->get_reward_avg(v);
|
||||
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
|
||||
m_has_new_sls_phase = true;
|
||||
}
|
||||
// import_values_from_sls();
|
||||
}
|
||||
|
||||
void import_values_from_sls() {
|
||||
IF_VERBOSE(1, verbose_stream() << "import values from sls\n");
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
ast_translation tr(m, m_sync_manager);
|
||||
for (auto const& [t, t_sync] : m_sls2sync_uninterp) {
|
||||
expr_ref val_t = m_context.get_value(t);
|
||||
m_sync_values.set(t_sync->get_id(), tr(val_t.get()));
|
||||
}
|
||||
m_has_new_sls_values = true;
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
|
||||
s(s), m_ddfw(d), m_context(m, *this) {}
|
||||
|
||||
m(m), s(s), m_ddfw(d), m_context(m, *this),
|
||||
m_sync_uninterp(m_sync_manager),
|
||||
m_sync_values(m_sync_manager) {}
|
||||
|
||||
svector<bool> m_sat_phase;
|
||||
std::atomic<bool> m_has_new_sat_phase = false;
|
||||
|
@ -99,6 +126,17 @@ namespace sls {
|
|||
svector<bool> m_sls_phase;
|
||||
|
||||
svector<double> m_rewards;
|
||||
|
||||
void add_uninterp(expr* smt_t) {
|
||||
ast_translation tr1(s.ctx.get_manager(), m_sync_manager);
|
||||
ast_translation tr2(s.ctx.get_manager(), m);
|
||||
auto sync_t = tr1(smt_t);
|
||||
auto sls_t = tr2(smt_t);
|
||||
m_sync_uninterp.push_back(sync_t);
|
||||
m_smt2sync_uninterp.insert(smt_t, sync_t);
|
||||
m_sls2sync_uninterp.insert(sls_t, sync_t);
|
||||
}
|
||||
|
||||
|
||||
void init_search() override {}
|
||||
|
||||
|
@ -153,9 +191,48 @@ namespace sls {
|
|||
m_new_clause_added = true;
|
||||
}
|
||||
void force_restart() override { m_ddfw->force_restart(); }
|
||||
|
||||
void export_values_to_smt() {
|
||||
if (!m_has_new_sls_values)
|
||||
return;
|
||||
IF_VERBOSE(1, verbose_stream() << "export values to smt\n");
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
ast_translation tr(m_sync_manager, s.ctx.get_manager());
|
||||
for (auto const& [t, t_sync] : m_smt2sync_uninterp) {
|
||||
expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr);
|
||||
if (sync_val) {
|
||||
s.ctx.user_propagate_initialize_value(t, tr(sync_val));
|
||||
}
|
||||
}
|
||||
m_has_new_sls_values = false;
|
||||
}
|
||||
|
||||
void export_phase_to_smt() {
|
||||
if (m_has_new_sls_phase) {
|
||||
IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n");
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (unsigned i = 0; i < m_sls_phase.size(); ++i)
|
||||
s.s().set_phase(sat::literal(i, !m_sls_phase[i]));
|
||||
m_has_new_sls_phase = false;
|
||||
}
|
||||
}
|
||||
|
||||
void import_phase_from_smt() {
|
||||
IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n");
|
||||
m_has_new_sat_phase = true;
|
||||
s.s().set_has_new_best_phase(false);
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (unsigned i = 0; i < m_sat_phase.size(); ++i)
|
||||
m_sat_phase[i] = s.s().get_best_phase(i);
|
||||
}
|
||||
|
||||
void export_activity_to_smt() {
|
||||
// TODO
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
void solver::finalize() {
|
||||
void solver::finalize() {
|
||||
if (!m_completed && m_ddfw) {
|
||||
m_ddfw->rlimit().cancel();
|
||||
m_thread.join();
|
||||
|
@ -184,23 +261,12 @@ namespace sls {
|
|||
m_units.push_back(lit);
|
||||
m_has_units = true;
|
||||
}
|
||||
if (s().at_base_lvl()) {
|
||||
if (s().has_new_best_phase()) {
|
||||
IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n");
|
||||
m_smt_plugin->m_has_new_sat_phase = true;
|
||||
s().set_has_new_best_phase(false);
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
for (unsigned i = 0; i < m_smt_plugin->m_sat_phase.size(); ++i)
|
||||
m_smt_plugin->m_sat_phase[i] = s().get_best_phase(i);
|
||||
}
|
||||
}
|
||||
if (m_smt_plugin->m_has_new_sls_phase) {
|
||||
IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n");
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
for (unsigned i = 0; i < m_smt_plugin->m_sls_phase.size(); ++i)
|
||||
s().set_phase(sat::literal(i, !m_smt_plugin->m_sls_phase[i]));
|
||||
m_smt_plugin->m_has_new_sls_phase = false;
|
||||
}
|
||||
if (s().at_base_lvl() && s().has_new_best_phase())
|
||||
m_smt_plugin->import_phase_from_smt();
|
||||
|
||||
m_smt_plugin->export_phase_to_smt();
|
||||
m_smt_plugin->export_activity_to_smt();
|
||||
m_smt_plugin->export_values_to_smt();
|
||||
}
|
||||
|
||||
|
||||
|
@ -229,12 +295,16 @@ namespace sls {
|
|||
m_ddfw->add(clause.size(), clause.data());
|
||||
for (sat::bool_var v = 0; v < s().num_vars(); ++v) {
|
||||
expr* e = ctx.bool_var2expr(v);
|
||||
if (e)
|
||||
m_smt_plugin->register_atom(v, tr(e));
|
||||
if (!e)
|
||||
continue;
|
||||
m_smt_plugin->register_atom(v, tr(e));
|
||||
for (auto t : subterms::all(expr_ref(e, m))) {
|
||||
if (is_uninterp(t))
|
||||
m_smt_plugin->add_uninterp(t);
|
||||
}
|
||||
}
|
||||
|
||||
run_local_search_sync();
|
||||
// m_thread = std::thread([this]() { run_local_search_async(); });
|
||||
m_thread = std::thread([this]() { run_local_search_async(); });
|
||||
}
|
||||
|
||||
void solver::sample_local_search() {
|
||||
|
|
Loading…
Reference in a new issue