3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 00:55:31 +00:00

Accelerate interval widening in refine_equal_lin

This commit is contained in:
Jakob Rath 2023-01-13 15:41:28 +01:00
parent 057e115bbc
commit caf624589e

View file

@ -315,15 +315,195 @@ namespace polysat {
return refine_equal_lin(v, val) && refine_disequal_lin(v, val);
}
namespace {
rational div_floor(rational const& a, rational const& b) {
return floor(a / b);
}
rational div_ceil(rational const& a, rational const& b) {
return ceil(a / b);
}
/**
* Traverse all interval constraints with coefficients to check whether current value 'val' for
* 'v' is feasible. If not, extract a (maximal) interval to block 'v' from being assigned val.
*
* To investigate:
* - side conditions are stronger than for unit intervals. They constrain the lower and upper bounds to
* be precisely the assigned values. This is to ensure that lo/hi that are computed based on lo_val
* and division with coeff are valid. Is there a more relaxed scheme?
*/
* Given a*y0 mod M \in [lo;hi], try to find the largest y_max >= y0 such that for all y \in [y0;y_max] . a*y mod M \in [lo;hi].
* Result may not be optimal.
* NOTE: upper bound is inclusive.
*/
rational compute_y_max(rational const& y0, rational const& a, rational const& lo0, rational const& hi, rational const& M) {
SASSERT(0 <= y0 && y0 < M);
SASSERT(1 <= a && a < M);
SASSERT(0 <= lo0 && lo0 < M);
SASSERT(0 <= hi && hi < M);
if (lo0 <= hi) {
SASSERT(lo0 <= mod(a*y0, M) && mod(a*y0, M) <= hi);
}
else {
SASSERT(mod(a*y0, M) <= hi || mod(a*y0, M) >= lo0);
}
// wrapping intervals are handled by replacing the lower bound lo by lo - M
rational const lo = lo0 > hi ? (lo0 - M) : lo0;
auto contained = [&lo, &hi] (rational const& a_y) -> bool {
return lo <= a_y && a_y <= hi;
};
auto delta_h = [&a, &lo, &hi] (rational const& a_y) -> rational {
SASSERT(lo <= a_y && a_y <= hi);
return div_floor(hi - a_y, a);
};
// minimal k such that lo <= a*y0 + k*M
rational const k = div_ceil(lo - a * y0, M);
rational const kM = k*M;
rational const a_y0 = a*y0 + kM;
SASSERT(contained(a_y0));
// maximal y for [lo;hi]-interval around a*y0
// rational const y0h = y0 + div_floor(hi - a_y0, a);
rational const delta0 = delta_h(a_y0);
rational const y0h = y0 + delta0;
rational const a_y0h = a_y0 + a*delta0;
SASSERT(y0 <= y0h);
SASSERT(contained(a_y0h));
// Check the first [lo;hi]-interval after a*y0
rational const y1l = y0h + 1;
rational const a_y1l = a_y0h + a - M;
if (!contained(a_y1l))
return y0h;
rational const delta1 = delta_h(a_y1l);
rational const y1h = y1l + delta1;
rational const a_y1h = a_y1l + a*delta1;
SASSERT(y1l <= y1h);
SASSERT(contained(a_y1h));
// Check the second [lo;hi]-interval after a*y0
rational const y2l = y1h + 1;
rational const a_y2l = a_y1h + a - M;
if (!contained(a_y2l))
return y1h;
SASSERT(contained(a_y2l));
// At this point, [y1l;y1h] must be a full y-interval that can be extended to the right.
// Extending the interval can only be possible if the part not covered by [lo;hi] is smaller than the coefficient a.
SASSERT(lo + M - hi < a);
// The points a*[y0l;y0h] + k*M fall into the interval [lo;hi].
// After the first overflow, the points a*[y1l .. y1h] + (k - 1)*M fall into [lo;hi].
// With each overflow, these points drift by some offset alpha.
rational const step = y1h - y0h;
rational const alpha = a * step - M;
if (alpha == 0) {
// the points do not drift after overflow
// => y_max is infinite
return y0 + M;
}
rational const i =
alpha < 0
// alpha < 0:
// With each overflow to the right, the points drift to the left.
// We can continue overflowing while a * yil >= lo, where yil = y1l + i * step.
? div_floor(lo - a_y1l, alpha)
// alpha > 0:
// With each overflow to the right, the points drift to the right.
// We can continue overflowing while a * yih <= hi, where yih = y1h + i * step.
: div_floor(hi - a_y1h, alpha);
// i is the number of overflows to the right
SASSERT(i >= 0);
// a * [yil;yih] is the right-most y-interval that is completely in [lo;hi].
rational const yih = y1h + i * step;
rational const a_yih = a*yih + (k - i - 1)*M;
SASSERT(contained(a_yih));
// The next interval to the right may contain a few more values if alpha > 0
// (because only the upper end moved out of the interval)
rational const y_next = yih + 1;
rational const a_y_next = a_yih + a - M;
if (contained(a_y_next))
return y_next + delta_h(a_y_next);
else
return yih;
}
/**
* Given a*y0 mod M \in [lo;hi], try to find the smallest y_min <= y0 such that for all y \in [y_min;y0] . a*y mod M \in [lo;hi].
* Result may not be optimal.
* NOTE: upper bound is inclusive.
*/
rational compute_y_min(rational const& y0, rational const& a, rational const& lo, rational const& hi, rational const& M) {
SASSERT(0 <= y0 && y0 < M);
SASSERT(1 <= a && a < M);
SASSERT(0 <= lo && lo < M);
SASSERT(0 <= hi && hi < M);
auto const negateM = [&M] (rational const& x) -> rational {
if (x.is_zero())
return x;
else
return M - x;
};
rational y_min = negateM(compute_y_max(negateM(y0), a, negateM(hi), negateM(lo), M));
while (y_min > y0)
y_min -= M;
return y_min;
}
/**
* Given a*y0 mod M \in [lo;hi],
* find the largest interval [y_min;y_max] around y0 such that for all y \in [y_min;y_max] . a*y mod M \in [lo;hi].
* The result is optimal.
* NOTE: upper bounds are inclusive.
*/
std::pair<rational, rational> compute_y_bounds(rational const& y0, rational const& a, rational const& lo, rational const& hi, rational const& M) {
SASSERT(0 <= y0 && y0 < M);
SASSERT(1 <= a && a < M);
SASSERT(0 <= lo && lo < M);
SASSERT(0 <= hi && hi < M);
auto const is_valid = [&] (rational const& y) -> bool {
rational const a_y = mod(a * y, M);
if (lo <= hi)
return lo <= a_y && a_y <= hi;
else
return a_y <= hi || lo <= a_y;
};
rational const y_max_max = y0 + M - 1;
rational y_max = compute_y_max(y0, a, lo, hi, M);
while (y_max < y_max_max && is_valid(y_max + 1))
y_max = compute_y_max(y_max + 1, a, lo, hi, M);
rational const y_min_min = y_max - M + 1;
rational y_min = y0;
while (y_min > y_min_min && is_valid(y_min - 1))
y_min = compute_y_min(y_min - 1, a, lo, hi, M);
SASSERT(y_min <= y0 && y0 <= y_max);
rational const len = y_max - y_min + 1;
if (len >= M)
// full
return { rational::zero(), M - 1 };
else
return { mod(y_min, M), mod(y_max, M) };
}
}
/**
* Traverse all interval constraints with coefficients to check whether current value 'val' for
* 'v' is feasible. If not, extract a (maximal) interval to block 'v' from being assigned val.
*
* To investigate:
* - side conditions are stronger than for unit intervals. They constrain the lower and upper bounds to
* be precisely the assigned values. This is to ensure that lo/hi that are computed based on lo_val
* and division with coeff are valid. Is there a more relaxed scheme?
*/
bool viable::refine_equal_lin(pvar v, rational const& val) {
// LOG_H2("refine-equal-lin with v" << v << ", val = " << val);
entry const* e = m_equal_lin[v];
@ -339,45 +519,10 @@ namespace polysat {
// with an early entry when a later entry could give a better interval.
m_equal_lin[v] = m_equal_lin[v]->next();
auto delta_l = [&](rational const& coeff_val) {
return floor((coeff_val - e->interval.lo_val()) / e->coeff);
};
auto delta_u = [&](rational const& coeff_val) {
return floor((e->interval.hi_val() - coeff_val - 1) / e->coeff);
};
// naive widening. TODO: can we accelerate this?
// condition e->interval.hi_val() < coeff_val is
// to ensure that widening is performed on the same interval
// similar for e->interval.lo_val() > coeff_val
// needs a proof.
auto increase_hi = [&](rational& hi) {
while (hi < max_value) {
rational coeff_val = mod(e->coeff * hi, mod_value);
if (!e->interval.currently_contains(coeff_val))
break;
if (e->interval.hi_val() < coeff_val)
break;
hi += delta_u(coeff_val) + 1;
}
};
auto decrease_lo = [&](rational& lo) {
while (lo > 1) {
rational coeff_val = mod(e->coeff * (lo - 1), mod_value);
if (!e->interval.currently_contains(coeff_val))
break;
if (e->interval.lo_val() > coeff_val)
break;
rational d = delta_l(coeff_val);
if (d.is_zero())
break;
lo -= d;
}
};
do {
rational coeff_val = mod(e->coeff * val, mod_value);
if (e->interval.currently_contains(coeff_val)) {
LOG("refine-equal-lin for src: " << lit_pp(s, e->src));
LOG("refine-equal-lin for v" << v << " in src: " << lit_pp(s, e->src));
LOG("forbidden interval v" << v << " " << num_pp(s, v, val) << " " << num_pp(s, v, e->coeff, true) << " * " << e->interval);
if (mod(e->interval.hi_val() + 1, mod_value) == e->interval.lo_val()) {
@ -387,8 +532,6 @@ namespace polysat {
LOG("refine-equal-lin: equation detected: " << dd::val_pp(m, a, true) << " * v" << v << " == " << dd::val_pp(m, b, false));
unsigned const parity_a = get_parity(a, N);
unsigned const parity_b = get_parity(b, N);
// LOG("a " << a << " parity " << parity_a);
// LOG("b " << b << " parity " << parity_b);
if (parity_a > parity_b) {
// No solution
LOG("refined: no solution due to parity");
@ -452,28 +595,13 @@ namespace polysat {
return false;
}
rational lo = val - delta_l(coeff_val);
rational hi = val + delta_u(coeff_val) + 1;
// TODO: special handling for the even factors of e->coeff = 2^k * a', a' odd
// (create one interval for v[N-k:] instead of 2^k intervals for v)
if (e->interval.lo_val() < e->interval.hi_val()) {
increase_hi(hi);
decrease_lo(lo);
}
else if (e->interval.lo_val() <= coeff_val) {
rational lambda_u = floor((max_value - coeff_val) / e->coeff);
hi = val + lambda_u + 1;
if (hi > max_value)
hi = 0;
decrease_lo(lo);
}
else {
SASSERT(coeff_val < e->interval.hi_val());
rational lambda_l = floor(coeff_val / e->coeff);
lo = val - lambda_l;
increase_hi(hi);
}
auto [lo, hi] = compute_y_bounds(val, e->coeff, e->interval.lo_val(), e->interval.hi_val() - 1, mod_value);
hi += 1; // compute_y_bounds calculates with inclusive upper bound; correct this here.
LOG("refined to [" << num_pp(s, v, lo) << ", " << num_pp(s, v, hi) << "[");
SASSERT(hi <= mod_value);
SASSERT(0 <= lo && lo <= val && val <= hi && hi <= mod_value);
bool full = (lo == 0 && hi == mod_value);
if (hi == mod_value)
hi = 0;
@ -763,10 +891,8 @@ namespace polysat {
}
if (e == n) {
SASSERT_EQ(e, e0);
// VERIFY(false);
return false;
}
if (e == e0) {
out_lo = n->interval.lo_val();
if (!n->interval.lo().is_val())