/*++ Copyright (c) 2023 Microsoft Corporation Module Name: arith_local_search.cpp Abstract: Local search dispatch for SMT Author: Nikolaj Bjorner (nbjorner) 2023-02-07 --*/ #include "sat/sat_solver.h" #include "sat/smt/arith_solver.h" namespace arith { sls::sls(solver& s): s(s), m(s.m) {} void sls::reset() { m_bool_vars.reset(); m_vars.reset(); m_terms.reset(); } void sls::save_best_values() { for (unsigned v = 0; v < s.get_num_vars(); ++v) m_vars[v].m_best_value = m_vars[v].m_value; check_ineqs(); if (unsat().size() == 1) { auto idx = *unsat().begin(); verbose_stream() << idx << "\n"; auto const& c = *m_bool_search->m_clauses[idx].m_clause; verbose_stream() << c << "\n"; for (auto lit : c) { bool_var bv = lit.var(); ineq* i = atom(bv); if (i) verbose_stream() << lit << ": " << *i << "\n"; } verbose_stream() << "\n"; } } void sls::store_best_values() { // first compute assignment to terms // then update non-basic variables in tableau. if (!unsat().empty()) return; for (auto const& [t,v] : m_terms) { int64_t val = 0; lp::lar_term const& term = s.lp().get_term(t); for (lp::lar_term::ival const& arg : term) { auto t2 = s.lp().column2tv(arg.column()); auto w = s.lp().local_to_external(t2.id()); val += to_numeral(arg.coeff()) * m_vars[w].m_best_value; } if (v == 52) { verbose_stream() << "update v" << v << " := " << val << "\n"; for (lp::lar_term::ival const& arg : term) { auto t2 = s.lp().column2tv(arg.column()); auto w = s.lp().local_to_external(t2.id()); verbose_stream() << "v" << w << " := " << m_vars[w].m_best_value << " * " << to_numeral(arg.coeff()) << "\n"; } } m_vars[v].m_best_value = val; } for (unsigned v = 0; v < s.get_num_vars(); ++v) { if (s.is_bool(v)) continue; if (!s.lp().external_is_used(v)) continue; int64_t new_value = m_vars[v].m_best_value; s.ensure_column(v); lp::column_index vj = s.lp().to_column_index(v); SASSERT(!vj.is_null()); if (!s.lp().is_base(vj.index())) { rational new_value_(new_value, rational::i64()); lp::impq val(new_value_, rational::zero()); s.lp().set_value_for_nbasic_column(vj.index(), val); } } lbool r = s.make_feasible(); VERIFY (!unsat().empty() || r == l_true); #if 0 if (unsat().empty()) s.m_num_conflicts = s.get_config().m_arith_propagation_threshold; #endif auto check_bool_var = [&](sat::bool_var bv) { auto* ineq = m_bool_vars.get(bv, nullptr); if (!ineq) return; api_bound* b = nullptr; s.m_bool_var2bound.find(bv, b); if (!b) return; auto bound = b->get_value(); theory_var v = b->get_var(); if (s.get_phase(bv) == m_bool_search->get_model()[bv]) return; switch (b->get_bound_kind()) { case lp_api::lower_t: verbose_stream() << "v" << v << " " << bound << " <= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n"; break; case lp_api::upper_t: verbose_stream() << "v" << v << " " << bound << " >= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n"; break; } int64_t value = 0; for (auto const& [coeff, v] : ineq->m_args) { value += coeff * m_vars[v].m_best_value; } ineq->m_args_value = value; verbose_stream() << *ineq << " dtt " << dtt(false, *ineq) << " phase " << s.get_phase(bv) << " model " << m_bool_search->get_model()[bv] << "\n"; for (auto const& [coeff, v] : ineq->m_args) verbose_stream() << "v" << v << " := " << m_vars[v].m_best_value << "\n"; s.display(verbose_stream()); display(verbose_stream()); UNREACHABLE(); exit(0); }; if (unsat().empty()) { for (bool_var v = 0; v < s.s().num_vars(); ++v) check_bool_var(v); } } void sls::set(sat::ddfw* d) { m_bool_search = d; reset(); m_bool_vars.reserve(s.s().num_vars()); add_vars(); for (unsigned i = 0; i < d->num_clauses(); ++i) for (sat::literal lit : *d->get_clause_info(i).m_clause) init_bool_var(lit.var()); for (unsigned v = 0; v < s.s().num_vars(); ++v) init_bool_var_assignment(v); d->set(this); } // distance to true int64_t sls::dtt(bool sign, int64_t args, ineq const& ineq) const { switch (ineq.m_op) { case ineq_kind::LE: if (sign) { if (args <= ineq.m_bound) return ineq.m_bound - args + 1; return 0; } if (args <= ineq.m_bound) return 0; return args - ineq.m_bound; case ineq_kind::EQ: if (sign) { if (args == ineq.m_bound) return 1; return 0; } if (args == ineq.m_bound) return 0; return 1; case ineq_kind::NE: if (sign) { if (args == ineq.m_bound) return 0; return 1; } if (args == ineq.m_bound) return 1; return 0; case ineq_kind::LT: if (sign) { if (args < ineq.m_bound) return ineq.m_bound - args; return 0; } if (args < ineq.m_bound) return 0; return args - ineq.m_bound + 1; default: UNREACHABLE(); return 0; } } // // dtt is high overhead. It walks ineq.m_args // m_vars[w].m_value can be computed outside and shared among calls // different data-structures for storing coefficients // int64_t sls::dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const { for (auto const& [coeff, w] : ineq.m_args) if (w == v) return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); return 1; } int64_t sls::dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const { return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq); } bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t& new_value) { for (auto const& [coeff, w] : ineq.m_args) if (w == v) return cm(old_sign, ineq, v, coeff, new_value); return false; } bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) { SASSERT(ineq.is_true() != old_sign); VERIFY(ineq.is_true() != old_sign); auto bound = ineq.m_bound; auto argsv = ineq.m_args_value; bool solved = false; int64_t delta = argsv - bound; auto make_eq = [&]() { SASSERT(delta != 0); if (delta < 0) new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; else new_value = value(v) - (delta + abs(coeff) - 1) / coeff; solved = argsv + coeff * (new_value - value(v)) == bound; if (!solved && abs(coeff) == 1) { verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n"; verbose_stream() << new_value << " " << value(v) << " delta " << delta << " lhs " << (argsv + coeff * (new_value - value(v))) << " bound " << bound << "\n"; UNREACHABLE(); } return solved; }; auto make_diseq = [&]() { if (delta >= 0) delta++; else delta--; new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; VERIFY(argsv + coeff * (new_value - value(v)) != bound); return true; }; if (!old_sign) { switch (ineq.m_op) { case ineq_kind::LE: // args <= bound -> args > bound SASSERT(argsv <= bound); SASSERT(delta <= 0); --delta; new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; VERIFY(argsv + coeff * (new_value - value(v)) > bound); return true; case ineq_kind::LT: // args < bound -> args >= bound SASSERT(argsv <= ineq.m_bound); SASSERT(delta <= 0); new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; VERIFY(argsv + coeff * (new_value - value(v)) >= bound); return true; case ineq_kind::EQ: return make_diseq(); case ineq_kind::NE: return make_eq(); default: UNREACHABLE(); break; } } else { switch (ineq.m_op) { case ineq_kind::LE: SASSERT(argsv > ineq.m_bound); SASSERT(delta > 0); new_value = value(v) - (delta + abs(coeff) - 1) / coeff; VERIFY(argsv + coeff * (new_value - value(v)) <= bound); return true; case ineq_kind::LT: SASSERT(argsv >= ineq.m_bound); SASSERT(delta >= 0); ++delta; new_value = value(v) - (abs(delta) + abs(coeff) - 1) / coeff; VERIFY(argsv + coeff * (new_value - value(v)) < bound); return true; case ineq_kind::NE: return make_diseq(); case ineq_kind::EQ: return make_eq(); default: UNREACHABLE(); break; } } return false; } // flip on the first positive score // it could be changed to flip on maximal positive score // or flip on maximal non-negative score // or flip on first non-negative score bool sls::flip(bool sign, ineq const& ineq) { int64_t new_value; auto v = ineq.m_var_to_flip; if (v == UINT_MAX) { IF_VERBOSE(1, verbose_stream() << "no var to flip\n"); return false; } if (!cm(sign, ineq, v, new_value)) { verbose_stream() << "no critical move for " << v << "\n"; return false; } update(v, new_value); return true; } // // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) // TODO - use cached dts instead of computed dts // cached dts has to be updated when the score of literals are updated. // double sls::dscore(var_t v, int64_t new_value) const { double score = 0; auto const& vi = m_vars[v]; for (auto const& [coeff, bv] : vi.m_bool_vars) { sat::literal lit(bv, false); for (auto cl : m_bool_search->get_use_list(lit)) score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); for (auto cl : m_bool_search->get_use_list(~lit)) score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); } return score; } // // cm_score is costly. It involves several cache misses. // Note that // - m_bool_search->get_use_list(lit).size() is "often" 1 or 2 // - dtt_old can be saved // int sls::cm_score(var_t v, int64_t new_value) { int score = 0; auto& vi = m_vars[v]; int64_t old_value = vi.m_value; for (auto const& [coeff, bv] : vi.m_bool_vars) { auto const& ineq = *atom(bv); bool old_sign = sign(bv); int64_t dtt_old = dtt(old_sign, ineq); int64_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value); if ((dtt_old == 0) == (dtt_new == 0)) continue; sat::literal lit(bv, old_sign); if (dtt_old == 0) // flip from true to false lit.neg(); // lit flips form false to true: for (auto cl : m_bool_search->get_use_list(lit)) { auto const& clause = get_clause_info(cl); if (!clause.is_true()) ++score; } // ignore the situation where clause contains multiple literals using v for (auto cl : m_bool_search->get_use_list(~lit)) { auto const& clause = get_clause_info(cl); if (clause.m_num_trues == 1) --score; } } return score; } int64_t sls::compute_dts(unsigned cl) const { int64_t d(1), d2; bool first = true; for (auto a : get_clause(cl)) { auto const* ineq = atom(a.var()); if (!ineq) continue; d2 = dtt(a.sign(), *ineq); if (first) d = d2, first = false; else d = std::min(d, d2); if (d == 0) break; } return d; } int64_t sls::dts(unsigned cl, var_t v, int64_t new_value) const { int64_t d(1), d2; bool first = true; for (auto lit : get_clause(cl)) { auto const* ineq = atom(lit.var()); if (!ineq) continue; d2 = dtt(lit.sign(), *ineq, v, new_value); if (first) d = d2, first = false; else d = std::min(d, d2); if (d == 0) break; } return d; } void sls::update(var_t v, int64_t new_value) { auto& vi = m_vars[v]; auto old_value = vi.m_value; for (auto const& [coeff, bv] : vi.m_bool_vars) { auto& ineq = *atom(bv); bool old_sign = sign(bv); sat::literal lit(bv, old_sign); SASSERT(is_true(lit)); ineq.m_args_value += coeff * (new_value - old_value); int64_t dtt_new = dtt(old_sign, ineq); if (dtt_new != 0) m_bool_search->flip(bv); SASSERT(dtt(sign(bv), ineq) == 0); } vi.m_value = new_value; } void sls::add_vars() { SASSERT(m_vars.empty()); for (unsigned v = 0; v < s.get_num_vars(); ++v) { int64_t value = s.is_registered_var(v) ? to_numeral(s.get_ivalue(v).x) : 0; auto k = s.is_int(v) ? sls::var_kind::INT : sls::var_kind::REAL; m_vars.push_back({ value, value, k, {} }); } } sls::ineq& sls::new_ineq(ineq_kind op, int64_t const& bound) { auto* i = alloc(ineq); i->m_bound = bound; i->m_op = op; return *i; } void sls::add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v) { ineq.m_args.push_back({ c, v }); ineq.m_args_value += c * value(v); m_vars[v].m_bool_vars.push_back({ c, bv}); } int64_t sls::to_numeral(rational const& r) { if (r.is_int64()) return r.get_int64(); return 0; } void sls::add_args(sat::bool_var bv, ineq& ineq, lp::tv t, theory_var v, int64_t sign) { if (t.is_term()) { lp::lar_term const& term = s.lp().get_term(t); m_terms.push_back({t,v}); for (lp::lar_term::ival arg : term) { auto t2 = s.lp().column2tv(arg.column()); auto w = s.lp().local_to_external(t2.id()); add_arg(bv, ineq, sign * to_numeral(arg.coeff()), w); } } else add_arg(bv, ineq, sign, s.lp().local_to_external(t.id())); } void sls::init_bool_var(sat::bool_var bv) { if (m_bool_vars.get(bv, nullptr)) return; api_bound* b = nullptr; s.m_bool_var2bound.find(bv, b); if (b) { auto t = b->tv(); rational bound = b->get_value(); bool should_minus = false; sls::ineq_kind op; should_minus = b->get_bound_kind() == lp_api::bound_kind::lower_t; op = sls::ineq_kind::LE; if (should_minus) bound.neg(); auto& ineq = new_ineq(op, to_numeral(bound)); add_args(bv, ineq, t, b->get_var(), should_minus ? -1 : 1); m_bool_vars.set(bv, &ineq); m_bool_search->set_external(bv); return; } expr* e = s.bool_var2expr(bv); expr* l = nullptr, * r = nullptr; if (e && m.is_eq(e, l, r) && s.a.is_int_real(l)) { theory_var u = s.get_th_var(l); theory_var v = s.get_th_var(r); lp::tv tu = s.get_tv(u); lp::tv tv = s.get_tv(v); auto& ineq = new_ineq(sls::ineq_kind::EQ, 0); add_args(bv, ineq, tu, u, 1); add_args(bv, ineq, tv, v, -1); m_bool_vars.set(bv, &ineq); m_bool_search->set_external(bv); return; } } void sls::init_bool_var_assignment(sat::bool_var v) { auto* ineq = m_bool_vars.get(v, nullptr); if (ineq && is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0)) m_bool_search->flip(v); } void sls::init_search() { on_restart(); } void sls::finish_search() { store_best_values(); } void sls::flip(sat::bool_var v) { sat::literal lit(v, !sign(v)); SASSERT(!is_true(lit)); auto const* ineq = atom(v); if (!ineq) IF_VERBOSE(0, verbose_stream() << "no inequality for variable " << v << "\n"); if (!ineq) return; SASSERT(ineq->is_true() == lit.sign()); flip(sign(v), *ineq); } double sls::reward(sat::bool_var v) { if (m_dscore_mode) return dscore_reward(v); else return dtt_reward(v); } double sls::dtt_reward(sat::bool_var bv0) { bool sign0 = sign(bv0); auto* ineq = atom(bv0); if (!ineq) return -1; int64_t new_value; double max_result = -1; for (auto const & [coeff, x] : ineq->m_args) { if (!cm(sign0, *ineq, x, coeff, new_value)) continue; double result = 0; auto old_value = m_vars[x].m_value; for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { result += m_bool_search->reward(bv); continue; bool old_sign = sign(bv); auto dtt_old = dtt(old_sign, *atom(bv)); auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value); if ((dtt_new == 0) != (dtt_old == 0)) result += m_bool_search->reward(bv); } if (result > max_result) { max_result = result; ineq->m_var_to_flip = x; } } return max_result; } double sls::dscore_reward(sat::bool_var bv) { m_dscore_mode = false; bool old_sign = sign(bv); sat::literal litv(bv, old_sign); auto* ineq = atom(bv); if (!ineq) return 0; SASSERT(ineq->is_true() != old_sign); int64_t new_value; for (auto const& [coeff, v] : ineq->m_args) { double result = 0; if (cm(old_sign, *ineq, v, coeff, new_value)) result = dscore(v, new_value); // just pick first positive, or pick a max? if (result > 0) { ineq->m_var_to_flip = v; return result; } } return 0; } // switch to dscore mode void sls::on_rescale() { m_dscore_mode = true; } void sls::on_save_model() { save_best_values(); } void sls::on_restart() { for (unsigned v = 0; v < s.s().num_vars(); ++v) init_bool_var_assignment(v); check_ineqs(); } void sls::check_ineqs() { auto check_bool_var = [&](sat::bool_var bv) { auto const* ineq = atom(bv); if (!ineq) return; int64_t d = dtt(sign(bv), *ineq); sat::literal lit(bv, sign(bv)); if (is_true(lit) != (d == 0)) { verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n"; } VERIFY(is_true(lit) == (d == 0)); }; for (unsigned v = 0; v < s.get_num_vars(); ++v) check_bool_var(v); } std::ostream& sls::display(std::ostream& out) const { for (bool_var bv = 0; bv < s.s().num_vars(); ++bv) { auto const* ineq = atom(bv); if (!ineq) continue; out << bv << " " << *ineq << "\n"; } for (unsigned v = 0; v < s.get_num_vars(); ++v) { if (s.is_bool(v)) continue; out << "v" << v << " := " << m_vars[v].m_value << " " << m_vars[v].m_best_value << "\n"; } return out; } }