diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 36ec30b27..206289708 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -267,6 +267,8 @@ namespace sat { void reinit(); + void force_restart() { m_restart_next = m_flips; } + inline unsigned num_vars() const { return m_vars.size(); } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 2fd0ebe33..f452c24a8 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -437,8 +437,8 @@ namespace sls { auto old_value = value(v); auto new_value = old_value + delta; - if (!vi.in_range(new_value)) - return false; + if (!vi.in_range(new_value)) + return false; if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) { auto const& lo = m_vars[v].m_lo; @@ -474,9 +474,11 @@ namespace sls { template void arith_base::add_update(var_t v, num_t delta) { num_t delta_out; - if (!is_permitted_update(v, delta, delta_out)) + if (!is_permitted_update(v, delta, delta_out)) return; - m_updates.push_back({ v, delta_out, compute_score(v, delta_out) }); + + + m_updates.push_back({ v, delta_out, 0 }); } // flip on the first positive score @@ -490,6 +492,16 @@ namespace sls { template bool arith_base::apply_update() { + + while (m_updates.size() > m_updates_max_size) { + auto idx = ctx.rand(m_updates.size()); + m_updates[idx] = m_updates.back(); + m_updates.pop_back(); + } + + for (auto & [v, delta, score] : m_updates) + score = compute_score(v, delta); + double sum_score = 0; for (auto const& [v, delta, score] : m_updates) @@ -509,7 +521,6 @@ namespace sls { IF_VERBOSE(10, verbose_stream() << "repair: v" << v << " := " << value(v) << " -> " << new_value << "\n"); if (update(v, new_value)) { - m_last_var = v; m_last_delta = delta; m_stats.m_num_steps++; m_vars[v].set_step(m_stats.m_num_steps, m_stats.m_num_steps + 3 + ctx.rand(10), delta); @@ -525,16 +536,28 @@ namespace sls { template bool arith_base::repair(sat::literal lit) { + verbose_stream() << "repair " << lit << " " << (ctx.is_unit(lit)?"unit":"") << "\n"; + //flet _tabu(m_use_tabu, m_use_tabu && lit != m_last_literal); + m_last_literal = lit; find_moves(lit); + static unsigned num_fail = 0; - if (apply_update()) + if (apply_update()) return true; + find_reset_moves(lit); - if (apply_update()) + if (apply_update()) return true; + + ++num_fail; + if (num_fail > 3) { + + ctx.force_restart(); + num_fail = 0; + } return false; } @@ -650,6 +673,7 @@ namespace sls { } vi.m_value = new_value; ctx.new_value_eh(e); + m_last_var = v; IF_VERBOSE(10, verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n"); @@ -1100,6 +1124,7 @@ namespace sls { template bool arith_base::propagate() { + // m_last_var = UINT_MAX; // allow to change last variable. return false; } @@ -1259,16 +1284,25 @@ namespace sls { bool lo_valid = true, hi_valid = true; bool lo_strict = false, hi_strict = false; for (auto [w, p] : monomial) { - if (!lo_valid && !hi_valid) + if (!lo_valid) break; auto const& wi = m_vars[w]; if (wi.m_lo && !wi.m_lo->is_strict && wi.m_lo->value >= 0) - lo *= power_of(value(w), p); + lo *= power_of(wi.m_lo->value, p); else lo_valid = false; - - if (hi_valid) { - // TODO + } + for (auto [w, p] : monomial) { + if (!lo_valid && !hi_valid) + break; + auto const& wi = m_vars[w]; + try { + if (wi.m_hi && !wi.m_hi->is_strict) + hi *= power_of(wi.m_hi->value, p); + else + hi_valid = false; + } + catch (overflow_exception&) { hi_valid = false; } } @@ -1278,7 +1312,7 @@ namespace sls { else add_ge(v, lo); } - if (hi_valid) { + if (lo_valid && hi_valid) { if (hi_strict) add_lt(v, hi); else @@ -1298,6 +1332,33 @@ namespace sls { if (vith.m_hi && viel.m_hi && !vith.m_hi->is_strict && !viel.m_hi->is_strict) add_le(v, std::max(vith.m_hi->value, viel.m_hi->value)); + } + switch (vi.m_op) { + case LAST_ARITH_OP: + case OP_ADD: + case OP_MUL: + break; + case OP_MOD: { + auto v2 = m_ops[vi.m_def_idx].m_arg2; + auto const& vi2 = m_vars[v2]; + if (vi2.m_lo && vi2.m_hi && vi2.m_lo->value == vi2.m_hi->value && vi2.m_lo->value > 0) { + add_le(v, vi2.m_lo->value - 1); + add_ge(v, num_t(0)); + } + break; + } + case OP_DIV: + break; + case OP_IDIV: + break; + case OP_REM: + break; + case OP_ABS: + add_ge(v, num_t(0)); + break; + default: + NOT_IMPLEMENTED_YET(); + } // TBD: can also do with other operators. } @@ -1604,12 +1665,17 @@ namespace sls { int result = 0; for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { bool old_sign = sign(bv); + auto lit = sat::literal(bv, old_sign); auto dtt_old = dtt(old_sign, *atom(bv)); auto dtt_new = dtt(old_sign, *atom(bv), coeff, delta); - if (dtt_new == 0 && dtt_old != 0) + if (dtt_new == 0 && dtt_old != 0) result += 1; - if (dtt_new != 0 && dtt_old == 0) + + if (dtt_new != 0 && dtt_old == 0) { + if (m_use_tabu && ctx.is_unit(lit)) + return 0; result -= 1; + } } if (result < 0) @@ -1698,7 +1764,7 @@ namespace sls { auto const& vi = m_vars[x]; auto const& lo = vi.m_lo; auto const& hi = vi.m_hi; - auto new_value = num_t(ctx.rand(5) - 2); + auto new_value = num_t(-2 + (int)ctx.rand(5)); if (lo && lo->value > new_value) new_value = lo->value; else if (hi && hi->value < new_value) @@ -1706,8 +1772,12 @@ namespace sls { if (new_value != value(x)) add_update(x, new_value - value(x)); else { - add_update(x, num_t(1)); - add_update(x, -num_t(1)); + add_update(x, num_t(1) - value(x)); + add_update(x, -num_t(1) - value(x)); + if (value(x) != 0) { + add_update(x, num_t(1)); + add_update(x, -num_t(1)); + } } } @@ -1722,10 +1792,9 @@ namespace sls { IF_VERBOSE(10, if (m_updates.empty()) { - verbose_stream() << *ineq << "\n"; + verbose_stream() << lit << ": " << * ineq << "\n"; for (auto const& [x, nl] : ineq->m_nonlinear) { - auto const& vi = m_vars[x]; - display(verbose_stream() << "v" << x << "\n", x) << "\n"; + display(verbose_stream(), x) << "\n"; } } verbose_stream() << "RESET moves num updates: " << lit << " " << m_updates.size() << "\n"); diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index f219c2d8b..90925b706 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -151,8 +151,10 @@ namespace sls { bool m_dscore_mode = false; vector m_updates; var_t m_last_var = 0; + sat::literal m_last_literal = sat::null_literal; num_t m_last_delta { 0 }; bool m_use_tabu = true; + unsigned m_updates_max_size = 45; arith_util a; void invariant(); diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index c08f231d8..44dcec925 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -368,9 +368,13 @@ namespace sls { return; m_initialized = true; m_unit_literals.reset(); + m_unit_indices.reset(); for (auto const& clause : s.clauses()) - if (clause.m_clause.size() == 1) + if (clause.m_clause.size() == 1) m_unit_literals.push_back(clause.m_clause[0]); + for (sat::literal lit : m_unit_literals) + m_unit_indices.insert(lit.index()); + verbose_stream() << "UNITS " << m_unit_literals << "\n"; for (auto a : m_atoms) if (a) diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 0a7d7bb5b..4530f7042 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -73,6 +73,7 @@ namespace sls { virtual void on_model(model_ref& mdl) = 0; virtual sat::bool_var add_var() = 0; virtual void add_clause(unsigned n, sat::literal const* lits) = 0; + virtual void force_restart() = 0; virtual std::ostream& display(std::ostream& out) = 0; }; @@ -101,6 +102,7 @@ namespace sls { unsigned_vector m_atom2bool_var; vector> m_parents; sat::literal_vector m_root_literals, m_unit_literals; + indexed_uint_set m_unit_indices; random_gen m_rand; bool m_initialized = false; bool m_new_constraint = false; @@ -154,8 +156,9 @@ namespace sls { unsigned rand(unsigned n) { return m_rand(n); } sat::literal_vector const& root_literals() const { return m_root_literals; } sat::literal_vector const& unit_literals() const { return m_unit_literals; } - + bool is_unit(sat::literal lit) const { return m_unit_indices.contains(lit.index()); } void reinit_relevant(); + void force_restart() { s.force_restart(); } ptr_vector const& parents(expr* e) { m_parents.reserve(e->get_id() + 1); @@ -173,6 +176,7 @@ namespace sls { ptr_vector const& subterms(); ast_manager& get_manager() { return m; } std::ostream& display(std::ostream& out) const; + std::ostream& display_all(std::ostream& out) const { return s.display(out); } void collect_statistics(statistics& st) const; void reset_statistics(); diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index d1f99d149..1960b809c 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -91,6 +91,8 @@ namespace sls { sat::bool_var add_var() override { m_dirty = true; return m_ddfw.add_var(); } void add_clause(expr* f) { m_context.add_clause(f); } + void force_restart() override { m_ddfw.force_restart(); } + void add_clause(unsigned n, sat::literal const* lits) override { m_ddfw.add(n, lits); m_new_constraint = true; diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index e079e72cc..6ba839c2a 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -136,6 +136,7 @@ namespace sls { m_ddfw->add(n, lits); m_new_clause_added = true; } + void force_restart() override { m_ddfw->force_restart(); } }; void solver::init_search() {