From afef727b8880e979ea2e0fb2f19500c196351879 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 13 Aug 2024 14:50:17 -0700 Subject: [PATCH] bug fixes --- src/ast/sls/sls_arith_base.cpp | 157 +++++++++++++++++++++++++-------- src/ast/sls/sls_arith_base.h | 26 +++++- src/ast/sls/sls_context.cpp | 10 ++- src/ast/sls/sls_smt_solver.cpp | 6 ++ 4 files changed, 154 insertions(+), 45 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index b5bf7c72d..b7e0c2f2e 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -171,8 +171,10 @@ namespace sls { num_t arith_base::divide_floor(var_t v, num_t const& a, num_t const& b) { if (!is_int(v)) return a / b; - if (b > 0) + if (b > 0 && a >= 0) return div(a, b); + else if (b > 0) + return -div(-a + b - 1, b); else if (a > 0) return -div(a - b - 1, -b); else @@ -183,8 +185,10 @@ namespace sls { num_t arith_base::divide_ceil(var_t v, num_t const& a, num_t const& b) { if (!is_int(v)) return a / b; - if (b > 0) + if (b > 0 && a >= 0) return div(a + b - 1, b); + else if (b > 0) + return -div(-a, b); else if (a > 0) return -div(a, -b); else @@ -256,19 +260,6 @@ namespace sls { lh += eps; if (a * rl * rl + b * rl + c <= 0) rl -= eps; - if (is_square && a * lh * lh + b * lh + c <= 0) { - num_t ll = divide_floor(x, -b - root, 2 * a); - num_t lh = divide_ceil(x, -b - root, 2 * a); - num_t rl = divide_floor(x, -b + root, 2 * a); - num_t rh = divide_ceil(x, -b + root, 2 * a); - - verbose_stream() << a << " " << b << " " << c << "\n"; - verbose_stream() << (-b - root) << " " << (2 * a) << " " << ll << " " << lh << "\n"; - verbose_stream() << (-b + root) << " " << (2 * a) << " " << rl << " " << rh << "\n"; - verbose_stream() << "root " << root << "\n"; - UNREACHABLE(); - } - SASSERT(!is_square || a * lh * lh + b * lh + c > 0); SASSERT(!is_square || a * rl * rl + b * rl + c > 0); add_update(x, lh - value(x)); @@ -420,9 +411,7 @@ namespace sls { num_t delta = sum; SASSERT(sum != 0); delta = sum < 0 ? divide(v, abs(sum), coeff) : -divide(v, sum, coeff); - if (sum + coeff * delta != 0) - solve_eq_pairs(v, ineq); - else + if (sum + coeff * delta == 0) add_update(v, delta); break; } @@ -440,7 +429,7 @@ namespace sls { if (m_last_var == v && m_last_delta == -delta) return false; - if (m_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) + if (false && m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) return false; auto old_value = value(v); @@ -448,7 +437,7 @@ namespace sls { if (!vi.in_range(new_value)) return false; - if (!in_bounds(v, new_value) && in_bounds(v, old_value)) { + if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) { auto const& lo = m_vars[v].m_lo; auto const& hi = m_vars[v].m_hi; if (lo && (lo->is_strict ? lo->value >= new_value : lo->value > new_value)) { @@ -492,7 +481,7 @@ namespace sls { if (is_fixed(v)) return false; auto argsv = ineq.m_args_value; - num_t a; + num_t a(0); for (auto const& [c, w] : ineq.m_args) if (v == w) { a = c; @@ -501,6 +490,7 @@ namespace sls { if (abs(a) == 1) return false; IF_VERBOSE(3, verbose_stream() << "solve_eq_pairs " << ineq << " for v" << v << "\n"); + SASSERT(a != 0); unsigned start = ctx.rand(); for (unsigned i = 0; i < ineq.m_args.size(); ++i) { unsigned j = (start + i) % ineq.m_args.size(); @@ -525,6 +515,7 @@ namespace sls { if (is_fixed(y)) return false; num_t x0, y0; + std::cout << "solve_eq_pairs " << _a << " v" << x << " " << _b << " v" << y << " " << r << "\n"; num_t a = _a, b = _b; num_t g = gcd(a, b, x0, y0); SASSERT(g >= 1); @@ -752,6 +743,14 @@ namespace sls { IF_VERBOSE(10, display(verbose_stream(), v) << " := " << new_value << "\n"); + + +#if 0 + if (!check_update(v, new_value)) + return false; + apply_checked_update(); +#else + for (auto const& [coeff, bv] : vi.m_bool_vars) { auto& ineq = *atom(bv); bool old_sign = sign(bv); @@ -771,25 +770,106 @@ namespace sls { for (auto idx : vi.m_muls) { auto const& [w, coeff, monomial] = m_muls[idx]; + ctx.new_value_eh(m_vars[w].m_expr); num_t prod(coeff); - for (auto [w, p] : monomial) - prod *= power_of(value(w), p); + try { + for (auto [w, p] : monomial) + prod *= power_of(value(w), p); + } + catch (overflow_exception const&) { + return false; + } if (value(w) != prod && !update(w, prod)) return false; + } for (auto idx : vi.m_adds) { auto const& ad = m_adds[idx]; + auto w = ad.m_var; + ctx.new_value_eh(m_vars[w].m_expr); num_t sum(ad.m_coeff); for (auto const& [coeff, w] : ad.m_args) sum += coeff * value(w); if (!update(ad.m_var, sum)) return false; } +#endif return true; } + template + bool arith_base::check_update(var_t v, num_t new_value) { + + ++m_update_timestamp; + if (m_update_timestamp == 0) { + for (auto& vi : m_vars) + vi.set_update_value(num_t(0), 0); + ++m_update_timestamp; + } + auto& vi = m_vars[v]; + m_update_trail.reset(); + m_update_trail.push_back(v); + vi.set_update_value(new_value, m_update_timestamp); + + for (unsigned i = 0; i < m_update_trail.size(); ++i) { + auto v = m_update_trail[i]; + auto& vi = m_vars[v]; + for (auto idx : vi.m_muls) { + auto const& [w, coeff, monomial] = m_muls[idx]; + num_t prod(coeff); + try { + for (auto [w, p] : monomial) + prod *= power_of(get_update_value(w), p); + } + catch (overflow_exception const&) { + return false; + } + if (get_update_value(w) != prod && !is_permitted_update(w, prod - value(w))) + return false; + m_update_trail.push_back(w); + m_vars[w].set_update_value(prod, m_update_timestamp); + } + + for (auto idx : vi.m_adds) { + auto const& ad = m_adds[idx]; + auto w = ad.m_var; + num_t sum(ad.m_coeff); + for (auto const& [coeff, w] : ad.m_args) + sum += coeff * get_update_value(w); + if (get_update_value(v) != sum && !is_permitted_update(w, sum - value(w))) + return false; + m_update_trail.push_back(w); + m_vars[w].set_update_value(sum, m_update_timestamp); + } + } + return true; + } + + template + void arith_base::apply_checked_update() { + for (auto v : m_update_trail) { + auto & vi = m_vars[v]; + auto old_value = vi.m_value; + vi.m_value = vi.get_update_value(m_update_timestamp); + auto new_value = vi.m_value; + ctx.new_value_eh(vi.m_expr); + for (auto const& [coeff, bv] : vi.m_bool_vars) { + auto& ineq = *atom(bv); + bool old_sign = sign(bv); + sat::literal lit(bv, old_sign); + SASSERT(ctx.is_true(lit)); + ineq.m_args_value += coeff * (new_value - old_value); + num_t dtt_new = dtt(old_sign, ineq); + // verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n"; + if (dtt_new != 0) + ctx.flip(bv); + SASSERT(dtt(sign(bv), ineq) == 0); + } + } + } + template typename arith_base::ineq& arith_base::new_ineq(ineq_kind op, num_t const& coeff) { auto* i = alloc(ineq); @@ -1144,53 +1224,55 @@ namespace sls { auto const& vi = m_vars[v]; if (vi.m_def_idx == UINT_MAX) return; - num_t v1, v2; + + num_t new_value, v1, v2; switch (vi.m_op) { case LAST_ARITH_OP: break; case OP_ADD: { auto const& ad = m_adds[vi.m_def_idx]; auto const& args = ad.m_args; - num_t sum(ad.m_coeff); + new_value = ad.m_coeff; for (auto [c, w] : args) - sum += c * value(w); - update(v, sum); + new_value += c * value(w); break; } case OP_MUL: { auto const& [w, coeff, monomial] = m_muls[vi.m_def_idx]; - num_t prod(coeff); + new_value = coeff; for (auto [w, p] : monomial) - prod *= power_of(value(w), p); - update(v, prod); + new_value *= power_of(value(w), p); break; } case OP_MOD: v1 = value(m_ops[vi.m_def_idx].m_arg1); v2 = value(m_ops[vi.m_def_idx].m_arg2); - update(v, v2 == 0 ? num_t(0) : mod(v1, v2)); + new_value = v2 == 0 ? num_t(0) : mod(v1, v2); break; case OP_DIV: v1 = value(m_ops[vi.m_def_idx].m_arg1); v2 = value(m_ops[vi.m_def_idx].m_arg2); - update(v, v2 == 0 ? num_t(0) : v1 / v2); + new_value = v2 == 0 ? num_t(0) : v1 / v2; break; case OP_IDIV: v1 = value(m_ops[vi.m_def_idx].m_arg1); v2 = value(m_ops[vi.m_def_idx].m_arg2); - update(v, v2 == 0 ? num_t(0) : div(v1, v2)); + new_value = v2 == 0 ? num_t(0) : div(v1, v2); break; case OP_REM: v1 = value(m_ops[vi.m_def_idx].m_arg1); v2 = value(m_ops[vi.m_def_idx].m_arg2); - update(v, v2 == 0 ? num_t(0) : v1 %= v2); + new_value = v2 == 0 ? num_t(0) : v1 %= v2; break; case OP_ABS: - update(v, abs(value(m_ops[vi.m_def_idx].m_arg1))); + new_value = abs(value(m_ops[vi.m_def_idx].m_arg1)); break; default: NOT_IMPLEMENTED_YET(); } + + if (!update(v, new_value)) + ctx.new_value_eh(e); } template @@ -1201,6 +1283,7 @@ namespace sls { auto const& vi = m_vars[v]; if (vi.m_def_idx == UINT_MAX) return false; + flet _tabu(m_use_tabu, false); TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); switch (vi.m_op) { case arith_op_kind::LAST_ARITH_OP: @@ -1234,11 +1317,11 @@ namespace sls { template void arith_base::initialize() { for (auto lit : ctx.unit_literals()) - initialize(lit); + initialize_unit(lit); } template - void arith_base::initialize(sat::literal lit) { + void arith_base::initialize_unit(sat::literal lit) { init_bool_var(lit.var()); auto* ineq = atom(lit.var()); if (!ineq) diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index c304e2c40..3500ab343 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -69,7 +69,11 @@ namespace sls { }; private: - struct var_info { + class var_info { + num_t m_range{ 100000000 }; + num_t m_update_value{ 0 }; + unsigned m_update_timestamp = 0; + public: var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {} expr* m_expr; num_t m_value{ 0 }; @@ -81,7 +85,16 @@ namespace sls { unsigned_vector m_muls; unsigned_vector m_adds; optional m_lo, m_hi; - num_t m_range{ 100000000 }; + + // retrieve temporary value during an update. + void set_update_value(num_t const& v, unsigned timestamp) { + m_update_value = v; + m_update_timestamp = timestamp; + } + num_t const& get_update_value(unsigned ts) const { + return ts == m_update_timestamp ? m_update_value : m_value; + } + bool in_range(num_t const& n) const { if (-m_range < n && n < m_range) return true; @@ -139,7 +152,7 @@ namespace sls { vector m_updates; var_t m_last_var = 0; num_t m_last_delta { 0 }; - bool m_tabu = false; + bool m_use_tabu = true; arith_util a; void invariant(); @@ -172,6 +185,10 @@ namespace sls { void add_update(var_t v, num_t delta); bool is_permitted_update(var_t v, num_t& delta); + unsigned m_update_timestamp = 0; + svector m_update_trail; + bool check_update(var_t v, num_t new_value); + void apply_checked_update(); vector m_factors; vector const& factor(num_t n); @@ -260,11 +277,12 @@ namespace sls { bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; } num_t value(var_t v) const { return m_vars[v].m_value; } + num_t const& get_update_value(var_t v) const { return m_vars[v].get_update_value(m_update_timestamp); } bool is_num(expr* e, num_t& i); expr_ref from_num(sort* s, num_t const& n); void check_ineqs(); void init_bool_var(sat::bool_var bv); - void initialize(sat::literal lit); + void initialize_unit(sat::literal lit); void add_le(var_t v, num_t const& n); void add_ge(var_t v, num_t const& n); void add_lt(var_t v, num_t const& n); diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 64fe69ca9..ec26737e3 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -101,8 +101,8 @@ namespace sls { propagate_literal(lit); } - while (!m_new_constraint && (!m_repair_up.empty() || !m_repair_down.empty())) { - while (!m_repair_down.empty() && !m_new_constraint) { + while (!m_new_constraint && m.inc() && (!m_repair_up.empty() || !m_repair_down.empty())) { + while (!m_repair_down.empty() && !m_new_constraint && m.inc()) { auto id = m_repair_down.erase_min(); expr* e = term(id); TRACE("sls", tout << "repair down " << mk_bounded_pp(e, m) << "\n"); @@ -114,7 +114,7 @@ namespace sls { } } } - while (!m_repair_up.empty() && !m_new_constraint) { + while (!m_repair_up.empty() && !m_new_constraint && m.inc()) { auto id = m_repair_up.erase_min(); expr* e = term(id); TRACE("sls", tout << "repair up " << mk_bounded_pp(e, m) << "\n"); @@ -308,12 +308,14 @@ namespace sls { } } ); - // verbose_stream() << "new value " << mk_bounded_pp(e, m) << " " << mk_bounded_pp(get_value(e), m) << "\n"; + m_repair_down.reserve(e->get_id() + 1); + m_repair_up.reserve(e->get_id() + 1); if (!m_repair_down.contains(e->get_id())) m_repair_down.insert(e->get_id()); for (auto p : parents(e)) { m_repair_up.reserve(p->get_id() + 1); + m_repair_down.reserve(p->get_id() + 1); if (!m_repair_up.contains(p->get_id())) m_repair_up.insert(p->get_id()); } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 15dc55f93..a72e95e72 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -46,7 +46,11 @@ namespace sls { void on_restart() override {} + bool m_on_save_model = false; void on_save_model() override { + if (m_on_save_model) + return; + flet _on_save_model(m_on_save_model, true); TRACE("sls", display(tout)); while (unsat().empty()) { m_context.check(); @@ -185,6 +189,8 @@ namespace sls { } else { sat::literal lit = mk_literal(f); + if (sign) + lit.neg(); m_solver_ctx->add_clause(1, &lit); } }