diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 528cad536..8d82c80a9 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -2698,24 +2698,49 @@ namespace sls { auto score = lookahead(e, false); TRACE("arith_verbose", tout << "lookahead " << v << " " << mk_bounded_pp(e, m) << " := " << delta + old_value << " " << score << " (" << m_best_score << ")\n";); if (score > m_best_score) { + m_tabu_set = 0; m_best_score = score; m_best_value = new_value; m_best_expr = e; } + else if (m_config.allow_plateau && score == m_best_score && !in_tabu_set(e, new_value)) { + m_best_score = score; + m_best_expr = e; + m_best_value = new_value; + insert_tabu_set(e, num_t(1)); + } // revert back to old value update_args_value(v, old_value); } + template + bool arith_base::in_tabu_set(expr* e, num_t const& n) { + uint64_t h = hash_u_u(e->get_id(), n.hash()); + return (m_tabu_set & (1ull << (h & 64ull))) != 0; + } + + template + void arith_base::insert_tabu_set(expr* e, num_t const& n) { + uint64_t h = hash_u_u(e->get_id(), n.hash()); + m_tabu_set |= (1ull << (h & 64ull)); + } + template void arith_base::lookahead_bool(expr* e) { bool b = get_bool_value(e); set_bool_value(e, !b); auto score = lookahead(e, false); if (score > m_best_score) { + m_tabu_set = 0; m_best_score = score; m_best_expr = e; } + else if (m_config.allow_plateau && score == m_best_score && !in_tabu_set(e, num_t(1))) { + m_best_score = score; + m_best_expr = e; + insert_tabu_set(e, num_t(1)); + } set_bool_value(e, b); lookahead(e, false); } @@ -2877,6 +2902,7 @@ namespace sls { m_best_value = value(v) + delta; break; } + case arith_move_type::hillclimb_plateau: case arith_move_type::hillclimb: { for (unsigned i = 0; i < sz; ++i) add_lookahead(info, vars[(start + i) % sz]); @@ -2885,6 +2911,7 @@ namespace sls { std::stable_sort(m_updates.begin(), m_updates.end(), [](auto const& a, auto const& b) { return a.m_var < b.m_var || (a.m_var == b.m_var && a.m_delta < b.m_delta); }); m_last_expr = nullptr; sz = m_updates.size(); + flet _allow_plateau(m_config.allow_plateau, t == arith_move_type::hillclimb_plateau); for (unsigned i = 0; i < sz; ++i) { auto const& [v, delta, score] = m_updates[(start + i) % m_updates.size()]; lookahead_num(v, delta); diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 9581cf82e..9a1f9344c 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -29,6 +29,7 @@ namespace sls { enum arith_move_type { hillclimb, + hillclimb_plateau, random_update, random_inc_dec }; @@ -66,6 +67,7 @@ namespace sls { unsigned restart_next = 1000; unsigned restart_init = 1000; bool arith_use_lookahead = false; + bool allow_plateau = false; }; struct stats { @@ -331,7 +333,10 @@ namespace sls { expr_mark m_is_root; unsigned m_touched = 1; sat::bool_var_set m_fixed_atoms; + uint64_t m_tabu_set = 0; + bool in_tabu_set(expr* e, num_t const& n); + void insert_tabu_set(expr* e, num_t const& n); bool_info& get_bool_info(expr* e); bool get_bool_value(expr* e); bool get_bool_value_rec(expr* e);