diff --git a/src/sat/sat_ddfw_wrapper.cpp b/src/sat/sat_ddfw_wrapper.cpp index b453c3718..2fba213de 100644 --- a/src/sat/sat_ddfw_wrapper.cpp +++ b/src/sat/sat_ddfw_wrapper.cpp @@ -81,11 +81,5 @@ namespace sat { } for (clause* c : s.m_clauses) m_ddfw.add(c->size(), c->begin()); - - } - - - - + } } - diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index bf40278b7..cd1107a87 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -19,6 +19,7 @@ Author: #include "sat/smt/sls_solver.h" #include "sat/smt/euf_solver.h" #include "ast/sls/sls_context.h" +#include "ast/for_each_expr.h" namespace sls { @@ -39,12 +40,21 @@ namespace sls { class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context { + ast_manager& m; solver& s; sat::ddfw* m_ddfw; sls::context m_context; bool m_new_clause_added = false; unsigned m_num_shared_vars = 0; - + unsigned m_min_unsat_size = UINT_MAX; + ast_manager m_sync_manager; + obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp + obj_map m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp + expr_ref_vector m_sync_uninterp; + expr_ref_vector m_sync_values; + std::atomic m_has_new_sls_values = false; + + // export from SAT to SLS: // - unit literals // - phase @@ -75,22 +85,39 @@ namespace sls { } // import from SLS: - // - activity // - phase // - values + // - activity void import_from_sls() { + if (unsat().size() > m_min_unsat_size) + return; + m_min_unsat_size = unsat().size(); std::lock_guard lock(s.m_mutex); for (unsigned v = 0; v < m_num_shared_vars; ++v) { m_rewards[v] = m_ddfw->get_reward_avg(v); m_sls_phase[v] = l_true == m_ddfw->get_model()[v]; m_has_new_sls_phase = true; } + // import_values_from_sls(); } + void import_values_from_sls() { + IF_VERBOSE(1, verbose_stream() << "import values from sls\n"); + std::lock_guard lock(s.m_mutex); + ast_translation tr(m, m_sync_manager); + for (auto const& [t, t_sync] : m_sls2sync_uninterp) { + expr_ref val_t = m_context.get_value(t); + m_sync_values.set(t_sync->get_id(), tr(val_t.get())); + } + m_has_new_sls_values = true; + } + + public: smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : - s(s), m_ddfw(d), m_context(m, *this) {} - + m(m), s(s), m_ddfw(d), m_context(m, *this), + m_sync_uninterp(m_sync_manager), + m_sync_values(m_sync_manager) {} svector m_sat_phase; std::atomic m_has_new_sat_phase = false; @@ -99,6 +126,17 @@ namespace sls { svector m_sls_phase; svector m_rewards; + + void add_uninterp(expr* smt_t) { + ast_translation tr1(s.ctx.get_manager(), m_sync_manager); + ast_translation tr2(s.ctx.get_manager(), m); + auto sync_t = tr1(smt_t); + auto sls_t = tr2(smt_t); + m_sync_uninterp.push_back(sync_t); + m_smt2sync_uninterp.insert(smt_t, sync_t); + m_sls2sync_uninterp.insert(sls_t, sync_t); + } + void init_search() override {} @@ -153,9 +191,48 @@ namespace sls { m_new_clause_added = true; } void force_restart() override { m_ddfw->force_restart(); } + + void export_values_to_smt() { + if (!m_has_new_sls_values) + return; + IF_VERBOSE(1, verbose_stream() << "export values to smt\n"); + std::lock_guard lock(s.m_mutex); + ast_translation tr(m_sync_manager, s.ctx.get_manager()); + for (auto const& [t, t_sync] : m_smt2sync_uninterp) { + expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr); + if (sync_val) { + s.ctx.user_propagate_initialize_value(t, tr(sync_val)); + } + } + m_has_new_sls_values = false; + } + + void export_phase_to_smt() { + if (m_has_new_sls_phase) { + IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n"); + std::lock_guard lock(s.m_mutex); + for (unsigned i = 0; i < m_sls_phase.size(); ++i) + s.s().set_phase(sat::literal(i, !m_sls_phase[i])); + m_has_new_sls_phase = false; + } + } + + void import_phase_from_smt() { + IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n"); + m_has_new_sat_phase = true; + s.s().set_has_new_best_phase(false); + std::lock_guard lock(s.m_mutex); + for (unsigned i = 0; i < m_sat_phase.size(); ++i) + m_sat_phase[i] = s.s().get_best_phase(i); + } + + void export_activity_to_smt() { + // TODO + } + }; - void solver::finalize() { + void solver::finalize() { if (!m_completed && m_ddfw) { m_ddfw->rlimit().cancel(); m_thread.join(); @@ -184,23 +261,12 @@ namespace sls { m_units.push_back(lit); m_has_units = true; } - if (s().at_base_lvl()) { - if (s().has_new_best_phase()) { - IF_VERBOSE(1, verbose_stream() << "new SAT->SLS phase\n"); - m_smt_plugin->m_has_new_sat_phase = true; - s().set_has_new_best_phase(false); - std::lock_guard lock(m_mutex); - for (unsigned i = 0; i < m_smt_plugin->m_sat_phase.size(); ++i) - m_smt_plugin->m_sat_phase[i] = s().get_best_phase(i); - } - } - if (m_smt_plugin->m_has_new_sls_phase) { - IF_VERBOSE(1, verbose_stream() << "new SLS->SAT phase\n"); - std::lock_guard lock(m_mutex); - for (unsigned i = 0; i < m_smt_plugin->m_sls_phase.size(); ++i) - s().set_phase(sat::literal(i, !m_smt_plugin->m_sls_phase[i])); - m_smt_plugin->m_has_new_sls_phase = false; - } + if (s().at_base_lvl() && s().has_new_best_phase()) + m_smt_plugin->import_phase_from_smt(); + + m_smt_plugin->export_phase_to_smt(); + m_smt_plugin->export_activity_to_smt(); + m_smt_plugin->export_values_to_smt(); } @@ -229,12 +295,16 @@ namespace sls { m_ddfw->add(clause.size(), clause.data()); for (sat::bool_var v = 0; v < s().num_vars(); ++v) { expr* e = ctx.bool_var2expr(v); - if (e) - m_smt_plugin->register_atom(v, tr(e)); + if (!e) + continue; + m_smt_plugin->register_atom(v, tr(e)); + for (auto t : subterms::all(expr_ref(e, m))) { + if (is_uninterp(t)) + m_smt_plugin->add_uninterp(t); + } } - run_local_search_sync(); - // m_thread = std::thread([this]() { run_local_search_async(); }); + m_thread = std::thread([this]() { run_local_search_async(); }); } void solver::sample_local_search() {