3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

avoid negative reward

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2024-08-26 09:21:38 -07:00
parent ace3472a96
commit cd92b38697
3 changed files with 92 additions and 48 deletions

View file

@ -155,7 +155,7 @@ namespace sat {
inline void inc_reward(literal lit, double w) { reward(lit.var()) += w; }
inline void dec_reward(literal lit, double w) { reward(lit.var()) -= w; }
inline void dec_reward(literal lit, double w) { if (reward(lit.var()) >= w) reward(lit.var()) -= w; }
void check_with_plugin();
void check_without_plugin();

View file

@ -1159,6 +1159,61 @@ namespace sls {
return false;
}
template<typename num_t>
num_t arith_base<num_t>::value1(var_t v) {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
return value(v);
num_t result, v1, v2;
switch (vi.m_op) {
case LAST_ARITH_OP:
break;
case OP_ADD: {
auto const& ad = m_adds[vi.m_def_idx];
auto const& args = ad.m_args;
result = ad.m_coeff;
for (auto [c, w] : args)
result += c * value(w);
break;
}
case OP_MUL: {
auto const& [w, monomial] = m_muls[vi.m_def_idx];
result = num_t(1);
for (auto [w, p] : monomial)
result *= power_of(value(w), p);
break;
}
case OP_MOD:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
result = v2 == 0 ? num_t(0) : mod(v1, v2);
break;
case OP_DIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
result = v2 == 0 ? num_t(0) : v1 / v2;
break;
case OP_IDIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
result = v2 == 0 ? num_t(0) : div(v1, v2);
break;
case OP_REM:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
result = v2 == 0 ? num_t(0) : v1 %= v2;
break;
case OP_ABS:
result = abs(value(m_ops[vi.m_def_idx].m_arg1));
break;
default:
NOT_IMPLEMENTED_YET();
}
return result;
}
template<typename num_t>
void arith_base<num_t>::repair_up(app* e) {
if (m.is_bool(e)) {
@ -1174,53 +1229,7 @@ namespace sls {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
return;
num_t new_value, v1, v2;
switch (vi.m_op) {
case LAST_ARITH_OP:
break;
case OP_ADD: {
auto const& ad = m_adds[vi.m_def_idx];
auto const& args = ad.m_args;
new_value = ad.m_coeff;
for (auto [c, w] : args)
new_value += c * value(w);
break;
}
case OP_MUL: {
auto const& [w, monomial] = m_muls[vi.m_def_idx];
new_value = num_t(1);
for (auto [w, p] : monomial)
new_value *= power_of(value(w), p);
break;
}
case OP_MOD:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
new_value = v2 == 0 ? num_t(0) : mod(v1, v2);
break;
case OP_DIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
new_value = v2 == 0 ? num_t(0) : v1 / v2;
break;
case OP_IDIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
new_value = v2 == 0 ? num_t(0) : div(v1, v2);
break;
case OP_REM:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
new_value = v2 == 0 ? num_t(0) : v1 %= v2;
break;
case OP_ABS:
new_value = abs(value(m_ops[vi.m_def_idx].m_arg1));
break;
default:
NOT_IMPLEMENTED_YET();
}
auto new_value = value1(v);
if (!update(v, new_value))
ctx.new_value_eh(e);
}
@ -1921,6 +1930,39 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::on_restart() {
#if 0
for (var_t v = 0; v < m_vars.size(); ++v) {
auto& vi = m_vars[v];
num_t new_value;
if (vi.m_def_idx == UINT_MAX) {
auto val = value(v);
if (ctx.rand(10) != 0) {
new_value = num_t((int)ctx.rand(2));
if (!in_bounds(v, new_value))
new_value = val;
}
else
new_value = val;
//verbose_stream() << v << " " << vi.m_value << " -> " << new_value << "\n";
vi.m_value = new_value;
}
else {
vi.m_value = value1(v);
}
ctx.new_value_eh(vi.m_expr);
}
for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) {
auto* ineq = atom(v);
if (!ineq)
continue;
ineq->m_args_value = ineq->m_coeff;
for (auto const& [coeff, w] : ineq->m_args)
ineq->m_args_value += coeff * value(w);
init_bool_var(v);
}
#endif
}
template<typename num_t>

View file

@ -191,6 +191,8 @@ namespace sls {
bool check_update(var_t v, num_t new_value);
void apply_checked_update();
num_t value1(var_t v);
vector<num_t> m_factors;
vector<num_t> const& factor(num_t n);
num_t root_of(unsigned n, num_t a);