mirror of
https://github.com/Z3Prover/z3
synced 2025-04-06 17:44:08 +00:00
add recursive updates to lookahead
This commit is contained in:
parent
57cb988461
commit
7fc59b65ad
|
@ -483,9 +483,113 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
void arith_base<num_t>::add_update(var_t v, num_t delta) {
|
||||
num_t delta_out;
|
||||
auto const& vi = m_vars[v];
|
||||
if (!is_permitted_update(v, delta, delta_out))
|
||||
return;
|
||||
m_updates.push_back({ v, delta_out, 0 });
|
||||
if (vi.m_op == arith_op_kind::OP_NUM)
|
||||
return;
|
||||
if (is_add(v) && m_allow_recursive_delta)
|
||||
add_update_add(m_adds[vi.m_def_idx], delta_out);
|
||||
else if (is_mul(v) && m_allow_recursive_delta)
|
||||
add_update_mul(m_muls[vi.m_def_idx], delta_out);
|
||||
else if (is_op(v) && m_allow_recursive_delta)
|
||||
add_update(m_ops[vi.m_def_idx], delta_out);
|
||||
else if (vi.is_if_op() && m_allow_recursive_delta) {
|
||||
expr* c, * t, * e;
|
||||
VERIFY(m.is_ite(vi.m_expr, c, t, e));
|
||||
bool cond = ctx.is_true(c);
|
||||
if (cond)
|
||||
add_update(mk_term(t), delta_out);
|
||||
else
|
||||
add_update(mk_term(e), delta_out);
|
||||
}
|
||||
else {
|
||||
if (!is_uninterp(vi.m_expr) && m_allow_recursive_delta)
|
||||
verbose_stream() << mk_bounded_pp(vi.m_expr, m) << " += " << delta_out << "\n";
|
||||
m_updates.push_back({ v, delta_out, 0 });
|
||||
}
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::add_update(op_def const& od, num_t const& delta) {
|
||||
switch (od.m_op) {
|
||||
case arith_op_kind::OP_IDIV:
|
||||
case arith_op_kind::OP_IDIV0:
|
||||
add_update_idiv(od, delta);
|
||||
break;
|
||||
case arith_op_kind::OP_MOD:
|
||||
case arith_op_kind::OP_MOD0:
|
||||
add_update_mod(od, delta);
|
||||
break;
|
||||
case arith_op_kind::OP_NUM:
|
||||
break;
|
||||
case arith_op_kind::OP_DIV:
|
||||
case arith_op_kind::OP_DIV0:
|
||||
case arith_op_kind::OP_POWER:
|
||||
default:
|
||||
IF_VERBOSE(1, verbose_stream() << "add-update-op is TBD " << mk_bounded_pp(m_vars[od.m_var].m_expr, m) << " " << od.m_op << " " << delta << "\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::add_update_idiv(op_def const& od, num_t const& delta) {
|
||||
num_t arg1 = value(od.m_arg1);
|
||||
num_t arg2 = value(od.m_arg2);
|
||||
|
||||
if (arg2 != 0) {
|
||||
num_t val = div(arg1, arg2);
|
||||
if (arg2 > 0)
|
||||
add_update(od.m_arg1, delta * arg2);
|
||||
else if (arg2 < 0)
|
||||
add_update(od.m_arg1, -delta * arg2);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::add_update_mod(op_def const& od, num_t const& delta) {
|
||||
num_t val = value(od.m_var);
|
||||
num_t arg1 = value(od.m_arg1);
|
||||
num_t arg2 = value(od.m_arg2);
|
||||
if (arg1 + delta >= 0 && arg1 + delta < arg2)
|
||||
add_update(od.m_arg1, delta);
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::add_update_add(add_def const& ad, num_t const& delta) {
|
||||
for (auto const& [coeff, w] : ad.m_args)
|
||||
add_update(w, divide(w, delta, coeff));
|
||||
}
|
||||
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::add_update_mul(mul_def const& md, num_t const& delta) {
|
||||
auto const& [v, monomial] = md;
|
||||
auto val = value(v) + delta;
|
||||
|
||||
if (val == 0) {
|
||||
for (auto [x, p] : monomial)
|
||||
add_update(x, -value(x));
|
||||
}
|
||||
else if (val == 1 || val == -1) {
|
||||
for (auto [x, p] : monomial) {
|
||||
add_update(x, num_t(1) - value(x));
|
||||
add_update(x, num_t(-1) - value(x));
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto [x, p] : monomial) {
|
||||
auto mx = mul_value_without(v, x);
|
||||
// val / mx = x^p
|
||||
if (mx == 0)
|
||||
continue;
|
||||
auto valmx = divide(x, val, mx);
|
||||
auto r = root_of(p, valmx);
|
||||
add_update(x, r - value(x));
|
||||
if (p % 2 == 0)
|
||||
add_update(x, -r - value(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flip on the first positive score
|
||||
|
@ -978,10 +1082,18 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
typename arith_base<num_t>::var_t arith_base<num_t>::mk_var(expr* e) {
|
||||
var_t v = m_expr2var.get(e->get_id(), UINT_MAX);
|
||||
if (v == UINT_MAX) {
|
||||
v = m_vars.size();
|
||||
m_expr2var.setx(e->get_id(), v, UINT_MAX);
|
||||
m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL));
|
||||
if (v != UINT_MAX)
|
||||
return v;
|
||||
v = m_vars.size();
|
||||
m_expr2var.setx(e->get_id(), v, UINT_MAX);
|
||||
m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL));
|
||||
expr* c = nullptr, * th = nullptr, * el = nullptr;
|
||||
if (m.is_ite(e, c, th, el)) {
|
||||
auto th_v = m_expr2var[th->get_id()];
|
||||
auto el_v = m_expr2var[el->get_id()];
|
||||
m_vars[th_v].m_ifs.push_back(v);
|
||||
m_vars[el_v].m_ifs.push_back(v);
|
||||
m_vars[v].m_def_idx = UINT_MAX - 1;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
@ -1152,7 +1264,16 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
num_t arith_base<num_t>::value1(var_t v) {
|
||||
auto const& vi = m_vars[v];
|
||||
if (vi.m_def_idx == UINT_MAX)
|
||||
|
||||
if (vi.is_if_op()) {
|
||||
expr* c = nullptr, * th = nullptr, *el = nullptr;
|
||||
VERIFY(m.is_ite(vi.m_expr, c, th, el));
|
||||
if (ctx.is_true(c))
|
||||
return value(mk_var(th));
|
||||
else
|
||||
return value(mk_var(el));
|
||||
}
|
||||
if (!vi.is_arith_op())
|
||||
return value(v);
|
||||
|
||||
num_t result, v1, v2;
|
||||
|
@ -1224,7 +1345,7 @@ namespace sls {
|
|||
if (v == UINT_MAX)
|
||||
return;
|
||||
auto const& vi = m_vars[v];
|
||||
if (vi.m_def_idx == UINT_MAX)
|
||||
if (!vi.is_arith_op())
|
||||
return;
|
||||
auto new_value = value1(v);
|
||||
if (!update(v, new_value))
|
||||
|
@ -1237,7 +1358,7 @@ namespace sls {
|
|||
if (v == UINT_MAX)
|
||||
return false;
|
||||
auto const& vi = m_vars[v];
|
||||
if (vi.m_def_idx == UINT_MAX)
|
||||
if (!vi.is_arith_op())
|
||||
return false;
|
||||
flet<bool> _tabu(m_use_tabu, false);
|
||||
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
|
||||
|
@ -2326,7 +2447,7 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
bool arith_base<num_t>::eval_is_correct(var_t v) {
|
||||
auto const& vi = m_vars[v];
|
||||
if (vi.m_def_idx == UINT_MAX)
|
||||
if (!vi.is_arith_op())
|
||||
return true;
|
||||
IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
|
||||
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
|
||||
|
@ -2441,8 +2562,7 @@ namespace sls {
|
|||
for (auto const& [c, v] : i.m_args)
|
||||
val += c * value(v);
|
||||
if (val != i.m_args_value) {
|
||||
verbose_stream() << val << ": " << i << "\n";
|
||||
display(verbose_stream());
|
||||
IF_VERBOSE(0, verbose_stream() << val << ": " << i << "\n"; display(verbose_stream()));
|
||||
TRACE("arith", display(tout << val << ": " << i << "\n"));
|
||||
}
|
||||
SASSERT(val == i.m_args_value);
|
||||
|
@ -2516,6 +2636,12 @@ namespace sls {
|
|||
auto old_value = value(v);
|
||||
IF_VERBOSE(5, verbose_stream() << "update: v" << v << " " << mk_bounded_pp(vi.m_expr, m) << " := " << old_value << " -> " << new_value << "\n");
|
||||
vi.set_value(new_value);
|
||||
ctx.new_value_eh(vi.m_expr);
|
||||
|
||||
for (auto const& [coeff, bv] : vi.m_linear_occurs) {
|
||||
auto& ineq = *get_ineq(bv);
|
||||
ineq.m_args_value += coeff * (new_value - old_value);
|
||||
}
|
||||
|
||||
for (auto const& idx : vi.m_muls) {
|
||||
auto& [x, monomial] = m_muls[idx];
|
||||
|
@ -2536,10 +2662,10 @@ namespace sls {
|
|||
for (auto const& x : vi.m_ops)
|
||||
update_args_value(x, value1(x));
|
||||
|
||||
for (auto const& [coeff, bv] : vi.m_linear_occurs) {
|
||||
auto& ineq = *get_ineq(bv);
|
||||
ineq.m_args_value += coeff * (new_value - old_value);
|
||||
}
|
||||
for (auto const& x : vi.m_ifs)
|
||||
update_args_value(x, value1(x));
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -114,13 +114,16 @@ namespace sls {
|
|||
vector<std::pair<num_t, sat::bool_var>> m_linear_occurs;
|
||||
sat::bool_var_vector m_bool_vars_of;
|
||||
unsigned_vector m_clauses_of;
|
||||
unsigned_vector m_muls, m_adds, m_ops;
|
||||
unsigned_vector m_muls, m_adds, m_ops, m_ifs;
|
||||
optional<bound> m_lo, m_hi;
|
||||
vector<num_t> m_finite_domain;
|
||||
|
||||
num_t const& value() const { return m_value; }
|
||||
void set_value(num_t const& v) { m_value = v; }
|
||||
|
||||
bool is_arith_op() const { return m_def_idx < UINT_MAX - 1; }
|
||||
bool is_if_op() const { return m_def_idx == UINT_MAX - 1; }
|
||||
|
||||
num_t const& best_value() const { return m_best_value; }
|
||||
void set_best_value(num_t const& v) { m_best_value = v; }
|
||||
|
||||
|
@ -202,6 +205,7 @@ namespace sls {
|
|||
sat::literal m_last_literal = sat::null_literal;
|
||||
num_t m_last_delta { 0 };
|
||||
bool m_use_tabu = true;
|
||||
bool m_allow_recursive_delta = false;
|
||||
unsigned m_updates_max_size = 45;
|
||||
arith_util a;
|
||||
friend class arith_clausal<num_t>;
|
||||
|
@ -241,6 +245,11 @@ namespace sls {
|
|||
num_t mul_value_without(var_t m, var_t x);
|
||||
|
||||
void add_update(var_t v, num_t delta);
|
||||
void add_update(op_def const& od, num_t const& delta);
|
||||
void add_update_mod(op_def const& od, num_t const& delta);
|
||||
void add_update_add(add_def const& ad, num_t const& delta);
|
||||
void add_update_mul(mul_def const& md, num_t const& delta);
|
||||
void add_update_idiv(op_def const& od, num_t const& delta);
|
||||
bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out);
|
||||
|
||||
|
||||
|
@ -270,7 +279,8 @@ namespace sls {
|
|||
|
||||
bool is_mul(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_MUL; }
|
||||
bool is_add(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_ADD; }
|
||||
bool is_op(var_t v) const { return m_vars[v].m_op != arith_op_kind::LAST_ARITH_OP && m_vars[v].m_op != arith_op_kind::OP_MUL && m_vars[v].m_op != arith_op_kind::OP_ADD; }
|
||||
bool is_op(var_t v) const { return 0 <= m_vars[v].m_op && m_vars[v].m_op < arith_op_kind::LAST_ARITH_OP && m_vars[v].m_op != arith_op_kind::OP_MUL && m_vars[v].m_op != arith_op_kind::OP_ADD; }
|
||||
bool is_if(var_t v) const { return m.is_ite(m_vars[v].m_expr); }
|
||||
mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; }
|
||||
add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; }
|
||||
|
||||
|
|
|
@ -137,12 +137,11 @@ namespace sls {
|
|||
if (!ineq)
|
||||
return;
|
||||
num_t na, nb;
|
||||
flet<bool> _allow_recursive_delta(a.m_allow_recursive_delta, true);
|
||||
for (auto const& [x, nl] : ineq->m_nonlinear) {
|
||||
if (a.is_fixed(x))
|
||||
continue;
|
||||
if (a.is_add(x) || a.is_mul(x) || a.is_op(x))
|
||||
;
|
||||
else if (a.is_linear(x, nl, nb))
|
||||
if (a.is_linear(x, nl, nb))
|
||||
a.find_linear_moves(*ineq, x, nb);
|
||||
else if (a.is_quadratic(x, nl, na, nb))
|
||||
a.find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value);
|
||||
|
@ -229,13 +228,13 @@ namespace sls {
|
|||
return v;
|
||||
}
|
||||
|
||||
|
||||
template<typename num_t>
|
||||
void arith_clausal<num_t>::lookahead(var_t v, num_t const& delta) {
|
||||
if (v == m_last_var && delta == m_last_delta)
|
||||
return;
|
||||
if (delta == 0)
|
||||
return;
|
||||
|
||||
m_last_var = v;
|
||||
m_last_delta = delta;
|
||||
if (!a.can_update_num(v, delta))
|
||||
|
|
Loading…
Reference in a new issue