diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index b5e0a0eca..d1c2968ba 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -4,6 +4,7 @@ z3_add_component(ast_sls sat_ddfw.cpp sls_arith_base.cpp sls_arith_clausal.cpp + sls_arith_lookahead.cpp sls_arith_plugin.cpp sls_array_plugin.cpp sls_basic_plugin.cpp diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 58473ad31..6dc1f2eeb 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -94,7 +94,8 @@ namespace sls { plugin(ctx), m_new_terms(m), a(m), - m_clausal_sls(*this) { + m_clausal_sls(*this), + m_lookahead_sls(*this) { m_fid = a.get_family_id(); } @@ -2460,685 +2461,22 @@ namespace sls { m_stats.m_steps = 0; } - // global lookahead mode - // - - template - typename arith_base::bool_info& arith_base::get_bool_info(expr* e) { - unsigned id = e->get_id(); - if (id >= m_bool_info.size()) - m_bool_info.reserve(id + 1, bool_info(m_config.paws_init)); - return m_bool_info[id]; - } - - template - bool arith_base::get_bool_value_rec(expr* e) { - if (!is_app(e)) - return ctx.get_value(e) == l_true; - - if (is_uninterp(e)) - return ctx.get_value(e) == l_true; - - app* ap = to_app(e); - bool is_arith_eq = m.is_eq(e) && a.is_int_real(ap->get_arg(0)); - - if (ap->get_family_id() == basic_family_id && !is_arith_eq) - return get_basic_bool_value(ap); - - auto v = ctx.atom2bool_var(e); - if (v == sat::null_bool_var) - return false; - auto const* ineq = get_ineq(v); - if (!ineq) - return false; - return ineq->is_true(); - } - - template - bool arith_base::get_bool_value(expr* e) { - auto& info = get_bool_info(e); - if (info.value != l_undef) - return info.value == l_true; - - auto r = get_bool_value_rec(e); - info.value = to_lbool(r); - return r; - } + template - bool arith_base::get_basic_bool_value(app* e) { - switch (e->get_decl_kind()) { - case OP_TRUE: + bool arith_base::update_num(var_t v, num_t const& delta) { + if (delta == 0) return true; - case OP_FALSE: + if (!can_update_num(v, delta)) return false; - case OP_NOT: - return !get_bool_value(e->get_arg(0)); - case OP_AND: - return all_of(*e, [&](expr* arg) { return get_bool_value(arg); }); - case OP_OR: - return any_of(*e, [&](expr* arg) { return get_bool_value(arg); }); - case OP_XOR: - return xor_of(*e, [&](expr* arg) { return get_bool_value(arg); }); - case OP_IMPLIES: - return !get_bool_value(e->get_arg(0)) || get_bool_value(e->get_arg(1)); - case OP_EQ: - if (m.is_bool(e->get_arg(0))) - return get_bool_value(e->get_arg(0)) == get_bool_value(e->get_arg(1)); - return ctx.get_value(e->get_arg(0)) == ctx.get_value(e->get_arg(1)); - case OP_DISTINCT: - return false; - case OP_ITE: - return get_bool_value(e->get_arg(0)) ? get_bool_value(e->get_arg(1)) : get_bool_value(e->get_arg(2)); - default: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - } - return false; - } - - template - void arith_base::initialize_bool_assignment() { - for (auto t : ctx.subterms()) - if (m.is_bool(t)) - set_bool_value(t, get_bool_value_rec(t)); -#if 0 - for (auto t : ctx.subterms()) { - if (m.is_bool(t)) - verbose_stream() << mk_bounded_pp(t, m) << " := " << get_bool_value(t) << "\n"; - else - verbose_stream() << mk_bounded_pp(t, m) << " := " << ctx.get_value(t) << "\n"; - } -#endif - } - - template - void arith_base::finalize_bool_assignment() { - for (unsigned v = ctx.num_bool_vars(); v-- > 0; ) { - auto a = ctx.atom(v); - if (!a) - continue; - if (get_bool_value(a) != ctx.is_true(v)) - ctx.flip(v); - } -#if 0 - for (auto idx : ctx.unsat()) { - auto const& cl = ctx.get_clause(idx); - verbose_stream() << "clause " << cl << "\n"; - for (auto lit : cl) { - auto a = ctx.atom(lit.var()); - if (a) - verbose_stream() << lit << " " << mk_bounded_pp(a, m) << " " << get_bool_value(a) << " " << ctx.is_true(lit) << "\n"; - else - verbose_stream() << lit << " " << ctx.is_true(lit) << "\n"; - } - } -#endif - - } - - template - double arith_base::new_score(expr* e) { - return new_score(e, true); - } - - template - double arith_base::new_score(expr* a, bool is_true) { - bool is_true_new = get_bool_value(a); - - if (is_true == is_true_new) - return 1; - if (is_uninterp(a)) - return 0; - if (m.is_true(a)) - return is_true ? 1 : 0; - if (m.is_false(a)) - return is_true ? 0 : 1; - expr* x, * y, * z; - if (m.is_not(a, x)) - return new_score(x, !is_true); - if ((m.is_and(a) && is_true) || (m.is_or(a) && !is_true)) { - double score = 1; - for (auto arg : *to_app(a)) - score = std::min(score, new_score(arg, is_true)); - return score; - } - if ((m.is_and(a) && !is_true) || (m.is_or(a) && is_true)) { - double score = 0; - for (auto arg : *to_app(a)) - score = std::max(score, new_score(arg, is_true)); - return score; - } - if (m.is_iff(a, x, y)) { - auto v0 = get_bool_value(x); - auto v1 = get_bool_value(y); - return (is_true == (v0 == v1)) ? 1 : 0; - } - if (m.is_ite(a, x, y, z)) - return get_bool_value(x) ? new_score(y, is_true) : new_score(z, is_true); - - - auto v = ctx.atom2bool_var(a); - if (v == sat::null_bool_var) - return 0; - auto const* ineq = get_ineq(v); - if (!ineq) - return 0; - - auto const& args = ineq->m_args_value; - auto const& coeff = ineq->m_coeff; - auto value = args + coeff; - - switch (ineq->m_op) { - case ineq_kind::LE: - if (is_true) { - if (value <= 0) - return 1.0; - } - else { - if (value > 0) - return 1.0; - value = -value + 1; - } - break; - case ineq_kind::LT: - if (is_true) { - if (value < 0) - return 1.0; - } - else { - if (value >= 0) - return 1.0; - value = -value; - } - break; - case ineq_kind::EQ: - if (is_true) { - if (value == 0) - return 1.0; - if (value < 0) - value = -value; - } - else { - if (value != 0) - return 1.0; - return 0.0; - } - break; - } - - - SASSERT(value > 0); - unsigned max_value = 1000; - if (value > max_value) - return 0.0; - auto d = value.get_double(); - double score = 1.0 - ((d * d) / ((double)max_value * (double)max_value)); - //score = 1.0 - d / max_value; - return score; - } - - template - void arith_base::rescore() { - m_top_score = 0; - m_is_root.reset(); - for (auto a : ctx.input_assertions()) { - double score = new_score(a); - set_score(a, score); - m_top_score += score; - m_is_root.mark(a); - } - } - - template - void arith_base::recalibrate_weights() { - for (auto a : ctx.input_assertions()) { - if (ctx.rand(2047) < m_config.paws_sp) { - if (get_bool_value(a)) - dec_weight(a); - } - else if (!get_bool_value(a)) - inc_weight(a); - } - } - - template - void arith_base::insert_update_stack_rec(expr* t) { - m_min_depth = m_max_depth = get_depth(t); - insert_update_stack(t); - for (unsigned depth = m_max_depth; depth <= m_max_depth; ++depth) { - for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { - auto a = m_update_stack[depth][i]; - for (auto p : ctx.parents(a)) { - insert_update_stack(p); - m_max_depth = std::max(m_max_depth, get_depth(p)); - } - } - } - m_update_stack.reserve(m_max_depth + 1); - } - template - double arith_base::lookahead(expr* t, bool update_score) { - ctx.rlimit().inc(); - SASSERT(a.is_int_real(t) || m.is_bool(t)); - double score = m_top_score; - for (unsigned depth = m_min_depth; depth <= m_max_depth; ++depth) { - for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { - auto* a = m_update_stack[depth][i]; - TRACE("arith_verbose", tout << "update " << mk_bounded_pp(a, m) << " depth: " << depth << "\n";); - if (t != a) - set_bool_value(a, get_bool_value_rec(a)); - if (m_is_root.is_marked(a)) { - auto nscore = new_score(a); - score += get_weight(a) * (nscore - old_score(a)); - if (update_score) - set_score(a, nscore); - } - } - } - return score; - } - - template - void arith_base::insert_update_stack(expr* t) { - unsigned depth = get_depth(t); - m_update_stack.reserve(depth + 1); - if (!m_in_update_stack.is_marked(t) && is_app(t)) { - m_in_update_stack.mark(t); - m_update_stack[depth].push_back(to_app(t)); - } - } - - template - void arith_base::clear_update_stack() { - m_in_update_stack.reset(); - m_update_stack.reserve(m_max_depth + 1); - for (unsigned i = m_min_depth; i <= m_max_depth; ++i) - m_update_stack[i].reset(); - } - - template - void arith_base::lookahead_num(var_t v, num_t const& delta) { - num_t old_value = value(v); - expr* e = m_vars[v].m_expr; - if (m_last_expr != e) { - if (m_last_expr) - lookahead(m_last_expr, false); - clear_update_stack(); - insert_update_stack_rec(e); - m_last_expr = e; - } - else if (m_last_delta == delta) - return; - m_last_delta = delta; - + auto& vi = m_vars[v]; + auto old_value = vi.value(); num_t new_value = old_value + delta; - - if (!update_num(v, delta)) - return; - auto score = lookahead(e, false); - TRACE("arith_verbose", tout << "lookahead " << v << " " << mk_bounded_pp(e, m) << " := " << delta + old_value << " " << score << " (" << m_best_score << ")\n";); - if (score > m_best_score) { - m_tabu_set = 0; - m_best_score = score; - m_best_value = new_value; - m_best_expr = e; - } - else if (m_config.allow_plateau && score == m_best_score && !in_tabu_set(e, new_value)) { - m_best_score = score; - m_best_expr = e; - m_best_value = new_value; - insert_tabu_set(e, new_value); - //verbose_stream() << "plateau " << mk_bounded_pp(e, m) << " := " << m_best_value << "\n"; - } - - // revert back to old value - update_args_value(v, old_value); + update_args_value(v, new_value); + return true; } - - template - bool arith_base::in_tabu_set(expr* e, num_t const& n) { - uint64_t h = hash_u_u(e->get_id(), n.hash()); - return (m_tabu_set & (1ull << (h & 63ull))) != 0; - } - - template - void arith_base::insert_tabu_set(expr* e, num_t const& n) { - uint64_t h = hash_u_u(e->get_id(), n.hash()); - m_tabu_set |= (1ull << (h & 63ull)); - } - - template - void arith_base::lookahead_bool(expr* e) { - bool b = get_bool_value(e); - set_bool_value(e, !b); - insert_update_stack_rec(e); - auto score = lookahead(e, false); - if (score > m_best_score) { - m_tabu_set = 0; - m_best_score = score; - m_best_expr = e; - } - else if (m_config.allow_plateau && score == m_best_score && !in_tabu_set(e, num_t(1))) { - m_best_score = score; - m_best_expr = e; - insert_tabu_set(e, num_t(1)); - } - set_bool_value(e, b); - lookahead(e, false); - clear_update_stack(); - m_last_expr = nullptr; - } - - template - void arith_base::add_lookahead(bool_info& i, sat::bool_var bv) { - if (!i.fixable_atoms.contains(bv)) - return; - if (m_fixed_atoms.contains(bv)) - return; - auto* ineq = get_ineq(bv); - if (!ineq) - return; - num_t na, nb; - for (auto const& [x, nl] : ineq->m_nonlinear) { - if (!i.fixable_vars.contains(x)) - continue; - if (is_fixed(x)) - continue; - if (is_linear(x, nl, nb)) - find_linear_moves(*ineq, x, nb); - else if (is_quadratic(x, nl, na, nb)) - find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); - else - ; - } - m_fixed_atoms.insert(bv); - } - - - // for every variable e, for every atom containing e - // add lookahead for e. - // m_fixable_atoms contains atoms that can be fixed. - // m_fixable_vars contains variables that can be updated. - template - void arith_base::add_lookahead(bool_info& i, expr* e) { - - auto add_finite_domain = [&](var_t v) { - auto old_value = value(v); - for (auto const& n : m_vars[v].m_finite_domain) - add_update(v, n - old_value); - }; - - - if (m.is_bool(e)) { - auto bv = ctx.atom2bool_var(e); - if (i.fixable_atoms.contains(bv)) - lookahead_bool(e); - } - else if (a.is_int_real(e)) { - auto v = mk_term(e); - auto& vi = m_vars[v]; - if (false && !vi.m_finite_domain.empty()) { - add_finite_domain(v); - return; - } - for (auto bv : vi.m_bool_vars_of) - add_lookahead(i, bv); - } - } - - // - // e is a formula that is false, - // assemble candidates that can flip the formula to true. - // candidate expressions may be either numeric or boolean variables. - // - template - ptr_vector const& arith_base::get_fixable_exprs(expr* e) { - auto& i = get_bool_info(e); - if (!i.fixable_exprs.empty()) - return i.fixable_exprs; - expr_mark visited; - ptr_buffer todo; - - m_tmp_set.reset(); - - todo.push_back(e); - while (!todo.empty()) { - auto e = todo.back(); - todo.pop_back(); - if (visited.is_marked(e)) - continue; - visited.mark(e); - if (m.is_xor(e) || m.is_and(e) || m.is_or(e) || m.is_implies(e) || m.is_iff(e) || m.is_ite(e) || m.is_not(e)) { - for (auto arg : *to_app(e)) - todo.push_back(arg); - } - else { - auto bv = ctx.atom2bool_var(e); - if (bv == sat::null_bool_var) - continue; - if (is_uninterp(e)) { - if (!i.fixable_atoms.contains(bv)) { - i.fixable_atoms.push_back(bv); - i.fixable_exprs.push_back(e); - } - continue; - } - auto* ineq = get_ineq(bv); - if (!ineq) - continue; - i.fixable_atoms.push_back(bv); - buffer vars; - - for (auto& [v, occ] : ineq->m_nonlinear) - vars.push_back(v); - - for (unsigned j = 0; j < vars.size(); ++j) { - auto v = vars[j]; - if (m_tmp_set.contains(v)) - continue; - - if (is_add(v)) { - for (auto [c, w] : get_add(v).m_args) - vars.push_back(w); - } - else if (is_mul(v)) { - for (auto [w, p] : get_mul(v).m_monomial) - vars.push_back(w); - } - else { - i.fixable_exprs.push_back(m_vars[v].m_expr); - m_tmp_set.insert(v); - } - } - } - } - for (auto v : m_tmp_set) - i.fixable_vars.push_back(v); - return i.fixable_exprs; - } - - template - bool arith_base::apply_move(expr* f, ptr_vector const& vars, arith_move_type t) { - if (vars.empty()) - return false; - auto& info = get_bool_info(f); - m_best_expr = nullptr; - m_best_score = m_top_score; - unsigned sz = vars.size(); - unsigned start = ctx.rand(); - m_updates.reset(); - m_fixed_atoms.reset(); - - switch (t) { - case arith_move_type::random_update: { - for (unsigned i = 0; i < sz; ++i) - add_lookahead(info, vars[(start + i) % sz]); - if (m_updates.empty()) - return false; - unsigned idx = ctx.rand(m_updates.size()); - auto& [v, delta, score] = m_updates[idx]; - m_best_expr = m_vars[v].m_expr; - if (false && !m_vars[v].m_finite_domain.empty()) - m_best_value = m_vars[v].m_finite_domain[ctx.rand() % m_vars[v].m_finite_domain.size()]; - else - m_best_value = value(v) + delta; - m_tabu_set = 0; - break; - } - case arith_move_type::hillclimb_plateau: - case arith_move_type::hillclimb: { - for (unsigned i = 0; i < sz; ++i) - add_lookahead(info, vars[(start + i) % sz]); - if (m_updates.empty()) - return false; - std::stable_sort(m_updates.begin(), m_updates.end(), [](auto const& a, auto const& b) { return a.m_var < b.m_var || (a.m_var == b.m_var && a.m_delta < b.m_delta); }); - m_last_expr = nullptr; - sz = m_updates.size(); - flet _allow_plateau(m_config.allow_plateau, m_config.allow_plateau || t == arith_move_type::hillclimb_plateau); - for (unsigned i = 0; i < sz; ++i) { - auto const& [v, delta, score] = m_updates[(start + i) % m_updates.size()]; - lookahead_num(v, delta); - } - if (m_last_expr) { - lookahead(m_last_expr, false); - clear_update_stack(); - } - break; - } - case arith_move_type::random_inc_dec: { - auto e = vars[ctx.rand() % sz]; - m_best_expr = e; - if (a.is_int_real(e)) { - var_t v = mk_term(e); - auto& vi = m_vars[v]; - if (!vi.m_finite_domain.empty()) - m_best_value = vi.m_finite_domain[ctx.rand() % vi.m_finite_domain.size()]; - else if (ctx.rand(2) == 0) - m_best_value = value(v) + 1; - else - m_best_value = value(v) - 1; - } - m_tabu_set = 0; - break; - } - } - - if (m_best_expr) { - if (m.is_bool(m_best_expr)) - set_bool_value(m_best_expr, !get_bool_value(m_best_expr)); - else { - var_t v = mk_term(m_best_expr); - if (!update_num(v, m_best_value - value(v))) { - TRACE("arith", - tout << "could not move v" << v << " " << t << " " << mk_bounded_pp(m_best_expr, m) << " := " << value(v) << " " << m_top_score << "\n"; - ); - return false; - } - } - insert_update_stack_rec(m_best_expr); - m_top_score = lookahead(m_best_expr, true); - clear_update_stack(); - } - - CTRACE("arith", !m_best_expr, tout << "no move " << t << "\n";); - CTRACE("arith", m_best_expr && a.is_int_real(m_best_expr), { - var_t v = mk_term(m_best_expr); - tout << t << " v" << v << " " << mk_bounded_pp(m_best_expr, m) << " := " << value(v) << " " << m_top_score << "\n"; - }); - return !!m_best_expr; - } - - - std::ostream& operator<<(std::ostream& out, arith_move_type mt) { - switch (mt) { - case arith_move_type::random_update: out << "random-update"; break; - case arith_move_type::hillclimb: out << "hillclimb"; break; - case arith_move_type::random_inc_dec: out << "random-inc-dec"; break; - case arith_move_type::hillclimb_plateau: out << "hillclimb-plateau"; break; - } - return out; - } - - template - void arith_base::global_search() { - initialize_bool_assignment(); - rescore(); - m_config.max_moves = m_stats.m_steps + m_config.max_moves_base; - TRACE("arith", tout << "search " << m_stats.m_steps << " " << m_config.max_moves << "\n";); - IF_VERBOSE(3, verbose_stream() << "lookahead-search steps:" << m_stats.m_steps << " max-moves:" << m_config.max_moves << "\n"); - TRACE("arith", display(tout)); - - while (ctx.rlimit().inc() && m_stats.m_steps < m_config.max_moves) { - m_stats.m_steps++; - check_restart(); - - auto t = get_candidate_unsat(); - - if (!t) - break; - - auto& vars = get_fixable_exprs(t); - - if (vars.empty()) - break; - - if (ctx.rand(2047) < m_config.wp && apply_move(t, vars, arith_move_type::random_inc_dec)) - continue; - - if (apply_move(t, vars, arith_move_type::hillclimb)) - continue; - - if (apply_move(t, vars, arith_move_type::random_update)) - recalibrate_weights(); - } - if (m_stats.m_steps >= m_config.max_moves) - m_config.max_moves_base += 100; - finalize_bool_assignment(); - } - - template - expr* arith_base::get_candidate_unsat() { - expr* e = nullptr; - if (m_config.ucb) { - double max = -1.0; - for (auto a : ctx.input_assertions()) { - if (get_bool_value(a)) - continue; - - auto const& vars = get_fixable_exprs(a); - if (vars.empty()) - continue; - auto score = old_score(a); - auto q = score - + m_config.ucb_constant * ::sqrt(log((double)m_touched) / get_touched(a)) - + m_config.ucb_noise * ctx.rand(512); - if (q > max) - max = q, e = a; - } - if (e) { - m_touched++; - inc_touched(e); - } - } - else { - unsigned n = 0; - for (auto a : ctx.input_assertions()) - if (!get_bool_value(a) && !get_fixable_exprs(a).empty() && ctx.rand() % ++n == 0) - e = a; - } - - m_last_atom = e; - CTRACE("arith", !e, tout << "no unsatisfiable candidate\n";); - CTRACE("arith", e, - tout << "select " << mk_bounded_pp(e, m) << " "; - for (auto v : get_fixable_exprs(e)) - tout << mk_bounded_pp(v, m) << " "; - tout << "\n"); - return e; - } - + template bool arith_base::can_update_num(var_t v, num_t const& delta) { num_t old_value = value(v); @@ -3172,19 +2510,6 @@ namespace sls { return true; } - template - bool arith_base::update_num(var_t v, num_t const& delta) { - if (delta == 0) - return true; - if (!can_update_num(v, delta)) - return false; - auto& vi = m_vars[v]; - auto old_value = vi.value(); - num_t new_value = old_value + delta; - update_args_value(v, new_value); - return true; - } - template void arith_base::update_args_value(var_t v, num_t const& new_value) { auto& vi = m_vars[v]; @@ -3217,39 +2542,8 @@ namespace sls { } } - template - void arith_base::check_restart() { - if (m_stats.m_steps % m_config.restart_base == 0) { - ucb_forget(); - rescore(); - } + - if (m_stats.m_steps < m_config.restart_next) - return; - - ++m_stats.m_restarts; - m_config.restart_next = std::max(m_config.restart_next, m_stats.m_steps); - - if (0x1 == (m_stats.m_restarts & 0x1)) - m_config.restart_next += m_config.restart_base; - else - m_config.restart_next += (2 * (m_stats.m_restarts >> 1)) * m_config.restart_base; - - // reset_uninterp_in_false_literals - rescore(); - } - - template - void arith_base::ucb_forget() { - if (m_config.ucb_forget >= 1.0) - return; - for (auto a : ctx.input_assertions()) { - auto touched_old = get_touched(a); - auto touched_new = static_cast((touched_old - 1) * m_config.ucb_forget + 1); - set_touched(a, touched_new); - m_touched += touched_new - touched_old; - } - } template void arith_base::updt_params() { @@ -3281,7 +2575,7 @@ namespace sls { if (m_config.use_clausal_lookahead) m_clausal_sls.search(); else if (m_config.use_lookahead) - global_search(); + m_lookahead_sls.search(); } } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 80a85fca3..e537019c3 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -23,20 +23,12 @@ Author: #include "ast/arith_decl_plugin.h" #include "ast/sls/sls_context.h" #include "ast/sls/sls_arith_clausal.h" +#include "ast/sls/sls_arith_lookahead.h" namespace sls { using theory_var = int; - enum arith_move_type { - hillclimb, - hillclimb_plateau, - random_update, - random_inc_dec - }; - - std::ostream& operator<<(std::ostream& out, arith_move_type mt); - static const unsigned null_arith_var = UINT_MAX; // local search portion for arithmetic @@ -213,7 +205,9 @@ namespace sls { unsigned m_updates_max_size = 45; arith_util a; friend class arith_clausal; + friend class arith_lookahead; arith_clausal m_clausal_sls; + arith_lookahead m_lookahead_sls; svector m_prob_break; indexed_uint_set m_bool_var_atoms; indexed_uint_set m_tmp_set; @@ -325,72 +319,9 @@ namespace sls { std::ostream& display(std::ostream& out, add_def const& ad) const; std::ostream& display(std::ostream& out, mul_def const& md) const; - - - // for global lookahead search mode - void global_search(); - struct bool_info { - unsigned weight = 0; - double score = 0; - unsigned touched = 1; - lbool value = l_undef; - sat::bool_var_vector fixable_atoms; - svector fixable_vars; - ptr_vector fixable_exprs; - bool_info(unsigned w) : weight(w) {} - }; - - vector> m_update_stack; - expr_mark m_in_update_stack; - svector m_bool_info; - double m_best_score = 0, m_top_score = 0; - unsigned m_min_depth = 0, m_max_depth = 0; - num_t m_best_value; - expr* m_best_expr = nullptr, * m_last_atom = nullptr, * m_last_expr = nullptr; - expr_mark m_is_root; - unsigned m_touched = 1; - sat::bool_var_set m_fixed_atoms; - uint64_t m_tabu_set = 0; - unsigned m_global_search_count = 0; - - bool in_tabu_set(expr* e, num_t const& n); - void insert_tabu_set(expr* e, num_t const& n); - bool_info& get_bool_info(expr* e); - bool get_bool_value(expr* e); - bool get_bool_value_rec(expr* e); - void set_bool_value(expr* e, bool v) { get_bool_info(e).value = to_lbool(v); } - bool get_basic_bool_value(app* e); - void initialize_bool_assignment(); - - void finalize_bool_assignment(); - double old_score(expr* e) { return get_bool_info(e).score; } - double new_score(expr* e); - double new_score(expr* e, bool is_true); - void set_score(expr* e, double s) { get_bool_info(e).score = s; } - void rescore(); - void recalibrate_weights(); - void inc_weight(expr* e) { ++get_bool_info(e).weight; } - void dec_weight(expr* e) { auto& i = get_bool_info(e); i.weight = i.weight > m_config.paws_init ? i.weight - 1 : m_config.paws_init; } - unsigned get_weight(expr* e) { return get_bool_info(e).weight; } - unsigned get_touched(expr* e) { return get_bool_info(e).touched; } - void inc_touched(expr* e) { ++get_bool_info(e).touched; } - void set_touched(expr* e, unsigned t) { get_bool_info(e).touched = t; } - void insert_update_stack(expr* t); - void insert_update_stack_rec(expr* t); - void clear_update_stack(); - void lookahead_num(var_t v, num_t const& value); + void update_args_value(var_t v, num_t const& new_value); bool can_update_num(var_t v, num_t const& delta); bool update_num(var_t v, num_t const& delta); - void lookahead_bool(expr* e); - double lookahead(expr* e, bool update_score); - void add_lookahead(bool_info& i, expr* e); - void add_lookahead(bool_info& i, sat::bool_var bv); - ptr_vector const& get_fixable_exprs(expr* e); - bool apply_move(expr* f, ptr_vector const& vars, arith_move_type t); - expr* get_candidate_unsat(); - void check_restart(); - void ucb_forget(); - void update_args_value(var_t v, num_t const& new_value); public: arith_base(context& ctx); ~arith_base() override {} diff --git a/src/ast/sls/sls_arith_clausal.cpp b/src/ast/sls/sls_arith_clausal.cpp index 837a67b9f..43fe047c2 100644 --- a/src/ast/sls/sls_arith_clausal.cpp +++ b/src/ast/sls/sls_arith_clausal.cpp @@ -91,7 +91,7 @@ namespace sls { var_t v = null_arith_var; { - a.m_best_score = 1; + m_best_score = 1; flet _use_tabu(a.m_use_tabu, true); if (v == null_arith_var) { add_lookahead_on_unsat_vars(); @@ -109,7 +109,7 @@ namespace sls { ctx.shift_weights(); if (v == null_arith_var) { - a.m_best_score = -1; + m_best_score = -1; flet _use_tabu(a.m_use_tabu, false); add_lookahead_on_unsat_vars(); v = random_move_on_updates(); @@ -245,13 +245,13 @@ namespace sls { num_t abs_value = abs(vi.value() + delta); unsigned last_step = vi.last_step(delta); ++m_num_lookaheads; - if (score < a.m_best_score) + if (score < m_best_score) return; - if (score > a.m_best_score || + if (score > m_best_score || (m_best_abs_value == -1) || (abs_value < m_best_abs_value) || (abs_value == m_best_abs_value && last_step < m_best_last_step)) { - a.m_best_score = score; + m_best_score = score; m_best_var = v; m_best_delta = delta; m_best_last_step = last_step; @@ -357,7 +357,6 @@ namespace sls { template void arith_clausal::initialize() { - a.initialize_bool_assignment(); for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) a.init_bool_var_assignment(v); diff --git a/src/ast/sls/sls_arith_clausal.h b/src/ast/sls/sls_arith_clausal.h index 32413bf06..b8d1c4294 100644 --- a/src/ast/sls/sls_arith_clausal.h +++ b/src/ast/sls/sls_arith_clausal.h @@ -80,6 +80,7 @@ namespace sls { var_t m_best_var = UINT_MAX; unsigned m_best_last_step = 0; unsigned m_num_lookaheads = 0; + double m_best_score = 0; // avoid checking the same updates twice var_t m_last_var = UINT_MAX; diff --git a/src/ast/sls/sls_arith_lookahead.cpp b/src/ast/sls/sls_arith_lookahead.cpp new file mode 100644 index 000000000..7c665390f --- /dev/null +++ b/src/ast/sls/sls_arith_lookahead.cpp @@ -0,0 +1,756 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + sls_arith_lookahead + + +Author: + + Nikolaj Bjorner (nbjorner) 2025-01-16 + +--*/ + +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "ast/sls/sls_arith_lookahead.h" +#include "ast/sls/sls_arith_base.h" + +namespace sls { + template + arith_lookahead::arith_lookahead(arith_base& a) : + ctx(a.ctx), + m(a.m), + a(a), + autil(m) { + } + + template + typename arith_lookahead::bool_info& arith_lookahead::get_bool_info(expr* e) { + unsigned id = e->get_id(); + if (id >= m_bool_info.size()) + m_bool_info.reserve(id + 1, bool_info(a.m_config.paws_init)); + return m_bool_info[id]; + } + + template + bool arith_lookahead::get_bool_value_rec(expr* e) { + if (!is_app(e)) + return ctx.get_value(e) == l_true; + + if (is_uninterp(e)) + return ctx.get_value(e) == l_true; + + app* ap = to_app(e); + bool is_arith_eq = m.is_eq(e) && autil.is_int_real(ap->get_arg(0)); + + if (ap->get_family_id() == basic_family_id && !is_arith_eq) + return get_basic_bool_value(ap); + + auto v = ctx.atom2bool_var(e); + if (v == sat::null_bool_var) + return false; + auto const* ineq = a.get_ineq(v); + if (!ineq) + return false; + return ineq->is_true(); + } + + template + bool arith_lookahead::get_bool_value(expr* e) { + auto& info = get_bool_info(e); + if (info.value != l_undef) + return info.value == l_true; + + auto r = get_bool_value_rec(e); + info.value = to_lbool(r); + return r; + } + + + template + bool arith_lookahead::get_basic_bool_value(app* e) { + switch (e->get_decl_kind()) { + case OP_TRUE: + return true; + case OP_FALSE: + return false; + case OP_NOT: + return !get_bool_value(e->get_arg(0)); + case OP_AND: + return all_of(*e, [&](expr* arg) { return get_bool_value(arg); }); + case OP_OR: + return any_of(*e, [&](expr* arg) { return get_bool_value(arg); }); + case OP_XOR: + return xor_of(*e, [&](expr* arg) { return get_bool_value(arg); }); + case OP_IMPLIES: + return !get_bool_value(e->get_arg(0)) || get_bool_value(e->get_arg(1)); + case OP_EQ: + if (m.is_bool(e->get_arg(0))) + return get_bool_value(e->get_arg(0)) == get_bool_value(e->get_arg(1)); + return ctx.get_value(e->get_arg(0)) == ctx.get_value(e->get_arg(1)); + case OP_DISTINCT: + return false; + case OP_ITE: + return get_bool_value(e->get_arg(0)) ? get_bool_value(e->get_arg(1)) : get_bool_value(e->get_arg(2)); + default: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + } + return false; + } + + + template + double arith_lookahead::new_score(expr* e) { + return new_score(e, true); + } + + template + double arith_lookahead::new_score(expr* e, bool is_true) { + bool is_true_new = get_bool_value(e); + + if (is_true == is_true_new) + return 1; + if (is_uninterp(e)) + return 0; + if (m.is_true(e)) + return is_true ? 1 : 0; + if (m.is_false(e)) + return is_true ? 0 : 1; + expr* x, * y, * z; + if (m.is_not(e, x)) + return new_score(x, !is_true); + if ((m.is_and(e) && is_true) || (m.is_or(e) && !is_true)) { + double score = 1; + for (auto arg : *to_app(e)) + score = std::min(score, new_score(arg, is_true)); + return score; + } + if ((m.is_and(e) && !is_true) || (m.is_or(e) && is_true)) { + double score = 0; + for (auto arg : *to_app(e)) + score = std::max(score, new_score(arg, is_true)); + return score; + } + if (m.is_iff(e, x, y)) { + auto v0 = get_bool_value(x); + auto v1 = get_bool_value(y); + return (is_true == (v0 == v1)) ? 1 : 0; + } + if (m.is_ite(e, x, y, z)) + return get_bool_value(x) ? new_score(y, is_true) : new_score(z, is_true); + + + auto v = ctx.atom2bool_var(e); + if (v == sat::null_bool_var) + return 0; + auto const* ineq = a.get_ineq(v); + if (!ineq) + return 0; + + auto const& args = ineq->m_args_value; + auto const& coeff = ineq->m_coeff; + auto value = args + coeff; + + switch (ineq->m_op) { + case arith_base::ineq_kind::LE: + if (is_true) { + if (value <= 0) + return 1.0; + } + else { + if (value > 0) + return 1.0; + value = -value + 1; + } + break; + case arith_base::ineq_kind::LT: + if (is_true) { + if (value < 0) + return 1.0; + } + else { + if (value >= 0) + return 1.0; + value = -value; + } + break; + case arith_base::ineq_kind::EQ: + if (is_true) { + if (value == 0) + return 1.0; + if (value < 0) + value = -value; + } + else { + if (value != 0) + return 1.0; + return 0.0; + } + break; + } + + + SASSERT(value > 0); + unsigned max_value = 1000; + if (value > max_value) + return 0.0; + auto d = value.get_double(); + double score = 1.0 - ((d * d) / ((double)max_value * (double)max_value)); + //score = 1.0 - d / max_value; + return score; + } + + template + void arith_lookahead::rescore() { + m_top_score = 0; + m_is_root.reset(); + for (auto a : ctx.input_assertions()) { + double score = new_score(a); + set_score(a, score); + m_top_score += score; + m_is_root.mark(a); + } + } + + template + void arith_lookahead::recalibrate_weights() { + for (auto f : ctx.input_assertions()) { + if (ctx.rand(2047) < a.m_config.paws_sp) { + if (get_bool_value(f)) + dec_weight(f); + } + else if (!get_bool_value(f)) + inc_weight(f); + } + } + + template + void arith_lookahead::dec_weight(expr* e) { + auto& i = get_bool_info(e); + i.weight = i.weight > a.m_config.paws_init ? i.weight - 1 : a.m_config.paws_init; + } + + template + void arith_lookahead::insert_update_stack_rec(expr* t) { + m_min_depth = m_max_depth = get_depth(t); + insert_update_stack(t); + for (unsigned depth = m_max_depth; depth <= m_max_depth; ++depth) { + for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { + auto a = m_update_stack[depth][i]; + for (auto p : ctx.parents(a)) { + insert_update_stack(p); + m_max_depth = std::max(m_max_depth, get_depth(p)); + } + } + } + m_update_stack.reserve(m_max_depth + 1); + } + template + double arith_lookahead::lookahead(expr* t, bool update_score) { + ctx.rlimit().inc(); + SASSERT(a.is_int_real(t) || m.is_bool(t)); + double score = m_top_score; + for (unsigned depth = m_min_depth; depth <= m_max_depth; ++depth) { + for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { + auto* a = m_update_stack[depth][i]; + TRACE("arith_verbose", tout << "update " << mk_bounded_pp(a, m) << " depth: " << depth << "\n";); + if (t != a) + set_bool_value(a, get_bool_value_rec(a)); + if (m_is_root.is_marked(a)) { + auto nscore = new_score(a); + score += get_weight(a) * (nscore - old_score(a)); + if (update_score) + set_score(a, nscore); + } + } + } + return score; + } + + template + void arith_lookahead::insert_update_stack(expr* t) { + unsigned depth = get_depth(t); + m_update_stack.reserve(depth + 1); + if (!m_in_update_stack.is_marked(t) && is_app(t)) { + m_in_update_stack.mark(t); + m_update_stack[depth].push_back(to_app(t)); + } + } + + template + void arith_lookahead::clear_update_stack() { + m_in_update_stack.reset(); + m_update_stack.reserve(m_max_depth + 1); + for (unsigned i = m_min_depth; i <= m_max_depth; ++i) + m_update_stack[i].reset(); + } + + template + void arith_lookahead::lookahead_num(var_t v, num_t const& delta) { + num_t old_value = a.value(v); + expr* e = a.m_vars[v].m_expr; + if (m_last_expr != e) { + if (m_last_expr) + lookahead(m_last_expr, false); + clear_update_stack(); + insert_update_stack_rec(e); + m_last_expr = e; + } + else if (a.m_last_delta == delta) + return; + a.m_last_delta = delta; + + num_t new_value = old_value + delta; + + if (!a.update_num(v, delta)) + return; + auto score = lookahead(e, false); + TRACE("arith_verbose", tout << "lookahead " << v << " " << mk_bounded_pp(e, m) << " := " << delta + old_value << " " << score << " (" << m_best_score << ")\n";); + if (score > m_best_score) { + m_tabu_set = 0; + m_best_score = score; + m_best_value = new_value; + m_best_expr = e; + } + else if (a.m_config.allow_plateau && score == m_best_score && !in_tabu_set(e, new_value)) { + m_best_score = score; + m_best_expr = e; + m_best_value = new_value; + insert_tabu_set(e, new_value); + //verbose_stream() << "plateau " << mk_bounded_pp(e, m) << " := " << m_best_value << "\n"; + } + + // revert back to old value + a.update_args_value(v, old_value); + } + + template + bool arith_lookahead::in_tabu_set(expr* e, num_t const& n) { + uint64_t h = hash_u_u(e->get_id(), n.hash()); + return (m_tabu_set & (1ull << (h & 63ull))) != 0; + } + + template + void arith_lookahead::insert_tabu_set(expr* e, num_t const& n) { + uint64_t h = hash_u_u(e->get_id(), n.hash()); + m_tabu_set |= (1ull << (h & 63ull)); + } + + template + void arith_lookahead::lookahead_bool(expr* e) { + bool b = get_bool_value(e); + set_bool_value(e, !b); + insert_update_stack_rec(e); + auto score = lookahead(e, false); + if (score > m_best_score) { + m_tabu_set = 0; + m_best_score = score; + m_best_expr = e; + } + else if (a.m_config.allow_plateau && score == m_best_score && !in_tabu_set(e, num_t(1))) { + m_best_score = score; + m_best_expr = e; + insert_tabu_set(e, num_t(1)); + } + set_bool_value(e, b); + lookahead(e, false); + clear_update_stack(); + m_last_expr = nullptr; + } + + template + void arith_lookahead::add_lookahead(bool_info& i, sat::bool_var bv) { + if (!i.fixable_atoms.contains(bv)) + return; + if (m_fixed_atoms.contains(bv)) + return; + auto* ineq = a.get_ineq(bv); + if (!ineq) + return; + num_t na, nb; + for (auto const& [x, nl] : ineq->m_nonlinear) { + if (!i.fixable_vars.contains(x)) + continue; + if (a.is_fixed(x)) + continue; + if (a.is_linear(x, nl, nb)) + a.find_linear_moves(*ineq, x, nb); + else if (a.is_quadratic(x, nl, na, nb)) + a.find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); + else + ; + } + m_fixed_atoms.insert(bv); + } + + + // for every variable e, for every atom containing e + // add lookahead for e. + // m_fixable_atoms contains atoms that can be fixed. + // m_fixable_vars contains variables that can be updated. + template + void arith_lookahead::add_lookahead(bool_info& i, expr* e) { + + auto add_finite_domain = [&](var_t v) { + auto old_value = a.value(v); + for (auto const& n : a.m_vars[v].m_finite_domain) + a.add_update(v, n - old_value); + }; + + + if (m.is_bool(e)) { + auto bv = ctx.atom2bool_var(e); + if (i.fixable_atoms.contains(bv)) + lookahead_bool(e); + } + else if (autil.is_int_real(e)) { + auto v = a.mk_term(e); + auto& vi = a.m_vars[v]; + if (false && !vi.m_finite_domain.empty()) { + add_finite_domain(v); + return; + } + for (auto bv : vi.m_bool_vars_of) + add_lookahead(i, bv); + } + } + + // + // e is a formula that is false, + // assemble candidates that can flip the formula to true. + // candidate expressions may be either numeric or boolean variables. + // + template + ptr_vector const& arith_lookahead::get_fixable_exprs(expr* e) { + auto& i = get_bool_info(e); + if (!i.fixable_exprs.empty()) + return i.fixable_exprs; + expr_mark visited; + ptr_buffer todo; + + auto& tmp_set = a.m_tmp_set; + tmp_set.reset(); + + todo.push_back(e); + while (!todo.empty()) { + auto e = todo.back(); + todo.pop_back(); + if (visited.is_marked(e)) + continue; + visited.mark(e); + if (m.is_xor(e) || m.is_and(e) || m.is_or(e) || m.is_implies(e) || m.is_iff(e) || m.is_ite(e) || m.is_not(e)) { + for (auto arg : *to_app(e)) + todo.push_back(arg); + } + else { + auto bv = ctx.atom2bool_var(e); + if (bv == sat::null_bool_var) + continue; + if (is_uninterp(e)) { + if (!i.fixable_atoms.contains(bv)) { + i.fixable_atoms.push_back(bv); + i.fixable_exprs.push_back(e); + } + continue; + } + auto* ineq = a.get_ineq(bv); + if (!ineq) + continue; + i.fixable_atoms.push_back(bv); + buffer vars; + + for (auto& [v, occ] : ineq->m_nonlinear) + vars.push_back(v); + + for (unsigned j = 0; j < vars.size(); ++j) { + auto v = vars[j]; + if (tmp_set.contains(v)) + continue; + + if (a.is_add(v)) { + for (auto [c, w] : a.get_add(v).m_args) + vars.push_back(w); + } + else if (a.is_mul(v)) { + for (auto [w, p] : a.get_mul(v).m_monomial) + vars.push_back(w); + } + else { + i.fixable_exprs.push_back(a.m_vars[v].m_expr); + tmp_set.insert(v); + } + } + } + } + for (auto v : tmp_set) + i.fixable_vars.push_back(v); + return i.fixable_exprs; + } + + template + bool arith_lookahead::apply_move(expr* f, ptr_vector const& vars, arith_move_type t) { + if (vars.empty()) + return false; + auto& info = get_bool_info(f); + m_best_expr = nullptr; + m_best_score = m_top_score; + unsigned sz = vars.size(); + unsigned start = ctx.rand(); + a.m_updates.reset(); + m_fixed_atoms.reset(); + + switch (t) { + case arith_move_type::random_update: { + for (unsigned i = 0; i < sz; ++i) + add_lookahead(info, vars[(start + i) % sz]); + if (a.m_updates.empty()) + return false; + unsigned idx = ctx.rand(a.m_updates.size()); + auto& [v, delta, score] = a.m_updates[idx]; + m_best_expr = a.m_vars[v].m_expr; + if (false && !a.m_vars[v].m_finite_domain.empty()) + m_best_value = a.m_vars[v].m_finite_domain[ctx.rand() % a.m_vars[v].m_finite_domain.size()]; + else + m_best_value = a.value(v) + delta; + m_tabu_set = 0; + break; + } + case arith_move_type::hillclimb_plateau: + case arith_move_type::hillclimb: { + for (unsigned i = 0; i < sz; ++i) + add_lookahead(info, vars[(start + i) % sz]); + if (a.m_updates.empty()) + return false; + std::stable_sort(a.m_updates.begin(), a.m_updates.end(), [](auto const& a, auto const& b) { return a.m_var < b.m_var || (a.m_var == b.m_var && a.m_delta < b.m_delta); }); + m_last_expr = nullptr; + sz = a.m_updates.size(); + flet _allow_plateau(a.m_config.allow_plateau, a.m_config.allow_plateau || t == arith_move_type::hillclimb_plateau); + for (unsigned i = 0; i < sz; ++i) { + auto const& [v, delta, score] = a.m_updates[(start + i) % a.m_updates.size()]; + lookahead_num(v, delta); + } + if (m_last_expr) { + lookahead(m_last_expr, false); + clear_update_stack(); + } + break; + } + case arith_move_type::random_inc_dec: { + auto e = vars[ctx.rand() % sz]; + m_best_expr = e; + if (autil.is_int_real(e)) { + var_t v = a.mk_term(e); + auto& vi = a.m_vars[v]; + if (!vi.m_finite_domain.empty()) + m_best_value = vi.m_finite_domain[ctx.rand() % vi.m_finite_domain.size()]; + else if (ctx.rand(2) == 0) + m_best_value = a.value(v) + 1; + else + m_best_value = a.value(v) - 1; + } + m_tabu_set = 0; + break; + } + } + + if (m_best_expr) { + if (m.is_bool(m_best_expr)) + set_bool_value(m_best_expr, !get_bool_value(m_best_expr)); + else { + var_t v = a.mk_term(m_best_expr); + if (!a.update_num(v, m_best_value - a.value(v))) { + TRACE("arith", + tout << "could not move v" << v << " " << t << " " << mk_bounded_pp(m_best_expr, m) << " := " << a.value(v) << " " << m_top_score << "\n"; + ); + return false; + } + } + insert_update_stack_rec(m_best_expr); + m_top_score = lookahead(m_best_expr, true); + clear_update_stack(); + } + + CTRACE("arith", !m_best_expr, tout << "no move " << t << "\n";); + CTRACE("arith", m_best_expr && a.is_int_real(m_best_expr), { + var_t v = mk_term(m_best_expr); + tout << t << " v" << v << " " << mk_bounded_pp(m_best_expr, m) << " := " << value(v) << " " << m_top_score << "\n"; + }); + return !!m_best_expr; + } + + + std::ostream& operator<<(std::ostream& out, arith_move_type mt) { + switch (mt) { + case arith_move_type::random_update: out << "random-update"; break; + case arith_move_type::hillclimb: out << "hillclimb"; break; + case arith_move_type::random_inc_dec: out << "random-inc-dec"; break; + case arith_move_type::hillclimb_plateau: out << "hillclimb-plateau"; break; + } + return out; + } + + + + template + void arith_lookahead::check_restart() { + if (a.m_stats.m_steps % a.m_config.restart_base == 0) { + ucb_forget(); + rescore(); + } + + if (a.m_stats.m_steps < a.m_config.restart_next) + return; + + ++a.m_stats.m_restarts; + a.m_config.restart_next = std::max(a.m_config.restart_next, a.m_stats.m_steps); + + if (0x1 == (a.m_stats.m_restarts & 0x1)) + a.m_config.restart_next += a.m_config.restart_base; + else + a.m_config.restart_next += (2 * (a.m_stats.m_restarts >> 1)) * a.m_config.restart_base; + + // reset_uninterp_in_false_literals + rescore(); + } + + template + void arith_lookahead::ucb_forget() { + if (a.m_config.ucb_forget >= 1.0) + return; + for (auto f : ctx.input_assertions()) { + auto touched_old = get_touched(f); + auto touched_new = static_cast((touched_old - 1) * a.m_config.ucb_forget + 1); + set_touched(f, touched_new); + m_touched += touched_new - touched_old; + } + } + + template + void arith_lookahead::initialize_bool_assignment() { + for (auto t : ctx.subterms()) + if (m.is_bool(t)) + set_bool_value(t, get_bool_value_rec(t)); +#if 0 + for (auto t : ctx.subterms()) { + if (m.is_bool(t)) + verbose_stream() << mk_bounded_pp(t, m) << " := " << get_bool_value(t) << "\n"; + else + verbose_stream() << mk_bounded_pp(t, m) << " := " << ctx.get_value(t) << "\n"; + } +#endif + } + + template + void arith_lookahead::finalize_bool_assignment() { + for (unsigned v = ctx.num_bool_vars(); v-- > 0; ) { + auto a = ctx.atom(v); + if (!a) + continue; + if (get_bool_value(a) != ctx.is_true(v)) + ctx.flip(v); + } +#if 0 + for (auto idx : ctx.unsat()) { + auto const& cl = ctx.get_clause(idx); + verbose_stream() << "clause " << cl << "\n"; + for (auto lit : cl) { + auto a = ctx.atom(lit.var()); + if (a) + verbose_stream() << lit << " " << mk_bounded_pp(a, m) << " " << get_bool_value(a) << " " << ctx.is_true(lit) << "\n"; + else + verbose_stream() << lit << " " << ctx.is_true(lit) << "\n"; + } + } +#endif + + } + + + template + void arith_lookahead::search() { + initialize_bool_assignment(); + rescore(); + a.m_config.max_moves = a.m_stats.m_steps + a.m_config.max_moves_base; + TRACE("arith", tout << "search " << a.m_stats.m_steps << " " << a.m_config.max_moves << "\n";); + IF_VERBOSE(3, verbose_stream() << "lookahead-search steps:" << a.m_stats.m_steps << " max-moves:" << a.m_config.max_moves << "\n"); + TRACE("arith", display(tout)); + + while (ctx.rlimit().inc() && a.m_stats.m_steps < a.m_config.max_moves) { + a.m_stats.m_steps++; + check_restart(); + + auto t = get_candidate_unsat(); + + if (!t) + break; + + auto& vars = get_fixable_exprs(t); + + if (vars.empty()) + break; + + if (ctx.rand(2047) < a.m_config.wp && apply_move(t, vars, arith_move_type::random_inc_dec)) + continue; + + if (apply_move(t, vars, arith_move_type::hillclimb)) + continue; + + if (apply_move(t, vars, arith_move_type::random_update)) + recalibrate_weights(); + } + if (a.m_stats.m_steps >= a.m_config.max_moves) + a.m_config.max_moves_base += 100; + finalize_bool_assignment(); + } + + template + expr* arith_lookahead::get_candidate_unsat() { + expr* e = nullptr; + if (a.m_config.ucb) { + double max = -1.0; + for (auto f : ctx.input_assertions()) { + if (get_bool_value(f)) + continue; + + auto const& vars = get_fixable_exprs(f); + if (vars.empty()) + continue; + auto score = old_score(f); + auto q = score + + a.m_config.ucb_constant * ::sqrt(log((double)m_touched) / get_touched(f)) + + a.m_config.ucb_noise * ctx.rand(512); + if (q > max) + max = q, e = f; + } + if (e) { + m_touched++; + inc_touched(e); + } + } + else { + unsigned n = 0; + for (auto a : ctx.input_assertions()) + if (!get_bool_value(a) && !get_fixable_exprs(a).empty() && ctx.rand() % ++n == 0) + e = a; + } + + m_last_atom = e; + CTRACE("arith", !e, tout << "no unsatisfiable candidate\n";); + CTRACE("arith", e, + tout << "select " << mk_bounded_pp(e, m) << " "; + for (auto v : get_fixable_exprs(e)) + tout << mk_bounded_pp(v, m) << " "; + tout << "\n"); + return e; + } + + + +} + +template class sls::arith_lookahead>; +template class sls::arith_lookahead; + diff --git a/src/ast/sls/sls_arith_lookahead.h b/src/ast/sls/sls_arith_lookahead.h new file mode 100644 index 000000000..356facbfc --- /dev/null +++ b/src/ast/sls/sls_arith_lookahead.h @@ -0,0 +1,116 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + sls_arith_lookahead + +Abstract: + + Theory plugin for arithmetic local search + based on lookahead search as used in HybridSMT + +Author: + + Nikolaj Bjorner (nbjorner) 2025-01-16 + +--*/ +#pragma once + +#include "util/checked_int64.h" +#include "util/optional.h" +#include "util/nat_set.h" +#include "ast/ast_trail.h" +#include "ast/arith_decl_plugin.h" +#include "ast/sls/sls_context.h" + + +namespace sls { + + template + class arith_base; + + using var_t = unsigned; + + enum arith_move_type { + hillclimb, + hillclimb_plateau, + random_update, + random_inc_dec + }; + + std::ostream& operator<<(std::ostream& out, arith_move_type mt); + + template + class arith_lookahead { + context& ctx; + ast_manager& m; + class arith_base& a; + arith_util autil; + + struct bool_info { + unsigned weight = 0; + double score = 0; + unsigned touched = 1; + lbool value = l_undef; + sat::bool_var_vector fixable_atoms; + svector fixable_vars; + ptr_vector fixable_exprs; + bool_info(unsigned w) : weight(w) {} + }; + + vector> m_update_stack; + expr_mark m_in_update_stack; + svector m_bool_info; + double m_best_score = 0, m_top_score = 0; + unsigned m_min_depth = 0, m_max_depth = 0; + num_t m_best_value; + expr* m_best_expr = nullptr, * m_last_atom = nullptr, * m_last_expr = nullptr; + expr_mark m_is_root; + unsigned m_touched = 1; + sat::bool_var_set m_fixed_atoms; + uint64_t m_tabu_set = 0; + unsigned m_global_search_count = 0; + + bool in_tabu_set(expr* e, num_t const& n); + void insert_tabu_set(expr* e, num_t const& n); + bool_info& get_bool_info(expr* e); + bool get_bool_value(expr* e); + bool get_bool_value_rec(expr* e); + void set_bool_value(expr* e, bool v) { get_bool_info(e).value = to_lbool(v); } + bool get_basic_bool_value(app* e); + double old_score(expr* e) { return get_bool_info(e).score; } + double new_score(expr* e); + double new_score(expr* e, bool is_true); + void set_score(expr* e, double s) { get_bool_info(e).score = s; } + void rescore(); + void recalibrate_weights(); + void inc_weight(expr* e) { ++get_bool_info(e).weight; } + void dec_weight(expr* e); + unsigned get_weight(expr* e) { return get_bool_info(e).weight; } + unsigned get_touched(expr* e) { return get_bool_info(e).touched; } + void inc_touched(expr* e) { ++get_bool_info(e).touched; } + void set_touched(expr* e, unsigned t) { get_bool_info(e).touched = t; } + void insert_update_stack(expr* t); + void insert_update_stack_rec(expr* t); + void clear_update_stack(); + void lookahead_num(var_t v, num_t const& value); + void lookahead_bool(expr* e); + double lookahead(expr* e, bool update_score); + void add_lookahead(bool_info& i, expr* e); + void add_lookahead(bool_info& i, sat::bool_var bv); + ptr_vector const& get_fixable_exprs(expr* e); + bool apply_move(expr* f, ptr_vector const& vars, arith_move_type t); + expr* get_candidate_unsat(); + void check_restart(); + void ucb_forget(); + void initialize_bool_assignment(); + void finalize_bool_assignment(); + + public: + arith_lookahead(arith_base& a); + void search(); + }; +} + +