From c5e33b79b580ddf228c508549fc93bdbd855a0d6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 18 Feb 2023 14:11:42 -0800 Subject: [PATCH] wip - arith sls overhaul to tier inequalities with Boolean variables instead of literals --- src/sat/sat_ddfw.cpp | 23 +- src/sat/sat_ddfw.h | 4 +- src/sat/sat_solver.cpp | 18 +- src/sat/smt/arith_sls.cpp | 450 ++++++++++++++++++++----------- src/sat/smt/arith_sls.h | 32 +-- src/sat/smt/arith_solver.cpp | 8 +- src/sat/smt/euf_local_search.cpp | 7 + 7 files changed, 352 insertions(+), 190 deletions(-) diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index 64589dc6e..329d783f4 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -98,15 +98,17 @@ namespace sat { template bool ddfw::do_flip() { - bool_var v = pick_var(); - return apply_flip(v); + double reward = 0; + bool_var v = pick_var(reward); + return apply_flip(v, reward); } template - bool ddfw::apply_flip(bool_var v) { - if (v == null_bool_var) + bool ddfw::apply_flip(bool_var v, double reward) { + if (v == null_bool_var) return false; - if (reward(v) > 0 || (reward(v) == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) { + + if (reward > 0 || (reward == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) { if (uses_plugin && is_external(v)) m_plugin->flip(v); else @@ -119,10 +121,9 @@ namespace sat { } template - bool_var ddfw::pick_var() { + bool_var ddfw::pick_var(double& r) { double sum_pos = 0; unsigned n = 1; - double r; bool_var v0 = null_bool_var; for (bool_var v : m_unsat_vars) { r = uses_plugin ? plugin_reward(v) : reward(v); @@ -142,16 +143,18 @@ namespace sat { } } } + r = 0; if (v0 != null_bool_var) return v0; if (m_unsat_vars.empty()) - return 0; + return null_bool_var; return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); } template bool ddfw::do_literal_flip() { - return apply_flip(pick_literal_var()); + double reward = 1; + return apply_flip(pick_literal_var(), reward); } /* @@ -414,7 +417,7 @@ namespace sat { bool ddfw::should_restart() { return m_flips >= m_restart_next; } - + void ddfw::do_restart() { reinit_values(); init_clause_data(); diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index 654910111..8c4f9287f 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -197,10 +197,10 @@ namespace sat { bool do_flip(); template - bool_var pick_var(); + bool_var pick_var(double& reward); template - bool apply_flip(bool_var v); + bool apply_flip(bool_var v, double reward); template bool do_literal_flip(); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index c570843b5..5ac0fbcf3 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1363,7 +1363,13 @@ namespace sat { if (m_ext) { verbose_stream() << "bounded local search\n"; do_restart(true); - m_ext->local_search(m_best_phase); + lbool r = m_ext->local_search(m_best_phase); + verbose_stream() << r << "\n"; + if (r == l_true) { + m_conflicts_since_restart = 0; + m_conflicts_since_gc = 0; + m_next_simplify = std::max(m_next_simplify, m_conflicts_since_init + 1); + } return; } literal_vector _lits; @@ -1728,6 +1734,8 @@ namespace sat { push(); m_stats.m_decision++; + CTRACE("sat", m_best_phase[next] != guess(next), tout << "phase " << phase << " " << m_best_phase[next] << " " << guess(next) << "\n"); + if (phase == l_undef) phase = guess(next) ? l_true: l_false; @@ -1738,12 +1746,12 @@ namespace sat { m_case_split_queue.unassign_var_eh(next); next_lit = literal(next, false); } - + if (phase == l_undef) is_pos = guess(next); else is_pos = phase == l_true; - + if (!is_pos) next_lit.neg(); @@ -2966,7 +2974,7 @@ namespace sat { } bool solver::should_rephase() { - return m_conflicts_since_init > m_rephase_lim; + return m_conflicts_since_init > 5 && m_conflicts_since_init > m_rephase_lim; } void solver::do_rephase() { @@ -3015,7 +3023,7 @@ namespace sat { UNREACHABLE(); break; } - m_rephase_inc += m_config.m_rephase_base; + m_rephase_inc = m_config.m_rephase_base; m_rephase_lim += m_rephase_inc; } diff --git a/src/sat/smt/arith_sls.cpp b/src/sat/smt/arith_sls.cpp index 54141e635..aeddd319b 100644 --- a/src/sat/smt/arith_sls.cpp +++ b/src/sat/smt/arith_sls.cpp @@ -24,7 +24,7 @@ namespace arith { s(s), m(s.m) {} void sls::reset() { - m_literals.reset(); + m_bool_vars.reset(); m_vars.reset(); m_terms.reset(); } @@ -36,6 +36,21 @@ namespace arith { void sls::save_best_values() { for (unsigned v = 0; v < s.get_num_vars(); ++v) m_vars[v].m_best_value = m_vars[v].m_value; + + auto check_bool_var = [&](sat::bool_var bv) { + auto const* ineq = atom(bv); + if (!ineq) + return; + sat::literal lit(bv, !m_bool_search->get_value(bv)); + int64_t d = dtt(lit.sign(), *ineq); + // verbose_stream() << "check " << lit << " " << *ineq << "\n"; + if (is_true(lit) != (d == 0)) { + verbose_stream() << lit << " " << *ineq << "\n"; + } + VERIFY(is_true(lit) == (d == 0)); + }; + for (unsigned v = 0; v < s.get_num_vars(); ++v) + check_bool_var(v); } void sls::store_best_values() { @@ -47,7 +62,7 @@ namespace arith { for (lp::lar_term::ival arg : term) { auto t2 = s.lp().column2tv(arg.column()); auto w = s.lp().local_to_external(t2.id()); - val += to_numeral(arg.coeff()) * value(w); + val += to_numeral(arg.coeff()) * m_vars[w].m_best_value; } update(v, val); } @@ -55,14 +70,12 @@ namespace arith { for (unsigned v = 0; v < s.get_num_vars(); ++v) { if (s.is_bool(v)) continue; - if (!s.lp().external_is_used(v)) - continue; + if (!s.lp().external_is_used(v)) + continue; int64_t old_value = 0; if (s.is_registered_var(v)) old_value = to_numeral(s.get_ivalue(v).x); int64_t new_value = m_vars[v].m_best_value; - if (old_value == new_value) - continue; s.ensure_column(v); lp::column_index vj = s.lp().to_column_index(v); SASSERT(!vj.is_null()); @@ -73,40 +86,98 @@ namespace arith { // TODO - figure out why this leads to unsound (unsat). } } + + lbool r = s.make_feasible(); + VERIFY (!unsat().empty() || r == l_true); + if (unsat().empty()) { + s.m_num_conflicts = s.get_config().m_arith_propagation_threshold; + } + verbose_stream() << "has changed " << s.m_solver->has_changed_columns() << "\n"; + + auto check_bool_var = [&](sat::bool_var bv) { + auto* ineq = m_bool_vars.get(bv, nullptr); + if (!ineq) + return; + api_bound* b = nullptr; + s.m_bool_var2bound.find(bv, b); + if (!b) + return; + auto bound = b->get_value(); + theory_var v = b->get_var(); + if (s.get_phase(bv) == m_bool_search->get_model()[bv]) + return; + switch (b->get_bound_kind()) { + case lp_api::lower_t: + verbose_stream() << bv << " " << bound << " <= " << s.get_value(v) << "\n"; + break; + case lp_api::upper_t: + verbose_stream() << bv << " " << bound << " >= " << s.get_value(v) << "\n"; + break; + } + int64_t value = 0; + for (auto const& [coeff, v] : ineq->m_args) { + value += coeff * m_vars[v].m_best_value; + } + ineq->m_args_value = value; + verbose_stream() << *ineq << " dtt " << dtt(false, *ineq) << " phase " << s.get_phase(bv) << " model " << m_bool_search->get_model()[bv] << "\n"; + }; + + if (unsat().empty()) { + for (bool_var v = 0; v < s.s().num_vars(); ++v) + check_bool_var(v); + } } void sls::set(sat::ddfw* d) { m_bool_search = d; reset(); - m_literals.reserve(s.s().num_vars() * 2); + m_bool_vars.reserve(s.s().num_vars()); add_vars(); for (unsigned i = 0; i < d->num_clauses(); ++i) for (sat::literal lit : *d->get_clause_info(i).m_clause) - init_literal(lit); + init_bool_var(lit.var()); for (unsigned v = 0; v < s.s().num_vars(); ++v) init_bool_var_assignment(v); - m_best_min_unsat = std::numeric_limits::max(); d->set(this); } - // distance to true - int64_t sls::dtt(int64_t args, ineq const& ineq) const { + int64_t sls::dtt(bool sign, int64_t args, ineq const& ineq) const { switch (ineq.m_op) { case ineq_kind::LE: + if (sign) { + if (args <= ineq.m_bound) + return ineq.m_bound - args + 1; + return 0; + } if (args <= ineq.m_bound) return 0; return args - ineq.m_bound; case ineq_kind::EQ: + if (sign) { + if (args == ineq.m_bound) + return 1; + return 0; + } if (args == ineq.m_bound) return 0; return 1; case ineq_kind::NE: + if (sign) { + if (args == ineq.m_bound) + return 0; + return 1; + } if (args == ineq.m_bound) return 1; return 0; case ineq_kind::LT: + if (sign) { + if (args < ineq.m_bound) + return ineq.m_bound - args; + return 0; + } if (args < ineq.m_bound) return 0; return args - ineq.m_bound + 1; @@ -121,45 +192,95 @@ namespace arith { // m_vars[w].m_value can be computed outside and shared among calls // different data-structures for storing coefficients // - int64_t sls::dtt(ineq const& ineq, var_t v, int64_t new_value) const { - auto new_args_value = ineq.m_args_value; - for (auto const& [coeff, w] : ineq.m_args) { - if (w == v) { - new_args_value += coeff * (new_value - m_vars[w].m_value); + int64_t sls::dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const { + for (auto const& [coeff, w] : ineq.m_args) + if (w == v) + return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); + return 1; + } + + int64_t sls::dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const { + return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq); + } + + bool sls::cm(bool sign, ineq const& ineq, var_t v, int64_t& new_value) { + for (auto const& [coeff, w] : ineq.m_args) + if (w == v) + return cm(sign, ineq, v, coeff, new_value); + return false; + } + + bool sls::cm(bool sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) { + VERIFY(ineq.is_true() == sign); + verbose_stream() << "cm " << ineq << " for " << v << " sign " << sign << "\n"; + auto bound = ineq.m_bound; + auto argsv = ineq.m_args_value; + bool solved = false; + int64_t delta = argsv - bound; + if (sign) { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(argsv <= bound); + SASSERT(delta <= 0); + delta--; + new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; + VERIFY(argsv + coeff * (new_value - value(v)) > bound); + return true; + case ineq_kind::LT: + SASSERT(argsv <= ineq.m_bound); + SASSERT(delta <= 0); + new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; + VERIFY(argsv + coeff * (new_value - value(v)) >= bound); + return true; + case ineq_kind::EQ: + if (delta >= 0) + delta++; + else + delta--; + new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; + VERIFY(argsv + coeff * (new_value - value(v)) != bound); + return true; + case ineq_kind::NE: + new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; + solved = argsv + coeff * (new_value - value(v)) == bound; + if (!solved) verbose_stream() << "did not solve disequality " << ineq << " for " << v << "\n"; + return solved; + default: + UNREACHABLE(); break; } } - return dtt(new_args_value, ineq); - } - - // critical move - bool sls::cm(ineq const& ineq, var_t v, int64_t& new_value) { - SASSERT(!ineq.is_true()); - int64_t delta = ineq.m_args_value - ineq.m_bound; - if (ineq.m_op == ineq_kind::NE || ineq.m_op == ineq_kind::LT) - delta--; - for (auto const& [coeff, w] : ineq.m_args) { - if (w == v) { - - if (coeff > 0) - new_value = value(v) - abs((delta + coeff - 1)/ coeff); + else { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(argsv > ineq.m_bound); + SASSERT(delta > 0); + new_value = value(v) - (delta + abs(coeff) - 1) / coeff; + VERIFY(argsv + coeff * (new_value - value(v)) <= bound); + return true; + case ineq_kind::LT: + SASSERT(argsv >= ineq.m_bound); + SASSERT(delta >= 0); + ++delta; + new_value = value(v) - (abs(delta) + abs(coeff) - 1) / coeff; + VERIFY(argsv + coeff * (new_value - value(v)) < bound); + return true; + case ineq_kind::NE: + if (delta >= 0) + delta++; else - new_value = value(v) + abs(delta) / -coeff; - - switch (ineq.m_op) { - case ineq_kind::LE: - SASSERT(delta + coeff * (new_value - value(v)) <= 0); - return true; - case ineq_kind::EQ: - return delta + coeff * (new_value - value(v)) == 0; - case ineq_kind::NE: - return delta + coeff * (new_value - value(v)) != 0; - case ineq_kind::LT: - return delta + coeff * (new_value - value(v)) < 0; - default: - UNREACHABLE(); - break; - } + delta--; + new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; + VERIFY(argsv + coeff * (new_value - value(v)) != bound); + return true; + case ineq_kind::EQ: + new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; + solved = argsv + coeff * (new_value - value(v)) == bound; + if (!solved) verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n"; + return solved; + default: + UNREACHABLE(); + break; } } return false; @@ -169,23 +290,19 @@ namespace arith { // it could be changed to flip on maximal positive score // or flip on maximal non-negative score // or flip on first non-negative score - bool sls::flip(ineq const& ineq) { + bool sls::flip(bool sign, ineq const& ineq) { int64_t new_value; - for (auto const& [coeff, v] : ineq.m_args) { - if (!cm(ineq, v, new_value)) - continue; - int score = cm_score(v, new_value); - if (score <= 0) - continue; - unsigned num_unsat = unsat().size(); - update(v, new_value); - IF_VERBOSE(2, - verbose_stream() << "v" << v << " score " << score << " " - << num_unsat << " -> " << unsat().size() << "\n"); - SASSERT(num_unsat > unsat().size()); - return true; + auto v = ineq.m_var_to_flip; + if (v == UINT_MAX) { + verbose_stream() << "no var to flip\n"; + return false; } - return false; + if (!cm(sign, ineq, v, new_value)) { + verbose_stream() << "no critical move for " << v << "\n"; + return false; + } + update(v, new_value); + return true; } // @@ -195,10 +312,13 @@ namespace arith { // double sls::dscore(var_t v, int64_t new_value) const { auto const& vi = m_vars[v]; + verbose_stream() << "dscore " << v << "\n"; double score = 0; +#if 0 for (auto const& [coeff, lit] : vi.m_literals) for (auto cl : m_bool_search->get_use_list(lit)) - score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); + score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); +#endif return score; } @@ -212,25 +332,28 @@ namespace arith { int score = 0; auto& vi = m_vars[v]; int64_t old_value = vi.m_value; - for (auto const& [coeff, lit] : vi.m_literals) { - auto const& ineq = *atom(lit); - int64_t dtt_old = dtt(ineq); - int64_t delta = coeff * (new_value - old_value); - int64_t dtt_new = dtt(ineq.m_args_value + delta, ineq); - - if (dtt_old == dtt_new) + for (auto const& [coeff, bv] : vi.m_bool_vars) { + auto const& ineq = *atom(bv); + bool sign = !m_bool_search->value(bv); + int64_t dtt_old = dtt(sign, ineq); + int64_t dtt_new = dtt(sign, ineq, coeff, old_value, new_value); + if ((dtt_old == 0) == (dtt_new == 0)) continue; - + sat::literal lit(bv, sign); + if (dtt_old == 0) + // flip from true to false + lit.neg(); + + // lit flips form false to true: for (auto cl : m_bool_search->get_use_list(lit)) { auto const& clause = get_clause_info(cl); - if (!clause.is_true()) { - VERIFY(dtt_old != 0); - if (dtt_new == 0) - ++score; // false -> true - } - else if (dtt_new == 0 || dtt_old > 0 || clause.m_num_trues > 1) // true -> true not really, same variable can be in multiple literals - continue; - else if (all_of(*clause.m_clause, [&](auto lit2) { return !atom(lit2) || dtt(*atom(lit2), v, new_value) > 0; })) // ?? TODO + if (!clause.is_true()) + ++score; + } + // ignore the situation where clause contains multiple literals using v + for (auto cl : m_bool_search->get_use_list(~lit)) { + auto const& clause = get_clause_info(cl); + if (clause.m_num_trues == 1) --score; } } @@ -241,10 +364,10 @@ namespace arith { int64_t d(1), d2; bool first = true; for (auto a : get_clause(cl)) { - auto const* ineq = atom(a); + auto const* ineq = atom(a.var()); if (!ineq) continue; - d2 = dtt(*ineq); + d2 = dtt(a.sign(), *ineq); if (first) d = d2, first = false; else @@ -259,10 +382,10 @@ namespace arith { int64_t d(1), d2; bool first = true; for (auto lit : get_clause(cl)) { - auto const* ineq = atom(lit); + auto const* ineq = atom(lit.var()); if (!ineq) continue; - d2 = dtt(*ineq, v, new_value); + d2 = dtt(lit.sign(), *ineq, v, new_value); if (first) d = d2, first = false; else @@ -275,15 +398,17 @@ namespace arith { void sls::update(var_t v, int64_t new_value) { auto& vi = m_vars[v]; - auto const& old_value = vi.m_value; - for (auto const& [coeff, lit] : vi.m_literals) { - auto& ineq = *atom(lit); + auto old_value = vi.m_value; + for (auto const& [coeff, bv] : vi.m_bool_vars) { + auto& ineq = *atom(bv); + bool sign = !m_bool_search->value(bv); + sat::literal lit(bv, sign); + SASSERT(is_true(lit)); ineq.m_args_value += coeff * (new_value - old_value); - int64_t dtt_new = dtt(ineq); - if ((dtt_new == 0) != is_true(lit)) - m_bool_search->flip(lit.var()); - - SASSERT((dtt_new == 0) == is_true(lit)); + int64_t dtt_new = dtt(sign, ineq); + if (dtt_new != 0) + m_bool_search->flip(bv); + SASSERT(dtt(!m_bool_search->value(bv), ineq) == 0); } vi.m_value = new_value; } @@ -304,10 +429,10 @@ namespace arith { return *i; } - void sls::add_arg(sat::literal lit, ineq& ineq, int64_t const& c, var_t v) { + void sls::add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v) { ineq.m_args.push_back({ c, v }); ineq.m_args_value += c * value(v); - m_vars[v].m_literals.push_back({ c, lit }); + m_vars[v].m_bool_vars.push_back({ c, bv}); } int64_t sls::to_numeral(rational const& r) { @@ -316,79 +441,63 @@ namespace arith { return 0; } - - void sls::add_args(sat::literal lit, ineq& ineq, lp::tv t, theory_var v, int64_t sign) { + void sls::add_args(sat::bool_var bv, ineq& ineq, lp::tv t, theory_var v, int64_t sign) { if (t.is_term()) { lp::lar_term const& term = s.lp().get_term(t); for (lp::lar_term::ival arg : term) { auto t2 = s.lp().column2tv(arg.column()); auto w = s.lp().local_to_external(t2.id()); - add_arg(lit, ineq, sign * to_numeral(arg.coeff()), w); + add_arg(bv, ineq, sign * to_numeral(arg.coeff()), w); } } else - add_arg(lit, ineq, sign, s.lp().local_to_external(t.id())); + add_arg(bv, ineq, sign, s.lp().local_to_external(t.id())); } - - void sls::init_literal(sat::literal lit) { - if (m_literals.get(lit.index(), nullptr)) + void sls::init_bool_var(sat::bool_var bv) { + if (m_bool_vars.get(bv, nullptr)) return; api_bound* b = nullptr; - s.m_bool_var2bound.find(lit.var(), b); + s.m_bool_var2bound.find(bv, b); if (b) { auto t = b->tv(); rational bound = b->get_value(); bool should_minus = false; sls::ineq_kind op; - if (!lit.sign()) { - should_minus = b->get_bound_kind() == lp_api::bound_kind::upper_t; - op = sls::ineq_kind::LE; - } - else { - should_minus = b->get_bound_kind() == lp_api::bound_kind::lower_t; - if (s.is_int(b->get_var())) { - bound -= 1; - op = sls::ineq_kind::LE; - } - else - op = sls::ineq_kind::LT; - - } + should_minus = b->get_bound_kind() == lp_api::bound_kind::lower_t; + op = sls::ineq_kind::LE; if (should_minus) bound.neg(); + auto& ineq = new_ineq(op, to_numeral(bound)); - add_args(lit, ineq, t, b->get_var(), should_minus ? -1 : 1); - m_literals.set(lit.index(), &ineq); + add_args(bv, ineq, t, b->get_var(), should_minus ? -1 : 1); + m_bool_vars.set(bv, &ineq); + m_bool_search->set_external(bv); return; } - expr* e = s.bool_var2expr(lit.var()); + expr* e = s.bool_var2expr(bv); expr* l = nullptr, * r = nullptr; if (e && m.is_eq(e, l, r) && s.a.is_int_real(l)) { theory_var u = s.get_th_var(l); theory_var v = s.get_th_var(r); lp::tv tu = s.get_tv(u); lp::tv tv = s.get_tv(v); - auto& ineq = new_ineq(lit.sign() ? sls::ineq_kind::NE : sls::ineq_kind::EQ, 0); - add_args(lit, ineq, tu, u, 1); - add_args(lit, ineq, tv, v, -1); - m_literals.set(lit.index(), &ineq); + auto& ineq = new_ineq(sls::ineq_kind::EQ, 0); + add_args(bv, ineq, tu, u, 1); + add_args(bv, ineq, tv, v, -1); + m_bool_vars.set(bv, &ineq); + m_bool_search->set_external(bv); return; } } void sls::init_bool_var_assignment(sat::bool_var v) { - init_literal_assignment(literal(v, false)); - init_literal_assignment(literal(v, true)); - } - - void sls::init_literal_assignment(sat::literal lit) { - auto* ineq = m_literals.get(lit.index(), nullptr); - if (ineq && is_true(lit) != (dtt(*ineq) == 0)) - m_bool_search->flip(lit.var()); + auto* ineq = m_bool_vars.get(v, nullptr); + if (ineq && is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0)) + m_bool_search->flip(v); } void sls::init_search() { @@ -402,14 +511,13 @@ namespace arith { void sls::flip(sat::bool_var v) { sat::literal lit(v, m_bool_search->get_value(v)); SASSERT(!is_true(lit)); - auto const* ineq = atom(lit); + auto const* ineq = atom(v); if (!ineq) IF_VERBOSE(0, verbose_stream() << "no inequality for variable " << v << "\n"); if (!ineq) return; - IF_VERBOSE(1, verbose_stream() << "flip " << lit << "\n"); - SASSERT(!ineq->is_true()); - flip(*ineq); + SASSERT(ineq->is_true() == lit.sign()); + flip(!lit.sign(), *ineq); } double sls::reward(sat::bool_var v) { @@ -419,39 +527,56 @@ namespace arith { return dtt_reward(v); } - double sls::dtt_reward(sat::bool_var v) { - sat::literal litv(v, m_bool_search->get_value(v)); - auto const* ineq = atom(litv); + double sls::dtt_reward(sat::bool_var bv0) { + bool sign0 = !m_bool_search->get_value(bv0); + auto* ineq = atom(bv0); if (!ineq) - return 0; - int64_t new_value; + return -1; + int64_t new_value; double result = 0; + double max_result = -1; + theory_var max_var = 0; for (auto const & [coeff, x] : ineq->m_args) { - if (!cm(*ineq, x, new_value)) + if (!cm(!sign0, *ineq, x, coeff, new_value)) continue; - for (auto const [coeff, lit] : m_vars[x].m_literals) { - auto dtt_old = dtt(*atom(lit)); - auto dtt_new = dtt(*atom(lit), x, new_value); + double result = 0; + auto old_value = m_vars[x].m_value; + for (auto const [coeff, bv] : m_vars[x].m_bool_vars) { + bool sign = !m_bool_search->value(bv); + auto dtt_old = dtt(sign, *atom(bv)); + auto dtt_new = dtt(sign, *atom(bv), coeff, old_value, new_value); if ((dtt_new == 0) != (dtt_old == 0)) - result += m_bool_search->reward(lit.var()); + result += m_bool_search->reward(bv); + } + if (result > max_result) { + max_result = result; + ineq->m_var_to_flip = x; } } - return result; + return max_result; } - double sls::dscore_reward(sat::bool_var x) { + double sls::dscore_reward(sat::bool_var bv) { m_dscore_mode = false; - sat::literal litv(x, m_bool_search->get_value(x)); - auto const* ineq = atom(litv); + bool sign = !m_bool_search->get_value(bv); + sat::literal litv(bv, sign); + auto* ineq = atom(bv); if (!ineq) return 0; - SASSERT(!ineq->is_true()); + SASSERT(ineq->is_true() == sign); int64_t new_value; - double result = 0; - for (auto const& [coeff, v] : ineq->m_args) - if (cm(*ineq, v, new_value)) - result += dscore(v, new_value); - return result; + + for (auto const& [coeff, v] : ineq->m_args) { + double result = 0; + if (cm(sign, *ineq, v, coeff, new_value)) + result = dscore(v, new_value); + // just pick first positive, or pick a max? + if (result > 0) { + ineq->m_var_to_flip = v; + return result; + } + } + return 0; } // switch to dscore mode @@ -466,5 +591,24 @@ namespace arith { void sls::on_restart() { for (unsigned v = 0; v < s.s().num_vars(); ++v) init_bool_var_assignment(v); + + verbose_stream() << "on-restart\n"; + auto check_bool_var = [&](sat::bool_var bv) { + auto const* ineq = atom(bv); + if (!ineq) + return; + bool sign = !m_bool_search->get_value(bv); + int64_t d = dtt(sign, *ineq); + sat::literal lit(bv, sign); + // verbose_stream() << "check " << lit << " " << *ineq << "\n"; + if (is_true(lit) != (d == 0)) { + verbose_stream() << "restart " << bv << " " << *ineq << "\n"; + } + VERIFY(is_true(lit) == (d == 0)); + }; + for (unsigned v = 0; v < s.get_num_vars(); ++v) + check_bool_var(v); + + verbose_stream() << "on-restart-done\n"; } } diff --git a/src/sat/smt/arith_sls.h b/src/sat/smt/arith_sls.h index 5496be637..3c9daaa51 100644 --- a/src/sat/smt/arith_sls.h +++ b/src/sat/smt/arith_sls.h @@ -60,8 +60,9 @@ namespace arith { struct ineq { vector> m_args; ineq_kind m_op = ineq_kind::LE; - int64_t m_bound; - int64_t m_args_value; + int64_t m_bound; + int64_t m_args_value; + unsigned m_var_to_flip = UINT_MAX; bool is_true() const { switch (m_op) { @@ -97,17 +98,15 @@ namespace arith { int64_t m_value; int64_t m_best_value; var_kind m_kind = var_kind::INT; - svector> m_literals; + svector> m_bool_vars; }; solver& s; ast_manager& m; sat::ddfw* m_bool_search = nullptr; - unsigned m_max_arith_steps = 0; - unsigned m_best_min_unsat = UINT_MAX; stats m_stats; config m_config; - scoped_ptr_vector m_literals; + scoped_ptr_vector m_bool_vars; vector m_vars; svector> m_terms; bool m_dscore_mode = false; @@ -122,17 +121,19 @@ namespace arith { bool is_true(sat::literal lit) { return lit.sign() != m_bool_search->get_value(lit.var()); } void reset(); - ineq* atom(sat::literal lit) const { return m_literals[lit.index()]; } + ineq* atom(sat::bool_var bv) const { return m_bool_vars[bv]; } void log(); - bool flip(ineq const& ineq); - int64_t dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); } - int64_t dtt(int64_t args, ineq const& ineq) const; - int64_t dtt(ineq const& ineq, var_t v, int64_t new_value) const; + bool flip(bool sign, ineq const& ineq); + int64_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } + int64_t dtt(bool sign, int64_t args_value, ineq const& ineq) const; + int64_t dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const; + int64_t dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const; int64_t dts(unsigned cl, var_t v, int64_t new_value) const; int64_t compute_dts(unsigned cl) const; - bool cm(ineq const& ineq, var_t v, int64_t& new_value); + bool cm(bool sign, ineq const& ineq, var_t v, int64_t& new_value); + bool cm(bool sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value); int cm_score(var_t v, int64_t new_value); void update(var_t v, int64_t new_value); double dscore_reward(sat::bool_var v); @@ -142,11 +143,10 @@ namespace arith { void store_best_values(); void add_vars(); sls::ineq& new_ineq(ineq_kind op, int64_t const& bound); - void add_arg(sat::literal lit, ineq& ineq, int64_t const& c, var_t v); - void add_args(sat::literal lit, ineq& ineq, lp::tv t, euf::theory_var v, int64_t sign); - void init_literal(sat::literal lit); + void add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v); + void add_args(sat::bool_var bv, ineq& ineq, lp::tv t, euf::theory_var v, int64_t sign); + void init_bool_var(sat::bool_var v); void init_bool_var_assignment(sat::bool_var v); - void init_literal_assignment(sat::literal lit); int64_t value(var_t v) const { return m_vars[v].m_value; } int64_t to_numeral(rational const& r); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index fd507ed15..bd5dd315f 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -101,8 +101,7 @@ namespace arith { return false; switch (lbl) { - case l_false: - TRACE("arith", tout << "propagation conflict\n";); + case l_false: get_infeasibility_explanation_and_set_conflict(); break; case l_true: @@ -382,9 +381,9 @@ namespace arith { void solver::assert_bound(bool is_true, api_bound& b) { - TRACE("arith", tout << b << "\n";); lp::constraint_index ci = b.get_constraint(is_true); lp().activate(ci); + TRACE("arith", tout << b << " " << is_infeasible() << "\n";); if (is_infeasible()) return; lp::lconstraint_kind k = bound2constraint_kind(b.is_int(), b.get_bound_kind(), is_true); @@ -1066,6 +1065,7 @@ namespace arith { TRACE("pcs", tout << lp().constraints();); auto status = lp().find_feasible_solution(); TRACE("arith_verbose", display(tout);); + TRACE("arith", tout << status << "\n"); switch (status) { case lp::lp_status::INFEASIBLE: return l_false; @@ -1202,7 +1202,7 @@ namespace arith { TRACE("arith", tout << "Lemma - " << (is_conflict ? "conflict" : "propagation") << "\n"; - for (literal c : m_core) tout << literal2expr(c) << "\n"; + for (literal c : m_core) tout << c << ": " << literal2expr(c) << "\n"; for (auto p : m_eqs) tout << ctx.bpp(p.first) << " == " << ctx.bpp(p.second) << "\n";); if (is_conflict) { diff --git a/src/sat/smt/euf_local_search.cpp b/src/sat/smt/euf_local_search.cpp index d22a65fb8..1c83b3a69 100644 --- a/src/sat/smt/euf_local_search.cpp +++ b/src/sat/smt/euf_local_search.cpp @@ -38,6 +38,13 @@ namespace euf { for (unsigned i = 0; i < mdl.size(); ++i) phase[i] = mdl[i] == l_true; + if (bool_search.unsat_set().empty()) { + enable_trace("arith"); + enable_trace("sat"); + enable_trace("euf"); + TRACE("sat", s().display(tout)); + } + return bool_search.unsat_set().empty() ? l_true : l_undef; } }