diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 2aa932f8a..626348ee8 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -488,7 +488,7 @@ namespace sls { for (sat::literal lit : m_unit_literals) m_unit_indices.insert(lit.index()); - IF_VERBOSE(0, verbose_stream() << "UNITS " << m_unit_literals << "\n"); + IF_VERBOSE(3, verbose_stream() << "UNITS " << m_unit_literals << "\n"); for (unsigned i = 0; i < m_atoms.size(); ++i) if (m_atoms.get(i)) register_terms(m_atoms.get(i)); diff --git a/src/ast/sls/sls_smt_plugin.cpp b/src/ast/sls/sls_smt_plugin.cpp index 58a4df173..e4ecb9cd1 100644 --- a/src/ast/sls/sls_smt_plugin.cpp +++ b/src/ast/sls/sls_smt_plugin.cpp @@ -48,7 +48,6 @@ namespace sls { m_completed = false; m_units.reset(); m_has_units = false; - m_model = nullptr; m_sls_model = nullptr; m_ddfw = alloc(sat::ddfw); m_ddfw->set_plugin(this); @@ -70,7 +69,6 @@ namespace sls { sls_e = m_smt2sls_tr(e); auto w = m_context.atom2bool_var(sls_e); if (w != sat::null_bool_var) { - IF_VERBOSE(0, verbose_stream() << mk_bounded_pp(e, m) << ": " << v << " -> " << w << "\n"); m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var); m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var); m_sls_phase.reserve(v + 1); @@ -87,6 +85,7 @@ namespace sls { if (!m_ddfw) return; m_result = m_ddfw->check(0, nullptr); + m_ddfw->collect_statistics(m_st); IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n"); m_completed = true; } @@ -96,16 +95,17 @@ namespace sls { if (!d) return; bool canceled = !m_completed; - IF_VERBOSE(0, verbose_stream() << "finalize\n"); - mdl = m_model; - if (!m_completed) { - d->rlimit().cancel(); - if (m_thread.joinable()) - m_thread.join(); - } + IF_VERBOSE(3, verbose_stream() << "finalize\n"); + if (!m_completed) + d->rlimit().cancel(); + if (m_thread.joinable()) + m_thread.join(); + SASSERT(m_completed); + st.copy(m_st); + mdl = nullptr; if (m_result == l_true && m_sls_model) { ast_translation tr(m_sls, m); - m_model = m_sls_model->translate(tr); + mdl = m_sls_model->translate(tr); TRACE("sls", tout << "model: " << *m_sls_model << "\n";); if (!canceled) ctx.set_finished(); @@ -115,9 +115,6 @@ namespace sls { dealloc(d); } - void smt_plugin::collect_statistics(statistics& st) { - - } std::ostream& smt_plugin::display(std::ostream& out) { m_ddfw->display(out); m_context.display(out); @@ -209,7 +206,7 @@ namespace sls { return true; } - void smt_plugin::import_from_sls() { + void smt_plugin::export_from_sls() { if (unsat().size() > m_min_unsat_size) return; m_min_unsat_size = unsat().size(); @@ -217,13 +214,16 @@ namespace sls { for (auto v : m_shared_bool_vars) { auto w = m_smt_bool_var2sls_bool_var[v]; m_rewards[v] = m_ddfw->get_reward_avg(w); - m_sls_phase[v] = l_true == m_ddfw->get_model()[w]; + //verbose_stream() << v << " " << w << "\n"; + VERIFY(m_ddfw->get_model().size() > w); + VERIFY(m_sls_phase.size() > v); + m_sls_phase[v] = l_true == m_ddfw->get_model()[w]; m_has_new_sls_phase = true; } - // import_values_from_sls(); + // export_values_from_sls(); } - void smt_plugin::import_values_from_sls() { + void smt_plugin::export_values_from_sls() { IF_VERBOSE(3, verbose_stream() << "import values from sls\n"); std::lock_guard lock(m_mutex); for (auto const& [t, t_sync] : m_sls2sync_uninterp) { @@ -232,13 +232,43 @@ namespace sls { } m_has_new_sls_values = true; } + + void smt_plugin::import_from_sls() { + export_activity_to_smt(); + export_values_to_smt(); + export_phase_to_smt(); + } void smt_plugin::export_activity_to_smt() { } void smt_plugin::export_values_to_smt() { + if (!m_has_new_sls_values) + return; + IF_VERBOSE(3, verbose_stream() << "SLS -> SMT values\n"); + std::lock_guard lock(m_mutex); + ast_translation tr(m_sync, m); + for (auto const& [t, t_sync] : m_smt2sync_uninterp) { + expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr); + if (!sync_val) + continue; + expr_ref val(tr(sync_val), m); + ctx.initialize_value(t, val); + } + m_has_new_sls_values = false; + } + void smt_plugin::export_phase_to_smt() { + if (!m_has_new_sls_phase) + return; + std::lock_guard lock(m_mutex); + IF_VERBOSE(3, verbose_stream() << "SLS -> SMT phase\n"); + for (auto v : m_shared_bool_vars) { + auto w = m_smt_bool_var2sls_bool_var[v]; + ctx.force_phase(sat::literal(w, m_sls_phase[v])); + } + m_has_new_sls_phase = false; } void smt_plugin::add_shared_term(expr* t) { @@ -256,5 +286,15 @@ namespace sls { m_sls2sync_uninterp.insert(sls_t, sync_t); } - + void smt_plugin::on_save_model() { + TRACE("sls", display(tout)); + while (unsat().empty()) { + m_context.check(); + if (!m_new_clause_added) + break; + m_ddfw->reinit(); + m_new_clause_added = false; + } + // export_from_sls(); + } } diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h index ee236d6b3..fb50df6bc 100644 --- a/src/ast/sls/sls_smt_plugin.h +++ b/src/ast/sls/sls_smt_plugin.h @@ -38,6 +38,12 @@ namespace sls { virtual unsigned get_num_bool_vars() const = 0; }; + + // + // m is accessed by the main thread + // m_sls is accessed by the sls thread + // m_sync is accessed by both + // class smt_plugin : public sat::local_search_plugin, public sat_solver_context { smt_context& ctx; ast_manager& m; @@ -52,25 +58,20 @@ namespace sls { std::atomic m_completed, m_has_units; std::thread m_thread; std::mutex m_mutex; - // m is accessed by the main thread - // m_slsm is accessed by the sls thread + sat::literal_vector m_units; - model_ref m_model, m_sls_model; - unsigned m_trail_lim = 0; + model_ref m_sls_model; ::statistics m_st; bool m_new_clause_added = false; unsigned m_min_unsat_size = UINT_MAX; obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp obj_map m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp std::atomic m_has_new_sls_values = false; - uint_set m_shared_bool_vars, m_shared_terms; svector m_sat_phase; std::atomic m_has_new_sat_phase = false; - std::atomic m_has_new_sls_phase = false; svector m_sls_phase; - svector m_rewards; svector m_smt_bool_var2sls_bool_var, m_sls_bool_var2smt_bool_var; @@ -84,12 +85,17 @@ namespace sls { void import_phase_from_smt(); void import_values_from_sls(); + void export_values_from_sls(); + void import_activity_from_sls(); bool export_phase_to_sls(); bool export_units_to_sls(); - void export_activity_to_smt(); void export_values_to_smt(); + void export_activity_to_smt(); + void export_phase_to_smt(); + void export_from_sls(); + friend class sat::ddfw; ~smt_plugin(); @@ -100,31 +106,20 @@ namespace sls { void check(expr_ref_vector const& fmls); void finalize(model_ref& md, ::statistics& st); void updt_params(params_ref& p) {} - void collect_statistics(statistics& st); std::ostream& display(std::ostream& out) override; - void import_from_sls(); + bool export_to_sls(); + void import_from_sls(); bool completed() { return m_completed; } void add_unit(sat::literal lit); - // local_search_plugin: void on_restart() override { if (export_to_sls()) m_ddfw->reinit(); } - void on_save_model() override { - TRACE("sls", display(tout)); - while (unsat().empty()) { - m_context.check(); - if (!m_new_clause_added) - break; - m_ddfw->reinit(); - m_new_clause_added = false; - } - //import_from_sls(); - } + void on_save_model() override; void on_model(model_ref& mdl) override { IF_VERBOSE(3, verbose_stream() << "on-model " << "\n"); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index e81e19eb2..087f02440 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -3593,11 +3593,8 @@ namespace smt { auto p = m_theories.get_plugin(tid); if (!p) return false; - auto mdl = dynamic_cast(p)->get_model(); - if (!mdl) - return false; - m_model = mdl; - return true; + m_model = dynamic_cast(p)->get_model(); + return m_model.get() != nullptr; } /** diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp index 3db0b1820..063ef287b 100644 --- a/src/smt/theory_sls.cpp +++ b/src/smt/theory_sls.cpp @@ -11,8 +11,7 @@ Abstract: Author: - Nikolaj Bjorner (nbjorner) 2024-02-21 - + Nikolaj Bjorner (nbjorner) 2024-10-24 --*/ @@ -23,14 +22,12 @@ Author: namespace smt { -#ifdef SINGLE_THREAD - - -#else - theory_sls::theory_sls(smt::context& ctx): + theory_sls::theory_sls(smt::context& ctx) : theory(ctx, ctx.get_manager().mk_family_id("sls")) {} +#ifndef SINGLE_THREAD + theory_sls::~theory_sls() { finalize(); } @@ -100,10 +97,8 @@ namespace smt { unsigned scope_lvl = ctx.get_scope_level(); if (ctx.get_search_level() == scope_lvl - n) { auto& lits = ctx.assigned_literals(); - for (; m_trail_lim < lits.size() && ctx.get_assign_level(lits[m_trail_lim]) == scope_lvl; ++m_trail_lim) { - auto lit = lits[m_trail_lim]; - m_smt_plugin->add_unit(lit); - } + for (; m_trail_lim < lits.size() && ctx.get_assign_level(lits[m_trail_lim]) == scope_lvl; ++m_trail_lim) + m_smt_plugin->add_unit(lits[m_trail_lim]); } #if 0 if (ctx.has_new_best_phase()) @@ -114,7 +109,6 @@ namespace smt { m_smt_plugin->import_from_sls(); } - void theory_sls::init() { if (m_smt_plugin) finalize(); diff --git a/src/smt/theory_sls.h b/src/smt/theory_sls.h index 9174d413c..7e93a8fed 100644 --- a/src/smt/theory_sls.h +++ b/src/smt/theory_sls.h @@ -11,7 +11,7 @@ Abstract: Author: - Nikolaj Bjorner (nbjorner) 2024-02-21 + Nikolaj Bjorner (nbjorner) 2024-10-24 --*/ #pragma once @@ -25,10 +25,21 @@ Author: #ifdef SINGLE_THREAD - namespace sls { - - + class theory_sls : public smt::theory { + model_ref m_model; + public: + theory_sls(context& ctx); + ~theory_sls() override {} + model_ref get_model() { return m_model; } + char const* get_name() const override { return "sls"; } + smt::theory* mk_fresh(context* new_ctx) override { return alloc(theory_sls, *new_ctx); } + void display(std::ostream& out) const override {} + bool internalize_atom(app* atom, bool gate_ctx) override { return false; } + bool internalize_term(app* term) override { return false; } + void new_eq_eh(theory_var v1, theory_var v2) override {} + void new_diseq_eh(theory_var v1, theory_var v2) override {} + }; } #else