From fce21981c6d466ea47ac9c7a3ff64e919527e055 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 27 Jul 2024 03:29:54 +0200 Subject: [PATCH] fixes to sls --- src/ast/sls/sat_ddfw.cpp | 50 +++- src/ast/sls/sat_ddfw.h | 3 - src/ast/sls/sls_arith_base.cpp | 455 ++++++++++++++++++++++++++----- src/ast/sls/sls_arith_base.h | 9 + src/ast/sls/sls_arith_plugin.cpp | 4 + src/ast/sls/sls_arith_plugin.h | 1 + src/ast/sls/sls_basic_plugin.cpp | 10 + src/ast/sls/sls_basic_plugin.h | 1 + src/ast/sls/sls_bv_plugin.cpp | 9 + src/ast/sls/sls_bv_plugin.h | 1 + src/ast/sls/sls_context.cpp | 13 + src/ast/sls/sls_context.h | 2 + src/ast/sls/sls_euf_plugin.h | 3 +- src/ast/sls/sls_smt_solver.cpp | 14 +- src/util/checked_int64.h | 25 ++ src/util/mpz.cpp | 1 + 16 files changed, 521 insertions(+), 80 deletions(-) diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index e9a0bb30e..f86800079 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -68,8 +68,9 @@ namespace sat { else if (should_restart()) do_restart(), m_plugin->on_restart(); else if (do_flip()); else shift_weights(), m_plugin->on_rescale(); - verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n"; + //verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n"; ++steps; + SASSERT(m_unsat.size() >= m_min_sz); } } catch (z3_exception& ex) { @@ -114,7 +115,7 @@ namespace sat { if (reward > 0 || (reward == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) { flip(v); - if (m_unsat.size() <= m_min_sz) + if (m_unsat.size() <= m_min_sz) save_best_values(); return true; } @@ -124,32 +125,46 @@ namespace sat { template bool_var ddfw::pick_var(double& r) { double sum_pos = 0; - unsigned n = 1; + unsigned n = 1, m = 1; bool_var v0 = null_bool_var; + bool_var v1 = null_bool_var; + if (m_unsat_vars.empty()) + return null_bool_var; for (bool_var v : m_unsat_vars) { - r = uses_plugin ? plugin_reward(v) : reward(v); + r = reward(v); if (r > 0.0) sum_pos += score(r); else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0) v0 = v; + else if (m_rand(m++) == 0) + v1 = v; } + + + if (v0 != null_bool_var && m_rand(20) == 0) + return v0; + + if (v1 != null_bool_var && m_rand(20) == 0) + return v1; + if (sum_pos > 0) { - double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos; + double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos; for (bool_var v : m_unsat_vars) { - r = uses_plugin && is_external(v) ? m_vars[v].m_last_reward : reward(v); + r = reward(v); if (r > 0) { lim_pos -= score(r); - if (lim_pos <= 0) - return v; + if (lim_pos <= 0) { + return v; + } } } } r = 0; - if (v0 != null_bool_var) - return v0; - if (m_unsat_vars.empty()) - return null_bool_var; - return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); + + if (v0 == null_bool_var) + v0 = m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); + + return v0; } void ddfw::add(unsigned n, literal const* c) { @@ -351,13 +366,15 @@ namespace sat { break; } } + save_best_values(); } bool ddfw::should_restart() { return m_flips >= m_restart_next; } - void ddfw::do_restart() { + void ddfw::do_restart() { + verbose_stream() << "restart\n"; reinit_values(); init_clause_data(); m_restart_next += m_config.m_restart_base*get_luby(++m_restart_count); @@ -403,6 +420,11 @@ namespace sat { if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11) save_model(); } +#if 0 + if (m_unsat.size() <= m_min_sz) { + verbose_stream() << "unsat " << m_clauses[m_unsat[0]] << "\n"; + } +#endif if (m_unsat.size() < m_min_sz) { m_models.reset(); // skip saving the first model. diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 7338be515..65970cbd4 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -38,7 +38,6 @@ namespace sat { virtual ~local_search_plugin() {} virtual void init_search() = 0; virtual void finish_search() = 0; - virtual double reward(bool_var v) = 0; virtual void on_rescale() = 0; virtual void on_save_model() = 0; virtual void on_restart() = 0; @@ -124,8 +123,6 @@ namespace sat { inline double& reward(bool_var v) { return m_vars[v].m_reward; } - inline double plugin_reward(bool_var v) { return is_external(v) ? (m_vars[v].m_last_reward = m_plugin->reward(v)) : reward(v); } - void set_external(bool_var v) { m_vars[v].m_external = true; } inline bool is_external(bool_var v) const { return m_vars[v].m_external; } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 60f765386..2f1045ccc 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -32,6 +32,8 @@ namespace sls { } } + + template std::ostream& arith_base::ineq::display(std::ostream& out) const { bool first = true; @@ -141,6 +143,49 @@ namespace sls { auto argsv = ineq.m_args_value; bool solved = false; num_t delta = argsv - bound; + auto const& lo = m_vars[v].m_lo; + auto const& hi = m_vars[v].m_hi; + + if (is_fixed(v)) + return false; + + auto well_formed = [&]() { + num_t new_args = argsv + coeff * (new_value - value(v)); + if (ineq.is_true()) { + switch (ineq.m_op) { + case ineq_kind::LE: return new_args > bound; + case ineq_kind::LT: return new_args >= bound; + case ineq_kind::EQ: return new_args != bound; + } + } + else { + switch (ineq.m_op) { + case ineq_kind::LE: return new_args <= bound; + case ineq_kind::LT: return new_args < bound; + case ineq_kind::EQ: return new_args == bound; + } + } + return false; + }; + + auto move_to_bounds = [&]() { + VERIFY(well_formed()); + if (!in_bounds(v, value(v))) + return true; + if (in_bounds(v, new_value)) + return true; + if (lo && lo->value > new_value) { + new_value = lo->value; + if (!well_formed()) + new_value += 1; + } + if (hi && hi->value < new_value) { + new_value = hi->value; + if (!well_formed()) + new_value -= 1; + } + return well_formed() && in_bounds(v, new_value); + }; if (ineq.is_true()) { switch (ineq.m_op) { @@ -148,24 +193,22 @@ namespace sls { // args <= bound -> args > bound SASSERT(argsv <= bound); SASSERT(delta <= 0); - delta -= 1 + (ctx.rand(10)); - new_value = value(v) + divide(v, abs(delta), coeff); - VERIFY(argsv + coeff * (new_value - value(v)) > bound); - return true; + delta -= 1; + new_value = value(v) + divide(v, abs(delta - ctx.rand(3)), coeff); + return move_to_bounds(); case ineq_kind::LT: // args < bound -> args >= bound SASSERT(argsv <= bound); SASSERT(delta <= 0); - delta = abs(delta) + ctx.rand(10); - new_value = value(v) + divide(v, delta, coeff); + delta = abs(delta); + new_value = value(v) + divide(v, delta + ctx.rand(3), coeff); VERIFY(argsv + coeff * (new_value - value(v)) >= bound); - return true; + return move_to_bounds(); case ineq_kind::EQ: { delta = abs(delta) + 1 + ctx.rand(10); int sign = ctx.rand(2) == 0 ? 1 : -1; new_value = value(v) + sign * divide(v, abs(delta), coeff); - VERIFY(argsv + coeff * (new_value - value(v)) != bound); - return true; + return move_to_bounds(); } default: UNREACHABLE(); @@ -178,16 +221,14 @@ namespace sls { SASSERT(argsv > bound); SASSERT(delta > 0); delta += ctx.rand(10); - new_value = value(v) - divide(v, delta, coeff); - VERIFY(argsv + coeff * (new_value - value(v)) <= bound); - return true; + new_value = value(v) - divide(v, delta + ctx.rand(3), coeff); + return move_to_bounds(); case ineq_kind::LT: SASSERT(argsv >= bound); SASSERT(delta >= 0); delta += 1 + ctx.rand(10); - new_value = value(v) - divide(v, delta, coeff); - VERIFY(argsv + coeff * (new_value - value(v)) < bound); - return true; + new_value = value(v) - divide(v, delta + ctx.rand(3), coeff); + return move_to_bounds(); case ineq_kind::EQ: SASSERT(delta != 0); if (delta < 0) @@ -195,12 +236,7 @@ namespace sls { else new_value = value(v) - divide(v, delta, coeff); solved = argsv + coeff * (new_value - value(v)) == bound; - if (!solved && abs(coeff) == 1) { - verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n"; - verbose_stream() << new_value << " " << value(v) << " delta " << delta << " lhs " << (argsv + coeff * (new_value - value(v))) << " bound " << bound << "\n"; - UNREACHABLE(); - } - return solved; + return solved && move_to_bounds(); default: UNREACHABLE(); break; @@ -209,6 +245,130 @@ namespace sls { return false; } + template + bool arith_base::solve_eq_pairs(ineq const& ineq) { + SASSERT(ineq.m_op == ineq_kind::EQ); + auto v = ineq.m_var_to_flip; + if (is_fixed(v)) + return false; + auto bound = -ineq.m_coeff; + auto argsv = ineq.m_args_value; + num_t a; + for (auto const& [c, w] : ineq.m_args) + if (v == w) { + a = c; + argsv -= value(v) * c; + } + if (abs(a) == 1) + return false; + verbose_stream() << "solve_eq_pairs " << ineq << " for v" << v << "\n"; + unsigned start = ctx.rand(); + for (unsigned i = 0; i < ineq.m_args.size(); ++i) { + unsigned j = (start + i) % ineq.m_args.size(); + auto const& [b, w] = ineq.m_args[j]; + if (w == v) + continue; + if (b == 1 || b == -1) + continue; + argsv -= value(w) * b; + if (solve_eq_pairs(a, v, b, w, bound - argsv)) + return true; + argsv += value(w) * b; + } + return false; + } + + // ax0 + by0 = r + // (x, y) = (x0 - k*b/g, y0 + k*a/g) + // find the min x1 >= x0 satisfying progression and where x1 >= lo(x) + // k*ab/g - k*ab/g = 0 + template + bool arith_base::solve_eq_pairs(num_t const& _a, var_t x, num_t const& _b, var_t y, num_t const& r) { + if (is_fixed(y)) + return false; + num_t x0, y0; + num_t a = _a, b = _b; + num_t g = gcd(a, b, x0, y0); + SASSERT(g >= 1); + SASSERT(g == a * x0 + b * y0); + if (!divides(g, r)) + return false; + //verbose_stream() << g << " == " << a << "*" << x0 << " + " << b << "*" << y0 << "\n"; + x0 *= div(r, g); + y0 *= div(r, g); + + //verbose_stream() << r << " == " << a << "*" << x0 << " + " << b << "*" << y0 << "\n"; + + + + auto adjust_lo = [&](num_t& x0, num_t& y0, num_t a, num_t b, optional const& lo, optional const& hi) { + if (!lo || lo->value <= x0) + return true; + // x0 + k*b/g >= lo + // k*(b/g) >= lo - x0 + // k >= (lo - x0)/(b/g) + // x1 := x0 + k*b/g + auto delta = lo->value - x0; + auto bg = abs(div(b, g)); + verbose_stream() << g << " " << bg << " " << " " << delta << "\n"; + auto k = divide(x, delta, bg); + auto x1 = x0 + k * bg; + if (hi && hi->value < x1) + return false; + x0 = x1; + y0 = y0 + k * (div(b, g) > 0 ? -div(a, g) : div(a, g)); + SASSERT(r == a * x0 + b * y0); + return true; + }; + auto adjust_hi = [&](num_t& x0, num_t& y0, num_t a, num_t b, optional const& lo, optional const& hi) { + if (!hi || hi->value >= x0) + return true; + // x0 + k*b/g <= hi + // k <= (x0 - hi)/(b/g) + auto delta = x0 - hi->value; + auto bg = abs(div(b, g)); + auto k = div(delta, bg); + auto x1 = x0 - k * bg; + if (lo && lo->value < x1) + return false; + x0 = x1; + y0 = y0 - k * (div(b, g) > 0 ? -div(a, g) : div(a, g)); + SASSERT(r == a * x0 + b * y0); + return true; + }; + auto const& lo_x = m_vars[x].m_lo; + auto const& hi_x = m_vars[x].m_hi; + + if (!adjust_lo(x0, y0, a, b, lo_x, hi_x)) + return false; + if (!adjust_hi(x0, y0, a, b, lo_x, hi_x)) + return false; + + auto const& lo_y = m_vars[y].m_lo; + auto const& hi_y = m_vars[y].m_hi; + + if (!adjust_lo(y0, x0, b, a, lo_y, hi_y)) + return false; + if (!adjust_hi(y0, x0, b, a, lo_y, hi_y)) + return false; + + if (lo_x && lo_x->value > x0) + return false; + if (hi_x && hi_x->value < x0) + return false; + + if (x0 == value(x)) + return false; + if (abs(value(x)) * 2 < abs(x0)) + return false; + if (abs(value(y)) * 2 < abs(y0)) + return false; + update(x, x0); + update(y, y0); + + return true; + } + // flip on the first positive score // it could be changed to flip on maximal positive score // or flip on maximal non-negative score @@ -216,24 +376,63 @@ namespace sls { template void arith_base::repair(sat::literal lit, ineq const& ineq) { num_t new_value, old_value; - if (UINT_MAX == ineq.m_var_to_flip) - dtt_reward(lit); + dtt_reward(lit); + auto v = ineq.m_var_to_flip; if (v == UINT_MAX) { IF_VERBOSE(0, verbose_stream() << "no var to flip\n"); return; } + + if (repair_eq(lit, ineq)) + return; + if (!cm(ineq, v, new_value)) { + display(verbose_stream(), v) << "\n"; IF_VERBOSE(0, verbose_stream() << "no critical move for " << v << "\n"); + if (dtt(!ctx.is_true(lit), ineq) != 0) + ctx.flip(lit.var()); return; } verbose_stream() << "repair " << lit << ": " << ineq << " var: v" << v << " := " << value(v) << " -> " << new_value << "\n"; + //for (auto const& [coeff, w] : ineq.m_args) + // display(verbose_stream(), w) << "\n"; update(v, new_value); - if (dtt(lit.sign(), ineq) != 0) + invariant(ineq); + if (dtt(!ctx.is_true(lit), ineq) != 0) ctx.flip(lit.var()); } + template + bool arith_base::repair_eq(sat::literal lit, ineq const& ineq) { + if (lit.sign() || ineq.m_op != ineq_kind::EQ) + return false; + auto v = ineq.m_var_to_flip; + num_t new_value; + verbose_stream() << ineq << "\n"; + for (auto const& [coeff, w] : ineq.m_args) + display(verbose_stream(), w) << "\n"; + if (ctx.rand(10) == 0 && solve_eq_pairs(ineq)) { + verbose_stream() << ineq << "\n"; + for (auto const& [coeff, w] : ineq.m_args) + display(verbose_stream(), w) << "\n"; + } + else if (cm(ineq, v, new_value) && update(v, new_value)) + ; + else if (solve_eq_pairs(ineq)) { + verbose_stream() << ineq << "\n"; + for (auto const& [coeff, w] : ineq.m_args) + display(verbose_stream(), w) << "\n"; + } + else + return false; + SASSERT(dtt(!ctx.is_true(lit), ineq) == 0); + if (dtt(!ctx.is_true(lit), ineq) != 0) + ctx.flip(lit.var()); + return true; + } + // // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) // TODO - use cached dts instead of computed dts @@ -349,6 +548,14 @@ namespace sls { return true; } + template + bool arith_base::is_fixed(var_t v) { + auto const& vi = m_vars[v]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + return lo && hi && lo->value == hi->value && lo->value == value(v); + } + template bool arith_base::update(var_t v, num_t const& new_value) { auto& vi = m_vars[v]; @@ -394,20 +601,29 @@ namespace sls { SASSERT(dtt(sign(bv), ineq) == 0); } vi.m_value = new_value; + + + SASSERT(!m.is_value(e)); + verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n"; + ctx.new_value_eh(e); for (auto idx : vi.m_muls) { auto const& [w, coeff, monomial] = m_muls[idx]; - ctx.new_value_eh(m_vars[w].m_expr); + num_t prod(coeff); + for (auto w : monomial) + prod *= value(w); + if (value(w) != prod) + update(w, prod); } for (auto idx : vi.m_adds) { auto const& ad = m_adds[idx]; - ctx.new_value_eh(m_vars[ad.m_var].m_expr); + num_t sum(ad.m_coeff); + for (auto const& [coeff, w] : ad.m_args) + sum += coeff * value(w); + if (sum != ad.m_coeff) + update(ad.m_var, sum); + } - if (m.is_value(e)) { - display(verbose_stream()); - } - SASSERT(!m.is_value(e)); - ctx.new_value_eh(e); return true; } @@ -421,7 +637,8 @@ namespace sls { template void arith_base::add_arg(linear_term& ineq, num_t const& c, var_t v) { - ineq.m_args.push_back({ c, v }); + if (c != 0) + ineq.m_args.push_back({ c, v }); } template<> @@ -686,6 +903,14 @@ namespace sls { repair(lit, *ineq); } + template + void arith_base::repair_literal(sat::literal lit) { + auto v = lit.var(); + auto const* ineq = atom(v); + if (ineq && ineq->is_true() != ctx.is_true(v)) + ctx.flip(v); + } + template bool arith_base::propagate() { return false; @@ -960,8 +1185,8 @@ namespace sls { product *= value(v); if (product == 0 || !divides(product, val)) continue; - update(w, div(val, product)); - return true; + if (update(w, div(val, product))) + return true; } return false; } @@ -976,7 +1201,7 @@ namespace sls { product *= value(v); if (product == val) return true; -// verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n"; + verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n"; unsigned sz = monomial.size(); if (ctx.rand(20) == 0) return update(v, product); @@ -1006,28 +1231,39 @@ namespace sls { if (!divides(coeff, val) && ctx.rand(2) == 0) n = div(val + coeff - 1, coeff); auto const& fs = factor(abs(n)); - vector coeffs(sz, num_t(ctx.rand(2) == 0 ? 1 : -1)); + vector coeffs(sz, num_t(1)); vector gcds(sz, num_t(0)); num_t sign(1); for (auto c : coeffs) sign *= c; - unsigned i = 0; + unsigned i = 0; for (auto w : monomial) { for (auto idx : m_vars[w].m_muls) { auto const& [w1, coeff1, monomial1] = m_muls[idx]; gcds[i] = gcd(gcds[i], abs(value(w1))); } + auto const& vi = m_vars[w]; + if (vi.m_lo && vi.m_lo->value >= 0) + coeffs[i] = 1; + else if (vi.m_hi && vi.m_hi->value < 0) + coeffs[i] = -1; + else + coeffs[i] = num_t(ctx.rand(2) == 0 ? 1 : -1); ++i; } for (auto f : fs) coeffs[ctx.rand(sz)] *= f; if ((sign == 0) != (n == 0)) coeffs[ctx.rand(sz)] *= -1; -// verbose_stream() << "value " << val << " coeff: " << coeff << " coeffs: " << coeffs << " factors: " << fs << "\n"; + verbose_stream() << "value " << val << " coeff: " << coeff << " coeffs: " << coeffs << " factors: " << fs << "\n"; i = 0; - for (auto w : monomial) - if (!update(w, coeffs[i++])) + for (auto w : monomial) { + if (!update(w, coeffs[i++])) { + verbose_stream() << "failed to update v" << w << " to " << coeffs[i - 1] << "\n"; return false; + } + } + verbose_stream() << "all updated for v" << v << " := " << value(v) << "\n"; return true; } else { @@ -1151,31 +1387,72 @@ namespace sls { if (!ineq) return -1; num_t new_value; - double max_result = -1; - unsigned n = 0; + double max_result = -100; + unsigned n = 0, mult = 2; + double sum_prob = 0; + unsigned i = 0; + m_probs.reserve(ineq->m_args.size()); for (auto const& [coeff, x] : ineq->m_args) { - if (!cm(*ineq, x, coeff, new_value)) - continue; double result = 0; - // auto old_value = m_vars[x].m_value; - for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { - result += ctx.reward(bv); -#if 0 - bool old_sign = sign(bv); - auto dtt_old = dtt(old_sign, *atom(bv)); - auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value); - if ((dtt_new == 0) != (dtt_old == 0)) - result += ctx.reward(bv); -#endif - } - if (result > max_result || max_result == -1 || (result == max_result && (ctx.rand(++n) == 0))) { - max_result = result; - ineq->m_var_to_flip = x; + double prob = 0; + if (is_fixed(x)) + prob = 0; + else if (!cm(*ineq, x, coeff, new_value)) + prob = 0.5; + else { + + auto old_value = m_vars[x].m_value; + for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { + bool old_sign = sign(bv); + auto dtt_old = dtt(old_sign, *atom(bv)); + auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value); + if (dtt_new == 0 && dtt_old != 0) + result += 1; + if (dtt_new != 0 && dtt_old == 0) + result -= 1; + } + + if (result > max_result || max_result == -100 || (result == max_result && (ctx.rand(++n) == 0))) + max_result = result; + + if (result < 0) + prob = 0.1; + else if (result == 0) + prob = 0.2; + else + prob = result; + } + // verbose_stream() << "prob v" << x << " " << prob << "\n"; + m_probs[i++] = prob; + sum_prob += prob; } + double lim = sum_prob * ((double)ctx.rand() / random_gen().max_value()); + do { + lim -= m_probs[--i]; + } + while (lim >= 0 && i > 0); + + ineq->m_var_to_flip = ineq->m_args[i].second; + return max_result; } +#if 0 + double sum_prob = 0; + unsigned i = 0; + clause const& c = get_clause(cls_idx); + for (literal lit : c) { + double prob = m_prob_break[m_breaks[lit.var()]]; + m_probs[i++] = prob; + sum_prob += prob; + } + double lim = sum_prob * ((double)m_rand() / m_rand.max_value()); + do { + lim -= m_probs[--i]; + } while (lim >= 0 && i > 0); +#endif + // Newton function for integer square root. template num_t arith_base::sqrt(num_t n) { @@ -1203,14 +1480,14 @@ namespace sls { } } static int increments[8] = { 4, 2, 4, 2, 4, 6, 2, 6 }; - unsigned i = 0; - for (auto d = num_t(7); d * d <= n; d += num_t(increments[i++])) { + unsigned i = 0, j = 0; + for (auto d = num_t(7); d * d <= n && j < 3; d += num_t(increments[i++]), ++j) { while (mod(n, d) == 0) { m_factors.push_back(d); n = div(n, d); } if (i == 8) - i = 0; + i = 0; } if (n > 1) m_factors.push_back(n); @@ -1314,6 +1591,7 @@ namespace sls { template bool arith_base::is_sat() { + invariant(); for (auto const& clause : ctx.clauses()) { bool sat = false; for (auto lit : clause.m_clause) { @@ -1405,6 +1683,61 @@ namespace sls { return out; } + template + void arith_base::invariant() { + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { + auto ineq = atom(v); + if (ineq) + invariant(*ineq); + } + auto& out = verbose_stream(); + for (auto md : m_muls) { + auto const& [w, coeff, monomial] = md; + num_t prod(coeff); + for (auto v : monomial) + prod *= value(v); + //verbose_stream() << "check " << w << " " << monomial << "\n"; + if (prod != value(w)) { + out << prod << " " << value(w) << "\n"; + out << "v" << w << " := "; + for (auto w : monomial) + out << "v" << w << " "; + out << "\n"; + } + SASSERT(prod == value(w)); + + } + for (auto ad : m_adds) { + //out << "check add " << ad.m_var << "\n"; + num_t sum(ad.m_coeff); + for (auto [c, w] : ad.m_args) + sum += c * value(w); + if (sum != value(ad.m_var)) { + + + out << "v" << ad.m_var << " := "; + bool first = true; + for (auto [c, w] : ad.m_args) + out << (first ? "" : " + ") << c << "* v" << w; + if (ad.m_coeff != 0) + out << " + " << ad.m_coeff; + out << "\n"; + } + SASSERT(sum == value(ad.m_var)); + } + } + + template + void arith_base::invariant(ineq const& i) { + num_t val(0); + for (auto const& [c, v] : i.m_args) + val += c * value(v); + //verbose_stream() << "invariant " << i << "\n"; + if (val != i.m_args_value) + verbose_stream() << i << "\n"; + SASSERT(val == i.m_args_value); + } + template void arith_base::mk_model(model& mdl) { } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index e24059135..488101f51 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -102,9 +102,13 @@ namespace sls { vector m_adds; vector m_ops; unsigned_vector m_expr2var; + svector m_probs; bool m_dscore_mode = false; arith_util a; + void invariant(); + void invariant(ineq const& i); + unsigned get_num_vars() const { return m_vars.size(); } bool repair_mul1(mul_def const& md); @@ -120,7 +124,9 @@ namespace sls { bool repair_to_int(op_def const& od); bool repair_to_real(op_def const& od); void repair(sat::literal lit, ineq const& ineq); + bool repair_eq(sat::literal lit, ineq const& ineq); bool in_bounds(var_t v, num_t const& value); + bool is_fixed(var_t v); vector m_factors; vector const& factor(num_t n); @@ -144,6 +150,8 @@ namespace sls { double dtt_reward(sat::literal lit); double dscore(var_t v, num_t const& new_value) const; void save_best_values(); + bool solve_eq_pairs(ineq const& ineq); + bool solve_eq_pairs(num_t const& a, var_t x, num_t const& b, var_t y, num_t const& r); var_t mk_var(expr* e); var_t mk_term(expr* e); @@ -180,6 +188,7 @@ namespace sls { bool propagate() override; void repair_up(app* e) override; bool repair_down(app* e) override; + void repair_literal(sat::literal lit) override; bool is_sat() override; void on_rescale() override; void on_restart() override; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index aefcbb0ee..92a94ac37 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -110,6 +110,10 @@ namespace sls { WITH_FALLBACK(repair_up(e)); } + void arith_plugin::repair_literal(sat::literal lit) { + WITH_FALLBACK(repair_literal(lit)); + } + void arith_plugin::set_value(expr* e, expr* v) { WITH_FALLBACK(set_value(e, v)); } diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 70ec38565..99951a22d 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -37,6 +37,7 @@ namespace sls { bool propagate() override; bool repair_down(app* e) override; void repair_up(app* e) override; + void repair_literal(sat::literal lit) override; bool is_sat() override; void on_rescale() override; diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp index 0600e1df5..d966ee7b3 100644 --- a/src/ast/sls/sls_basic_plugin.cpp +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -335,8 +335,17 @@ namespace sls { set_value(e, b); } + void basic_plugin::repair_literal(sat::literal lit) { + auto a = ctx.atom(lit.var()); + if (!is_basic(a)) + return; + if (bval1(to_app(a)) != bval0(to_app(a))) + ctx.flip(lit.var()); + } + bool basic_plugin::repair_down(app* e) { SASSERT(m.is_bool(e)); + unsigned n = e->get_num_args(); if (!is_basic(e)) return false; @@ -345,6 +354,7 @@ namespace sls { if (bval0(e) == bval1(e)) return true; + verbose_stream() << "basic repair down " << mk_bounded_pp(e, m) << "\n"; unsigned s = ctx.rand(n); for (unsigned i = 0; i < n; ++i) { auto j = (i + s) % n; diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h index 64890f81e..82a7835c7 100644 --- a/src/ast/sls/sls_basic_plugin.h +++ b/src/ast/sls/sls_basic_plugin.h @@ -48,6 +48,7 @@ namespace sls { bool propagate() override; bool repair_down(app* e) override; void repair_up(app* e) override; + void repair_literal(sat::literal lit) override; bool is_sat() override; void on_rescale() override {} diff --git a/src/ast/sls/sls_bv_plugin.cpp b/src/ast/sls/sls_bv_plugin.cpp index 032ff397d..9fc4d4b62 100644 --- a/src/ast/sls/sls_bv_plugin.cpp +++ b/src/ast/sls/sls_bv_plugin.cpp @@ -143,6 +143,15 @@ namespace sls { } } + void bv_plugin::repair_literal(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto a = ctx.atom(lit.var()); + if (!a || !is_app(a)) + return; + if (!m_eval.eval_is_correct(to_app(a))) + ctx.flip(lit.var()); + } + std::ostream& bv_plugin::trace_repair(bool down, expr* e) { verbose_stream() << (down ? "d #" : "u #") << e->get_id() << ": " diff --git a/src/ast/sls/sls_bv_plugin.h b/src/ast/sls/sls_bv_plugin.h index e675f04cc..ead5d7d8d 100644 --- a/src/ast/sls/sls_bv_plugin.h +++ b/src/ast/sls/sls_bv_plugin.h @@ -45,6 +45,7 @@ namespace sls { bool propagate() override; bool repair_down(app* e) override; void repair_up(app* e) override; + void repair_literal(sat::literal lit) override; bool is_sat() override; void on_rescale() override {} diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 2b802b90f..0525bfb39 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -64,7 +64,10 @@ namespace sls { propagate_boolean_assignment(); + verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n"; + + // display(verbose_stream()); if (m_new_constraint || !unsat().empty()) @@ -129,6 +132,16 @@ namespace sls { for (auto p : m_plugins) propagated |= p && !m_new_constraint && p->propagate(); } + + for (sat::bool_var v = 0; v < s.num_vars(); ++v) { + auto a = atom(v); + if (!a) + continue; + sat::literal lit(v, !is_true(v)); + auto p = m_plugins.get(get_fid(a), nullptr); + if (p) + p->repair_literal(lit); + } } family_id context::get_fid(expr* e) const { diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 7b481d3a8..ec6658954 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -42,6 +42,7 @@ namespace sls { virtual void initialize() = 0; virtual bool propagate() = 0; virtual void propagate_literal(sat::literal lit) = 0; + virtual void repair_literal(sat::literal lit) = 0; virtual bool repair_down(app* e) = 0; virtual void repair_up(app* e) = 0; virtual bool is_sat() = 0; @@ -69,6 +70,7 @@ namespace sls { virtual void on_model(model_ref& mdl) = 0; virtual sat::bool_var add_var() = 0; virtual void add_clause(unsigned n, sat::literal const* lits) = 0; + virtual std::ostream& display(std::ostream& out) = 0; }; class context { diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h index 6578e0a3b..e0871bb77 100644 --- a/src/ast/sls/sls_euf_plugin.h +++ b/src/ast/sls/sls_euf_plugin.h @@ -41,7 +41,7 @@ namespace sls { expr_ref get_value(expr* e) override; void initialize() override {} void propagate_literal(sat::literal lit) override {} - bool propagate() override; + bool propagate() override; bool is_sat() override; void register_term(expr* e) override; std::ostream& display(std::ostream& out) const override; @@ -50,6 +50,7 @@ namespace sls { void repair_up(app* e) override {} bool repair_down(app* e) override { return false; } + void repair_literal(sat::literal lit) override {} }; } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 511682350..00f5f2a2a 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -122,7 +122,7 @@ namespace sls { } void smt_solver::add_clause(expr* f) { - expr* g; + expr* g, * h; sat::literal_vector clause; if (m.is_not(f, g) && m.is_not(g, g)) { add_clause(g); @@ -149,6 +149,18 @@ namespace sls { clause.push_back(~mk_literal(arg)); m_solver_ctx->add_clause(clause.size(), clause.data()); } + else if (m.is_eq(f, g, h) && m.is_bool(g)) { + auto lit1 = mk_literal(g); + auto lit2 = mk_literal(h); + clause.reset(); + clause.push_back(~lit1); + clause.push_back(lit2); + m_solver_ctx->add_clause(clause.size(), clause.data()); + clause.reset(); + clause.push_back(lit1); + clause.push_back(~lit2); + m_solver_ctx->add_clause(clause.size(), clause.data()); + } else { sat::literal lit = mk_literal(f); m_solver_ctx->add_clause(1, &lit); diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index 174b7b10a..deb02f38d 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -323,4 +323,29 @@ inline checked_int64 gcd(checked_int64 const& a, checked_int64 +inline checked_int64 gcd(checked_int64 const& a, checked_int64 const& b, + checked_int64& x, checked_int64& y) { + checked_int64 _a = a; + checked_int64 _b = b; + x = 0; + y = 0; + checked_int64 lastx = 1; + checked_int64 lasty = 0; + while (_b != 0) { + checked_int64 q = div(_a, _b); + checked_int64 r = mod(_a, _b); + _a = _b; + _b = r; + checked_int64 temp = x; + x = lastx - q * x; + lastx = temp; + temp = y; + y = lasty - q * y; + lasty = temp; + } + return _a; } \ No newline at end of file diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index 296b4426e..011b62f7f 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -609,6 +609,7 @@ void mpz_manager::div_gcd(mpz const& a, mpz const& b, mpz & c) { template void mpz_manager::div(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz-ext] div(" << to_string(a) << ", " << to_string(b) << ") == ";); + SASSERT(!is_zero(b)); if (is_one(b)) { set(c, a); }