diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index da97ca823..1e6eef3a3 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -43,6 +43,7 @@ Done: #include "ast/sls/sls_arith_base.h" #include "ast/ast_ll_pp.h" #include "ast/ast_pp.h" +#include "params/sls_params.hpp" #include namespace sls { @@ -388,7 +389,9 @@ namespace sls { } template - void arith_base::find_linear_moves(ineq const& ineq, var_t v, num_t const& coeff, num_t const& sum) { + void arith_base::find_linear_moves(ineq const& ineq, var_t v, num_t const& coeff) { + num_t const& sum = ineq.m_args_value; + TRACE("arith_verbose", tout << ineq << " " << v << " " << value(v) << "\n"); if (ineq.is_true()) { switch (ineq.m_op) { case ineq_kind::LE: @@ -426,6 +429,7 @@ namespace sls { delta = sum < 0 ? divide(v, abs(sum), coeff) : -divide(v, sum, coeff); if (sum + coeff * delta == 0) add_update(v, delta); + break; } default: @@ -441,17 +445,23 @@ namespace sls { delta_out = delta; - if (m_last_var == v && m_last_delta == -delta) - return false; - - if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) + if (m_last_var == v && m_last_delta == -delta) { + TRACE("arith", tout << "flip back " << v << " " << delta << "\n";); return false; + } + + if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) { + TRACE("arith", tout << "tabu\n"); + return false; + } auto old_value = value(v); auto new_value = old_value + delta; - if (!vi.in_range(new_value)) - return false; + if (!vi.in_range(new_value)) { + TRACE("arith", tout << "out of range: v" << v << " " << old_value << " " << delta << " " << new_value << "\n";); + return false; + } if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) { @@ -556,7 +566,7 @@ namespace sls { for (auto const& [coeff, x] : ineq->m_args) { if (is_fixed(x)) continue; - find_linear_moves(*ineq, x, coeff, ineq->m_args_value); + find_linear_moves(*ineq, x, coeff); } } return apply_update(); @@ -667,7 +677,7 @@ namespace sls { } buffer to_flip; - for (auto const& [coeff, bv] : vi.m_ineqs) { + for (auto const& [coeff, bv] : vi.m_linear_occurs) { auto& ineq = *get_ineq(bv); bool old_sign = sign(bv); sat::literal lit(bv, old_sign); @@ -781,11 +791,9 @@ namespace sls { template void arith_base::add_args(linear_term& term, expr* e, num_t const& coeff) { auto v = m_expr2var.get(e->get_id(), UINT_MAX); - expr* x, * y; + expr* x, * y, * z, * u; num_t i; - if (v != UINT_MAX) - add_arg(term, coeff, v); - else if (is_num(e, i)) + if (is_num(e, i)) term.m_coeff += coeff * i; else if (a.is_add(e)) { for (expr* arg : *to_app(e)) @@ -798,6 +806,18 @@ namespace sls { else if (a.is_mul(e, x, y) && is_num(x, i)) { add_args(term, y, i * coeff); } + else if (a.is_mul(e, x, y) && a.is_add(y, z, u)) { + expr_ref t(a.mk_mul(x, z), m); + add_args(term, t, coeff); + t = a.mk_mul(x, u); + add_args(term, t, coeff); + } + else if (a.is_mul(e, x, y) && a.is_add(x, z, u)) { + expr_ref t(a.mk_mul(y, z), m); + add_args(term, t, coeff); + t = a.mk_mul(y, u); + add_args(term, t, coeff); + } else if (a.is_mul(e)) { unsigned_vector ms; for (expr* arg : *to_app(e)) @@ -812,23 +832,31 @@ namespace sls { break; default: { v = mk_var(e); + unsigned idx = m_muls.size(); - std::stable_sort(ms.begin(), ms.end(), [&](unsigned a, unsigned b) { return a < b; }); - svector> mp; - for (unsigned i = 0; i < ms.size(); ++i) { - auto w = ms[i]; - auto p = 1; - while (i + 1 < ms.size() && ms[i + 1] == w) - ++p, ++i; - mp.push_back({ w, p }); + for (idx = 0; idx < m_muls.size(); ++idx) + if (m_muls[idx].m_var == v) + break; + + if (idx == m_muls.size()) { + std::stable_sort(ms.begin(), ms.end(), [&](unsigned a, unsigned b) { return a < b; }); + svector> mp; + for (unsigned i = 0; i < ms.size(); ++i) { + auto w = ms[i]; + auto p = 1; + while (i + 1 < ms.size() && ms[i + 1] == w) + ++p, ++i; + mp.push_back({ w, p }); + } + + m_muls.push_back({ v, mp }); + num_t prod(1); + for (auto [w, p] : mp) + m_vars[w].m_muls.push_back(idx), prod *= power_of(value(w), p); + m_vars[v].m_def_idx = idx; + m_vars[v].m_op = arith_op_kind::OP_MUL; + m_vars[v].set_value(prod); } - m_muls.push_back({ v, mp }); - num_t prod(1); - for (auto [w, p] : mp) - m_vars[w].m_muls.push_back(idx), prod *= power_of(value(w), p); - m_vars[v].m_def_idx = idx; - m_vars[v].m_op = arith_op_kind::OP_MUL; - m_vars[v].set_value(prod); add_arg(term, coeff, v); break; } @@ -836,6 +864,8 @@ namespace sls { } else if (a.is_uminus(e, x)) add_args(term, x, -coeff); + else if (v != UINT_MAX) + add_arg(term, coeff, v); else if (a.is_mod(e, x, y) || a.is_mod0(e, x, y)) add_arg(term, coeff, mk_op(arith_op_kind::OP_MOD, e, x, y)); else if (a.is_idiv(e, x, y) || a.is_idiv0(e, x, y)) @@ -1004,7 +1034,7 @@ namespace sls { // compute the value of the linear term, and accumulate non-linear sub-terms i.m_args_value = i.m_coeff; for (auto const& [coeff, v] : i.m_args) { - m_vars[v].m_ineqs.push_back({ coeff, bv }); + m_vars[v].m_linear_occurs.push_back({ coeff, bv }); i.m_args_value += coeff * value(v); if (is_mul(v)) { auto const& [w, monomial] = get_mul(v); @@ -1202,6 +1232,8 @@ namespace sls { void arith_base::initialize() { for (auto lit : ctx.unit_literals()) initialize_unit(lit); + for (auto f : ctx.input_assertions()) + initialize_input_assertion(f); for (unsigned v = 0; v < m_vars.size(); ++v) { auto const& vi = m_vars[v]; if (vi.m_lo || vi.m_hi) @@ -1337,6 +1369,29 @@ namespace sls { } } + template + void arith_base::initialize_input_assertion(expr* f) { + if (m.is_or(f)) { + var_t v = UINT_MAX; + expr* x, * y; + vector values; + for (expr* arg : *to_app(f)) { + num_t n; + if (m.is_eq(arg, x, y) && is_num(y, n)) { + var_t w = m_expr2var.get(x->get_id(), UINT_MAX); + if (w != UINT_MAX && (v == w || v == UINT_MAX)) + v = w, values.push_back(n); + else + return; + } + else + return; + } + m_vars[v].m_finite_domain.append(values); + return; + } + } + template void arith_base::initialize_unit(sat::literal lit) { init_bool_var(lit.var()); @@ -1630,12 +1685,11 @@ namespace sls { double arith_base::compute_score(var_t x, num_t const& delta) { int result = 0; int breaks = 0; - for (auto const& [coeff, bv] : m_vars[x].m_ineqs) { + for (auto const& [coeff, bv] : m_vars[x].m_linear_occurs) { bool old_sign = sign(bv); auto lit = sat::literal(bv, old_sign); auto dtt_old = dtt(old_sign, *get_ineq(bv)); auto dtt_new = dtt(old_sign, *get_ineq(bv), coeff, delta); -#if 1 if (dtt_new == 0 && dtt_old != 0) result += 1; @@ -1645,14 +1699,6 @@ namespace sls { result -= 1; breaks += 1; } -#else - if (dtt_new == dtt_old) - continue; - if (m_use_tabu && ctx.is_unit(lit) && dtt_new != 0) - return 0; - double reward = ctx.reward(bv); - result += reward; -#endif } if (result < 0) @@ -1726,7 +1772,7 @@ namespace sls { if (is_fixed(x)) continue; if (is_linear(x, nl, b)) - find_linear_moves(*ineq, x, b, ineq->m_args_value); + find_linear_moves(*ineq, x, b); else if (is_quadratic(x, nl, a, b)) find_quadratic_moves(*ineq, x, a, b, ineq->m_args_value); else @@ -2149,9 +2195,9 @@ namespace sls { out << " "; } - if (!vi.m_ineqs.empty()) { + if (!vi.m_linear_occurs.empty()) { out << " bool: "; - for (auto [c, bv] : vi.m_ineqs) + for (auto [c, bv] : vi.m_linear_occurs) out << c << "@" << bv << " "; } return out; @@ -2260,36 +2306,38 @@ namespace sls { if (ineq) invariant(*ineq); } + auto report_error = [&](std::ostream& out, var_t v) { + display(out); + display(out << "variable: ", v) << "\n"; + out << mk_bounded_pp(m_vars[v].m_expr, m) << "\n"; + + if (is_mul(v)) { + auto const& [w, monomial] = get_mul(v); + num_t prod(1); + for (auto [v, p] : monomial) + prod *= power_of(value(v), p); + out << "product " << prod << " value " << value(w) << "\n"; + out << "v" << w << " := "; + for (auto [w, p] : monomial) { + out << "(v" << w; + if (p > 1) + out << "^" << p; + out << " := " << value(w); + out << ") "; + } + out << "\n"; + } + else if (is_add(v)) { + auto const& ad = get_add(v); + out << "v" << ad.m_var << " := "; + display(out, ad) << "\n"; + } + }; auto& out = verbose_stream(); for (var_t v = 0; v < m_vars.size(); ++v) { if (!eval_is_correct(v)) { - - display(out); - display(out, v) << "\n"; - out << mk_bounded_pp(m_vars[v].m_expr, m) << "\n"; - - if (is_mul(v)) { - auto const& [w, monomial] = get_mul(v); - num_t prod(1); - for (auto [v, p] : monomial) - prod *= power_of(value(v), p); - out << "product " << prod << " value " << value(w) << "\n"; - out << "v" << w << " := "; - for (auto [w, p] : monomial) { - out << "(v" << w; - if (p > 1) - out << "^" << p; - out << " := " << value(w); - out << ") "; - } - out << "\n"; - } - else if (is_add(v)) { - auto const& ad = get_add(v); - out << "v" << ad.m_var << " := "; - display(out, ad) << "\n"; - } - + report_error(verbose_stream(), v); + TRACE("arith", report_error(tout, v)); UNREACHABLE(); } } @@ -2300,8 +2348,11 @@ namespace sls { num_t val = i.m_coeff; for (auto const& [c, v] : i.m_args) val += c * value(v); - if (val != i.m_args_value) + if (val != i.m_args_value) { verbose_stream() << val << ": " << i << "\n"; + display(verbose_stream()); + TRACE("arith", display(tout << val << ": " << i << "\n")); + } SASSERT(val == i.m_args_value); VERIFY(val == i.m_args_value); } @@ -2310,6 +2361,7 @@ namespace sls { template void arith_base::collect_statistics(statistics& st) const { st.update("sls-arith-flips", m_stats.m_num_steps); + st.update("sls-arith-moves", m_stats.m_moves); } template @@ -2321,8 +2373,8 @@ namespace sls { // template - arith_base::bool_info& arith_base::get_bool_info(expr* e) { - m_bool_info.reserve(e->get_id() + 1, { m_config.paws_init, 0, 1, l_undef }); + arith_base::bool_info& arith_base::get_bool_info(expr* e) { + m_bool_info.reserve(e->get_id() + 1, bool_info(m_config.paws_init)); return m_bool_info[e->get_id()]; } @@ -2334,12 +2386,11 @@ namespace sls { if (is_uninterp(e)) return ctx.get_value(e) == l_true; - app* a = to_app(e); - if (a->get_family_id() == basic_family_id) { - bool r = get_basic_bool_value(a); - get_bool_info(e).value = to_lbool(r); - return r; - } + app* ap = to_app(e); + bool is_arith_eq = m.is_eq(e) && a.is_int_real(ap->get_arg(0)); + + if (ap->get_family_id() == basic_family_id && !is_arith_eq) + return get_basic_bool_value(ap); auto v = ctx.atom2bool_var(e); if (v == sat::null_bool_var) @@ -2347,9 +2398,7 @@ namespace sls { auto const* ineq = get_ineq(v); if (!ineq) return false; - auto r = ineq->is_true() == ctx.is_true(v); - get_bool_info(e).value = to_lbool(r); - return r; + return ineq->is_true(); } template @@ -2384,9 +2433,9 @@ namespace sls { case OP_EQ: if (m.is_bool(e->get_arg(0))) return get_bool_value(e->get_arg(0)) == get_bool_value(e->get_arg(1)); - NOT_IMPLEMENTED_YET(); + return ctx.get_value(e->get_arg(0)) == ctx.get_value(e->get_arg(1)); case OP_DISTINCT: - NOT_IMPLEMENTED_YET(); + return false; default: NOT_IMPLEMENTED_YET(); } @@ -2397,7 +2446,15 @@ namespace sls { void arith_base::initialize_bool_assignment() { for (auto t : ctx.subterms()) if (m.is_bool(t)) - get_bool_value(t); + set_bool_value(t, get_bool_value_rec(t)); +#if 0 + for (auto t : ctx.subterms()) { + if (m.is_bool(t)) + verbose_stream() << mk_bounded_pp(t, m) << " := " << get_bool_value(t) << "\n"; + else + verbose_stream() << mk_bounded_pp(t, m) << " := " << ctx.get_value(t) << "\n"; + } +#endif } template @@ -2420,7 +2477,6 @@ namespace sls { double arith_base::new_score(expr* a, bool is_true) { bool is_true_new = get_bool_value(a); - //verbose_stream() << "compute score " << mk_bounded_pp(a, m) << " is-true " << is_true << " is-true-new " << is_true_new << "\n"; if (is_true == is_true_new) return 1; if (is_uninterp(a)) @@ -2502,12 +2558,14 @@ namespace sls { break; } + SASSERT(value > 0); - unsigned max_value = 10000; + unsigned max_value = 1000; if (value > max_value) - return 1.0; + return 0.0; auto d = value.get_double(); - return 1.0 - ((d * d) / ((double)max_value * (double)max_value)); + double score = 1.0 - ((d * d) / ((double)max_value * (double)max_value)); + return score; } template @@ -2547,9 +2605,10 @@ namespace sls { } } } + m_update_stack.reserve(m_max_depth + 1); } template - double arith_base::lookahead(expr* t) { + double arith_base::lookahead(expr* t, bool update_score) { SASSERT(a.is_int_real(t) || m.is_bool(t)); double score = m_top_score; for (unsigned depth = m_min_depth; depth <= m_max_depth; ++depth) { @@ -2558,8 +2617,12 @@ namespace sls { TRACE("bv_verbose", tout << "update " << mk_bounded_pp(a, m) << " depth: " << depth << "\n";); if (t != a) set_bool_value(a, get_bool_value_rec(a)); - if (m_is_root.is_marked(a)) - score += get_weight(a) * (new_score(a) - old_score(a)); + if (m_is_root.is_marked(a)) { + auto nscore = new_score(a); + score += get_weight(a) * (nscore - old_score(a)); + if (update_score) + set_score(a, nscore); + } } } return score; @@ -2567,8 +2630,6 @@ namespace sls { template void arith_base::insert_update_stack(expr* t) { - if (!m.is_bool(t)) - return; unsigned depth = get_depth(t); m_update_stack.reserve(depth + 1); if (!m_in_update_stack.is_marked(t) && is_app(t)) { @@ -2579,39 +2640,51 @@ namespace sls { template void arith_base::clear_update_stack() { - lookahead(nullptr); m_in_update_stack.reset(); + m_update_stack.reserve(m_max_depth + 1); for (unsigned i = m_min_depth; i <= m_max_depth; ++i) m_update_stack[i].reset(); } template - void arith_base::lookahead_num(var_t v, num_t const& new_value) { + void arith_base::lookahead_num(var_t v, num_t const& delta) { num_t old_value = value(v); - if (!update(v, new_value)) + + if (!update_num(v, delta)) return; + num_t new_value = old_value + delta; expr* e = m_vars[v].m_expr; - auto score = lookahead(e); + if (m_last_expr != e) { + if (m_last_expr) + lookahead(m_last_expr, false); + clear_update_stack(); + insert_update_stack_rec(e); + m_last_expr = e; + } + auto score = lookahead(e, false); + TRACE("arith_verbose", tout << "lookahead " << v << " " << mk_bounded_pp(e, m) << " := " << delta + old_value << " " << score << " (" << m_best_score << ")\n";); if (score > m_best_score) { m_best_score = score; m_best_value = new_value; m_best_expr = e; } - VERIFY(update(v, old_value)); - lookahead(e); + + // revert back to old value + update_args_value(v, old_value); } template void arith_base::lookahead_bool(expr* e) { bool b = get_bool_value(e); set_bool_value(e, !b); - auto score = lookahead(e); + auto score = lookahead(e, false); if (score > m_best_score) { m_best_score = score; m_best_expr = e; } - set_bool_value(e, b); + set_bool_value(e, b); + lookahead(e, false); } // for every variable e, for every atom containing e @@ -2619,38 +2692,60 @@ namespace sls { // m_fixable_atoms contains atoms that can be fixed. // m_fixable_vars contains variables that can be updated. template - void arith_base::add_lookahead(expr* e) { + void arith_base::add_lookahead(bool_info& i, expr* e) { + + auto add_atom = [&](sat::bool_var bv) { + if (!i.fixable_atoms.contains(bv)) + return; + if (m_fixed_atoms.contains(bv)) + return; + auto a = ctx.atom(bv); + if (!a) + return; + auto* ineq = get_ineq(bv); + if (!ineq) + return; + num_t na, nb; + for (auto const& [x, nl] : ineq->m_nonlinear) { + if (!i.fixable_vars.contains(x)) + continue; + if (is_fixed(x)) + continue; + if (is_linear(x, nl, nb)) + find_linear_moves(*ineq, x, nb); + else if (is_quadratic(x, nl, na, nb)) + find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); + else + ; + } + m_fixed_atoms.insert(bv); + }; + + auto add_finite_domain = [&](var_t v) { + auto old_value = value(v); + for (auto const& n : m_vars[v].m_finite_domain) + add_update(v, n - old_value); + }; + + if (m.is_bool(e)) { auto bv = ctx.atom2bool_var(e); - if (m_fixable_atoms.contains(bv)) + if (i.fixable_atoms.contains(bv)) lookahead_bool(e); } else if (a.is_int_real(e)) { auto v = mk_term(e); auto& vi = m_vars[v]; - for (auto [coeff, bv] : vi.m_ineqs) { - if (!m_fixable_atoms.contains(bv)) - continue; - auto a = ctx.atom(bv); - if (!a) - continue; - auto* ineq = get_ineq(bv); - if (!ineq) - continue; - num_t na, nb; - for (auto const& [x, nl] : ineq->m_nonlinear) { - if (!m_fixable_vars.contains(x)) - continue; - if (is_fixed(x)) - continue; - if (is_linear(x, nl, nb)) - find_linear_moves(*ineq, x, nb, ineq->m_args_value); - else if (is_quadratic(x, nl, na, nb)) - find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); - else - ; - } - m_fixable_atoms.remove(bv); + if (false && !vi.m_finite_domain.empty()) { + add_finite_domain(v); + return; + } + for (auto const& [coeff, bv] : vi.m_linear_occurs) + add_atom(bv); + for (auto const& idx : vi.m_muls) { + auto const& [x, monomial] = m_muls[idx]; + for (auto [coeff, bv] : m_vars[x].m_linear_occurs) + add_atom(bv); } } } @@ -2661,111 +2756,170 @@ namespace sls { // candidate expressions may be either numeric or boolean variables. // template - void arith_base::add_fixable(expr* e) { - m_fixable_exprs.reset(); - m_fixable_atoms.reset(); - m_fixable_vars.reset(); + ptr_vector const& arith_base::get_fixable_exprs(expr* e) { + auto& i = get_bool_info(e); + if (!i.fixable_exprs.empty()) + return i.fixable_exprs; expr_mark visited; - buffer> todo; - expr* x, * y, * z; - todo.push_back({ e, l_true }); + ptr_buffer todo; + + todo.push_back(e); while (!todo.empty()) { - auto [e, is_true] = todo.back(); + auto e = todo.back(); todo.pop_back(); if (visited.is_marked(e)) continue; - visited.mark(e); - if (is_true == l_true && get_bool_value(e)) - continue; - if (is_true == l_false && !get_bool_value(e)) - continue; - if (m.is_not(e, e)) - todo.push_back({ e, ~is_true }); - else if (m.is_and(e) || m.is_or(e)) { + visited.mark(e); + if (m.is_xor(e) || m.is_and(e) || m.is_or(e) || m.is_implies(e) || m.is_iff(e) || m.is_ite(e) || m.is_not(e)) { for (auto arg : *to_app(e)) - todo.push_back({ arg, is_true }); - } - else if (m.is_implies(e, x, y)) { - todo.push_back({ x, ~is_true }); - todo.push_back({ y, is_true }); - } - else if (m.is_iff(e, x, y)) { - todo.push_back({ x, l_undef }); - todo.push_back({ y, l_undef }); - } - else if (m.is_ite(e, x, y, z)) { - todo.push_back({ x, l_undef }); - todo.push_back({ y, is_true }); - todo.push_back({ z, ~is_true }); + todo.push_back(arg); } else { auto bv = ctx.atom2bool_var(e); if (bv == sat::null_bool_var) continue; if (is_uninterp(e)) { - if (!m_fixable_atoms.contains(bv)) { - m_fixable_atoms.insert(bv); - m_fixable_exprs.push_back(e); + if (!i.fixable_atoms.contains(bv)) { + i.fixable_atoms.insert(bv); + i.fixable_exprs.push_back(e); } continue; } auto* ineq = get_ineq(bv); if (!ineq) continue; - m_fixable_atoms.insert(bv); - for (auto& [v, occ] : ineq->m_nonlinear) { - if (m_fixable_vars.contains(v)) + i.fixable_atoms.insert(bv); + buffer vars; + + for (auto& [v, occ] : ineq->m_nonlinear) + vars.push_back(v); + + for (unsigned j = 0; j < vars.size(); ++j) { + auto v = vars[j]; + if (i.fixable_vars.contains(v)) continue; - m_fixable_vars.insert(v); - m_fixable_exprs.push_back(m_vars[v].m_expr); + + if (is_add(v)) { + for (auto [c, w] : get_add(v).m_args) + vars.push_back(w); + } + else if (is_mul(v)) { + for (auto [w, p] : get_mul(v).m_monomial) + vars.push_back(w); + } + else { + i.fixable_exprs.push_back(m_vars[v].m_expr); + i.fixable_vars.insert(v); + } } } } + return i.fixable_exprs; } template - bool arith_base::apply_move(expr* t, bool randomize) { - add_fixable(t); - auto& vars = m_fixable_exprs; + bool arith_base::apply_move(expr* f, ptr_vector const& vars, arith_move_type t) { if (vars.empty()) return false; + auto& info = get_bool_info(f); m_best_expr = nullptr; m_best_score = m_top_score; unsigned sz = vars.size(); unsigned start = ctx.rand(); m_updates.reset(); - insert_update_stack_rec(t); - for (unsigned i = 0; i < sz; ++i) - add_lookahead(vars[(start + i) % sz]); + m_fixed_atoms.reset(); - if (randomize) { + switch (t) { + case arith_move_type::random_update: { + for (unsigned i = 0; i < sz; ++i) + add_lookahead(info, vars[(start + i) % sz]); if (m_updates.empty()) return false; - auto& [v, new_value, score] = m_updates[ctx.rand() % m_updates.size()]; + unsigned idx = ctx.rand() % m_updates.size(); + auto& [v, delta, score] = m_updates[idx]; m_best_expr = m_vars[v].m_expr; + m_best_value = value(v) + delta; + break; + } + case arith_move_type::hillclimb: { + for (unsigned i = 0; i < sz; ++i) + add_lookahead(info, vars[(start + i) % sz]); + if (m_updates.empty()) + return false; + std::stable_sort(m_updates.begin(), m_updates.end(), [](auto const& a, auto const& b) { return a.m_var < b.m_var; }); + m_last_expr = nullptr; + sz = m_updates.size(); + for (unsigned i = 0; i < sz; ++i) { + auto const& [v, delta, score] = m_updates[(start + i) % m_updates.size()]; + lookahead_num(v, delta); + } + if (m_last_expr) { + lookahead(m_last_expr, false); + clear_update_stack(); + } + break; + } + case arith_move_type::random_inc_dec: { + auto e = vars[ctx.rand() % sz]; + m_best_expr = e; + if (a.is_int_real(e)) { + var_t v = mk_term(e); + if (ctx.rand(2) == 0) + m_best_value = value(v) + 1; + else + m_best_value = value(v) - 1; + } + break; } - else { - for (auto const& [v, new_value, score] : m_updates) - lookahead_num(v, new_value); } - if (m_best_expr) - m_top_score = lookahead(m_best_expr); - clear_update_stack(); - - CTRACE("bv", !m_best_expr, tout << "no guided move\n";); + if (m_best_expr) { + if (m.is_bool(m_best_expr)) + set_bool_value(m_best_expr, !get_bool_value(m_best_expr)); + else { + var_t v = mk_term(m_best_expr); + if (!update_num(v, m_best_value - value(v))) { + TRACE("arith", + tout << "could not move v" << v << " " << t << " " << mk_bounded_pp(m_best_expr, m) << " := " << value(v) << " " << m_top_score << "\n"; + ); + return false; + } + } + insert_update_stack_rec(m_best_expr); + m_top_score = lookahead(m_best_expr, true); + clear_update_stack(); + } + + CTRACE("arith", !m_best_expr, tout << "no move " << t << "\n";); + CTRACE("arith", m_best_expr && a.is_int_real(m_best_expr), { + var_t v = mk_term(m_best_expr); + tout << t << " v" << v << " " << mk_bounded_pp(m_best_expr, m) << " := " << value(v) << " " << m_top_score << "\n"; + }); return !!m_best_expr; } + + std::ostream& operator<<(std::ostream& out, arith_move_type mt) { + switch (mt) { + case arith_move_type::random_update: out << "random-update"; break; + case arith_move_type::hillclimb: out << "hillclimb"; break; + case arith_move_type::random_inc_dec: out << "random-inc-dec"; break; + } + return out; + } + template void arith_base::global_search() { initialize_bool_assignment(); rescore(); m_config.max_moves = m_stats.m_moves + m_config.max_moves_base; - TRACE("bv", tout << "search " << m_stats.m_moves << " " << m_config.max_moves << "\n";); + TRACE("arith", tout << "search " << m_stats.m_moves << " " << m_config.max_moves << "\n";); IF_VERBOSE(1, verbose_stream() << "lookahead-search moves:" << m_stats.m_moves << " max-moves:" << m_config.max_moves << "\n"); + TRACE("arith", display(tout)); + bool loop_again = true; while (m.inc() && m_stats.m_moves < m_config.max_moves) { + loop_again = false; m_stats.m_moves++; check_restart(); @@ -2774,31 +2928,204 @@ namespace sls { if (!t) break; - if (apply_move(t, false)) + auto& vars = get_fixable_exprs(t); + + if (vars.empty()) + return; + + if (ctx.rand(2047) < m_config.wp && apply_move(t, vars, arith_move_type::random_inc_dec)) continue; - if (apply_move(t, true)) + if (apply_move(t, vars, arith_move_type::hillclimb)) + continue; + + if (apply_move(t, vars, arith_move_type::random_update)) recalibrate_weights(); + loop_again = true; } - m_config.max_moves_base += 100; + if (loop_again) + m_config.max_moves_base += 100; finalize_bool_assignment(); } template expr* arith_base::get_candidate_unsat() { - unsigned n = 0; - expr* r = nullptr; - for (auto a : ctx.input_assertions()) { - if (!get_bool_value(a) && (ctx.rand() % (++n)) == 0) - r = a; + expr* e = nullptr; + if (m_config.ucb) { + double max = -1.0; + for (auto a : ctx.input_assertions()) { + if (get_bool_value(a)) + continue; + + auto const& vars = get_fixable_exprs(a); + if (vars.empty()) + continue; + auto score = old_score(a); + auto q = score + + m_config.ucb_constant * ::sqrt(log((double)m_touched) / get_touched(a)) + + m_config.ucb_noise * ctx.rand(512); + if (q > max) + max = q, e = a; + } + if (e) { + m_touched++; + inc_touched(e); + } } - return r; + else { + unsigned n = 0; + for (auto a : ctx.input_assertions()) + if (!get_bool_value(a) && !get_fixable_exprs(a).empty() && ctx.rand() % ++n == 0) + e = a; + } + + m_last_atom = e; + CTRACE("arith", !e, "no candidate\n";); + CTRACE("arith", e, + tout << "select " << mk_bounded_pp(e, m) << " "; + for (auto v : get_fixable_exprs(e)) + tout << mk_bounded_pp(v, m) << " "; + tout << "\n"); + return e; + } + + template + bool arith_base::can_update_num(var_t v, num_t const& delta) { + num_t old_value = value(v); + num_t new_value = old_value + delta; + auto& vi = m_vars[v]; + //expr* e = vi.m_expr; + if (old_value == new_value) + return true; + if (!vi.in_range(new_value)) { + TRACE("arith", tout << "Not in range v" << v << " " << new_value << "\n"); + return false; + } + if (!in_bounds(v, new_value) && in_bounds(v, old_value)) { + TRACE("arith", tout << "out of bounds v" << v << " " << new_value << "\n"); + //verbose_stream() << "out of bounds v" << v << " " << new_value << "\n"; + return false; + } + + // check for overflow + try { + for (auto idx : vi.m_muls) { + auto const& [w, monomial] = m_muls[idx]; + num_t prod(1); + for (auto [w, p] : monomial) + prod *= power_of(v == w ? new_value : value(w), p); + } + } + catch (overflow_exception const&) { + return false; + } + return true; + } + + template + bool arith_base::update_num(var_t v, num_t const& delta) { + if (delta == 0) + return true; + if (!can_update_num(v, delta)) + return false; + auto& vi = m_vars[v]; + auto old_value = vi.value(); + num_t new_value = old_value + delta; + update_args_value(v, new_value); + return true; + } + + template + void arith_base::update_args_value(var_t v, num_t const& new_value) { + auto& vi = m_vars[v]; + + for (auto const& idx : vi.m_muls) { + auto& [x, monomial] = m_muls[idx]; + num_t new_prod(1); + for (auto [w, p] : monomial) + new_prod *= power_of(v == w ? new_value : value(w), p); + update_args_value(x, new_prod); + } + + for (auto const& idx : vi.m_adds) { + auto& ad = m_adds[idx]; + num_t new_sum(ad.m_coeff); + for (auto [c, w] : ad.m_args) + new_sum += c * (v == w ? new_value : value(w)); + update_args_value(ad.m_var, new_sum); + } + + auto old_value = value(v); + for (auto const& [coeff, bv] : vi.m_linear_occurs) { + auto& ineq = *get_ineq(bv); + ineq.m_args_value += coeff * (new_value - old_value); + } + IF_VERBOSE(5, verbose_stream() << "update: v" << v << " " << mk_bounded_pp(vi.m_expr, m) << " := " << old_value << " -> " << new_value << "\n"); + vi.set_value(new_value); } template void arith_base::check_restart() { + if (m_stats.m_moves % m_config.restart_base == 0) { + ucb_forget(); + rescore(); + } + + if (m_stats.m_moves < m_config.restart_next) + return; + + ++m_stats.m_restarts; + m_config.restart_next = std::max(m_config.restart_next, m_stats.m_moves); + + if (0x1 == (m_stats.m_restarts & 0x1)) + m_config.restart_next += m_config.restart_base; + else + m_config.restart_next += (2 * (m_stats.m_restarts >> 1)) * m_config.restart_base; + + // reset_uninterp_in_false_literals + rescore(); + + } + template + void arith_base::ucb_forget() { + if (m_config.ucb_forget >= 1.0) + return; + for (auto a : ctx.input_assertions()) { + auto touched_old = get_touched(a); + auto touched_new = static_cast((touched_old - 1) * m_config.ucb_forget + 1); + set_touched(a, touched_new); + m_touched += touched_new - touched_old; + } + } + + template + void arith_base::updt_params() { + if (m_config.config_initialized) + return; + + sls_params p(ctx.get_params()); + m_config.paws_init = p.paws_init(); + m_config.paws_sp = p.paws_sp(); + //m_config.ucb = p.ucb(); + //m_config.ucb_constant = p.ucb_constant(); + //m_config.ucb_noise = p.ucb_noise(); + //m_config.ucb_forget = p.ucb_forget(); + m_config.wp = p.wp(); + m_config.restart_base = p.restart_base(); + //m_config.restart_next = p.restart_next(); + //m_config.max_moves_base = p.max_moves_base(); + //m_config.max_moves = p.max_moves(); + m_config.arith_use_lookahead = p.arith_use_lookahead(); + m_config.config_initialized = true; + } + + template + void arith_base::start_propagation() { + updt_params(); + if (m_config.arith_use_lookahead) + global_search(); } } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 07b952d14..928f02a35 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -27,6 +27,14 @@ namespace sls { using theory_var = int; + enum arith_move_type { + hillclimb, + random_update, + random_inc_dec + }; + + std::ostream& operator<<(std::ostream& out, arith_move_type mt); + // local search portion for arithmetic template class arith_base : public plugin { @@ -37,6 +45,7 @@ namespace sls { typedef unsigned atom_t; struct config { + bool config_initialized = false; double cb = 2.85; unsigned L = 20; unsigned t = 45; @@ -47,11 +56,22 @@ namespace sls { bool paws = true; unsigned max_moves = 500; unsigned max_moves_base = 500; + unsigned wp = 100; + bool ucb = true; + double ucb_constant = 1.0; + double ucb_forget = 0.1; + bool ucb_init = false; + double ucb_noise = 0.1; + unsigned restart_base = 1000; + unsigned restart_next = 1000; + unsigned restart_init = 1000; + bool arith_use_lookahead = false; }; struct stats { unsigned m_num_steps = 0; unsigned m_moves = 0; + unsigned m_restarts = 0; }; public: @@ -93,10 +113,11 @@ namespace sls { var_sort m_sort; arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP; unsigned m_def_idx = UINT_MAX; - vector> m_ineqs; + vector> m_linear_occurs; unsigned_vector m_muls; unsigned_vector m_adds; optional m_lo, m_hi; + vector m_finite_domain; num_t const& value() const { return m_value; } void set_value(num_t const& v) { m_value = v; } @@ -187,6 +208,7 @@ namespace sls { unsigned get_num_vars() const { return m_vars.size(); } + void updt_params(); bool is_distinct(expr* e); bool eval_distinct(expr* e); void repair_distinct(expr* e); @@ -247,7 +269,7 @@ namespace sls { bool find_lin_moves(sat::literal lit); bool find_reset_moves(sat::literal lit); void add_reset_update(var_t v); - void find_linear_moves(ineq const& i, var_t x, num_t const& coeff, num_t const& sum); + void find_linear_moves(ineq const& i, var_t x, num_t const& coeff); void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum); double compute_score(var_t x, num_t const& delta); void save_best_values(); @@ -273,6 +295,7 @@ namespace sls { void check_ineqs(); void init_bool_var(sat::bool_var bv); void initialize_unit(sat::literal lit); + void initialize_input_assertion(expr* f); void add_le(var_t v, num_t const& n); void add_ge(var_t v, num_t const& n); void add_lt(var_t v, num_t const& n); @@ -288,20 +311,25 @@ namespace sls { struct bool_info { unsigned weight = 0; double score = 0; - unsigned touched = 0; + unsigned touched = 1; lbool value = l_undef; + sat::bool_var_set fixable_atoms; + uint_set fixable_vars; + ptr_vector fixable_exprs; + bool_info(unsigned w) : weight(w) {} }; + vector> m_update_stack; expr_mark m_in_update_stack; svector m_bool_info; double m_best_score = 0, m_top_score = 0; unsigned m_min_depth = 0, m_max_depth = 0; num_t m_best_value; - expr* m_best_expr = nullptr, * m_last_atom = nullptr; + expr* m_best_expr = nullptr, * m_last_atom = nullptr, * m_last_expr = nullptr; expr_mark m_is_root; - sat::bool_var_set m_fixable_atoms; - uint_set m_fixable_vars; - ptr_vector m_fixable_exprs; + unsigned m_touched = 1; + sat::bool_var_set m_fixed_atoms; + bool_info& get_bool_info(expr* e); bool get_bool_value(expr* e); bool get_bool_value_rec(expr* e); @@ -313,29 +341,36 @@ namespace sls { double new_score(expr* e); double new_score(expr* e, bool is_true); void set_score(expr* e, double s) { get_bool_info(e).score = s; } - void rescore(); void recalibrate_weights(); void inc_weight(expr* e) { ++get_bool_info(e).weight; } void dec_weight(expr* e) { auto& i = get_bool_info(e); i.weight = i.weight > m_config.paws_init ? i.weight - 1 : m_config.paws_init; } unsigned get_weight(expr* e) { return get_bool_info(e).weight; } + unsigned get_touched(expr* e) { return get_bool_info(e).touched; } + void inc_touched(expr* e) { ++get_bool_info(e).touched; } + void set_touched(expr* e, unsigned t) { get_bool_info(e).touched = t; } void insert_update_stack(expr* t); void insert_update_stack_rec(expr* t); void clear_update_stack(); void lookahead_num(var_t v, num_t const& value); + bool can_update_num(var_t v, num_t const& delta); + bool update_num(var_t v, num_t const& delta); void lookahead_bool(expr* e); - double lookahead(expr* e); - void add_lookahead(expr* e); - void add_fixable(expr* e); - bool apply_move(expr* f, bool randomize); + double lookahead(expr* e, bool update_score); + void add_lookahead(bool_info& i, expr* e); + ptr_vector const& get_fixable_exprs(expr* e); + bool apply_move(expr* f, ptr_vector const& vars, arith_move_type t); expr* get_candidate_unsat(); void check_restart(); + void ucb_forget(); + void update_args_value(var_t v, num_t const& new_value); public: arith_base(context& ctx); ~arith_base() override {} void register_term(expr* e) override; bool set_value(expr* e, expr* v) override; expr_ref get_value(expr* e) override; + void start_propagation() override; bool is_fixed(expr* e, expr_ref& value) override; void initialize() override; void propagate_literal(sat::literal lit) override; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index c8bbbfd51..da640db81 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -72,6 +72,10 @@ namespace sls { APPLY_BOTH(initialize()); } + void arith_plugin::start_propagation() { + WITH_FALLBACK(start_propagation()); + } + void arith_plugin::propagate_literal(sat::literal lit) { WITH_FALLBACK(propagate_literal(lit)); } diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 7a0471110..15dca5b4e 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -32,6 +32,7 @@ namespace sls { ~arith_plugin() override {} void register_term(expr* e) override; expr_ref get_value(expr* e) override; + void start_propagation() override; bool is_fixed(expr* e, expr_ref& value) override; void initialize() override; void propagate_literal(sat::literal lit) override; diff --git a/src/ast/sls/sls_bv_lookahead.cpp b/src/ast/sls/sls_bv_lookahead.cpp index 630642f1d..2d47cfb33 100644 --- a/src/ast/sls/sls_bv_lookahead.cpp +++ b/src/ast/sls/sls_bv_lookahead.cpp @@ -304,9 +304,9 @@ namespace sls { void bv_lookahead::updt_params(params_ref const& _p) { sls_params p(_p); - if (m_config.updated) + if (m_config.config_initialized) return; - m_config.updated = true; + m_config.config_initialized = true; m_config.walksat = p.walksat(); m_config.walksat_repick = p.walksat_repick(); m_config.paws_sp = p.paws_sp(); diff --git a/src/ast/sls/sls_bv_lookahead.h b/src/ast/sls/sls_bv_lookahead.h index a4b6ebaa9..1d85ce6f7 100644 --- a/src/ast/sls/sls_bv_lookahead.h +++ b/src/ast/sls/sls_bv_lookahead.h @@ -26,7 +26,7 @@ namespace sls { class bv_lookahead { struct config { - bool updated = false; + bool config_initialized = false; double cb = 2.85; unsigned paws_init = 40; unsigned paws_sp = 52; @@ -181,11 +181,11 @@ namespace sls { void finalize_bool_values(); + void updt_params(params_ref const& p); + public: bv_lookahead(bv_eval& ev); - void updt_params(params_ref const& p); - void start_propagation(); void collect_statistics(statistics& st) const; diff --git a/src/params/sls_params.pyg b/src/params/sls_params.pyg index 5a87bd745..7ac16479e 100644 --- a/src/params/sls_params.pyg +++ b/src/params/sls_params.pyg @@ -25,6 +25,7 @@ def_module_params('sls', ('dt_axiomatic', BOOL, True, 'use axiomatic mode or model reduction for datatype solver'), ('track_unsat', BOOL, 0, 'keep a list of unsat assertions as done in SAT - currently disabled internally'), ('random_seed', UINT, 0, 'random seed'), + ('arith_use_lookahead', BOOL, False, 'use lookahead solver for NIRA'), ('bv_use_top_level_assertions', BOOL, True, 'use top-level assertions for BV lookahead solver'), ('bv_use_lookahead', BOOL, True, 'use lookahead solver for BV'), ('bv_allow_rotation', BOOL, True, 'allow model rotation when repairing literal assignment'),