3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-08-24 03:57:51 +00:00

add clausal lookahead to arithmetic solver as part of portfolio

have legacy qfbv-sls solver use nnf pre-processing. It relies on it for correctness of the score updates.
This commit is contained in:
Nikolaj Bjorner 2025-01-20 16:16:46 -08:00
parent a941f5ae84
commit 22e4054674
13 changed files with 678 additions and 92 deletions

View file

@ -111,7 +111,8 @@ namespace sls {
arith_base<num_t>::arith_base(context& ctx) :
plugin(ctx),
a(m),
m_new_terms(m) {
m_new_terms(m),
m_clausal_sls(*this) {
m_fid = a.get_family_id();
}
@ -447,12 +448,12 @@ namespace sls {
delta_out = delta;
if (m_last_var == v && m_last_delta == -delta) {
TRACE("arith", tout << "flip back " << v << " " << delta << "\n";);
TRACE("arith_verbose", tout << "flip back " << v << " " << delta << "\n";);
return false;
}
if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) {
TRACE("arith", tout << "tabu\n");
if (m_use_tabu && vi.is_tabu(m_stats.m_steps, delta)) {
TRACE("arith_verbose", tout << "tabu v" << v << " delta:" << delta << "\n");
return false;
}
@ -545,8 +546,8 @@ namespace sls {
if (update(v, new_value)) {
m_last_delta = delta;
m_stats.m_num_steps++;
m_vars[v].set_step(m_stats.m_num_steps, m_stats.m_num_steps + 3 + ctx.rand(10), delta);
m_stats.m_steps++;
m_vars[v].set_step(m_stats.m_steps, m_stats.m_steps + 3 + ctx.rand(10), delta);
return true;
}
sum_score -= score;
@ -1106,6 +1107,7 @@ namespace sls {
// attach i to bv
m_ineqs.set(bv, &i);
m_bool_var_atoms.insert(bv);
}
template<typename num_t>
@ -1403,6 +1405,40 @@ namespace sls {
throw default_exception("repair is not supported for " + mk_pp(e, m));
}
}
for (unsigned v = 0; v < m_vars.size(); ++v)
initialize_bool_vars_of(v);
}
template<typename num_t>
void arith_base<num_t>::initialize_bool_vars_of(var_t v) {
if (!m_vars[v].m_bool_vars_of.empty())
return;
buffer<var_t> todo;
todo.push_back(v);
auto& vi = m_vars[v];
for (unsigned i = 0; i < todo.size(); ++i) {
var_t u = todo[i];
auto& ui = m_vars[u];
for (auto const& idx : ui.m_muls) {
auto& [x, monomial] = m_muls[idx];
if (all_of(todo, [x](var_t v) { return x != v; }))
todo.push_back(x);
}
for (auto const& idx : ui.m_adds) {
auto x = m_adds[idx].m_var;
if (all_of(todo, [x](var_t v) { return x != v; }))
todo.push_back(x);
}
for (auto const& [coeff, bv] : ui.m_linear_occurs)
vi.m_bool_vars_of.insert(bv);
}
;
for (auto bv : vi.m_bool_vars_of) {
for (auto i : ctx.get_use_list(sat::literal(bv, true)))
vi.m_clauses_of.insert(i);
for (auto i : ctx.get_use_list(sat::literal(bv, false)))
vi.m_clauses_of.insert(i);
}
}
template<typename num_t>
@ -2274,7 +2310,7 @@ namespace sls {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
return true;
IF_VERBOSE(4, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
switch (vi.m_op) {
case arith_op_kind::LAST_ARITH_OP:
@ -2398,13 +2434,12 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::collect_statistics(statistics& st) const {
st.update("sls-arith-flips", m_stats.m_num_steps);
st.update("sls-arith-moves", m_stats.m_moves);
st.update("sls-arith-steps", m_stats.m_steps);
}
template<typename num_t>
void arith_base<num_t>::reset_statistics() {
m_stats.m_num_steps = 0;
m_stats.m_steps = 0;
}
// global lookahead mode
@ -2708,7 +2743,6 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::lookahead_num(var_t v, num_t const& delta) {
num_t old_value = value(v);
expr* e = m_vars[v].m_expr;
if (m_last_expr != e) {
if (m_last_expr)
@ -2779,6 +2813,31 @@ namespace sls {
m_last_expr = nullptr;
}
template<typename num_t>
void arith_base<num_t>::add_lookahead(bool_info& i, sat::bool_var bv) {
if (!i.fixable_atoms.contains(bv))
return;
if (m_fixed_atoms.contains(bv))
return;
auto* ineq = get_ineq(bv);
if (!ineq)
return;
num_t na, nb;
for (auto const& [x, nl] : ineq->m_nonlinear) {
if (!i.fixable_vars.contains(x))
continue;
if (is_fixed(x))
continue;
if (is_linear(x, nl, nb))
find_linear_moves(*ineq, x, nb);
else if (is_quadratic(x, nl, na, nb))
find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value);
else
;
}
m_fixed_atoms.insert(bv);
}
// for every variable e, for every atom containing e
// add lookahead for e.
// m_fixable_atoms contains atoms that can be fixed.
@ -2786,33 +2845,6 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::add_lookahead(bool_info& i, expr* e) {
auto add_atom = [&](sat::bool_var bv) {
if (!i.fixable_atoms.contains(bv))
return;
if (m_fixed_atoms.contains(bv))
return;
auto a = ctx.atom(bv);
if (!a)
return;
auto* ineq = get_ineq(bv);
if (!ineq)
return;
num_t na, nb;
for (auto const& [x, nl] : ineq->m_nonlinear) {
if (!i.fixable_vars.contains(x))
continue;
if (is_fixed(x))
continue;
if (is_linear(x, nl, nb))
find_linear_moves(*ineq, x, nb);
else if (is_quadratic(x, nl, na, nb))
find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value);
else
;
}
m_fixed_atoms.insert(bv);
};
auto add_finite_domain = [&](var_t v) {
auto old_value = value(v);
for (auto const& n : m_vars[v].m_finite_domain)
@ -2832,13 +2864,8 @@ namespace sls {
add_finite_domain(v);
return;
}
for (auto const& [coeff, bv] : vi.m_linear_occurs)
add_atom(bv);
for (auto const& idx : vi.m_muls) {
auto const& [x, monomial] = m_muls[idx];
for (auto [coeff, bv] : m_vars[x].m_linear_occurs)
add_atom(bv);
}
for (auto bv : vi.m_bool_vars_of)
add_lookahead(i, bv);
}
}
@ -2927,7 +2954,7 @@ namespace sls {
add_lookahead(info, vars[(start + i) % sz]);
if (m_updates.empty())
return false;
unsigned idx = ctx.rand() % m_updates.size();
unsigned idx = ctx.rand(m_updates.size());
auto& [v, delta, score] = m_updates[idx];
m_best_expr = m_vars[v].m_expr;
if (false && !m_vars[v].m_finite_domain.empty())
@ -3015,13 +3042,13 @@ namespace sls {
void arith_base<num_t>::global_search() {
initialize_bool_assignment();
rescore();
m_config.max_moves = m_stats.m_moves + m_config.max_moves_base;
TRACE("arith", tout << "search " << m_stats.m_moves << " " << m_config.max_moves << "\n";);
IF_VERBOSE(3, verbose_stream() << "lookahead-search moves:" << m_stats.m_moves << " max-moves:" << m_config.max_moves << "\n");
m_config.max_moves = m_stats.m_steps + m_config.max_moves_base;
TRACE("arith", tout << "search " << m_stats.m_steps << " " << m_config.max_moves << "\n";);
IF_VERBOSE(3, verbose_stream() << "lookahead-search steps:" << m_stats.m_steps << " max-moves:" << m_config.max_moves << "\n");
TRACE("arith", display(tout));
while (ctx.rlimit().inc() && m_stats.m_moves < m_config.max_moves) {
m_stats.m_moves++;
while (ctx.rlimit().inc() && m_stats.m_steps < m_config.max_moves) {
m_stats.m_steps++;
check_restart();
auto t = get_candidate_unsat();
@ -3043,7 +3070,7 @@ namespace sls {
if (apply_move(t, vars, arith_move_type::random_update))
recalibrate_weights();
}
if (m_stats.m_moves >= m_config.max_moves)
if (m_stats.m_steps >= m_config.max_moves)
m_config.max_moves_base += 100;
finalize_bool_assignment();
}
@ -3098,11 +3125,11 @@ namespace sls {
if (old_value == new_value)
return true;
if (!vi.in_range(new_value)) {
TRACE("arith", tout << "Not in range v" << v << " " << new_value << "\n");
TRACE("arith_verbose", tout << "Not in range v" << v << " " << new_value << "\n");
return false;
}
if (!in_bounds(v, new_value) && in_bounds(v, old_value)) {
TRACE("arith", tout << "out of bounds v" << v << " " << new_value << "\n");
TRACE("arith_verbose", tout << "out of bounds v" << v << " " << new_value << "\n");
//verbose_stream() << "out of bounds v" << v << " " << new_value << "\n";
return false;
}
@ -3166,16 +3193,16 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::check_restart() {
if (m_stats.m_moves % m_config.restart_base == 0) {
if (m_stats.m_steps % m_config.restart_base == 0) {
ucb_forget();
rescore();
}
if (m_stats.m_moves < m_config.restart_next)
if (m_stats.m_steps < m_config.restart_next)
return;
++m_stats.m_restarts;
m_config.restart_next = std::max(m_config.restart_next, m_stats.m_moves);
m_config.restart_next = std::max(m_config.restart_next, m_stats.m_steps);
if (0x1 == (m_stats.m_restarts & 0x1))
m_config.restart_next += m_config.restart_base;
@ -3184,10 +3211,8 @@ namespace sls {
// reset_uninterp_in_false_literals
rescore();
}
template<typename num_t>
void arith_base<num_t>::ucb_forget() {
if (m_config.ucb_forget >= 1.0)
@ -3214,18 +3239,21 @@ namespace sls {
//m_config.ucb_forget = p.ucb_forget();
m_config.wp = p.wp();
m_config.restart_base = p.restart_base();
//m_config.restart_next = p.restart_next();
m_config.restart_next = p.restart_base();
//m_config.max_moves_base = p.max_moves_base();
//m_config.max_moves = p.max_moves();
m_config.arith_use_lookahead = p.arith_use_lookahead();
m_config.use_lookahead = p.arith_use_lookahead();
m_config.use_clausal_lookahead = p.arith_use_clausal_lookahead();
m_config.allow_plateau = p.arith_allow_plateau();
m_config.config_initialized = true;
}
template<typename num_t>
void arith_base<num_t>::start_propagation() {
updt_params();
if (m_config.arith_use_lookahead)
updt_params();
if (m_config.use_clausal_lookahead)
m_clausal_sls.search();
else if (m_config.use_lookahead)
global_search();
}