3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 01:25:31 +00:00
This commit is contained in:
Jakob Rath 2022-12-21 16:05:27 +01:00
parent ec158845fc
commit d51031f19b

View file

@ -359,20 +359,6 @@ namespace polysat {
LOG("refine-equal-lin for src: " << lit_pp(s, e->src));
LOG("forbidden interval v" << v << " " << num_pp(s, v, val) << " " << num_pp(s, v, e->coeff, true) << " * " << e->interval);
if (e->interval.lo_val().is_one() && e->interval.hi_val().is_zero() && e->coeff.is_odd()) {
rational lo(1);
rational hi(0);
LOG("refine-equal-lin: " << " [" << lo << ", " << hi << "[");
entry* ne = alloc_entry();
ne->refined = e;
ne->src = e->src;
ne->side_cond = e->side_cond;
ne->coeff = 1;
ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi);
intersect(v, ne);
return false;
}
if (mod(e->interval.hi_val() + 1, mod_value) == e->interval.lo_val()) {
// We have an equation: a * v == b
rational const a = e->coeff;
@ -380,6 +366,8 @@ namespace polysat {
LOG("refine-equal-lin: equation detected: " << dd::val_pp(m, a, true) << " * v" << v << " == " << dd::val_pp(m, b, false));
unsigned const parity_a = get_parity(a, N);
unsigned const parity_b = get_parity(b, N);
// LOG("a " << a << " parity " << parity_a);
// LOG("b " << b << " parity " << parity_b);
if (parity_a > parity_b) {
// No solution
LOG("refined: no solution due to parity");
@ -392,20 +380,42 @@ namespace polysat {
intersect(v, ne);
return false;
}
if (parity_a == 0) {
// "fast path" for odd a
rational a_inv;
VERIFY(a.mult_inverse(N, a_inv));
rational const hi = mod(a_inv * b, mod_value);
rational const lo = mod(hi + 1, mod_value);
LOG("refined to [" << num_pp(s, v, lo) << ", " << num_pp(s, v, hi) << "[");
SASSERT_EQ(mod(a * hi, mod_value), b); // hi is the solution
entry* ne = alloc_entry();
ne->refined = e;
ne->src = e->src;
ne->side_cond = e->side_cond;
ne->coeff = 1;
ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi);
SASSERT(ne->interval.currently_contains(val));
intersect(v, ne);
return false;
}
// 2^k * v == a_inv * b
// 2^k solutions because only the lower N-k bits of v are fixed.
//
// Smallest solution is v0 == a_inv * (b >> k)
// Solutions are of the form v_i = v0 + 2^(N-k) * i for i in { 0, 1, ..., 2^k - 1 }.
// Forbidden intervals: [v_i + 1; v_{i+1}[ == [ v_i + 1; v_i + 2^(N-k) [
// We need the interval that covers val:
// v_i + 1 <= val < v_i + 2^(N-k)
//
// TODO: create one interval for v[N-k:] instead of 2^k intervals for v.
unsigned const k = parity_a;
rational const a_inv = a.pseudo_inverse(N);
unsigned const N_minus_k = N - parity_a;
unsigned const N_minus_k = N - k;
rational const two_to_N_minus_k = rational::power_of_two(N_minus_k);
rational const v0 = mod(a_inv * b, two_to_N_minus_k);
rational const v0 = mod(a_inv * machine_div2k(b, k), two_to_N_minus_k);
SASSERT(mod(val, two_to_N_minus_k) != v0); // val is not a solution
rational const vi = v0 + clear_lower_bits(mod(val - v0, mod_value), N_minus_k);
rational const lo = vi + 1;
rational const lo = mod(vi + 1, mod_value);
rational const hi = mod(vi + two_to_N_minus_k, mod_value);
LOG("refined to [" << num_pp(s, v, lo) << ", " << num_pp(s, v, hi) << "[");
SASSERT_EQ(mod(a * (lo - 1), mod_value), b); // lo-1 is a solution