diff --git a/src/math/polysat/viable.cpp b/src/math/polysat/viable.cpp index 7aef1cbf3..08ef1b9c5 100644 --- a/src/math/polysat/viable.cpp +++ b/src/math/polysat/viable.cpp @@ -277,84 +277,71 @@ namespace polysat { } bool viable::refine_disequal_lin(pvar v, rational const& val) { + // LOG_H2("refine-disequal-lin with v" << v << ", val = " << val); auto* e = m_diseq_lin[v]; if (!e) return true; entry* first = e; rational const& max_value = s.var2pdd(v).max_value(); - rational mod_value = max_value + 1; + rational const mod_value = max_value + 1; do { LOG("refine-disequal-lin for src: " << e->src); - // We have: - // a1*v + b1 > a2*v + b2 if e->src.is_positive() - // a1*v + b1 >= a2*v + b2 if e->src.is_negative() + // We compute an interval if the concrete value 'val' violates the constraint: + // p*val + q > r*val + s if e->src.is_positive() + // p*val + q >= r*val + s if e->src.is_negative() // Note that e->interval is meaningless in this case, - // we just use it to transport the values a1,b1,a2,b2. - rational const& a1 = e->interval.lo_val(); - rational const& b1 = e->interval.lo().val(); - rational const& a2 = e->interval.hi_val(); - rational const& b2 = e->interval.hi().val(); - SASSERT(a1 != a2 && a1 != 0 && a2 != 0); + // we just use it to transport the values p,q,r,s + rational const& p = e->interval.lo_val(); + rational const& q_ = e->interval.lo().val(); + rational const& r = e->interval.hi_val(); + rational const& s_ = e->interval.hi().val(); + SASSERT(p != r && p != 0 && r != 0); - rational lhs = mod(a1 * val + b1, mod_value); - rational rhs = mod(a2 * val + b2, mod_value); + rational const a = mod(p * val + q_, mod_value); + rational const b = mod(r * val + s_, mod_value); + rational const np = mod_value - p; + rational const nr = mod_value - r; + int const corr = e->src.is_negative() ? 1 : 0; auto delta_l = [&](rational const& val) { - rational m1 = ceil((rhs + 1) / a2); - int corr = e->src.is_negative() ? 1 : 0; - rational m3 = (lhs - rhs + corr) / (a1 - a2); - if (m3 <= 0) - m3 = m1; // remove m3 from the minimum - else - m3 = ceil(m3); - - // return std::min(m1, m3) - 1; - return std::min(val, std::min(m1, m3) - 1); + rational num = a - b + corr; + rational l1 = floor(b / r); + rational l2 = val; + if (p > r) + l2 = ceil(num / (p - r)) - 1; + rational l3 = ceil(num / (p + nr)) - 1; + rational l4 = ceil((mod_value - a) / np) - 1; + rational d1 = l3; + rational d2 = std::min(l1, l2); + rational d3 = std::min(l1, l4); + rational d4 = std::min(l2, l4); + rational dmax = std::max(std::max(d1, d2), std::max(d3, d4)); + return std::min(val, dmax); }; auto delta_u = [&](rational const& val) { - rational m1 = ceil((mod_value - lhs) / a1); - rational m2 = mod_value - val; - int corr = e->src.is_negative() ? 1 : 0; - rational m3 = (lhs - rhs + corr) / (a2 - a1); - if (m3 <= 0) - m3 = m2; // remove m3 from the minimum - else - m3 = ceil(m3); - - return std::min(m1, std::min(m2, m3)) - 1; + rational num = a - b + corr; + rational h1 = floor(b / nr); + rational h2 = max_value - val; + if (r > p) + h2 = ceil(num / (r - p)) - 1; + rational h3 = ceil(num / (np + r)) - 1; + rational h4 = ceil((mod_value - a) / p) - 1; + rational d1 = h3; + rational d2 = std::min(h1, h2); + rational d3 = std::min(h1, h4); + rational d4 = std::min(h2, h4); + rational dmax = std::max(std::max(d1, d2), std::max(d3, d4)); + return std::min(max_value - val, dmax); }; - if (lhs > rhs || (e->src.is_negative() && lhs == rhs)) { - rational lo; - rational hi; - - // TODO: extract into separate function - if (e->src.is_negative() && a2.is_one() && b1.is_zero() && b2.is_zero()) { - // special case: v > -a*v for some numeral a - rational const& a = mod(-a1, mod_value); - if (val.is_zero()) { - lo = 0; - hi = ceil( (mod_value + 1) / (a + 1) ); - } else { - rational const y = mod(-a * val, mod_value); - lo = ceil( val + (y - max_value) / a ); - hi = ceil( (y + a*val + 1) / (a + 1) ); - // can always extend to 0 - if (lo.is_one()) - lo = 0; - } - } else { - // general case - lo = val - delta_l(val); - hi = val + delta_u(val) + 1; - // TODO: increase interval - } + if (a > b || (e->src.is_negative() && a == b)) { + rational lo = val - delta_l(val); + rational hi = val + delta_u(val) + 1; LOG("refine-disequal-lin: " << " [" << lo << ", " << hi << "["); SASSERT(0 <= lo && lo <= val); - // SASSERT(val <= hi && hi <= max_value); SASSERT(val <= hi && hi <= mod_value); if (hi == mod_value) hi = 0; pdd lop = s.var2pdd(v).mk_val(lo); diff --git a/src/test/polysat.cpp b/src/test/polysat.cpp index 0d79452d4..c9d737a60 100644 --- a/src/test/polysat.cpp +++ b/src/test/polysat.cpp @@ -1195,23 +1195,33 @@ class test_fi { } public: - static void exhaustive(unsigned bw = 3) { - rational const m = rational::power_of_two(bw); - for (rational a1(1); a1 < m; ++a1) { - for (rational a2(1); a2 < m; ++a2) { - // TODO: remove this to test other cases - if (a1 == a2) - continue; - for (rational b1(0); b1 < m; ++b1) - for (rational b2(0); b2 < m; ++b2) - for (rational val(0); val < m; ++val) - for (bool negated : {true, false}) - check_one(a1, b1, a2, b2, val, negated, bw); + static void exhaustive(unsigned bw = 0) { + if (bw == 0) { + exhaustive(1); + exhaustive(2); + exhaustive(3); + exhaustive(4); + exhaustive(5); + } else { + std::cout << "test_fi::exhaustive for bw=" << bw << std::endl; + rational const m = rational::power_of_two(bw); + for (rational p(1); p < m; ++p) { + for (rational r(1); r < m; ++r) { + // TODO: remove this condition to test the cases other than disequal_lin! (also start p,q from 0) + if (p == r) + continue; + for (rational q(0); q < m; ++q) + for (rational s(0); s < m; ++s) + for (rational v(0); v < m; ++v) + for (bool negated : {true, false}) + check_one(p, q, r, s, v, negated, bw); + } } } } - static void randomized(unsigned num_rounds = 10'000, unsigned bw = 16) { + static void randomized(unsigned num_rounds = 100'000, unsigned bw = 16) { + std::cout << "test_fi::randomized for bw=" << bw << " (" << num_rounds << " rounds)" << std::endl; rational const m = rational::power_of_two(bw); VERIFY(bw <= 32 && "random_gen generates 32-bit numbers"); random_gen rng;