diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index b7e0c2f2e..ac5df6056 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -475,124 +475,6 @@ namespace sls { m_updates.push_back({ v, delta, compute_score(v, delta) }); } - template - bool arith_base::solve_eq_pairs(var_t v, ineq const& ineq) { - SASSERT(ineq.m_op == ineq_kind::EQ); - if (is_fixed(v)) - return false; - auto argsv = ineq.m_args_value; - num_t a(0); - for (auto const& [c, w] : ineq.m_args) - if (v == w) { - a = c; - argsv -= value(v) * c; - } - if (abs(a) == 1) - return false; - IF_VERBOSE(3, verbose_stream() << "solve_eq_pairs " << ineq << " for v" << v << "\n"); - SASSERT(a != 0); - unsigned start = ctx.rand(); - for (unsigned i = 0; i < ineq.m_args.size(); ++i) { - unsigned j = (start + i) % ineq.m_args.size(); - auto const& [b, w] = ineq.m_args[j]; - if (w == v) - continue; - if (b == 1 || b == -1) - continue; - argsv -= value(w) * b; - solve_eq_pairs(a, v, b, w, - argsv); - argsv += value(w) * b; - } - return false; - } - - // ax0 + by0 = r - // (x, y) = (x0 - k*b/g, y0 + k*a/g) - // find the min x1 >= x0 satisfying progression and where x1 >= lo(x) - // k*ab/g - k*ab/g = 0 - template - bool arith_base::solve_eq_pairs(num_t const& _a, var_t x, num_t const& _b, var_t y, num_t const& r) { - if (is_fixed(y)) - return false; - num_t x0, y0; - std::cout << "solve_eq_pairs " << _a << " v" << x << " " << _b << " v" << y << " " << r << "\n"; - num_t a = _a, b = _b; - num_t g = gcd(a, b, x0, y0); - SASSERT(g >= 1); - SASSERT(g == a * x0 + b * y0); - if (!divides(g, r)) - return false; - //verbose_stream() << g << " == " << a << "*" << x0 << " + " << b << "*" << y0 << "\n"; - x0 *= div(r, g); - y0 *= div(r, g); - - auto adjust_lo = [&](num_t& x0, num_t& y0, num_t a, num_t b, optional const& lo, optional const& hi) { - if (!lo || lo->value <= x0) - return true; - // x0 + k*b/g >= lo - // k*(b/g) >= lo - x0 - // k >= (lo - x0)/(b/g) - // x1 := x0 + k*b/g - auto delta = lo->value - x0; - auto bg = abs(div(b, g)); - verbose_stream() << g << " " << bg << " " << " " << delta << "\n"; - auto k = divide(x, delta, bg); - auto x1 = x0 + k * bg; - if (hi && hi->value < x1) - return false; - x0 = x1; - y0 = y0 + k * (div(b, g) > 0 ? -div(a, g) : div(a, g)); - SASSERT(r == a * x0 + b * y0); - return true; - }; - auto adjust_hi = [&](num_t& x0, num_t& y0, num_t a, num_t b, optional const& lo, optional const& hi) { - if (!hi || hi->value >= x0) - return true; - // x0 + k*b/g <= hi - // k <= (x0 - hi)/(b/g) - auto delta = x0 - hi->value; - auto bg = abs(div(b, g)); - auto k = div(delta, bg); - auto x1 = x0 - k * bg; - if (lo && lo->value < x1) - return false; - x0 = x1; - y0 = y0 - k * (div(b, g) > 0 ? -div(a, g) : div(a, g)); - SASSERT(r == a * x0 + b * y0); - return true; - }; - auto const& lo_x = m_vars[x].m_lo; - auto const& hi_x = m_vars[x].m_hi; - - if (!adjust_lo(x0, y0, a, b, lo_x, hi_x)) - return false; - if (!adjust_hi(x0, y0, a, b, lo_x, hi_x)) - return false; - - auto const& lo_y = m_vars[y].m_lo; - auto const& hi_y = m_vars[y].m_hi; - - if (!adjust_lo(y0, x0, b, a, lo_y, hi_y)) - return false; - if (!adjust_hi(y0, x0, b, a, lo_y, hi_y)) - return false; - - if (lo_x && lo_x->value > x0) - return false; - if (hi_x && hi_x->value < x0) - return false; - - if (x0 == value(x)) - return false; - if (abs(value(x)) * 2 < abs(x0)) - return false; - if (abs(value(y)) * 2 < abs(y0)) - return false; - add_update(x, x0 - value(x)); - // add_update(y, y0 - value(y)); add pairwise update? - return true; - } - // flip on the first positive score // it could be changed to flip on maximal positive score // or flip on maximal non-negative score @@ -1351,10 +1233,7 @@ namespace sls { } break; case ineq_kind::EQ: - if (lit.sign()) { - verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; - } - else { + if (!lit.sign()) { if (c == -1) { add_ge(v, ineq->m_coeff); add_le(v, ineq->m_coeff); diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 3500ab343..515046d41 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -258,8 +258,6 @@ namespace sls { void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum); double compute_score(var_t x, num_t const& delta); void save_best_values(); - bool solve_eq_pairs(var_t v, ineq const& ineq); - bool solve_eq_pairs(num_t const& a, var_t x, num_t const& b, var_t y, num_t const& r); var_t mk_var(expr* e); var_t mk_term(expr* e); diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index cfe03b6da..d78fb30f8 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -73,9 +73,16 @@ public: m_sls->assert_expr(g->form(i)); - lbool res = m_sls->check(); m_st.reset(); - m_sls->collect_statistics(m_st); + lbool res = l_undef; + try { + res = m_sls->check(); + } + catch (z3_exception& ex) { + m_sls->collect_statistics(m_st); + throw; + } + // report_tactic_progress("Number of flips:", m_sls->get_num_moves()); IF_VERBOSE(10, verbose_stream() << res << "\n"); IF_VERBOSE(10, m_sls->display(verbose_stream())); diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index b1cb6e448..37f7c5526 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -27,6 +27,7 @@ Revision History: #include "util/rational.h" #include "util/mpn.h" + class overflow_exception : public z3_exception { char const* msg() const override { return "checked_int64 overflow/underflow"; } }; @@ -120,8 +121,10 @@ public: uint64_t x = static_cast(m_value); uint64_t y = static_cast(other.m_value); int64_t r = static_cast(x + y); - if (m_value > 0 && other.m_value > 0 && r <= 0) throw overflow_exception(); - if (m_value < 0 && other.m_value < 0 && r >= 0) throw overflow_exception(); + if (m_value > 0 && other.m_value > 0 && r <= 0) + throw overflow_exception(); + if (m_value < 0 && other.m_value < 0 && r >= 0) + throw overflow_exception(); m_value = r; } else { @@ -152,12 +155,22 @@ public: if (INT_MIN < m_value && m_value <= INT_MAX && INT_MIN < other.m_value && other.m_value <= INT_MAX) { m_value *= other.m_value; } + else if (m_value == 0 || other.m_value == 0 || m_value == 1 || other.m_value == 1) { + m_value *= other.m_value; + } + else if (m_value == INT64_MIN || other.m_value == INT64_MIN) + throw overflow_exception(); else { - uint64_t x = m_value, y = other.m_value; - uint64_t z = x * y; - if (y != 0 && z / y != x) + uint64_t x = m_value < 0 ? -m_value : m_value; + uint64_t y = other.m_value < 0 ? -other.m_value : other.m_value; + uint64_t r = x * y; + if ((y != 0 && r / y != x) || r > INT64_MAX) throw overflow_exception(); - m_value = z; + m_value = r; + if (m_value < 0 && other.m_value > 0) + m_value = -m_value; + if (m_value > 0 && other.m_value < 0) + m_value = -m_value; } } else {