diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index 92094ba3e..1703a39ac 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -4,9 +4,11 @@ z3_add_component(ast_sls bv_sls.cpp bv_sls_eval.cpp bv_sls_fixed.cpp - bv_sls_terms.cpp + bv_sls_terms.cpp + sat_ddfw.cpp sls_arith_base.cpp sls_arith_plugin.cpp + sls_bv.cpp sls_cc.cpp sls_engine.cpp sls_smt.cpp diff --git a/src/sat/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp similarity index 90% rename from src/sat/sat_ddfw.cpp rename to src/ast/sls/sat_ddfw.cpp index 52f17887d..21820c8df 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -26,18 +26,18 @@ --*/ #include "util/luby.h" -#include "sat/sat_ddfw.h" -#include "sat/sat_solver.h" +#include "util/trace.h" +#include "ast/sls/sat_ddfw.h" #include "params/sat_params.hpp" + namespace sat { ddfw::~ddfw() { } - lbool ddfw::check(unsigned sz, literal const* assumptions, parallel* p) { - init(sz, assumptions); - flet _p(m_par, p); + lbool ddfw::check(unsigned sz, literal const* assumptions) { + init(sz, assumptions); if (m_plugin) check_with_plugin(); else @@ -52,7 +52,7 @@ namespace sat { if (should_reinit_weights()) do_reinit_weights(); else if (do_flip()); else if (should_restart()) do_restart(); - else if (should_parallel_sync()) do_parallel_sync(); + else if (m_parallel_sync && m_parallel_sync()); else shift_weights(); } } @@ -78,7 +78,6 @@ namespace sat { double kflips_per_sec = sec > 0 ? (m_flips - m_last_flips) / (1000.0 * sec) : 0.0; if (m_last_flips == 0) { IF_VERBOSE(1, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts"; - if (m_par) verbose_stream() << " :par"; verbose_stream() << ")\n"); } IF_VERBOSE(1, verbose_stream() << "(sat.ddfw " @@ -90,7 +89,6 @@ namespace sat { << std::setw(11) << m_reinit_count << std::setw(13) << m_unsat_vars.size() << std::setw(9) << m_shifts; - if (m_par) verbose_stream() << std::setw(10) << m_parsync_count; verbose_stream() << ")\n"); m_stopwatch.start(); m_last_flips = m_flips; @@ -151,6 +149,8 @@ namespace sat { void ddfw::add(unsigned n, literal const* c) { unsigned idx = m_clauses.size(); m_clauses.push_back(clause_info(n, c, m_config.m_init_clause_weight)); + if (n > 2) + ++m_num_non_binary_clauses; for (literal lit : m_clauses.back().m_clause) { m_use_list.reserve(2*(lit.var()+1)); m_vars.reserve(lit.var()+1); @@ -177,35 +177,6 @@ namespace sat { m_unsat.remove(m_clauses.size()); } - void ddfw::add(solver const& s) { - m_clauses.reset(); - m_use_list.reset(); - m_num_non_binary_clauses = 0; - - unsigned trail_sz = s.init_trail_size(); - for (unsigned i = 0; i < trail_sz; ++i) { - add(1, s.m_trail.data() + i); - } - unsigned sz = s.m_watches.size(); - for (unsigned l_idx = 0; l_idx < sz; ++l_idx) { - literal l1 = ~to_literal(l_idx); - watch_list const & wlist = s.m_watches[l_idx]; - for (watched const& w : wlist) { - if (!w.is_binary_non_learned_clause()) - continue; - literal l2 = w.get_literal(); - if (l1.index() > l2.index()) - continue; - literal ls[2] = { l1, l2 }; - add(2, ls); - } - } - for (clause* c : s.m_clauses) { - add(c->size(), c->begin()); - } - m_num_non_binary_clauses = s.m_clauses.size(); - } - void ddfw::add_assumptions() { for (unsigned i = 0; i < m_assumptions.size(); ++i) add(1, m_assumptions.data() + i); @@ -236,8 +207,10 @@ namespace sat { m_restart_count = 0; m_restart_next = m_config.m_restart_base*2; +#if 0 m_parsync_count = 0; m_parsync_next = m_config.m_parsync_base; +#endif m_min_sz = m_unsat.size(); m_flips = 0; @@ -246,18 +219,6 @@ namespace sat { m_stopwatch.start(); } - void ddfw::reinit(solver& s, bool_vector const& phase) { - add(s); - add_assumptions(); - for (unsigned v = 0; v < phase.size(); ++v) { - value(v) = phase[v]; - reward(v) = 0; - make_count(v) = 0; - } - init_clause_data(); - flatten_use_list(); - } - void ddfw::reinit() { add_assumptions(); init_clause_data(); @@ -414,25 +375,12 @@ namespace sat { } } - bool ddfw::should_parallel_sync() { - return m_par != nullptr && m_flips >= m_parsync_next; - } - void ddfw::save_priorities() { m_probs.reset(); for (unsigned v = 0; v < num_vars(); ++v) m_probs.push_back(-m_vars[v].m_reward_avg); } - void ddfw::do_parallel_sync() { - if (m_par->from_solver(*this)) - m_par->to_solver(*this); - - ++m_parsync_count; - m_parsync_next *= 3; - m_parsync_next /= 2; - } - void ddfw::save_model() { m_model.reserve(num_vars()); for (unsigned i = 0; i < num_vars(); ++i) diff --git a/src/sat/sat_ddfw.h b/src/ast/sls/sat_ddfw.h similarity index 85% rename from src/sat/sat_ddfw.h rename to src/ast/sls/sat_ddfw.h index 2580e3bab..d74da3d54 100644 --- a/src/sat/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -26,12 +26,12 @@ #include "util/ema.h" #include "util/sat_sls.h" #include "util/map.h" -#include "sat/sat_types.h" +#include "util/sat_literal.h" +#include "util/statistics.h" +#include "util/stopwatch.h" namespace sat { - class solver; - class parallel; class local_search_plugin { public: @@ -43,8 +43,9 @@ namespace sat { virtual void on_save_model() = 0; virtual void on_restart() = 0; }; - - class ddfw : public i_local_search { + + class ddfw { + friend class ddfw_wrapper; protected: struct config { @@ -86,7 +87,7 @@ namespace sat { svector m_vars; // var -> info svector m_probs; // var -> probability of flipping svector m_scores; // reward -> score - model m_model; // var -> best assignment + svector m_model; // var -> best assignment unsigned m_init_weight = 2; vector m_use_list; @@ -97,15 +98,15 @@ namespace sat { indexed_uint_set m_unsat_vars; // set of variables that are in unsat clauses random_gen m_rand; unsigned m_num_non_binary_clauses = 0; - unsigned m_restart_count = 0, m_reinit_count = 0, m_parsync_count = 0; - uint64_t m_restart_next = 0, m_reinit_next = 0, m_parsync_next = 0; + unsigned m_restart_count = 0, m_reinit_count = 0; + uint64_t m_restart_next = 0, m_reinit_next = 0; uint64_t m_flips = 0, m_last_flips = 0, m_shifts = 0; unsigned m_min_sz = 0, m_steps_since_progress = 0; u_map m_models; stopwatch m_stopwatch; - parallel* m_par; scoped_ptr m_plugin = nullptr; + std::function m_parallel_sync; void flatten_use_list(); @@ -191,11 +192,7 @@ namespace sat { void do_restart(); void reinit_values(); - unsigned select_random_true_clause(); - - // parallel integration - bool should_parallel_sync(); - void do_parallel_sync(); + unsigned select_random_true_clause(); void log(); @@ -205,8 +202,6 @@ namespace sat { void invariant(); - - void del(); void add_assumptions(); @@ -217,35 +212,33 @@ namespace sat { public: - ddfw(): m_par(nullptr) {} + ddfw() {} - ~ddfw() override; + ~ddfw(); void set_plugin(local_search_plugin* p) { m_plugin = p; } - lbool check(unsigned sz, literal const* assumptions, parallel* p) override; + lbool check(unsigned sz, literal const* assumptions); - void updt_params(params_ref const& p) override; + void updt_params(params_ref const& p); - model const& get_model() const override { return m_model; } + svector const& get_model() const { return m_model; } - reslimit& rlimit() override { return m_limit; } + reslimit& rlimit() { return m_limit; } - void set_seed(unsigned n) override { m_rand.set_seed(n); } + void set_seed(unsigned n) { m_rand.set_seed(n); } - void add(solver const& s) override; - bool get_value(bool_var v) const override { return value(v); } + bool get_value(bool_var v) const { return value(v); } std::ostream& display(std::ostream& out) const; // for parallel integration - unsigned num_non_binary_clauses() const override { return m_num_non_binary_clauses; } - void reinit(solver& s, bool_vector const& phase) override; + unsigned num_non_binary_clauses() const { return m_num_non_binary_clauses; } - void collect_statistics(statistics& st) const override {} + void collect_statistics(statistics& st) const {} - double get_priority(bool_var v) const override { return m_probs[v]; } + double get_priority(bool_var v) const { return m_probs[v]; } // access clause information and state of Boolean search indexed_uint_set& unsat_set() { return m_unsat; } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index ca4a95e41..1c6e7e404 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -324,6 +324,11 @@ namespace sls { SASSERT(dtt(sign(bv), ineq) == 0); } vi.m_value = new_value; + if (vi.m_shared) { + sort* s = vi.m_sort == var_sort::INT ? a.mk_int() : a.mk_real(); + expr_ref num = from_num(s, new_value); + ctx.set_value(vi.m_expr, num); + } for (auto idx : vi.m_muls) { auto const& [v, monomial] = m_muls[idx]; num_t prod(1); @@ -380,6 +385,20 @@ namespace sls { return false; } + + expr_ref arith_base::from_num(sort* s, rational const& n) { + return expr_ref(a.mk_numeral(n, s), m); + } + + expr_ref arith_base>::from_num(sort* s, checked_int64 const& n) { + return expr_ref(a.mk_numeral(rational(n.get_int64(), rational::i64()), s), m); + } + + template + expr_ref arith_base::from_num(sort* s, num_t const& n) { + return expr_ref(m); + } + template void arith_base::add_args(linear_term& term, expr* e, num_t const& coeff) { auto v = m_expr2var.get(e->get_id(), UINT_MAX); @@ -444,15 +463,12 @@ namespace sls { else if (a.is_to_int(e, x)) add_arg(term, coeff, mk_op(arith_op_kind::OP_TO_INT, e, x, x)); else if (a.is_to_real(e, x)) - add_arg(term, coeff, mk_op(arith_op_kind::OP_TO_REAL, e, x, x)); - else if (is_uninterp(e)) - add_arg(term, coeff, mk_var(e)); + add_arg(term, coeff, mk_op(arith_op_kind::OP_TO_REAL, e, x, x)); else if (a.is_arith_expr(e)) { NOT_IMPLEMENTED_YET(); } - else { - NOT_IMPLEMENTED_YET(); - } + else + add_arg(term, coeff, mk_var(e)); } template @@ -950,6 +966,29 @@ namespace sls { void arith_base::register_term(expr* e) { } + template + void arith_base::set_shared(expr* e) { + if (!a.is_int_real(e)) + return; + var_t v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v == UINT_MAX) + v = mk_term(e); + m_vars[v].m_shared = true; + } + + template + void arith_base::set_value(expr* e, expr* v) { + auto w = m_expr2var.get(e->get_id(), UINT_MAX); + if (w == UINT_MAX) + return; + num_t n; + if (!is_num(v, n)) + return; + if (n == value(w)) + return; + update(w, n); + } + template expr_ref arith_base::get_value(expr* e) { auto v = mk_var(e); diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index cd71a3911..135e8e2e7 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -90,6 +90,7 @@ namespace sls { expr* m_expr; num_t m_value{ 0 }; num_t m_best_value{ 0 }; + bool m_shared = false; var_sort m_sort; arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP; unsigned m_def_idx = UINT_MAX; @@ -147,9 +148,7 @@ namespace sls { double reward(sat::literal lit); bool sign(sat::bool_var v) const { return !ctx.is_true(sat::literal(v, false)); } - ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); } - - + ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); } num_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } num_t dtt(bool sign, num_t const& args_value, ineq const& ineq) const; num_t dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const; @@ -178,19 +177,19 @@ namespace sls { num_t value(var_t v) const { return m_vars[v].m_value; } bool is_num(expr* e, num_t& i); - + expr_ref from_num(sort* s, num_t const& n); void check_ineqs(); - public: arith_base(context& ctx); ~arith_base() override {} void init_bool_var(sat::bool_var v) override; void register_term(expr* e) override; + void set_shared(expr* e) override; + void set_value(expr* e, expr* v) override; expr_ref get_value(expr* e) override; lbool check() override; bool is_sat() override; void reset() override; - void on_rescale() override; void on_restart() override; std::ostream& display(std::ostream& out) const override; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index 23b657192..e8d237fb0 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -29,6 +29,8 @@ namespace sls { } catch (overflow_exception&) { m_arith = alloc(arith_base, ctx); + for (auto e : m_shared) + m_arith->set_shared(e); return; // initialization happens on check-sat calls } } @@ -44,6 +46,8 @@ namespace sls { } catch (overflow_exception&) { m_arith = alloc(arith_base, ctx); + for (auto e : m_shared) + m_arith->set_shared(e); } } m_arith->register_term(e); @@ -56,6 +60,8 @@ namespace sls { } catch (overflow_exception&) { m_arith = alloc(arith_base, ctx); + for (auto e : m_shared) + m_arith->set_shared(e); } } return m_arith->get_value(e); @@ -68,6 +74,8 @@ namespace sls { } catch (overflow_exception&) { m_arith = alloc(arith_base, ctx); + for (auto e : m_shared) + m_arith->set_shared(e); } } return m_arith->check(); @@ -79,35 +87,54 @@ namespace sls { return m_arith->is_sat(); } void arith_plugin::reset() { - if (!m_arith) - m_arith64->reset(); - else + if (m_arith) m_arith->reset(); + else + m_arith64->reset(); + m_shared.reset(); } void arith_plugin::on_rescale() { - if (!m_arith) - m_arith64->on_rescale(); - else + if (m_arith) m_arith->on_rescale(); - } - void arith_plugin::on_restart() { - if (!m_arith) - m_arith64->on_restart(); else - m_arith->on_restart(); + m_arith64->on_rescale(); + } + + void arith_plugin::on_restart() { + if (m_arith) + m_arith->on_restart(); + else + m_arith64->on_restart(); } std::ostream& arith_plugin::display(std::ostream& out) const { - if (!m_arith) - return m_arith64->display(out); - return m_arith->display(out); + if (m_arith) + return m_arith->display(out); + else + return m_arith64->display(out); } void arith_plugin::mk_model(model& mdl) { - if (!m_arith) - m_arith64->mk_model(mdl); - else + if (m_arith) m_arith->mk_model(mdl); + else + m_arith64->mk_model(mdl); + } + + void arith_plugin::set_shared(expr* e) { + if (m_arith) + m_arith->set_shared(e); + else { + m_arith64->set_shared(e); + m_shared.push_back(e); + } + } + + void arith_plugin::set_value(expr* e, expr* v) { + if (m_arith) + m_arith->set_value(e, v); + else + m_arith->set_value(e, v); } } diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 494a20b9b..1686cf3b2 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -24,8 +24,12 @@ namespace sls { class arith_plugin : public plugin { scoped_ptr>> m_arith64; scoped_ptr> m_arith; + expr_ref_vector m_shared; public: - arith_plugin(context& ctx) : plugin(ctx) { m_arith64 = alloc(arith_base>,ctx); } + arith_plugin(context& ctx) : + plugin(ctx), m_shared(ctx.get_manager()) { + m_arith64 = alloc(arith_base>,ctx); + } ~arith_plugin() override {} void init_bool_var(sat::bool_var v) override; void register_term(expr* e) override; @@ -38,6 +42,8 @@ namespace sls { void on_restart() override; std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override; + void set_shared(expr* e) override; + void set_value(expr* e, expr* v) override; }; } diff --git a/src/ast/sls/sls_bv.cpp b/src/ast/sls/sls_bv.cpp new file mode 100644 index 000000000..1a6cf04cd --- /dev/null +++ b/src/ast/sls/sls_bv.cpp @@ -0,0 +1,93 @@ + +#include "ast/sls/sls_bv.h" + +namespace sls { + + bv_plugin::bv_plugin(context& ctx): + plugin(ctx), + bv(m), + m_terms(m), + m_eval(m) + {} + + void bv_plugin::init_bool_var(sat::bool_var v) { + } + + void bv_plugin::register_term(expr* e) { + } + + expr_ref bv_plugin::get_value(expr* e) { + return expr_ref(m); + } + + lbool bv_plugin::check() { + return l_undef; + } + + bool bv_plugin::is_sat() { + return false; + } + + void bv_plugin::reset() { + } + + void bv_plugin::on_rescale() { + + } + + void bv_plugin::on_restart() { + } + + std::ostream& bv_plugin::display(std::ostream& out) const { + return out; + } + + void bv_plugin::mk_model(model& mdl) { + + } + + void bv_plugin::set_shared(expr* e) { + + } + + void bv_plugin::set_value(expr* e, expr* v) { + + } + + std::pair bv_plugin::next_to_repair() { + app* e = nullptr; + if (m_repair_down != UINT_MAX) { + e = m_terms.term(m_repair_down); + m_repair_down = UINT_MAX; + return { true, e }; + } + + if (!m_repair_up.empty()) { + unsigned index = m_repair_up.elem_at(ctx.rand(m_repair_up.size())); + m_repair_up.remove(index); + e = m_terms.term(index); + return { false, e }; + } + + while (!m_repair_roots.empty()) { + unsigned index = m_repair_roots.elem_at(ctx.rand(m_repair_roots.size())); + e = m_terms.term(index); + if (m_terms.is_assertion(e) && !m_eval.bval1(e)) { + SASSERT(m_eval.bval0(e)); + return { true, e }; + } + if (!m_eval.re_eval_is_correct(e)) { + init_repair_goal(e); + return { true, e }; + } + m_repair_roots.remove(index); + } + + return { false, nullptr }; + } + + void bv_plugin::init_repair_goal(app* e) { + m_eval.init_eval(e); + } + +} diff --git a/src/ast/sls/sls_bv.h b/src/ast/sls/sls_bv.h new file mode 100644 index 000000000..c1ad0464a --- /dev/null +++ b/src/ast/sls/sls_bv.h @@ -0,0 +1,55 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_bv.h + +Abstract: + + Theory plugin for bit-vector local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ +#pragma once + +#include "ast/sls/sls_smt.h" +#include "ast/bv_decl_plugin.h" +#include "ast/sls/bv_sls_terms.h" +#include "ast/sls/bv_sls_eval.h" + +namespace sls { + + class bv_plugin : public plugin { + bv_util bv; + bv::sls_terms m_terms; + bv::sls_eval m_eval; + bv::sls_stats m_stats; + + indexed_uint_set m_repair_up, m_repair_roots; + unsigned m_repair_down = UINT_MAX; + + std::pair next_to_repair(); + void init_repair_goal(app* e); + public: + bv_plugin(context& ctx); + ~bv_plugin() override {} + void init_bool_var(sat::bool_var v) override; + void register_term(expr* e) override; + expr_ref get_value(expr* e) override; + lbool check() override; + bool is_sat() override; + void reset() override; + + void on_rescale() override; + void on_restart() override; + std::ostream& display(std::ostream& out) const override; + void mk_model(model& mdl) override; + void set_shared(expr* e) override; + void set_value(expr* e, expr* v) override; + }; + +} diff --git a/src/ast/sls/sls_cc.h b/src/ast/sls/sls_cc.h index c68671909..06204bb28 100644 --- a/src/ast/sls/sls_cc.h +++ b/src/ast/sls/sls_cc.h @@ -46,6 +46,8 @@ namespace sls { void init_bool_var(sat::bool_var v) override {} std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override; + void set_value(expr* e, expr* v) override {} + void set_shared(expr* e) override {} }; } diff --git a/src/ast/sls/sls_smt.cpp b/src/ast/sls/sls_smt.cpp index fa696b3f7..4bbd7c136 100644 --- a/src/ast/sls/sls_smt.cpp +++ b/src/ast/sls/sls_smt.cpp @@ -101,7 +101,9 @@ namespace sls { } void context::set_value(expr* e, expr* v) { - NOT_IMPLEMENTED_YET(); + for (auto p : m_plugins) + if (p) + p->set_value(e, v); } bool context::is_relevant(expr* e) { diff --git a/src/ast/sls/sls_smt.h b/src/ast/sls/sls_smt.h index 23565a1e5..3f555300c 100644 --- a/src/ast/sls/sls_smt.h +++ b/src/ast/sls/sls_smt.h @@ -46,6 +46,8 @@ namespace sls { virtual void on_restart() {}; virtual std::ostream& display(std::ostream& out) const = 0; virtual void mk_model(model& mdl) = 0; + virtual void set_shared(expr* e) = 0; + virtual void set_value(expr* e, expr* v) = 0; }; using clause = std::initializer_list ; @@ -110,6 +112,7 @@ namespace sls { double reward(sat::bool_var v) { return s.reward(v); } indexed_uint_set const& unsat() const { return s.unsat(); } unsigned rand() { return m_rand(); } + unsigned rand(unsigned n) { return m_rand(n); } sat::literal_vector const& root_literals() const { return m_root_literals; } void reinit_relevant(); diff --git a/src/sat/CMakeLists.txt b/src/sat/CMakeLists.txt index 77fabcbcf..48e6959b3 100644 --- a/src/sat/CMakeLists.txt +++ b/src/sat/CMakeLists.txt @@ -15,7 +15,7 @@ z3_add_component(sat sat_config.cpp sat_cut_simplifier.cpp sat_cutset.cpp - sat_ddfw.cpp + sat_ddfw_wrapper.cpp sat_drat.cpp sat_elim_eqs.cpp sat_elim_vars.cpp diff --git a/src/sat/sat_ddfw_wrapper.cpp b/src/sat/sat_ddfw_wrapper.cpp new file mode 100644 index 000000000..39c767073 --- /dev/null +++ b/src/sat/sat_ddfw_wrapper.cpp @@ -0,0 +1,90 @@ +/*++ + Copyright (c) 2019 Microsoft Corporation + + Module Name: + + sat_ddfw_wrapper.cpp + +*/ + +#include "sat/sat_ddfw_wrapper.h" +#include "sat/sat_solver.h" +#include "sat/sat_parallel.h" + +namespace sat { + + lbool ddfw_wrapper::check(unsigned sz, literal const* assumptions, parallel* p) { + flet _p(m_par, p); + m_ddfw.m_parallel_sync = nullptr; + if (m_par) { + m_ddfw.m_parallel_sync = [&]() -> bool { + if (should_parallel_sync()) { + do_parallel_sync(); + return true; + } + else + return false; + }; + } + return m_ddfw.check(sz, assumptions); + } + + bool ddfw_wrapper::should_parallel_sync() { + return m_par != nullptr && m_ddfw.m_flips >= m_parsync_next; + } + + void ddfw_wrapper::do_parallel_sync() { + if (m_par->from_solver(*this)) + m_par->to_solver(*this); + + ++m_parsync_count; + m_parsync_next *= 3; + m_parsync_next /= 2; + } + + + void ddfw_wrapper::reinit(solver& s, bool_vector const& phase) { + add(s); + m_ddfw.add_assumptions(); + for (unsigned v = 0; v < phase.size(); ++v) { + m_ddfw.value(v) = phase[v]; + m_ddfw.reward(v) = 0; + m_ddfw.make_count(v) = 0; + } + m_ddfw.init_clause_data(); + m_ddfw.flatten_use_list(); + } + + void ddfw_wrapper::add(solver const& s) { + m_ddfw.m_clauses.reset(); + m_ddfw.m_use_list.reset(); + m_ddfw.m_num_non_binary_clauses = 0; + + unsigned trail_sz = s.init_trail_size(); + for (unsigned i = 0; i < trail_sz; ++i) { + m_ddfw.add(1, s.m_trail.data() + i); + } + unsigned sz = s.m_watches.size(); + for (unsigned l_idx = 0; l_idx < sz; ++l_idx) { + literal l1 = ~to_literal(l_idx); + watch_list const & wlist = s.m_watches[l_idx]; + for (watched const& w : wlist) { + if (!w.is_binary_non_learned_clause()) + continue; + literal l2 = w.get_literal(); + if (l1.index() > l2.index()) + continue; + literal ls[2] = { l1, l2 }; + m_ddfw.add(2, ls); + } + } + for (clause* c : s.m_clauses) + m_ddfw.add(c->size(), c->begin()); + + } + + + + +} + diff --git a/src/sat/sat_ddfw_wrapper.h b/src/sat/sat_ddfw_wrapper.h new file mode 100644 index 000000000..6c87c72bd --- /dev/null +++ b/src/sat/sat_ddfw_wrapper.h @@ -0,0 +1,89 @@ +/*++ + Copyright (c) 2019 Microsoft Corporation + + Module Name: + + sat_ddfw_wrapper.h + + + --*/ +#pragma once + +#include "util/uint_set.h" +#include "util/rlimit.h" +#include "util/params.h" +#include "util/ema.h" +#include "util/sat_sls.h" +#include "util/map.h" +#include "ast/sls/sat_ddfw.h" +#include "sat/sat_types.h" + +namespace sat { + class solver; + class parallel; + + + class ddfw_wrapper : public i_local_search { + protected: + ddfw m_ddfw; + parallel* m_par = nullptr; + unsigned m_parsync_count = 0; + uint64_t m_parsync_next = 0; + + void do_parallel_sync(); + bool should_parallel_sync(); + + public: + + ddfw_wrapper() {} + + ~ddfw_wrapper() override {} + + void set_plugin(local_search_plugin* p) { m_ddfw.set_plugin(p); } + + lbool check(unsigned sz, literal const* assumptions, parallel* p) override; + + void updt_params(params_ref const& p) override { m_ddfw.updt_params(p); } + + model const& get_model() const override { return m_ddfw.get_model(); } + + reslimit& rlimit() override { return m_ddfw.rlimit(); } + + void set_seed(unsigned n) override { m_ddfw.set_seed(n); } + + void add(solver const& s) override; + + bool get_value(bool_var v) const override { return m_ddfw.get_value(v); } + + std::ostream& display(std::ostream& out) const { return m_ddfw.display(out); } + + // for parallel integration + unsigned num_non_binary_clauses() const override { return m_ddfw.num_non_binary_clauses(); } + + void reinit(solver& s, bool_vector const& phase) override; + + void collect_statistics(statistics& st) const override {} + + double get_priority(bool_var v) const override { return m_ddfw.get_priority(v); } + + // access clause information and state of Boolean search + indexed_uint_set& unsat_set() { return m_ddfw.unsat_set(); } + + vector const& clauses() const { return m_ddfw.clauses(); } + + clause_info& get_clause_info(unsigned idx) { return m_ddfw.get_clause_info(idx); } + + void remove_assumptions() { m_ddfw.remove_assumptions(); } + + void flip(bool_var v) { m_ddfw.flip(v); } + + inline double get_reward(bool_var v) const { return m_ddfw.get_reward(v); } + + void add(unsigned sz, literal const* c) { m_ddfw.add(sz, c); } + + void reinit() { m_ddfw.reinit(); } + + + }; +} + diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 1cd0c1400..6a70e8bad 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -29,7 +29,7 @@ Revision History: #include "sat/sat_solver.h" #include "sat/sat_integrity_checker.h" #include "sat/sat_lookahead.h" -#include "sat/sat_ddfw.h" +#include "sat/sat_ddfw_wrapper.h" #include "sat/sat_prob.h" #include "sat/sat_anf_simplifier.h" #include "sat/sat_cut_simplifier.h" @@ -1362,7 +1362,7 @@ namespace sat { } literal_vector _lits; scoped_limits scoped_rl(rlimit()); - m_local_search = alloc(ddfw); + m_local_search = alloc(ddfw_wrapper); scoped_ls _ls(*this); SASSERT(m_local_search); m_local_search->add(*this); @@ -1439,7 +1439,7 @@ namespace sat { lbool solver::do_ddfw_search(unsigned num_lits, literal const* lits) { if (m_ext) return l_undef; SASSERT(!m_local_search); - m_local_search = alloc(ddfw); + m_local_search = alloc(ddfw_wrapper); return invoke_local_search(num_lits, lits); } @@ -1480,7 +1480,7 @@ namespace sat { vector lims(num_ddfw); // set up ddfw search for (int i = 0; i < num_ddfw; ++i) { - ddfw* d = alloc(ddfw); + ddfw_wrapper* d = alloc(ddfw_wrapper); d->updt_params(m_params); d->set_seed(m_config.m_random_seed + i); d->add(*this); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 57477f686..9e7186a34 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -228,7 +228,7 @@ namespace sat { friend class parallel; friend class lookahead; friend class local_search; - friend class ddfw; + friend class ddfw_wrapper; friend class prob; friend class unit_walk; friend struct mk_stat; diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 48530bc83..770e1cf5d 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -28,7 +28,6 @@ Author: #include "math/polynomial/algebraic_numbers.h" #include "math/polynomial/polynomial.h" #include "sat/smt/sat_th.h" -#include "sat/sat_ddfw.h" namespace euf { class solver; diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 373948014..cc437173e 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -18,7 +18,6 @@ Author: #include "util/top_sort.h" #include "sat/smt/sat_smt.h" -#include "sat/sat_ddfw.h" #include "ast/euf/euf_egraph.h" #include "model/model.h" #include "smt/params/smt_params.h" @@ -139,10 +138,6 @@ namespace euf { virtual euf::enode_pair get_justification_eq(size_t j); - /** - * Local search interface - */ - virtual void set_bool_search(sat::ddfw* ddfw) {} virtual void set_bounds_begin() {} diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index 5028bf239..035385b65 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -201,13 +201,13 @@ namespace sls { void solver::run_local_search_async() { if (m_ddfw) { - m_result = m_ddfw->check(0, nullptr, nullptr); + m_result = m_ddfw->check(0, nullptr); m_completed = true; } } void solver::run_local_search_sync() { - m_result = m_ddfw->check(0, nullptr, nullptr); + m_result = m_ddfw->check(0, nullptr); local_search_done(); } diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h index 9d009b805..92a64955c 100644 --- a/src/sat/smt/sls_solver.h +++ b/src/sat/smt/sls_solver.h @@ -20,7 +20,7 @@ Author: #include "util/rlimit.h" #include "ast/sls/bv_sls.h" #include "sat/smt/sat_th.h" -#include "sat/sat_ddfw.h" +#include "ast/sls/sat_ddfw.h" #ifdef SINGLE_THREAD diff --git a/src/sat/tactic/sat_tactic.cpp b/src/sat/tactic/sat_tactic.cpp index 9fe7a6947..dabda88d7 100644 --- a/src/sat/tactic/sat_tactic.cpp +++ b/src/sat/tactic/sat_tactic.cpp @@ -16,13 +16,14 @@ Author: Notes: --*/ +#include "params/sat_params.hpp" #include "ast/ast_pp.h" #include "model/model_v2_pp.h" #include "tactic/tactical.h" #include "sat/tactic/goal2sat.h" #include "sat/tactic/sat2goal.h" #include "sat/sat_solver.h" -#include "params/sat_params.hpp" + class sat_tactic : public tactic { diff --git a/src/tactic/smtlogics/smt_tactic.cpp b/src/tactic/smtlogics/smt_tactic.cpp index 288d728c3..7bae01a81 100644 --- a/src/tactic/smtlogics/smt_tactic.cpp +++ b/src/tactic/smtlogics/smt_tactic.cpp @@ -15,11 +15,11 @@ Author: --*/ -#include "smt/tactic/smt_tactic_core.h" -#include "sat/tactic/sat_tactic.h" #include "params/sat_params.hpp" #include "solver/solver2tactic.h" #include "solver/solver.h" +#include "smt/tactic/smt_tactic_core.h" +#include "sat/tactic/sat_tactic.h" tactic * mk_smt_tactic(ast_manager & m, params_ref const & p) { sat_params sp(p);