From 7bef2f3e6f312b979077ed53dfb4218659321b84 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 11 Feb 2023 09:33:35 -0800 Subject: [PATCH] wip - local search for euf/arithmetic --- src/sat/sat_ddfw.cpp | 38 +-- src/sat/sat_ddfw.h | 26 +- src/sat/sat_local_search.cpp | 17 +- src/sat/sat_local_search.h | 16 +- src/sat/smt/CMakeLists.txt | 2 +- src/sat/smt/arith_sls.cpp | 511 +++++++++++++++++++++++++++++++ src/sat/smt/arith_sls.h | 147 +++++++++ src/sat/smt/arith_solver.h | 113 +------ src/sat/smt/euf_local_search.cpp | 18 +- 9 files changed, 716 insertions(+), 172 deletions(-) create mode 100644 src/sat/smt/arith_sls.cpp create mode 100644 src/sat/smt/arith_sls.h diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index 74a0e7777..98e3ce2bd 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -49,6 +49,7 @@ namespace sat { else if (should_parallel_sync()) do_parallel_sync(); else shift_weights(); } + remove_assumptions(); log(); return m_min_sz == 0 ? l_true : l_undef; } @@ -244,7 +245,6 @@ namespace sat { m_use_list_index.push_back(m_flat_use_list.size()); } - void ddfw::flip(bool_var v) { ++m_flips; literal lit = literal(v, !value(v)); @@ -309,19 +309,15 @@ namespace sat { log(); if (m_reinit_count % 2 == 0) { - for (auto& ci : m_clauses) { - ci.m_weight += 1; - } + for (auto& ci : m_clauses) + ci.m_weight += 1; } else { - for (auto& ci : m_clauses) { - if (ci.is_true()) { - ci.m_weight = m_config.m_init_clause_weight; - } - else { - ci.m_weight = m_config.m_init_clause_weight + 1; - } - } + for (auto& ci : m_clauses) + if (ci.is_true()) + ci.m_weight = m_config.m_init_clause_weight; + else + ci.m_weight = m_config.m_init_clause_weight + 1; } init_clause_data(); ++m_reinit_count; @@ -341,11 +337,9 @@ namespace sat { clause const& c = get_clause(i); ci.m_trues = 0; ci.m_num_trues = 0; - for (literal lit : c) { - if (is_true(lit)) { - ci.add(lit); - } - } + for (literal lit : c) + if (is_true(lit)) + ci.add(lit); switch (ci.m_num_trues) { case 0: for (literal lit : c) { @@ -384,12 +378,10 @@ namespace sat { void ddfw::reinit_values() { for (unsigned i = 0; i < num_vars(); ++i) { int b = bias(i); - if (0 == (m_rand() % (1 + abs(b)))) { - value(i) = (m_rand() % 2) == 0; - } - else { - value(i) = bias(i) > 0; - } + if (0 == (m_rand() % (1 + abs(b)))) + value(i) = (m_rand() % 2) == 0; + else + value(i) = bias(i) > 0; } } diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index 7dd69de81..9971033e3 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -43,6 +43,17 @@ namespace sat { void add(literal lit) { ++m_num_trues; m_trues += lit.index(); } void del(literal lit) { SASSERT(m_num_trues > 0); --m_num_trues; m_trues -= lit.index(); } }; + + class use_list { + ddfw& p; + unsigned i; + public: + use_list(ddfw& p, literal lit) : + p(p), i(lit.index()) {} + unsigned const* begin() { return p.m_flat_use_list.data() + p.m_use_list_index[i]; } + unsigned const* end() { return p.m_flat_use_list.data() + p.m_use_list_index[i + 1]; } + }; + protected: struct config { @@ -102,15 +113,7 @@ namespace sat { parallel* m_par; - class use_list { - ddfw& p; - unsigned i; - public: - use_list(ddfw& p, literal lit): - p(p), i(lit.index()) {} - unsigned const* begin() { return p.m_flat_use_list.data() + p.m_use_list_index[i]; } - unsigned const* end() { return p.m_flat_use_list.data() + p.m_use_list_index[i + 1]; } - }; + void flatten_use_list(); @@ -163,7 +166,6 @@ namespace sat { // flip activity bool do_flip(); bool_var pick_var(); - void flip(bool_var v); void save_best_values(); void save_model(); void save_priorities(); @@ -245,6 +247,10 @@ namespace sat { void remove_assumptions(); + void flip(bool_var v); + + use_list get_use_list(literal lit) { return use_list(*this, lit); } + }; } diff --git a/src/sat/sat_local_search.cpp b/src/sat/sat_local_search.cpp index 8cc90f05e..c3cb0fb37 100644 --- a/src/sat/sat_local_search.cpp +++ b/src/sat/sat_local_search.cpp @@ -353,10 +353,7 @@ namespace sat { DEBUG_CODE(verify_unsat_stack();); } - local_search::local_search() : - m_is_unsat(false), - m_initializing(false), - m_par(nullptr) { + local_search::local_search() { } void local_search::reinit(solver& s, bool_vector const& phase) { @@ -375,11 +372,10 @@ namespace sat { m_vars.reserve(s.num_vars()); m_config.set_config(s.get_config()); - if (m_config.phase_sticky()) { - unsigned v = 0; + unsigned v = 0; + if (m_config.phase_sticky()) for (var_info& vi : m_vars) - vi.m_bias = s.m_phase[v++] ? 98 : 2; - } + vi.m_bias = s.m_phase[v++] ? 98 : 2; // copy units unsigned trail_sz = s.init_trail_size(); @@ -419,9 +415,8 @@ namespace sat { if (ext && (!ext->is_pb() || !ext->extract_pb(card, pb))) throw default_exception("local search is incomplete with extensions beyond PB"); - if (_init) { - init(); - } + if (_init) + init(); } local_search::~local_search() { diff --git a/src/sat/sat_local_search.h b/src/sat/sat_local_search.h index 7295b851a..b62234522 100644 --- a/src/sat/sat_local_search.h +++ b/src/sat/sat_local_search.h @@ -133,21 +133,21 @@ namespace sat { vector m_constraints; // all constraints literal_vector m_assumptions; // temporary assumptions literal_vector m_prop_queue; // propagation queue - unsigned m_num_non_binary_clauses; - bool m_is_pb; - bool m_is_unsat; + unsigned m_num_non_binary_clauses = 0; + bool m_is_pb = false; + bool m_is_unsat = false; unsigned_vector m_unsat_stack; // store all the unsat constraints unsigned_vector m_index_in_unsat_stack; // which position is a constraint in the unsat_stack // configuration changed decreasing variables (score>0 and conf_change==true) bool_var_vector m_goodvar_stack; - bool m_initializing; + bool m_initializing = false; // information about solution - unsigned m_best_unsat; - double m_best_unsat_rate; - double m_last_best_unsat_rate; + unsigned m_best_unsat = 0; + double m_best_unsat_rate = 0; + double m_last_best_unsat_rate = 0; // for non-known instance, set as maximal int m_best_known_value = INT_MAX; // best known value for this instance @@ -159,7 +159,7 @@ namespace sat { reslimit m_limit; random_gen m_rand; - parallel* m_par; + parallel* m_par = nullptr; model m_model; inline int score(bool_var v) const { return m_vars[v].m_score; } diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index dbbfc3856..4a899ca9d 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -3,7 +3,7 @@ z3_add_component(sat_smt arith_axioms.cpp arith_diagnostics.cpp arith_internalize.cpp - arith_local_search.cpp + arith_sls.cpp arith_solver.cpp array_axioms.cpp array_diagnostics.cpp diff --git a/src/sat/smt/arith_sls.cpp b/src/sat/smt/arith_sls.cpp new file mode 100644 index 000000000..460bdedce --- /dev/null +++ b/src/sat/smt/arith_sls.cpp @@ -0,0 +1,511 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + arith_local_search.cpp + +Abstract: + + Local search dispatch for SMT + +Author: + + Nikolaj Bjorner (nbjorner) 2023-02-07 + +--*/ +#include "sat/sat_solver.h" +#include "sat/smt/arith_solver.h" + + +namespace arith { + + + /// + /// need to initialize ineqs (arithmetical atoms) + /// + + + sls::sls(solver& s): + s(s), m(s.m) {} + + void sls::operator()(bool_vector& phase) { + + m_best_min_unsat = unsat().size(); + unsigned num_steps = 0; + while (m.inc() && m_best_min_unsat > 0 && num_steps < m_max_arith_steps) { + if (!flip()) + break; + ++m_stats.m_num_flips; + ++num_steps; + unsigned num_unsat = unsat().size(); + if (num_unsat < m_best_min_unsat) { + m_best_min_unsat = num_unsat; + num_steps = 0; + save_best_values(); + } + } + IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << unsat().size() << ")\n"); + } + + void sls::save_best_values() { + // first compute assignment to terms + // then update non-basic variables in tableau, assuming a sat solution was found. +#if false + for (auto const& [t, v] : terms) { + rational val; + lp::lar_term const& term = lp().get_term(t); + for (lp::lar_term::ival arg : term) { + auto t2 = lp().column2tv(arg.column()); + auto w = lp().local_to_external(t2.id()); + val += arg.coeff() * local_search.value(w); + } + update(v, val); + } +#endif + + for (unsigned v = 0; v < s.get_num_vars(); ++v) { + if (s.is_bool(v)) + continue; + if (!s.lp().external_is_used(v)) + continue; + rational old_value = s.is_registered_var(v) ? s.get_ivalue(v).x : rational::zero(); + rational new_value = value(v); + if (old_value == new_value) + continue; + s.ensure_column(v); + lp::column_index vj = s.lp().to_column_index(v); + SASSERT(!vj.is_null()); + if (!s.lp().is_base(vj.index())) { + lp::impq val(new_value); + s.lp().set_value_for_nbasic_column(vj.index(), val); + } + } + } + + void sls::set(sat::ddfw* d) { + m_bool_search = d; + add_vars(); + m_clauses.resize(d->num_clauses()); + for (unsigned i = 0; i < d->num_clauses(); ++i) + for (sat::literal lit : *d->get_clause_info(i).m_clause) + init_literal(lit); + } + + void sls::set_bounds_begin() { + m_max_arith_steps = 0; + } + + void sls::set_bounds(enode* n) { + ++m_max_arith_steps; + } + + void sls::set_bounds_end(unsigned num_literals) { + m_max_arith_steps = (m_config.L * m_max_arith_steps) / num_literals; + } + + bool sls::flip() { + log(); + return flip_unsat() || flip_clauses() || flip_dscore(); + } + + // distance to true + rational sls::dtt(rational const& args, ineq const& ineq) const { + switch (ineq.m_op) { + case ineq_kind::LE: + if (args <= ineq.m_bound) + return rational::zero(); + return args - ineq.m_bound; + case ineq_kind::EQ: + if (args == ineq.m_bound) + return rational::zero(); + return rational::one(); + case ineq_kind::NE: + if (args == ineq.m_bound) + return rational::one(); + return rational::zero(); + case ineq_kind::LT: + if (args < ineq.m_bound) + return rational::zero(); + return args - ineq.m_bound + 1; + default: + UNREACHABLE(); + return rational::zero(); + } + } + + rational sls::dtt(ineq const& ineq, var_t v, rational const& new_value) const { + auto new_args_value = ineq.m_args_value; + for (auto const& [coeff, w] : ineq.m_args) { + if (w == v) { + new_args_value += coeff * (new_value - m_vars[w].m_value); + break; + } + } + return dtt(new_args_value, ineq); + } + + // critical move + bool sls::cm(ineq const& ineq, var_t v, rational& new_value) { + SASSERT(!ineq.is_true()); + auto delta = ineq.m_args_value - ineq.m_bound; + for (auto const& [coeff, w] : ineq.m_args) { + if (w == v) { + if (coeff > 0) + new_value = value(v) - abs(ceil(delta / coeff)); + else + new_value = value(v) + abs(floor(delta / coeff)); + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(delta + coeff * (new_value - value(v)) <= 0); + return true; + case ineq_kind::EQ: + return delta + coeff * (new_value - value(v)) == 0; + case ineq_kind::NE: + return delta + coeff * (new_value - value(v)) != 0; + case ineq_kind::LT: + return delta + coeff * (new_value - value(v)) < 0; + default: + UNREACHABLE(); + break; + } + } + } + return false; + } + + bool sls::flip_unsat() { + unsigned start = s.random(); + unsigned sz = unsat().size(); + for (unsigned i = sz; i-- > 0; ) + if (flip(unsat().elem_at((i + start) % sz))) + return true; + return false; + } + + bool sls::flip(unsigned cl) { + auto const& clause = get_clause(cl); + rational new_value; + for (literal lit : clause) { + auto const* ineq = atom(lit); + if (!ineq || ineq->is_true()) + continue; + for (auto const& [coeff, v] : ineq->m_args) { + if (!cm(*ineq, v, new_value)) + continue; + int score = cm_score(v, new_value); + if (score <= 0) + continue; + unsigned num_unsat = unsat().size(); + update(v, new_value); + IF_VERBOSE(2, + verbose_stream() << "score " << v << " " << score << "\n" + << num_unsat << " -> " << unsat().size() << "\n"); + return true; + } + } + return false; + } + + bool sls::flip_clauses() { + unsigned start = s.random(); + unsigned sz = m_bool_search->num_clauses(); + for (unsigned i = sz; i-- > 0; ) + if (flip((i + start) % sz)) + return true; + return false; + } + + bool sls::flip_dscore() { + paws(); + unsigned start = s.random(); + unsigned sz = unsat().size(); + for (unsigned i = sz; i-- > 0; ) + if (flip_dscore(unsat().elem_at((i + start) % sz))) + return true; + return false; + } + + bool sls::flip_dscore(unsigned cl) { + auto const& clause = get_clause(cl); + rational new_value, min_value, min_score(-1); + var_t min_var = UINT_MAX; + for (auto lit : clause) { + auto const* ineq = atom(lit); + if (!ineq || ineq->is_true()) + continue; + for (auto const& [coeff, v] : ineq->m_args) { + if (cm(*ineq, v, new_value)) { + rational score = dscore(v, new_value); + if (UINT_MAX == min_var || score < min_score) { + min_var = v; + min_value = new_value; + min_score = score; + } + } + } + } + if (min_var != UINT_MAX) { + update(min_var, min_value); + return true; + } + return false; + } + + /** + * redistribute weights of clauses. TODO - re-use ddfw weights instead. + */ + void sls::paws() { + for (unsigned cl = num_clauses(); cl-- > 0; ) { + auto& clause = get_clause_info(cl); + bool above = 10000 * m_config.sp <= (s.random() % 10000); + if (!above && clause.is_true() && get_weight(cl) > 1) + get_weight(cl) -= 1; + if (above && !clause.is_true()) + get_weight(cl) += 1; + } + } + + // + // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) + // + rational sls::dscore(var_t v, rational const& new_value) const { + auto const& vi = m_vars[v]; + rational score(0); + for (auto const& [coeff, lit] : vi.m_literals) + for (auto cl : m_bool_search->get_use_list(lit)) + score += (dts(cl) - dts(cl, v, new_value)) * rational(get_weight(cl)); + return score; + } + + int sls::cm_score(var_t v, rational const& new_value) { + int score = 0; + auto& vi = m_vars[v]; + for (auto const& [coeff, lit] : vi.m_literals) { + auto const& ineq = *atom(lit); + rational dtt_old = dtt(ineq); + rational dtt_new = dtt(ineq, v, new_value); + for (auto cl : m_bool_search->get_use_list(lit)) { + auto const& clause = get_clause_info(cl); + if (!clause.is_true()) { + if (dtt_new == 0) + ++score; // false -> true + } + else if (dtt_new == 0 || dtt_old > 0 || clause.m_num_trues > 0) // true -> true ?? TODO + continue; + else if (all_of(*clause.m_clause, [&](auto lit) { return !atom(lit) || dtt(*atom(lit), v, new_value) > 0; })) // ?? TODO + --score; + } + } + return score; + } + + rational sls::dts(unsigned cl) const { + rational d(1), d2; + bool first = true; + for (auto a : get_clause(cl)) { + auto const* ineq = atom(a); + if (!ineq) + continue; + d2 = dtt(*ineq); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + rational sls::dts(unsigned cl, var_t v, rational const& new_value) const { + rational d(1), d2; + bool first = true; + for (auto lit : get_clause(cl)) { + auto const* ineq = atom(lit); + if (!ineq) + continue; + d2 = dtt(*ineq, v, new_value); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + void sls::update(var_t v, rational const& new_value) { + auto& vi = m_vars[v]; + auto const& old_value = vi.m_value; + for (auto const& [coeff, lit] : vi.m_literals) { + auto& ineq = *atom(lit); + rational dtt_old = dtt(ineq); + ineq.m_args_value += coeff * (new_value - old_value); + rational dtt_new = dtt(ineq); + SASSERT(!(dtt_new == 0 && dtt_new < dtt_old) || m_bool_search->get_value(lit.var()) == lit.sign()); + SASSERT(!(dtt_old == 0 && dtt_new > dtt_old) || m_bool_search->get_value(lit.var()) != lit.sign()); + if (dtt_new == 0 && dtt_new < dtt_old) // flip from false to true + m_bool_search->flip(lit.var()); + else if (dtt_old == 0 && dtt_old < dtt_new) // flip from true to false + m_bool_search->flip(lit.var()); + dtt(ineq) = dtt_new; + SASSERT((dtt_new == 0) == (m_bool_search->get_value(lit.var()) != lit.sign())); + } + vi.m_value = new_value; + } + + void sls::add_vars() { + SASSERT(m_vars.empty()); + for (unsigned v = 0; v < s.get_num_vars(); ++v) { + rational value = s.is_registered_var(v) ? s.get_ivalue(v).x : rational::zero(); + value = s.is_int(v) ? ceil(value) : value; + auto k = s.is_int(v) ? sls::var_kind::INT : sls::var_kind::REAL; + m_vars.push_back({ value, value, k, {} }); + } + } + + sls::ineq& sls::new_ineq(ineq_kind op, rational const& bound) { + auto* i = alloc(ineq); + i->m_bound = bound; + i->m_op = op; + return *i; + } + + void sls::add_arg(sat::literal lit, ineq& ineq, rational const& c, var_t v) { + ineq.m_args.push_back({ c, v }); + ineq.m_args_value += c * value(v); + m_vars[v].m_literals.push_back({ c, lit }); + } + + void sls::add_bounds(sat::literal_vector& bounds) { + unsigned bvars = s.s().num_vars(); + auto add_ineq = [&](sat::literal lit, ineq& i) { + m_literals.set(lit.index(), &i); + bounds.push_back(lit); + }; + for (unsigned v = 0; v < s.get_num_vars(); ++v) { + rational lo, hi; + bool is_strict_lo = false, is_strict_hi = false; + lp::constraint_index ci; + if (!s.is_registered_var(v)) + continue; + lp::column_index vi = s.lp().to_column_index(v); + if (vi.is_null()) + continue; + bool has_lo = s.lp().has_lower_bound(vi.index(), ci, lo, is_strict_lo); + bool has_hi = s.lp().has_upper_bound(vi.index(), ci, hi, is_strict_hi); + + if (has_lo && has_hi && lo == hi) { + auto& ineq = new_ineq(sls::ineq_kind::EQ, lo); + sat::literal lit(bvars++); + add_arg(lit, ineq, rational::one(), v); + add_ineq(lit, ineq); + continue; + } + if (has_lo) { + auto& ineq = new_ineq(is_strict_lo ? sls::ineq_kind::LT : sls::ineq_kind::LE, -lo); + sat::literal lit(bvars++); + add_arg(lit, ineq, -rational::one(), v); + add_ineq(lit, ineq); + } + if (has_hi) { + auto& ineq = new_ineq(is_strict_hi ? sls::ineq_kind::LT : sls::ineq_kind::LE, hi); + sat::literal lit(bvars++); + add_arg(lit, ineq, rational::one(), v); + add_ineq(lit, ineq); + } + } + } + + + void sls::add_args(ineq& ineq, lp::tv t, theory_var v, rational sign) { + if (t.is_term()) { + lp::lar_term const& term = s.lp().get_term(t); + + for (lp::lar_term::ival arg : term) { + auto t2 = s.lp().column2tv(arg.column()); + auto w = s.lp().local_to_external(t2.id()); + ineq.m_args.push_back({ sign * arg.coeff(), w }); + } + } + else + ineq.m_args.push_back({ sign, s.lp().local_to_external(t.id()) }); + } + + + void sls::init_literal(sat::literal lit) { + if (m_literals.get(lit.index(), nullptr)) + return; + api_bound* b = nullptr; + s.m_bool_var2bound.find(lit.var(), b); + if (b) { + auto t = b->tv(); + rational bound = b->get_value(); + bool should_minus = false; + sls::ineq_kind op; + if (!lit.sign()) { + should_minus = b->get_bound_kind() == lp::GE; + op = sls::ineq_kind::LE; + } + else { + should_minus = b->get_bound_kind() == lp::LE; + if (s.is_int(b->get_var())) { + bound -= 1; + op = sls::ineq_kind::LE; + } + else + op = sls::ineq_kind::LT; + + } + if (should_minus) + bound.neg(); + auto& ineq = new_ineq(op, bound); + + add_args(ineq, t, b->get_var(), should_minus ? rational::minus_one() :rational::one()); + set_literal(lit, ineq); + return; + } + + expr* e = s.bool_var2expr(lit.var()); + expr* l = nullptr, * r = nullptr; + if (e && m.is_eq(e, l, r) && s.a.is_int_real(l)) { + theory_var u = s.get_th_var(l); + theory_var v = s.get_th_var(r); + lp::tv tu = s.get_tv(u); + lp::tv tv = s.get_tv(v); + auto& ineq = new_ineq(lit.sign() ? sls::ineq_kind::NE : sls::ineq_kind::EQ, rational::zero()); + add_args(ineq, tu, u, rational::one()); + add_args(ineq, tv, v, -rational::one()); + set_literal(lit, ineq); + return; + } + } + + /** + * Associate literal with inequality and synchronize truth assignment based on arithmetic values. + */ + void sls::set_literal(sat::literal lit, ineq& ineq) { + m_literals.set(lit.index(), &ineq); + if (m_bool_search->get_value(lit.var())) { + if (dtt(ineq) != 0) + m_bool_search->flip(lit.var()); + } + else { + if (dtt(ineq) == 0) + m_bool_search->flip(lit.var()); + } + } + +#if 0 + + { + + } +} + + +#endif +} + diff --git a/src/sat/smt/arith_sls.h b/src/sat/smt/arith_sls.h new file mode 100644 index 000000000..dcb935610 --- /dev/null +++ b/src/sat/smt/arith_sls.h @@ -0,0 +1,147 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + arith_local_search.h + +Abstract: + + Theory plugin for arithmetic local search + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ +#pragma once + +#include "util/obj_pair_set.h" +#include "ast/ast_trail.h" +#include "ast/arith_decl_plugin.h" +#include "math/lp/lp_solver.h" +#include "math/lp/lp_primal_simplex.h" +#include "math/lp/lp_dual_simplex.h" +#include "math/lp/indexed_value.h" +#include "math/lp/lar_solver.h" +#include "math/lp/nla_solver.h" +#include "math/lp/lp_types.h" +#include "math/lp/lp_api.h" +#include "math/polynomial/algebraic_numbers.h" +#include "math/polynomial/polynomial.h" +#include "sat/smt/sat_th.h" +#include "sat/sat_ddfw.h" + +namespace arith { + + class solver; + + // local search portion for arithmetic + class sls { + enum class ineq_kind { EQ, LE, LT, NE }; + enum class var_kind { INT, REAL }; + typedef unsigned var_t; + typedef unsigned atom_t; + + struct config { + double cb = 0.0; + unsigned L = 20; + unsigned t = 45; + unsigned max_no_improve = 500000; + double sp = 0.0003; + }; + + struct stats { + unsigned m_num_flips = 0; + }; + + // encode args <= bound, args = bound, args < bound + struct ineq { + vector> m_args; + ineq_kind m_op = ineq_kind::LE; + rational m_bound; + rational m_args_value; + + bool is_true() const { + switch (m_op) { + case ineq_kind::LE: + return m_args_value <= m_bound; + case ineq_kind::EQ: + return m_args_value == m_bound; + case ineq_kind::NE: + return m_args_value != m_bound; + default: + return m_args_value < m_bound; + } + } + }; + + struct var_info { + rational m_value; + rational m_best_value; + var_kind m_kind = var_kind::INT; + vector> m_literals; + }; + + struct clause { + unsigned m_weight = 1; + }; + + solver& s; + ast_manager& m; + sat::ddfw* m_bool_search = nullptr; + unsigned m_max_arith_steps = 0; + unsigned m_best_min_unsat = UINT_MAX; + stats m_stats; + config m_config; + scoped_ptr_vector m_literals; + vector m_vars; + vector m_clauses; + + indexed_uint_set& unsat() { return m_bool_search->unsat_set(); } + unsigned num_clauses() const { return m_bool_search->num_clauses(); } + sat::clause& get_clause(unsigned idx) { return *get_clause_info(idx).m_clause; } + sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; } + sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); } + sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); } + + ineq* atom(sat::literal lit) const { return m_literals[lit.index()]; } + unsigned& get_weight(unsigned idx) { return m_clauses[idx].m_weight; } + unsigned get_weight(unsigned idx) const { return m_clauses[idx].m_weight; } + bool flip(); + void log() {} + bool flip_unsat(); + bool flip_clauses(); + bool flip_dscore(); + bool flip_dscore(unsigned cl); + bool flip(unsigned cl); + rational dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); } + rational dtt(rational const& args, ineq const& ineq) const; + rational dtt(ineq const& ineq, var_t v, rational const& new_value) const; + rational dts(unsigned cl, var_t v, rational const& new_value) const; + rational dts(unsigned cl) const; + bool cm(ineq const& ineq, var_t v, rational& new_value); + int cm_score(var_t v, rational const& new_value); + void update(var_t v, rational const& new_value); + void paws(); + rational dscore(var_t v, rational const& new_value) const; + void save_best_values(); + void add_vars(); + sls::ineq& new_ineq(ineq_kind op, rational const& bound); + void add_arg(sat::literal lit, ineq& ineq, rational const& c, var_t v); + void add_bounds(sat::literal_vector& bounds); + void add_args(ineq& ineq, lp::tv t, euf::theory_var v, rational sign); + void init_literal(sat::literal lit); + void set_literal(sat::literal lit, ineq& ineq); + + rational value(var_t v) const { return m_vars[v].m_value; } + public: + sls(solver& s); + void operator ()(bool_vector& phase); + void set_bounds_begin(); + void set_bounds_end(unsigned num_literals); + void set_bounds(euf::enode* n); + void set(sat::ddfw* d); + }; + +} diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 17207bd80..732e291b1 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -30,6 +30,7 @@ Author: #include "math/polynomial/algebraic_numbers.h" #include "math/polynomial/polynomial.h" #include "sat/smt/sat_th.h" +#include "sat/smt/arith_sls.h" #include "sat/sat_ddfw.h" namespace euf { @@ -98,6 +99,7 @@ namespace arith { class solver : public euf::th_euf_solver { friend struct arith_proof_hint; + friend class sls; struct scope { unsigned m_bounds_lim; @@ -190,116 +192,7 @@ namespace arith { coeffs().pop_back(); } }; - - // local search portion for arithmetic - class sls { - enum class ineq_kind { EQ, LE, LT, NE }; - enum class var_kind { INT, REAL }; - typedef unsigned var_t; - typedef unsigned atom_t; - - struct config { - double cb = 0.0; - unsigned L = 20; - unsigned t = 45; - unsigned max_no_improve = 500000; - double sp = 0.0003; - }; - - struct stats { - unsigned m_num_flips = 0; - }; - // encode args <= bound, args = bound, args < bound - struct ineq { - vector> m_args; - ineq_kind m_op = ineq_kind::LE; - rational m_bound; - rational m_args_value; - - bool is_true() const { - switch (m_op) { - case ineq_kind::LE: - return m_args_value <= m_bound; - case ineq_kind::EQ: - return m_args_value == m_bound; - case ineq_kind::NE: - return m_args_value != m_bound; - default: - return m_args_value < m_bound; - } - } - }; - - struct var_info { - rational m_value; - rational m_best_value; - var_kind m_kind = var_kind::INT; - vector> m_atoms; - }; - - struct atom_info { - ineq m_ineq; - unsigned m_clause_idx; - bool m_is_bool = false; - bool m_phase = false; - bool m_best_phase = false; - unsigned m_breaks = 0; - }; - - struct clause { - unsigned m_weight = 1; - rational m_dts = rational::one(); - }; - - solver& s; - ast_manager& m; - sat::ddfw* m_bool_search = nullptr; - unsigned m_max_arith_steps = 0; - unsigned m_best_min_unsat = UINT_MAX; - stats m_stats; - config m_config; - scoped_ptr_vector m_atoms; - vector m_vars; - vector m_clauses; - - indexed_uint_set& unsat() { return m_bool_search->unsat_set(); } - unsigned num_clauses() const { return m_bool_search->num_clauses(); } - sat::clause& get_clause(unsigned idx) { return *get_clause_info(idx).m_clause; } - sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; } - sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); } - sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); } - - atom_info* atom(sat::literal lit) const { return m_atoms[lit.index()]; } - rational& dts(unsigned idx) { return m_clauses[idx].m_dts; } - bool flip(); - void log() {} - bool flip_unsat(); - bool flip_clauses(); - bool flip_dscore(); - bool flip_dscore(unsigned cl); - bool flip(unsigned cl); - rational dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); } - rational dtt(rational const& args, ineq const& ineq) const; - rational dtt(ineq const& ineq, var_t v, rational const& new_value) const; - rational dts(unsigned cl, var_t v, rational const& new_value) const; - rational dts(unsigned cl) const; - bool cm(ineq const& ineq, var_t v, rational& new_value); - int cm_score(var_t v, rational const& new_value); - void update(var_t v, rational const& new_value); - void paws(); - rational dscore(var_t v, rational const& new_value) const; - void save_best_values() {} - - rational value(var_t v) const { return m_vars[v].m_value; } - public: - sls(solver& s); - void operator ()(bool_vector& phase); - void set_bounds_begin(); - void set_bounds_end(unsigned num_literals); - void set_bounds(enode* n); - void set(sat::ddfw* d) { m_bool_search = d; } - }; - + sls m_local_search; typedef vector> var_coeffs; diff --git a/src/sat/smt/euf_local_search.cpp b/src/sat/smt/euf_local_search.cpp index 4ee6490b7..873a64b7e 100644 --- a/src/sat/smt/euf_local_search.cpp +++ b/src/sat/smt/euf_local_search.cpp @@ -24,7 +24,7 @@ namespace euf { void solver::local_search(bool_vector& phase) { scoped_limits scoped_rl(m.limit()); sat::ddfw bool_search; - bool_search.add(s()); + bool_search.reinit(s(), phase); bool_search.updt_params(s().params()); bool_search.set_seed(rand()); scoped_rl.push_child(&(bool_search.rlimit())); @@ -36,29 +36,29 @@ namespace euf { for (unsigned rounds = 0; m.inc() && rounds < max_rounds; ++rounds) { - bool_search.reinit(s(), phase); - setup_bounds(phase); // Non-boolean literals are assumptions to Boolean search literal_vector assumptions; for (unsigned v = 0; v < phase.size(); ++v) if (!is_propositional(literal(v))) - assumptions.push_back(literal(v, !phase[v])); + assumptions.push_back(literal(v, !bool_search.get_value(v))); bool_search.rlimit().push(m_max_bool_steps); lbool r = bool_search.check(assumptions.size(), assumptions.data(), nullptr); - - - auto const& mdl = bool_search.get_model(); - for (unsigned i = 0; i < mdl.size(); ++i) - phase[i] = mdl[i] == l_true; + bool_search.rlimit().pop(); for (auto* th : m_solvers) th->local_search(phase); // if is_sat break; + } + + + auto const& mdl = bool_search.get_model(); + for (unsigned i = 0; i < mdl.size(); ++i) + phase[i] = mdl[i] == l_true; }