3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-06-08 18:10:57 +00:00

Implement proper bounds checking for small integers based on platform pointer size

Co-authored-by: nunoplopes <2998477+nunoplopes@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-02-09 15:38:31 +00:00
parent 6e545ac56a
commit cd492c3e9c
2 changed files with 393 additions and 283 deletions

View file

@ -394,9 +394,9 @@ void mpz_manager<SYNCH>::set(mpz_cell& src, mpz & a, int sign, unsigned sz) {
} }
unsigned d = src.m_digits[0]; unsigned d = src.m_digits[0];
if (i == 1 && d <= INT_MAX) { if (i == 1 && d <= static_cast<unsigned>(mpz::SMALL_INT_MAX)) {
// src fits is a fixnum // src fits in small integer range
a.set(sign < 0 ? -static_cast<int>(d) : static_cast<int>(d)); a.set64(sign < 0 ? -static_cast<int64_t>(d) : static_cast<int64_t>(d));
return; return;
} }
@ -704,22 +704,20 @@ mpz mpz_manager<SYNCH>::mod2k(mpz const & a, unsigned k) {
template<bool SYNCH> template<bool SYNCH>
void mpz_manager<SYNCH>::neg(mpz & a) { void mpz_manager<SYNCH>::neg(mpz & a) {
STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";); STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";);
if (is_small(a) && a.value() == INT_MIN) { if (is_small(a)) {
// neg(INT_MIN) is not a small int int64_t v = a.value64();
set_big_i64(a, - static_cast<long long>(INT_MIN)); if (v == mpz::SMALL_INT_MIN) {
// neg(SMALL_INT_MIN) overflows small range
set_big_i64(a, -v);
return; return;
} }
#ifndef _MP_GMP a.set64(-v);
if (is_small(a)) {
a.set(-a.value());
} }
#ifndef _MP_GMP
else { else {
a.set_sign(-a.sign()); a.set_sign(-a.sign());
} }
#else #else
if (is_small(a)) {
a.set(-a.value());
}
else { else {
mpz_neg(*a.ptr(), *a.ptr()); mpz_neg(*a.ptr(), *a.ptr());
} }
@ -730,14 +728,14 @@ void mpz_manager<SYNCH>::neg(mpz & a) {
template<bool SYNCH> template<bool SYNCH>
void mpz_manager<SYNCH>::abs(mpz & a) { void mpz_manager<SYNCH>::abs(mpz & a) {
if (is_small(a)) { if (is_small(a)) {
int v = a.value(); int64_t v = a.value64();
if (v < 0) { if (v < 0) {
if (v == INT_MIN) { if (v == mpz::SMALL_INT_MIN) {
// abs(INT_MIN) is not a small int // abs(SMALL_INT_MIN) overflows small range
set_big_i64(a, - static_cast<long long>(INT_MIN)); set_big_i64(a, -v);
} }
else else
a.set(-v); a.set64(-v);
} }
} }
else { else {
@ -943,15 +941,19 @@ template<bool SYNCH>
void mpz_manager<SYNCH>::gcd(mpz const & a, mpz const & b, mpz & c) { void mpz_manager<SYNCH>::gcd(mpz const & a, mpz const & b, mpz & c) {
static_assert(sizeof(int) == sizeof(int), "size mismatch"); static_assert(sizeof(int) == sizeof(int), "size mismatch");
static_assert(sizeof(mpz) <= 16, "mpz size overflow"); static_assert(sizeof(mpz) <= 16, "mpz size overflow");
if (is_small(a) && is_small(b) && a.value() != INT_MIN && b.value() != INT_MIN) { if (is_small(a) && is_small(b)) {
int _a = a.value(); int64_t _a = a.value64();
int _b = b.value(); int64_t _b = b.value64();
// Check if absolute values fit in uint64 (they always do for small integers)
// and won't overflow when negating
if (_a != mpz::SMALL_INT_MIN && _b != mpz::SMALL_INT_MIN) {
if (_a < 0) _a = -_a; if (_a < 0) _a = -_a;
if (_b < 0) _b = -_b; if (_b < 0) _b = -_b;
unsigned r = u_gcd(_a, _b); uint64_t r = u64_gcd(static_cast<uint64_t>(_a), static_cast<uint64_t>(_b));
set(c, r); set(c, r);
return;
}
} }
else {
#ifdef _MP_GMP #ifdef _MP_GMP
ensure_mpz_t a1(a), b1(b); ensure_mpz_t a1(a), b1(b);
mk_big(c); mk_big(c);
@ -1197,7 +1199,6 @@ void mpz_manager<SYNCH>::gcd(mpz const & a, mpz const & b, mpz & c) {
del(a1); del(b1); del(r); del(t); del(tmp); del(a1); del(b1); del(r); del(t); del(tmp);
#endif // LEHMER_GCD #endif // LEHMER_GCD
} }
}
template<bool SYNCH> template<bool SYNCH>
unsigned mpz_manager<SYNCH>::size_info(mpz const & a) { unsigned mpz_manager<SYNCH>::size_info(mpz const & a) {
@ -2029,26 +2030,52 @@ void mpz_manager<SYNCH>::ensure_capacity(mpz & a, unsigned capacity) {
capacity = m_init_cell_capacity; capacity = m_init_cell_capacity;
if (is_small(a)) { if (is_small(a)) {
int val = a.value(); int64_t val = a.value64();
allocate_if_needed(a, capacity); allocate_if_needed(a, capacity);
SASSERT(a.ptr()->m_capacity >= capacity); SASSERT(a.ptr()->m_capacity >= capacity);
if (val == INT_MIN) { // Check if this is SMALL_INT_MIN which needs special handling
unsigned intmin_sz = m_int_min.ptr()->m_size; if (val == mpz::SMALL_INT_MIN) {
for (unsigned i = 0; i < intmin_sz; ++i) // For 32-bit: SMALL_INT_MIN = -2^30, so -val = 2^30 fits in unsigned
a.ptr()->m_digits[i] = m_int_min.ptr()->m_digits[i]; // For 64-bit: SMALL_INT_MIN = -2^62, so -val = 2^62 fits in uint64_t
a.set_sign(-1); uint64_t abs_val = static_cast<uint64_t>(-val);
a.ptr()->m_size = m_int_min.ptr()->m_size; if (sizeof(digit_t) == sizeof(uint64_t)) {
} // 64-bit machine
else if (val < 0) { a.ptr()->m_digits[0] = static_cast<digit_t>(abs_val);
a.ptr()->m_digits[0] = -val;
a.set_sign(-1);
a.ptr()->m_size = 1; a.ptr()->m_size = 1;
} }
else { else {
a.ptr()->m_digits[0] = val; // 32-bit machine
a.set_sign(1); a.ptr()->m_digits[0] = static_cast<unsigned>(abs_val);
a.ptr()->m_digits[1] = static_cast<unsigned>(abs_val >> 32);
a.ptr()->m_size = (abs_val >> 32) == 0 ? 1 : 2;
}
a.set_sign(-1);
}
else if (val < 0) {
uint64_t abs_val = static_cast<uint64_t>(-val);
if (sizeof(digit_t) == sizeof(uint64_t)) {
a.ptr()->m_digits[0] = static_cast<digit_t>(abs_val);
a.ptr()->m_size = 1; a.ptr()->m_size = 1;
} }
else {
a.ptr()->m_digits[0] = static_cast<unsigned>(abs_val);
a.ptr()->m_digits[1] = static_cast<unsigned>(abs_val >> 32);
a.ptr()->m_size = (abs_val >> 32) == 0 ? 1 : 2;
}
a.set_sign(-1);
}
else {
if (sizeof(digit_t) == sizeof(uint64_t)) {
a.ptr()->m_digits[0] = static_cast<digit_t>(val);
a.ptr()->m_size = 1;
}
else {
a.ptr()->m_digits[0] = static_cast<unsigned>(val);
a.ptr()->m_digits[1] = static_cast<unsigned>(val >> 32);
a.ptr()->m_size = (val >> 32) == 0 ? 1 : 2;
}
a.set_sign(1);
}
} }
else if (a.ptr()->m_capacity < capacity) { else if (a.ptr()->m_capacity < capacity) {
mpz_cell * new_cell = allocate(capacity); mpz_cell * new_cell = allocate(capacity);
@ -2079,10 +2106,10 @@ void mpz_manager<SYNCH>::normalize(mpz & a) {
return; return;
} }
if (i == 1 && ds[0] <= INT_MAX) { if (i == 1 && ds[0] <= static_cast<unsigned>(mpz::SMALL_INT_MAX)) {
// a is small // a fits in small integer range
int val = a.sign() < 0 ? -static_cast<int>(ds[0]) : static_cast<int>(ds[0]); int64_t val = a.sign() < 0 ? -static_cast<int64_t>(ds[0]) : static_cast<int64_t>(ds[0]);
a.set(val); a.set64(val);
return; return;
} }
// adjust size // adjust size

View file

@ -92,6 +92,32 @@ private:
static constexpr uintptr_t OWNER_BIT = 0x4; static constexpr uintptr_t OWNER_BIT = 0x4;
static constexpr uintptr_t MPZ_PTR_MASK = ~static_cast<uintptr_t>(0x7); static constexpr uintptr_t MPZ_PTR_MASK = ~static_cast<uintptr_t>(0x7);
// Small integers are stored shifted left by 1, so we have (sizeof(uintptr_t)*8 - 1) bits available
// This gives us:
// - On 32-bit platforms: 31 bits, range [-2^30, 2^30-1]
// - On 64-bit platforms: 63 bits, range [-2^62, 2^62-1]
static constexpr int SMALL_BITS = sizeof(uintptr_t) * 8 - 1;
// Maximum and minimum values that can be stored as small integers
static constexpr int64_t SMALL_INT_MAX = (static_cast<int64_t>(1) << (SMALL_BITS - 1)) - 1;
static constexpr int64_t SMALL_INT_MIN = -(static_cast<int64_t>(1) << (SMALL_BITS - 1));
static bool fits_in_small(int64_t v) {
return v >= SMALL_INT_MIN && v <= SMALL_INT_MAX;
}
static bool fits_in_small(uint64_t v) {
return v <= static_cast<uint64_t>(SMALL_INT_MAX);
}
static bool fits_in_small(int v) {
return fits_in_small(static_cast<int64_t>(v));
}
static bool fits_in_small(unsigned int v) {
return fits_in_small(static_cast<uint64_t>(v));
}
mpz_type * ptr() const { mpz_type * ptr() const {
SASSERT(!is_small()); SASSERT(!is_small());
return reinterpret_cast<mpz_type*>(m_value & MPZ_PTR_MASK); return reinterpret_cast<mpz_type*>(m_value & MPZ_PTR_MASK);
@ -135,7 +161,11 @@ protected:
friend class mpbq_manager; friend class mpbq_manager;
friend class mpz_stack; friend class mpz_stack;
public: public:
mpz(int v = 0) noexcept : m_value(static_cast<uintptr_t>(static_cast<intptr_t>(v)) << 1) {} mpz(int v = 0) noexcept : m_value(static_cast<uintptr_t>(static_cast<intptr_t>(v)) << 1) {
// On 32-bit platforms, INT_MIN doesn't fit in 31 bits. This constructor should only be used
// with values that fit, or the caller should use set_big_i64.
SASSERT(fits_in_small(v));
}
mpz(mpz_type* ptr) noexcept { mpz(mpz_type* ptr) noexcept {
SASSERT(ptr); SASSERT(ptr);
@ -157,6 +187,11 @@ public:
m_value = static_cast<uintptr_t>(static_cast<intptr_t>(v)) << 1; m_value = static_cast<uintptr_t>(static_cast<intptr_t>(v)) << 1;
} }
void set64(int64_t v) {
SASSERT(fits_in_small(v));
m_value = static_cast<uintptr_t>(static_cast<intptr_t>(v)) << 1;
}
void swap(mpz & other) noexcept { void swap(mpz & other) noexcept {
std::swap(m_value, other.m_value); std::swap(m_value, other.m_value);
} }
@ -168,9 +203,16 @@ public:
inline int value() const { inline int value() const {
SASSERT(is_small()); SASSERT(is_small());
// Decode small integer: shift right by 1 (arithmetic shift to preserve sign) // Decode small integer: shift right by 1 (arithmetic shift to preserve sign)
// Note: On 64-bit platforms, this may truncate if the value doesn't fit in int
return static_cast<int>(static_cast<intptr_t>(m_value) >> 1); return static_cast<int>(static_cast<intptr_t>(m_value) >> 1);
} }
inline int64_t value64() const {
SASSERT(is_small());
// Decode small integer: shift right by 1 (arithmetic shift to preserve sign)
return static_cast<int64_t>(static_cast<intptr_t>(m_value) >> 1);
}
inline int sign() const { inline int sign() const {
SASSERT(!is_small()); SASSERT(!is_small());
return get_sign(); return get_sign();
@ -291,13 +333,16 @@ class mpz_manager {
mpz m_two64; mpz m_two64;
static int64_t i64(mpz const & a) { return static_cast<int64_t>(a.value()); } static int64_t i64(mpz const & a) { return a.value64(); }
void set_big_i64(mpz & c, int64_t v); void set_big_i64(mpz & c, int64_t v);
void set_i64(mpz & c, int64_t v) { void set_i64(mpz & c, int64_t v) {
if (v >= INT_MIN && v <= INT_MAX) { if (mpz::fits_in_small(v)) {
c.set(static_cast<int>(v)); if (!is_small(c)) {
deallocate(c);
}
c.set64(v);
} }
else { else {
set_big_i64(c, v); set_big_i64(c, v);
@ -363,20 +408,44 @@ class mpz_manager {
void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell, mpz_cell* reserve) { void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell, mpz_cell* reserve) {
if (is_small(a)) { if (is_small(a)) {
if (a.value() == INT_MIN) { int64_t val = a.value64();
cell = reserve;
if (val == mpz::SMALL_INT_MIN) {
sign = -1; sign = -1;
cell = m_int_min.ptr(); uint64_t abs_val = static_cast<uint64_t>(-val);
if (sizeof(digit_t) == sizeof(uint64_t)) {
cell->m_size = 1;
cell->m_digits[0] = static_cast<digit_t>(abs_val);
} }
else { else {
cell = reserve; cell->m_digits[0] = static_cast<unsigned>(abs_val);
cell->m_size = 1; cell->m_digits[1] = static_cast<unsigned>(abs_val >> 32);
if (a.value() < 0) { cell->m_size = (abs_val >> 32) == 0 ? 1 : 2;
}
}
else if (val < 0) {
sign = -1; sign = -1;
cell->m_digits[0] = -a.value(); uint64_t abs_val = static_cast<uint64_t>(-val);
if (sizeof(digit_t) == sizeof(uint64_t)) {
cell->m_size = 1;
cell->m_digits[0] = static_cast<digit_t>(abs_val);
}
else {
cell->m_digits[0] = static_cast<unsigned>(abs_val);
cell->m_digits[1] = static_cast<unsigned>(abs_val >> 32);
cell->m_size = (abs_val >> 32) == 0 ? 1 : 2;
}
} }
else { else {
sign = 1; sign = 1;
cell->m_digits[0] = a.value(); if (sizeof(digit_t) == sizeof(uint64_t)) {
cell->m_size = 1;
cell->m_digits[0] = static_cast<digit_t>(val);
}
else {
cell->m_digits[0] = static_cast<unsigned>(val);
cell->m_digits[1] = static_cast<unsigned>(val >> 32);
cell->m_size = (val >> 32) == 0 ? 1 : 2;
} }
} }
} }
@ -593,17 +662,28 @@ public:
} }
void set(mpz & a, int val) { void set(mpz & a, int val) {
// On 32-bit platforms, int can be outside small range
if (mpz::fits_in_small(val)) {
if (!is_small(a)) { if (!is_small(a)) {
deallocate(a); deallocate(a);
} }
a.set(val); a.set(val);
} }
else {
set_i64(a, val);
}
}
void set(mpz & a, unsigned val) { void set(mpz & a, unsigned val) {
if (val <= INT_MAX) if (mpz::fits_in_small(val)) {
set(a, static_cast<int>(val)); if (!is_small(a)) {
else deallocate(a);
set(a, static_cast<int64_t>(static_cast<uint64_t>(val))); }
a.set(static_cast<int>(val));
}
else {
set_i64(a, static_cast<int64_t>(val));
}
} }
void set(mpz & a, char const * val); void set(mpz & a, char const * val);
@ -613,8 +693,11 @@ public:
} }
void set(mpz & a, uint64_t val) { void set(mpz & a, uint64_t val) {
if (val < INT_MAX) { if (mpz::fits_in_small(val)) {
a.set(static_cast<int>(val)); if (!is_small(a)) {
deallocate(a);
}
a.set64(static_cast<int64_t>(val));
} }
else { else {
set_big_ui64(a, val); set_big_ui64(a, val);