mirror of
https://github.com/Z3Prover/z3
synced 2025-08-24 03:57:51 +00:00
add clausal lookahead to arithmetic solver as part of portfolio
have legacy qfbv-sls solver use nnf pre-processing. It relies on it for correctness of the score updates.
This commit is contained in:
parent
a941f5ae84
commit
22e4054674
13 changed files with 678 additions and 92 deletions
|
@ -111,7 +111,8 @@ namespace sls {
|
|||
arith_base<num_t>::arith_base(context& ctx) :
|
||||
plugin(ctx),
|
||||
a(m),
|
||||
m_new_terms(m) {
|
||||
m_new_terms(m),
|
||||
m_clausal_sls(*this) {
|
||||
m_fid = a.get_family_id();
|
||||
}
|
||||
|
||||
|
@ -447,12 +448,12 @@ namespace sls {
|
|||
delta_out = delta;
|
||||
|
||||
if (m_last_var == v && m_last_delta == -delta) {
|
||||
TRACE("arith", tout << "flip back " << v << " " << delta << "\n";);
|
||||
TRACE("arith_verbose", tout << "flip back " << v << " " << delta << "\n";);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) {
|
||||
TRACE("arith", tout << "tabu\n");
|
||||
if (m_use_tabu && vi.is_tabu(m_stats.m_steps, delta)) {
|
||||
TRACE("arith_verbose", tout << "tabu v" << v << " delta:" << delta << "\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -545,8 +546,8 @@ namespace sls {
|
|||
|
||||
if (update(v, new_value)) {
|
||||
m_last_delta = delta;
|
||||
m_stats.m_num_steps++;
|
||||
m_vars[v].set_step(m_stats.m_num_steps, m_stats.m_num_steps + 3 + ctx.rand(10), delta);
|
||||
m_stats.m_steps++;
|
||||
m_vars[v].set_step(m_stats.m_steps, m_stats.m_steps + 3 + ctx.rand(10), delta);
|
||||
return true;
|
||||
}
|
||||
sum_score -= score;
|
||||
|
@ -1106,6 +1107,7 @@ namespace sls {
|
|||
|
||||
// attach i to bv
|
||||
m_ineqs.set(bv, &i);
|
||||
m_bool_var_atoms.insert(bv);
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
|
@ -1403,6 +1405,40 @@ namespace sls {
|
|||
throw default_exception("repair is not supported for " + mk_pp(e, m));
|
||||
}
|
||||
}
|
||||
for (unsigned v = 0; v < m_vars.size(); ++v)
|
||||
initialize_bool_vars_of(v);
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::initialize_bool_vars_of(var_t v) {
|
||||
if (!m_vars[v].m_bool_vars_of.empty())
|
||||
return;
|
||||
buffer<var_t> todo;
|
||||
todo.push_back(v);
|
||||
auto& vi = m_vars[v];
|
||||
for (unsigned i = 0; i < todo.size(); ++i) {
|
||||
var_t u = todo[i];
|
||||
auto& ui = m_vars[u];
|
||||
for (auto const& idx : ui.m_muls) {
|
||||
auto& [x, monomial] = m_muls[idx];
|
||||
if (all_of(todo, [x](var_t v) { return x != v; }))
|
||||
todo.push_back(x);
|
||||
}
|
||||
for (auto const& idx : ui.m_adds) {
|
||||
auto x = m_adds[idx].m_var;
|
||||
if (all_of(todo, [x](var_t v) { return x != v; }))
|
||||
todo.push_back(x);
|
||||
}
|
||||
for (auto const& [coeff, bv] : ui.m_linear_occurs)
|
||||
vi.m_bool_vars_of.insert(bv);
|
||||
}
|
||||
;
|
||||
for (auto bv : vi.m_bool_vars_of) {
|
||||
for (auto i : ctx.get_use_list(sat::literal(bv, true)))
|
||||
vi.m_clauses_of.insert(i);
|
||||
for (auto i : ctx.get_use_list(sat::literal(bv, false)))
|
||||
vi.m_clauses_of.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
|
@ -2274,7 +2310,7 @@ namespace sls {
|
|||
auto const& vi = m_vars[v];
|
||||
if (vi.m_def_idx == UINT_MAX)
|
||||
return true;
|
||||
IF_VERBOSE(4, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
|
||||
IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
|
||||
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
|
||||
switch (vi.m_op) {
|
||||
case arith_op_kind::LAST_ARITH_OP:
|
||||
|
@ -2398,13 +2434,12 @@ namespace sls {
|
|||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::collect_statistics(statistics& st) const {
|
||||
st.update("sls-arith-flips", m_stats.m_num_steps);
|
||||
st.update("sls-arith-moves", m_stats.m_moves);
|
||||
st.update("sls-arith-steps", m_stats.m_steps);
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::reset_statistics() {
|
||||
m_stats.m_num_steps = 0;
|
||||
m_stats.m_steps = 0;
|
||||
}
|
||||
|
||||
// global lookahead mode
|
||||
|
@ -2708,7 +2743,6 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
void arith_base<num_t>::lookahead_num(var_t v, num_t const& delta) {
|
||||
num_t old_value = value(v);
|
||||
|
||||
expr* e = m_vars[v].m_expr;
|
||||
if (m_last_expr != e) {
|
||||
if (m_last_expr)
|
||||
|
@ -2779,6 +2813,31 @@ namespace sls {
|
|||
m_last_expr = nullptr;
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::add_lookahead(bool_info& i, sat::bool_var bv) {
|
||||
if (!i.fixable_atoms.contains(bv))
|
||||
return;
|
||||
if (m_fixed_atoms.contains(bv))
|
||||
return;
|
||||
auto* ineq = get_ineq(bv);
|
||||
if (!ineq)
|
||||
return;
|
||||
num_t na, nb;
|
||||
for (auto const& [x, nl] : ineq->m_nonlinear) {
|
||||
if (!i.fixable_vars.contains(x))
|
||||
continue;
|
||||
if (is_fixed(x))
|
||||
continue;
|
||||
if (is_linear(x, nl, nb))
|
||||
find_linear_moves(*ineq, x, nb);
|
||||
else if (is_quadratic(x, nl, na, nb))
|
||||
find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value);
|
||||
else
|
||||
;
|
||||
}
|
||||
m_fixed_atoms.insert(bv);
|
||||
}
|
||||
|
||||
// for every variable e, for every atom containing e
|
||||
// add lookahead for e.
|
||||
// m_fixable_atoms contains atoms that can be fixed.
|
||||
|
@ -2786,33 +2845,6 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
void arith_base<num_t>::add_lookahead(bool_info& i, expr* e) {
|
||||
|
||||
auto add_atom = [&](sat::bool_var bv) {
|
||||
if (!i.fixable_atoms.contains(bv))
|
||||
return;
|
||||
if (m_fixed_atoms.contains(bv))
|
||||
return;
|
||||
auto a = ctx.atom(bv);
|
||||
if (!a)
|
||||
return;
|
||||
auto* ineq = get_ineq(bv);
|
||||
if (!ineq)
|
||||
return;
|
||||
num_t na, nb;
|
||||
for (auto const& [x, nl] : ineq->m_nonlinear) {
|
||||
if (!i.fixable_vars.contains(x))
|
||||
continue;
|
||||
if (is_fixed(x))
|
||||
continue;
|
||||
if (is_linear(x, nl, nb))
|
||||
find_linear_moves(*ineq, x, nb);
|
||||
else if (is_quadratic(x, nl, na, nb))
|
||||
find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value);
|
||||
else
|
||||
;
|
||||
}
|
||||
m_fixed_atoms.insert(bv);
|
||||
};
|
||||
|
||||
auto add_finite_domain = [&](var_t v) {
|
||||
auto old_value = value(v);
|
||||
for (auto const& n : m_vars[v].m_finite_domain)
|
||||
|
@ -2832,13 +2864,8 @@ namespace sls {
|
|||
add_finite_domain(v);
|
||||
return;
|
||||
}
|
||||
for (auto const& [coeff, bv] : vi.m_linear_occurs)
|
||||
add_atom(bv);
|
||||
for (auto const& idx : vi.m_muls) {
|
||||
auto const& [x, monomial] = m_muls[idx];
|
||||
for (auto [coeff, bv] : m_vars[x].m_linear_occurs)
|
||||
add_atom(bv);
|
||||
}
|
||||
for (auto bv : vi.m_bool_vars_of)
|
||||
add_lookahead(i, bv);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2927,7 +2954,7 @@ namespace sls {
|
|||
add_lookahead(info, vars[(start + i) % sz]);
|
||||
if (m_updates.empty())
|
||||
return false;
|
||||
unsigned idx = ctx.rand() % m_updates.size();
|
||||
unsigned idx = ctx.rand(m_updates.size());
|
||||
auto& [v, delta, score] = m_updates[idx];
|
||||
m_best_expr = m_vars[v].m_expr;
|
||||
if (false && !m_vars[v].m_finite_domain.empty())
|
||||
|
@ -3015,13 +3042,13 @@ namespace sls {
|
|||
void arith_base<num_t>::global_search() {
|
||||
initialize_bool_assignment();
|
||||
rescore();
|
||||
m_config.max_moves = m_stats.m_moves + m_config.max_moves_base;
|
||||
TRACE("arith", tout << "search " << m_stats.m_moves << " " << m_config.max_moves << "\n";);
|
||||
IF_VERBOSE(3, verbose_stream() << "lookahead-search moves:" << m_stats.m_moves << " max-moves:" << m_config.max_moves << "\n");
|
||||
m_config.max_moves = m_stats.m_steps + m_config.max_moves_base;
|
||||
TRACE("arith", tout << "search " << m_stats.m_steps << " " << m_config.max_moves << "\n";);
|
||||
IF_VERBOSE(3, verbose_stream() << "lookahead-search steps:" << m_stats.m_steps << " max-moves:" << m_config.max_moves << "\n");
|
||||
TRACE("arith", display(tout));
|
||||
|
||||
while (ctx.rlimit().inc() && m_stats.m_moves < m_config.max_moves) {
|
||||
m_stats.m_moves++;
|
||||
while (ctx.rlimit().inc() && m_stats.m_steps < m_config.max_moves) {
|
||||
m_stats.m_steps++;
|
||||
check_restart();
|
||||
|
||||
auto t = get_candidate_unsat();
|
||||
|
@ -3043,7 +3070,7 @@ namespace sls {
|
|||
if (apply_move(t, vars, arith_move_type::random_update))
|
||||
recalibrate_weights();
|
||||
}
|
||||
if (m_stats.m_moves >= m_config.max_moves)
|
||||
if (m_stats.m_steps >= m_config.max_moves)
|
||||
m_config.max_moves_base += 100;
|
||||
finalize_bool_assignment();
|
||||
}
|
||||
|
@ -3098,11 +3125,11 @@ namespace sls {
|
|||
if (old_value == new_value)
|
||||
return true;
|
||||
if (!vi.in_range(new_value)) {
|
||||
TRACE("arith", tout << "Not in range v" << v << " " << new_value << "\n");
|
||||
TRACE("arith_verbose", tout << "Not in range v" << v << " " << new_value << "\n");
|
||||
return false;
|
||||
}
|
||||
if (!in_bounds(v, new_value) && in_bounds(v, old_value)) {
|
||||
TRACE("arith", tout << "out of bounds v" << v << " " << new_value << "\n");
|
||||
TRACE("arith_verbose", tout << "out of bounds v" << v << " " << new_value << "\n");
|
||||
//verbose_stream() << "out of bounds v" << v << " " << new_value << "\n";
|
||||
return false;
|
||||
}
|
||||
|
@ -3166,16 +3193,16 @@ namespace sls {
|
|||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::check_restart() {
|
||||
if (m_stats.m_moves % m_config.restart_base == 0) {
|
||||
if (m_stats.m_steps % m_config.restart_base == 0) {
|
||||
ucb_forget();
|
||||
rescore();
|
||||
}
|
||||
|
||||
if (m_stats.m_moves < m_config.restart_next)
|
||||
if (m_stats.m_steps < m_config.restart_next)
|
||||
return;
|
||||
|
||||
++m_stats.m_restarts;
|
||||
m_config.restart_next = std::max(m_config.restart_next, m_stats.m_moves);
|
||||
m_config.restart_next = std::max(m_config.restart_next, m_stats.m_steps);
|
||||
|
||||
if (0x1 == (m_stats.m_restarts & 0x1))
|
||||
m_config.restart_next += m_config.restart_base;
|
||||
|
@ -3184,10 +3211,8 @@ namespace sls {
|
|||
|
||||
// reset_uninterp_in_false_literals
|
||||
rescore();
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::ucb_forget() {
|
||||
if (m_config.ucb_forget >= 1.0)
|
||||
|
@ -3214,18 +3239,21 @@ namespace sls {
|
|||
//m_config.ucb_forget = p.ucb_forget();
|
||||
m_config.wp = p.wp();
|
||||
m_config.restart_base = p.restart_base();
|
||||
//m_config.restart_next = p.restart_next();
|
||||
m_config.restart_next = p.restart_base();
|
||||
//m_config.max_moves_base = p.max_moves_base();
|
||||
//m_config.max_moves = p.max_moves();
|
||||
m_config.arith_use_lookahead = p.arith_use_lookahead();
|
||||
m_config.use_lookahead = p.arith_use_lookahead();
|
||||
m_config.use_clausal_lookahead = p.arith_use_clausal_lookahead();
|
||||
m_config.allow_plateau = p.arith_allow_plateau();
|
||||
m_config.config_initialized = true;
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::start_propagation() {
|
||||
updt_params();
|
||||
if (m_config.arith_use_lookahead)
|
||||
updt_params();
|
||||
if (m_config.use_clausal_lookahead)
|
||||
m_clausal_sls.search();
|
||||
else if (m_config.use_lookahead)
|
||||
global_search();
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue