From cd92b386975eae59ec93a0e6767658beac3a8b1e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 26 Aug 2024 09:21:38 -0700 Subject: [PATCH] avoid negative reward Signed-off-by: Nikolaj Bjorner --- src/ast/sls/sat_ddfw.h | 2 +- src/ast/sls/sls_arith_base.cpp | 136 +++++++++++++++++++++------------ src/ast/sls/sls_arith_base.h | 2 + 3 files changed, 92 insertions(+), 48 deletions(-) diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 206289708..81b5cae20 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -155,7 +155,7 @@ namespace sat { inline void inc_reward(literal lit, double w) { reward(lit.var()) += w; } - inline void dec_reward(literal lit, double w) { reward(lit.var()) -= w; } + inline void dec_reward(literal lit, double w) { if (reward(lit.var()) >= w) reward(lit.var()) -= w; } void check_with_plugin(); void check_without_plugin(); diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 4dd6b62e6..56b1af73e 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -1159,6 +1159,61 @@ namespace sls { return false; } + + template + num_t arith_base::value1(var_t v) { + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return value(v); + + num_t result, v1, v2; + switch (vi.m_op) { + case LAST_ARITH_OP: + break; + case OP_ADD: { + auto const& ad = m_adds[vi.m_def_idx]; + auto const& args = ad.m_args; + result = ad.m_coeff; + for (auto [c, w] : args) + result += c * value(w); + break; + } + case OP_MUL: { + auto const& [w, monomial] = m_muls[vi.m_def_idx]; + result = num_t(1); + for (auto [w, p] : monomial) + result *= power_of(value(w), p); + break; + } + case OP_MOD: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : mod(v1, v2); + break; + case OP_DIV: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : v1 / v2; + break; + case OP_IDIV: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : div(v1, v2); + break; + case OP_REM: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : v1 %= v2; + break; + case OP_ABS: + result = abs(value(m_ops[vi.m_def_idx].m_arg1)); + break; + default: + NOT_IMPLEMENTED_YET(); + } + return result; + } + template void arith_base::repair_up(app* e) { if (m.is_bool(e)) { @@ -1174,53 +1229,7 @@ namespace sls { auto const& vi = m_vars[v]; if (vi.m_def_idx == UINT_MAX) return; - - num_t new_value, v1, v2; - switch (vi.m_op) { - case LAST_ARITH_OP: - break; - case OP_ADD: { - auto const& ad = m_adds[vi.m_def_idx]; - auto const& args = ad.m_args; - new_value = ad.m_coeff; - for (auto [c, w] : args) - new_value += c * value(w); - break; - } - case OP_MUL: { - auto const& [w, monomial] = m_muls[vi.m_def_idx]; - new_value = num_t(1); - for (auto [w, p] : monomial) - new_value *= power_of(value(w), p); - break; - } - case OP_MOD: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); - new_value = v2 == 0 ? num_t(0) : mod(v1, v2); - break; - case OP_DIV: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); - new_value = v2 == 0 ? num_t(0) : v1 / v2; - break; - case OP_IDIV: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); - new_value = v2 == 0 ? num_t(0) : div(v1, v2); - break; - case OP_REM: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); - new_value = v2 == 0 ? num_t(0) : v1 %= v2; - break; - case OP_ABS: - new_value = abs(value(m_ops[vi.m_def_idx].m_arg1)); - break; - default: - NOT_IMPLEMENTED_YET(); - } - + auto new_value = value1(v); if (!update(v, new_value)) ctx.new_value_eh(e); } @@ -1921,6 +1930,39 @@ namespace sls { template void arith_base::on_restart() { +#if 0 + for (var_t v = 0; v < m_vars.size(); ++v) { + auto& vi = m_vars[v]; + num_t new_value; + if (vi.m_def_idx == UINT_MAX) { + auto val = value(v); + + if (ctx.rand(10) != 0) { + new_value = num_t((int)ctx.rand(2)); + if (!in_bounds(v, new_value)) + new_value = val; + } + else + new_value = val; + //verbose_stream() << v << " " << vi.m_value << " -> " << new_value << "\n"; + vi.m_value = new_value; + } + else { + vi.m_value = value1(v); + } + ctx.new_value_eh(vi.m_expr); + } + + for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) { + auto* ineq = atom(v); + if (!ineq) + continue; + ineq->m_args_value = ineq->m_coeff; + for (auto const& [coeff, w] : ineq->m_args) + ineq->m_args_value += coeff * value(w); + init_bool_var(v); + } +#endif } template diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 7fed38786..1b689435a 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -191,6 +191,8 @@ namespace sls { bool check_update(var_t v, num_t new_value); void apply_checked_update(); + num_t value1(var_t v); + vector m_factors; vector const& factor(num_t n); num_t root_of(unsigned n, num_t a);