From cd492c3e9c39876ec3f85b0538a508cc7f7cb4e6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Feb 2026 15:38:31 +0000 Subject: [PATCH] Implement proper bounds checking for small integers based on platform pointer size Co-authored-by: nunoplopes <2998477+nunoplopes@users.noreply.github.com> --- src/util/mpz.cpp | 545 +++++++++++++++++++++++++---------------------- src/util/mpz.h | 131 +++++++++--- 2 files changed, 393 insertions(+), 283 deletions(-) diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index d74ce46ee..136a95132 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -394,9 +394,9 @@ void mpz_manager::set(mpz_cell& src, mpz & a, int sign, unsigned sz) { } unsigned d = src.m_digits[0]; - if (i == 1 && d <= INT_MAX) { - // src fits is a fixnum - a.set(sign < 0 ? -static_cast(d) : static_cast(d)); + if (i == 1 && d <= static_cast(mpz::SMALL_INT_MAX)) { + // src fits in small integer range + a.set64(sign < 0 ? -static_cast(d) : static_cast(d)); return; } @@ -704,22 +704,20 @@ mpz mpz_manager::mod2k(mpz const & a, unsigned k) { template void mpz_manager::neg(mpz & a) { STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";); - if (is_small(a) && a.value() == INT_MIN) { - // neg(INT_MIN) is not a small int - set_big_i64(a, - static_cast(INT_MIN)); - return; + if (is_small(a)) { + int64_t v = a.value64(); + if (v == mpz::SMALL_INT_MIN) { + // neg(SMALL_INT_MIN) overflows small range + set_big_i64(a, -v); + return; + } + a.set64(-v); } #ifndef _MP_GMP - if (is_small(a)) { - a.set(-a.value()); - } else { a.set_sign(-a.sign()); } #else - if (is_small(a)) { - a.set(-a.value()); - } else { mpz_neg(*a.ptr(), *a.ptr()); } @@ -730,14 +728,14 @@ void mpz_manager::neg(mpz & a) { template void mpz_manager::abs(mpz & a) { if (is_small(a)) { - int v = a.value(); + int64_t v = a.value64(); if (v < 0) { - if (v == INT_MIN) { - // abs(INT_MIN) is not a small int - set_big_i64(a, - static_cast(INT_MIN)); + if (v == mpz::SMALL_INT_MIN) { + // abs(SMALL_INT_MIN) overflows small range + set_big_i64(a, -v); } else - a.set(-v); + a.set64(-v); } } else { @@ -943,260 +941,263 @@ template void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { static_assert(sizeof(int) == sizeof(int), "size mismatch"); static_assert(sizeof(mpz) <= 16, "mpz size overflow"); - if (is_small(a) && is_small(b) && a.value() != INT_MIN && b.value() != INT_MIN) { - int _a = a.value(); - int _b = b.value(); - if (_a < 0) _a = -_a; - if (_b < 0) _b = -_b; - unsigned r = u_gcd(_a, _b); - set(c, r); + if (is_small(a) && is_small(b)) { + int64_t _a = a.value64(); + int64_t _b = b.value64(); + // Check if absolute values fit in uint64 (they always do for small integers) + // and won't overflow when negating + if (_a != mpz::SMALL_INT_MIN && _b != mpz::SMALL_INT_MIN) { + if (_a < 0) _a = -_a; + if (_b < 0) _b = -_b; + uint64_t r = u64_gcd(static_cast(_a), static_cast(_b)); + set(c, r); + return; + } } - else { #ifdef _MP_GMP - ensure_mpz_t a1(a), b1(b); - mk_big(c); - mpz_gcd(*c.ptr(), a1(), b1()); - return; + ensure_mpz_t a1(a), b1(b); + mk_big(c); + mpz_gcd(*c.ptr(), a1(), b1()); + return; #endif - if (is_zero(a)) { - set(c, b); - abs(c); - return; - } - if (is_zero(b)) { - set(c, a); - abs(c); - return; - } + if (is_zero(a)) { + set(c, b); + abs(c); + return; + } + if (is_zero(b)) { + set(c, a); + abs(c); + return; + } #ifdef BINARY_GCD - // Binary GCD for big numbers - // - It doesn't use division - // - The initial experiments, don't show any performance improvement - // - It only works with _MP_INTERNAL - mpz u, v, diff; - set(u, a); - set(v, b); - abs(u); - abs(v); + // Binary GCD for big numbers + // - It doesn't use division + // - The initial experiments, don't show any performance improvement + // - It only works with _MP_INTERNAL + mpz u, v, diff; + set(u, a); + set(v, b); + abs(u); + abs(v); - unsigned k_u = power_of_two_multiple(u); - unsigned k_v = power_of_two_multiple(v); - unsigned k = k_u < k_v ? k_u : k_v; + unsigned k_u = power_of_two_multiple(u); + unsigned k_v = power_of_two_multiple(v); + unsigned k = k_u < k_v ? k_u : k_v; - machine_div2k(u, k_u); + machine_div2k(u, k_u); - while (true) { - machine_div2k(v, k_v); + while (true) { + machine_div2k(v, k_v); - if (lt(u, v)) { - sub(v, u, v); - } - else { - sub(u, v, diff); - swap(u, v); - swap(v, diff); - } - - if (is_zero(v) || is_one(v)) - break; - - // reset least significant bit - if (is_small(v)) - v.set(v.value() & ~1); - else - v.ptr()->m_digits[0] &= ~static_cast(1); - k_v = power_of_two_multiple(v); + if (lt(u, v)) { + sub(v, u, v); + } + else { + sub(u, v, diff); + swap(u, v); + swap(v, diff); } + + if (is_zero(v) || is_one(v)) + break; + + // reset least significant bit + if (is_small(v)) + v.set(v.value() & ~1); + else + v.ptr()->m_digits[0] &= ~static_cast(1); + k_v = power_of_two_multiple(v); + } - mul2k(u, k, c); - del(u); del(v); del(diff); + mul2k(u, k, c); + del(u); del(v); del(diff); #endif // BINARY_GCD #ifdef EUCLID_GCD - mpz tmp1; - mpz tmp2; - mpz aux; - set(tmp1, a); - set(tmp2, b); - abs(tmp1); - abs(tmp2); - if (lt(tmp1, tmp2)) - swap(tmp1, tmp2); - if (is_zero(tmp2)) { - swap(c, tmp1); - } - else { - while (true) { - if (is_uint64(tmp1) && is_uint64(tmp2)) { - set(c, u64_gcd(get_uint64(tmp1), get_uint64(tmp2))); - break; - } - rem(tmp1, tmp2, aux); - if (is_zero(aux)) { - swap(c, tmp2); - break; - } - swap(tmp1, tmp2); - swap(tmp2, aux); + mpz tmp1; + mpz tmp2; + mpz aux; + set(tmp1, a); + set(tmp2, b); + abs(tmp1); + abs(tmp2); + if (lt(tmp1, tmp2)) + swap(tmp1, tmp2); + if (is_zero(tmp2)) { + swap(c, tmp1); + } + else { + while (true) { + if (is_uint64(tmp1) && is_uint64(tmp2)) { + set(c, u64_gcd(get_uint64(tmp1), get_uint64(tmp2))); + break; } + rem(tmp1, tmp2, aux); + if (is_zero(aux)) { + swap(c, tmp2); + break; + } + swap(tmp1, tmp2); + swap(tmp2, aux); } - del(tmp1); del(tmp2); del(aux); + } + del(tmp1); del(tmp2); del(aux); #endif // EUCLID_GCD #ifdef LS_BINARY_GCD - mpz u, v, t, u1, u2; - set(u, a); - set(v, b); - abs(u); - abs(v); - if (lt(u, v)) - swap(u, v); - while (!is_zero(v)) { - // Basic idea: - // compute t = 2^e*v such that t <= u < 2t - // u := min{u - t, 2t - u} - // - // The assignment u := min{u - t, 2t - u} - // can be replaced with u := u - t - // - // Since u and v are positive, we have: - // 2^{log2(u)} <= u < 2^{(log2(u) + 1)} - // 2^{log2(v)} <= v < 2^{(log2(v) + 1)} - // --> - // 2^{log2(v)}*2^{log2(u)-log2(v)} <= v*2^{log2(u)-log2(v)} < 2^{log2(v) + 1}*2^{log2(u)-log2(v)} - // --> - // 2^{log2(u)} <= v*2^{log2(u)-log2(v)} < 2^{log2(u) + 1} - // - // Now, let t be v*2^{log2(u)-log2(v)} - // If t <= u, then we found t - // Otherwise t = t div 2 - unsigned k_u = log2(u); - unsigned k_v = log2(v); - SASSERT(k_v <= k_u); - unsigned e = k_u - k_v; - mul2k(v, e, t); - sub(u, t, u1); - if (is_neg(u1)) { - // t is too big - machine_div2k(t, 1); - // Now, u1 contains u - 2t - neg(u1); - // Now, u1 contains 2t - u - sub(u, t, u2); // u2 := u - t - } - else { - // u1 contains u - t - mul2k(t, 1); - sub(t, u, u2); - // u2 contains 2t - u - } - SASSERT(is_nonneg(u1)); - SASSERT(is_nonneg(u2)); - if (lt(u1, u2)) - swap(u, u1); - else - swap(u, u2); - if (lt(u, v)) - swap(u,v); + mpz u, v, t, u1, u2; + set(u, a); + set(v, b); + abs(u); + abs(v); + if (lt(u, v)) + swap(u, v); + while (!is_zero(v)) { + // Basic idea: + // compute t = 2^e*v such that t <= u < 2t + // u := min{u - t, 2t - u} + // + // The assignment u := min{u - t, 2t - u} + // can be replaced with u := u - t + // + // Since u and v are positive, we have: + // 2^{log2(u)} <= u < 2^{(log2(u) + 1)} + // 2^{log2(v)} <= v < 2^{(log2(v) + 1)} + // --> + // 2^{log2(v)}*2^{log2(u)-log2(v)} <= v*2^{log2(u)-log2(v)} < 2^{log2(v) + 1}*2^{log2(u)-log2(v)} + // --> + // 2^{log2(u)} <= v*2^{log2(u)-log2(v)} < 2^{log2(u) + 1} + // + // Now, let t be v*2^{log2(u)-log2(v)} + // If t <= u, then we found t + // Otherwise t = t div 2 + unsigned k_u = log2(u); + unsigned k_v = log2(v); + SASSERT(k_v <= k_u); + unsigned e = k_u - k_v; + mul2k(v, e, t); + sub(u, t, u1); + if (is_neg(u1)) { + // t is too big + machine_div2k(t, 1); + // Now, u1 contains u - 2t + neg(u1); + // Now, u1 contains 2t - u + sub(u, t, u2); // u2 := u - t } - swap(u, c); - del(u); del(v); del(t); del(u1); del(u2); + else { + // u1 contains u - t + mul2k(t, 1); + sub(t, u, u2); + // u2 contains 2t - u + } + SASSERT(is_nonneg(u1)); + SASSERT(is_nonneg(u2)); + if (lt(u1, u2)) + swap(u, u1); + else + swap(u, u2); + if (lt(u, v)) + swap(u,v); + } + swap(u, c); + del(u); del(v); del(t); del(u1); del(u2); #endif // LS_BINARY_GCD #ifdef LEHMER_GCD - // For now, it only works if sizeof(digit_t) == sizeof(unsigned) - static_assert(sizeof(digit_t) == sizeof(unsigned), ""); - - int64_t a_hat, b_hat, A, B, C, D, T, q, a_sz, b_sz; - mpz a1, b1, t, r, tmp; - set(a1, a); - set(b1, b); - abs(a1); - abs(b1); - if (lt(a1, b1)) - swap(a1, b1); - while (true) { - SASSERT(ge(a1, b1)); - if (is_small(b1)) { - if (is_small(a1)) { - unsigned r = u_gcd(a1.value(), b1.value()); - set(c, r); - break; - } - else { - while (!is_zero(b1)) { - SASSERT(ge(a1, b1)); - rem(a1, b1, tmp); - swap(a1, b1); - swap(b1, tmp); - } - swap(c, a1); - break; - } - } - SASSERT(!is_small(a1)); - SASSERT(!is_small(b1)); - a_sz = a1.ptr()->m_size; - b_sz = b1.ptr()->m_size; - SASSERT(b_sz <= a_sz); - a_hat = a1.ptr()->m_digits[a_sz - 1]; - b_hat = (b_sz == a_sz) ? b1.ptr()->m_digits[b_sz - 1] : 0; - A = 1; - B = 0; - C = 0; - D = 1; - while (true) { - // Loop invariants - SASSERT(a_hat + A <= static_cast(UINT_MAX) + 1); - SASSERT(a_hat + B < static_cast(UINT_MAX) + 1); - SASSERT(b_hat + C < static_cast(UINT_MAX) + 1); - SASSERT(b_hat + D <= static_cast(UINT_MAX) + 1); - // overflows can't happen since I'm using int64 - if (b_hat + C == 0 || b_hat + D == 0) - break; - q = (a_hat + A)/(b_hat + C); - if (q != (a_hat + B)/(b_hat + D)) - break; - T = A - q*C; - A = C; - C = T; - T = B - q*D; - B = D; - D = T; - T = a_hat - q*b_hat; - a_hat = b_hat; - b_hat = T; - } - SASSERT(ge(a1, b1)); - if (B == 0) { - rem(a1, b1, t); - swap(a1, b1); - swap(b1, t); - SASSERT(ge(a1, b1)); + // For now, it only works if sizeof(digit_t) == sizeof(unsigned) + static_assert(sizeof(digit_t) == sizeof(unsigned), ""); + + int64_t a_hat, b_hat, A, B, C, D, T, q, a_sz, b_sz; + mpz a1, b1, t, r, tmp; + set(a1, a); + set(b1, b); + abs(a1); + abs(b1); + if (lt(a1, b1)) + swap(a1, b1); + while (true) { + SASSERT(ge(a1, b1)); + if (is_small(b1)) { + if (is_small(a1)) { + unsigned r = u_gcd(a1.value(), b1.value()); + set(c, r); + break; } else { - // t <- A*a1 - set(tmp, A); - mul(a1, tmp, t); - // t <- t + B*b1 - set(tmp, B); - addmul(t, tmp, b1, t); - // r <- C*a1 - set(tmp, C); - mul(a1, tmp, r); - // r <- r + D*b1 - set(tmp, D); - addmul(r, tmp, b1, r); - // a <- t - swap(a1, t); - // b <- r - swap(b1, r); - SASSERT(ge(a1, b1)); + while (!is_zero(b1)) { + SASSERT(ge(a1, b1)); + rem(a1, b1, tmp); + swap(a1, b1); + swap(b1, tmp); + } + swap(c, a1); + break; } } - del(a1); del(b1); del(r); del(t); del(tmp); -#endif // LEHMER_GCD + SASSERT(!is_small(a1)); + SASSERT(!is_small(b1)); + a_sz = a1.ptr()->m_size; + b_sz = b1.ptr()->m_size; + SASSERT(b_sz <= a_sz); + a_hat = a1.ptr()->m_digits[a_sz - 1]; + b_hat = (b_sz == a_sz) ? b1.ptr()->m_digits[b_sz - 1] : 0; + A = 1; + B = 0; + C = 0; + D = 1; + while (true) { + // Loop invariants + SASSERT(a_hat + A <= static_cast(UINT_MAX) + 1); + SASSERT(a_hat + B < static_cast(UINT_MAX) + 1); + SASSERT(b_hat + C < static_cast(UINT_MAX) + 1); + SASSERT(b_hat + D <= static_cast(UINT_MAX) + 1); + // overflows can't happen since I'm using int64 + if (b_hat + C == 0 || b_hat + D == 0) + break; + q = (a_hat + A)/(b_hat + C); + if (q != (a_hat + B)/(b_hat + D)) + break; + T = A - q*C; + A = C; + C = T; + T = B - q*D; + B = D; + D = T; + T = a_hat - q*b_hat; + a_hat = b_hat; + b_hat = T; + } + SASSERT(ge(a1, b1)); + if (B == 0) { + rem(a1, b1, t); + swap(a1, b1); + swap(b1, t); + SASSERT(ge(a1, b1)); + } + else { + // t <- A*a1 + set(tmp, A); + mul(a1, tmp, t); + // t <- t + B*b1 + set(tmp, B); + addmul(t, tmp, b1, t); + // r <- C*a1 + set(tmp, C); + mul(a1, tmp, r); + // r <- r + D*b1 + set(tmp, D); + addmul(r, tmp, b1, r); + // a <- t + swap(a1, t); + // b <- r + swap(b1, r); + SASSERT(ge(a1, b1)); + } } + del(a1); del(b1); del(r); del(t); del(tmp); +#endif // LEHMER_GCD } template @@ -2029,25 +2030,51 @@ void mpz_manager::ensure_capacity(mpz & a, unsigned capacity) { capacity = m_init_cell_capacity; if (is_small(a)) { - int val = a.value(); + int64_t val = a.value64(); allocate_if_needed(a, capacity); SASSERT(a.ptr()->m_capacity >= capacity); - if (val == INT_MIN) { - unsigned intmin_sz = m_int_min.ptr()->m_size; - for (unsigned i = 0; i < intmin_sz; ++i) - a.ptr()->m_digits[i] = m_int_min.ptr()->m_digits[i]; + // Check if this is SMALL_INT_MIN which needs special handling + if (val == mpz::SMALL_INT_MIN) { + // For 32-bit: SMALL_INT_MIN = -2^30, so -val = 2^30 fits in unsigned + // For 64-bit: SMALL_INT_MIN = -2^62, so -val = 2^62 fits in uint64_t + uint64_t abs_val = static_cast(-val); + if (sizeof(digit_t) == sizeof(uint64_t)) { + // 64-bit machine + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_size = 1; + } + else { + // 32-bit machine + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_digits[1] = static_cast(abs_val >> 32); + a.ptr()->m_size = (abs_val >> 32) == 0 ? 1 : 2; + } a.set_sign(-1); - a.ptr()->m_size = m_int_min.ptr()->m_size; } else if (val < 0) { - a.ptr()->m_digits[0] = -val; + uint64_t abs_val = static_cast(-val); + if (sizeof(digit_t) == sizeof(uint64_t)) { + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_size = 1; + } + else { + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_digits[1] = static_cast(abs_val >> 32); + a.ptr()->m_size = (abs_val >> 32) == 0 ? 1 : 2; + } a.set_sign(-1); - a.ptr()->m_size = 1; } else { - a.ptr()->m_digits[0] = val; + if (sizeof(digit_t) == sizeof(uint64_t)) { + a.ptr()->m_digits[0] = static_cast(val); + a.ptr()->m_size = 1; + } + else { + a.ptr()->m_digits[0] = static_cast(val); + a.ptr()->m_digits[1] = static_cast(val >> 32); + a.ptr()->m_size = (val >> 32) == 0 ? 1 : 2; + } a.set_sign(1); - a.ptr()->m_size = 1; } } else if (a.ptr()->m_capacity < capacity) { @@ -2079,10 +2106,10 @@ void mpz_manager::normalize(mpz & a) { return; } - if (i == 1 && ds[0] <= INT_MAX) { - // a is small - int val = a.sign() < 0 ? -static_cast(ds[0]) : static_cast(ds[0]); - a.set(val); + if (i == 1 && ds[0] <= static_cast(mpz::SMALL_INT_MAX)) { + // a fits in small integer range + int64_t val = a.sign() < 0 ? -static_cast(ds[0]) : static_cast(ds[0]); + a.set64(val); return; } // adjust size diff --git a/src/util/mpz.h b/src/util/mpz.h index 64e1587cb..490c85993 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -92,6 +92,32 @@ private: static constexpr uintptr_t OWNER_BIT = 0x4; static constexpr uintptr_t MPZ_PTR_MASK = ~static_cast(0x7); + // Small integers are stored shifted left by 1, so we have (sizeof(uintptr_t)*8 - 1) bits available + // This gives us: + // - On 32-bit platforms: 31 bits, range [-2^30, 2^30-1] + // - On 64-bit platforms: 63 bits, range [-2^62, 2^62-1] + static constexpr int SMALL_BITS = sizeof(uintptr_t) * 8 - 1; + + // Maximum and minimum values that can be stored as small integers + static constexpr int64_t SMALL_INT_MAX = (static_cast(1) << (SMALL_BITS - 1)) - 1; + static constexpr int64_t SMALL_INT_MIN = -(static_cast(1) << (SMALL_BITS - 1)); + + static bool fits_in_small(int64_t v) { + return v >= SMALL_INT_MIN && v <= SMALL_INT_MAX; + } + + static bool fits_in_small(uint64_t v) { + return v <= static_cast(SMALL_INT_MAX); + } + + static bool fits_in_small(int v) { + return fits_in_small(static_cast(v)); + } + + static bool fits_in_small(unsigned int v) { + return fits_in_small(static_cast(v)); + } + mpz_type * ptr() const { SASSERT(!is_small()); return reinterpret_cast(m_value & MPZ_PTR_MASK); @@ -135,7 +161,11 @@ protected: friend class mpbq_manager; friend class mpz_stack; public: - mpz(int v = 0) noexcept : m_value(static_cast(static_cast(v)) << 1) {} + mpz(int v = 0) noexcept : m_value(static_cast(static_cast(v)) << 1) { + // On 32-bit platforms, INT_MIN doesn't fit in 31 bits. This constructor should only be used + // with values that fit, or the caller should use set_big_i64. + SASSERT(fits_in_small(v)); + } mpz(mpz_type* ptr) noexcept { SASSERT(ptr); @@ -157,6 +187,11 @@ public: m_value = static_cast(static_cast(v)) << 1; } + void set64(int64_t v) { + SASSERT(fits_in_small(v)); + m_value = static_cast(static_cast(v)) << 1; + } + void swap(mpz & other) noexcept { std::swap(m_value, other.m_value); } @@ -168,9 +203,16 @@ public: inline int value() const { SASSERT(is_small()); // Decode small integer: shift right by 1 (arithmetic shift to preserve sign) + // Note: On 64-bit platforms, this may truncate if the value doesn't fit in int return static_cast(static_cast(m_value) >> 1); } + inline int64_t value64() const { + SASSERT(is_small()); + // Decode small integer: shift right by 1 (arithmetic shift to preserve sign) + return static_cast(static_cast(m_value) >> 1); + } + inline int sign() const { SASSERT(!is_small()); return get_sign(); @@ -291,13 +333,16 @@ class mpz_manager { mpz m_two64; - static int64_t i64(mpz const & a) { return static_cast(a.value()); } + static int64_t i64(mpz const & a) { return a.value64(); } void set_big_i64(mpz & c, int64_t v); void set_i64(mpz & c, int64_t v) { - if (v >= INT_MIN && v <= INT_MAX) { - c.set(static_cast(v)); + if (mpz::fits_in_small(v)) { + if (!is_small(c)) { + deallocate(c); + } + c.set64(v); } else { set_big_i64(c, v); @@ -363,20 +408,44 @@ class mpz_manager { void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell, mpz_cell* reserve) { if (is_small(a)) { - if (a.value() == INT_MIN) { + int64_t val = a.value64(); + cell = reserve; + if (val == mpz::SMALL_INT_MIN) { sign = -1; - cell = m_int_min.ptr(); - } - else { - cell = reserve; - cell->m_size = 1; - if (a.value() < 0) { - sign = -1; - cell->m_digits[0] = -a.value(); + uint64_t abs_val = static_cast(-val); + if (sizeof(digit_t) == sizeof(uint64_t)) { + cell->m_size = 1; + cell->m_digits[0] = static_cast(abs_val); } else { - sign = 1; - cell->m_digits[0] = a.value(); + cell->m_digits[0] = static_cast(abs_val); + cell->m_digits[1] = static_cast(abs_val >> 32); + cell->m_size = (abs_val >> 32) == 0 ? 1 : 2; + } + } + else if (val < 0) { + sign = -1; + uint64_t abs_val = static_cast(-val); + if (sizeof(digit_t) == sizeof(uint64_t)) { + cell->m_size = 1; + cell->m_digits[0] = static_cast(abs_val); + } + else { + cell->m_digits[0] = static_cast(abs_val); + cell->m_digits[1] = static_cast(abs_val >> 32); + cell->m_size = (abs_val >> 32) == 0 ? 1 : 2; + } + } + else { + sign = 1; + if (sizeof(digit_t) == sizeof(uint64_t)) { + cell->m_size = 1; + cell->m_digits[0] = static_cast(val); + } + else { + cell->m_digits[0] = static_cast(val); + cell->m_digits[1] = static_cast(val >> 32); + cell->m_size = (val >> 32) == 0 ? 1 : 2; } } } @@ -593,17 +662,28 @@ public: } void set(mpz & a, int val) { - if (!is_small(a)) { - deallocate(a); + // On 32-bit platforms, int can be outside small range + if (mpz::fits_in_small(val)) { + if (!is_small(a)) { + deallocate(a); + } + a.set(val); + } + else { + set_i64(a, val); } - a.set(val); } void set(mpz & a, unsigned val) { - if (val <= INT_MAX) - set(a, static_cast(val)); - else - set(a, static_cast(static_cast(val))); + if (mpz::fits_in_small(val)) { + if (!is_small(a)) { + deallocate(a); + } + a.set(static_cast(val)); + } + else { + set_i64(a, static_cast(val)); + } } void set(mpz & a, char const * val); @@ -613,8 +693,11 @@ public: } void set(mpz & a, uint64_t val) { - if (val < INT_MAX) { - a.set(static_cast(val)); + if (mpz::fits_in_small(val)) { + if (!is_small(a)) { + deallocate(a); + } + a.set64(static_cast(val)); } else { set_big_ui64(a, val);