3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-08-24 03:57:51 +00:00

add recursive updates to lookahead

This commit is contained in:
Nikolaj Bjorner 2025-01-25 16:10:00 -08:00
parent 57cb988461
commit 7fc59b65ad
3 changed files with 156 additions and 21 deletions

View file

@ -483,9 +483,113 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::add_update(var_t v, num_t delta) {
num_t delta_out;
auto const& vi = m_vars[v];
if (!is_permitted_update(v, delta, delta_out))
return;
m_updates.push_back({ v, delta_out, 0 });
if (vi.m_op == arith_op_kind::OP_NUM)
return;
if (is_add(v) && m_allow_recursive_delta)
add_update_add(m_adds[vi.m_def_idx], delta_out);
else if (is_mul(v) && m_allow_recursive_delta)
add_update_mul(m_muls[vi.m_def_idx], delta_out);
else if (is_op(v) && m_allow_recursive_delta)
add_update(m_ops[vi.m_def_idx], delta_out);
else if (vi.is_if_op() && m_allow_recursive_delta) {
expr* c, * t, * e;
VERIFY(m.is_ite(vi.m_expr, c, t, e));
bool cond = ctx.is_true(c);
if (cond)
add_update(mk_term(t), delta_out);
else
add_update(mk_term(e), delta_out);
}
else {
if (!is_uninterp(vi.m_expr) && m_allow_recursive_delta)
verbose_stream() << mk_bounded_pp(vi.m_expr, m) << " += " << delta_out << "\n";
m_updates.push_back({ v, delta_out, 0 });
}
}
template<typename num_t>
void arith_base<num_t>::add_update(op_def const& od, num_t const& delta) {
switch (od.m_op) {
case arith_op_kind::OP_IDIV:
case arith_op_kind::OP_IDIV0:
add_update_idiv(od, delta);
break;
case arith_op_kind::OP_MOD:
case arith_op_kind::OP_MOD0:
add_update_mod(od, delta);
break;
case arith_op_kind::OP_NUM:
break;
case arith_op_kind::OP_DIV:
case arith_op_kind::OP_DIV0:
case arith_op_kind::OP_POWER:
default:
IF_VERBOSE(1, verbose_stream() << "add-update-op is TBD " << mk_bounded_pp(m_vars[od.m_var].m_expr, m) << " " << od.m_op << " " << delta << "\n");
break;
}
}
template<typename num_t>
void arith_base<num_t>::add_update_idiv(op_def const& od, num_t const& delta) {
num_t arg1 = value(od.m_arg1);
num_t arg2 = value(od.m_arg2);
if (arg2 != 0) {
num_t val = div(arg1, arg2);
if (arg2 > 0)
add_update(od.m_arg1, delta * arg2);
else if (arg2 < 0)
add_update(od.m_arg1, -delta * arg2);
}
}
template<typename num_t>
void arith_base<num_t>::add_update_mod(op_def const& od, num_t const& delta) {
num_t val = value(od.m_var);
num_t arg1 = value(od.m_arg1);
num_t arg2 = value(od.m_arg2);
if (arg1 + delta >= 0 && arg1 + delta < arg2)
add_update(od.m_arg1, delta);
}
template<typename num_t>
void arith_base<num_t>::add_update_add(add_def const& ad, num_t const& delta) {
for (auto const& [coeff, w] : ad.m_args)
add_update(w, divide(w, delta, coeff));
}
template<typename num_t>
void arith_base<num_t>::add_update_mul(mul_def const& md, num_t const& delta) {
auto const& [v, monomial] = md;
auto val = value(v) + delta;
if (val == 0) {
for (auto [x, p] : monomial)
add_update(x, -value(x));
}
else if (val == 1 || val == -1) {
for (auto [x, p] : monomial) {
add_update(x, num_t(1) - value(x));
add_update(x, num_t(-1) - value(x));
}
}
else {
for (auto [x, p] : monomial) {
auto mx = mul_value_without(v, x);
// val / mx = x^p
if (mx == 0)
continue;
auto valmx = divide(x, val, mx);
auto r = root_of(p, valmx);
add_update(x, r - value(x));
if (p % 2 == 0)
add_update(x, -r - value(x));
}
}
}
// flip on the first positive score
@ -978,10 +1082,18 @@ namespace sls {
template<typename num_t>
typename arith_base<num_t>::var_t arith_base<num_t>::mk_var(expr* e) {
var_t v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v == UINT_MAX) {
v = m_vars.size();
m_expr2var.setx(e->get_id(), v, UINT_MAX);
m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL));
if (v != UINT_MAX)
return v;
v = m_vars.size();
m_expr2var.setx(e->get_id(), v, UINT_MAX);
m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL));
expr* c = nullptr, * th = nullptr, * el = nullptr;
if (m.is_ite(e, c, th, el)) {
auto th_v = m_expr2var[th->get_id()];
auto el_v = m_expr2var[el->get_id()];
m_vars[th_v].m_ifs.push_back(v);
m_vars[el_v].m_ifs.push_back(v);
m_vars[v].m_def_idx = UINT_MAX - 1;
}
return v;
}
@ -1152,7 +1264,16 @@ namespace sls {
template<typename num_t>
num_t arith_base<num_t>::value1(var_t v) {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (vi.is_if_op()) {
expr* c = nullptr, * th = nullptr, *el = nullptr;
VERIFY(m.is_ite(vi.m_expr, c, th, el));
if (ctx.is_true(c))
return value(mk_var(th));
else
return value(mk_var(el));
}
if (!vi.is_arith_op())
return value(v);
num_t result, v1, v2;
@ -1224,7 +1345,7 @@ namespace sls {
if (v == UINT_MAX)
return;
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (!vi.is_arith_op())
return;
auto new_value = value1(v);
if (!update(v, new_value))
@ -1237,7 +1358,7 @@ namespace sls {
if (v == UINT_MAX)
return false;
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (!vi.is_arith_op())
return false;
flet<bool> _tabu(m_use_tabu, false);
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
@ -2326,7 +2447,7 @@ namespace sls {
template<typename num_t>
bool arith_base<num_t>::eval_is_correct(var_t v) {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
if (!vi.is_arith_op())
return true;
IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
@ -2441,8 +2562,7 @@ namespace sls {
for (auto const& [c, v] : i.m_args)
val += c * value(v);
if (val != i.m_args_value) {
verbose_stream() << val << ": " << i << "\n";
display(verbose_stream());
IF_VERBOSE(0, verbose_stream() << val << ": " << i << "\n"; display(verbose_stream()));
TRACE("arith", display(tout << val << ": " << i << "\n"));
}
SASSERT(val == i.m_args_value);
@ -2516,6 +2636,12 @@ namespace sls {
auto old_value = value(v);
IF_VERBOSE(5, verbose_stream() << "update: v" << v << " " << mk_bounded_pp(vi.m_expr, m) << " := " << old_value << " -> " << new_value << "\n");
vi.set_value(new_value);
ctx.new_value_eh(vi.m_expr);
for (auto const& [coeff, bv] : vi.m_linear_occurs) {
auto& ineq = *get_ineq(bv);
ineq.m_args_value += coeff * (new_value - old_value);
}
for (auto const& idx : vi.m_muls) {
auto& [x, monomial] = m_muls[idx];
@ -2536,10 +2662,10 @@ namespace sls {
for (auto const& x : vi.m_ops)
update_args_value(x, value1(x));
for (auto const& [coeff, bv] : vi.m_linear_occurs) {
auto& ineq = *get_ineq(bv);
ineq.m_args_value += coeff * (new_value - old_value);
}
for (auto const& x : vi.m_ifs)
update_args_value(x, value1(x));
}