diff --git a/src/sat/smt/polysat/constraints.cpp b/src/sat/smt/polysat/constraints.cpp index 8217f68a7..fe577ef8c 100644 --- a/src/sat/smt/polysat/constraints.cpp +++ b/src/sat/smt/polysat/constraints.cpp @@ -92,6 +92,22 @@ namespace polysat { return ~parity_at_least(p, k + 1); } + signed_constraint constraints::msb_ge(pdd const& p, unsigned k) { + if (k == 0) + return uge(p, 0); + if (k > p.manager().power_of_2()) + return ult(p, 0); + return uge(p, rational::power_of_two(k - 1)); + } + + signed_constraint constraints::msb_le(pdd const& p, unsigned k) { + if (k == 0) + return eq(p); + if (k >= p.manager().power_of_2()) + return uge(p, 0); + return ult(p, rational::power_of_two(k)); + } + // 2^{N-i-1}* p >= 2^{N-1} signed_constraint constraints::bit(pdd const& p, unsigned i) { unsigned N = p.manager().power_of_2(); diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index 7e54126eb..9ccead6b8 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -171,6 +171,12 @@ namespace polysat { signed_constraint parity_at_least(pdd const& p, unsigned k); signed_constraint parity_at_most(pdd const& p, unsigned k); + // most significant bit-position counting least significant bit as position 1, most as N. + // msb(x) >= k <=> x >= 2^{k-1}, for N >= k > 0 + // msb(x) <= k <=> x < 2^{k}, for 0 <= k < N + signed_constraint msb_ge(pdd const& x, unsigned k); + signed_constraint msb_le(pdd const& x, unsigned k); + signed_constraint lshr(pdd const& a, pdd const& b, pdd const& r); signed_constraint ashr(pdd const& a, pdd const& b, pdd const& r); signed_constraint shl(pdd const& a, pdd const& b, pdd const& r); diff --git a/src/sat/smt/polysat/saturation.cpp b/src/sat/smt/polysat/saturation.cpp index 8f1c6f0c5..496f93b8a 100644 --- a/src/sat/smt/polysat/saturation.cpp +++ b/src/sat/smt/polysat/saturation.cpp @@ -209,12 +209,12 @@ namespace polysat { /** * Expand the following axioms: * Ovfl(x, y) & x <= y => y >= 2^{(N + 1) div 2} - * Ovfl(x, y) & msb(x) <= k => msb(y) >= N - k - 1 - * Ovfl(x, y) & msb(x) <= k & msb(y) <= N - k - 1 => 0x * 0y >= 2^{N-1} + * Ovfl(x, y) & msb(x) <= k => msb(y) >= N - k + 1 + * Ovfl(x, y) & msb(x) <= k & msb(y) <= N - k + 1 => 0x * 0y >= 2^N * * ~Ovfl(x, y) & x <= y => x < 2^{(N + 1) div 2} - * ~Ovfl(x,y) & msb(x) >= k => msb(y) <= N - k - 1 - * ~Ovfl(x,y) & msb(x) >= k & msb(y) >= N - k - 1 => 0x * 0y < 2^{N-1} + * ~Ovfl(x,y) & msb(x) >= k => msb(y) <= N - k + 1 + * ~Ovfl(x,y) & msb(x) >= k & msb(y) >= N - k + 1 => 0x * 0y < 2^N */ void saturation::try_umul_blast(umul_ovfl const& sc) { auto x = sc.p(); @@ -237,48 +237,27 @@ namespace polysat { } // Keep in mind that - // num-bits(0) = 1 + // num-bits(0) = 1 - handled as special case // num-bits(1) = 1 // num-bits(2) = 2 // num-bits(4) = 3 - // msb(0) = 0 - // msb(1) = 0 - // msb(2) = 1 - // msb(3) = 1 - // msb(2^N - 1) = N-1 - // msb(x) >= k <=> x >= 2^k, for k > 0 - // msb(x) <= k <=> x < 2^{k+1}, for k + 1 < N - - auto msb_ge = [&](pdd const& x, unsigned k) { - SASSERT(k > 0); - return C.uge(x, rational::power_of_two(k)); - }; - - auto msb_le = [&](pdd const& x, unsigned k) { - SASSERT(k + 1 < N); - return C.ult(x, rational::power_of_two(k + 1)); - }; if (sc.sign()) { // Case ~Ovfl(x,y) is asserted by current assignment x * y is overflow SASSERT(bx > 1 && by > 1); SASSERT(bx + by >= N + 1); - auto k = bx - 1; if (bx > (N + 1) / 2) { add_clause("~Ovfl(x, y) & x <= y => x < 2^{(N + 1) div 2}", - { d, ~C.ule(x, y), C.ult(x, rational::power_of_two((N + 1) / 2)) }, - true); + { d, ~C.ule(x, y), C.ult(x, rational::power_of_two((N + 1) / 2)) }, true); } else if (bx + by > N + 1) - add_clause("~Ovfl(x, y) & msb(x) >= k => msb(y) <= N - k - 1", - {d, ~msb_ge(x, k), msb_le(y, N - k - 1)}, - true); + add_clause("~Ovfl(x, y) & msb(x) >= k => msb(y) <= N - k + 1", + {d, ~C.msb_ge(x, bx), C.msb_le(y, N - bx + 1)}, true); else { auto x1 = c.mk_zero_extend(1, x); auto y1 = c.mk_zero_extend(1, y); - add_clause("~Ovfl(x, y) => 0x * 0y < 2^{N}", - { d, C.ult(x1 * y1, rational::power_of_two(N)) }, - true); + add_clause("~Ovfl(x, y) => 0x * 0y < 2^N", + { d, C.ult(x1 * y1, rational::power_of_two(N)) }, true); } } else { @@ -288,22 +267,18 @@ namespace polysat { } else if (bx < (N + 1) / 2) { add_clause("Ovfl(x, y) & x <= y => y >= 2^{(N + 1) div 2}", - { d, ~C.ule(x, y), C.uge(x, rational::power_of_two((N + 1) / 2)) }, - true); + { d, ~C.ule(x, y), C.uge(x, rational::power_of_two((N + 1) / 2)) }, true); } else if (bx + by < N + 1) { SASSERT(bx <= by); - auto k = bx - 1; - add_clause("Ovfl(x, y) & msb(x) <= k => msb(y) >= N - k - 1", - { d, ~msb_le(x, k), msb_ge(y, N - k - 1) }, true); + add_clause("Ovfl(x, y) & msb(x) <= k => msb(y) >= N - k", + { d, ~C.msb_le(x, bx), C.msb_ge(y, N - bx + 1) }, true); } else { - auto k = bx - 1; auto x1 = c.mk_zero_extend(1, x); auto y1 = c.mk_zero_extend(1, y); - add_clause("Ovfl(x, y) & msb(x) <= k & msb(y) <= N - k - 1 => 0x * 0y >= 2 ^ {N - 1}", - { d, ~msb_le(x, k), ~msb_le(y, N - k - 1), C.uge(x1 * y1, rational::power_of_two(N - 1)) }, - true); + add_clause("Ovfl(x, y) & msb(x) <= k & msb(y) <= N - k + 1 => 0x * 0y >= 2 ^ N", + { d, ~C.msb_le(x, bx), ~C.msb_le(y, N - bx + 1), C.uge(x1 * y1, rational::power_of_two(N)) }, true); } } }