From 6466345755817618e60fb89b29b17c570c9eac56 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 26 Dec 2023 14:10:43 -0800 Subject: [PATCH] viable revisit v1 Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/CMakeLists.txt | 1 + src/sat/smt/polysat/number.h | 30 +++ src/sat/smt/polysat/refine.cpp | 213 ++++++++++++++++ src/sat/smt/polysat/refine.h | 46 ++++ src/sat/smt/polysat/viable.cpp | 382 +++++++++++++++++++++++++---- src/sat/smt/polysat/viable.h | 9 +- 6 files changed, 638 insertions(+), 43 deletions(-) create mode 100644 src/sat/smt/polysat/number.h create mode 100644 src/sat/smt/polysat/refine.cpp create mode 100644 src/sat/smt/polysat/refine.h diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt index 70e0f9592..80168e713 100644 --- a/src/sat/smt/polysat/CMakeLists.txt +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -7,6 +7,7 @@ z3_add_component(polysat forbidden_intervals.cpp inequality.cpp op_constraint.cpp + refine.cpp saturation.cpp ule_constraint.cpp umul_ovfl_constraint.cpp diff --git a/src/sat/smt/polysat/number.h b/src/sat/smt/polysat/number.h new file mode 100644 index 000000000..56c66b2f4 --- /dev/null +++ b/src/sat/smt/polysat/number.h @@ -0,0 +1,30 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat numbers + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "sat/smt/polysat/types.h" + +namespace polysat { + + inline unsigned get_parity(rational const& val, unsigned num_bits) { + if (val.is_zero()) + return num_bits; + return val.trailing_zeros(); + }; + + /** Return val with the lower k bits set to zero. */ + inline rational clear_lower_bits(rational const& val, unsigned k) { + return val - mod(val, rational::power_of_two(k)); + } + +} diff --git a/src/sat/smt/polysat/refine.cpp b/src/sat/smt/polysat/refine.cpp new file mode 100644 index 000000000..c706f643d --- /dev/null +++ b/src/sat/smt/polysat/refine.cpp @@ -0,0 +1,213 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + helpers for refining intervals + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +Notes: + +--*/ + + +#include "util/debug.h" +#include "sat/smt/polysat/refine.h" +#include "sat/smt/polysat/number.h" + +namespace { + rational div_floor(rational const& a, rational const& b) { + return floor(a / b); + } + + rational div_ceil(rational const& a, rational const& b) { + return ceil(a / b); + } +} + +namespace polysat { + + rational refine_equal::compute_y_max(rational const& y0, rational const& a, rational const& lo0, rational const& hi, rational const& M) { + // verbose_stream() << "y0=" << y0 << " a=" << a << " lo0=" << lo0 << " hi=" << hi << " M=" << M << std::endl; + // SASSERT(0 <= y0 && y0 < M); // not required + SASSERT(1 <= a && a < M); + SASSERT(0 <= lo0 && lo0 < M); + SASSERT(0 <= hi && hi < M); + + if (lo0 <= hi) { + SASSERT(lo0 <= mod(a*y0, M) && mod(a*y0, M) <= hi); + } + else { + SASSERT(mod(a*y0, M) <= hi || mod(a*y0, M) >= lo0); + } + + // wrapping intervals are handled by replacing the lower bound lo by lo - M + rational const lo = lo0 > hi ? (lo0 - M) : lo0; + + // the length of the interval is now hi - lo + 1. + // full intervals shouldn't go through this computation. + SASSERT(hi - lo + 1 < M); + + auto contained = [&lo, &hi] (rational const& a_y) -> bool { + return lo <= a_y && a_y <= hi; + }; + + auto delta_h = [&a, &lo, &hi] (rational const& a_y) -> rational { + SASSERT(lo <= a_y && a_y <= hi); + (void)lo; // avoid warning in release mode + return div_floor(hi - a_y, a); + }; + + // minimal k such that lo <= a*y0 + k*M + rational const k = div_ceil(lo - a * y0, M); + rational const kM = k*M; + rational const a_y0 = a*y0 + kM; + SASSERT(contained(a_y0)); + + // maximal y for [lo;hi]-interval around a*y0 + // rational const y0h = y0 + div_floor(hi - a_y0, a); + rational const delta0 = delta_h(a_y0); + rational const y0h = y0 + delta0; + rational const a_y0h = a_y0 + a*delta0; + SASSERT(y0 <= y0h); + SASSERT(contained(a_y0h)); + + // Check the first [lo;hi]-interval after a*y0 + rational const y1l = y0h + 1; + rational const a_y1l = a_y0h + a - M; + if (!contained(a_y1l)) + return y0h; + rational const delta1 = delta_h(a_y1l); + rational const y1h = y1l + delta1; + rational const a_y1h = a_y1l + a*delta1; + SASSERT(y1l <= y1h); + SASSERT(contained(a_y1h)); + + // Check the second [lo;hi]-interval after a*y0 + rational const y2l = y1h + 1; + rational const a_y2l = a_y1h + a - M; + if (!contained(a_y2l)) + return y1h; + SASSERT(contained(a_y2l)); + + // At this point, [y1l;y1h] must be a full y-interval that can be extended to the right. + // Extending the interval can only be possible if the part not covered by [lo;hi] is smaller than the coefficient a. + // The size of the gap is (lo + M) - (hi + 1). + SASSERT(lo + M - hi - 1 < a); + + // The points a*[y0l;y0h] + k*M fall into the interval [lo;hi]. + // After the first overflow, the points a*[y1l .. y1h] + (k - 1)*M fall into [lo;hi]. + // With each overflow, these points drift by some offset alpha. + rational const step = y1h - y0h; + rational const alpha = a * step - M; + + if (alpha == 0) { + // the points do not drift after overflow + // => y_max is infinite + return y0 + M; + } + + rational const i = + alpha < 0 + // alpha < 0: + // With each overflow to the right, the points drift to the left. + // We can continue overflowing while a * yil >= lo, where yil = y1l + i * step. + ? div_floor(lo - a_y1l, alpha) + // alpha > 0: + // With each overflow to the right, the points drift to the right. + // We can continue overflowing while a * yih <= hi, where yih = y1h + i * step. + : div_floor(hi - a_y1h, alpha); + + // i is the number of overflows to the right + SASSERT(i >= 0); + + // a * [yil;yih] is the right-most y-interval that is completely in [lo;hi]. + rational const yih = y1h + i * step; + rational const a_yih = a_y1h + i * alpha; + SASSERT_EQ(a_yih, a*yih + (k - i - 1)*M); + SASSERT(contained(a_yih)); + + // The next interval to the right may contain a few more values if alpha > 0 + // (because only the upper end moved out of the interval) + rational const y_next = yih + 1; + rational const a_y_next = a_yih + a - M; + if (contained(a_y_next)) + return y_next + delta_h(a_y_next); + else + return yih; + } + + rational refine_equal::compute_y_min(rational const& y0, rational const& a, rational const& lo, rational const& hi, rational const& M) { + // verbose_stream() << "y0=" << y0 << " a=" << a << " lo=" << lo << " hi=" << hi << " M=" << M << std::endl; + // SASSERT(0 <= y0 && y0 < M); // not required + SASSERT(1 <= a && a < M); + SASSERT(0 <= lo && lo < M); + SASSERT(0 <= hi && hi < M); + + auto const negateM = [&M] (rational const& x) -> rational { + if (x.is_zero()) + return x; + else + return M - x; + }; + + rational y_min = -compute_y_max(-y0, a, negateM(hi), negateM(lo), M); + while (y_min > y0) + y_min -= M; + return y_min; + } + + std::pair refine_equal::compute_y_bounds(rational const& y0, rational const& a, rational const& lo, rational const& hi, rational const& M) { + // verbose_stream() << "y0=" << y0 << " a=" << a << " lo=" << lo << " hi=" << hi << " M=" << M << std::endl; + SASSERT(0 <= y0 && y0 < M); + SASSERT(1 <= a && a < M); + SASSERT(0 <= lo && lo < M); + SASSERT(0 <= hi && hi < M); + + auto const is_valid = [&] (rational const& y) -> bool { + rational const a_y = mod(a * y, M); + if (lo <= hi) + return lo <= a_y && a_y <= hi; + else + return a_y <= hi || lo <= a_y; + }; + + unsigned const max_refinements = 100; + unsigned i = 0; + rational const y_max_max = y0 + M - 1; + rational y_max = compute_y_max(y0, a, lo, hi, M); + while (y_max < y_max_max && is_valid(y_max + 1)) { + y_max = compute_y_max(y_max + 1, a, lo, hi, M); + if (++i == max_refinements) { + // verbose_stream() << "y0=" << y0 << ", a=" << a << ", lo=" << lo << ", hi=" << hi << "\n"; + // verbose_stream() << "refined y_max: " << y_max << "\n"; + break; + } + } + + i = 0; + rational const y_min_min = y_max - M + 1; + rational y_min = y0; + while (y_min > y_min_min && is_valid(y_min - 1)) { + y_min = compute_y_min(y_min - 1, a, lo, hi, M); + if (++i == max_refinements) { + // verbose_stream() << "y0=" << y0 << ", a=" << a << ", lo=" << lo << ", hi=" << hi << "\n"; + // verbose_stream() << "refined y_min: " << y_min << "\n"; + break; + } + } + + SASSERT(y_min <= y0 && y0 <= y_max); + rational const len = y_max - y_min + 1; + if (len >= M) + // full + return { rational::zero(), M - 1 }; + else + return { mod(y_min, M), mod(y_max, M) }; + } + +} diff --git a/src/sat/smt/polysat/refine.h b/src/sat/smt/polysat/refine.h new file mode 100644 index 000000000..d7154470e --- /dev/null +++ b/src/sat/smt/polysat/refine.h @@ -0,0 +1,46 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + helpers for refining intervals + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include "sat/smt/polysat/types.h" + +namespace polysat { + + namespace refine_equal { + + /** + * Given a*y0 mod M \in [lo;hi], try to find the largest y_max >= y0 such that for all y \in [y0;y_max] . a*y mod M \in [lo;hi]. + * Result may not be optimal. + * NOTE: upper bound is inclusive. + */ + rational compute_y_max(rational const& y0, rational const& a, rational const& lo0, rational const& hi, rational const& M); + + /** + * Given a*y0 mod M \in [lo;hi], try to find the smallest y_min <= y0 such that for all y \in [y_min;y0] . a*y mod M \in [lo;hi]. + * Result may not be optimal. + * NOTE: upper bound is inclusive. + */ + rational compute_y_min(rational const& y0, rational const& a, rational const& lo, rational const& hi, rational const& M); + + /** + * Given a*y0 mod M \in [lo;hi], + * find the largest interval [y_min;y_max] around y0 such that for all y \in [y_min;y_max] . a*y mod M \in [lo;hi]. + * Result may not be optimal. + * NOTE: upper bounds are inclusive. + */ + std::pair compute_y_bounds(rational const& y0, rational const& a, rational const& lo, rational const& hi, rational const& M); + + } + +} diff --git a/src/sat/smt/polysat/viable.cpp b/src/sat/smt/polysat/viable.cpp index a5d6ca839..12945c9ee 100644 --- a/src/sat/smt/polysat/viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -20,6 +20,8 @@ Notes: #include "util/log.h" #include "sat/smt/polysat/viable.h" #include "sat/smt/polysat/core.h" +#include "sat/smt/polysat/number.h" +#include "sat/smt/polysat/refine.h" #include "sat/smt/polysat/ule_constraint.h" namespace polysat { @@ -134,20 +136,21 @@ namespace polysat { // Refine? // lbool viable::next_viable(rational& val) { + unsigned rounds = 0; do { + if (rounds > 10) + return l_undef; + ++rounds; rational val0 = val; auto r = next_viable_unit(val); if (r != l_true) return r; - if (val0 != val) - continue; if (!m_fixed_bits.next(val)) return l_false; - if (val0 != val) + if (refine_equal_lin(m_var, val)) + continue; + if (refine_disequal_lin(m_var, val)) continue; - r = next_viable_non_unit(val); - if (r != l_true) - return r; if (val0 != val) continue; } @@ -157,10 +160,9 @@ namespace polysat { // // - // from smallest overlap [w] to largest - // from smallest layer [bit_width, entries] to largest - // check if val is allowed by entries - // + // from smallest size(w) overlap [w] to largest + // from smallest bit_width layer [bit_width, entries] to largest + // check if val is allowed by entries or advance val to next allowed value // lbool viable::next_viable_unit(rational& val) { for (auto const& [w, offset] : m_overlaps) { @@ -172,7 +174,7 @@ namespace polysat { } lbool viable::next_viable_overlap(pvar w, rational& val) { - for (auto const& layer : m_units[w].get_layers()) { + for (auto& layer : m_units[w].get_layers()) { auto r = next_viable_layer(w, layer, val); if (r != l_true) return r; @@ -180,43 +182,341 @@ namespace polysat { return l_true; } - lbool viable::next_viable_layer(pvar w, layer const& layer, rational& val) { - unsigned num_bits_w = c.size(w); - unsigned num_bits_l = layer.bit_width; + /* + * v in [lo, hi[: + * - hi >= v: forward(v) := hi + * - hi < v: l_false + * 2^k v in [lo, hi[: + * - hi > 2^k v: forward(v) := hi//2^k + 2^k(v//2^k) + * - hi <= 2^k v: forward(v) := hi//2^k + 2^k(v//2^k + 1) unless it overflows. + * w is a suffix of v of width w.width <= v.width with forbidden 2^l w not in [lo, hi[ and 2^l v[w.width-1:0] in [lo, hi[. + * - set k := l + v.width - w.width, lo' := 2^{v.width-w.width} lo, hi' := 2^{v.width-w.width} hi. + */ + lbool viable::next_viable_layer(pvar w, layer& layer, rational& val) { + unsigned v_width = m_num_bits; + unsigned w_width = c.size(w); + unsigned l = w_width - layer.bit_width; + SASSERT(v_width >= w_width); + SASSERT(layer.bit_width <= w_width); - auto is_before = [&](entry* e) { - return false; - }; - - auto is_after = [&](entry* e) { - return false; - }; - - auto is_conflicting = [&](entry* e) { - return false; - }; - - auto increase_value = [&](entry* e) { - - }; - - auto first = layer.entries, e = first; - do { - if (is_conflicting(e)) - increase_value(e); - if (is_before(e)) - e = e->next(); - else if (is_after(e)) + bool is_zero = val.is_zero(), wrapped = false; + rational val1 = val; + rational const& p2l = rational::power_of_two(l); + rational const& p2w = rational::power_of_two(w_width); + + while (true) { + if (l > 0) + val1 *= p2l; + if (w_width < v_width || l > 0) + val1 = mod(val1, p2w); + auto e = find_overlap(val1, layer.entries); + if (!e) { + if (l > 0) + val1 /= p2l; break; + } + // TODO check if admitted: layer.entries = e; + m_explain.push_back(e); + auto hi = e->interval.hi_val(); + if (hi < val1) { + if (is_zero) + return l_false; + if (w_width == v_width && l == 0) + return l_false; + // start from 0 and find the next viable value within this layer. + val1 = 0; + wrapped = true; + } + val1 = hi; + SASSERT(val1 < p2w); + // p2l * x = val1 = hi + if (l > 0) + val1 = hi / p2l; + SASSERT(val1.is_int()); + } + SASSERT(val1 < p2w); + if (w_width < v_width) { + if (l > 0) + NOT_IMPLEMENTED_YET(); + rational val2 = val; + if (wrapped) { + val2 = mod(div(val2, p2w) + 1, p2w) * p2w; + if (val2 == 0) + return l_false; + } else - return l_false; - } while (e != first); + val2 = clear_lower_bits(val2, w_width); + val = val1 + val2; + } + else if (l > 0) { + NOT_IMPLEMENTED_YET(); + } + else + val = val1; + return l_true; } + // Find interval that contains 'val', or, if no such interval exists, null. + viable::entry* viable::find_overlap(rational const& val, entry* entries) { + SASSERT(entries); + // display_all(std::cerr << "entries:\n\t", 0, entries, "\n\t"); + entry* const first = entries; + entry* e = entries; + do { + auto const& i = e->interval; + if (i.currently_contains(val)) + return e; + entry* const n = e->next(); + // there is only one interval, and it does not contain 'val' + if (e == n) + return nullptr; + // check whether 'val' is contained in the gap between e and n + bool const overlapping = e->interval.currently_contains(n->interval.lo_val()); + if (!overlapping && r_interval::contains(e->interval.hi_val(), n->interval.lo_val(), val)) + return nullptr; + e = n; + } + while (e != first); + UNREACHABLE(); + return nullptr; + } - lbool viable::next_viable_non_unit(rational& val) { - return l_undef; + bool viable::refine_equal_lin(pvar v, rational const& val) { + // LOG_H2("refine-equal-lin with v" << v << ", val = " << val); + entry const* e = m_equal_lin[v]; + if (!e) + return true; + entry const* first = e; + auto& m = c.var2pdd(v); + unsigned const N = m.power_of_2(); + rational const& max_value = m.max_value(); + rational const& mod_value = m.two_to_N(); + SASSERT(0 <= val && val <= max_value); + + // Rotate the 'first' entry, to prevent getting stuck in a refinement loop + // with an early entry when a later entry could give a better interval. + m_equal_lin[v] = m_equal_lin[v]->next(); + + do { + rational coeff_val = mod(e->coeff * val, mod_value); + if (e->interval.currently_contains(coeff_val)) { + // IF_LOGGING( + // verbose_stream() << "refine-equal-lin for v" << v << " in src: "; + // for (const auto& src : e->src) + // verbose_stream() << lit_pp(s, src) << "\n"; + // ); + // LOG("forbidden interval v" << v << " " << num_pp(s, v, val) << " " << num_pp(s, v, e->coeff, true) << " * " << e->interval); + + if (mod(e->interval.hi_val() + 1, mod_value) == e->interval.lo_val()) { + // We have an equation: a * v == b + rational const a = e->coeff; + rational const b = e->interval.hi_val(); + LOG("refine-equal-lin: equation detected: " << dd::val_pp(m, a, true) << " * v" << v << " == " << dd::val_pp(m, b, false)); + unsigned const parity_a = get_parity(a, N); + unsigned const parity_b = get_parity(b, N); + if (parity_a > parity_b) { + // No solution + LOG("refined: no solution due to parity"); + entry* ne = alloc_entry(v, e->constraint_index); + ne->refined = true; + ne->src = e->src; + ne->side_cond = e->side_cond; + ne->coeff = 1; + ne->bit_width = e->bit_width; + ne->interval = eval_interval::full(); + intersect(v, ne); + return false; + } + if (parity_a == 0) { + // "fast path" for odd a + rational a_inv; + VERIFY(a.mult_inverse(N, a_inv)); + rational const hi = mod(a_inv * b, mod_value); + rational const lo = mod(hi + 1, mod_value); + // LOG("refined to [" << num_pp(c, v, lo) << ", " << num_pp(c, v, hi) << "["); + SASSERT_EQ(mod(a * hi, mod_value), b); // hi is the solution + entry* ne = alloc_entry(v, e->constraint_index); + ne->refined = true; + ne->src = e->src; + ne->side_cond = e->side_cond; + ne->coeff = 1; + ne->bit_width = e->bit_width; + ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi); + SASSERT(ne->interval.currently_contains(val)); + intersect(v, ne); + return false; + } + // 2^k * v == a_inv * b + // 2^k solutions because only the lower N-k bits of v are fixed. + // + // Smallest solution is v0 == a_inv * (b >> k) + // Solutions are of the form v_i = v0 + 2^(N-k) * i for i in { 0, 1, ..., 2^k - 1 }. + // Forbidden intervals: [v_i + 1; v_{i+1}[ == [ v_i + 1; v_i + 2^(N-k) [ + // We need the interval that covers val: + // v_i + 1 <= val < v_i + 2^(N-k) + // + // TODO: create one interval for v[N-k:] instead of 2^k intervals for v. + unsigned const k = parity_a; + rational const a_inv = a.pseudo_inverse(N); + unsigned const N_minus_k = N - k; + rational const two_to_N_minus_k = rational::power_of_two(N_minus_k); + rational const v0 = mod(a_inv * machine_div2k(b, k), two_to_N_minus_k); + SASSERT(mod(val, two_to_N_minus_k) != v0); // val is not a solution + rational const vi = v0 + clear_lower_bits(mod(val - v0, mod_value), N_minus_k); + rational const lo = mod(vi + 1, mod_value); + rational const hi = mod(vi + two_to_N_minus_k, mod_value); + // LOG("refined to [" << num_pp(c, v, lo) << ", " << num_pp(c, v, hi) << "["); + SASSERT_EQ(mod(a * (lo - 1), mod_value), b); // lo-1 is a solution + SASSERT_EQ(mod(a * hi, mod_value), b); // hi is a solution + entry* ne = alloc_entry(v, e->constraint_index); + ne->refined = true; + ne->src = e->src; + ne->side_cond = e->side_cond; + ne->coeff = 1; + ne->bit_width = e->bit_width; + ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi); + SASSERT(ne->interval.currently_contains(val)); + intersect(v, ne); + return false; + } + + // TODO: special handling for the even factors of e->coeff = 2^k * a', a' odd + // (create one interval for v[N-k:] instead of 2^k intervals for v) + + // TODO: often, the intervals alternate between short forbidden and short allowed intervals. + // we should be able to handle this similarly to compute_y_bounds, + // and be able to represent such periodic intervals (inside safe bounds). + // + // compute_y_bounds calculates with inclusive upper bound, so we need to adjust argument and result accordingly. + rational const hi_val_incl = e->interval.hi_val().is_zero() ? max_value : (e->interval.hi_val() - 1); + auto [lo, hi] = refine_equal::compute_y_bounds(val, e->coeff, e->interval.lo_val(), hi_val_incl, mod_value); + hi += 1; + //LOG("refined to [" << num_pp(c, v, lo) << ", " << num_pp(c, v, hi) << "["); + // verbose_stream() << "lo=" << lo << " val=" << val << " hi=" << hi << "\n"; + if (lo <= hi) { + SASSERT(0 <= lo && lo <= val && val < hi && hi <= mod_value); + } + else { + SASSERT(0 < hi && hi < lo && lo < mod_value && (val < hi || lo <= val)); + } + bool full = (lo == 0 && hi == mod_value); + if (hi == mod_value) + hi = 0; + entry* ne = alloc_entry(v, e->constraint_index); + ne->refined = true; + ne->src = e->src; + ne->side_cond = e->side_cond; + ne->coeff = 1; + ne->bit_width = e->bit_width; + if (full) + ne->interval = eval_interval::full(); + else + ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi); + SASSERT(ne->interval.currently_contains(val)); + intersect(v, ne); + return false; + } + e = e->next(); + } while (e != first); + return true; + } + + bool viable::refine_disequal_lin(pvar v, rational const& val) { + // LOG_H2("refine-disequal-lin with v" << v << ", val = " << val); + entry const* e = m_diseq_lin[v]; + if (!e) + return true; + entry const* first = e; + auto& m = c.var2pdd(v); + rational const& max_value = m.max_value(); + rational const& mod_value = m.two_to_N(); + SASSERT(0 <= val && val <= max_value); + + // Rotate the 'first' entry, to prevent getting stuck in a refinement loop + // with an early entry when a later entry could give a better interval. + m_diseq_lin[v] = m_diseq_lin[v]->next(); + + do { + // IF_LOGGING( + // verbose_stream() << "refine-disequal-lin for v" << v << " in src: "; + // for (const auto& src : e->src) + // verbose_stream() << lit_pp(s, src) << "\n"; + // ); + + // We compute an interval if the concrete value 'val' violates the constraint: + // p*val + q > r*val + s if e->src.is_positive() + // p*val + q >= r*val + s if e->src.is_negative() + // Note that e->interval is meaningless in this case, + // we just use it to transport the values p,q,r,s + rational const& p = e->interval.lo_val(); + rational const& q_ = e->interval.lo().val(); + rational const& r = e->interval.hi_val(); + rational const& s_ = e->interval.hi().val(); + SASSERT(p != r && p != 0 && r != 0); + SASSERT(e->src.size() == 1); + + rational const a = mod(p * val + q_, mod_value); + rational const b = mod(r * val + s_, mod_value); + rational const np = mod_value - p; + rational const nr = mod_value - r; + int const corr = e->src[0].is_negative() ? 1 : 0; + + auto delta_l = [&](rational const& val) { + rational num = a - b + corr; + rational l1 = floor(b / r); + rational l2 = val; + if (p > r) + l2 = ceil(num / (p - r)) - 1; + rational l3 = ceil(num / (p + nr)) - 1; + rational l4 = ceil((mod_value - a) / np) - 1; + rational d1 = l3; + rational d2 = std::min(l1, l2); + rational d3 = std::min(l1, l4); + rational d4 = std::min(l2, l4); + rational dmax = std::max(std::max(d1, d2), std::max(d3, d4)); + return std::min(val, dmax); + }; + auto delta_u = [&](rational const& val) { + rational num = a - b + corr; + rational h1 = floor(b / nr); + rational h2 = max_value - val; + if (r > p) + h2 = ceil(num / (r - p)) - 1; + rational h3 = ceil(num / (np + r)) - 1; + rational h4 = ceil((mod_value - a) / p) - 1; + rational d1 = h3; + rational d2 = std::min(h1, h2); + rational d3 = std::min(h1, h4); + rational d4 = std::min(h2, h4); + rational dmax = std::max(std::max(d1, d2), std::max(d3, d4)); + return std::min(max_value - val, dmax); + }; + + if (a > b || (e->src[0].is_negative() && a == b)) { + rational lo = val - delta_l(val); + rational hi = val + delta_u(val) + 1; + + LOG("refine-disequal-lin: " << " [" << lo << ", " << hi << "["); + + SASSERT(0 <= lo && lo <= val); + SASSERT(val <= hi && hi <= mod_value); + if (hi == mod_value) hi = 0; + pdd lop = c.var2pdd(v).mk_val(lo); + pdd hip = c.var2pdd(v).mk_val(hi); + entry* ne = alloc_entry(v, e->constraint_index); + ne->refined = true; + ne->src = e->src; + ne->side_cond = e->side_cond; + ne->coeff = 1; + ne->bit_width = e->bit_width; + ne->interval = eval_interval::proper(lop, lo, hip, hi); + intersect(v, ne); + return false; + } + e = e->next(); + } while (e != first); + return true; } /* diff --git a/src/sat/smt/polysat/viable.h b/src/sat/smt/polysat/viable.h index 25cebdb57..f1ae50c2c 100644 --- a/src/sat/smt/polysat/viable.h +++ b/src/sat/smt/polysat/viable.h @@ -78,6 +78,7 @@ namespace polysat { svector m_layers; public: svector const& get_layers() const { return m_layers; } + svector& get_layers() { return m_layers; } layer& ensure_layer(unsigned bit_width); layer* get_layer(unsigned bit_width); layer* get_layer(entry* e) { return get_layer(e->bit_width); } @@ -118,9 +119,13 @@ namespace polysat { lbool next_viable_overlap(pvar w, rational& val); - lbool next_viable_layer(pvar w, layer const& l, rational& val); + lbool next_viable_layer(pvar w, layer& l, rational& val); - lbool next_viable_non_unit(rational& val); + viable::entry* find_overlap(rational const& val, entry* entries); + + bool refine_disequal_lin(pvar v, rational const& val); + + bool refine_equal_lin(pvar v, rational const& val); pvar m_var = null_var; unsigned m_num_bits = 0;