diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 6dc1f2eeb..fafd336a0 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -483,9 +483,113 @@ namespace sls { template void arith_base::add_update(var_t v, num_t delta) { num_t delta_out; + auto const& vi = m_vars[v]; if (!is_permitted_update(v, delta, delta_out)) return; - m_updates.push_back({ v, delta_out, 0 }); + if (vi.m_op == arith_op_kind::OP_NUM) + return; + if (is_add(v) && m_allow_recursive_delta) + add_update_add(m_adds[vi.m_def_idx], delta_out); + else if (is_mul(v) && m_allow_recursive_delta) + add_update_mul(m_muls[vi.m_def_idx], delta_out); + else if (is_op(v) && m_allow_recursive_delta) + add_update(m_ops[vi.m_def_idx], delta_out); + else if (vi.is_if_op() && m_allow_recursive_delta) { + expr* c, * t, * e; + VERIFY(m.is_ite(vi.m_expr, c, t, e)); + bool cond = ctx.is_true(c); + if (cond) + add_update(mk_term(t), delta_out); + else + add_update(mk_term(e), delta_out); + } + else { + if (!is_uninterp(vi.m_expr) && m_allow_recursive_delta) + verbose_stream() << mk_bounded_pp(vi.m_expr, m) << " += " << delta_out << "\n"; + m_updates.push_back({ v, delta_out, 0 }); + } + } + + template + void arith_base::add_update(op_def const& od, num_t const& delta) { + switch (od.m_op) { + case arith_op_kind::OP_IDIV: + case arith_op_kind::OP_IDIV0: + add_update_idiv(od, delta); + break; + case arith_op_kind::OP_MOD: + case arith_op_kind::OP_MOD0: + add_update_mod(od, delta); + break; + case arith_op_kind::OP_NUM: + break; + case arith_op_kind::OP_DIV: + case arith_op_kind::OP_DIV0: + case arith_op_kind::OP_POWER: + default: + IF_VERBOSE(1, verbose_stream() << "add-update-op is TBD " << mk_bounded_pp(m_vars[od.m_var].m_expr, m) << " " << od.m_op << " " << delta << "\n"); + break; + } + } + + template + void arith_base::add_update_idiv(op_def const& od, num_t const& delta) { + num_t arg1 = value(od.m_arg1); + num_t arg2 = value(od.m_arg2); + + if (arg2 != 0) { + num_t val = div(arg1, arg2); + if (arg2 > 0) + add_update(od.m_arg1, delta * arg2); + else if (arg2 < 0) + add_update(od.m_arg1, -delta * arg2); + } + } + + template + void arith_base::add_update_mod(op_def const& od, num_t const& delta) { + num_t val = value(od.m_var); + num_t arg1 = value(od.m_arg1); + num_t arg2 = value(od.m_arg2); + if (arg1 + delta >= 0 && arg1 + delta < arg2) + add_update(od.m_arg1, delta); + } + + template + void arith_base::add_update_add(add_def const& ad, num_t const& delta) { + for (auto const& [coeff, w] : ad.m_args) + add_update(w, divide(w, delta, coeff)); + } + + + template + void arith_base::add_update_mul(mul_def const& md, num_t const& delta) { + auto const& [v, monomial] = md; + auto val = value(v) + delta; + + if (val == 0) { + for (auto [x, p] : monomial) + add_update(x, -value(x)); + } + else if (val == 1 || val == -1) { + for (auto [x, p] : monomial) { + add_update(x, num_t(1) - value(x)); + add_update(x, num_t(-1) - value(x)); + } + } + else { + for (auto [x, p] : monomial) { + auto mx = mul_value_without(v, x); + // val / mx = x^p + if (mx == 0) + continue; + auto valmx = divide(x, val, mx); + auto r = root_of(p, valmx); + add_update(x, r - value(x)); + if (p % 2 == 0) + add_update(x, -r - value(x)); + } + } } // flip on the first positive score @@ -978,10 +1082,18 @@ namespace sls { template typename arith_base::var_t arith_base::mk_var(expr* e) { var_t v = m_expr2var.get(e->get_id(), UINT_MAX); - if (v == UINT_MAX) { - v = m_vars.size(); - m_expr2var.setx(e->get_id(), v, UINT_MAX); - m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); + if (v != UINT_MAX) + return v; + v = m_vars.size(); + m_expr2var.setx(e->get_id(), v, UINT_MAX); + m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); + expr* c = nullptr, * th = nullptr, * el = nullptr; + if (m.is_ite(e, c, th, el)) { + auto th_v = m_expr2var[th->get_id()]; + auto el_v = m_expr2var[el->get_id()]; + m_vars[th_v].m_ifs.push_back(v); + m_vars[el_v].m_ifs.push_back(v); + m_vars[v].m_def_idx = UINT_MAX - 1; } return v; } @@ -1152,7 +1264,16 @@ namespace sls { template num_t arith_base::value1(var_t v) { auto const& vi = m_vars[v]; - if (vi.m_def_idx == UINT_MAX) + + if (vi.is_if_op()) { + expr* c = nullptr, * th = nullptr, *el = nullptr; + VERIFY(m.is_ite(vi.m_expr, c, th, el)); + if (ctx.is_true(c)) + return value(mk_var(th)); + else + return value(mk_var(el)); + } + if (!vi.is_arith_op()) return value(v); num_t result, v1, v2; @@ -1224,7 +1345,7 @@ namespace sls { if (v == UINT_MAX) return; auto const& vi = m_vars[v]; - if (vi.m_def_idx == UINT_MAX) + if (!vi.is_arith_op()) return; auto new_value = value1(v); if (!update(v, new_value)) @@ -1237,7 +1358,7 @@ namespace sls { if (v == UINT_MAX) return false; auto const& vi = m_vars[v]; - if (vi.m_def_idx == UINT_MAX) + if (!vi.is_arith_op()) return false; flet _tabu(m_use_tabu, false); TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); @@ -2326,7 +2447,7 @@ namespace sls { template bool arith_base::eval_is_correct(var_t v) { auto const& vi = m_vars[v]; - if (vi.m_def_idx == UINT_MAX) + if (!vi.is_arith_op()) return true; IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); @@ -2441,8 +2562,7 @@ namespace sls { for (auto const& [c, v] : i.m_args) val += c * value(v); if (val != i.m_args_value) { - verbose_stream() << val << ": " << i << "\n"; - display(verbose_stream()); + IF_VERBOSE(0, verbose_stream() << val << ": " << i << "\n"; display(verbose_stream())); TRACE("arith", display(tout << val << ": " << i << "\n")); } SASSERT(val == i.m_args_value); @@ -2516,6 +2636,12 @@ namespace sls { auto old_value = value(v); IF_VERBOSE(5, verbose_stream() << "update: v" << v << " " << mk_bounded_pp(vi.m_expr, m) << " := " << old_value << " -> " << new_value << "\n"); vi.set_value(new_value); + ctx.new_value_eh(vi.m_expr); + + for (auto const& [coeff, bv] : vi.m_linear_occurs) { + auto& ineq = *get_ineq(bv); + ineq.m_args_value += coeff * (new_value - old_value); + } for (auto const& idx : vi.m_muls) { auto& [x, monomial] = m_muls[idx]; @@ -2536,10 +2662,10 @@ namespace sls { for (auto const& x : vi.m_ops) update_args_value(x, value1(x)); - for (auto const& [coeff, bv] : vi.m_linear_occurs) { - auto& ineq = *get_ineq(bv); - ineq.m_args_value += coeff * (new_value - old_value); - } + for (auto const& x : vi.m_ifs) + update_args_value(x, value1(x)); + + } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index e537019c3..86d45afd0 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -114,13 +114,16 @@ namespace sls { vector> m_linear_occurs; sat::bool_var_vector m_bool_vars_of; unsigned_vector m_clauses_of; - unsigned_vector m_muls, m_adds, m_ops; + unsigned_vector m_muls, m_adds, m_ops, m_ifs; optional m_lo, m_hi; vector m_finite_domain; num_t const& value() const { return m_value; } void set_value(num_t const& v) { m_value = v; } + bool is_arith_op() const { return m_def_idx < UINT_MAX - 1; } + bool is_if_op() const { return m_def_idx == UINT_MAX - 1; } + num_t const& best_value() const { return m_best_value; } void set_best_value(num_t const& v) { m_best_value = v; } @@ -202,6 +205,7 @@ namespace sls { sat::literal m_last_literal = sat::null_literal; num_t m_last_delta { 0 }; bool m_use_tabu = true; + bool m_allow_recursive_delta = false; unsigned m_updates_max_size = 45; arith_util a; friend class arith_clausal; @@ -241,6 +245,11 @@ namespace sls { num_t mul_value_without(var_t m, var_t x); void add_update(var_t v, num_t delta); + void add_update(op_def const& od, num_t const& delta); + void add_update_mod(op_def const& od, num_t const& delta); + void add_update_add(add_def const& ad, num_t const& delta); + void add_update_mul(mul_def const& md, num_t const& delta); + void add_update_idiv(op_def const& od, num_t const& delta); bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out); @@ -270,7 +279,8 @@ namespace sls { bool is_mul(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_MUL; } bool is_add(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_ADD; } - bool is_op(var_t v) const { return m_vars[v].m_op != arith_op_kind::LAST_ARITH_OP && m_vars[v].m_op != arith_op_kind::OP_MUL && m_vars[v].m_op != arith_op_kind::OP_ADD; } + bool is_op(var_t v) const { return 0 <= m_vars[v].m_op && m_vars[v].m_op < arith_op_kind::LAST_ARITH_OP && m_vars[v].m_op != arith_op_kind::OP_MUL && m_vars[v].m_op != arith_op_kind::OP_ADD; } + bool is_if(var_t v) const { return m.is_ite(m_vars[v].m_expr); } mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; } add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; } diff --git a/src/ast/sls/sls_arith_clausal.cpp b/src/ast/sls/sls_arith_clausal.cpp index 4d506f2bb..1f6157787 100644 --- a/src/ast/sls/sls_arith_clausal.cpp +++ b/src/ast/sls/sls_arith_clausal.cpp @@ -137,12 +137,11 @@ namespace sls { if (!ineq) return; num_t na, nb; + flet _allow_recursive_delta(a.m_allow_recursive_delta, true); for (auto const& [x, nl] : ineq->m_nonlinear) { if (a.is_fixed(x)) continue; - if (a.is_add(x) || a.is_mul(x) || a.is_op(x)) - ; - else if (a.is_linear(x, nl, nb)) + if (a.is_linear(x, nl, nb)) a.find_linear_moves(*ineq, x, nb); else if (a.is_quadratic(x, nl, na, nb)) a.find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); @@ -229,13 +228,13 @@ namespace sls { return v; } - template void arith_clausal::lookahead(var_t v, num_t const& delta) { if (v == m_last_var && delta == m_last_delta) return; if (delta == 0) return; + m_last_var = v; m_last_delta = delta; if (!a.can_update_num(v, delta))