From 848bfb14a1764a25c0564226f0c5915003245de6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 25 Oct 2024 23:29:26 -0700 Subject: [PATCH] use common infrastructure for sls-smt --- src/ast/sls/sls_smt_plugin.cpp | 43 ++-- src/ast/sls/sls_smt_plugin.h | 5 +- src/sat/smt/sls_solver.cpp | 441 ++++++--------------------------- src/sat/smt/sls_solver.h | 41 ++- src/smt/theory_sls.cpp | 3 +- 5 files changed, 126 insertions(+), 407 deletions(-) diff --git a/src/ast/sls/sls_smt_plugin.cpp b/src/ast/sls/sls_smt_plugin.cpp index e4ecb9cd1..6f0a342cc 100644 --- a/src/ast/sls/sls_smt_plugin.cpp +++ b/src/ast/sls/sls_smt_plugin.cpp @@ -39,9 +39,8 @@ namespace sls { smt_plugin::~smt_plugin() { SASSERT(!m_ddfw); } - - void smt_plugin::check(expr_ref_vector const& fmls) { + void smt_plugin::check(expr_ref_vector const& fmls, vector const& clauses) { SASSERT(!m_ddfw); // set up state for local search theory_sls here m_result = l_undef; @@ -53,29 +52,37 @@ namespace sls { m_ddfw->set_plugin(this); m_ddfw->updt_params(ctx.get_params()); + for (auto const& clause : clauses) { + m_ddfw->add(clause.size(), clause.data()); + for (auto lit : clause) + add_shared_var(lit.var(), lit.var()); + } + + for (auto v : m_shared_bool_vars) { + expr* e = ctx.bool_var2expr(v); + if (!e) + continue; + m_context.register_atom(v, m_smt2sls_tr(e)); + for (auto t : subterms::all(expr_ref(e, m))) + add_shared_term(e); + } + for (auto fml : fmls) m_context.add_constraint(m_smt2sls_tr(fml)); - // m_context.display(verbose_stream()); - for (unsigned v = 0; v < ctx.get_num_bool_vars(); ++v) { expr* e = ctx.bool_var2expr(v); if (!e) continue; - for (auto t : subterms::all(expr_ref(e, m))) - add_shared_term(e); expr_ref sls_e(m_sls); sls_e = m_smt2sls_tr(e); auto w = m_context.atom2bool_var(sls_e); - if (w != sat::null_bool_var) { - m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var); - m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var); - m_sls_phase.reserve(v + 1); - m_sat_phase.reserve(v + 1); - m_rewards.reserve(v + 1); - m_shared_bool_vars.insert(v); - } + if (w == sat::null_bool_var) + continue; + add_shared_var(v, w); + for (auto t : subterms::all(expr_ref(e, m))) + add_shared_term(e); } m_thread = std::thread([this]() { run(); }); @@ -139,6 +146,14 @@ namespace sls { return false; } + void smt_plugin::add_shared_var(sat::bool_var v, sat::bool_var w) { + m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var); + m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var); + m_sls_phase.reserve(v + 1); + m_sat_phase.reserve(v + 1); + m_rewards.reserve(v + 1); + m_shared_bool_vars.insert(v); + } void smt_plugin::add_unit(sat::literal lit) { if (!is_shared(lit)) diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h index fb50df6bc..ad7fd73d5 100644 --- a/src/ast/sls/sls_smt_plugin.h +++ b/src/ast/sls/sls_smt_plugin.h @@ -80,8 +80,8 @@ namespace sls { bool is_shared(sat::literal lit); void run(); void add_shared_term(expr* t); - void add_uninterp(expr* smt_t); + void add_shared_var(sat::bool_var v, sat::bool_var w); void import_phase_from_smt(); void import_values_from_sls(); @@ -93,7 +93,6 @@ namespace sls { void export_activity_to_smt(); void export_phase_to_smt(); - void export_from_sls(); friend class sat::ddfw; @@ -103,7 +102,7 @@ namespace sls { smt_plugin(smt_context& ctx); // interface to calling solver: - void check(expr_ref_vector const& fmls); + void check(expr_ref_vector const& fmls, vector const& clauses); void finalize(model_ref& md, ::statistics& st); void updt_params(params_ref& p) {} std::ostream& display(std::ostream& out) override; diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index b78b7e1b9..46524cbd1 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -23,399 +23,108 @@ Author: namespace sls { -#ifdef SINGLE_THREAD solver::solver(euf::solver& ctx) : th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) - {} + {} + +#ifdef SINGLE_THREAD #else - solver::solver(euf::solver& ctx): - th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) - {} solver::~solver() { finalize(); } - - class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context { - ast_manager& m; - solver& s; - sat::ddfw* m_ddfw; - sls::context m_context; - bool m_new_clause_added = false; - unsigned m_min_unsat_size = UINT_MAX; - ast_manager m_sync_manager; - obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp - obj_map m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp - ast_translation m_smt2sync_tr, m_smt2sls_tr; - expr_ref_vector m_sync_uninterp; - expr_ref_vector m_sync_values; - std::atomic m_has_new_sls_values = false; - - // export from SAT to SLS: - // - unit literals - // - phase - - bool export_units_to_sls() { - if (!s.m_has_units) - return false; - std::lock_guard lock(s.m_mutex); - IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n"); - for (auto lit : s.m_units) { - if (m_shared_bool_vars.contains(lit.var())) { - IF_VERBOSE(10, verbose_stream() << "unit " << lit << "\n"); - m_ddfw->add(1, &lit); - } - else { - IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " " << mk_bounded_pp(s.ctx.bool_var2expr(lit.var()), s.ctx.get_manager()) << "\n"); - } - } - - s.m_has_units = false; - s.m_units.reset(); - return true; - } - - bool export_phase_to_sls() { - if (!m_has_new_sat_phase) - return false; - std::lock_guard lock(s.m_mutex); - IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n"); - for (auto i : m_shared_bool_vars) { - if (m_sat_phase[i] != is_true(sat::literal(i, false))) - flip(i); - m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1; - } - m_has_new_sat_phase = false; - return true; - } - - bool export_to_sls() { - bool updated = false; - if (export_units_to_sls()) - updated = true; - if (export_phase_to_sls()) - updated = true; - return updated; - } - - // import from SLS: - // - phase - // - values - // - activity - void import_from_sls() { - if (unsat().size() > m_min_unsat_size) - return; - m_min_unsat_size = unsat().size(); - std::lock_guard lock(s.m_mutex); - for (auto v : m_shared_bool_vars) { - m_rewards[v] = m_ddfw->get_reward_avg(v); - m_sls_phase[v] = l_true == m_ddfw->get_model()[v]; - m_has_new_sls_phase = true; - } - // import_values_from_sls(); - } - - void import_values_from_sls() { - IF_VERBOSE(3, verbose_stream() << "import values from sls\n"); - std::lock_guard lock(s.m_mutex); - ast_translation tr(m, m_sync_manager); - for (auto const& [t, t_sync] : m_sls2sync_uninterp) { - expr_ref val_t = m_context.get_value(t); - m_sync_values.set(t_sync->get_id(), tr(val_t.get())); - } - m_has_new_sls_values = true; - } - - - - void add_uninterp(expr* smt_t) { - auto sync_t = m_smt2sync_tr(smt_t); - auto sls_t = m_smt2sls_tr(smt_t); - m_sync_uninterp.push_back(sync_t); - m_smt2sync_uninterp.insert(smt_t, sync_t); - m_sls2sync_uninterp.insert(sls_t, sync_t); - } - - public: - smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : - m(m), s(s), m_ddfw(d), m_context(m, *this), - m_sync_uninterp(m_sync_manager), - m_sync_values(m_sync_manager), - m_smt2sync_tr(s.ctx.get_manager(), m_sync_manager), - m_smt2sls_tr(s.ctx.get_manager(), m) - { - } - - uint_set m_shared_bool_vars, m_shared_terms; - svector m_sat_phase; - std::atomic m_has_new_sat_phase = false; - - std::atomic m_has_new_sls_phase = false; - svector m_sls_phase; - - svector m_rewards; - - - - void add_shared_term(expr* t) { - m_shared_terms.insert(t->get_id()); - if (is_uninterp(t)) - add_uninterp(t); - } - - void add_shared_var(sat::bool_var v) { - m_sls_phase.reserve(v + 1); - m_sat_phase.reserve(v + 1); - m_rewards.reserve(v + 1); - m_shared_bool_vars.insert(v); - } - - void init_search() override {} - - void finish_search() override {} - - void on_rescale() override {} - - void on_restart() override { - if (export_to_sls()) - m_ddfw->reinit(); - } - - void on_save_model() override { - TRACE("sls", display(tout)); - while (unsat().empty()) { - m_context.check(); - if (!m_new_clause_added) - break; - m_ddfw->reinit(); - m_new_clause_added = false; - } - //import_from_sls(); - } - - void on_model(model_ref& mdl) override { - IF_VERBOSE(3, verbose_stream() << "on-model " << "\n"); - s.m_sls_model = mdl; - } - - void register_atom(sat::bool_var v, expr* e) { - m_context.register_atom(v, e); - } - - std::ostream& display(std::ostream& out) { - m_ddfw->display(out); - m_context.display(out); - return out; - } - - vector const& clauses() const override { return m_ddfw->clauses(); } - sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); } - ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); } - void flip(sat::bool_var v) override { - m_ddfw->flip(v); - } - double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } - double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } - bool is_true(sat::literal lit) override { - return m_ddfw->get_value(lit.var()) != lit.sign(); - } - unsigned num_vars() const override { return m_ddfw->num_vars(); } - indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } - sat::bool_var add_var() override { - return m_ddfw->add_var(); - } - void add_clause(unsigned n, sat::literal const* lits) override { - m_ddfw->add(n, lits); - m_new_clause_added = true; - } - void force_restart() override { m_ddfw->force_restart(); } - - void export_values_to_smt() { - if (!m_has_new_sls_values) - return; - IF_VERBOSE(3, verbose_stream() << "SLS -> SMT values\n"); - std::lock_guard lock(s.m_mutex); - ast_translation tr(m_sync_manager, s.ctx.get_manager()); - for (auto const& [t, t_sync] : m_smt2sync_uninterp) { - expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr); - if (sync_val) - s.ctx.user_propagate_initialize_value(t, tr(sync_val)); - } - m_has_new_sls_values = false; - } - - void export_phase_to_smt() { - if (!m_has_new_sls_phase) - return; - IF_VERBOSE(3, verbose_stream() << "new SLS -> SMT phase\n"); - std::lock_guard lock(s.m_mutex); - for (unsigned i = 0; i < m_sls_phase.size(); ++i) - s.s().set_phase(sat::literal(i, !m_sls_phase[i])); - m_has_new_sls_phase = false; - } - - void import_phase_from_smt() { - if (m_has_new_sat_phase) - return; - m_has_new_sat_phase = true; - IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n"); - s.s().set_has_new_best_phase(false); - std::lock_guard lock(s.m_mutex); - for (auto v : m_shared_bool_vars) - m_sat_phase[v] = s.s().get_best_phase(v); - } - - void export_activity_to_smt() { - // TODO - } - - // determine if unit literal restricts values of shared subterms. - bool is_value_restriction(sat::literal lit) { - auto e = s.ctx.bool_var2expr(lit.var()); - expr* t = nullptr; - if (!e) - return false; - bv_util bv(s.ctx.get_manager()); - if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) { - verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, s.ctx.get_manager()) << "\n"; - return true; - } - - // if arith.is_le(e, s, t) && t is a numeral, s is shared-term.... - return false; - } - - }; - - void solver::finalize() { - if (!m_completed && m_ddfw) { - m_ddfw->rlimit().cancel(); - m_thread.join(); - m_ddfw->collect_statistics(m_st); - m_ddfw = nullptr; - m_slsm = nullptr; - m_smt_plugin = nullptr; - m_units.reset(); - } + params_ref solver::get_params() { + return s().params(); } - sat::check_result solver::check() { - return sat::check_result::CR_DONE; + void solver::initialize_value(expr* t, expr* v) { + ctx.user_propagate_initialize_value(t, v); + } + + void solver::force_phase(sat::literal lit) { + ctx.s().set_phase(lit); + } + + void solver::set_has_new_best_phase(bool b) { + + } + + bool solver::get_best_phase(sat::bool_var v) { + return false; + } + + expr* solver::bool_var2expr(sat::bool_var v) { + return ctx.bool_var2expr(v); + } + + void solver::set_finished() { + m.limit().cancel(); + } + + unsigned solver::get_num_bool_vars() const { + return s().num_vars(); + } + + void solver::finalize() { + if (!m_smt_plugin) + return; + + m_smt_plugin->finalize(m_model, m_st); + m_model = nullptr; + m_smt_plugin = nullptr; } bool solver::unit_propagate() { force_push(); - sample_local_search(); - return false; + if (m_smt_plugin && !m_checking) { + expr_ref_vector fmls(m); + m_checking = true; + m_smt_plugin->check(fmls, ctx.top_level_clauses()); + return true; + } + if (!m_smt_plugin) + return false; + if (!m_smt_plugin->completed()) + return false; + m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin = nullptr; + return true; } void solver::pop_core(unsigned n) { if (!m_smt_plugin) return; - for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { - auto lit = s().trail_literal(m_trail_lim); - if (m_smt_plugin->is_value_restriction(lit) || - m_smt_plugin->m_shared_bool_vars.contains(lit.var())) { - std::lock_guard lock(m_mutex); - m_units.push_back(lit); - m_has_units = true; - } - } - if (s().has_new_best_phase()) - m_smt_plugin->import_phase_from_smt(); - - m_smt_plugin->export_phase_to_smt(); - m_smt_plugin->export_activity_to_smt(); - m_smt_plugin->export_values_to_smt(); - } + unsigned scope_lvl = s().scope_lvl(); + if (s().search_lvl() == scope_lvl - n) { + for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { + auto lit = s().trail_literal(m_trail_lim); + m_smt_plugin->add_unit(lit); + } + } +#if 0 + if (ctx.has_new_best_phase()) + m_smt_plugin->import_phase_from_smt(); + +#endif + + m_smt_plugin->import_from_sls(); + } void solver::init_search() { - if (m_ddfw) { - m_ddfw->rlimit().cancel(); - m_thread.join(); - } - // set up state for local search solver here - m_result = l_undef; - m_completed = false; - m_slsm = alloc(ast_manager); - m_units.reset(); - m_has_units = false; - m_model = nullptr; - m_sls_model = nullptr; - m_ddfw = alloc(sat::ddfw); - ast_translation tr(m, *m_slsm); - scoped_limits scoped_limits(m.limit()); - scoped_limits.push_child(&m_slsm->limit()); - scoped_limits.push_child(&m_ddfw->rlimit()); - m_smt_plugin = alloc(smt_plugin, *m_slsm, *this, m_ddfw.get()); - m_ddfw->set_plugin(m_smt_plugin); - m_ddfw->updt_params(s().params()); - for (auto const& clause : ctx.top_level_clauses()) { - m_ddfw->add(clause.size(), clause.data()); - for (auto lit : clause) - m_smt_plugin->add_shared_var(lit.var()); - } - for (auto v : m_smt_plugin->m_shared_bool_vars) { - expr* e = ctx.bool_var2expr(v); - if (!e) - continue; - m_smt_plugin->register_atom(v, tr(e)); - for (auto t : subterms::all(expr_ref(e, m))) - m_smt_plugin->add_shared_term(e); - - } - - m_thread = std::thread([this]() { run_local_search_async(); }); - } - - void solver::sample_local_search() { - if (!m_completed) - return; - m_thread.join(); - local_search_done(); - } - - void solver::local_search_done() { - IF_VERBOSE(1, verbose_stream() << "local-search-done\n"); - m_completed = false; - - CTRACE("sls", m_smt_plugin, m_smt_plugin->display(tout)); - if (m_ddfw) - m_ddfw->collect_statistics(m_st); - - TRACE("sls", tout << "result " << m_result << "\n"); - - if (m_result == l_true && m_sls_model) { - ast_translation tr(*m_slsm, m); - m_model = m_sls_model->translate(tr); - TRACE("sls", tout << "model: " << *m_sls_model << "\n";); - s().set_canceled(); - } - m_ddfw = nullptr; - m_smt_plugin = nullptr; - m_sls_model = nullptr; - } - - void solver::run_local_search_async() { - if (m_ddfw) { - m_result = m_ddfw->check(0, nullptr); - IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n"); - m_completed = true; - } - } - - void solver::run_local_search_sync() { - m_result = m_ddfw->check(0, nullptr); - local_search_done(); + if (m_smt_plugin) + finalize(); + m_smt_plugin = alloc(sls::smt_plugin, *this); + m_checking = false; } std::ostream& solver::display(std::ostream& out) const { - out << "sls-solver\n"; - return out; + return out << "theory-sls\n"; } + + #endif } diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h index 696867157..d14b403c9 100644 --- a/src/sat/smt/sls_solver.h +++ b/src/sat/smt/sls_solver.h @@ -54,6 +54,7 @@ namespace sls { #include #include +#include "ast/sls/sls_smt_plugin.h" namespace euf { class solver; @@ -61,29 +62,12 @@ namespace euf { namespace sls { - class solver : public euf::th_euf_solver { - class smt_plugin; - - std::atomic m_result; - std::atomic m_completed, m_has_units; - std::thread m_thread; - std::mutex m_mutex; - // m is accessed by the main thread - // m_slsm is accessed by the sls thread - scoped_ptr m_slsm; - scoped_ptr m_ddfw; - sat::literal_vector m_units; - smt_plugin* m_smt_plugin = nullptr; - model_ref m_model, m_sls_model; + class solver : public euf::th_euf_solver, public sls::smt_context { + model_ref m_model; + sls::smt_plugin* m_smt_plugin = nullptr; unsigned m_trail_lim = 0; - statistics m_st; - - - - void run_local_search_async(); - void run_local_search_sync(); - void sample_local_search(); - void local_search_done(); + bool m_checking = false; + ::statistics m_st; public: solver(euf::solver& ctx); @@ -102,10 +86,21 @@ namespace sls { sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } void internalize(expr* e) override { UNREACHABLE(); } void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); } - sat::check_result check() override; + sat::check_result check() override { return sat::check_result::CR_DONE; } std::ostream& display(std::ostream& out) const override; std::ostream & display_justification(std::ostream & out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } std::ostream & display_constraint(std::ostream & out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } + + + ast_manager& get_manager() override { return m; } + params_ref get_params() override; + void initialize_value(expr* t, expr* v) override; + void force_phase(sat::literal lit) override; + void set_has_new_best_phase(bool b) override; + bool get_best_phase(sat::bool_var v) override; + expr* bool_var2expr(sat::bool_var v) override; + void set_finished() override; + unsigned get_num_bool_vars() const override; }; diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp index 063ef287b..a3a2d3e21 100644 --- a/src/smt/theory_sls.cpp +++ b/src/smt/theory_sls.cpp @@ -79,7 +79,8 @@ namespace smt { for (unsigned i = 0; i < ctx.get_num_asserted_formulas(); ++i) fmls.push_back(ctx.get_asserted_formula(i)); m_checking = true; - m_smt_plugin->check(fmls); + vector clauses; + m_smt_plugin->check(fmls, clauses); return; } if (!m_smt_plugin)