3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00

add recursive updates to lookahead

This commit is contained in:
Nikolaj Bjorner 2025-01-25 16:10:00 -08:00
parent 57cb988461
commit 7fc59b65ad
3 changed files with 156 additions and 21 deletions

View file

@ -483,9 +483,113 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::add_update(var_t v, num_t delta) {
num_t delta_out;
auto const& vi = m_vars[v];
if (!is_permitted_update(v, delta, delta_out))
return;
m_updates.push_back({ v, delta_out, 0 });
if (vi.m_op == arith_op_kind::OP_NUM)
return;
if (is_add(v) && m_allow_recursive_delta)
add_update_add(m_adds[vi.m_def_idx], delta_out);
else if (is_mul(v) && m_allow_recursive_delta)
add_update_mul(m_muls[vi.m_def_idx], delta_out);
else if (is_op(v) && m_allow_recursive_delta)
add_update(m_ops[vi.m_def_idx], delta_out);
else if (vi.is_if_op() && m_allow_recursive_delta) {
expr* c, * t, * e;
VERIFY(m.is_ite(vi.m_expr, c, t, e));
bool cond = ctx.is_true(c);
if (cond)
add_update(mk_term(t), delta_out);
else
add_update(mk_term(e), delta_out);
}
else {
if (!is_uninterp(vi.m_expr) && m_allow_recursive_delta)
verbose_stream() << mk_bounded_pp(vi.m_expr, m) << " += " << delta_out << "\n";
m_updates.push_back({ v, delta_out, 0 });
}
}
template<typename num_t>
void arith_base<num_t>::add_update(op_def const& od, num_t const& delta) {
switch (od.m_op) {
case arith_op_kind::OP_IDIV:
case arith_op_kind::OP_IDIV0:
add_update_idiv(od, delta);
break;
case arith_op_kind::OP_MOD:
case arith_op_kind::OP_MOD0:
add_update_mod(od, delta);
break;
case arith_op_kind::OP_NUM:
break;
case arith_op_kind::OP_DIV:
case arith_op_kind::OP_DIV0:
case arith_op_kind::OP_POWER:
default:
IF_VERBOSE(1, verbose_stream() << "add-update-op is TBD " << mk_bounded_pp(m_vars[od.m_var].m_expr, m) << " " << od.m_op << " " << delta << "\n");
break;
}
}
template<typename num_t>
void arith_base<num_t>::add_update_idiv(op_def const& od, num_t const& delta) {
num_t arg1 = value(od.m_arg1);
num_t arg2 = value(od.m_arg2);
if (arg2 != 0) {
num_t val = div(arg1, arg2);
if (arg2 > 0)
add_update(od.m_arg1, delta * arg2);
else if (arg2 < 0)
add_update(od.m_arg1, -delta * arg2);
}
}
template<typename num_t>
void arith_base<num_t>::add_update_mod(op_def const& od, num_t const& delta) {
num_t val = value(od.m_var);
num_t arg1 = value(od.m_arg1);
num_t arg2 = value(od.m_arg2);
if (arg1 + delta >= 0 && arg1 + delta < arg2)
add_update(od.m_arg1, delta);
}
template<typename num_t>
void arith_base<num_t>::add_update_add(add_def const& ad, num_t const& delta) {
for (auto const& [coeff, w] : ad.m_args)
add_update(w, divide(w, delta, coeff));
}
template<typename num_t>
void arith_base<num_t>::add_update_mul(mul_def const& md, num_t const& delta) {
auto const& [v, monomial] = md;
auto val = value(v) + delta;
if (val == 0) {
for (auto [x, p] : monomial)
add_update(x, -value(x));
}
else if (val == 1 || val == -1) {
for (auto [x, p] : monomial) {
add_update(x, num_t(1) - value(x));
add_update(x, num_t(-1) - value(x));
}
}
else {
for (auto [x, p] : monomial) {
auto mx = mul_value_without(v, x);
// val / mx = x^p
if (mx == 0)
continue;
auto valmx = divide(x, val, mx);
auto r = root_of(p, valmx);
add_update(x, r - value(x));
if (p % 2 == 0)
add_update(x, -r - value(x));
}
}
}
// flip on the first positive score
@ -978,10 +1082,18 @@ namespace sls {
template<typename num_t>
typename arith_base<num_t>::var_t arith_base<num_t>::mk_var(expr* e) {
var_t v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v == UINT_MAX) {
v = m_vars.size();
m_expr2var.setx(e->get_id(), v, UINT_MAX);
m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL));
if (v != UINT_MAX)
return v;
v = m_vars.size();
m_expr2var.setx(e->get_id(), v, UINT_MAX);
m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL));
expr* c = nullptr, * th = nullptr, * el = nullptr;
if (m.is_ite(e, c, th, el)) {
auto th_v = m_expr2var[th->get_id()];
auto el_v = m_expr2var[el->get_id()];
m_vars[th_v].m_ifs.push_back(v);
m_vars[el_v].m_ifs.push_back(v);
m_vars[v].m_def_idx = UINT_MAX - 1;
}
return v;
}
@ -1152,7 +1264,16 @@ namespace sls {
template<typename num_t>
num_t arith_base<num_t>::value1(var_t v) {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (vi.is_if_op()) {
expr* c = nullptr, * th = nullptr, *el = nullptr;
VERIFY(m.is_ite(vi.m_expr, c, th, el));
if (ctx.is_true(c))
return value(mk_var(th));
else
return value(mk_var(el));
}
if (!vi.is_arith_op())
return value(v);
num_t result, v1, v2;
@ -1224,7 +1345,7 @@ namespace sls {
if (v == UINT_MAX)
return;
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (!vi.is_arith_op())
return;
auto new_value = value1(v);
if (!update(v, new_value))
@ -1237,7 +1358,7 @@ namespace sls {
if (v == UINT_MAX)
return false;
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (!vi.is_arith_op())
return false;
flet<bool> _tabu(m_use_tabu, false);
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
@ -2326,7 +2447,7 @@ namespace sls {
template<typename num_t>
bool arith_base<num_t>::eval_is_correct(var_t v) {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (!vi.is_arith_op())
return true;
IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
@ -2441,8 +2562,7 @@ namespace sls {
for (auto const& [c, v] : i.m_args)
val += c * value(v);
if (val != i.m_args_value) {
verbose_stream() << val << ": " << i << "\n";
display(verbose_stream());
IF_VERBOSE(0, verbose_stream() << val << ": " << i << "\n"; display(verbose_stream()));
TRACE("arith", display(tout << val << ": " << i << "\n"));
}
SASSERT(val == i.m_args_value);
@ -2516,6 +2636,12 @@ namespace sls {
auto old_value = value(v);
IF_VERBOSE(5, verbose_stream() << "update: v" << v << " " << mk_bounded_pp(vi.m_expr, m) << " := " << old_value << " -> " << new_value << "\n");
vi.set_value(new_value);
ctx.new_value_eh(vi.m_expr);
for (auto const& [coeff, bv] : vi.m_linear_occurs) {
auto& ineq = *get_ineq(bv);
ineq.m_args_value += coeff * (new_value - old_value);
}
for (auto const& idx : vi.m_muls) {
auto& [x, monomial] = m_muls[idx];
@ -2536,10 +2662,10 @@ namespace sls {
for (auto const& x : vi.m_ops)
update_args_value(x, value1(x));
for (auto const& [coeff, bv] : vi.m_linear_occurs) {
auto& ineq = *get_ineq(bv);
ineq.m_args_value += coeff * (new_value - old_value);
}
for (auto const& x : vi.m_ifs)
update_args_value(x, value1(x));
}

View file

@ -114,13 +114,16 @@ namespace sls {
vector<std::pair<num_t, sat::bool_var>> m_linear_occurs;
sat::bool_var_vector m_bool_vars_of;
unsigned_vector m_clauses_of;
unsigned_vector m_muls, m_adds, m_ops;
unsigned_vector m_muls, m_adds, m_ops, m_ifs;
optional<bound> m_lo, m_hi;
vector<num_t> m_finite_domain;
num_t const& value() const { return m_value; }
void set_value(num_t const& v) { m_value = v; }
bool is_arith_op() const { return m_def_idx < UINT_MAX - 1; }
bool is_if_op() const { return m_def_idx == UINT_MAX - 1; }
num_t const& best_value() const { return m_best_value; }
void set_best_value(num_t const& v) { m_best_value = v; }
@ -202,6 +205,7 @@ namespace sls {
sat::literal m_last_literal = sat::null_literal;
num_t m_last_delta { 0 };
bool m_use_tabu = true;
bool m_allow_recursive_delta = false;
unsigned m_updates_max_size = 45;
arith_util a;
friend class arith_clausal<num_t>;
@ -241,6 +245,11 @@ namespace sls {
num_t mul_value_without(var_t m, var_t x);
void add_update(var_t v, num_t delta);
void add_update(op_def const& od, num_t const& delta);
void add_update_mod(op_def const& od, num_t const& delta);
void add_update_add(add_def const& ad, num_t const& delta);
void add_update_mul(mul_def const& md, num_t const& delta);
void add_update_idiv(op_def const& od, num_t const& delta);
bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out);
@ -270,7 +279,8 @@ namespace sls {
bool is_mul(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_MUL; }
bool is_add(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_ADD; }
bool is_op(var_t v) const { return m_vars[v].m_op != arith_op_kind::LAST_ARITH_OP && m_vars[v].m_op != arith_op_kind::OP_MUL && m_vars[v].m_op != arith_op_kind::OP_ADD; }
bool is_op(var_t v) const { return 0 <= m_vars[v].m_op && m_vars[v].m_op < arith_op_kind::LAST_ARITH_OP && m_vars[v].m_op != arith_op_kind::OP_MUL && m_vars[v].m_op != arith_op_kind::OP_ADD; }
bool is_if(var_t v) const { return m.is_ite(m_vars[v].m_expr); }
mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; }
add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; }

View file

@ -137,12 +137,11 @@ namespace sls {
if (!ineq)
return;
num_t na, nb;
flet<bool> _allow_recursive_delta(a.m_allow_recursive_delta, true);
for (auto const& [x, nl] : ineq->m_nonlinear) {
if (a.is_fixed(x))
continue;
if (a.is_add(x) || a.is_mul(x) || a.is_op(x))
;
else if (a.is_linear(x, nl, nb))
if (a.is_linear(x, nl, nb))
a.find_linear_moves(*ineq, x, nb);
else if (a.is_quadratic(x, nl, na, nb))
a.find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value);
@ -229,13 +228,13 @@ namespace sls {
return v;
}
template<typename num_t>
void arith_clausal<num_t>::lookahead(var_t v, num_t const& delta) {
if (v == m_last_var && delta == m_last_delta)
return;
if (delta == 0)
return;
m_last_var = v;
m_last_delta = delta;
if (!a.can_update_num(v, delta))