From 849385c6a1c23c9be19bd8e5ee316e7c0f9abcea Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 12 Aug 2024 17:42:52 -0700 Subject: [PATCH] bugfixes in sls-arith --- src/ast/sls/sat_ddfw.cpp | 15 +- src/ast/sls/sat_ddfw.h | 4 +- src/ast/sls/sls_arith_base.cpp | 1199 ++++++++++++++++++++---------- src/ast/sls/sls_arith_base.h | 84 ++- src/ast/sls/sls_arith_plugin.cpp | 75 +- src/ast/sls/sls_arith_plugin.h | 5 +- src/ast/sls/sls_basic_plugin.cpp | 36 +- src/ast/sls/sls_basic_plugin.h | 6 +- src/ast/sls/sls_bv_plugin.cpp | 6 +- src/ast/sls/sls_bv_plugin.h | 4 +- src/ast/sls/sls_context.cpp | 31 +- src/ast/sls/sls_context.h | 10 +- src/ast/sls/sls_euf_plugin.h | 5 +- src/ast/sls/sls_smt_solver.cpp | 92 ++- src/ast/sls/sls_smt_solver.h | 4 +- src/util/checked_int64.h | 18 +- src/util/vector.h | 4 + 17 files changed, 1094 insertions(+), 504 deletions(-) diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index f86800079..c6f3591ff 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -374,7 +374,6 @@ namespace sat { } void ddfw::do_restart() { - verbose_stream() << "restart\n"; reinit_values(); init_clause_data(); m_restart_next += m_config.m_restart_base*get_luby(++m_restart_count); @@ -624,6 +623,20 @@ namespace sat { m_config.m_reinit_base = p.ddfw_reinit_base(); m_config.m_restart_base = p.ddfw_restart_base(); } + + void ddfw::collect_statistics(statistics& st) const { + st.update("sls-ddfw-flips", (double)m_flips); + st.update("sls-ddfw-restarts", m_restart_count); + st.update("sls-ddfw-reinits", m_reinit_count); + st.update("sls-ddfw-shifts", (double)m_shifts); + } + void ddfw::reset_statistics() { + m_flips = 0; + m_restart_count = 0; + m_reinit_count = 0; + m_shifts = 0; + } + } diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 65970cbd4..36ec30b27 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -233,7 +233,9 @@ namespace sat { // for parallel integration unsigned num_non_binary_clauses() const { return m_num_non_binary_clauses; } - void collect_statistics(statistics& st) const {} + void collect_statistics(statistics& st) const; + + void reset_statistics(); double get_priority(bool_var v) const { return m_probs[v]; } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 5543ad58d..b5bf7c72d 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -13,6 +13,31 @@ Author: Nikolaj Bjorner (nbjorner) 2023-02-07 + Uses quadratic solver method from nia_ls in hybrid-smt + (with a bug fix for when order of roots are swapped) + Other features from nia_ls are also used as a starting point, + such as tabu and fallbacks. + +Todo: + +- add fairness for which variable to flip and direction (by age fifo). + - maintain age per variable, per sign + +- include more general tabu measure + - + +- random walk when there is no applicable update + - repair_down can fail repeatedely. Then allow a mode to reset arguments similar to + repair of literals. + +- avoid overflow for nested products + +Done: +- add tabu for flipping variable back to the same value. + - remember last variable/delta and block -delta = last_delta && last_variable = current_variable +- include measures for bounded updates + - per variable maintain increasing range + --*/ #include "ast/sls/sls_arith_base.h" @@ -24,11 +49,11 @@ namespace sls { bool arith_base::ineq::is_true() const { switch (m_op) { case ineq_kind::LE: - return m_args_value + this->m_coeff <= 0; + return m_args_value <= 0; case ineq_kind::EQ: - return m_args_value + this->m_coeff == 0; + return m_args_value== 0; default: - return m_args_value + this->m_coeff < 0; + return m_args_value < 0; } } @@ -43,12 +68,25 @@ namespace sls { out << " + " << this->m_coeff; switch (m_op) { case ineq_kind::LE: - return out << " <= " << 0 << "(" << m_args_value + this->m_coeff << ")"; + out << " <= " << 0 << "(" << m_args_value << ")"; + break; case ineq_kind::EQ: - return out << " == " << 0 << "(" << m_args_value + this->m_coeff << ")"; + out << " == " << 0 << "(" << m_args_value << ")"; + break; default: - return out << " < " << 0 << "(" << m_args_value + this->m_coeff << ")"; + out << " < " << 0 << "(" << m_args_value << ")"; + break; } + for (auto const& [x, nl] : this->m_nonlinear) { + if (nl.size() == 1 && nl[0].v == x) + continue; + for (auto const& [v, c, p] : nl) { + out << " v" << x; + if (p > 1) out << "^" << p; + out << " in " << c << " * v" << v; + } + } + return out; } template @@ -117,16 +155,8 @@ namespace sls { } template - num_t arith_base::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const { - return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq); - } - - template - bool arith_base::cm(ineq const& ineq, var_t v, num_t& new_value) { - for (auto const& [coeff, w] : ineq.m_args) - if (w == v) - return cm(ineq, v, coeff, new_value); - return false; + num_t arith_base::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& delta) const { + return dtt(sign, ineq.m_args_value + coeff * delta, ineq); } template @@ -138,77 +168,238 @@ namespace sls { } template - bool arith_base::cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value) { - auto bound = -ineq.m_coeff; - auto argsv = ineq.m_args_value; - bool solved = false; - num_t delta = argsv - bound; - auto const& lo = m_vars[v].m_lo; - auto const& hi = m_vars[v].m_hi; + num_t arith_base::divide_floor(var_t v, num_t const& a, num_t const& b) { + if (!is_int(v)) + return a / b; + if (b > 0) + return div(a, b); + else if (a > 0) + return -div(a - b - 1, -b); + else + return div(-a, -b); + } - if (is_fixed(v)) - return false; + template + num_t arith_base::divide_ceil(var_t v, num_t const& a, num_t const& b) { + if (!is_int(v)) + return a / b; + if (b > 0) + return div(a + b - 1, b); + else if (a > 0) + return -div(a, -b); + else + return div(-a - b - 1, -b); + } - auto well_formed = [&]() { - num_t new_args = argsv + coeff * (new_value - value(v)); - if (ineq.is_true()) { - switch (ineq.m_op) { - case ineq_kind::LE: return new_args > bound; - case ineq_kind::LT: return new_args >= bound; - case ineq_kind::EQ: return new_args != bound; - } - } - else { - switch (ineq.m_op) { - case ineq_kind::LE: return new_args <= bound; - case ineq_kind::LT: return new_args < bound; - case ineq_kind::EQ: return new_args == bound; - } - } - return false; - }; + // + // i = 1, 3, 5, 7, 9, ... + // d, d - 1, d - 4, d - 9, d - 16, + // + template + static num_t sqrt(num_t d) { + if (d <= 1) + return d; + auto sq = 2*sqrt(div(d, num_t(4))) + 1; + if (sq * sq <= d) + return sq; + return sq - 1; + } - auto move_to_bounds = [&]() { - VERIFY(well_formed()); - if (!in_bounds(v, value(v))) - return true; - if (in_bounds(v, new_value)) - return true; - if (lo && lo->value > new_value) { - new_value = lo->value; - if (!well_formed()) - new_value += 1; - } - if (hi && hi->value < new_value) { - new_value = hi->value; - if (!well_formed()) - new_value -= 1; - } - return well_formed() && in_bounds(v, new_value); - }; + // + // a*x^2 + b*x + c = sum + // + template + void arith_base::find_quadratic_moves(ineq const& ineq, var_t x, num_t const& a, num_t const& b, num_t const& sum) { + num_t c, d; + try { + c = sum - a * value(x) * value(x) - b * value(x); + d = b * b - 4 * a * c; + } + catch (overflow_exception const&) { + return; + } + if (d < 0) + return; + num_t root = sqrt(d); + bool is_square = root * root == d; + num_t ll = divide_floor(x, -b - root, 2 * a); + num_t lh = divide_ceil(x, -b - root, 2 * a); + num_t rl = divide_floor(x, -b + root, 2 * a); + num_t rh = divide_ceil(x, -b + root, 2 * a); + if (lh > rl) { + std::swap(ll, rl); + std::swap(lh, rh); + } + num_t eps(1); + if (!is_int(x) && abs(rh - lh) <= eps) + eps = abs(rh - lh) / num_t(2); +// verbose_stream() << a << " " << b << " " << c << "\n"; +// verbose_stream() << (-b - root) << " " << (2 * a) << " " << ll << " " << lh << "\n"; +// verbose_stream() << (-b + root) << " " << (2 * a) << " " << rl << " " << rh << "\n"; + SASSERT(ll <= lh && ll + 1 >= lh); + SASSERT(rl <= rh && rl + 1 >= rh); + SASSERT(!is_square || ll != lh || a * ll * ll + b * ll + c == 0); + SASSERT(!is_square || rl != rh || a * rl * rl + b * rl + c == 0); + if (d > 0 && lh == rh) + return; + if (d == 0 && ll != lh) + return; if (ineq.is_true()) { switch (ineq.m_op) { case ineq_kind::LE: - // args <= bound -> args > bound - SASSERT(argsv <= bound); - SASSERT(delta <= 0); - delta -= 1; - new_value = value(v) + divide(v, abs(delta - ctx.rand(3)), coeff); - return move_to_bounds(); + SASSERT(sum <= 0); + if (d == 0) + break; + if (a < 0) { + if (a * lh * lh + b * lh + c <= 0) + lh += eps; + if (a * rl * rl + b * rl + c <= 0) + rl -= eps; + if (is_square && a * lh * lh + b * lh + c <= 0) { + num_t ll = divide_floor(x, -b - root, 2 * a); + num_t lh = divide_ceil(x, -b - root, 2 * a); + num_t rl = divide_floor(x, -b + root, 2 * a); + num_t rh = divide_ceil(x, -b + root, 2 * a); + + verbose_stream() << a << " " << b << " " << c << "\n"; + verbose_stream() << (-b - root) << " " << (2 * a) << " " << ll << " " << lh << "\n"; + verbose_stream() << (-b + root) << " " << (2 * a) << " " << rl << " " << rh << "\n"; + verbose_stream() << "root " << root << "\n"; + UNREACHABLE(); + } + + SASSERT(!is_square || a * lh * lh + b * lh + c > 0); + SASSERT(!is_square || a * rl * rl + b * rl + c > 0); + add_update(x, lh - value(x)); + add_update(x, rl - value(x)); + } + else { + if (a * ll * ll + b * ll + c <= 0) + ll -= eps; + if (a * rh * rh + b * rh + c <= 0) + rh += eps; + SASSERT(!is_square || a * ll * ll + b * ll + c > 0); + SASSERT(!is_square || a * rh * rh + b * rh + c > 0); + add_update(x, ll - value(x)); + add_update(x, rh - value(x)); + } + break; + case ineq_kind::LT: + SASSERT(sum < 0); + SASSERT(!is_int(x)); + SASSERT(ll == lh); + SASSERT(rl == rh); + if (d == 0) + break; + + if (a > 0) { + SASSERT(!is_square || a * (ll + eps) * (ll + eps) + b * (ll + eps) + c >= 0); + SASSERT(!is_square || a * (rl - eps) * (rl - eps) + b * (rl - eps) + c >= 0); + add_update(x, lh - value(x) + eps); + if (ll != rl) + add_update(x, rh - value(x) - eps); + } + else { + SASSERT(!is_square || a * (ll - eps) * (ll - eps) + b * (ll - eps) + c >= 0); + SASSERT(!is_square || a * (rl + eps) * (rl + eps) + b * (rl + eps) + c >= 0); + add_update(x, ll - value(x) - eps); + if (ll != rl) + add_update(x, rl - value(x) + eps); + } + break; + case ineq_kind::EQ: + SASSERT(sum == 0); + SASSERT(!is_square || a * (value(x) + 1) * (value(x) + 1) + b * (value(x) + 1) + c != 0); + SASSERT(!is_square || a * (value(x) - 1) * (value(x) - 1) + b * (value(x) - 1) + c != 0); + add_update(x, num_t(1)); + add_update(x, num_t(-1)); + break; + } + } + else { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(sum > 0); + if (d == 0) { + SASSERT(!is_square || !is_int(x) || a <= 0 || ll != lh || a * ll * ll + b * ll + c <= 0); + if (a > 0 && ll == lh) + add_update(x, ll - value(x)); + break; + } + SASSERT(d > 0); + if (a > 0) { + if (a * lh * lh + b * lh + c > 0) + lh += eps; + if (a * rl * rl + b * rl + c > 0) + rl -= eps; + SASSERT(!is_square || a * lh * lh + b * lh + c <= 0); + SASSERT(!is_square || a * rl * rl + b * rl + c <= 0); + add_update(x, lh - value(x)); + add_update(x, rl - value(x)); + } + else { + if (a * ll * ll + b * ll + c > 0) + ll += eps; + if (a * rh * rh + b * rh + c > 0) + rh -= eps; + SASSERT(!is_square || a * ll * ll + b * ll + c <= 0); + SASSERT(!is_square || a * rh * rh + b * rh + c <= 0); + add_update(x, ll - value(x)); + add_update(x, rh - value(x)); + } + break; + case ineq_kind::LT: + SASSERT(sum >= 0); + SASSERT(!is_int(x)); + if (d == 0) + break; + SASSERT(d > 0); + if (a > 0) { + SASSERT(!is_square || a * (ll - eps) * (ll - eps) + b * (ll - eps) + c < 0); + SASSERT(!is_square || a * (rl + eps) * (rl + eps) + b * (rl + eps) + c < 0); + add_update(x, lh - value(x) - eps); + if (ll != rl) + add_update(x, rh - value(x) + eps); + } + else { + SASSERT(!is_square || a* (ll + eps)* (ll + eps) + b * (ll + eps) + c < 0); + SASSERT(!is_square || a* (rl - eps)* (rl - eps) + b * (rl - eps) + c < 0); + add_update(x, ll - value(x) + eps); + if (ll != rl) + add_update(x, rl - value(x) - eps); + } + break; + case ineq_kind::EQ: + SASSERT(sum != 0); + if (!is_square) + break; + if (ll == lh) + add_update(x, ll - value(x)); + if (rl == rh && lh != rh) + add_update(x, rl - value(x)); + break; + } + } + } + + template + void arith_base::find_linear_moves(ineq const& ineq, var_t v, num_t const& coeff, num_t const& sum) { + if (ineq.is_true()) { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(sum <= 0); + add_update(v, divide(v, -sum + 1, coeff)); + break; case ineq_kind::LT: - // args < bound -> args >= bound - SASSERT(argsv <= bound); - SASSERT(delta <= 0); - delta = abs(delta); - new_value = value(v) + divide(v, delta + ctx.rand(3), coeff); - VERIFY(argsv + coeff * (new_value - value(v)) >= bound); - return move_to_bounds(); + SASSERT(sum < 0); + add_update(v, divide(v, -sum, coeff)); + break; case ineq_kind::EQ: { - delta = abs(delta) + 1 + ctx.rand(10); - int sign = ctx.rand(2) == 0 ? 1 : -1; - new_value = value(v) + sign * divide(v, abs(delta), coeff); - return move_to_bounds(); + SASSERT(sum == 0); + add_update(v, num_t(1)); + add_update(v, num_t(- 1)); + break; } default: UNREACHABLE(); @@ -218,40 +409,88 @@ namespace sls { else { switch (ineq.m_op) { case ineq_kind::LE: - SASSERT(argsv > bound); - SASSERT(delta > 0); - delta += ctx.rand(10); - new_value = value(v) - divide(v, delta + ctx.rand(3), coeff); - return move_to_bounds(); + SASSERT(sum > 0); + add_update(v, - divide(v, sum, coeff)); + break; case ineq_kind::LT: - SASSERT(argsv >= bound); - SASSERT(delta >= 0); - delta += 1 + ctx.rand(10); - new_value = value(v) - divide(v, delta + ctx.rand(3), coeff); - return move_to_bounds(); - case ineq_kind::EQ: - SASSERT(delta != 0); - if (delta < 0) - new_value = value(v) + divide(v, abs(delta), coeff); + SASSERT(sum >= 0); + add_update(v, - divide(v, sum + 1, coeff)); + break; + case ineq_kind::EQ: { + num_t delta = sum; + SASSERT(sum != 0); + delta = sum < 0 ? divide(v, abs(sum), coeff) : -divide(v, sum, coeff); + if (sum + coeff * delta != 0) + solve_eq_pairs(v, ineq); else - new_value = value(v) - divide(v, delta, coeff); - solved = argsv + coeff * (new_value - value(v)) == bound; - return solved && move_to_bounds(); + add_update(v, delta); + break; + } default: UNREACHABLE(); break; } } - return false; } template - bool arith_base::solve_eq_pairs(ineq const& ineq) { + bool arith_base::is_permitted_update(var_t v, num_t& delta) { + auto& vi = m_vars[v]; + + if (m_last_var == v && m_last_delta == -delta) + return false; + + if (m_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) + return false; + + auto old_value = value(v); + auto new_value = old_value + delta; + if (!vi.in_range(new_value)) + return false; + + if (!in_bounds(v, new_value) && in_bounds(v, old_value)) { + auto const& lo = m_vars[v].m_lo; + auto const& hi = m_vars[v].m_hi; + if (lo && (lo->is_strict ? lo->value >= new_value : lo->value > new_value)) { + if (lo->is_strict && delta < 0 && lo->value <= old_value) { + num_t eps(1); + if (hi && hi->value - lo->value <= eps) + eps = (hi->value - lo->value) / num_t(2); + delta = lo->value - old_value + eps; + } + else if (!lo->is_strict && delta < 0 && lo->value < old_value) + delta = lo->value - old_value; + else + return false; + } + if (hi && (hi->is_strict ? hi->value <= new_value : hi->value < new_value)) { + if (hi->is_strict && delta >= 0 && hi->value >= old_value) { + num_t eps(1); + if (lo && hi->value - lo->value <= eps) + eps = (hi->value - lo->value) / num_t(2); + delta = hi->value - old_value - eps; + } + else if (!hi->is_strict && delta > 0 && hi->value > old_value) + delta = hi->value - old_value; + else + return false; + } + } + return delta != 0; + } + + template + void arith_base::add_update(var_t v, num_t delta) { + if (!is_permitted_update(v, delta)) + return; + m_updates.push_back({ v, delta, compute_score(v, delta) }); + } + + template + bool arith_base::solve_eq_pairs(var_t v, ineq const& ineq) { SASSERT(ineq.m_op == ineq_kind::EQ); - auto v = ineq.m_var_to_flip; if (is_fixed(v)) return false; - auto bound = -ineq.m_coeff; auto argsv = ineq.m_args_value; num_t a; for (auto const& [c, w] : ineq.m_args) @@ -261,7 +500,7 @@ namespace sls { } if (abs(a) == 1) return false; - verbose_stream() << "solve_eq_pairs " << ineq << " for v" << v << "\n"; + IF_VERBOSE(3, verbose_stream() << "solve_eq_pairs " << ineq << " for v" << v << "\n"); unsigned start = ctx.rand(); for (unsigned i = 0; i < ineq.m_args.size(); ++i) { unsigned j = (start + i) % ineq.m_args.size(); @@ -271,8 +510,7 @@ namespace sls { if (b == 1 || b == -1) continue; argsv -= value(w) * b; - if (solve_eq_pairs(a, v, b, w, bound - argsv)) - return true; + solve_eq_pairs(a, v, b, w, - argsv); argsv += value(w) * b; } return false; @@ -297,10 +535,6 @@ namespace sls { x0 *= div(r, g); y0 *= div(r, g); - //verbose_stream() << r << " == " << a << "*" << x0 << " + " << b << "*" << y0 << "\n"; - - - auto adjust_lo = [&](num_t& x0, num_t& y0, num_t a, num_t b, optional const& lo, optional const& hi) { if (!lo || lo->value <= x0) return true; @@ -363,9 +597,8 @@ namespace sls { return false; if (abs(value(y)) * 2 < abs(y0)) return false; - update(x, x0); - update(y, y0); - + add_update(x, x0 - value(x)); + // add_update(y, y0 - value(y)); add pairwise update? return true; } @@ -373,124 +606,59 @@ namespace sls { // it could be changed to flip on maximal positive score // or flip on maximal non-negative score // or flip on first non-negative score - template - void arith_base::repair(sat::literal lit, ineq const& ineq) { - num_t new_value, old_value; - dtt_reward(lit); - auto v = ineq.m_var_to_flip; - - if (v == UINT_MAX) { - IF_VERBOSE(0, verbose_stream() << "no var to flip\n"); - return; - } - - if (repair_eq(lit, ineq)) - return; - - if (!cm(ineq, v, new_value)) { - display(verbose_stream(), v) << "\n"; - IF_VERBOSE(0, verbose_stream() << "no critical move for " << v << "\n"); - if (dtt(!ctx.is_true(lit), ineq) != 0) - ctx.flip(lit.var()); - return; - } - verbose_stream() << "repair " << lit << ": " << ineq << " var: v" << v << " := " << value(v) << " -> " << new_value << "\n"; - //for (auto const& [coeff, w] : ineq.m_args) - // display(verbose_stream(), w) << "\n"; - update(v, new_value); - invariant(ineq); - if (dtt(!ctx.is_true(lit), ineq) != 0) - ctx.flip(lit.var()); - } - - template - bool arith_base::repair_eq(sat::literal lit, ineq const& ineq) { - if (lit.sign() || ineq.m_op != ineq_kind::EQ) - return false; - auto v = ineq.m_var_to_flip; - num_t new_value; - verbose_stream() << ineq << "\n"; - for (auto const& [coeff, w] : ineq.m_args) - display(verbose_stream(), w) << "\n"; - if (ctx.rand(10) == 0 && solve_eq_pairs(ineq)) { - verbose_stream() << ineq << "\n"; - for (auto const& [coeff, w] : ineq.m_args) - display(verbose_stream(), w) << "\n"; - } - else if (cm(ineq, v, new_value) && update(v, new_value)) - ; - else if (solve_eq_pairs(ineq)) { - verbose_stream() << ineq << "\n"; - for (auto const& [coeff, w] : ineq.m_args) - display(verbose_stream(), w) << "\n"; - } - else - return false; - SASSERT(dtt(!ctx.is_true(lit), ineq) == 0); - if (dtt(!ctx.is_true(lit), ineq) != 0) - ctx.flip(lit.var()); - return true; - } - - // - // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) - // TODO - use cached dts instead of computed dts - // cached dts has to be updated when the score of literals are updated. + // prefer maximal score + // prefer v/delta with oldest occurrence with same direction // + template - double arith_base::dscore(var_t v, num_t const& new_value) const { - double score = 0; - auto const& vi = m_vars[v]; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - sat::literal lit(bv, false); - for (auto cl : ctx.get_use_list(lit)) - score += (compute_dts(cl) - dts(cl, v, new_value)).get_int64() * ctx.get_weight(cl); - for (auto cl : ctx.get_use_list(~lit)) - score += (compute_dts(cl) - dts(cl, v, new_value)).get_int64() * ctx.get_weight(cl); + bool arith_base::apply_update() { + double sum_score = 0; + + for (auto const& [v, delta, score] : m_updates) + sum_score += score; + + while (!m_updates.empty()) { + + unsigned i = m_updates.size(); + double lim = sum_score * ((double)ctx.rand() / random_gen().max_value()); + do { + lim -= m_updates[--i].m_score; + } while (lim >= 0 && i > 0); + + auto [v, delta, score] = m_updates[i]; + + num_t new_value = value(v) + delta; + + IF_VERBOSE(10, verbose_stream() << "repair: v" << v << " := " << value(v) << " -> " << new_value << "\n"); + if (update(v, new_value)) { + m_last_var = v; + m_last_delta = delta; + m_stats.m_num_steps++; + m_vars[v].set_step(m_stats.m_num_steps, m_stats.m_num_steps + 3 + ctx.rand(10), delta); + return true; + } + sum_score -= score; + m_updates[i] = m_updates.back(); + m_updates.pop_back(); } - return score; + return false; } - // - // cm_score is costly. It involves several cache misses. - // Note that - // - get_use_list(lit).size() is "often" 1 or 2 - // - dtt_old can be saved - // template - int arith_base::cm_score(var_t v, num_t const& new_value) { - int score = 0; - auto& vi = m_vars[v]; - num_t old_value = vi.m_value; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - auto const& ineq = *atom(bv); - bool old_sign = sign(bv); - num_t dtt_old = dtt(old_sign, ineq); - num_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value); - if ((dtt_old == 0) == (dtt_new == 0)) - continue; - sat::literal lit(bv, old_sign); - if (dtt_old == 0) - // flip from true to false - lit.neg(); + bool arith_base::repair(sat::literal lit) { + + find_moves(lit); + + if (apply_update()) + return true; - // lit flips form false to true: + find_reset_moves(lit); - for (auto cl : ctx.get_use_list(lit)) { - auto const& clause = ctx.get_clause(cl); - if (!clause.is_true()) - ++score; - } + if (apply_update()) + return true; - // ignore the situation where clause contains multiple literals using v - for (auto cl : ctx.get_use_list(~lit)) { - auto const& clause = ctx.get_clause(cl); - if (clause.m_num_trues == 1) - --score; - } - } - return score; + return false; } template @@ -560,34 +728,30 @@ namespace sls { bool arith_base::update(var_t v, num_t const& new_value) { auto& vi = m_vars[v]; expr* e = vi.m_expr; + SASSERT(!m.is_value(e)); auto old_value = vi.m_value; if (old_value == new_value) return true; - display(verbose_stream(), v) << " := " << new_value << "\n"; - if (!in_bounds(v, new_value)) { - auto const& lo = vi.m_lo; - auto const& hi = vi.m_hi; - if (is_int(v) && lo && !lo->is_strict && new_value < lo->value) { - if (lo->value != old_value) - return update(v, lo->value); - if (in_bounds(v, old_value + 1)) - return update(v, old_value + 1); - else - return false; + if (!vi.in_range(new_value)) + return false; + if (!in_bounds(v, new_value) && in_bounds(v, old_value)) + return false; + + // check for overflow + try { + for (auto idx : vi.m_muls) { + auto const& [w, coeff, monomial] = m_muls[idx]; + num_t prod(coeff); + for (auto [w, p] : monomial) + prod *= power_of(v == w ? new_value : value(w), p); } - if (is_int(v) && hi && !hi->is_strict && new_value > hi->value) { - if (hi->value != old_value) - return update(v, hi->value); - else if (in_bounds(v, old_value - 1)) - return update(v, old_value - 1); - else - return false; - } - verbose_stream() << "out of bounds old value " << old_value << "\n"; - display(verbose_stream(), v) << "\n"; - SASSERT(false); + } + catch (overflow_exception const&) { return false; } + + IF_VERBOSE(10, display(verbose_stream(), v) << " := " << new_value << "\n"); + for (auto const& [coeff, bv] : vi.m_bool_vars) { auto& ineq = *atom(bv); bool old_sign = sign(bv); @@ -601,27 +765,26 @@ namespace sls { SASSERT(dtt(sign(bv), ineq) == 0); } vi.m_value = new_value; - - - SASSERT(!m.is_value(e)); - verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n"; ctx.new_value_eh(e); + + IF_VERBOSE(10, verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n"); + for (auto idx : vi.m_muls) { auto const& [w, coeff, monomial] = m_muls[idx]; num_t prod(coeff); - for (auto [w, p]:monomial) + for (auto [w, p] : monomial) prod *= power_of(value(w), p); - if (value(w) != prod) - update(w, prod); + if (value(w) != prod && !update(w, prod)) + return false; } + for (auto idx : vi.m_adds) { auto const& ad = m_adds[idx]; num_t sum(ad.m_coeff); for (auto const& [coeff, w] : ad.m_args) sum += coeff * value(w); - if (sum != ad.m_coeff) - update(ad.m_var, sum); - + if (!update(ad.m_var, sum)) + return false; } return true; @@ -828,7 +991,7 @@ namespace sls { if (v == UINT_MAX) { v = m_vars.size(); m_expr2var.setx(e->get_id(), v, UINT_MAX); - m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); + m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); } return v; } @@ -885,11 +1048,59 @@ namespace sls { template void arith_base::init_ineq(sat::bool_var bv, ineq& i) { - i.m_args_value = 0; + + // ensure that variables are unique in the linear term: + std::stable_sort(i.m_args.begin(), i.m_args.end(), [&](auto const& a, auto const& b) { return a.second < b.second; }); + unsigned k = 0; + for (unsigned j = 0; j < i.m_args.size(); ++j) { + if (j > k && i.m_args[k].second == i.m_args[j].second) + i.m_args[k].first += i.m_args[j].first; + else + i.m_args[k++] = i.m_args[j]; + } + i.m_args.shrink(k); + // compute the value of the linear term, and accumulate non-linear sub-terms + i.m_args_value = i.m_coeff; for (auto const& [coeff, v] : i.m_args) { m_vars[v].m_bool_vars.push_back({ coeff, bv }); i.m_args_value += coeff * value(v); + if (is_mul(v)) { + auto const& [w, c, monomial] = get_mul(v); + for (auto [w, p] : monomial) + i.m_nonlinear.push_back({ w, { {v, coeff, p} } }); + } + else + i.m_nonlinear.push_back({ v, { { v, coeff, 1 } } }); } + std::stable_sort(i.m_nonlinear.begin(), i.m_nonlinear.end(), [&](auto const& a, auto const& b) { return a.first < b.first; }); + + // ensure that non-linear terms are have a unique summary. + k = 0; + for (unsigned j = 0; j < i.m_nonlinear.size(); ++j) { + if (j > k && i.m_nonlinear[k].first == i.m_nonlinear[j].first) + i.m_nonlinear[k].second.append(i.m_nonlinear[j].second); + else + i.m_nonlinear[k++] = i.m_nonlinear[j]; + } + i.m_nonlinear.shrink(k); + + // Ensure that non-linear term occurrences are sorted, and + // that terms with the same variable are combined. + for (auto& [x, nl] : i.m_nonlinear) { + if (nl.size() == 1) + continue; + std::stable_sort(nl.begin(), nl.end(), [&](auto const& a, auto const& b) { return a.p < b.p; }); + k = 0; + for (unsigned j = 0; j < nl.size(); ++j) { + if (j > k && nl[k].v == nl[j].v) + nl[k].coeff += nl[j].coeff; + else + nl[k++] = nl[j]; + } + nl.shrink(k); + } + + // attach i to bv m_bool_vars.set(bv, &i); } @@ -909,7 +1120,7 @@ namespace sls { return; if (ineq->is_true() != lit.sign()) return; - repair(lit, *ineq); + repair(lit); } template @@ -1128,26 +1339,40 @@ namespace sls { template bool arith_base::repair_add(add_def const& ad) { auto v = ad.m_var; + auto old_value = value(v); auto const& coeffs = ad.m_args; num_t sum(ad.m_coeff); - num_t val = value(v); - - verbose_stream() << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << value(v) << "\n"; for (auto const& [c, w] : coeffs) sum += c * value(w); - if (val == sum) + + if (old_value == sum) return true; - if (ctx.rand(20) == 0) - return update(v, sum); - else { - auto const& [c, w] = coeffs[ctx.rand(coeffs.size())]; - num_t delta = sum - val; - bool is_real = m_vars[w].m_sort == var_sort::REAL; - bool round_down = ctx.rand(2) == 0; - num_t new_value = value(w) + (is_real ? delta / c : round_down ? div(delta, c) : div(delta + c - 1, c)); - return update(w, new_value); - } + + m_updates.reset(); +// display(verbose_stream(), v) << " "; +// verbose_stream() << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << old_value << " " << sum << "\n"; + + for (auto const& [coeff, w] : coeffs) + add_update(v, divide(w, old_value - sum, coeff)); + + if (apply_update()) + return eval_is_correct(v); + + m_updates.reset(); + for (auto const& [coeff, w] : coeffs) + if (is_mul(w)) { + auto const& [w1, c, monomial] = get_mul(w); + for (auto [w1, p] : monomial) + add_reset_update(w1); + } + else + add_reset_update(w); + + if (apply_update()) + return eval_is_correct(v); + + return update(v, sum); } template @@ -1155,7 +1380,9 @@ namespace sls { auto const& [v, coeff, monomial] = md; if (!is_int(v)) return false; - for (auto [c, v, p] : monomial_iterator(md)) { + for (auto [c, v, p] : monomial_iterator(md)) { + if (c == 0) + continue; num_t val1 = div(value(v), c); if (val1 < 0 && p % 2 == 0) continue; @@ -1290,7 +1517,7 @@ namespace sls { product *= power_of(value(v), p); if (product == val) return true; - verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n"; + IF_VERBOSE(10, verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n"); unsigned sz = monomial.size(); if (ctx.rand(20) == 0) return update(v, product); @@ -1308,8 +1535,13 @@ namespace sls { return true; else if (repair_mul_factors(md)) return true; + else if (repair_mul_one(md)) + return true; else { - NOT_IMPLEMENTED_YET(); + IF_VERBOSE(0, verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n"); + + //NOT_IMPLEMENTED_YET(); + return false; } return false; } @@ -1323,6 +1555,8 @@ namespace sls { num_t n = div(val, coeff); if (!divides(coeff, val) && ctx.rand(2) == 0) n = div(val + coeff - 1, coeff); + if (n == 0) + return false; auto const& fs = factor(abs(n)); unsigned sz = monomial.size(); vector coeffs(sz, num_t(1)); @@ -1463,73 +1697,141 @@ namespace sls { } template - double arith_base::reward(sat::literal lit) { - if (m_dscore_mode) - return dscore_reward(lit.var()); + double arith_base::compute_score(var_t x, num_t const& delta) { + int result = 0; + for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { + bool old_sign = sign(bv); + auto dtt_old = dtt(old_sign, *atom(bv)); + auto dtt_new = dtt(old_sign, *atom(bv), coeff, delta); + if (dtt_new == 0 && dtt_old != 0) + result += 1; + if (dtt_new != 0 && dtt_old == 0) + result -= 1; + } + + if (result < 0) + return 0.1; + else if (result == 0) + return 0.2; else - return dtt_reward(lit); + return result; } template - double arith_base::dtt_reward(sat::literal lit) { - auto* ineq = atom(lit.var()); - if (!ineq) - return -1; - num_t new_value; - double max_result = -100; - unsigned n = 0, mult = 2; - double sum_prob = 0; - unsigned i = 0; - m_probs.reserve(ineq->m_args.size()); - for (auto const& [coeff, x] : ineq->m_args) { - double result = 0; - double prob = 0; - if (is_fixed(x)) - prob = 0; - else if (!cm(*ineq, x, coeff, new_value)) - prob = 0.5; - else { + num_t arith_base::mul_value_without(var_t m, var_t x) { + auto const& vi = m_vars[m]; + auto const& [w, coeff, monomial] = m_muls[vi.m_def_idx]; + SASSERT(m == w); + num_t r(coeff); + for (auto [y, p] : monomial) + if (x != y) + for (unsigned i = 0; i < p; ++i) + r *= value(y); + return r; + } - auto old_value = m_vars[x].m_value; - for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { - bool old_sign = sign(bv); - auto dtt_old = dtt(old_sign, *atom(bv)); - auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value); - if (dtt_new == 0 && dtt_old != 0) - result += 1; - if (dtt_new != 0 && dtt_old == 0) - result -= 1; - } - - if (result > max_result || max_result == -100 || (result == max_result && (ctx.rand(++n) == 0))) - max_result = result; - - if (result < 0) - prob = 0.1; - else if (result == 0) - prob = 0.2; - else - prob = result; - - } - // verbose_stream() << "prob v" << x << " " << prob << "\n"; - m_probs[i++] = prob; - sum_prob += prob; + template + bool arith_base::is_linear(var_t x, vector const& nl, num_t& b) { + if (nl.size() == 1 && nl[0].v == x) { + b = nl[0].coeff; + return true; } - double lim = sum_prob * ((double)ctx.rand() / random_gen().max_value()); - do { - lim -= m_probs[--i]; - } - while (lim >= 0 && i > 0); + b = 0; + for (auto const& [v, c, p] : nl) { + if (p > 1) + return false; + if (x == v) + b += c; + else + b += c * mul_value_without(v, x); + } + return b != 0; + } - ineq->m_var_to_flip = ineq->m_args[i].second; + template + bool arith_base::is_quadratic(var_t x, vector const& nl, num_t& a, num_t& b) { + a = 0; + b = 0; + for (auto const& [v, c, p] : nl) { + if (p == 1) { + if (x == v) + b += c; + else + b += c * mul_value_without(v, x); + } + else if (p == 2) { + SASSERT(v != x); + a += c * mul_value_without(v, x); + } + else + return false; + } + return a != 0 || b != 0; + } - return max_result; + template + void arith_base::find_moves(sat::literal lit) { + m_updates.reset(); + auto* ineq = atom(lit.var()); + num_t a, b; + if (!ineq) + return; + for (auto const& [x, nl] : ineq->m_nonlinear) { + if (is_fixed(x)) + continue; + if (is_linear(x, nl, b)) + find_linear_moves(*ineq, x, b, ineq->m_args_value); + else if (is_quadratic(x, nl, a, b)) + find_quadratic_moves(*ineq, x, a, b, ineq->m_args_value); + else + ; + } + } + + template + void arith_base::add_reset_update(var_t x) { + m_last_delta = 0; + if (is_fixed(x)) + return; + auto const& vi = m_vars[x]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + auto new_value = num_t(ctx.rand(5) - 2); + if (lo && lo->value > new_value) + new_value = lo->value; + else if (hi && hi->value < new_value) + new_value = hi->value; + if (new_value != value(x)) + add_update(x, new_value - value(x)); + else { + add_update(x, num_t(1)); + add_update(x, -num_t(1)); + } + } + + template + void arith_base::find_reset_moves(sat::literal lit) { + auto* ineq = atom(lit.var()); + num_t a, b; + if (!ineq) + return; + for (auto const& [x, nl] : ineq->m_nonlinear) + add_reset_update(x); + + IF_VERBOSE(10, + if (m_updates.empty()) { + verbose_stream() << *ineq << "\n"; + for (auto const& [x, nl] : ineq->m_nonlinear) { + auto const& vi = m_vars[x]; + display(verbose_stream() << "v" << x << "\n", x) << "\n"; + } + } + verbose_stream() << "RESET moves num updates: " << lit << " " << m_updates.size() << "\n"); } template num_t arith_base::power_of(num_t x, unsigned k) { - num_t r = x; + num_t r(1); while (k > 1) { if (k % 2 == 1) { r = x * r; @@ -1565,6 +1867,8 @@ namespace sls { template vector const& arith_base::factor(num_t n) { m_factors.reset(); + if (n == 0) + return m_factors; for (auto d : { 2, 3, 5 }) { while (mod(n, num_t(d)) == 0) { m_factors.push_back(num_t(d)); @@ -1586,31 +1890,6 @@ namespace sls { return m_factors; } - - template - double arith_base::dscore_reward(sat::bool_var bv) { - m_dscore_mode = false; - bool old_sign = sign(bv); - sat::literal litv(bv, old_sign); - auto* ineq = atom(bv); - if (!ineq) - return 0; - SASSERT(ineq->is_true() != old_sign); - num_t new_value; - - for (auto const& [coeff, v] : ineq->m_args) { - double result = 0; - if (cm(*ineq, v, coeff, new_value)) - result = dscore(v, new_value); - // just pick first positive, or pick a max? - if (result > 0) { - ineq->m_var_to_flip = v; - return result; - } - } - return 0; - } - // switch to dscore mode template void arith_base::on_rescale() { @@ -1626,19 +1905,17 @@ namespace sls { template void arith_base::check_ineqs() { - auto check_bool_var = [&](sat::bool_var bv) { + for (unsigned bv = 0; bv < ctx.num_bool_vars(); ++bv) { auto const* ineq = atom(bv); if (!ineq) - return; + continue; num_t d = dtt(sign(bv), *ineq); sat::literal lit(bv, sign(bv)); if (ctx.is_true(lit) != (d == 0)) { verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n"; } VERIFY(ctx.is_true(lit) == (d == 0)); - }; - for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) - check_bool_var(v); + } } template @@ -1656,20 +1933,20 @@ namespace sls { } template - void arith_base::set_value(expr* e, expr* v) { + bool arith_base::set_value(expr* e, expr* v) { if (!a.is_int_real(e)) - return; + return false; var_t w = m_expr2var.get(e->get_id(), UINT_MAX); if (w == UINT_MAX) w = mk_term(e); num_t n; if (!is_num(v, n)) - return; + return false; // verbose_stream() << "set value " << w << " " << mk_bounded_pp(e, m) << " " << n << " " << value(w) << "\n"; if (n == value(w)) - return; - update(w, n); + return true; + return update(w, n); } template @@ -1719,6 +1996,34 @@ namespace sls { return true; } + template + std::ostream& arith_base::display(std::ostream& out, add_def const& ad) const { + bool first = true; + for (auto [c, w] : ad.m_args) { + if (first && c == 1) + ; + else if (first && c == -1) + out << "-"; + else if (first) + out << c << "*"; + else if (c == 1) + out << " + "; + else if (c == - 1) + out << " - "; + else if (c > 0) + out << " + " << c << "*"; + else + out << " - " << -c << "*"; + first = false; + out << "v" << w; + } + if (ad.m_coeff > 0) + out << " + " << ad.m_coeff; + else if (ad.m_coeff < 0) + out << " - " << -ad.m_coeff; + return out; + } + template std::ostream& arith_base::display(std::ostream& out, var_t v) const { auto const& vi = m_vars[v]; @@ -1737,9 +2042,28 @@ namespace sls { out << ")"; out << " "; } - out << mk_bounded_pp(vi.m_expr, m) << " : "; + out << mk_bounded_pp(vi.m_expr, m) << " "; + if (is_add(v)) + display(out << "add: ", get_add(v)) << " "; + + if (!vi.m_adds.empty()) { + out << " adds: "; + for (auto v : vi.m_adds) + out << "v" << m_adds[v].m_var << " "; + out << " "; + } + + if (!vi.m_muls.empty()) { + out << " muls: "; + for (auto v : vi.m_muls) + out << "v" << m_muls[v].m_var << " "; + out << " "; + } + + if (!vi.m_bool_vars.empty()) + out << " bool: "; for (auto [c, bv] : vi.m_bool_vars) - out << c << "@" << bv << " "; + out << c << "@" << bv << " "; return out; } @@ -1764,15 +2088,9 @@ namespace sls { out << "\n"; } - for (auto ad : m_adds) { - out << "v" << ad.m_var << " := "; - bool first = true; - for (auto [c, w] : ad.m_args) - out << (first?"":" + ") << c << "* v" << w; - if (ad.m_coeff != 0) - out << " + " << ad.m_coeff; - out << "\n"; - } + for (auto ad : m_adds) + display(out, ad) << "\n"; + for (auto od : m_ops) { out << "v" << od.m_var << " := "; out << "v" << od.m_arg1 << " op-" << od.m_op << " v" << od.m_arg2 << "\n"; @@ -1780,6 +2098,71 @@ namespace sls { return out; } + template + bool arith_base::eval_is_correct(var_t v) { + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return true; + TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); + switch (vi.m_op) { + case arith_op_kind::LAST_ARITH_OP: + break; + case arith_op_kind::OP_ADD: { + auto ad = m_adds[vi.m_def_idx]; + num_t sum(ad.m_coeff); + for (auto [c, w] : ad.m_args) + sum += c * value(w); + return sum == value(v); + } + case arith_op_kind::OP_MUL: { + auto md = m_muls[vi.m_def_idx]; + num_t prod(md.m_coeff); + for (auto [w, p] : md.m_monomial) + prod *= power_of(value(w), p); + return prod == value(v); + } + case arith_op_kind::OP_MOD: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); + } + case arith_op_kind::OP_REM: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); + } + case arith_op_kind::OP_POWER: { + auto od = m_ops[vi.m_def_idx]; + NOT_IMPLEMENTED_YET(); + break; + } + case arith_op_kind::OP_IDIV: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : div(value(od.m_arg1), value(od.m_arg2))); + } + case arith_op_kind::OP_DIV: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : value(od.m_arg1) / value(od.m_arg2)); + } + case arith_op_kind::OP_ABS: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == abs(value(od.m_arg1)); + } + case arith_op_kind::OP_TO_INT: { + auto od = m_ops[vi.m_def_idx]; + NOT_IMPLEMENTED_YET(); + break; + } + case arith_op_kind::OP_TO_REAL: { + auto od = m_ops[vi.m_def_idx]; + NOT_IMPLEMENTED_YET(); + break; + } + default: { + NOT_IMPLEMENTED_YET(); + break; + } + } + return true; + } template void arith_base::invariant() { @@ -1789,49 +2172,43 @@ namespace sls { invariant(*ineq); } auto& out = verbose_stream(); - for (auto md : m_muls) { - auto const& [w, coeff, monomial] = md; - num_t prod(coeff); - for (auto [v,p] : monomial) - prod *= power_of(value(v), p); - //verbose_stream() << "check " << w << " " << monomial << "\n"; - if (prod != value(w)) { - out << prod << " " << value(w) << "\n"; - out << "v" << w << " := "; - for (auto [w, p] : monomial) { - out << "v" << w; - if (p > 1) - out << "^" << p; - out << " "; + for (var_t v = 0; v < m_vars.size(); ++v) { + if (!eval_is_correct(v)) { + + display(out); + display(out, v) << "\n"; + out << mk_bounded_pp(m_vars[v].m_expr, m) << "\n"; + + if (is_mul(v)) { + auto const& [w, coeff, monomial] = get_mul(v); + num_t prod(coeff); + for (auto [v, p] : monomial) + prod *= power_of(value(v), p); + out << "product " << prod << " value " << value(w) << "\n"; + out << "coeff " << coeff << "\n"; + out << "v" << w << " := "; + for (auto [w, p] : monomial) { + out << "(v" << w; + if (p > 1) + out << "^" << p; + out << " := " << value(w); + out << ") "; + } + out << "\n"; } - out << "\n"; + else if (is_add(v)) { + auto const& ad = get_add(v); + out << "v" << ad.m_var << " := "; + display(out, ad) << "\n"; + } + UNREACHABLE(); } - SASSERT(prod == value(w)); - - } - for (auto ad : m_adds) { - //out << "check add " << ad.m_var << "\n"; - num_t sum(ad.m_coeff); - for (auto [c, w] : ad.m_args) - sum += c * value(w); - if (sum != value(ad.m_var)) { - - - out << "v" << ad.m_var << " := "; - bool first = true; - for (auto [c, w] : ad.m_args) - out << (first ? "" : " + ") << c << "* v" << w; - if (ad.m_coeff != 0) - out << " + " << ad.m_coeff; - out << "\n"; - } - SASSERT(sum == value(ad.m_var)); } } template void arith_base::invariant(ineq const& i) { - num_t val(0); + num_t val = i.m_coeff; for (auto const& [c, v] : i.m_args) val += c * value(v); //verbose_stream() << "invariant " << i << "\n"; @@ -1843,6 +2220,16 @@ namespace sls { template void arith_base::mk_model(model& mdl) { } + + template + void arith_base::collect_statistics(statistics& st) const { + st.update("sls-arith-flips", m_stats.m_num_steps); + } + + template + void arith_base::reset_statistics() { + m_stats.m_num_steps = 0; + } } template class sls::arith_base>; diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index a3dba2791..c304e2c40 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -45,7 +45,7 @@ namespace sls { }; struct stats { - unsigned m_num_flips = 0; + unsigned m_num_steps = 0; }; public: @@ -53,11 +53,16 @@ namespace sls { vector> m_args; num_t m_coeff{ 0 }; }; + struct nonlinear_coeff { + var_t v; // variable or multiplier containing x + num_t coeff; // coeff of v in inequality + unsigned p; // power + }; // encode args <= bound, args = bound, args < bound - struct ineq : public linear_term { + struct ineq : public linear_term { + vector>> m_nonlinear; ineq_kind m_op = ineq_kind::LE; num_t m_args_value; - unsigned m_var_to_flip = UINT_MAX; bool is_true() const; std::ostream& display(std::ostream& out) const; @@ -76,6 +81,27 @@ namespace sls { unsigned_vector m_muls; unsigned_vector m_adds; optional m_lo, m_hi; + num_t m_range{ 100000000 }; + bool in_range(num_t const& n) const { + if (-m_range < n && n < m_range) + return true; + if (m_lo && !m_hi) + return n < m_lo->value + m_range; + if (!m_lo && m_hi) + return n > m_hi->value - m_range; + return false; + } + unsigned m_tabu_pos = 0, m_tabu_neg = 0; + unsigned m_last_pos = 0, m_last_neg = 0; + bool is_tabu(unsigned step, num_t const& delta) { + return (delta > 0 ? m_tabu_pos : m_tabu_neg) > step; + } + void set_step(unsigned step, unsigned tabu_step, num_t const& delta) { + if (delta > 0) + m_tabu_pos = tabu_step, m_last_pos = step; + else + m_tabu_neg = tabu_step, m_last_neg = step; + } }; struct mul_def { @@ -93,6 +119,12 @@ namespace sls { arith_op_kind m_op = LAST_ARITH_OP; unsigned m_arg1, m_arg2; }; + + struct var_change { + unsigned m_var; + num_t m_delta; + double m_score; + }; stats m_stats; config m_config; @@ -104,6 +136,10 @@ namespace sls { unsigned_vector m_expr2var; svector m_probs; bool m_dscore_mode = false; + vector m_updates; + var_t m_last_var = 0; + num_t m_last_delta { 0 }; + bool m_tabu = false; arith_util a; void invariant(); @@ -111,6 +147,7 @@ namespace sls { unsigned get_num_vars() const { return m_vars.size(); } + bool eval_is_correct(var_t v); bool repair_mul_one(mul_def const& md); bool repair_power(mul_def const& md); bool repair_mul_factors(mul_def const& md); @@ -126,10 +163,15 @@ namespace sls { bool repair_abs(op_def const& od); bool repair_to_int(op_def const& od); bool repair_to_real(op_def const& od); - void repair(sat::literal lit, ineq const& ineq); - bool repair_eq(sat::literal lit, ineq const& ineq); + bool repair(sat::literal lit); bool in_bounds(var_t v, num_t const& value); bool is_fixed(var_t v); + bool is_linear(var_t x, vector const& nlc, num_t& b); + bool is_quadratic(var_t x, vector const& nlc, num_t& a, num_t& b); + num_t mul_value_without(var_t m, var_t x); + + void add_update(var_t v, num_t delta); + bool is_permitted_update(var_t v, num_t& delta); vector m_factors; vector const& factor(num_t n); @@ -174,25 +216,32 @@ namespace sls { monomials monomial_iterator(mul_def const& md) { return monomials(*this, md); } - double reward(sat::literal lit); + // double reward(sat::literal lit); bool sign(sat::bool_var v) const { return !ctx.is_true(sat::literal(v, false)); } ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); } num_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } num_t dtt(bool sign, num_t const& args_value, ineq const& ineq) const; num_t dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const; - num_t dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const; + num_t dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& delta) const; num_t dts(unsigned cl, var_t v, num_t const& new_value) const; num_t compute_dts(unsigned cl) const; - bool cm(ineq const& ineq, var_t v, num_t& new_value); - bool cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value); - int cm_score(var_t v, num_t const& new_value); + + bool is_mul(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_MUL; } + bool is_add(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_ADD; } + mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; } + add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; } + bool update(var_t v, num_t const& new_value); - double dscore_reward(sat::bool_var v); - double dtt_reward(sat::literal lit); - double dscore(var_t v, num_t const& new_value) const; + bool apply_update(); + void find_moves(sat::literal lit); + void find_reset_moves(sat::literal lit); + void add_reset_update(var_t v); + void find_linear_moves(ineq const& i, var_t x, num_t const& coeff, num_t const& sum); + void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum); + double compute_score(var_t x, num_t const& delta); void save_best_values(); - bool solve_eq_pairs(ineq const& ineq); + bool solve_eq_pairs(var_t v, ineq const& ineq); bool solve_eq_pairs(num_t const& a, var_t x, num_t const& b, var_t y, num_t const& r); var_t mk_var(expr* e); @@ -203,6 +252,8 @@ namespace sls { ineq& new_ineq(ineq_kind op, num_t const& bound); void init_ineq(sat::bool_var bv, ineq& i); num_t divide(var_t v, num_t const& delta, num_t const& coeff); + num_t divide_floor(var_t v, num_t const& a, num_t const& b); + num_t divide_ceil(var_t v, num_t const& a, num_t const& b); void init_bool_var_assignment(sat::bool_var v); @@ -219,11 +270,12 @@ namespace sls { void add_lt(var_t v, num_t const& n); void add_gt(var_t v, num_t const& n); std::ostream& display(std::ostream& out, var_t v) const; + std::ostream& display(std::ostream& out, add_def const& ad) const; public: arith_base(context& ctx); ~arith_base() override {} void register_term(expr* e) override; - void set_value(expr* e, expr* v) override; + bool set_value(expr* e, expr* v) override; expr_ref get_value(expr* e) override; void initialize() override; void propagate_literal(sat::literal lit) override; @@ -236,6 +288,8 @@ namespace sls { void on_restart() override; std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override; + void collect_statistics(statistics& st) const override; + void reset_statistics() override; }; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index 92a94ac37..120930c02 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -22,30 +22,42 @@ Author: namespace sls { #define WITH_FALLBACK(_fn_) \ - if (!m_arith) { \ + if (m_arith64) { \ try {\ return m_arith64->_fn_;\ }\ catch (overflow_exception&) {\ + throw;\ init_backup();\ }\ }\ return m_arith->_fn_; \ +#define APPLY_BOTH(_fn_) \ + if (m_arith64) { \ + try {\ + m_arith64->_fn_;\ + }\ + catch (overflow_exception&) {\ + throw;\ + init_backup();\ + }\ + }\ + m_arith->_fn_; \ + arith_plugin::arith_plugin(context& ctx) : plugin(ctx), m_shared(ctx.get_manager()) { m_arith64 = alloc(arith_base>, ctx); - m_fid = m_arith64->fid(); - init_backup(); + m_arith = alloc(arith_base, ctx); + m_fid = m_arith->fid(); } void arith_plugin::init_backup() { - m_arith = alloc(arith_base, ctx); - m_arith->initialize(); + m_arith64 = nullptr; } void arith_plugin::register_term(expr* e) { - WITH_FALLBACK(register_term(e)); + APPLY_BOTH(register_term(e)); } expr_ref arith_plugin::get_value(expr* e) { @@ -53,10 +65,7 @@ namespace sls { } void arith_plugin::initialize() { - if (m_arith) - m_arith->initialize(); - else - m_arith64->initialize(); + APPLY_BOTH(initialize()); } void arith_plugin::propagate_literal(sat::literal lit) { @@ -68,38 +77,26 @@ namespace sls { } bool arith_plugin::is_sat() { - if (m_arith) - return m_arith->is_sat(); - else - return m_arith64->is_sat(); + WITH_FALLBACK(is_sat()); } void arith_plugin::on_rescale() { - if (m_arith) - m_arith->on_rescale(); - else - m_arith64->on_rescale(); + APPLY_BOTH(on_rescale()); } void arith_plugin::on_restart() { - if (m_arith) - m_arith->on_restart(); - else - m_arith64->on_restart(); + WITH_FALLBACK(on_restart()); } - std::ostream& arith_plugin::display(std::ostream& out) const { - if (m_arith) - return m_arith->display(out); + std::ostream& arith_plugin::display(std::ostream& out) const { + if (m_arith64) + return m_arith64->display(out); else - return m_arith64->display(out); + return m_arith->display(out); } void arith_plugin::mk_model(model& mdl) { - if (m_arith) - m_arith->mk_model(mdl); - else - m_arith64->mk_model(mdl); + WITH_FALLBACK(mk_model(mdl)); } bool arith_plugin::repair_down(app* e) { @@ -114,7 +111,21 @@ namespace sls { WITH_FALLBACK(repair_literal(lit)); } - void arith_plugin::set_value(expr* e, expr* v) { - WITH_FALLBACK(set_value(e, v)); + bool arith_plugin::set_value(expr* e, expr* v) { + WITH_FALLBACK(set_value(e, v)); + } + + void arith_plugin::collect_statistics(statistics& st) const { + if (m_arith64) + m_arith64->collect_statistics(st); + else + m_arith->collect_statistics(st); + } + + void arith_plugin::reset_statistics() { + if (m_arith) + m_arith->reset_statistics(); + if (m_arith64) + m_arith64->reset_statistics(); } } diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 99951a22d..5c8d4b245 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -44,7 +44,10 @@ namespace sls { void on_restart() override; std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override; - void set_value(expr* e, expr* v) override; + bool set_value(expr* e, expr* v) override; + + void collect_statistics(statistics& st) const override; + void reset_statistics() override; }; } diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp index d966ee7b3..cbb77dad0 100644 --- a/src/ast/sls/sls_basic_plugin.cpp +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -86,11 +86,11 @@ namespace sls { return out; } - void basic_plugin::set_value(expr* e, expr* v) { + bool basic_plugin::set_value(expr* e, expr* v) { if (!is_basic(e)) - return; + return false; SASSERT(m.is_true(v) || m.is_false(v)); - set_value(e, m.is_true(v)); + return set_value(e, m.is_true(v)); } bool basic_plugin::bval1(app* e) const { @@ -229,6 +229,7 @@ namespace sls { case OP_XOR: NOT_IMPLEMENTED_YET(); case OP_ITE: + NOT_IMPLEMENTED_YET(); case OP_DISTINCT: NOT_IMPLEMENTED_YET(); @@ -278,8 +279,35 @@ namespace sls { } bool basic_plugin::try_repair_ite(app* e, unsigned i) { - if (!m.is_bool(e)) + if (m.is_bool(e)) + return try_repair_ite_bool(e, i); + else + return try_repair_ite_nonbool(e, i); + } + + bool basic_plugin::try_repair_ite_nonbool(app* e, unsigned i) { + auto child = e->get_arg(i); + auto cond = e->get_arg(0); + bool c = bval0(cond); + + if (i == 0) { + auto eval = ctx.get_value(e); + auto eval1 = ctx.get_value(e->get_arg(1)); + auto eval2 = ctx.get_value(e->get_arg(2)); + if (eval == eval1 && eval == eval2) + return true; + if (eval == eval1) + return set_value(cond, true); + if (eval == eval2) + return set_value(cond, false); return false; + } + if (c != (i == 1)) + return false; + return ctx.set_value(child, ctx.get_value(e)); + } + + bool basic_plugin::try_repair_ite_bool(app* e, unsigned i) { auto child = e->get_arg(i); auto cond = e->get_arg(0); bool c = bval0(cond); diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h index 82a7835c7..fc36ad629 100644 --- a/src/ast/sls/sls_basic_plugin.h +++ b/src/ast/sls/sls_basic_plugin.h @@ -29,6 +29,8 @@ namespace sls { bool try_repair_eq(app* e, unsigned i); bool try_repair_xor(app* e, unsigned i); bool try_repair_ite(app* e, unsigned i); + bool try_repair_ite_nonbool(app* e, unsigned i); + bool try_repair_ite_bool(app* e, unsigned i); bool try_repair_implies(app* e, unsigned i); bool try_repair_distinct(app* e, unsigned i); bool set_value(expr* e, bool b); @@ -55,7 +57,9 @@ namespace sls { void on_restart() override {} std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override {} - void set_value(expr* e, expr* v) override; + bool set_value(expr* e, expr* v) override; + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} }; } diff --git a/src/ast/sls/sls_bv_plugin.cpp b/src/ast/sls/sls_bv_plugin.cpp index 9fc4d4b62..1f3be03e9 100644 --- a/src/ast/sls/sls_bv_plugin.cpp +++ b/src/ast/sls/sls_bv_plugin.cpp @@ -89,14 +89,14 @@ namespace sls { return m_eval.display(out); } - void bv_plugin::set_value(expr* e, expr* v) { + bool bv_plugin::set_value(expr* e, expr* v) { if (!bv.is_bv(e)) - return; + return false; rational val; VERIFY(bv.is_numeral(v, val)); auto& w = m_eval.eval(to_app(e)); w.set_value(w.eval, val); - w.commit_eval(); + return w.commit_eval(); } bool bv_plugin::repair_down(app* e) { diff --git a/src/ast/sls/sls_bv_plugin.h b/src/ast/sls/sls_bv_plugin.h index ead5d7d8d..9cee67ac4 100644 --- a/src/ast/sls/sls_bv_plugin.h +++ b/src/ast/sls/sls_bv_plugin.h @@ -52,7 +52,9 @@ namespace sls { void on_restart() override {} std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override {} - void set_value(expr* e, expr* v) override; + bool set_value(expr* e, expr* v) override; + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} }; } diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 0525bfb39..64fe69ca9 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -59,13 +59,13 @@ namespace sls { // Use timestamps to make it incremental. // init(); - verbose_stream() << "check " << unsat().size() << "\n"; - while (unsat().empty()) { + //verbose_stream() << "check " << unsat().size() << "\n"; + while (unsat().empty() && m.inc()) { propagate_boolean_assignment(); - verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n"; + // verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n"; // display(verbose_stream()); @@ -73,7 +73,7 @@ namespace sls { if (m_new_constraint || !unsat().empty()) return l_undef; - verbose_stream() << unsat().size() << " " << m_new_constraint << "\n"; + //verbose_stream() << unsat().size() << " " << m_new_constraint << "\n"; if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) { model_ref mdl = alloc(model, m); @@ -84,7 +84,7 @@ namespace sls { if (p) p->mk_model(*mdl); s.on_model(mdl); - verbose_stream() << *mdl << "\n"; + // verbose_stream() << *mdl << "\n"; TRACE("sls", display(tout)); return l_true; } @@ -109,7 +109,7 @@ namespace sls { if (is_app(e)) { auto p = m_plugins.get(get_fid(e), nullptr); if (p && !p->repair_down(to_app(e)) && !m_repair_up.contains(e->get_id())) { - IF_VERBOSE(0, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n"); + IF_VERBOSE(3, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n"); m_repair_up.insert(e->get_id()); } } @@ -192,10 +192,11 @@ namespace sls { } - void context::set_value(expr * e, expr * v) { + bool context::set_value(expr * e, expr * v) { for (auto p : m_plugins) - if (p) - p->set_value(e, v); + if (p && p->set_value(e, v)) + return true; + return false; } bool context::is_relevant(expr* e) { @@ -377,4 +378,16 @@ namespace sls { return out; } + + void context::collect_statistics(statistics& st) const { + for (auto p : m_plugins) + if (p) + p->collect_statistics(st); + } + + void context::reset_statistics() { + for (auto p : m_plugins) + if (p) + p->reset_statistics(); + } } diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index ec6658954..6a3bebacc 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -18,6 +18,7 @@ Author: #include "util/sat_literal.h" #include "util/sat_sls.h" +#include "util/statistics.h" #include "ast/ast.h" #include "model/model.h" #include "util/scoped_ptr_vector.h" @@ -50,7 +51,9 @@ namespace sls { virtual void on_restart() {}; virtual std::ostream& display(std::ostream& out) const = 0; virtual void mk_model(model& mdl) = 0; - virtual void set_value(expr* e, expr* v) = 0; + virtual bool set_value(expr* e, expr* v) = 0; + virtual void collect_statistics(statistics& st) const = 0; + virtual void reset_statistics() = 0; }; using clause = ptr_iterator; @@ -156,7 +159,7 @@ namespace sls { // Between plugin solvers expr_ref get_value(expr* e); - void set_value(expr* e, expr* v); + bool set_value(expr* e, expr* v); void new_value_eh(expr* e); bool is_true(expr* e); bool is_fixed(expr* e); @@ -166,5 +169,8 @@ namespace sls { ast_manager& get_manager() { return m; } std::ostream& display(std::ostream& out) const; + void collect_statistics(statistics& st) const; + void reset_statistics(); + }; } diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h index e0871bb77..fd0aa7266 100644 --- a/src/ast/sls/sls_euf_plugin.h +++ b/src/ast/sls/sls_euf_plugin.h @@ -46,11 +46,14 @@ namespace sls { void register_term(expr* e) override; std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override; - void set_value(expr* e, expr* v) override {} + bool set_value(expr* e, expr* v) override { return false; } void repair_up(app* e) override {} bool repair_down(app* e) override { return false; } void repair_literal(sat::literal lit) override {} + + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} }; } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 00f5f2a2a..15dc55f93 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -88,6 +88,16 @@ namespace sls { m_new_clause_added = true; } model_ref get_model() { return m_model; } + + void collect_statistics(statistics& st) { + m_ddfw.collect_statistics(st); + m_context.collect_statistics(st); + } + + void reset_statistics() { + m_ddfw.reset_statistics(); + m_context.reset_statistics(); + } }; smt_solver::smt_solver(ast_manager& m, params_ref const& p): @@ -118,48 +128,60 @@ namespace sls { add_clause(f); IF_VERBOSE(10, m_solver_ctx->display(verbose_stream())); - return m_ddfw.check(0, nullptr); + auto r = m_ddfw.check(0, nullptr); + + return r; } void smt_solver::add_clause(expr* f) { - expr* g, * h; + expr* g, * h, * k; sat::literal_vector clause; if (m.is_not(f, g) && m.is_not(g, g)) { add_clause(g); + return; } - if (m.is_or(f)) { + bool sign = m.is_not(f, f); + if (!sign && m.is_or(f)) { clause.reset(); for (auto arg : *to_app(f)) clause.push_back(mk_literal(arg)); m_solver_ctx->add_clause(clause.size(), clause.data()); } - else if (m.is_and(f)) { + else if (!sign && m.is_and(f)) { for (auto arg : *to_app(f)) add_clause(arg); } - else if (m.is_not(f, g) && m.is_or(g)) { - for (auto arg : *to_app(g)) { + else if (sign && m.is_or(f)) { + for (auto arg : *to_app(f)) { expr_ref fml(m.mk_not(arg), m);; add_clause(fml); } } - else if (m.is_not(f, g) && m.is_and(g)) { + else if (sign && m.is_and(f)) { clause.reset(); - for (auto arg : *to_app(g)) + for (auto arg : *to_app(f)) clause.push_back(~mk_literal(arg)); m_solver_ctx->add_clause(clause.size(), clause.data()); } - else if (m.is_eq(f, g, h) && m.is_bool(g)) { + else if (m.is_iff(f, g, h)) { auto lit1 = mk_literal(g); auto lit2 = mk_literal(h); - clause.reset(); - clause.push_back(~lit1); - clause.push_back(lit2); - m_solver_ctx->add_clause(clause.size(), clause.data()); - clause.reset(); - clause.push_back(lit1); - clause.push_back(~lit2); - m_solver_ctx->add_clause(clause.size(), clause.data()); + sat::literal cls1[2] = { sign ? lit1 :~lit1, lit2 }; + sat::literal cls2[2] = { sign ? ~lit1 : lit1, ~lit2 }; + m_solver_ctx->add_clause(2, cls1); + m_solver_ctx->add_clause(2, cls2); + } + else if (m.is_ite(f, g, h, k)) { + auto lit1 = mk_literal(g); + auto lit2 = mk_literal(h); + auto lit3 = mk_literal(k); + // (g -> h) & (~g -> k) + // (g & h) | (~g & k) + // negated: (g -> ~h) & (g -> ~k) + sat::literal cls1[2] = { ~lit1, sign ? ~lit2 : lit2 }; + sat::literal cls2[2] = { lit1, sign ? ~lit3 : lit3 }; + m_solver_ctx->add_clause(2, cls1); + m_solver_ctx->add_clause(2, cls2); } else { sat::literal lit = mk_literal(f); @@ -170,6 +192,7 @@ namespace sls { sat::literal smt_solver::mk_literal(expr* e) { sat::literal lit; bool neg = false; + expr* a, * b, * c; while (m.is_not(e,e)) neg = !neg; if (m_expr2lit.find(e, lit)) @@ -197,6 +220,33 @@ namespace sls { clause.push_back(~lit); m_solver_ctx->add_clause(clause.size(), clause.data()); } + else if (m.is_iff(e, a, b)) { + lit = mk_literal(); + auto lit1 = mk_literal(a); + auto lit2 = mk_literal(b); + sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; + sat::literal cls2[3] = { ~lit, lit1, ~lit2 }; + sat::literal cls3[3] = { lit, lit1, lit2 }; + sat::literal cls4[3] = { lit, ~lit1, ~lit2 }; + m_solver_ctx->add_clause(3, cls1); + m_solver_ctx->add_clause(3, cls2); + m_solver_ctx->add_clause(3, cls3); + m_solver_ctx->add_clause(3, cls4); + } + else if (m.is_ite(e, a, b, c)) { + lit = mk_literal(); + auto lit1 = mk_literal(a); + auto lit2 = mk_literal(b); + auto lit3 = mk_literal(c); + sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; + sat::literal cls2[3] = { ~lit, lit1, lit3 }; + sat::literal cls3[3] = { lit, ~lit1, ~lit2 }; + sat::literal cls4[3] = { lit, lit1, ~lit3 }; + m_solver_ctx->add_clause(3, cls1); + m_solver_ctx->add_clause(3, cls2); + m_solver_ctx->add_clause(3, cls3); + m_solver_ctx->add_clause(3, cls4); + } else { sat::bool_var v = m_num_vars++; lit = sat::literal(v, false); @@ -218,4 +268,12 @@ namespace sls { std::ostream& smt_solver::display(std::ostream& out) { return m_solver_ctx->display(out); } + + void smt_solver::collect_statistics(statistics& st) { + m_solver_ctx->collect_statistics(st); + } + + void smt_solver::reset_statistics() { + m_solver_ctx->reset_statistics(); + } } diff --git a/src/ast/sls/sls_smt_solver.h b/src/ast/sls/sls_smt_solver.h index 0f70da4b2..5b7e0d62a 100644 --- a/src/ast/sls/sls_smt_solver.h +++ b/src/ast/sls/sls_smt_solver.h @@ -42,8 +42,8 @@ namespace sls { lbool check(); model_ref get_model(); void updt_params(params_ref& p) {} - void collect_statistics(statistics& st) { st.copy(m_st); } + void collect_statistics(statistics& st); std::ostream& display(std::ostream& out); - void reset_statistics() { m_st.reset(); } + void reset_statistics(); }; } diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index deb02f38d..e202f6769 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -25,6 +25,7 @@ Revision History: #include "util/z3_exception.h" #include "util/rational.h" +#include "util/mpn.h" class overflow_exception : public z3_exception { char const* msg() const override { return "checked_int64 overflow/underflow"; } @@ -106,7 +107,7 @@ public: checked_int64 & operator--() { if (CHECK && m_value == INT64_MIN) { - throw overflow_exception(); + throw overflo9w_exception(); } --m_value; return *this; @@ -134,8 +135,10 @@ public: uint64_t x = static_cast(m_value); uint64_t y = static_cast(other.m_value); int64_t r = static_cast(x - y); - if (m_value > 0 && other.m_value < 0 && r <= 0) throw overflow_exception(); - if (m_value < 0 && other.m_value > 0 && r >= 0) throw overflow_exception(); + if (m_value > 0 && other.m_value < 0 && r <= 0) + throw overflow_exception(); + if (m_value < 0 && other.m_value > 0 && r >= 0) + throw overflow_exception(); m_value = r; } else { @@ -149,13 +152,12 @@ public: if (INT_MIN < m_value && m_value <= INT_MAX && INT_MIN < other.m_value && other.m_value <= INT_MAX) { m_value *= other.m_value; } - // TBD: could be tuned by using known techniques or 128-bit arithmetic. else { - rational r(r64(m_value) * r64(other.m_value)); - if (!r.is_int64()) { + uint64_t x = m_value, y = other.m_value; + uint64_t z = x * y; + if (y != 0 && z / y != x) throw overflow_exception(); - } - m_value = r.get_int64(); + m_value = z; } } else { diff --git a/src/util/vector.h b/src/util/vector.h index d684f43eb..d552daea1 100644 --- a/src/util/vector.h +++ b/src/util/vector.h @@ -290,6 +290,10 @@ public: } } + vector(std::initializer_list const& l) { + for (auto const& t : l) + push_back(t); + } ~vector() { destroy();