diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index 8961b535b..a63fc0994 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -15,7 +15,8 @@ z3_add_component(ast_sls sls_context.cpp sls_datatype_plugin.cpp sls_euf_plugin.cpp - sls_smt_solver.cpp + sls_smt_plugin.cpp + sls_smt_solver.cpp COMPONENT_DEPENDENCIES ast euf diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 507bee701..2aa932f8a 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -321,6 +321,8 @@ namespace sls { expr_ref _e(f, m); expr* g, * h, * k; sat::literal_vector clause; + if (m.is_true(f)) + return; if (m.is_not(f, g) && m.is_not(g, g)) { add_clause(g); return; @@ -486,7 +488,7 @@ namespace sls { for (sat::literal lit : m_unit_literals) m_unit_indices.insert(lit.index()); - verbose_stream() << "UNITS " << m_unit_literals << "\n"; + IF_VERBOSE(0, verbose_stream() << "UNITS " << m_unit_literals << "\n"); for (unsigned i = 0; i < m_atoms.size(); ++i) if (m_atoms.get(i)) register_terms(m_atoms.get(i)); diff --git a/src/ast/sls/sls_smt_plugin.cpp b/src/ast/sls/sls_smt_plugin.cpp new file mode 100644 index 000000000..84134ea4e --- /dev/null +++ b/src/ast/sls/sls_smt_plugin.cpp @@ -0,0 +1,263 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_plugin.cpp + +Abstract: + + A Stochastic Local Search (SLS) Plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + + +#include "ast/sls/sls_smt_plugin.h" +#include "ast/for_each_expr.h" +#include "ast/bv_decl_plugin.h" + +namespace sls { + + smt_plugin::smt_plugin(smt_context& ctx) : + ctx(ctx), + m(ctx.get_manager()), + m_sls(), + m_sync(), + m_smt2sync_tr(m, m_sync), + m_smt2sls_tr(m, m_sls), + m_sync_uninterp(m_sync), + m_sls_uninterp(m_sls), + m_sync_values(m_sync), + m_context(m_sls, *this) + { + } + + smt_plugin::~smt_plugin() { + SASSERT(!m_ddfw); + } + + + void smt_plugin::check(expr_ref_vector const& fmls) { + SASSERT(!m_ddfw); + // set up state for local search theory_sls here + m_result = l_undef; + m_completed = false; + m_units.reset(); + m_has_units = false; + m_model = nullptr; + m_sls_model = nullptr; + m_ddfw = alloc(sat::ddfw); + m_ddfw->set_plugin(this); + m_ddfw->updt_params(ctx.get_params()); + + for (auto fml : fmls) + m_context.add_constraint(m_smt2sls_tr(fml)); + + // m_context.display(verbose_stream()); + + for (unsigned v = 0; v < ctx.get_num_bool_vars(); ++v) { + expr* e = ctx.bool_var2expr(v); + if (!e) + continue; + for (auto t : subterms::all(expr_ref(e, m))) + add_shared_term(e); + + expr_ref sls_e(m_sls); + sls_e = m_smt2sls_tr(e); + auto w = m_context.atom2bool_var(sls_e); + if (w != sat::null_bool_var) { + verbose_stream() << mk_bounded_pp(e, m) << ": " << v << " -> " << w << "\n"; + m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var); + m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var); + + add_shared_var(w); + } + } + + m_thread = std::thread([this]() { run(); }); + } + + void smt_plugin::run() { + if (!m_ddfw) + return; + m_result = m_ddfw->check(0, nullptr); + IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n"); + m_completed = true; + } + + void smt_plugin::finalize(model_ref& mdl, ::statistics& st) { + auto* d = m_ddfw; + if (!d) + return; + bool canceled = !m_completed; + IF_VERBOSE(0, verbose_stream() << "finalize\n"); + mdl = m_model; + if (!m_completed) { + d->rlimit().cancel(); + if (m_thread.joinable()) + m_thread.join(); + } + if (m_result == l_true && m_sls_model) { + ast_translation tr(m_sls, m); + m_model = m_sls_model->translate(tr); + TRACE("sls", tout << "model: " << *m_sls_model << "\n";); + if (!canceled) + ctx.set_finished(); + } + m_ddfw = nullptr; + // m_ddfw owns the pointer to smt_plugin and destructs it. + dealloc(d); + } + + void smt_plugin::collect_statistics(statistics& st) { + + } + std::ostream& smt_plugin::display(std::ostream& out) { + m_ddfw->display(out); + m_context.display(out); + return out; + } + + bool smt_plugin::is_shared(sat::literal lit) { + auto w = m_smt_bool_var2sls_bool_var.get(lit.var(), sat::null_bool_var); + if (w != sat::null_bool_var) + return true; + auto e = ctx.bool_var2expr(lit.var()); + expr* t = nullptr; + if (!e) + return false; + bv_util bv(m); + if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) { + verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, ctx.get_manager()) << "\n"; + return true; + } + + // if arith.is_le(e, s, t) && t is a numeral, s is shared-term.... + return false; + } + + + void smt_plugin::add_unit(sat::literal lit) { + if (!is_shared(lit)) + return; + std::lock_guard lock(m_mutex); + m_units.push_back(lit); + m_has_units = true; + } + + void smt_plugin::import_phase_from_smt() { + if (m_has_new_sat_phase) + return; + m_has_new_sat_phase = true; + IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n"); + ctx.set_has_new_best_phase(false); + std::lock_guard lock(m_mutex); + for (auto v : m_shared_bool_vars) + m_sat_phase[v] = ctx.get_best_phase(v); + } + + bool smt_plugin::export_to_sls() { + bool updated = false; + if (export_units_to_sls()) + updated = true; + if (export_phase_to_sls()) + updated = true; + return updated; + } + + bool smt_plugin::export_phase_to_sls() { + if (!m_has_new_sat_phase) + return false; + std::lock_guard lock(m_mutex); + IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n"); + for (auto i : m_shared_bool_vars) { + auto v = m_smt_bool_var2sls_bool_var[i]; + if (m_sat_phase[v] != is_true(sat::literal(v, false))) + flip(v); + m_ddfw->bias(v) = m_sat_phase[v] ? 1 : -1; + } + m_has_new_sat_phase = false; + return true; + } + + bool smt_plugin::export_units_to_sls() { + if (!m_has_units) + return false; + std::lock_guard lock(m_mutex); + IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << m_units << "\n"); + for (auto lit : m_units) { + if (m_shared_bool_vars.contains(lit.var())) { + sat::literal sls_lit(m_smt_bool_var2sls_bool_var[lit.var()], false); + IF_VERBOSE(10, verbose_stream() << "unit " << sls_lit << "\n"); + m_ddfw->add(1, &sls_lit); + } + else { + IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " " + << mk_bounded_pp(ctx.bool_var2expr(lit.var()), m) << "\n"); + } + } + m_has_units = false; + m_units.reset(); + return true; + } + + void smt_plugin::import_from_sls() { + if (unsat().size() > m_min_unsat_size) + return; + m_min_unsat_size = unsat().size(); + std::lock_guard lock(m_mutex); + for (auto v : m_shared_bool_vars) { + m_rewards[v] = m_ddfw->get_reward_avg(v); + m_sls_phase[v] = l_true == m_ddfw->get_model()[v]; + m_has_new_sls_phase = true; + } + // import_values_from_sls(); + } + + void smt_plugin::import_values_from_sls() { + IF_VERBOSE(3, verbose_stream() << "import values from sls\n"); + std::lock_guard lock(m_mutex); + for (auto const& [t, t_sync] : m_sls2sync_uninterp) { + expr_ref val_t = m_context.get_value(t); + m_sync_values.set(t_sync->get_id(), m_smt2sync_tr(val_t.get())); + } + m_has_new_sls_values = true; + } + + + + void smt_plugin::export_activity_to_smt() { + + } + + void smt_plugin::export_values_to_smt() { + + } + + void smt_plugin::add_shared_term(expr* t) { + m_shared_terms.insert(t->get_id()); + if (is_uninterp(t)) + add_uninterp(t); + } + + void smt_plugin::add_uninterp(expr* smt_t) { + auto sync_t = m_smt2sync_tr(smt_t); + auto sls_t = m_smt2sls_tr(smt_t); + m_sync_uninterp.push_back(sync_t); + m_sls_uninterp.push_back(sls_t); + m_smt2sync_uninterp.insert(smt_t, sync_t); + m_sls2sync_uninterp.insert(sls_t, sync_t); + } + + void smt_plugin::add_shared_var(sat::bool_var v) { + m_sls_phase.reserve(v + 1); + m_sat_phase.reserve(v + 1); + m_rewards.reserve(v + 1); + m_shared_bool_vars.insert(v); + } + +} diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h new file mode 100644 index 000000000..18164217a --- /dev/null +++ b/src/ast/sls/sls_smt_plugin.h @@ -0,0 +1,168 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_plugin.h + +Abstract: + + A Stochastic Local Search (SLS) Plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + +#pragma once + +#include "ast/sls/sls_context.h" +#include "ast/sls/sat_ddfw.h" +#include "util/statistics.h" +#include +#include + +namespace sls { + + class smt_context { + public: + virtual ast_manager& get_manager() = 0; + virtual params_ref get_params() = 0; + virtual void initialize_value(expr* t, expr* v) = 0; + virtual void force_phase(sat::literal lit) = 0; + virtual void set_has_new_best_phase(bool b) = 0; + virtual bool get_best_phase(sat::bool_var v) = 0; + virtual expr* bool_var2expr(sat::bool_var v) = 0; + virtual void set_finished() = 0; + virtual unsigned get_num_bool_vars() const = 0; + }; + + class smt_plugin : public sat::local_search_plugin, public sat_solver_context { + smt_context& ctx; + ast_manager& m; + ast_manager m_sls; + ast_manager m_sync; + ast_translation m_smt2sync_tr, m_smt2sls_tr; + expr_ref_vector m_sync_uninterp, m_sls_uninterp; + expr_ref_vector m_sync_values; + sat::ddfw* m_ddfw = nullptr; + sls::context m_context; + std::atomic m_result; + std::atomic m_completed, m_has_units; + std::thread m_thread; + std::mutex m_mutex; + // m is accessed by the main thread + // m_slsm is accessed by the sls thread + sat::literal_vector m_units; + model_ref m_model, m_sls_model; + unsigned m_trail_lim = 0; + ::statistics m_st; + bool m_new_clause_added = false; + unsigned m_min_unsat_size = UINT_MAX; + obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp + obj_map m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp + std::atomic m_has_new_sls_values = false; + + uint_set m_shared_bool_vars, m_shared_terms; + svector m_sat_phase; + std::atomic m_has_new_sat_phase = false; + + std::atomic m_has_new_sls_phase = false; + svector m_sls_phase; + + svector m_rewards; + svector m_smt_bool_var2sls_bool_var, m_sls_bool_var2smt_bool_var; + + + + bool is_shared(sat::literal lit); + void run(); + void add_shared_term(expr* t); + + void add_uninterp(expr* smt_t); + void add_shared_var(sat::bool_var v); + + void import_phase_from_smt(); + void import_values_from_sls(); + bool export_phase_to_sls(); + bool export_units_to_sls(); + void export_activity_to_smt(); + void export_values_to_smt(); + + + friend class sat::ddfw; + ~smt_plugin(); + + public: + smt_plugin(smt_context& ctx); + + // interface to calling solver: + void check(expr_ref_vector const& fmls); + void finalize(model_ref& md, ::statistics& st); + void updt_params(params_ref& p) {} + void collect_statistics(statistics& st); + std::ostream& display(std::ostream& out) override; + void import_from_sls(); + bool export_to_sls(); + bool completed() { return m_completed; } + void add_unit(sat::literal lit); + + + // local_search_plugin: + void on_restart() override { + if (export_to_sls()) + m_ddfw->reinit(); + } + + void on_save_model() override { + TRACE("sls", display(tout)); + while (unsat().empty()) { + m_context.check(); + if (!m_new_clause_added) + break; + m_ddfw->reinit(); + m_new_clause_added = false; + } + //import_from_sls(); + } + + void on_model(model_ref& mdl) override { + IF_VERBOSE(3, verbose_stream() << "on-model " << "\n"); + m_sls_model = mdl; + } + + void init_search() override {} + + void finish_search() override {} + + void on_rescale() override {} + + + + // sat_solver_context: + vector const& clauses() const override { return m_ddfw->clauses(); } + sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); } + ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); } + void flip(sat::bool_var v) override { + m_ddfw->flip(v); + } + double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } + double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } + bool is_true(sat::literal lit) override { + return m_ddfw->get_value(lit.var()) != lit.sign(); + } + unsigned num_vars() const override { return m_ddfw->num_vars(); } + indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } + sat::bool_var add_var() override { + return m_ddfw->add_var(); + } + void add_clause(unsigned n, sat::literal const* lits) override { + m_ddfw->add(n, lits); + m_new_clause_added = true; + } + void force_restart() override { m_ddfw->force_restart(); } + + + }; +} diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 3da8fa4a0..da9440761 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1248,7 +1248,7 @@ namespace arith { for (auto ev : m_explanation) set_evidence(ev.ci()); - TRACE("arith", + TRACE("arith_conflict", tout << "Lemma - " << (is_conflict ? "conflict" : "propagation") << "\n"; for (literal c : m_core) tout << c << ": " << literal2expr(c) << "\n"; for (auto p : m_eqs) tout << ctx.bpp(p.first) << " == " << ctx.bpp(p.second) << "\n";); @@ -1270,6 +1270,7 @@ namespace arith { m_core.push_back(ctx.mk_literal(m.mk_eq(eq.first->get_expr(), eq.second->get_expr()))); for (literal& c : m_core) c.neg(); + DEBUG_CODE(for (literal c : m_core) { SASSERT(s().value(c) != l_true); }); add_redundant(m_core, explain(ty)); } @@ -1520,10 +1521,13 @@ namespace arith { } for (auto const& ineq : m_nla->literals()) { auto lit = mk_ineq_literal(ineq); + if (s().value(lit) == l_true) + continue; ctx.mark_relevant(lit); s().set_phase(lit); + verbose_stream() << lit << ":= " << s().value(lit) << "\n"; // force trichotomy axiom for equality literals - if (ineq.cmp() == lp::EQ) { + if (ineq.cmp() == lp::EQ && false) { nla::lemma l; l.push_back(ineq); l.push_back(nla::ineq(lp::LT, ineq.term(), ineq.rs())); diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index a80ef1511..b78b7e1b9 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -62,12 +62,17 @@ namespace sls { if (!s.m_has_units) return false; std::lock_guard lock(s.m_mutex); - IF_VERBOSE(3, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n"); - for (auto lit : s.m_units) - if (m_shared_vars.contains(lit.var())) { - IF_VERBOSE(10, verbose_stream() << "unit " << lit << "\n"); + IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n"); + for (auto lit : s.m_units) { + if (m_shared_bool_vars.contains(lit.var())) { + IF_VERBOSE(10, verbose_stream() << "unit " << lit << "\n"); m_ddfw->add(1, &lit); } + else { + IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " " << mk_bounded_pp(s.ctx.bool_var2expr(lit.var()), s.ctx.get_manager()) << "\n"); + } + } + s.m_has_units = false; s.m_units.reset(); return true; @@ -78,7 +83,7 @@ namespace sls { return false; std::lock_guard lock(s.m_mutex); IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n"); - for (auto i : m_shared_vars) { + for (auto i : m_shared_bool_vars) { if (m_sat_phase[i] != is_true(sat::literal(i, false))) flip(i); m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1; @@ -105,7 +110,7 @@ namespace sls { return; m_min_unsat_size = unsat().size(); std::lock_guard lock(s.m_mutex); - for (auto v : m_shared_vars) { + for (auto v : m_shared_bool_vars) { m_rewards[v] = m_ddfw->get_reward_avg(v); m_sls_phase[v] = l_true == m_ddfw->get_model()[v]; m_has_new_sls_phase = true; @@ -123,7 +128,16 @@ namespace sls { } m_has_new_sls_values = true; } + + + void add_uninterp(expr* smt_t) { + auto sync_t = m_smt2sync_tr(smt_t); + auto sls_t = m_smt2sls_tr(smt_t); + m_sync_uninterp.push_back(sync_t); + m_smt2sync_uninterp.insert(smt_t, sync_t); + m_sls2sync_uninterp.insert(sls_t, sync_t); + } public: smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : @@ -135,7 +149,7 @@ namespace sls { { } - uint_set m_shared_vars; + uint_set m_shared_bool_vars, m_shared_terms; svector m_sat_phase; std::atomic m_has_new_sat_phase = false; @@ -144,19 +158,19 @@ namespace sls { svector m_rewards; - void add_uninterp(expr* smt_t) { - auto sync_t = m_smt2sync_tr(smt_t); - auto sls_t = m_smt2sls_tr(smt_t); - m_sync_uninterp.push_back(sync_t); - m_smt2sync_uninterp.insert(smt_t, sync_t); - m_sls2sync_uninterp.insert(sls_t, sync_t); + + + void add_shared_term(expr* t) { + m_shared_terms.insert(t->get_id()); + if (is_uninterp(t)) + add_uninterp(t); } void add_shared_var(sat::bool_var v) { m_sls_phase.reserve(v + 1); m_sat_phase.reserve(v + 1); m_rewards.reserve(v + 1); - m_shared_vars.insert(v); + m_shared_bool_vars.insert(v); } void init_search() override {} @@ -250,7 +264,7 @@ namespace sls { IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n"); s.s().set_has_new_best_phase(false); std::lock_guard lock(s.m_mutex); - for (auto v : m_shared_vars) + for (auto v : m_shared_bool_vars) m_sat_phase[v] = s.s().get_best_phase(v); } @@ -258,6 +272,22 @@ namespace sls { // TODO } + // determine if unit literal restricts values of shared subterms. + bool is_value_restriction(sat::literal lit) { + auto e = s.ctx.bool_var2expr(lit.var()); + expr* t = nullptr; + if (!e) + return false; + bv_util bv(s.ctx.get_manager()); + if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) { + verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, s.ctx.get_manager()) << "\n"; + return true; + } + + // if arith.is_le(e, s, t) && t is a numeral, s is shared-term.... + return false; + } + }; void solver::finalize() { @@ -287,12 +317,12 @@ namespace sls { return; for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { auto lit = s().trail_literal(m_trail_lim); - if (!m_smt_plugin->m_shared_vars.contains(lit.var())) - continue; - IF_VERBOSE(10, verbose_stream() << "push unit " << lit << " " << mk_bounded_pp(ctx.literal2expr(lit), m) << "\n"); - std::lock_guard lock(m_mutex); - m_units.push_back(lit); - m_has_units = true; + if (m_smt_plugin->is_value_restriction(lit) || + m_smt_plugin->m_shared_bool_vars.contains(lit.var())) { + std::lock_guard lock(m_mutex); + m_units.push_back(lit); + m_has_units = true; + } } if (s().has_new_best_phase()) m_smt_plugin->import_phase_from_smt(); @@ -329,15 +359,14 @@ namespace sls { for (auto lit : clause) m_smt_plugin->add_shared_var(lit.var()); } - for (auto v : m_smt_plugin->m_shared_vars) { + for (auto v : m_smt_plugin->m_shared_bool_vars) { expr* e = ctx.bool_var2expr(v); if (!e) continue; m_smt_plugin->register_atom(v, tr(e)); - for (auto t : subterms::all(expr_ref(e, m))) { - if (is_uninterp(t)) - m_smt_plugin->add_uninterp(t); - } + for (auto t : subterms::all(expr_ref(e, m))) + m_smt_plugin->add_shared_term(e); + } m_thread = std::thread([this]() { run_local_search_async(); }); diff --git a/src/smt/CMakeLists.txt b/src/smt/CMakeLists.txt index e6ee97046..3922e9373 100644 --- a/src/smt/CMakeLists.txt +++ b/src/smt/CMakeLists.txt @@ -66,6 +66,7 @@ z3_add_component(smt theory_pb.cpp theory_recfun.cpp theory_seq.cpp + theory_sls.cpp theory_special_relations.cpp theory_str.cpp theory_str_mc.cpp diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 9ceee136f..e81e19eb2 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -37,6 +37,7 @@ Revision History: #include "smt/uses_theory.h" #include "smt/theory_special_relations.h" #include "smt/theory_polymorphism.h" +#include "smt/theory_sls.h" #include "smt/smt_for_each_relevant_expr.h" #include "smt/smt_model_generator.h" #include "smt/smt_model_checker.h" @@ -3506,6 +3507,10 @@ namespace smt { if (r == l_true && get_cancel_flag()) { r = l_undef; } + if (r == l_undef && get_cancel_flag() && has_sls_model()) { + m.limit().reset_cancel(); + r = l_true; + } if (r == l_true && gparams::get_value("model_validate") == "true") { recfun::util u(m); if (u.get_rec_funs().empty() && m_proto_model) { @@ -3581,6 +3586,20 @@ namespace smt { return r; } + bool context::has_sls_model() { + if (!m_fparams.m_sls_enable) + return false; + auto tid = m.get_family_id("sls"); + auto p = m_theories.get_plugin(tid); + if (!p) + return false; + auto mdl = dynamic_cast(p)->get_model(); + if (!mdl) + return false; + m_model = mdl; + return true; + } + /** \brief Setup the logical context based on the current set of asserted formulas and execute the check command. diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 715b28f23..fe2bc0f6f 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -619,6 +619,9 @@ namespace smt { friend class set_var_theory_trail; void set_var_theory(bool_var v, theory_id tid); + + bool has_sls_model(); + // ----------------------------------- // // Backtracking support diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index caeca9659..b227d9913 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -35,6 +35,7 @@ Revision History: #include "smt/theory_seq.h" #include "smt/theory_char.h" #include "smt/theory_special_relations.h" +#include "smt/theory_sls.h" #include "smt/theory_pb.h" #include "smt/theory_fpa.h" #include "smt/theory_str.h" @@ -67,6 +68,7 @@ namespace smt { case CFG_AUTO: setup_auto_config(); break; } setup_card(); + setup_sls(); } void setup::setup_default() { @@ -766,6 +768,11 @@ namespace smt { m_context.register_plugin(alloc(theory_pb, m_context)); } + void setup::setup_sls() { + if (m_params.m_sls_enable) + m_context.register_plugin(alloc(theory_sls, m_context)); + } + void setup::setup_fpa() { setup_bv(); m_context.register_plugin(alloc(theory_fpa, m_context)); diff --git a/src/smt/smt_setup.h b/src/smt/smt_setup.h index bb4a81671..acbea59cb 100644 --- a/src/smt/smt_setup.h +++ b/src/smt/smt_setup.h @@ -103,6 +103,7 @@ namespace smt { void setup_seq(); void setup_char(); void setup_card(); + void setup_sls(); void setup_i_arith(); void setup_mi_arith(); void setup_lra_arith(); diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 5a3cbbd1a..c9ea77612 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -3254,7 +3254,7 @@ public: tout << "@" << ctx().get_scope_level() << (is_conflict ? " conflict":" lemma"); for (auto const& p : m_params) tout << " " << p; tout << "\n"; - display_evidence(tout, m_explanation);); + display_evidence(tout << core << " ", m_explanation);); for (auto ev : m_explanation) set_evidence(ev.ci(), m_core, m_eqs); diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp new file mode 100644 index 000000000..3db0b1820 --- /dev/null +++ b/src/smt/theory_sls.cpp @@ -0,0 +1,136 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + theory_sls + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-21 + + +--*/ + +#include "smt/theory_sls.h" +#include "smt/smt_context.h" +#include "ast/sls/sls_context.h" +#include "ast/for_each_expr.h" + +namespace smt { + +#ifdef SINGLE_THREAD + + +#else + theory_sls::theory_sls(smt::context& ctx): + theory(ctx, ctx.get_manager().mk_family_id("sls")) + {} + + theory_sls::~theory_sls() { + finalize(); + } + + params_ref theory_sls::get_params() { + return ctx.get_params(); + } + + void theory_sls::initialize_value(expr* t, expr* v) { + ctx.user_propagate_initialize_value(t, v); + } + + void theory_sls::force_phase(sat::literal lit) { + ctx.force_phase(lit); + } + + void theory_sls::set_has_new_best_phase(bool b) { + + } + + bool theory_sls::get_best_phase(sat::bool_var v) { + return false; + } + + expr* theory_sls::bool_var2expr(sat::bool_var v) { + return ctx.bool_var2expr(v); + } + + void theory_sls::set_finished() { + m.limit().cancel(); + } + + unsigned theory_sls::get_num_bool_vars() const { + return ctx.get_num_bool_vars(); + } + + void theory_sls::finalize() { + if (!m_smt_plugin) + return; + + m_smt_plugin->finalize(m_model, m_st); + m_model = nullptr; + m_smt_plugin = nullptr; + } + + void theory_sls::propagate() { + if (m_smt_plugin && !m_checking) { + expr_ref_vector fmls(m); + for (unsigned i = 0; i < ctx.get_num_asserted_formulas(); ++i) + fmls.push_back(ctx.get_asserted_formula(i)); + m_checking = true; + m_smt_plugin->check(fmls); + return; + } + if (!m_smt_plugin) + return; + if (!m_smt_plugin->completed()) + return; + m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin = nullptr; + } + + void theory_sls::pop_scope_eh(unsigned n) { + if (!m_smt_plugin) + return; + + unsigned scope_lvl = ctx.get_scope_level(); + if (ctx.get_search_level() == scope_lvl - n) { + auto& lits = ctx.assigned_literals(); + for (; m_trail_lim < lits.size() && ctx.get_assign_level(lits[m_trail_lim]) == scope_lvl; ++m_trail_lim) { + auto lit = lits[m_trail_lim]; + m_smt_plugin->add_unit(lit); + } + } +#if 0 + if (ctx.has_new_best_phase()) + m_smt_plugin->import_phase_from_smt(); + +#endif + + m_smt_plugin->import_from_sls(); + } + + + void theory_sls::init() { + if (m_smt_plugin) + finalize(); + m_smt_plugin = alloc(sls::smt_plugin, *this); + m_checking = false; + } + + void theory_sls::collect_statistics(::statistics& st) const { + st.copy(m_st); + } + + void theory_sls::display(std::ostream& out) const { + out << "theory-sls\n"; + } + + + +#endif +} diff --git a/src/smt/theory_sls.h b/src/smt/theory_sls.h new file mode 100644 index 000000000..9174d413c --- /dev/null +++ b/src/smt/theory_sls.h @@ -0,0 +1,79 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + theory_sls + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-21 + +--*/ +#pragma once + + +#include "util/rlimit.h" +#include "ast/sls/sat_ddfw.h" +#include "smt/smt_theory.h" +#include "model/model.h" + + +#ifdef SINGLE_THREAD + + +namespace sls { + + +} + +#else + +#include "ast/sls/sls_smt_plugin.h" +class context; + +namespace smt { + class theory_sls : public smt::theory, public sls::smt_context { + model_ref m_model; + sls::smt_plugin* m_smt_plugin = nullptr; + unsigned m_trail_lim = 0; + bool m_checking = false; + ::statistics m_st; + + void finalize(); + + public: + theory_sls(context& ctx); + ~theory_sls() override; + model_ref get_model() { return m_model; } + char const* get_name() const override { return "sls"; } + void init() override; + void pop_scope_eh(unsigned n) override; + smt::theory* mk_fresh(context* new_ctx) override { return alloc(theory_sls, *new_ctx); } + void collect_statistics(::statistics& st) const override; + void propagate() override; + void display(std::ostream& out) const override; + bool internalize_atom(app * atom, bool gate_ctx) override { return false; } + bool internalize_term(app* term) override { return false; } + void new_eq_eh(theory_var v1, theory_var v2) override {} + void new_diseq_eh(theory_var v1, theory_var v2) override {} + + ast_manager& get_manager() override { return m; } + params_ref get_params() override; + void initialize_value(expr* t, expr* v) override; + void force_phase(sat::literal lit) override; + void set_has_new_best_phase(bool b) override; + bool get_best_phase(sat::bool_var v) override; + expr* bool_var2expr(sat::bool_var v) override; + void set_finished() override; + unsigned get_num_bool_vars() const override; + + }; + +} + +#endif