diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index dd6b3571d..79f60d763 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -423,9 +423,11 @@ namespace sls { } template - bool arith_base::is_permitted_update(var_t v, num_t& delta) { + bool arith_base::is_permitted_update(var_t v, num_t const& delta, num_t & delta_out) { auto& vi = m_vars[v]; + delta_out = delta; + if (m_last_var == v && m_last_delta == -delta) return false; @@ -441,38 +443,39 @@ namespace sls { auto const& lo = m_vars[v].m_lo; auto const& hi = m_vars[v].m_hi; if (lo && (lo->is_strict ? lo->value >= new_value : lo->value > new_value)) { - if (lo->is_strict && delta < 0 && lo->value <= old_value) { + if (lo->is_strict && delta_out < 0 && lo->value <= old_value) { num_t eps(1); if (hi && hi->value - lo->value <= eps) eps = (hi->value - lo->value) / num_t(2); - delta = lo->value - old_value + eps; + delta_out = lo->value - old_value + eps; } - else if (!lo->is_strict && delta < 0 && lo->value < old_value) - delta = lo->value - old_value; + else if (!lo->is_strict && delta_out < 0 && lo->value < old_value) + delta_out = lo->value - old_value; else return false; } if (hi && (hi->is_strict ? hi->value <= new_value : hi->value < new_value)) { - if (hi->is_strict && delta >= 0 && hi->value >= old_value) { + if (hi->is_strict && delta_out >= 0 && hi->value >= old_value) { num_t eps(1); if (lo && hi->value - lo->value <= eps) eps = (hi->value - lo->value) / num_t(2); - delta = hi->value - old_value - eps; + delta_out = hi->value - old_value - eps; } - else if (!hi->is_strict && delta > 0 && hi->value > old_value) - delta = hi->value - old_value; + else if (!hi->is_strict && delta_out > 0 && hi->value > old_value) + delta_out = hi->value - old_value; else return false; } } - return delta != 0; + return delta_out != 0; } template void arith_base::add_update(var_t v, num_t delta) { - if (!is_permitted_update(v, delta)) + num_t delta_out; + if (!is_permitted_update(v, delta, delta_out)) return; - m_updates.push_back({ v, delta, compute_score(v, delta) }); + m_updates.push_back({ v, delta_out, compute_score(v, delta_out) }); } // flip on the first positive score @@ -698,6 +701,7 @@ namespace sls { m_update_trail.push_back(v); vi.set_update_value(new_value, m_update_timestamp); + num_t delta; for (unsigned i = 0; i < m_update_trail.size(); ++i) { auto v = m_update_trail[i]; auto& vi = m_vars[v]; @@ -711,7 +715,7 @@ namespace sls { catch (overflow_exception const&) { return false; } - if (get_update_value(w) != prod && !is_permitted_update(w, prod - value(w))) + if (get_update_value(w) != prod && (!is_permitted_update(w, prod - value(w), delta) || prod - value(w) != delta)) return false; m_update_trail.push_back(w); m_vars[w].set_update_value(prod, m_update_timestamp); @@ -723,7 +727,7 @@ namespace sls { num_t sum(ad.m_coeff); for (auto const& [coeff, w] : ad.m_args) sum += coeff * get_update_value(w); - if (get_update_value(v) != sum && !is_permitted_update(w, sum - value(w))) + if (get_update_value(v) != sum && !(is_permitted_update(w, sum - value(w), delta) || sum - value(w) != delta)) return false; m_update_trail.push_back(w); m_vars[w].set_update_value(sum, m_update_timestamp); @@ -1202,6 +1206,86 @@ namespace sls { void arith_base::initialize() { for (auto lit : ctx.unit_literals()) initialize_unit(lit); + for (unsigned v = 0; v < m_vars.size(); ++v) { + auto const& vi = m_vars[v]; + if (vi.m_lo || vi.m_hi) + continue; + if (is_add(v)) { + auto const& ad = get_add(v); + num_t lo(ad.m_coeff), hi(ad.m_coeff); + bool lo_valid = true, hi_valid = true; + bool lo_strict = false, hi_strict = false; + for (auto const& [c, w] : ad.m_args) { + if (!lo_valid && !hi_valid) + break; + auto const& wi = m_vars[w]; + if (lo_valid) { + if (c > 0 && wi.m_lo) + lo += c * wi.m_lo->value, + lo_strict |= wi.m_lo->is_strict; + else if (c < 0 && wi.m_hi) + lo += c * wi.m_hi->value, + lo_strict |= wi.m_hi->is_strict; + else + lo_valid = false; + } + if (hi_valid) { + if (c > 0 && wi.m_hi) + hi += c * wi.m_hi->value, + hi_strict |= wi.m_hi->is_strict; + else if (c < 0 && wi.m_lo) + hi += c * wi.m_lo->value, + hi_strict |= wi.m_lo->is_strict; + else + hi_valid = false; + } + } + if (lo_valid) { + if (lo_strict) + add_gt(v, lo); + else + add_ge(v, lo); + } + if (hi_valid) { + if (hi_strict) + add_lt(v, hi); + else + add_le(v, hi); + } + } + if (is_mul(v)) { + auto const& [w, c, monomial] = get_mul(v); + num_t lo(c), hi(c); + bool lo_valid = true, hi_valid = true; + bool lo_strict = false, hi_strict = false; + for (auto [w, p] : monomial) { + if (!lo_valid && !hi_valid) + break; + auto const& wi = m_vars[w]; + if (lo_valid) { + // TODO + lo_valid = false; + } + if (hi_valid) { + // TODO + hi_valid = false; + } + } + if (lo_valid) { + if (lo_strict) + add_gt(v, lo); + else + add_ge(v, lo); + } + if (hi_valid) { + if (hi_strict) + add_lt(v, hi); + else + add_le(v, hi); + } + } + // TBD: can also do with other operators. + } } template diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 67fddfd0a..5c0988a99 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -179,7 +179,7 @@ namespace sls { num_t mul_value_without(var_t m, var_t x); void add_update(var_t v, num_t delta); - bool is_permitted_update(var_t v, num_t& delta); + bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out); unsigned m_update_timestamp = 0; svector m_update_trail; bool check_update(var_t v, num_t new_value);