diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp index 9af87d3c5..c0972349b 100644 --- a/src/ast/sls/bv_sls.cpp +++ b/src/ast/sls/bv_sls.cpp @@ -61,6 +61,17 @@ namespace bv { } } + + void sls::set_model() { + if (!m_set_model) + return; + if (m_repair_roots.size() >= m_min_repair_size) + return; + m_min_repair_size = m_repair_roots.size(); + IF_VERBOSE(2, verbose_stream() << "(sls-update-model :num-unsat " << m_min_repair_size << ")\n"); + m_set_model(*get_model()); + } + void sls::init_repair_goal(app* t) { m_eval.init_eval(t); } @@ -94,6 +105,9 @@ namespace bv { if (m_to_repair.empty()) return; + // refresh the best model so far to a callback + set_model(); + // add fresh units, if any bool new_assertion = false; while (m_get_unit) { @@ -130,7 +144,7 @@ namespace bv { return m_rand() % 2 == 0; }; m_eval.init_eval(m_terms.assertions(), eval); - init_repair(); + init_repair(); // m_engine_init = false; } @@ -295,10 +309,12 @@ namespace bv { model_ref mdl = alloc(model, m); auto& terms = m_eval.sort_assertions(m_terms.assertions()); for (expr* e : terms) { +#if 0 if (!m_eval.re_eval_is_correct(to_app(e))) { verbose_stream() << "missed evaluation #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; m_eval.display_value(verbose_stream(), e) << "\n"; } +#endif if (!is_uninterp_const(e)) continue; diff --git a/src/ast/sls/bv_sls.h b/src/ast/sls/bv_sls.h index 690b618bf..987cebcdb 100644 --- a/src/ast/sls/bv_sls.h +++ b/src/ast/sls/bv_sls.h @@ -54,10 +54,13 @@ namespace bv { bool m_engine_model = false; bool m_engine_init = false; std::function m_get_unit; + std::function m_set_model; + unsigned m_min_repair_size = UINT_MAX; std::pair next_to_repair(); void init_repair_goal(app* e); + void set_model(); void try_repair_down(app* e); void try_repair_up(app* e); void set_repair_down(expr* e) { m_repair_down = e->get_id(); } @@ -96,6 +99,11 @@ namespace bv { */ void init_unit(std::function get_unit) { m_get_unit = get_unit; } + /** + * Add callback to set model + */ + void set_model(std::function f) { m_set_model = f; } + /** * Run (bounded) local search to find feasible assignments. */ diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index f4491896b..459b26339 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -1069,7 +1069,7 @@ namespace intblast { if (e->get_family_id() != bv.get_family_id()) return false; for (euf::enode* arg : euf::enode_args(n)) - dep.add(n, arg->get_root()); + dep.add(n, arg); return true; } diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index e12ff5ba7..a507619ee 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -22,31 +22,37 @@ Author: namespace sls { +#ifdef SINGLE_THREAD + + solver::solver(euf::solver& ctx) : + th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) + {} + +#else solver::solver(euf::solver& ctx): - th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")), - m_units(m) {} + th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) + {} solver::~solver() { finalize(); } void solver::finalize() { - if (!m_completed && m_bvsls) { - m_bvsls->cancel(); + if (!m_completed && m_sls) { + m_sls->cancel(); m_thread.join(); - m_bvsls->collect_statistics(m_st); - m_bvsls = nullptr; + m_sls->collect_statistics(m_st); + m_sls = nullptr; + m_shared = nullptr; + m_slsm = nullptr; + m_units = nullptr; } } sat::check_result solver::check() { - return sat::check_result::CR_DONE; } - void solver::simplify() { - } - bool solver::unit_propagate() { force_push(); sample_local_search(); @@ -66,10 +72,6 @@ namespace sls { return false; } - void solver::push_core() { - - } - void solver::pop_core(unsigned n) { for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { auto lit = s().trail_literal(m_trail_lim); @@ -77,60 +79,63 @@ namespace sls { if (is_unit(e)) { // IF_VERBOSE(1, verbose_stream() << "add unit " << mk_pp(e, m) << "\n"); std::lock_guard lock(m_mutex); - m_units.push_back(e); + ast_translation tr(m, *m_shared); + m_units->push_back(tr(e.get())); m_has_units = true; } } - } - - void solver::init_search() { - init_local_search(); - } + } - void solver::init_local_search() { - if (m_bvsls) { - m_bvsls->cancel(); + void solver::init_search() { + if (m_sls) { + m_sls->cancel(); m_thread.join(); m_result = l_undef; m_completed = false; m_has_units = false; m_model = nullptr; - m_units.reset(); + m_units = nullptr; } // set up state for local search solver here - m_m = alloc(ast_manager, m); - ast_translation tr(m, *m_m); + m_shared = alloc(ast_manager); + m_slsm = alloc(ast_manager); + m_units = alloc(expr_ref_vector, *m_shared); + ast_translation tr(m, *m_slsm); - params_ref p; m_completed = false; m_result = l_undef; m_model = nullptr; - m_bvsls = alloc(bv::sls, *m_m, p); + m_sls = alloc(bv::sls, *m_slsm, s().params()); for (expr* a : ctx.get_assertions()) - m_bvsls->assert_expr(tr(a)); + m_sls->assert_expr(tr(a)); std::function eval = [&](expr* e, unsigned r) { return false; }; - m_bvsls->init(); - m_bvsls->init_eval(eval); - m_bvsls->updt_params(s().params()); - m_bvsls->init_unit([&]() { + m_sls->init(); + m_sls->init_eval(eval); + m_sls->updt_params(s().params()); + m_sls->init_unit([&]() { if (!m_has_units) - return expr_ref(*m_m); - expr_ref e(m); + return expr_ref(*m_slsm); + expr_ref e(*m_slsm); { std::lock_guard lock(m_mutex); - if (m_units.empty()) - return expr_ref(*m_m); - e = m_units.back(); - m_units.pop_back(); + if (m_units->empty()) + return expr_ref(*m_slsm); + ast_translation tr(*m_shared, *m_slsm); + e = tr(m_units->back()); + m_units->pop_back(); } - ast_translation tr(m, *m_m); - return expr_ref(tr(e.get()), *m_m); + return e; + }); + m_sls->set_model([&](model& mdl) { + std::lock_guard lock(m_mutex); + ast_translation tr(*m_shared, m); + m_model = mdl.translate(tr); }); m_thread = std::thread([this]() { run_local_search(); }); @@ -141,20 +146,21 @@ namespace sls { return; m_thread.join(); m_completed = false; - m_bvsls->collect_statistics(m_st); + m_sls->collect_statistics(m_st); if (m_result == l_true) { IF_VERBOSE(2, verbose_stream() << "(sat.sls :model-completed)\n";); - auto mdl = m_bvsls->get_model(); - ast_translation tr(*m_m, m); + auto mdl = m_sls->get_model(); + ast_translation tr(*m_slsm, m); m_model = mdl->translate(tr); s().set_canceled(); } - m_bvsls = nullptr; + m_sls = nullptr; } void solver::run_local_search() { - m_result = (*m_bvsls)(); + m_result = (*m_sls)(); m_completed = true; } +#endif } diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h index 5a6c9950b..e1d8a95b5 100644 --- a/src/sat/smt/sls_solver.h +++ b/src/sat/smt/sls_solver.h @@ -16,13 +16,45 @@ Author: --*/ #pragma once -#include -#include + #include "util/rlimit.h" #include "ast/sls/bv_sls.h" #include "sat/smt/sat_th.h" +#ifdef SINGLE_THREAD + + +namespace euf { + class solver; +} + +namespace sls { + + class solver : public euf::th_euf_solver { + public: + solver(euf::solver& ctx); + + sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } + void internalize(expr* e) override { UNREACHABLE(); } + th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } + + model_ref get_model() { return model_ref(nullptr); } + bool unit_propagate() override { return false; } + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override { UNREACHABLE(); } + sat::check_result check() override { return sat::check_result::CR_DONE;} + std::ostream& display(std::ostream& out) const override { return out; } + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } + + }; +} + +#else + +#include +#include + namespace euf { class solver; } @@ -34,38 +66,36 @@ namespace sls { std::atomic m_completed, m_has_units; std::thread m_thread; std::mutex m_mutex; - scoped_ptr m_m; - scoped_ptr m_bvsls; + // m is accessed by the main thread + // m_slsm is accessed by the sls thread + // m_shared is only accessed at synchronization points + scoped_ptr m_shared, m_slsm; + scoped_ptr m_sls; + scoped_ptr m_units; model_ref m_model; unsigned m_trail_lim = 0; - expr_ref_vector m_units; statistics m_st; void run_local_search(); - void init_local_search(); void sample_local_search(); - bool is_unit(expr*); + public: solver(euf::solver& ctx); ~solver(); - void simplify() override; - void init_search() override; + model_ref get_model() { return m_model; } - void push_core() override; + void init_search() override; + void push_core() override {} void pop_core(unsigned n) override; + th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } + void collect_statistics(statistics& st) const override { st.copy(m_st); } + void finalize() override; + bool unit_propagate() override; sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } void internalize(expr* e) override { UNREACHABLE(); } - th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } - void collect_statistics(statistics& st) const override { st.copy(m_st); } - - model_ref get_model() { return m_model; } - - void finalize() override; - - bool unit_propagate() override; void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); } sat::check_result check() override; std::ostream & display(std::ostream & out) const override { return out; } @@ -75,3 +105,5 @@ namespace sls { }; } + +#endif \ No newline at end of file