diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index 77e39ce55..b5e0a0eca 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -3,6 +3,7 @@ z3_add_component(ast_sls bvsls_opt_engine.cpp sat_ddfw.cpp sls_arith_base.cpp + sls_arith_clausal.cpp sls_arith_plugin.cpp sls_array_plugin.cpp sls_basic_plugin.cpp diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index 415368e2c..f08e1db37 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -99,6 +99,16 @@ namespace sat { m_last_flips = m_flips; } + sat::bool_var ddfw::bool_flip() { + flet _in_bool_flip(m_in_bool_flip, true); + double reward = 0; + bool_var v = pick_var(reward); + if (apply_flip(v, reward)) + return v; + shift_weights(); + return sat::null_bool_var; + } + bool ddfw::do_flip() { double reward = 0; bool_var v = pick_var(reward); @@ -125,7 +135,9 @@ namespace sat { bool_var v0 = null_bool_var; for (bool_var v : m_unsat_vars) { r = reward(v); - if (r > 0.0) + if (m_in_bool_flip && m_plugin->is_external(v)) + ; + else if (r > 0.0) sum_pos += score(r); else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0) v0 = v; @@ -134,6 +146,8 @@ namespace sat { double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos; for (bool_var v : m_unsat_vars) { r = reward(v); + if (m_in_bool_flip && m_plugin->is_external(v)) + continue; if (r > 0) { lim_pos -= score(r); if (lim_pos <= 0) @@ -146,6 +160,8 @@ namespace sat { return v0; if (m_unsat_vars.empty()) return null_bool_var; + if (m_in_bool_flip) + return false; return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); } @@ -332,6 +348,7 @@ namespace sat { m_vars[v].m_reward = 0; } m_unsat_vars.reset(); + m_num_external_in_unsat_vars = 0; m_unsat.reset(); unsigned sz = m_clauses.size(); for (unsigned i = 0; i < sz; ++i) { @@ -400,7 +417,7 @@ namespace sat { for (unsigned i = 0; i < num_vars(); ++i) m_model[i] = to_lbool(value(i)); save_priorities(); - if (m_plugin) + if (m_plugin && !m_in_bool_flip) m_last_result = m_plugin->on_save_model(); } diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 5c027d759..217a0ca0b 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -36,11 +36,10 @@ namespace sat { class local_search_plugin { public: virtual ~local_search_plugin() {} - //virtual void init_search() = 0; - //virtual void finish_search() = 0; virtual void on_rescale() = 0; virtual lbool on_save_model() = 0; virtual void on_restart() = 0; + virtual bool is_external(sat::bool_var v) = 0; }; class ddfw { @@ -140,14 +139,26 @@ namespace sat { unsigned select_max_same_sign(unsigned cf_idx); + unsigned m_num_external_in_unsat_vars = 0; + inline void inc_make(literal lit) { bool_var v = lit.var(); - if (make_count(v)++ == 0) m_unsat_vars.insert_fresh(v); + if (make_count(v)++ == 0) { + m_unsat_vars.insert_fresh(v); + if (m_plugin && m_plugin->is_external(v)) + ++m_num_external_in_unsat_vars; + } } inline void dec_make(literal lit) { bool_var v = lit.var(); - if (--make_count(v) == 0) m_unsat_vars.remove(v); + if (--make_count(v) == 0) { + if (m_unsat_vars.contains(v)) { + m_unsat_vars.remove(v); + if (m_plugin && m_plugin->is_external(v)) + --m_num_external_in_unsat_vars; + } + } } inline void inc_reward(literal lit, double w) { m_vars[lit.var()].m_reward += w; } @@ -164,13 +175,12 @@ namespace sat { bool apply_flip(bool_var v, double reward); - void save_best_values(); void save_model(); void save_priorities(); // shift activity - void shift_weights(); + inline double calculate_transfer_weight(double w); // reinitialize weights activity @@ -204,6 +214,8 @@ namespace sat { bool_var_set m_rotate_tabu; bool_var_vector m_new_tabu_vars; + bool m_in_bool_flip = false; + public: ddfw() {} @@ -241,6 +253,10 @@ namespace sat { indexed_uint_set const& unsat_set() const { return m_unsat; } + indexed_uint_set const& unsat_vars() const { return m_unsat_vars; } + + unsigned num_external_in_unsat_vars() const { return m_num_external_in_unsat_vars; } + vector const& clauses() const { return m_clauses; } clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; } @@ -251,6 +267,10 @@ namespace sat { void flip(bool_var v); + sat::bool_var bool_flip(); + + void shift_weights(); + inline double reward(bool_var v) const { return m_vars[v].m_reward; } void set_reward(bool_var v, double r) { m_vars[v].m_reward = r; } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index ee0e4cced..8ad22e40e 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -111,7 +111,8 @@ namespace sls { arith_base::arith_base(context& ctx) : plugin(ctx), a(m), - m_new_terms(m) { + m_new_terms(m), + m_clausal_sls(*this) { m_fid = a.get_family_id(); } @@ -447,12 +448,12 @@ namespace sls { delta_out = delta; if (m_last_var == v && m_last_delta == -delta) { - TRACE("arith", tout << "flip back " << v << " " << delta << "\n";); + TRACE("arith_verbose", tout << "flip back " << v << " " << delta << "\n";); return false; } - if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) { - TRACE("arith", tout << "tabu\n"); + if (m_use_tabu && vi.is_tabu(m_stats.m_steps, delta)) { + TRACE("arith_verbose", tout << "tabu v" << v << " delta:" << delta << "\n"); return false; } @@ -545,8 +546,8 @@ namespace sls { if (update(v, new_value)) { m_last_delta = delta; - m_stats.m_num_steps++; - m_vars[v].set_step(m_stats.m_num_steps, m_stats.m_num_steps + 3 + ctx.rand(10), delta); + m_stats.m_steps++; + m_vars[v].set_step(m_stats.m_steps, m_stats.m_steps + 3 + ctx.rand(10), delta); return true; } sum_score -= score; @@ -1106,6 +1107,7 @@ namespace sls { // attach i to bv m_ineqs.set(bv, &i); + m_bool_var_atoms.insert(bv); } template @@ -1403,6 +1405,40 @@ namespace sls { throw default_exception("repair is not supported for " + mk_pp(e, m)); } } + for (unsigned v = 0; v < m_vars.size(); ++v) + initialize_bool_vars_of(v); + } + + template + void arith_base::initialize_bool_vars_of(var_t v) { + if (!m_vars[v].m_bool_vars_of.empty()) + return; + buffer todo; + todo.push_back(v); + auto& vi = m_vars[v]; + for (unsigned i = 0; i < todo.size(); ++i) { + var_t u = todo[i]; + auto& ui = m_vars[u]; + for (auto const& idx : ui.m_muls) { + auto& [x, monomial] = m_muls[idx]; + if (all_of(todo, [x](var_t v) { return x != v; })) + todo.push_back(x); + } + for (auto const& idx : ui.m_adds) { + auto x = m_adds[idx].m_var; + if (all_of(todo, [x](var_t v) { return x != v; })) + todo.push_back(x); + } + for (auto const& [coeff, bv] : ui.m_linear_occurs) + vi.m_bool_vars_of.insert(bv); + } + ; + for (auto bv : vi.m_bool_vars_of) { + for (auto i : ctx.get_use_list(sat::literal(bv, true))) + vi.m_clauses_of.insert(i); + for (auto i : ctx.get_use_list(sat::literal(bv, false))) + vi.m_clauses_of.insert(i); + } } template @@ -2274,7 +2310,7 @@ namespace sls { auto const& vi = m_vars[v]; if (vi.m_def_idx == UINT_MAX) return true; - IF_VERBOSE(4, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); + IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); switch (vi.m_op) { case arith_op_kind::LAST_ARITH_OP: @@ -2398,13 +2434,12 @@ namespace sls { template void arith_base::collect_statistics(statistics& st) const { - st.update("sls-arith-flips", m_stats.m_num_steps); - st.update("sls-arith-moves", m_stats.m_moves); + st.update("sls-arith-steps", m_stats.m_steps); } template void arith_base::reset_statistics() { - m_stats.m_num_steps = 0; + m_stats.m_steps = 0; } // global lookahead mode @@ -2708,7 +2743,6 @@ namespace sls { template void arith_base::lookahead_num(var_t v, num_t const& delta) { num_t old_value = value(v); - expr* e = m_vars[v].m_expr; if (m_last_expr != e) { if (m_last_expr) @@ -2779,6 +2813,31 @@ namespace sls { m_last_expr = nullptr; } + template + void arith_base::add_lookahead(bool_info& i, sat::bool_var bv) { + if (!i.fixable_atoms.contains(bv)) + return; + if (m_fixed_atoms.contains(bv)) + return; + auto* ineq = get_ineq(bv); + if (!ineq) + return; + num_t na, nb; + for (auto const& [x, nl] : ineq->m_nonlinear) { + if (!i.fixable_vars.contains(x)) + continue; + if (is_fixed(x)) + continue; + if (is_linear(x, nl, nb)) + find_linear_moves(*ineq, x, nb); + else if (is_quadratic(x, nl, na, nb)) + find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); + else + ; + } + m_fixed_atoms.insert(bv); + } + // for every variable e, for every atom containing e // add lookahead for e. // m_fixable_atoms contains atoms that can be fixed. @@ -2786,33 +2845,6 @@ namespace sls { template void arith_base::add_lookahead(bool_info& i, expr* e) { - auto add_atom = [&](sat::bool_var bv) { - if (!i.fixable_atoms.contains(bv)) - return; - if (m_fixed_atoms.contains(bv)) - return; - auto a = ctx.atom(bv); - if (!a) - return; - auto* ineq = get_ineq(bv); - if (!ineq) - return; - num_t na, nb; - for (auto const& [x, nl] : ineq->m_nonlinear) { - if (!i.fixable_vars.contains(x)) - continue; - if (is_fixed(x)) - continue; - if (is_linear(x, nl, nb)) - find_linear_moves(*ineq, x, nb); - else if (is_quadratic(x, nl, na, nb)) - find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); - else - ; - } - m_fixed_atoms.insert(bv); - }; - auto add_finite_domain = [&](var_t v) { auto old_value = value(v); for (auto const& n : m_vars[v].m_finite_domain) @@ -2832,13 +2864,8 @@ namespace sls { add_finite_domain(v); return; } - for (auto const& [coeff, bv] : vi.m_linear_occurs) - add_atom(bv); - for (auto const& idx : vi.m_muls) { - auto const& [x, monomial] = m_muls[idx]; - for (auto [coeff, bv] : m_vars[x].m_linear_occurs) - add_atom(bv); - } + for (auto bv : vi.m_bool_vars_of) + add_lookahead(i, bv); } } @@ -2927,7 +2954,7 @@ namespace sls { add_lookahead(info, vars[(start + i) % sz]); if (m_updates.empty()) return false; - unsigned idx = ctx.rand() % m_updates.size(); + unsigned idx = ctx.rand(m_updates.size()); auto& [v, delta, score] = m_updates[idx]; m_best_expr = m_vars[v].m_expr; if (false && !m_vars[v].m_finite_domain.empty()) @@ -3015,13 +3042,13 @@ namespace sls { void arith_base::global_search() { initialize_bool_assignment(); rescore(); - m_config.max_moves = m_stats.m_moves + m_config.max_moves_base; - TRACE("arith", tout << "search " << m_stats.m_moves << " " << m_config.max_moves << "\n";); - IF_VERBOSE(3, verbose_stream() << "lookahead-search moves:" << m_stats.m_moves << " max-moves:" << m_config.max_moves << "\n"); + m_config.max_moves = m_stats.m_steps + m_config.max_moves_base; + TRACE("arith", tout << "search " << m_stats.m_steps << " " << m_config.max_moves << "\n";); + IF_VERBOSE(3, verbose_stream() << "lookahead-search steps:" << m_stats.m_steps << " max-moves:" << m_config.max_moves << "\n"); TRACE("arith", display(tout)); - while (ctx.rlimit().inc() && m_stats.m_moves < m_config.max_moves) { - m_stats.m_moves++; + while (ctx.rlimit().inc() && m_stats.m_steps < m_config.max_moves) { + m_stats.m_steps++; check_restart(); auto t = get_candidate_unsat(); @@ -3043,7 +3070,7 @@ namespace sls { if (apply_move(t, vars, arith_move_type::random_update)) recalibrate_weights(); } - if (m_stats.m_moves >= m_config.max_moves) + if (m_stats.m_steps >= m_config.max_moves) m_config.max_moves_base += 100; finalize_bool_assignment(); } @@ -3098,11 +3125,11 @@ namespace sls { if (old_value == new_value) return true; if (!vi.in_range(new_value)) { - TRACE("arith", tout << "Not in range v" << v << " " << new_value << "\n"); + TRACE("arith_verbose", tout << "Not in range v" << v << " " << new_value << "\n"); return false; } if (!in_bounds(v, new_value) && in_bounds(v, old_value)) { - TRACE("arith", tout << "out of bounds v" << v << " " << new_value << "\n"); + TRACE("arith_verbose", tout << "out of bounds v" << v << " " << new_value << "\n"); //verbose_stream() << "out of bounds v" << v << " " << new_value << "\n"; return false; } @@ -3166,16 +3193,16 @@ namespace sls { template void arith_base::check_restart() { - if (m_stats.m_moves % m_config.restart_base == 0) { + if (m_stats.m_steps % m_config.restart_base == 0) { ucb_forget(); rescore(); } - if (m_stats.m_moves < m_config.restart_next) + if (m_stats.m_steps < m_config.restart_next) return; ++m_stats.m_restarts; - m_config.restart_next = std::max(m_config.restart_next, m_stats.m_moves); + m_config.restart_next = std::max(m_config.restart_next, m_stats.m_steps); if (0x1 == (m_stats.m_restarts & 0x1)) m_config.restart_next += m_config.restart_base; @@ -3184,10 +3211,8 @@ namespace sls { // reset_uninterp_in_false_literals rescore(); - } - template void arith_base::ucb_forget() { if (m_config.ucb_forget >= 1.0) @@ -3214,18 +3239,21 @@ namespace sls { //m_config.ucb_forget = p.ucb_forget(); m_config.wp = p.wp(); m_config.restart_base = p.restart_base(); - //m_config.restart_next = p.restart_next(); + m_config.restart_next = p.restart_base(); //m_config.max_moves_base = p.max_moves_base(); //m_config.max_moves = p.max_moves(); - m_config.arith_use_lookahead = p.arith_use_lookahead(); + m_config.use_lookahead = p.arith_use_lookahead(); + m_config.use_clausal_lookahead = p.arith_use_clausal_lookahead(); m_config.allow_plateau = p.arith_allow_plateau(); m_config.config_initialized = true; } template void arith_base::start_propagation() { - updt_params(); - if (m_config.arith_use_lookahead) + updt_params(); + if (m_config.use_clausal_lookahead) + m_clausal_sls.search(); + else if (m_config.use_lookahead) global_search(); } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index effec089d..d073bf9bf 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -22,6 +22,7 @@ Author: #include "ast/ast_trail.h" #include "ast/arith_decl_plugin.h" #include "ast/sls/sls_context.h" +#include "ast/sls/sls_arith_clausal.h" namespace sls { @@ -36,6 +37,8 @@ namespace sls { std::ostream& operator<<(std::ostream& out, arith_move_type mt); + static const unsigned null_arith_var = UINT_MAX; + // local search portion for arithmetic template class arith_base : public plugin { @@ -66,13 +69,13 @@ namespace sls { unsigned restart_base = 1000; unsigned restart_next = 1000; unsigned restart_init = 1000; - bool arith_use_lookahead = false; + bool use_lookahead = false; + bool use_clausal_lookahead = false; bool allow_plateau = false; }; struct stats { - unsigned m_num_steps = 0; - unsigned m_moves = 0; + unsigned m_steps = 0; unsigned m_restarts = 0; }; @@ -116,6 +119,8 @@ namespace sls { arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP; unsigned m_def_idx = UINT_MAX; vector> m_linear_occurs; + indexed_uint_set m_bool_vars_of; + indexed_uint_set m_clauses_of; unsigned_vector m_muls; unsigned_vector m_adds; optional m_lo, m_hi; @@ -154,6 +159,9 @@ namespace sls { else m_tabu_neg = tabu_step, m_last_neg = step; } + unsigned last_step(num_t const& delta) const { + return delta > 0 ? m_last_pos : m_last_neg; + } void out_of_range() { ++m_num_out_of_range; if (m_num_out_of_range < 1000 * (1 + m_num_in_range)) @@ -204,7 +212,10 @@ namespace sls { bool m_use_tabu = true; unsigned m_updates_max_size = 45; arith_util a; + friend class arith_clausal; + arith_clausal m_clausal_sls; svector m_prob_break; + indexed_uint_set m_bool_var_atoms; void invariant(); void invariant(ineq const& i); @@ -277,6 +288,7 @@ namespace sls { double compute_score(var_t x, num_t const& delta); void save_best_values(); + void initialize_bool_vars_of(var_t v); var_t mk_var(expr* e); var_t mk_term(expr* e); var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y); @@ -318,7 +330,7 @@ namespace sls { double score = 0; unsigned touched = 1; lbool value = l_undef; - sat::bool_var_set fixable_atoms; + indexed_uint_set fixable_atoms; uint_set fixable_vars; ptr_vector fixable_exprs; bool_info(unsigned w) : weight(w) {} @@ -335,6 +347,7 @@ namespace sls { unsigned m_touched = 1; sat::bool_var_set m_fixed_atoms; uint64_t m_tabu_set = 0; + unsigned m_global_search_count = 0; bool in_tabu_set(expr* e, num_t const& n); void insert_tabu_set(expr* e, num_t const& n); @@ -344,6 +357,7 @@ namespace sls { void set_bool_value(expr* e, bool v) { get_bool_info(e).value = to_lbool(v); } bool get_basic_bool_value(app* e); void initialize_bool_assignment(); + void finalize_bool_assignment(); double old_score(expr* e) { return get_bool_info(e).score; } double new_score(expr* e); @@ -366,6 +380,7 @@ namespace sls { void lookahead_bool(expr* e); double lookahead(expr* e, bool update_score); void add_lookahead(bool_info& i, expr* e); + void add_lookahead(bool_info& i, sat::bool_var bv); ptr_vector const& get_fixable_exprs(expr* e); bool apply_move(expr* f, ptr_vector const& vars, arith_move_type t); expr* get_candidate_unsat(); diff --git a/src/ast/sls/sls_arith_clausal.cpp b/src/ast/sls/sls_arith_clausal.cpp new file mode 100644 index 000000000..a0dcad290 --- /dev/null +++ b/src/ast/sls/sls_arith_clausal.cpp @@ -0,0 +1,368 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + sls_arith_clausal + +Abstract: + + Theory plugin for arithmetic local search + based on clausal search as used in HybridSMT (nia_ls) + + In contrast to HybridSMT/nia_ls we reuse ddfw + for everything Boolean. It requiers exposing the following: + + - unsat_vars - Boolean variables that are in unsat clauses. + - num_external_vars_in_unsat - External variables in unsat clauses + - shift_weights - allow plugin to invoke shift-weights + + +Author: + + Nikolaj Bjorner (nbjorner) 2025-01-16 + +--*/ + +#include "ast/sls/sls_arith_clausal.h" +#include "ast/sls/sls_arith_base.h" + +namespace sls { + template + arith_clausal::arith_clausal(arith_base& a) : + ctx(a.ctx), + a(a) { + } + + template + void arith_clausal::search() { + num_t delta; + + initialize(); + + TRACE("arith", ctx.display_all(tout)); + + a.m_config.max_moves = a.m_stats.m_steps + a.m_config.max_moves_base; + + while (ctx.rlimit().inc() && a.m_stats.m_steps < a.m_config.max_moves && !ctx.unsat().empty()) { + a.m_stats.m_steps++; + + check_restart(); + + unsigned vars_in_unsat = ctx.unsat_vars().size(); + unsigned ext_in_unsat = ctx.num_external_in_unsat_vars(); + unsigned bool_in_unsat = vars_in_unsat - ext_in_unsat; + bool time_up_bool = m_no_improve_bool * vars_in_unsat > 5 * bool_in_unsat; + bool time_up_arith = m_no_improve_arith * vars_in_unsat > 20 * ext_in_unsat; + if ((m_bool_mode && bool_in_unsat < vars_in_unsat && time_up_bool) || bool_in_unsat == 0) + enter_arith_mode(); + else if ((!m_bool_mode && bool_in_unsat > 0 && time_up_arith) || vars_in_unsat == bool_in_unsat) + enter_bool_mode(); + if (m_bool_mode) { + sat::bool_var v = ctx.bool_flip(); + TRACE("arith", tout << "bool flip v:" << v << "\n"; + tout << "unsat-vars " << vars_in_unsat << "\n"; + tout << "bools: " << bool_in_unsat << " timeup-bool " << time_up_bool << "\n"; + tout << "no-improve bool: " << m_no_improve_bool << "\n"; + tout << "ext: " << ext_in_unsat << " timeup-arith " << time_up_arith << "\n";); + + m_no_improve_bool = update_outer_best_solution() ? 0 : m_no_improve_bool + 1; + } + else { + move_arith_variable(); + m_no_improve_arith = update_inner_best_solution() ? 0 : m_no_improve_arith + 1; + } + m_no_improve = update_best_solution() ? 0 : m_no_improve + 1; + } + if (a.m_stats.m_steps >= a.m_config.max_moves) + a.m_config.max_moves_base += 100; + } + + template + void arith_clausal::move_arith_variable() { + + var_t v = null_arith_var; + + { + a.m_best_score = 1; + flet _use_tabu(a.m_use_tabu, true); + if (v == null_arith_var) { + add_lookahead_on_unsat_vars(); + v = critical_move_on_updates(unsat_var_move); + } + if (v == null_arith_var) { + add_lookahead_on_false_literals(); + v = critical_move_on_updates(false_literal_move); + } + } + + // tabu flips were not possible + + if (v == null_arith_var) + ctx.shift_weights(); + + if (v == null_arith_var) { + a.m_best_score = -1; + flet _use_tabu(a.m_use_tabu, false); + add_lookahead_on_unsat_vars(); + v = random_move_on_updates(); + } + } + + template + void arith_clausal::add_lookahead_on_unsat_vars() { + a.m_updates.reset(); + a.m_fixed_atoms.reset(); + TRACE("arith_verbose", tout << "unsat-vars "; + for (auto v : ctx.unsat_vars()) + if (a.get_ineq(v)) tout << mk_bounded_pp(ctx.atom(v), a.m) << " "; + tout << "\n";); + + for (auto v : ctx.unsat_vars()) { + + auto* ineq = a.get_ineq(v); + if (!ineq) + continue; + auto e = ctx.atom(v); + auto& i = a.get_bool_info(e); + auto const& vars = a.get_fixable_exprs(e); + for (auto v : vars) + a.add_lookahead(i, v); + } + } + + /** + * \brief walk over literals that are false in some clause. + * Try to determine if flipping them to true improves the overall score. + */ + template + void arith_clausal::add_lookahead_on_false_literals() { + a.m_updates.reset(); + a.m_fixed_atoms.reset(); + + for (auto bv : a.m_bool_var_atoms) { + if (ctx.unsat_vars().contains(bv)) + continue; + auto* ineq = a.get_ineq(bv); + if (!ineq) + continue; + sat::literal lit(bv, !ineq->is_true()); + auto const& ul = ctx.get_use_list(~lit); + if (ul.begin() == ul.end()) + continue; + auto v = lit.var(); + // literal is false in some clause but none of the clauses where it occurs false are unsat. + + auto e = ctx.atom(v); + auto& i = a.get_bool_info(e); + a.add_lookahead(i, v); + } + } + + template + var_t arith_clausal::critical_move_on_updates(move_t mt) { + if (a.m_updates.empty()) + return null_arith_var; + std::stable_sort(a.m_updates.begin(), a.m_updates.end(), [](auto const& a, auto const& b) { return a.m_var < b.m_var || (a.m_var == b.m_var && a.m_delta < b.m_delta); }); + m_last_var = null_arith_var; + m_last_delta = 0; + m_best_var = null_arith_var; + m_best_delta = 0; + m_best_abs_value = num_t(-1); + m_best_last_step = UINT_MAX; + for (auto const& u : a.m_updates) + lookahead(u.m_var, u.m_delta); + critical_move(m_best_var, m_best_delta, mt); + return m_best_var; + } + + template + var_t arith_clausal::random_move_on_updates() { + if (a.m_updates.empty()) + return null_arith_var; + unsigned idx = ctx.rand(a.m_updates.size()); + auto& [v, delta, score] = a.m_updates[idx]; + if (!a.can_update_num(v, delta)) + return null_arith_var; + critical_move(v, delta, random_move); + return v; + } + + + template + void arith_clausal::lookahead(var_t v, num_t const& delta) { + if (v == m_last_var && delta == m_last_delta) + return; + if (delta == 0) + return; + m_last_var = v; + m_last_delta = delta; + if (!a.can_update_num(v, delta)) + return; + auto score = get_score(v, delta); + auto& vi = a.m_vars[v]; + num_t abs_value = abs(vi.value() + delta); + unsigned last_step = vi.last_step(delta); + if (score < a.m_best_score) + return; + if (score > a.m_best_score || + (m_best_abs_value == -1) || + (abs_value < m_best_abs_value) || + (abs_value == m_best_abs_value && last_step < m_best_last_step)) { + a.m_best_score = score; + m_best_var = v; + m_best_delta = delta; + m_best_last_step = last_step; + m_best_abs_value = abs_value; + } + } + + template + void arith_clausal::critical_move(var_t v, num_t const& delta, move_t mt) { + if (v == null_arith_var) + return; + a.m_last_delta = delta; + a.m_last_var = v; + TRACE("arith", tout << mt << " v" << v << " " << mk_bounded_pp(a.m_vars[v].m_expr, a.m) + << " += " << delta << " score:" << a.m_best_score << "\n"); + a.m_vars[v].set_step(a.m_stats.m_steps, a.m_stats.m_steps + 3 + ctx.rand(10), delta); + VERIFY(a.update_num(v, delta)); + for (auto bv : a.m_vars[v].m_bool_vars_of) + if (a.get_ineq(bv) && a.get_ineq(bv)->is_true() != ctx.is_true(bv)) + ctx.flip(bv); + + DEBUG_CODE( + for (sat::bool_var bv = 0; bv < ctx.num_bool_vars(); ++bv) { + if (a.get_ineq(bv) && a.get_ineq(bv)->is_true() != ctx.is_true(bv)) { + TRACE("arith", tout << bv << " " << *a.get_ineq(bv) << "\n"; + tout << a.m_vars[v].m_bool_vars_of << "\n"); + } + VERIFY(!a.get_ineq(bv) || a.get_ineq(bv)->is_true() == ctx.is_true(bv)); + }); + } + + template + double arith_clausal::get_score(var_t v, num_t const& delta) { + auto& vi = a.m_vars[v]; + VERIFY(a.update_num(v, delta)); + double score = 0; + for (auto ci : vi.m_clauses_of) { + auto const& c = ctx.get_clause(ci); + unsigned num_true = 0; + for (auto lit : c) { + auto bv = lit.var(); + auto ineq = a.get_ineq(bv); + if (ineq) { + if (ineq->is_true() != lit.sign()) + ++num_true; + } + else if (ctx.is_true(lit)) + ++num_true; + } + CTRACE("arith_verbose", c.m_num_trues != num_true && (c.m_num_trues == 0 || num_true == 0), + tout << "clause: " << c + << " v" << v << " += " << delta + << " new-true lits: " << num_true + << " old-true lits: " << c.m_num_trues + << " w: " << c.m_weight << "\n"; + for (auto lit : c) + if (a.get_ineq(lit.var())) + tout << lit << " " << *a.get_ineq(lit.var()) << "\n";); + if (c.m_num_trues > 0 && num_true == 0) + score -= c.m_weight; + else if (c.m_num_trues == 0 && num_true > 0) + score += c.m_weight; + } + // revert the update + VERIFY(a.update_num(v, -delta)); + return score; + } + + template + void arith_clausal::check_restart() { + if (m_no_improve <= 500000) + return; + + IF_VERBOSE(2, verbose_stream() << "restart sls-arith\n"); + TRACE("arith", tout << "restart\n";); + // reset values of (arithmetical) variables at bounds. + for (auto& vi : a.m_vars) { + if (vi.m_lo && !vi.m_lo->is_strict && vi.m_lo->value > 0) + vi.set_value(vi.m_lo->value); + else if (vi.m_hi && !vi.m_hi->is_strict && vi.m_hi->value < 0) + vi.set_value(vi.m_hi->value); + else + vi.set_value(num_t(0)); + vi.m_bool_vars_of.reset(); + vi.m_clauses_of.reset(); + } + initialize(); + } + + template + void arith_clausal::initialize() { + a.initialize_bool_assignment(); + for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) + a.init_bool_var_assignment(v); + + m_best_found_cost_bool = ctx.unsat().size(); + m_best_found_cost_arith = ctx.unsat().size(); + m_best_found_cost_restart = ctx.unsat().size(); + m_no_improve = 0; + m_no_improve_bool = 0; + m_no_improve_arith = 0; + } + + + template + bool arith_clausal::update_outer_best_solution() { + if (ctx.unsat().size() >= m_best_found_cost_bool) + return false; + m_best_found_cost_bool = ctx.unsat().size(); + return true; + } + + template + void arith_clausal::enter_bool_mode() { + CTRACE("arith", !m_bool_mode, tout << "enter bool mode\n";); + m_best_found_cost_bool = ctx.unsat().size(); + if (!m_bool_mode) + m_no_improve_bool = 0; + m_bool_mode = true; + } + + template + bool arith_clausal::update_inner_best_solution() { + if (ctx.unsat().size() >= m_best_found_cost_arith) + return false; + m_best_found_cost_arith = ctx.unsat().size(); + return true; + } + + template + void arith_clausal::enter_arith_mode() { + CTRACE("arith", m_bool_mode, tout << "enter arith mode\n";); + m_best_found_cost_arith = ctx.unsat().size(); + if (m_bool_mode) + m_no_improve_arith = 0; + m_bool_mode = false; + } + + template + bool arith_clausal::update_best_solution() { + bool improved = false; + if (ctx.unsat().size() < m_best_found_cost_restart) { + improved = true; + m_best_found_cost_restart = ctx.unsat().size(); + } + if (ctx.unsat().size() < m_best_found_cost) { + improved = true; + m_best_found_cost = ctx.unsat().size(); + } + return improved; + } +} + +template class sls::arith_clausal>; +template class sls::arith_clausal; + diff --git a/src/ast/sls/sls_arith_clausal.h b/src/ast/sls/sls_arith_clausal.h new file mode 100644 index 000000000..0eefb3955 --- /dev/null +++ b/src/ast/sls/sls_arith_clausal.h @@ -0,0 +1,90 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + sls_arith_clausal + +Abstract: + + Theory plugin for arithmetic local search + based on clausal search as used in HybridSMT + +Author: + + Nikolaj Bjorner (nbjorner) 2025-01-16 + +--*/ +#pragma once + +#include "util/checked_int64.h" +#include "util/optional.h" +#include "ast/ast_trail.h" +#include "ast/arith_decl_plugin.h" +#include "ast/sls/sls_context.h" + +namespace sls { + + template + class arith_base; + + using var_t = unsigned; + + template + class arith_clausal { + context& ctx; + class arith_base& a; + + void check_restart(); + void initialize(); + + enum move_t { + unsat_var_move, + false_literal_move, + random_move + }; + friend std::ostream& operator<<(std::ostream& out, move_t mt) { + return out << (mt == unsat_var_move ? + "unsat-var" : mt == false_literal_move ? + "false-literal" : "random"); + } + void enter_arith_mode(); + void enter_bool_mode(); + + bool update_outer_best_solution(); + bool update_inner_best_solution(); + bool update_best_solution(); + void move_arith_variable(); + var_t critical_move_on_updates(move_t mt); + var_t random_move_on_updates(); + void add_lookahead_on_unsat_vars(); + void add_lookahead_on_false_literals(); + void critical_move(var_t v, num_t const& delta, move_t mt); + void lookahead(var_t v, num_t const& delta); + double get_score(var_t v, num_t const& delta); + + + unsigned m_no_improve_bool = 0; + unsigned m_no_improve_arith = 0; + unsigned m_no_improve = 0; + bool m_bool_mode = true; + unsigned m_best_found_cost_bool = 0; + unsigned m_best_found_cost_arith = 0; + unsigned m_best_found_cost_restart = 0; + unsigned m_best_found_cost = 0; + num_t m_best_abs_value; + num_t m_best_delta; + var_t m_best_var = UINT_MAX; + unsigned m_best_last_step = 0; + + // avoid checking the same updates twice + var_t m_last_var = UINT_MAX; + num_t m_last_delta; + + public: + arith_clausal(arith_base& a); + void search(); + }; +} + + diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index acae72599..5333e43f7 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -112,6 +112,18 @@ namespace sls { if (p) p->on_restart(); } + + bool context::is_external(sat::bool_var v) { + auto a = atom(v); + if (!a) + return false; + family_id fid = get_fid(a); + if (fid == basic_family_id) + return false; + auto p = m_plugins.get(fid, nullptr); + CTRACE("sls_verbose", p != nullptr, tout << "external " << mk_bounded_pp(a, m) << "\n"); + return p != nullptr; + } lbool context::check() { // @@ -438,6 +450,7 @@ namespace sls { sat::literal context::mk_literal(expr* e) { expr_ref _e(e, m); + SASSERT(!m_input_assertions.contains(e)); sat::literal lit; bool neg = false; expr* a, * b, * c; @@ -528,8 +541,11 @@ namespace sls { for (unsigned i = 0; i < m_atoms.size(); ++i) if (m_atoms.get(i)) register_terms(m_atoms.get(i)); - for (auto e : m_input_assertions) - register_terms(e); + { + flet _is_input_assertion(m_is_input_assertion, true); + for (auto e : m_input_assertions) + register_terms(e); + } for (auto p : m_plugins) if (p) p->initialize(); @@ -564,7 +580,7 @@ namespace sls { m_parents.reserve(arg->get_id() + 1); m_parents[arg->get_id()].push_back(e); } - if (m.is_bool(e)) + if (m.is_bool(e) && !m_is_input_assertion) mk_literal(e); visit(e); } @@ -629,7 +645,6 @@ namespace sls { m_visited.reset(); m_root_literals.reset(); - for (auto const& clause : s.clauses()) { bool has_relevant = false; unsigned n = 0; diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index ee3696d87..1b69682d9 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -69,12 +69,16 @@ namespace sls { virtual sat::clause_info const& get_clause(unsigned idx) const = 0; virtual ptr_iterator get_use_list(sat::literal lit) = 0; virtual void flip(sat::bool_var v) = 0; + virtual sat::bool_var bool_flip() = 0; virtual bool try_rotate(sat::bool_var v, sat::bool_var_set& rotated, unsigned& budget) = 0; virtual double reward(sat::bool_var v) = 0; virtual double get_weigth(unsigned clause_idx) = 0; virtual bool is_true(sat::literal lit) = 0; virtual unsigned num_vars() const = 0; virtual indexed_uint_set const& unsat() const = 0; + virtual indexed_uint_set const& unsat_vars() const = 0; + virtual void shift_weights() = 0; + virtual unsigned num_external_in_unsat_vars() const = 0; virtual void on_model(model_ref& mdl) = 0; virtual sat::bool_var add_var() = 0; virtual void add_clause(unsigned n, sat::literal const* lits) = 0; @@ -136,6 +140,7 @@ namespace sls { void init(); expr_ref_vector m_todo; + bool m_is_input_assertion = false; void register_terms(expr* e); void register_term(expr* e); @@ -162,6 +167,7 @@ namespace sls { void register_atom(sat::bool_var v, expr* e); lbool check(); + bool is_external(sat::bool_var v); void on_restart(); void updt_params(params_ref const& p); params_ref const& get_params() const { return m_params; } @@ -183,9 +189,13 @@ namespace sls { void add_theory_axiom(expr* f) { add_assertion(f, false); } void add_clause(sat::literal_vector const& lits); void flip(sat::bool_var v) { s.flip(v); } + sat::bool_var bool_flip() { return s.bool_flip(); } + void shift_weights() { s.shift_weights(); } bool try_rotate(sat::bool_var v, sat::bool_var_set& rotated, unsigned& budget) { return s.try_rotate(v, rotated, budget); } double reward(sat::bool_var v) { return s.reward(v); } indexed_uint_set const& unsat() const { return s.unsat(); } + indexed_uint_set const& unsat_vars() const { return s.unsat_vars(); } + unsigned num_external_in_unsat_vars() const { return s.num_external_in_unsat_vars(); } unsigned rand() { return m_rand(); } unsigned rand(unsigned n) { return m_rand(n); } reslimit& rlimit() { return s.rlimit(); } diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h index 711fec4d6..d8bf025e2 100644 --- a/src/ast/sls/sls_smt_plugin.h +++ b/src/ast/sls/sls_smt_plugin.h @@ -124,6 +124,8 @@ namespace sls { m_ddfw->reinit(); } + void shift_weights() override { m_ddfw->shift_weights(); } + lbool on_save_model() override; void on_model(model_ref& mdl) override { @@ -131,6 +133,14 @@ namespace sls { m_sls_model = mdl; } + sat::bool_var bool_flip() override { + return m_ddfw->bool_flip(); + } + + bool is_external(sat::bool_var v) override { + return m_context.is_external(v); + } + void on_rescale() override {} reslimit& rlimit() override { return m_ddfw->rlimit(); } @@ -160,6 +170,8 @@ namespace sls { } unsigned num_vars() const override { return m_ddfw->num_vars(); } indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } + indexed_uint_set const& unsat_vars() const override { return m_ddfw->unsat_vars(); } + unsigned num_external_in_unsat_vars() const override { return m_ddfw->num_external_in_unsat_vars(); } sat::bool_var add_var() override { return m_ddfw->add_var(); } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 0258fc183..5a11e4dbb 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -67,10 +67,14 @@ namespace sls { return r; } - void on_model(model_ref& mdl) override { + void on_model(model_ref& mdl) override { m_model = mdl; } + bool is_external(sat::bool_var v) override { + return m_context.is_external(v); + } + void register_atom(sat::bool_var v, expr* e) { m_context.register_atom(v, e); } @@ -85,15 +89,19 @@ namespace sls { sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); } ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); } void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); } + sat::bool_var bool_flip() override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; return m_ddfw.bool_flip(); } bool try_rotate(sat::bool_var v, sat::bool_var_set& rotated, unsigned& budget) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; return m_ddfw.try_rotate(v, rotated, budget); } double reward(sat::bool_var v) override { return m_ddfw.reward(v); } double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; } bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } unsigned num_vars() const override { return m_ddfw.num_vars(); } indexed_uint_set const& unsat() const override { return m_ddfw.unsat_set(); } + indexed_uint_set const& unsat_vars() const override { return m_ddfw.unsat_vars(); } + unsigned num_external_in_unsat_vars() const override { return m_ddfw.num_external_in_unsat_vars(); } sat::bool_var add_var() override { m_dirty = true; return m_ddfw.add_var(); } void add_input_assertion(expr* f) { m_context.add_input_assertion(f); } reslimit& rlimit() { return m_ddfw.rlimit(); } + void shift_weights() override { m_ddfw.shift_weights(); } void force_restart() override { m_ddfw.force_restart(); } diff --git a/src/params/sls_params.pyg b/src/params/sls_params.pyg index 1ec140c95..5df6c1c63 100644 --- a/src/params/sls_params.pyg +++ b/src/params/sls_params.pyg @@ -27,6 +27,7 @@ def_module_params('sls', ('random_seed', UINT, 0, 'random seed'), ('arith_use_lookahead', BOOL, True, 'use lookahead solver for NIRA'), ('arith_allow_plateau', BOOL, False, 'allow plateau moves during NIRA solving'), + ('arith_use_clausal_lookahead', BOOL, False, 'use clause based lookahead for NIRA'), ('bv_use_top_level_assertions', BOOL, True, 'use top-level assertions for BV lookahead solver'), ('bv_use_lookahead', BOOL, True, 'use lookahead solver for BV'), ('bv_allow_rotation', BOOL, True, 'allow model rotation when repairing literal assignment'), diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index 5e04900a4..2d9f18808 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -227,7 +227,7 @@ static tactic * mk_sls_tactic(ast_manager & m, params_ref const & p) { clean(alloc(sls_tactic, m, p))); } -static tactic * mk_preamble(ast_manager & m, params_ref const & p) { +static tactic * mk_preamble(ast_manager & m, params_ref const & p, bool add_nnf) { params_ref simp2_p = p; simp2_p.set_bool("som", true); @@ -244,27 +244,28 @@ static tactic * mk_preamble(ast_manager & m, params_ref const & p) { // conservative gaussian elimination. gaussian_p.set_uint("gaussian_max_occs", 2); - return and_then(and_then(mk_simplify_tactic(m, p), - mk_propagate_values_tactic(m), - using_params(mk_solve_eqs_tactic(m), gaussian_p), - mk_elim_uncnstr_tactic(m), - mk_bv_size_reduction_tactic(m), - using_params(mk_simplify_tactic(m), simp2_p)), - using_params(mk_simplify_tactic(m), hoist_p), - mk_max_bv_sharing_tactic(m)//, - // mk_nnf_tactic(m, p) + return and_then( + and_then(mk_simplify_tactic(m, p), + mk_propagate_values_tactic(m), + using_params(mk_solve_eqs_tactic(m), gaussian_p), + mk_elim_uncnstr_tactic(m), + mk_bv_size_reduction_tactic(m), + using_params(mk_simplify_tactic(m), simp2_p)), + using_params(mk_simplify_tactic(m), hoist_p), + mk_max_bv_sharing_tactic(m), + add_nnf ? mk_nnf_tactic(m, p) : mk_skip_tactic() ); } tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p) { - tactic * t = and_then(mk_preamble(m, p), mk_sls_tactic(m, p)); + tactic * t = and_then(mk_preamble(m, p, true), mk_sls_tactic(m, p)); t->updt_params(p); return t; } tactic* mk_sls_smt_tactic(ast_manager& m, params_ref const& p) { - tactic* t = and_then(mk_preamble(m, p), alloc(sls_smt_tactic, m, p)); + tactic* t = and_then(mk_preamble(m, p, false), alloc(sls_smt_tactic, m, p)); t->updt_params(p); return t; }