From 96d815b9049e7feba128fa32c27a81aa2c42350c Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 7 Feb 2023 19:27:19 -0800 Subject: [PATCH] adding arith sls --- src/sat/sat_ddfw.cpp | 14 +- src/sat/sat_ddfw.h | 2 +- src/sat/sat_extension.h | 1 + src/sat/sat_local_search.cpp | 11 +- src/sat/sat_local_search.h | 2 +- src/sat/sat_parallel.cpp | 2 +- src/sat/sat_prob.h | 2 +- src/sat/sat_solver.cpp | 2 +- src/sat/sat_types.h | 2 +- src/sat/smt/CMakeLists.txt | 2 + src/sat/smt/arith_diagnostics.cpp | 11 + src/sat/smt/arith_local_search.cpp | 352 +++++++++++++++++++++++++++++ src/sat/smt/arith_solver.cpp | 1 + src/sat/smt/arith_solver.h | 105 +++++++-- src/sat/smt/euf_local_search.cpp | 126 +++++++++++ src/sat/smt/euf_solver.h | 16 +- src/sat/smt/sat_th.h | 9 + 17 files changed, 625 insertions(+), 35 deletions(-) create mode 100644 src/sat/smt/arith_local_search.cpp create mode 100644 src/sat/smt/euf_local_search.cpp diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index 418070c64..723b38586 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -200,16 +200,14 @@ namespace sat { m_shifts = 0; m_stopwatch.start(); } - - void ddfw::reinit(solver& s) { + + void ddfw::reinit(solver& s, bool_vector const& phase) { add(s); add_assumptions(); - if (s.m_best_phase_size > 0) { - for (unsigned v = 0; v < num_vars(); ++v) { - value(v) = s.m_best_phase[v]; - reward(v) = 0; - make_count(v) = 0; - } + for (unsigned v = 0; v < phase.size(); ++v) { + value(v) = phase[v]; + reward(v) = 0; + make_count(v) = 0; } init_clause_data(); flatten_use_list(); diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index d5e7df773..ce5ff9fdf 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -227,7 +227,7 @@ namespace sat { // for parallel integration unsigned num_non_binary_clauses() const override { return m_num_non_binary_clauses; } - void reinit(solver& s) override; + void reinit(solver& s, bool_vector const& phase) override; void collect_statistics(statistics& st) const override {} diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index d6a956a32..3a1f363a3 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -126,6 +126,7 @@ namespace sat { virtual void add_assumptions(literal_set& ext_assumptions) {} virtual bool tracking_assumptions() { return false; } virtual bool enable_self_propagate() const { return false; } + virtual void local_search(bool_vector& phase) {} virtual bool extract_pb(std::function& card, std::function& pb) { diff --git a/src/sat/sat_local_search.cpp b/src/sat/sat_local_search.cpp index 61ddd13d8..8cc90f05e 100644 --- a/src/sat/sat_local_search.cpp +++ b/src/sat/sat_local_search.cpp @@ -359,13 +359,10 @@ namespace sat { m_par(nullptr) { } - void local_search::reinit(solver& s) { - import(s, true); - if (s.m_best_phase_size > 0) { - for (unsigned i = num_vars(); i-- > 0; ) { - set_phase(i, s.m_best_phase[i]); - } - } + void local_search::reinit(solver& s, bool_vector const& phase) { + import(s, true); + for (unsigned i = phase.size(); i-- > 0; ) + set_phase(i, phase[i]); } void local_search::import(solver const& s, bool _init) { diff --git a/src/sat/sat_local_search.h b/src/sat/sat_local_search.h index e46d4b009..7295b851a 100644 --- a/src/sat/sat_local_search.h +++ b/src/sat/sat_local_search.h @@ -248,7 +248,7 @@ namespace sat { void set_seed(unsigned n) override { config().set_random_seed(n); } - void reinit(solver& s) override; + void reinit(solver& s, bool_vector const& phase) override; // used by unit-walk void set_phase(bool_var v, bool f); diff --git a/src/sat/sat_parallel.cpp b/src/sat/sat_parallel.cpp index 3e493168a..cdb13706f 100644 --- a/src/sat/sat_parallel.cpp +++ b/src/sat/sat_parallel.cpp @@ -252,7 +252,7 @@ namespace sat { m_consumer_ready = true; if (m_solver_copy) { copied = true; - s.reinit(*m_solver_copy.get()); + s.reinit(*m_solver_copy.get(), m_solver_copy->m_best_phase); } return copied; } diff --git a/src/sat/sat_prob.h b/src/sat/sat_prob.h index f05365e39..d8d58d091 100644 --- a/src/sat/sat_prob.h +++ b/src/sat/sat_prob.h @@ -150,7 +150,7 @@ namespace sat { void collect_statistics(statistics& st) const override {} - void reinit(solver& s) override { UNREACHABLE(); } + void reinit(solver& s, bool_vector const& phase) override { UNREACHABLE(); } }; } diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 373885f34..10aac6dcb 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1373,7 +1373,7 @@ namespace sat { m_backoffs.m_local_search.delta_effort(*this); m_local_search->rlimit().push(m_backoffs.m_local_search.limit); - m_local_search->reinit(*this); + m_local_search->reinit(*this, m_best_phase); lbool r = m_local_search->check(_lits.size(), _lits.data(), nullptr); auto const& mdl = m_local_search->get_model(); if (mdl.size() == m_best_phase.size()) { diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index 3026b3c5e..d5d457cb0 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -85,7 +85,7 @@ namespace sat { virtual void updt_params(params_ref const& p) = 0; virtual void set_seed(unsigned s) = 0; virtual lbool check(unsigned sz, literal const* assumptions, parallel* par) = 0; - virtual void reinit(solver& s) = 0; + virtual void reinit(solver& s, bool_vector const& phase) = 0; virtual unsigned num_non_binary_clauses() const = 0; virtual reslimit& rlimit() = 0; virtual model const& get_model() const = 0; diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 22fc9963c..dbbfc3856 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -3,6 +3,7 @@ z3_add_component(sat_smt arith_axioms.cpp arith_diagnostics.cpp arith_internalize.cpp + arith_local_search.cpp arith_solver.cpp array_axioms.cpp array_diagnostics.cpp @@ -20,6 +21,7 @@ z3_add_component(sat_smt euf_ackerman.cpp euf_internalize.cpp euf_invariant.cpp + euf_local_search.cpp euf_model.cpp euf_proof.cpp euf_proof_checker.cpp diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index 8ead3d980..a3e48256d 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -23,6 +23,17 @@ Author: namespace arith { + + void arith_proof_hint_builder::set_type(euf::solver& ctx, hint_type ty) { + ctx.push(value_trail(m_eq_tail)); + ctx.push(value_trail(m_lit_tail)); + m_ty = ty; + reset(); + } + + arith_proof_hint* arith_proof_hint_builder::mk(euf::solver& s) { + return new (s.get_region()) arith_proof_hint(m_ty, m_num_le, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); + } std::ostream& solver::display(std::ostream& out) const { lp().display(out); diff --git a/src/sat/smt/arith_local_search.cpp b/src/sat/smt/arith_local_search.cpp new file mode 100644 index 000000000..c6d02e436 --- /dev/null +++ b/src/sat/smt/arith_local_search.cpp @@ -0,0 +1,352 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + arith_local_search.cpp + +Abstract: + + Local search dispatch for SMT + +Author: + + Nikolaj Bjorner (nbjorner) 2023-02-07 + +--*/ +#include "sat/sat_solver.h" +#include "sat/smt/arith_solver.h" + + +namespace arith { + + + /// + /// need access to clauses + /// need access to m_unsat + /// need update of phase + /// need to initialize ineqs (arithmetical atoms) + /// + + solver::sls::sls(solver& s): + s(s), m(s.m) {} + + void solver::sls::operator()(bool_vector& phase) { + + // need to init variables/atoms/ineqs + + m.limit().push(m_max_arith_steps); + + unsigned m_best_min_unsat = 1; + unsigned best = m_best_min_unsat; + + while (m.inc() && m_best_min_unsat > 0) { + // unsigned prev = m_unsat.size(); + if (!flip()) + return; +#if 0 + if (m_unsat.size() < best) { + best = m_unsat.size(); + num_steps = 0; + } + if (m_unsat.size() < m_best_min_unsat) + save_best_values(); +#endif + } + } + + void solver::sls::set_bounds_begin() { + m_max_arith_steps = 0; + } + + void solver::sls::set_bounds_end(unsigned num_literals) { + // m_max_arith_steps = s.ctx.m_sl_config.L * + } + + void solver::sls::set_bounds(enode* n) { + ++m_max_arith_steps; + } + + bool solver::sls::flip() { + ++m_stats.m_num_flips; + log(); + if (flip_unsat()) + return true; + if (flip_clauses()) + return true; + if (flip_dscore()) + return true; + return false; + } + + // distance to true + rational solver::sls::dtt(rational const& args, ineq const& ineq) const { + switch (ineq.m_op) { + case ineq_kind::LE: + if (args <= ineq.m_bound) + return rational::zero(); + return args - ineq.m_bound; + case ineq_kind::EQ: + if (args == ineq.m_bound) + return rational::zero(); + return rational::one(); + case ineq_kind::NE: + if (args == ineq.m_bound) + return rational::one(); + return rational::zero(); + case ineq_kind::LT: + default: + if (args < ineq.m_bound) + return rational::zero(); + return args - ineq.m_bound + 1; + } + } + + rational solver::sls::dtt(ineq const& ineq, var_t v, rational const& new_value) const { + auto new_args_value = ineq.m_args_value; + for (auto const& [coeff, w] : ineq.m_args) { + if (w == v) { + new_args_value += coeff * (new_value - m_vars[w].m_value); + break; + } + } + return dtt(new_args_value, ineq); + } + + // critical move + bool solver::sls::cm(ineq const& ineq, var_t v, rational& new_value) { + SASSERT(!ineq.is_true()); + auto delta = ineq.m_args_value - ineq.m_bound; + for (auto const& [coeff, w] : ineq.m_args) { + if (w == v) { + if (coeff > 0) + new_value = value(v) - abs(ceil(delta / coeff)); + else + new_value = value(v) + abs(floor(delta / coeff)); + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(delta + coeff * (new_value - value(v)) <= 0); + return true; + case ineq_kind::EQ: + return delta + coeff * (new_value - value(v)) == 0; + case ineq_kind::NE: + return delta + coeff * (new_value - value(v)) != 0; + case ineq_kind::LT: + return delta + coeff * (new_value - value(v)) < 0; + default: + UNREACHABLE(); break; + } + } + } + return false; + } + +#if 0 + + bool solver::sls::flip_unsat() { + unsigned start = m_rand(); + for (unsigned i = m_unsat.size(); i-- > 0; ) { + unsigned cl = m_unsat.elem_at((i + start) % m_unsat.size()); + if (flip(m_clauses[cl])) + return true; + } + return false; + } + + bool solver::sls::flip_clauses() { + unsigned start = m_rand(); + for (unsigned i = m_clauses.size(); i-- > 0; ) + if (flip_arith(m_clauses[(i + start) % m_clauses.size()])) + return true; + return false; + } + + bool solver::sls::flip_dscore() { + paws(); + unsigned start = m_rand(); + for (unsigned i = m_unsat.size(); i-- > 0; ) { + unsigned cl = m_unsat.elem_at((i + start) % m_unsat.size()); + if (flip_dscore(m_clauses[cl])) + return true; + } + std::cout << "flip dscore\n"; + IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << m_unsat.size() << ")\n"); + return false; + } + + bool solver::sls::flip_dscore(clause const& clause) { + rational new_value, min_value, min_score(-1); + var_t min_var = UINT_MAX; + for (auto a : clause.m_arith) { + auto const& ai = m_atoms[a]; + ineq const& ineq = ai.m_ineq; + for (auto const& [coeff, v] : ineq.m_args) { + if (!ineq.is_true() && cm(ineq, v, new_value)) { + rational score = dscore(v, new_value); + if (UINT_MAX == min_var || score < min_score) { + min_var = v; + min_value = new_value; + min_score = score; + } + } + } + } + if (min_var != UINT_MAX) { + update(min_var, min_value); + return true; + } + return false; + } + + void solver::sls::paws() { + for (auto& clause : m_clauses) { + bool above = 10000 * m_config.sp <= (m_rand() % 10000); + if (!above && clause.is_true() && clause.m_weight > 1) + clause.m_weight -= 1; + if (above && !clause.is_true()) + clause.m_weight += 1; + } + } + + void solver::sls::update(var_t v, rational const& new_value) { + auto& vi = m_vars[v]; + auto const& old_value = vi.m_value; + for (auto const& [coeff, atm] : vi.m_atoms) { + auto& ai = m_atoms[atm]; + SASSERT(!ai.m_is_bool); + auto& clause = m_clauses[ai.m_clause_idx]; + rational dtt_old = dtt(ai.m_ineq); + ai.m_ineq.m_args_value += coeff * (new_value - old_value); + rational dtt_new = dtt(ai.m_ineq); + bool was_true = clause.is_true(); + if (dtt_new < clause.m_dts) { + if (was_true && clause.m_dts > 0 && dtt_new == 0 && 1 == clause.m_num_trues) { + for (auto lit : clause.m_bools) { + if (is_true(lit)) { + dec_break(lit); + break; + } + } + } + clause.m_dts = dtt_new; + if (!was_true && clause.is_true()) + m_unsat.remove(ai.m_clause_idx); + } + else if (clause.m_dts == dtt_old && dtt_old < dtt_new) { + clause.m_dts = dts(clause); + if (was_true && !clause.is_true()) + m_unsat.insert(ai.m_clause_idx); + if (was_true && clause.is_true() && clause.m_dts > 0 && dtt_old == 0 && 1 == clause.m_num_trues) { + for (auto lit : clause.m_bools) { + if (is_true(lit)) { + inc_break(lit); + break; + } + } + } + } + SASSERT(clause.m_dts >= 0); + } + vi.m_value = new_value; + } + + bool solver::sls::flip_arith(clause const& clause) { + rational new_value; + for (auto a : clause.m_arith) { + auto const& ai = m_atoms[a]; + ineq const& ineq = ai.m_ineq; + for (auto const& [coeff, v] : ineq.m_args) { + if (!ineq.is_true() && cm(ineq, v, new_value)) { + int score = cm_score(v, new_value); + if (score <= 0) + continue; + unsigned num_unsat = m_unsat.size(); + update(v, new_value); + std::cout << "score " << v << " " << score << "\n"; + std::cout << num_unsat << " -> " << m_unsat.size() << "\n"; + return true; + } + } + } + return false; + } + + + + rational solver::sls::dts(clause const& cl) const { + rational d(1), d2; + bool first = true; + for (auto a : cl.m_arith) { + auto const& ai = m_atoms[a]; + d2 = dtt(ai.m_ineq); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + rational solver::sls::dts(clause const& cl, var_t v, rational const& new_value) const { + rational d(1), d2; + bool first = true; + for (auto a : cl.m_arith) { + auto const& ai = m_atoms[a]; + d2 = dtt(ai.m_ineq, v, new_value); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + // + // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) + // + rational solver::sls::dscore(var_t v, rational const& new_value) const { + auto const& vi = m_vars[v]; + rational score(0); + for (auto const& [coeff, atm] : vi.m_atoms) { + auto const& ai = m_atoms[atm]; + auto const& cl = m_clauses[ai.m_clause_idx]; + score += (cl.m_dts - dts(cl, v, new_value)) * rational(cl.m_weight); + } + return score; + } + + int solver::sls::cm_score(var_t v, rational const& new_value) { + int score = 0; + auto& vi = m_vars[v]; + for (auto const& [coeff, atm] : vi.m_atoms) { + auto const& ai = m_atoms[atm]; + auto const& clause = m_clauses[ai.m_clause_idx]; + rational dtt_old = dtt(ai.m_ineq); + rational dtt_new = dtt(ai.m_ineq, v, new_value); + if (!clause.is_true()) { + if (dtt_new == 0) + ++score; + } + else if (dtt_new == 0 || dtt_old > 0 || clause.m_num_trues > 0) + continue; + else { + bool has_true = false; + for (auto a : clause.m_arith) { + auto const& ai = m_atoms[a]; + rational d = dtt(ai.m_ineq, v, new_value); + has_true |= (d == 0); + } + if (!has_true) + --score; + } + } + return score; + } + +#endif +} + diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 35e0795b7..fd507ed15 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -24,6 +24,7 @@ namespace arith { solver::solver(euf::solver& ctx, theory_id id) : th_euf_solver(ctx, symbol("arith"), id), m_model_eqs(DEFAULT_HASHTABLE_INITIAL_CAPACITY, var_value_hash(*this), var_value_eq(*this)), + m_local_search(*this), m_resource_limit(*this), m_bp(*this), a(m), diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index a13ef6684..1717b93d1 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -78,12 +78,7 @@ namespace arith { m_eq_tail++; } public: - void set_type(euf::solver& ctx, hint_type ty) { - ctx.push(value_trail(m_eq_tail)); - ctx.push(value_trail(m_lit_tail)); - m_ty = ty; - reset(); - } + void set_type(euf::solver& ctx, hint_type ty); void set_num_le(unsigned n) { m_num_le = n; } void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); } void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); } @@ -96,12 +91,9 @@ namespace arith { } std::pair const& lit(unsigned i) const { return m_literals[i]; } std::tuple const& eq(unsigned i) const { return m_eqs[i]; } - arith_proof_hint* mk(euf::solver& s) { - return new (s.get_region()) arith_proof_hint(m_ty, m_num_le, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); - } + arith_proof_hint* mk(euf::solver& s); }; - class solver : public euf::th_euf_solver { friend struct arith_proof_hint; @@ -144,7 +136,7 @@ namespace arith { }; int_hashtable m_model_eqs; - bool m_new_eq { false }; + bool m_new_eq = false; // temporary values kept during internalization @@ -198,6 +190,85 @@ namespace arith { } }; + // local search portion for arithmetic + class sls { + enum class ineq_kind { EQ, LE, LT, NE }; + enum class var_kind { INT, REAL }; + typedef unsigned var_t; + typedef unsigned atom_t; + + struct stats { + unsigned m_num_flips = 0; + }; + // encode args <= bound, args = bound, args < bound + struct ineq { + vector> m_args; + ineq_kind m_op = ineq_kind::LE; + rational m_bound; + rational m_args_value; + + bool is_true() const { + switch (m_op) { + case ineq_kind::LE: + return m_args_value <= m_bound; + case ineq_kind::EQ: + return m_args_value == m_bound; + case ineq_kind::NE: + return m_args_value != m_bound; + default: + return m_args_value < m_bound; + } + } + }; + + struct var_info { + rational m_value; + rational m_best_value; + var_kind m_kind = var_kind::INT; + vector> m_atoms; + }; + + struct atom_info { + ineq m_ineq; + unsigned m_clause_idx; + bool m_is_bool = false; + bool m_phase = false; + bool m_best_phase = false; + unsigned m_breaks = 0; + }; + + solver& s; + ast_manager& m; + unsigned m_max_arith_steps = 0; + stats m_stats; + vector m_atoms; + vector m_vars; + + bool flip(); + void log() {} + bool flip_unsat() { return false; } + bool flip_clauses() { return false; } + bool flip_dscore() { return false; } +// bool flip_dscore(clause const&); +// bool flip(clause const&); + rational dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); } + rational dtt(rational const& args, ineq const& ineq) const; + rational dtt(ineq const& ineq, var_t v, rational const& new_value) const; +// rational dts(clause const& cl, var_t v, rational const& new_value) const; +// rational dts(clause const& cl) const; + bool cm(ineq const& ineq, var_t v, rational& new_value); + + rational value(var_t v) const { return m_vars[v].m_value; } + public: + sls(solver& s); + void operator ()(bool_vector& phase); + void set_bounds_begin(); + void set_bounds_end(unsigned num_literals); + void set_bounds(enode* n); + }; + + sls m_local_search; + typedef vector> var_coeffs; vector m_columns; var_coeffs m_left_side; // constraint left side @@ -233,10 +304,10 @@ namespace arith { unsigned m_asserted_qhead = 0; svector > m_assume_eq_candidates; - unsigned m_assume_eq_head{ 0 }; + unsigned m_assume_eq_head = 0; lp::u_set m_tmp_var_set; - unsigned m_num_conflicts{ 0 }; + unsigned m_num_conflicts = 0; lp_api::stats m_stats; svector m_scopes; @@ -515,6 +586,11 @@ namespace arith { bool enable_ackerman_axioms(euf::enode* n) const override { return !a.is_add(n->get_expr()); } bool has_unhandled() const override { return m_not_handled != nullptr; } + void set_bounds_begin() override { m_local_search.set_bounds_begin(); } + void set_bounds_end(unsigned num_literals) override { m_local_search.set_bounds_end(num_literals); } + void set_bounds(enode* n) override { m_local_search.set_bounds(n); } + void local_search(bool_vector& phase) override { m_local_search(phase); } + // bounds and equality propagation callbacks lp::lar_solver& lp() { return *m_solver; } lp::lar_solver const& lp() const { return *m_solver; } @@ -523,4 +599,7 @@ namespace arith { void consume(rational const& v, lp::constraint_index j); bool bound_is_interesting(unsigned vi, lp::lconstraint_kind kind, const rational& bval) const; }; + + + } diff --git a/src/sat/smt/euf_local_search.cpp b/src/sat/smt/euf_local_search.cpp new file mode 100644 index 000000000..5ea30b96c --- /dev/null +++ b/src/sat/smt/euf_local_search.cpp @@ -0,0 +1,126 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + euf_local_search.cpp + +Abstract: + + Local search dispatch for SMT + +Author: + + Nikolaj Bjorner (nbjorner) 2023-02-07 + +--*/ +#include "sat/sat_solver.h" +#include "sat/sat_ddfw.h" +#include "sat/smt/euf_solver.h" + + +namespace euf { + + void solver::local_search(bool_vector& phase) { + + + scoped_limits scoped_rl(m.limit()); + sat::ddfw bool_search; + bool_search.add(s()); + bool_search.updt_params(s().params()); + bool_search.set_seed(rand()); + scoped_rl.push_child(&(bool_search.rlimit())); + + unsigned rounds = 0; + unsigned max_rounds = 30; + + sat::model mdl(s().num_vars()); + for (unsigned v = 0; v < s().num_vars(); ++v) + mdl[v] = s().value(v); + + + while (m.inc() && rounds < max_rounds) { + setup_bounds(mdl); + bool_search.reinit(s(), phase); + + // Non-boolean literals are assumptions to Boolean search + literal_vector _lits; + for (unsigned v = 0; v < mdl.size(); ++v) + if (!is_propositional(literal(v))) + _lits.push_back(literal(v, mdl[v] == l_false)); + + bool_search.rlimit().push(m_max_bool_steps); + + lbool r = bool_search.check(_lits.size(), _lits.data(), nullptr); + + + auto const& mdl = bool_search.get_model(); + for (unsigned i = 0; i < mdl.size(); ++i) + phase[i] = mdl[i] == l_true; + + for (auto* th : m_solvers) + th->local_search(phase); + ++rounds; + // if is_sat break; + } + + } + + bool solver::is_propositional(sat::literal lit) { + expr* e = m_bool_var2expr.get(lit.var(), nullptr); + if (!e) + return true; + if (is_uninterp_const(e)) + return true; + euf::enode* n = m_egraph.find(e); + if (!n) + return true; + } + + void solver::setup_bounds(sat::model const& mdl) { + unsigned num_literals = 0; + unsigned num_bool = 0; + for (auto* th : m_solvers) + th->set_bounds_begin(); + + auto init_literal = [&](sat::literal l) { + if (is_propositional(l)) { + ++num_bool; + return; + } + euf::enode* n = m_egraph.find(m_bool_var2expr.get(l.var(), nullptr)); + for (auto const& thv : enode_th_vars(n)) { + auto* th = m_id2solver.get(thv.get_id(), nullptr); + if (th) + th->set_bounds(n); + } + }; + + auto is_true = [&](auto lit) { + return mdl[lit.var()] == to_lbool(!lit.sign()); + }; + + svector bin_clauses; + s().collect_bin_clauses(bin_clauses, false, false); + for (auto* cp : s().clauses()) { + if (any_of(*cp, [&](auto lit) { return is_true(lit); })) + continue; + num_literals += cp->size(); + for (auto l : *cp) + init_literal(l); + } + + for (auto [l1, l2] : bin_clauses) { + if (is_true(l1) || is_true(l2)) + continue; + num_literals += 2; + init_literal(l1); + init_literal(l2); + }; + + m_max_bool_steps = (m_ls_config.L * num_bool) / num_literals; + + for (auto* th : m_solvers) + th->set_bounds_end(num_literals); + } +} diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 96079fbec..44e3df4b0 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -100,6 +100,14 @@ namespace euf { scope(unsigned l) : m_var_lim(l) {} }; + struct local_search_config { + double cb = 0.0; + unsigned L = 20; + unsigned t = 45; + unsigned max_no_improve = 500000; + double sp = 0.0003; + }; + size_t* to_ptr(sat::literal l) { return TAG(size_t*, reinterpret_cast((size_t)(l.index() << 4)), 1); } size_t* to_ptr(size_t jst) { return TAG(size_t*, reinterpret_cast(jst), 2); } @@ -119,6 +127,7 @@ namespace euf { sat::sat_internalizer& si; relevancy m_relevancy; smt_params m_config; + local_search_config m_ls_config; euf::egraph m_egraph; trail_stack m_trail; stats m_stats; @@ -253,6 +262,11 @@ namespace euf { constraint& eq_constraint() { return mk_constraint(m_eq, constraint::kind_t::eq); } constraint& lit_constraint(enode* n); + // local search + unsigned m_max_bool_steps = 10; + bool is_propositional(sat::literal lit); + void setup_bounds(sat::model const& mdl); + // user propagator void check_for_user_propagator() { if (!m_user_propagator) @@ -339,6 +353,7 @@ namespace euf { void add_assumptions(sat::literal_set& assumptions) override; bool tracking_assumptions() override; std::string reason_unknown() override { return m_reason_unknown; } + void local_search(bool_vector& phase) override; void propagate(literal lit, ext_justification_idx idx); bool propagate(enode* a, enode* b, ext_justification_idx idx); @@ -551,4 +566,3 @@ namespace euf { inline std::ostream& operator<<(std::ostream& out, euf::solver const& s) { return s.display(out); } - diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index a3b81a08d..2101dfd64 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -136,6 +136,15 @@ namespace euf { sat::status status() const { return sat::status::th(false, get_id()); } + /** + * Local search interface + */ + virtual void set_bounds_begin() {} + + virtual void set_bounds_end(unsigned num_literals) {} + + virtual void set_bounds(enode* n) {} + }; class th_proof_hint : public sat::proof_hint {