From c7ea4964f21b7ed6a0789a2e1aab5528d77887de Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 17 Nov 2024 13:07:28 -0800 Subject: [PATCH] bug fixes to sls --- src/ast/sls/sls_arith_base.cpp | 121 ++++++------------------------- src/ast/sls/sls_arith_base.h | 56 ++++++++------ src/ast/sls/sls_arith_plugin.cpp | 10 +-- src/ast/sls/sls_smt_plugin.cpp | 11 ++- src/ast/sls/sls_smt_plugin.h | 3 +- src/sat/smt/sls_solver.cpp | 6 +- src/smt/theory_sls.cpp | 16 +++- 7 files changed, 84 insertions(+), 139 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index e3aa9ff01..8e68781cc 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -59,8 +59,6 @@ namespace sls { } } - - template std::ostream& arith_base::ineq::display(std::ostream& out) const { bool first = true; @@ -118,7 +116,7 @@ namespace sls { template void arith_base::save_best_values() { for (auto& v : m_vars) - v.m_best_value = v.m_value; + v.set_best_value(v.value()); check_ineqs(); } @@ -168,8 +166,8 @@ namespace sls { template num_t arith_base::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const { for (auto const& [coeff, w] : ineq.m_args) - if (w == v) - return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); + if (w == v) + return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].value()), ineq); return num_t(1); } @@ -444,17 +442,19 @@ namespace sls { delta_out = delta; - if (m_last_var == v && m_last_delta == -delta) - return false; + if (m_last_var == v && m_last_delta == -delta) + return false; - if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) + if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) return false; + auto old_value = value(v); auto new_value = old_value + delta; if (!vi.in_range(new_value)) return false; + if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) { auto const& lo = m_vars[v].m_lo; auto const& hi = m_vars[v].m_hi; @@ -490,9 +490,7 @@ namespace sls { void arith_base::add_update(var_t v, num_t delta) { num_t delta_out; if (!is_permitted_update(v, delta, delta_out)) - return; - - + return; m_updates.push_back({ v, delta_out, 0 }); } @@ -647,7 +645,7 @@ namespace sls { bool arith_base::update(var_t v, num_t const& new_value) { auto& vi = m_vars[v]; expr* e = vi.m_expr; - auto old_value = vi.m_value; + auto old_value = vi.value(); if (old_value == new_value) return true; if (!vi.in_range(new_value)) @@ -665,15 +663,10 @@ namespace sls { } } catch (overflow_exception const&) { + verbose_stream() << "overflow1\n"; return false; } -#if 0 - if (!check_update(v, new_value)) - return false; - apply_checked_update(); -#else - buffer to_flip; for (auto const& [coeff, bv] : vi.m_bool_vars) { auto& ineq = *atom(bv); @@ -687,12 +680,13 @@ namespace sls { } IF_VERBOSE(5, verbose_stream() << "repair: v" << v << " := " << old_value << " -> " << new_value << "\n"); - vi.m_value = new_value; + vi.set_value(new_value); ctx.new_value_eh(e); m_last_var = v; for (auto bv : to_flip) { - ctx.flip(bv); + if (dtt(sign(bv), *atom(bv)) != 0) + ctx.flip(bv); SASSERT(dtt(sign(bv), *atom(bv)) == 0); } @@ -711,6 +705,7 @@ namespace sls { prod *= power_of(value(w), p); } catch (overflow_exception const&) { + verbose_stream() << "overflow\n"; return false; } if (value(w) != prod && !update(w, prod)) @@ -727,82 +722,10 @@ namespace sls { if (!update(ad.m_var, sum)) return false; } -#endif return true; } - - template - bool arith_base::check_update(var_t v, num_t new_value) { - - ++m_update_timestamp; - if (m_update_timestamp == 0) { - for (auto& vi : m_vars) - vi.set_update_value(num_t(0), 0); - ++m_update_timestamp; - } - auto& vi = m_vars[v]; - m_update_trail.reset(); - m_update_trail.push_back(v); - vi.set_update_value(new_value, m_update_timestamp); - - num_t delta; - for (unsigned i = 0; i < m_update_trail.size(); ++i) { - auto v = m_update_trail[i]; - auto& vi = m_vars[v]; - for (auto idx : vi.m_muls) { - auto const& [w, monomial] = m_muls[idx]; - num_t prod(1); - try { - for (auto [w, p] : monomial) - prod *= power_of(get_update_value(w), p); - } - catch (overflow_exception const&) { - return false; - } - if (get_update_value(w) != prod && (!is_permitted_update(w, prod - value(w), delta) || prod - value(w) != delta)) - return false; - m_update_trail.push_back(w); - m_vars[w].set_update_value(prod, m_update_timestamp); - } - - for (auto idx : vi.m_adds) { - auto const& ad = m_adds[idx]; - auto w = ad.m_var; - num_t sum(ad.m_coeff); - for (auto const& [coeff, w] : ad.m_args) - sum += coeff * get_update_value(w); - if (get_update_value(v) != sum && !(is_permitted_update(w, sum - value(w), delta) || sum - value(w) != delta)) - return false; - m_update_trail.push_back(w); - m_vars[w].set_update_value(sum, m_update_timestamp); - } - } - return true; - } - - template - void arith_base::apply_checked_update() { - for (auto v : m_update_trail) { - auto & vi = m_vars[v]; - auto old_value = vi.m_value; - vi.m_value = vi.get_update_value(m_update_timestamp); - auto new_value = vi.m_value; - ctx.new_value_eh(vi.m_expr); - for (auto const& [coeff, bv] : vi.m_bool_vars) { - auto& ineq = *atom(bv); - bool old_sign = sign(bv); - sat::literal lit(bv, old_sign); - SASSERT(ctx.is_true(lit)); - ineq.m_args_value += coeff * (new_value - old_value); - num_t dtt_new = dtt(old_sign, ineq); - if (dtt_new != 0) - ctx.flip(bv); - SASSERT(dtt(sign(bv), ineq) == 0); - } - } - } - + template typename arith_base::ineq& arith_base::new_ineq(ineq_kind op, num_t const& coeff) { auto* i = alloc(ineq); @@ -906,7 +829,7 @@ namespace sls { m_vars[w].m_muls.push_back(idx), prod *= power_of(value(w), p); m_vars[v].m_def_idx = idx; m_vars[v].m_op = arith_op_kind::OP_MUL; - m_vars[v].m_value = prod; + m_vars[v].set_value(prod); add_arg(term, coeff, v); break; } @@ -972,7 +895,7 @@ namespace sls { m_ops.push_back({v, k, v, w}); m_vars[v].m_def_idx = idx; m_vars[v].m_op = k; - m_vars[v].m_value = val; + m_vars[v].set_value(val); return v; } @@ -993,7 +916,7 @@ namespace sls { m_vars[w].m_adds.push_back(idx), sum += c * value(w); m_vars[v].m_def_idx = idx; m_vars[v].m_op = arith_op_kind::OP_ADD; - m_vars[v].m_value = sum; + m_vars[v].set_value(sum); return v; } @@ -1055,6 +978,7 @@ namespace sls { else { SASSERT(!a.is_arith_expr(e)); } + } template @@ -1345,6 +1269,7 @@ namespace sls { hi_valid = false; } catch (overflow_exception&) { + verbose_stream() << "overflow3\n"; hi_valid = false; } } @@ -2021,7 +1946,7 @@ namespace sls { if (is_num(e, n)) return expr_ref(a.mk_numeral(n.to_rational(), a.is_int(e)), m); auto v = mk_term(e); - return expr_ref(a.mk_numeral(m_vars[v].m_value.to_rational(), a.is_int(e)), m); + return expr_ref(a.mk_numeral(m_vars[v].value().to_rational(), a.is_int(e)), m); } template @@ -2112,7 +2037,7 @@ namespace sls { auto const& vi = m_vars[v]; auto const& lo = vi.m_lo; auto const& hi = vi.m_hi; - out << "v" << v << " := " << vi.m_value << " "; + out << "v" << v << " := " << vi.value() << " "; if (lo || hi) { if (lo) out << (lo->is_strict ? "(": "[") << lo->value; diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index fe9876660..a816ba518 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -76,13 +76,14 @@ namespace sls { class var_info { num_t m_range{ 100000000 }; - num_t m_update_value{ 0 }; - unsigned m_update_timestamp = 0; + unsigned m_num_out_of_range = 0; + unsigned m_num_in_range = 0; + num_t m_value{ 0 }; + num_t m_best_value{ 0 }; public: var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {} expr* m_expr; - num_t m_value{ 0 }; - num_t m_best_value{ 0 }; + var_sort m_sort; arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP; unsigned m_def_idx = UINT_MAX; @@ -91,23 +92,27 @@ namespace sls { unsigned_vector m_adds; optional m_lo, m_hi; - // retrieve temporary value during an update. - void set_update_value(num_t const& v, unsigned timestamp) { - m_update_value = v; - m_update_timestamp = timestamp; - } - num_t const& get_update_value(unsigned ts) const { - return ts == m_update_timestamp ? m_update_value : m_value; - } + num_t const& value() const { return m_value; } + void set_value(num_t const& v) { m_value = v; } - bool in_range(num_t const& n) const { + num_t const& best_value() const { return m_best_value; } + void set_best_value(num_t const& v) { m_best_value = v; } + + bool in_range(num_t const& n) { if (-m_range < n && n < m_range) return true; + bool result = false; if (m_lo && !m_hi) - return n < m_lo->value + m_range; - if (!m_lo && m_hi) - return n > m_hi->value - m_range; - return false; + result = n < m_lo->value + m_range; + else if (!m_lo && m_hi) + result = n > m_hi->value - m_range; +#if 0 + if (!result) + out_of_range(); + else + ++m_num_in_range; +#endif + return result; } unsigned m_tabu_pos = 0, m_tabu_neg = 0; unsigned m_last_pos = 0, m_last_neg = 0; @@ -120,6 +125,15 @@ namespace sls { else m_tabu_neg = tabu_step, m_last_neg = step; } + void out_of_range() { + ++m_num_out_of_range; + if (m_num_out_of_range < 1000 * (1 + m_num_in_range)) + return; + IF_VERBOSE(2, verbose_stream() << "increase range " << m_range << "\n"); + m_range *= 2; + m_num_out_of_range = 0; + m_num_in_range = 0; + } }; struct mul_def { @@ -187,10 +201,7 @@ namespace sls { void add_update(var_t v, num_t delta); bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out); - unsigned m_update_timestamp = 0; - svector m_update_trail; - bool check_update(var_t v, num_t new_value); - void apply_checked_update(); + num_t value1(var_t v); @@ -247,8 +258,7 @@ namespace sls { bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; } - num_t value(var_t v) const { return m_vars[v].m_value; } - num_t const& get_update_value(var_t v) const { return m_vars[v].get_update_value(m_update_timestamp); } + num_t value(var_t v) const { return m_vars[v].value(); } bool is_num(expr* e, num_t& i); expr_ref from_num(sort* s, num_t const& n); void check_ineqs(); diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index 310d4009f..d9275b4e7 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -27,7 +27,7 @@ namespace sls { return m_arith64->_fn_;\ }\ catch (overflow_exception&) {\ - throw;\ + IF_VERBOSE(1, verbose_stream() << "revert to bignum solver " << #_fn_ << "\n");\ init_backup();\ }\ }\ @@ -39,7 +39,7 @@ namespace sls { m_arith64->_fn_;\ }\ catch (overflow_exception&) {\ - throw;\ + IF_VERBOSE(1, verbose_stream() << "revert to bignum solver " << #_fn_ << "\n");\ init_backup();\ }\ }\ @@ -49,11 +49,7 @@ namespace sls { plugin(ctx), m_shared(ctx.get_manager()) { m_arith64 = alloc(arith_base>, ctx); m_arith = alloc(arith_base, ctx); - m_arith64 = nullptr; - if (m_arith) - m_fid = m_arith->fid(); - else - m_fid = m_arith64->fid(); + m_fid = m_arith->fid(); } void arith_plugin::init_backup() { diff --git a/src/ast/sls/sls_smt_plugin.cpp b/src/ast/sls/sls_smt_plugin.cpp index 08d37ffa4..4e8d2f0bb 100644 --- a/src/ast/sls/sls_smt_plugin.cpp +++ b/src/ast/sls/sls_smt_plugin.cpp @@ -115,7 +115,7 @@ namespace sls { m_ddfw->rlimit().pop(); } - void smt_plugin::finalize(model_ref& mdl, ::statistics& st) { + void smt_plugin::finalize(model_ref& mdl) { auto* d = m_ddfw; if (!d) return; @@ -126,7 +126,6 @@ namespace sls { if (m_thread.joinable()) m_thread.join(); SASSERT(m_completed); - st.copy(m_st); mdl = nullptr; if (m_result == l_true && m_sls_model) { ast_translation tr(m_sls, m); @@ -140,6 +139,10 @@ namespace sls { dealloc(d); } + void smt_plugin::collect_statistics(::statistics& st) const { + st.copy(m_st); + } + void smt_plugin::get_shared_clauses(vector& _clauses) { _clauses.reset(); for (auto const& clause : clauses()) { @@ -257,7 +260,7 @@ namespace sls { void smt_plugin::sls_phase_to_smt() { if (!m_has_new_sls_phase) return; - IF_VERBOSE(2, verbose_stream() << "SLS -> SMT phase\n"); + IF_VERBOSE(2, verbose_stream() << "SLS -> SMT phase " << m_min_unsat_size << "\n"); for (auto v : m_shared_bool_vars) ctx.force_phase(sat::literal(v, !m_sls_phase[v])); m_has_new_sls_phase = false; @@ -290,7 +293,7 @@ namespace sls { } void smt_plugin::export_from_sls() { - if (unsat().size() > m_min_unsat_size) + if (unsat().size() >= m_min_unsat_size) return; m_min_unsat_size = unsat().size(); export_phase_from_sls(); diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h index 2e74b793f..6b46c3fbf 100644 --- a/src/ast/sls/sls_smt_plugin.h +++ b/src/ast/sls/sls_smt_plugin.h @@ -106,7 +106,8 @@ namespace sls { // interface to calling solver: void check(expr_ref_vector const& fmls, vector const& clauses); - void finalize(model_ref& md, ::statistics& st); + void collect_statistics(::statistics& st) const; + void finalize(model_ref& md); void get_shared_clauses(vector& clauses); void updt_params(params_ref& p) {} std::ostream& display(std::ostream& out) override; diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index 45b4c5951..f9ff3ea0f 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -72,7 +72,8 @@ namespace sls { if (!m_smt_plugin) return; - m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin->collect_statistics(m_st); + m_smt_plugin->finalize(m_model); m_model = nullptr; m_smt_plugin = nullptr; } @@ -89,7 +90,8 @@ namespace sls { return false; if (!m_smt_plugin->completed()) return false; - m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin->collect_statistics(m_st); + m_smt_plugin->finalize(m_model); m_smt_plugin = nullptr; return true; } diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp index 0127c229c..a475a551f 100644 --- a/src/smt/theory_sls.cpp +++ b/src/smt/theory_sls.cpp @@ -62,6 +62,8 @@ namespace smt { } bool theory_sls::get_smt_value(expr* v, expr_ref& value) { + if (!ctx.e_internalized(v)) + return false; auto* n = ctx.get_enode(v); return n && ctx.get_value(n, value); } @@ -78,7 +80,8 @@ namespace smt { if (!m_smt_plugin) return; - m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin->collect_statistics(m_st); + m_smt_plugin->finalize(m_model); m_model = nullptr; m_smt_plugin = nullptr; } @@ -98,7 +101,8 @@ namespace smt { else if (!m_parallel_mode) propagate_local_search(); else if (m_smt_plugin->completed()) { - m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin->collect_statistics(m_st); + m_smt_plugin->finalize(m_model); m_smt_plugin = nullptr; } } @@ -184,7 +188,10 @@ namespace smt { } void theory_sls::collect_statistics(::statistics& st) const { - st.copy(m_st); + if (m_smt_plugin) + m_smt_plugin->collect_statistics(st); + else + st.copy(m_st); } void theory_sls::restart_eh() { @@ -205,7 +212,8 @@ namespace smt { void theory_sls::bounded_run(unsigned num_steps) { m_smt_plugin->bounded_run(num_steps); if (m_smt_plugin->result() == l_true) { - m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin->collect_statistics(m_st); + m_smt_plugin->finalize(m_model); m_smt_plugin = nullptr; } }