diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index 10c345b0c..d83b74039 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -33,6 +33,11 @@ Revision History: #error No multi-precision library selected. #endif +// Out-of-line definitions for static constexpr members (required for ODR-use in C++14 and earlier) +constexpr int mpz::SMALL_BITS; +constexpr int64_t mpz::SMALL_INT_MAX; +constexpr int64_t mpz::SMALL_INT_MIN; + // Available GCD algorithms // #define EUCLID_GCD // #define BINARY_GCD @@ -214,8 +219,8 @@ void mpz_manager::deallocate(bool is_heap, mpz_cell * ptr) { template mpz_manager::sign_cell::sign_cell(mpz_manager& m, mpz const& a): m_local(reinterpret_cast(m_bytes)), m_a(a) { - m_local.m_ptr->m_capacity = capacity; - m.get_sign_cell(a, m_sign, m_cell, m_local.m_ptr); + m_local.ptr()->m_capacity = capacity; + m.get_sign_cell(a, m_sign, m_cell, m_local.ptr()); } @@ -224,12 +229,11 @@ mpz_manager::sign_cell::sign_cell(mpz_manager& m, mpz const& a): template void mpz_manager::del(mpz_manager* m, mpz & a) { - if (a.m_ptr) { + if (!a.is_small()) { SASSERT(m); - m->deallocate(a.m_owner == mpz_self, a.m_ptr); - a.m_ptr = nullptr; - a.m_kind = mpz_small; - a.m_owner = mpz_self; + mpz::mpz_type* p = a.ptr(); + m->deallocate(!a.is_external(), p); + a.m_value = 0; // reset to small } } @@ -260,43 +264,44 @@ void mpz_manager::sub(mpz const & a, mpz const & b, mpz & c) { template void mpz_manager::set_big_i64(mpz & c, int64_t v) { #ifndef _MP_GMP - if (c.m_ptr == nullptr) { - c.m_ptr = allocate(m_init_cell_capacity); - c.m_owner = mpz_self; + mpz_cell* cell = c.is_small() ? nullptr : c.ptr(); + if (cell == nullptr) { + cell = allocate(m_init_cell_capacity); + c.set_ptr(cell, false, false); // Will update sign below } - c.m_kind = mpz_large; SASSERT(capacity(c) >= m_init_cell_capacity); uint64_t _v; + bool is_negative = false; if (v == std::numeric_limits::min()) { // min-int is even _v = -(v/2); - c.m_val = -1; + is_negative = true; } else if (v < 0) { _v = -v; - c.m_val = -1; + is_negative = true; } else { _v = v; - c.m_val = 1; } + c.set_sign(is_negative ? -1 : 1); if (sizeof(digit_t) == sizeof(uint64_t)) { // 64-bit machine digits(c)[0] = static_cast(_v); - c.m_ptr->m_size = 1; + c.ptr()->m_size = 1; } else { // 32-bit machine digits(c)[0] = static_cast(_v); digits(c)[1] = static_cast(_v >> 32); - c.m_ptr->m_size = digits(c)[1] == 0 ? 1 : 2; + c.ptr()->m_size = digits(c)[1] == 0 ? 1 : 2; } #else - if (c.m_ptr == nullptr) { - c.m_ptr = allocate(); - c.m_owner = mpz_self; + mpz_t* cell = c.is_small() ? nullptr : c.ptr(); + if (cell == nullptr) { + cell = allocate(); + c.set_ptr(cell, false, false); } - c.m_kind = mpz_large; uint64_t _v; bool sign = v < 0; if (v == std::numeric_limits::min()) { @@ -308,14 +313,14 @@ void mpz_manager::set_big_i64(mpz & c, int64_t v) { else { _v = v; } - mpz_set_ui(*c.m_ptr, static_cast(_v)); + mpz_set_ui(*c.ptr(), static_cast(_v)); MPZ_BEGIN_CRITICAL(); mpz_set_ui(m_tmp, static_cast(_v >> 32)); mpz_mul(m_tmp, m_tmp, m_two32); - mpz_add(*c.m_ptr, *c.m_ptr, m_tmp); + mpz_add(*c.ptr(), *c.ptr(), m_tmp); MPZ_END_CRITICAL(); if (sign) - mpz_neg(*c.m_ptr, *c.m_ptr); + mpz_neg(*c.ptr(), *c.ptr()); #endif if (v == std::numeric_limits::min()) { big_add(c, c, c); @@ -325,35 +330,35 @@ void mpz_manager::set_big_i64(mpz & c, int64_t v) { template void mpz_manager::set_big_ui64(mpz & c, uint64_t v) { #ifndef _MP_GMP - if (c.m_ptr == nullptr) { - c.m_ptr = allocate(m_init_cell_capacity); - c.m_owner = mpz_self; + mpz_cell* cell = c.is_small() ? nullptr : c.ptr(); + if (cell == nullptr) { + cell = allocate(m_init_cell_capacity); + c.set_ptr(cell, false, false); // positive, owned } - c.m_kind = mpz_large; SASSERT(capacity(c) >= m_init_cell_capacity); - c.m_val = 1; + c.set_sign(1); // positive if (sizeof(digit_t) == sizeof(uint64_t)) { // 64-bit machine digits(c)[0] = static_cast(v); - c.m_ptr->m_size = 1; + c.ptr()->m_size = 1; } else { // 32-bit machine digits(c)[0] = static_cast(v); digits(c)[1] = static_cast(v >> 32); - c.m_ptr->m_size = digits(c)[1] == 0 ? 1 : 2; + c.ptr()->m_size = digits(c)[1] == 0 ? 1 : 2; } #else - if (c.m_ptr == nullptr) { - c.m_ptr = allocate(); - c.m_owner = mpz_self; + mpz_t* cell = c.is_small() ? nullptr : c.ptr(); + if (cell == nullptr) { + cell = allocate(); + c.set_ptr(cell, false, false); // positive, owned } - c.m_kind = mpz_large; - mpz_set_ui(*c.m_ptr, static_cast(v)); + mpz_set_ui(*c.ptr(), static_cast(v)); MPZ_BEGIN_CRITICAL(); mpz_set_ui(m_tmp, static_cast(v >> 32)); mpz_mul(m_tmp, m_tmp, m_two32); - mpz_add(*c.m_ptr, *c.m_ptr, m_tmp); + mpz_add(*c.ptr(), *c.ptr(), m_tmp); MPZ_END_CRITICAL(); #endif } @@ -365,10 +370,10 @@ mpz_manager::ensure_mpz_t::ensure_mpz_t(mpz const& a) { if (is_small(a)) { m_result = &m_local; mpz_init(m_local); - mpz_set_si(m_local, a.m_val); + mpz_set_si(m_local, a.value()); } else { - m_result = a.m_ptr; + m_result = a.ptr(); } } @@ -394,17 +399,16 @@ void mpz_manager::set(mpz_cell& src, mpz & a, int sign, unsigned sz) { } unsigned d = src.m_digits[0]; - if (i == 1 && d <= INT_MAX) { - // src fits is a fixnum - a.m_val = sign < 0 ? -static_cast(d) : static_cast(d); - a.m_kind = mpz_small; + if (i == 1 && d <= static_cast(mpz::SMALL_INT_MAX)) { + // src fits in small integer range + a.set64(sign < 0 ? -static_cast(d) : static_cast(d)); return; } set_digits(a, i, src.m_digits); - a.m_val = sign; + a.set_sign(sign); - SASSERT(a.m_kind == mpz_large); + SASSERT(!a.is_small()); } #endif @@ -443,15 +447,14 @@ void mpz_manager::set_digits(mpz & target, unsigned sz, digit_t const * d set(target, digits[0]); else { #ifndef _MP_GMP - target.m_val = 1; // number is positive. - if (target.m_ptr == nullptr) { + mpz_cell* cell = target.is_small() ? nullptr : target.ptr(); + if (cell == nullptr) { unsigned c = sz < m_init_cell_capacity ? m_init_cell_capacity : sz; - target.m_ptr = allocate(c); - target.m_ptr->m_size = sz; - target.m_ptr->m_capacity = c; - target.m_kind = mpz_large; - target.m_owner = mpz_self; - memcpy(target.m_ptr->m_digits, digits, sizeof(digit_t) * sz); + cell = allocate(c); + cell->m_size = sz; + cell->m_capacity = c; + target.set_ptr(cell, false, false); // positive, owned + memcpy(cell->m_digits, digits, sizeof(digit_t) * sz); } else if (capacity(target) < sz) { SASSERT(sz > m_init_cell_capacity); @@ -460,29 +463,26 @@ void mpz_manager::set_digits(mpz & target, unsigned sz, digit_t const * d ptr->m_size = sz; ptr->m_capacity = sz; deallocate(target); - target.m_val = 1; - target.m_ptr = ptr; - target.m_kind = mpz_large; - target.m_owner = mpz_self; + target.set_ptr(ptr, false, false); // positive, owned } else { - target.m_ptr->m_size = sz; - if (target.m_ptr->m_digits != digits) - memcpy(target.m_ptr->m_digits, digits, sizeof(digit_t) * sz); - target.m_kind = mpz_large; + target.ptr()->m_size = sz; + if (target.ptr()->m_digits != digits) + memcpy(target.ptr()->m_digits, digits, sizeof(digit_t) * sz); + // already large } #else mk_big(target); // reset - mpz_set_ui(*target.m_ptr, digits[sz - 1]); + mpz_set_ui(*target.ptr(), digits[sz - 1]); SASSERT(sz > 0); unsigned i = sz - 1; MPZ_BEGIN_CRITICAL(); while (i > 0) { --i; - mpz_mul_2exp(*target.m_ptr, *target.m_ptr, 32); + mpz_mul_2exp(*target.ptr(), *target.ptr(), 32); mpz_set_ui(m_tmp, digits[i]); - mpz_add(*target.m_ptr, *target.m_ptr, m_tmp); + mpz_add(*target.ptr(), *target.ptr(), m_tmp); } MPZ_END_CRITICAL(); #endif @@ -568,7 +568,7 @@ void mpz_manager::machine_div(mpz const & a, mpz const & b, mpz & c) { template void mpz_manager::reset(mpz & a) { deallocate(a); - set(a, 0); + a.m_value = 0; // reset to small } template @@ -673,7 +673,7 @@ mpz mpz_manager::mod2k(mpz const & a, unsigned k) { if (rem_bits > 0 && digit_count < ca.cell()->m_size) { is_zero &= (digits(result)[digit_count] = ca.cell()->m_digits[digit_count] & mask) == 0; } - result.m_ptr->m_size = total_digits; + result.ptr()->m_size = total_digits; if (ca.sign() < 0 && !is_zero) { // Negative case: if non-zero, result = 2^k - (|a| mod 2^k) @@ -700,7 +700,7 @@ mpz mpz_manager::mod2k(mpz const & a, unsigned k) { ensure_mpz_t a1(a); mk_big(result); MPZ_BEGIN_CRITICAL(); - mpz_tdiv_r_2exp(*result.m_ptr, a1(), k); + mpz_tdiv_r_2exp(*result.ptr(), a1(), k); MPZ_END_CRITICAL(); #endif return result; @@ -709,19 +709,22 @@ mpz mpz_manager::mod2k(mpz const & a, unsigned k) { template void mpz_manager::neg(mpz & a) { STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";); - if (is_small(a) && a.m_val == INT_MIN) { - // neg(INT_MIN) is not a small int - set_big_i64(a, - static_cast(INT_MIN)); - return; + if (is_small(a)) { + int64_t v = a.value(); + if (v == mpz::SMALL_INT_MIN) { + // neg(SMALL_INT_MIN) overflows small range + set_big_i64(a, -v); + return; + } + a.set64(-v); } #ifndef _MP_GMP - a.m_val = -a.m_val; -#else - if (is_small(a)) { - a.m_val = -a.m_val; - } else { - mpz_neg(*a.m_ptr, *a.m_ptr); + a.set_sign(-a.sign()); + } +#else + else { + mpz_neg(*a.ptr(), *a.ptr()); } #endif STRACE(mpz, tout << to_string(a) << "\n";); @@ -730,20 +733,21 @@ void mpz_manager::neg(mpz & a) { template void mpz_manager::abs(mpz & a) { if (is_small(a)) { - if (a.m_val < 0) { - if (a.m_val == INT_MIN) { - // abs(INT_MIN) is not a small int - set_big_i64(a, - static_cast(INT_MIN)); + int64_t v = a.value(); + if (v < 0) { + if (v == mpz::SMALL_INT_MIN) { + // abs(SMALL_INT_MIN) overflows small range + set_big_i64(a, -v); } else - a.m_val = -a.m_val; + a.set64(-v); } } else { #ifndef _MP_GMP - a.m_val = 1; + a.set_sign(1); #else - mpz_abs(*a.m_ptr, *a.m_ptr); + mpz_abs(*a.ptr(), *a.ptr()); #endif } } @@ -765,9 +769,9 @@ void mpz_manager::big_add_sub(mpz const & a, mpz const & b, mpz & c) { allocate_if_needed(tmp, sz); m_mpn_manager.add(ca.cell()->m_digits, ca.cell()->m_size, cb.cell()->m_digits, cb.cell()->m_size, - tmp.m_ptr->m_digits, sz, &real_sz); + tmp.ptr()->m_digits, sz, &real_sz); SASSERT(real_sz <= sz); - set(*tmp.m_ptr, c, ca.sign(), real_sz); + set(*tmp.ptr(), c, ca.sign(), real_sz); } else { digit_t borrow; @@ -784,10 +788,10 @@ void mpz_manager::big_add_sub(mpz const & a, mpz const & b, mpz & c) { cb.cell()->m_size, ca.cell()->m_digits, ca.cell()->m_size, - tmp.m_ptr->m_digits, + tmp.ptr()->m_digits, &borrow); SASSERT(borrow == 0); - set(*tmp.m_ptr, c, sign_b, sz); + set(*tmp.ptr(), c, sign_b, sz); } else { // a > b @@ -797,10 +801,10 @@ void mpz_manager::big_add_sub(mpz const & a, mpz const & b, mpz & c) { ca.cell()->m_size, cb.cell()->m_digits, cb.cell()->m_size, - tmp.m_ptr->m_digits, + tmp.ptr()->m_digits, &borrow); SASSERT(borrow == 0); - set(*tmp.m_ptr, c, ca.sign(), sz); + set(*tmp.ptr(), c, ca.sign(), sz); } } del(tmp); @@ -817,7 +821,7 @@ void mpz_manager::big_add(mpz const & a, mpz const & b, mpz & c) { // GMP version ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_add(*c.m_ptr, a1(), b1()); + mpz_add(*c.ptr(), a1(), b1()); #endif } @@ -829,7 +833,7 @@ void mpz_manager::big_sub(mpz const & a, mpz const & b, mpz & c) { // GMP version ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_sub(*c.m_ptr, a1(), b1()); + mpz_sub(*c.ptr(), a1(), b1()); #endif } @@ -845,14 +849,14 @@ void mpz_manager::big_mul(mpz const & a, mpz const & b, mpz & c) { ca.cell()->m_size, cb.cell()->m_digits, cb.cell()->m_size, - tmp.m_ptr->m_digits); - set(*tmp.m_ptr, c, ca.sign() == cb.sign() ? 1 : -1, sz); + tmp.ptr()->m_digits); + set(*tmp.ptr(), c, ca.sign() == cb.sign() ? 1 : -1, sz); del(tmp); #else // GMP version ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_mul(*c.m_ptr, a1(), b1()); + mpz_mul(*c.ptr(), a1(), b1()); #endif } @@ -866,7 +870,7 @@ void mpz_manager::big_div_rem(mpz const & a, mpz const & b, mpz & q, mpz ensure_mpz_t a1(a), b1(b); mk_big(q); mk_big(r); - mpz_tdiv_qr(*q.m_ptr, *r.m_ptr, a1(), b1()); + mpz_tdiv_qr(*q.ptr(), *r.ptr(), a1(), b1()); #endif } @@ -897,12 +901,12 @@ void mpz_manager::quot_rem_core(mpz const & a, mpz const & b, mpz & q, mp allocate_if_needed(r1, r_sz); m_mpn_manager.div(ca.cell()->m_digits, ca.cell()->m_size, cb.cell()->m_digits, cb.cell()->m_size, - q1.m_ptr->m_digits, - r1.m_ptr->m_digits); + q1.ptr()->m_digits, + r1.ptr()->m_digits); if (MODE == QUOT_ONLY || MODE == QUOT_AND_REM) - set(*q1.m_ptr, q, ca.sign() == cb.sign() ? 1 : -1, q_sz); + set(*q1.ptr(), q, ca.sign() == cb.sign() ? 1 : -1, q_sz); if (MODE == REM_ONLY || MODE == QUOT_AND_REM) - set(*r1.m_ptr, r, ca.sign(), r_sz); + set(*r1.ptr(), r, ca.sign(), r_sz); del(q1); del(r1); } @@ -919,7 +923,7 @@ void mpz_manager::big_div(mpz const & a, mpz const & b, mpz & c) { // GMP version ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_tdiv_q(*c.m_ptr, a1(), b1()); + mpz_tdiv_q(*c.ptr(), a1(), b1()); #endif } @@ -934,27 +938,31 @@ void mpz_manager::big_rem(mpz const & a, mpz const & b, mpz & c) { // GMP version ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_tdiv_r(*c.m_ptr, a1(), b1()); + mpz_tdiv_r(*c.ptr(), a1(), b1()); #endif } template void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { - static_assert(sizeof(a.m_val) == sizeof(int), "size mismatch"); static_assert(sizeof(mpz) <= 16, "mpz size overflow"); - if (is_small(a) && is_small(b) && a.m_val != INT_MIN && b.m_val != INT_MIN) { - int _a = a.m_val; - int _b = b.m_val; - if (_a < 0) _a = -_a; - if (_b < 0) _b = -_b; - unsigned r = u_gcd(_a, _b); - set(c, r); + if (is_small(a) && is_small(b)) { + int64_t _a = a.value(); + int64_t _b = b.value(); + // Check if absolute values fit in uint64 (they always do for small integers) + // and won't overflow when negating + if (_a != mpz::SMALL_INT_MIN && _b != mpz::SMALL_INT_MIN) { + if (_a < 0) _a = -_a; + if (_b < 0) _b = -_b; + uint64_t r = u64_gcd(static_cast(_a), static_cast(_b)); + set(c, r); + return; + } } else { #ifdef _MP_GMP ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_gcd(*c.m_ptr, a1(), b1()); + mpz_gcd(*c.ptr(), a1(), b1()); return; #endif if (is_zero(a)) { @@ -1001,9 +1009,9 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { // reset least significant bit if (is_small(v)) - v.m_val &= ~1; + v.set64(v.value() & ~1); else - v.m_ptr->m_digits[0] &= ~static_cast(1); + v.ptr()->m_digits[0] &= ~static_cast(1); k_v = power_of_two_multiple(v); } @@ -1118,7 +1126,7 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { SASSERT(ge(a1, b1)); if (is_small(b1)) { if (is_small(a1)) { - unsigned r = u_gcd(a1.m_val, b1.m_val); + uint64_t r = u64_gcd(static_cast(a1.value()), static_cast(b1.value())); set(c, r); break; } @@ -1135,11 +1143,11 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { } SASSERT(!is_small(a1)); SASSERT(!is_small(b1)); - a_sz = a1.m_ptr->m_size; - b_sz = b1.m_ptr->m_size; + a_sz = a1.ptr()->m_size; + b_sz = b1.ptr()->m_size; SASSERT(b_sz <= a_sz); - a_hat = a1.m_ptr->m_digits[a_sz - 1]; - b_hat = (b_sz == a_sz) ? b1.m_ptr->m_digits[b_sz - 1] : 0; + a_hat = a1.ptr()->m_digits[a_sz - 1]; + b_hat = (b_sz == a_sz) ? b1.ptr()->m_digits[b_sz - 1] : 0; A = 1; B = 0; C = 0; @@ -1194,8 +1202,8 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { } } del(a1); del(b1); del(r); del(t); del(tmp); + } #endif // LEHMER_GCD - } } template @@ -1203,9 +1211,9 @@ unsigned mpz_manager::size_info(mpz const & a) { if (is_small(a)) return 1; #ifndef _MP_GMP - return a.m_ptr->m_size + 1; + return a.ptr()->m_size + 1; #else - return mpz_size(*a.m_ptr); + return mpz_size(*a.ptr()); #endif } @@ -1398,8 +1406,7 @@ void mpz_manager::bitwise_or(mpz const & a, mpz const & b, mpz & c) { SASSERT(is_nonneg(b)); TRACE(mpz, tout << "is_small(a): " << is_small(a) << ", is_small(b): " << is_small(b) << "\n";); if (is_small(a) && is_small(b)) { - c.m_val = a.m_val | b.m_val; - c.m_kind = mpz_small; + c.set64(a.value() | b.value()); } else { #ifndef _MP_GMP @@ -1434,7 +1441,7 @@ void mpz_manager::bitwise_or(mpz const & a, mpz const & b, mpz & c) { #else ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_ior(*c.m_ptr, a1(), b1()); + mpz_ior(*c.ptr(), a1(), b1()); #endif } } @@ -1444,8 +1451,7 @@ void mpz_manager::bitwise_and(mpz const & a, mpz const & b, mpz & c) { SASSERT(is_nonneg(a)); SASSERT(is_nonneg(b)); if (is_small(a) && is_small(b)) { - c.m_val = a.m_val & b.m_val; - c.m_kind = mpz_small; + c.set64(a.value() & b.value()); } else { #ifndef _MP_GMP @@ -1469,7 +1475,7 @@ void mpz_manager::bitwise_and(mpz const & a, mpz const & b, mpz & c) { #else ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_and(*c.m_ptr, a1(), b1()); + mpz_and(*c.ptr(), a1(), b1()); #endif } } @@ -1511,7 +1517,7 @@ void mpz_manager::bitwise_xor(mpz const & a, mpz const & b, mpz & c) { #else ensure_mpz_t a1(a), b1(b); mk_big(c); - mpz_xor(*c.m_ptr, a1(), b1()); + mpz_xor(*c.ptr(), a1(), b1()); #endif } } @@ -1558,33 +1564,32 @@ void mpz_manager::big_set(mpz & target, mpz const & source) { #ifndef _MP_GMP if (&target == &source) return; - target.m_val = source.m_val; - if (target.m_ptr == nullptr) { - target.m_ptr = allocate(capacity(source)); - target.m_ptr->m_size = size(source); - target.m_ptr->m_capacity = capacity(source); - target.m_kind = mpz_large; - target.m_owner = mpz_self; - memcpy(target.m_ptr->m_digits, source.m_ptr->m_digits, sizeof(digit_t) * size(source)); + int src_sign = source.sign(); + mpz_cell* target_cell = target.is_small() ? nullptr : target.ptr(); + if (target_cell == nullptr) { + mpz_cell* new_cell = allocate(capacity(source)); + new_cell->m_size = size(source); + new_cell->m_capacity = capacity(source); + memcpy(new_cell->m_digits, source.ptr()->m_digits, sizeof(digit_t) * size(source)); + target.set_ptr(new_cell, src_sign < 0, false); } else if (capacity(target) < size(source)) { deallocate(target); - target.m_ptr = allocate(capacity(source)); - target.m_ptr->m_size = size(source); - target.m_ptr->m_capacity = capacity(source); - target.m_kind = mpz_large; - target.m_owner = mpz_self; - memcpy(target.m_ptr->m_digits, source.m_ptr->m_digits, sizeof(digit_t) * size(source)); + mpz_cell* new_cell = allocate(capacity(source)); + new_cell->m_size = size(source); + new_cell->m_capacity = capacity(source); + memcpy(new_cell->m_digits, source.ptr()->m_digits, sizeof(digit_t) * size(source)); + target.set_ptr(new_cell, src_sign < 0, false); } else { - target.m_ptr->m_size = size(source); - memcpy(target.m_ptr->m_digits, source.m_ptr->m_digits, sizeof(digit_t) * size(source)); - target.m_kind = mpz_large; + target.ptr()->m_size = size(source); + memcpy(target.ptr()->m_digits, source.ptr()->m_digits, sizeof(digit_t) * size(source)); + target.set_sign(src_sign); } #else // GMP version mk_big(target); - mpz_set(*target.m_ptr, *source.m_ptr); + mpz_set(*target.ptr(), *source.ptr()); #endif } @@ -1628,10 +1633,10 @@ int mpz_manager::big_compare(mpz const & a, mpz const & b) { template bool mpz_manager::is_uint64(mpz const & a) const { #ifndef _MP_GMP - if (a.m_val < 0) - return false; if (is_small(a)) - return true; + return a.value() >= 0; + if (a.sign() < 0) + return false; if (sizeof(digit_t) == sizeof(uint64_t)) { return size(a) <= 1; } @@ -1641,8 +1646,8 @@ bool mpz_manager::is_uint64(mpz const & a) const { #else // GMP version if (is_small(a)) - return a.m_val >= 0; - return is_nonneg(a) && mpz_cmp(*a.m_ptr, m_uint64_max) <= 0; + return a.value() >= 0; + return is_nonneg(a) && mpz_cmp(*a.ptr(), m_uint64_max) <= 0; #endif } @@ -1656,7 +1661,7 @@ bool mpz_manager::is_int64(mpz const & a) const { uint64_t num = big_abs_to_uint64(a); uint64_t msb = static_cast(1) << 63; uint64_t msb_val = msb & num; - if (a.m_val >= 0) { + if (a.sign() >= 0) { // non-negative number. return (0 == msb_val); } @@ -1669,29 +1674,29 @@ bool mpz_manager::is_int64(mpz const & a) const { } #else // GMP version - return mpz_cmp(m_int64_min, *a.m_ptr) <= 0 && mpz_cmp(*a.m_ptr, m_int64_max) <= 0; + return mpz_cmp(m_int64_min, *a.ptr()) <= 0 && mpz_cmp(*a.ptr(), m_int64_max) <= 0; #endif } template uint64_t mpz_manager::get_uint64(mpz const & a) const { if (is_small(a)) - return static_cast(a.m_val); + return static_cast(a.value()); #ifndef _MP_GMP - SASSERT(a.m_ptr->m_size > 0); + SASSERT(a.ptr()->m_size > 0); return big_abs_to_uint64(a); #else // GMP version if (sizeof(uint64_t) == sizeof(unsigned long)) { - return mpz_get_ui(*a.m_ptr); + return mpz_get_ui(*a.ptr()); } else { MPZ_BEGIN_CRITICAL(); mpz_manager * _this = const_cast(this); - mpz_set(_this->m_tmp, *a.m_ptr); + mpz_set(_this->m_tmp, *a.ptr()); mpz_mod(_this->m_tmp, m_tmp, m_two32); uint64_t r = static_cast(mpz_get_ui(m_tmp)); - mpz_set(_this->m_tmp, *a.m_ptr); + mpz_set(_this->m_tmp, *a.ptr()); mpz_div(_this->m_tmp, m_tmp, m_two32); r += static_cast(mpz_get_ui(m_tmp)) << static_cast(32); MPZ_END_CRITICAL(); @@ -1703,11 +1708,11 @@ uint64_t mpz_manager::get_uint64(mpz const & a) const { template int64_t mpz_manager::get_int64(mpz const & a) const { if (is_small(a)) - return static_cast(a.m_val); + return a.value(); #ifndef _MP_GMP SASSERT(is_int64(a)); uint64_t num = big_abs_to_uint64(a); - if (a.m_val < 0) { + if (a.sign() < 0) { if (num != 0 && (num << 1) == 0) return INT64_MIN; return -static_cast(num); @@ -1715,15 +1720,15 @@ int64_t mpz_manager::get_int64(mpz const & a) const { return static_cast(num); #else // GMP - if (sizeof(int64_t) == sizeof(long) || mpz_fits_slong_p(*a.m_ptr)) { - return mpz_get_si(*a.m_ptr); + if (sizeof(int64_t) == sizeof(long) || mpz_fits_slong_p(*a.ptr())) { + return mpz_get_si(*a.ptr()); } else { MPZ_BEGIN_CRITICAL(); mpz_manager * _this = const_cast(this); - mpz_mod(_this->m_tmp, *a.m_ptr, m_two32); + mpz_mod(_this->m_tmp, *a.ptr(), m_two32); int64_t r = static_cast(mpz_get_ui(m_tmp)); - mpz_div(_this->m_tmp, *a.m_ptr, m_two32); + mpz_div(_this->m_tmp, *a.ptr(), m_two32); r += static_cast(mpz_get_si(m_tmp)) << static_cast(32); MPZ_END_CRITICAL(); return r; @@ -1734,7 +1739,7 @@ int64_t mpz_manager::get_int64(mpz const & a) const { template double mpz_manager::get_double(mpz const & a) const { if (is_small(a)) - return static_cast(a.m_val); + return static_cast(a.value()); #ifndef _MP_GMP double r = 0.0; double d = 1.0; @@ -1749,30 +1754,34 @@ double mpz_manager::get_double(mpz const & a) const { if (!(r >= 0.0)) { r = static_cast(UINT64_MAX); // some large number } - return a.m_val < 0 ? -r : r; + return a.sign() < 0 ? -r : r; #else - return mpz_get_d(*a.m_ptr); + return mpz_get_d(*a.ptr()); #endif } template void mpz_manager::display(std::ostream & out, mpz const & a) const { if (is_small(a)) { - out << a.m_val; + out << a.value(); } else { #ifndef _MP_GMP - if (a.m_val < 0) - out << '-'; - - auto sz = sizeof(digit_t) == 4 ? 11 : 21; - sbuffer buffer(sz * size(a), 0); - out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size()); + if (a.sign() < 0) + out << "-"; + if (sizeof(digit_t) == 4) { + sbuffer buffer(11*size(a), 0); + out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size()); + } + else { + sbuffer buffer(21*size(a), 0); + out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size()); + } #else // GMP version - size_t sz = mpz_sizeinbase(*a.m_ptr, 10) + 2; + size_t sz = mpz_sizeinbase(*a.ptr(), 10) + 2; sbuffer buffer(sz, 0); - mpz_get_str(buffer.data(), 10, *a.m_ptr); + mpz_get_str(buffer.data(), 10, *a.ptr()); out << buffer.data(); #endif } @@ -1828,11 +1837,11 @@ void mpz_manager::display_hex(std::ostream & out, mpz const & a, unsigned } #else // GMP version - size_t sz = mpz_sizeinbase(*(a.m_ptr), 16); + size_t sz = mpz_sizeinbase(*(a.ptr()), 16); unsigned requiredLength = num_bits / 4; unsigned padding = requiredLength > sz ? requiredLength - sz : 0; sbuffer buffer(sz, 0); - mpz_get_str(buffer.data(), 16, *(a.m_ptr)); + mpz_get_str(buffer.data(), 16, *(a.ptr())); for (unsigned i = 0; i < padding; ++i) { out << "0"; } @@ -1883,10 +1892,10 @@ void mpz_manager::display_bin(std::ostream & out, mpz const & a, unsigned } #else // GMP version - size_t sz = mpz_sizeinbase(*(a.m_ptr), 2); + size_t sz = mpz_sizeinbase(*(a.ptr()), 2); unsigned padding = num_bits > sz ? num_bits - sz : 0; sbuffer buffer(sz, 0); - mpz_get_str(buffer.data(), 2, *(a.m_ptr)); + mpz_get_str(buffer.data(), 2, *(a.ptr())); for (unsigned i = 0; i < padding; ++i) { out << "0"; } @@ -1905,14 +1914,14 @@ std::string mpz_manager::to_string(mpz const & a) const { template unsigned mpz_manager::hash(mpz const & a) { if (is_small(a)) - return ::abs(a.m_val); + return ::abs(a.value()); #ifndef _MP_GMP unsigned sz = size(a); if (sz == 1) return static_cast(digits(a)[0]); return string_hash(std::string_view(reinterpret_cast(digits(a)), sz * sizeof(digit_t)), 17); #else - return mpz_get_si(*a.m_ptr); + return mpz_get_si(*a.ptr()); #endif } @@ -1921,38 +1930,37 @@ void mpz_manager::power(mpz const & a, unsigned p, mpz & b) { #ifdef _MP_GMP if (!is_small(a)) { mk_big(b); - mpz_pow_ui(*b.m_ptr, *a.m_ptr, p); + mpz_pow_ui(*b.ptr(), *a.ptr(), p); return; } #endif #ifndef _MP_GMP if (is_small(a)) { - if (a.m_val == 2) { + if (a.value() == 2) { if (p < 8 * sizeof(int) - 1) { - b.m_val = 1 << p; - b.m_kind = mpz_small; + b.set(1 << p); } else { unsigned sz = p/(8 * sizeof(digit_t)) + 1; unsigned shift = p%(8 * sizeof(digit_t)); SASSERT(sz > 0); allocate_if_needed(b, sz); - SASSERT(b.m_ptr->m_capacity >= sz); - b.m_ptr->m_size = sz; + SASSERT(b.ptr()->m_capacity >= sz); + b.ptr()->m_size = sz; for (unsigned i = 0; i < sz - 1; ++i) - b.m_ptr->m_digits[i] = 0; - b.m_ptr->m_digits[sz-1] = 1 << shift; - b.m_val = 1; - b.m_kind = mpz_large; + b.ptr()->m_digits[i] = 0; + b.ptr()->m_digits[sz-1] = 1 << shift; + // b is already large after allocate_if_needed, just ensure sign is positive + b.set_sign(1); } return; } - if (a.m_val == 0) { + if (a.value() == 0) { SASSERT(p != 0); set(b, 0); return; } - if (a.m_val == 1) { + if (a.value() == 1) { set(b, 1); return; } @@ -1983,8 +1991,9 @@ bool mpz_manager::is_power_of_two(mpz const & a, unsigned & shift) { if (is_nonpos(a)) return false; if (is_small(a)) { - if (::is_power_of_two(a.m_val)) { - shift = ::log2((unsigned)a.m_val); + int64_t v = a.value(); + if (v > 0 && (v & (v - 1)) == 0) { // Check if power of 2 + shift = uint64_log2(static_cast(v)); return true; } else { @@ -1992,7 +2001,7 @@ bool mpz_manager::is_power_of_two(mpz const & a, unsigned & shift) { } } #ifndef _MP_GMP - mpz_cell * c = a.m_ptr; + mpz_cell * c = a.ptr(); unsigned sz = c->m_size; digit_t * ds = c->m_digits; for (unsigned i = 0; i < sz - 1; ++i) { @@ -2008,7 +2017,7 @@ bool mpz_manager::is_power_of_two(mpz const & a, unsigned & shift) { return false; } #else - if (mpz_popcount(*a.m_ptr) == 1) { + if (mpz_popcount(*a.ptr()) == 1) { shift = log2(a); return true; } @@ -2028,45 +2037,69 @@ void mpz_manager::ensure_capacity(mpz & a, unsigned capacity) { capacity = m_init_cell_capacity; if (is_small(a)) { - int val = a.m_val; + int64_t val = a.value(); allocate_if_needed(a, capacity); - a.m_kind = mpz_large; - SASSERT(a.m_ptr->m_capacity >= capacity); - if (val == INT_MIN) { - unsigned intmin_sz = m_int_min.m_ptr->m_size; - for (unsigned i = 0; i < intmin_sz; ++i) - a.m_ptr->m_digits[i] = m_int_min.m_ptr->m_digits[i]; - a.m_val = -1; - a.m_ptr->m_size = m_int_min.m_ptr->m_size; + SASSERT(a.ptr()->m_capacity >= capacity); + // Check if this is SMALL_INT_MIN which needs special handling + if (val == mpz::SMALL_INT_MIN) { + // For 32-bit: SMALL_INT_MIN = -2^30, so -val = 2^30 fits in unsigned + // For 64-bit: SMALL_INT_MIN = -2^62, so -val = 2^62 fits in uint64_t + uint64_t abs_val = static_cast(-val); + if (sizeof(digit_t) == sizeof(uint64_t)) { + // 64-bit machine + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_size = 1; + } + else { + // 32-bit machine + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_digits[1] = static_cast(abs_val >> 32); + a.ptr()->m_size = (abs_val >> 32) == 0 ? 1 : 2; + } + a.set_sign(-1); } else if (val < 0) { - a.m_ptr->m_digits[0] = -val; - a.m_val = -1; - a.m_ptr->m_size = 1; + uint64_t abs_val = static_cast(-val); + if (sizeof(digit_t) == sizeof(uint64_t)) { + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_size = 1; + } + else { + a.ptr()->m_digits[0] = static_cast(abs_val); + a.ptr()->m_digits[1] = static_cast(abs_val >> 32); + a.ptr()->m_size = (abs_val >> 32) == 0 ? 1 : 2; + } + a.set_sign(-1); } else { - a.m_ptr->m_digits[0] = val; - a.m_val = 1; - a.m_ptr->m_size = 1; + if (sizeof(digit_t) == sizeof(uint64_t)) { + a.ptr()->m_digits[0] = static_cast(val); + a.ptr()->m_size = 1; + } + else { + a.ptr()->m_digits[0] = static_cast(val); + a.ptr()->m_digits[1] = static_cast(val >> 32); + a.ptr()->m_size = (val >> 32) == 0 ? 1 : 2; + } + a.set_sign(1); } } - else if (a.m_ptr->m_capacity < capacity) { + else if (a.ptr()->m_capacity < capacity) { mpz_cell * new_cell = allocate(capacity); SASSERT(new_cell->m_capacity == capacity); - unsigned old_sz = a.m_ptr->m_size; + unsigned old_sz = a.ptr()->m_size; new_cell->m_size = old_sz; for (unsigned i = 0; i < old_sz; ++i) - new_cell->m_digits[i] = a.m_ptr->m_digits[i]; + new_cell->m_digits[i] = a.ptr()->m_digits[i]; + bool is_neg = a.sign() < 0; deallocate(a); - a.m_ptr = new_cell; - a.m_owner = mpz_self; - a.m_kind = mpz_large; + a.set_ptr(new_cell, is_neg, false); } } template void mpz_manager::normalize(mpz & a) { - mpz_cell * c = a.m_ptr; + mpz_cell * c = a.ptr(); digit_t * ds = c->m_digits; unsigned i = c->m_size; for (; i > 0; --i) { @@ -2080,11 +2113,10 @@ void mpz_manager::normalize(mpz & a) { return; } - if (i == 1 && ds[0] <= INT_MAX) { - // a is small - int val = a.m_val < 0 ? -static_cast(ds[0]) : static_cast(ds[0]); - a.m_val = val; - a.m_kind = mpz_small; + if (i == 1 && ds[0] <= static_cast(mpz::SMALL_INT_MAX)) { + // a fits in small integer range + int64_t val = a.sign() < 0 ? -static_cast(ds[0]) : static_cast(ds[0]); + a.set64(val); return; } // adjust size @@ -2099,17 +2131,25 @@ void mpz_manager::machine_div2k(mpz & a, unsigned k) { if (is_small(a)) { if (k < 32) { int64_t twok = 1ull << ((int64_t)k); - int64_t val = a.m_val; - a.m_val = (int)(val/twok); + int64_t val = a.value(); + int64_t result = val / twok; + // Division by power of 2 should keep us in small range + SASSERT(mpz::fits_in_small(result)); + a.set64(result); + } + else if (k < 64) { + int64_t twok = 1ull << ((int64_t)k); + int64_t val = a.value(); + a.set64(val/twok); } else { - a.m_val = 0; + a.set(0); } return; } #ifndef _MP_GMP unsigned digit_shift = k / (8 * sizeof(digit_t)); - mpz_cell * c = a.m_ptr; + mpz_cell * c = a.ptr(); unsigned sz = c->m_size; if (digit_shift >= sz) { set(a, 0); @@ -2157,7 +2197,7 @@ void mpz_manager::machine_div2k(mpz & a, unsigned k) { MPZ_BEGIN_CRITICAL(); mpz_tdiv_q_2exp(m_tmp, a1(), k); mk_big(a); - mpz_swap(*a.m_ptr, m_tmp); + mpz_swap(*a.ptr(), m_tmp); MPZ_END_CRITICAL(); #endif } @@ -2174,14 +2214,13 @@ void mpz_manager::mul2k(mpz & a, unsigned k) { TRACE(mpz_mul2k, tout << "mul2k\na: " << to_string(a) << "\nk: " << k << "\n";); unsigned word_shift = k / (8 * sizeof(digit_t)); unsigned bit_shift = k % (8 * sizeof(digit_t)); - unsigned old_sz = is_small(a) ? 1 : a.m_ptr->m_size; + unsigned old_sz = is_small(a) ? 1 : a.ptr()->m_size; unsigned new_sz = old_sz + word_shift + 1; ensure_capacity(a, new_sz); TRACE(mpz_mul2k, tout << "word_shift: " << word_shift << "\nbit_shift: " << bit_shift << "\nold_sz: " << old_sz << "\nnew_sz: " << new_sz - << "\na after ensure capacity:\n" << to_string(a) << "\n"; - tout << a.m_kind << "\n";); + << "\na after ensure capacity:\n" << to_string(a) << "\n";); SASSERT(!is_small(a)); - mpz_cell * cell_a = a.m_ptr; + mpz_cell * cell_a = a.ptr(); old_sz = cell_a->m_size; digit_t * ds = cell_a->m_digits; for (unsigned i = old_sz; i < new_sz; ++i) @@ -2220,7 +2259,7 @@ void mpz_manager::mul2k(mpz & a, unsigned k) { #else ensure_mpz_t a1(a); mk_big(a); - mpz_mul_2exp(*a.m_ptr, a1(), k); + mpz_mul_2exp(*a.ptr(), a1(), k); #endif } @@ -2234,32 +2273,50 @@ unsigned mpz_manager::power_of_two_multiple(mpz const & a) { return 0; if (is_small(a)) { unsigned r = 0; - int v = a.m_val; -#define COUNT_DIGIT_RIGHT_ZEROS() \ - if (v % (1 << 16) == 0) { \ - r += 16; \ - v /= (1 << 16); \ - } \ - if (v % (1 << 8) == 0) { \ - r += 8; \ - v /= (1 << 8); \ - } \ - if (v % (1 << 4) == 0) { \ - r += 4; \ - v /= (1 << 4); \ - } \ - if (v % (1 << 2) == 0) { \ - r += 2; \ - v /= (1 << 2); \ - } \ - if (v % 2 == 0) { \ - r++; \ + int64_t val = a.value(); + // Count trailing zeros in 64-bit value + if (val == 0) return 0; + + // Work with absolute value for counting trailing zeros + // Handle SMALL_INT_MIN specially to avoid overflow + uint64_t v; + if (val == mpz::SMALL_INT_MIN) { + // SMALL_INT_MIN = -2^(SMALL_BITS-1), so it has (SMALL_BITS-1) trailing zeros + // On 32-bit: return 30, on 64-bit: return 62 + return (sizeof(uintptr_t) * 8 - 1) - 1; + } else if (val < 0) { + v = static_cast(-val); + } else { + v = static_cast(val); + } + + if ((v & 0xFFFFFFFF) == 0) { + r += 32; + v >>= 32; + } + if ((v & 0xFFFF) == 0) { + r += 16; + v >>= 16; + } + if ((v & 0xFF) == 0) { + r += 8; + v >>= 8; + } + if ((v & 0xF) == 0) { + r += 4; + v >>= 4; + } + if ((v & 0x3) == 0) { + r += 2; + v >>= 2; + } + if ((v & 0x1) == 0) { + r++; } - COUNT_DIGIT_RIGHT_ZEROS(); return r; } #ifndef _MP_GMP - mpz_cell * c = a.m_ptr; + mpz_cell * c = a.ptr(); unsigned sz = c->m_size; unsigned r = 0; digit_t * source = c->m_digits; @@ -2274,14 +2331,33 @@ unsigned mpz_manager::power_of_two_multiple(mpz const & a) { v = static_cast(static_cast(v) / (static_cast(1) << 32)); } } - COUNT_DIGIT_RIGHT_ZEROS(); + // Count trailing zeros in digit_t + if (v % (1 << 16) == 0) { + r += 16; + v /= (1 << 16); + } + if (v % (1 << 8) == 0) { + r += 8; + v /= (1 << 8); + } + if (v % (1 << 4) == 0) { + r += 4; + v /= (1 << 4); + } + if (v % (1 << 2) == 0) { + r += 2; + v /= (1 << 2); + } + if (v % 2 == 0) { + r++; + } return r; } r += (8 * sizeof(digit_t)); } return r; #else - return mpz_scan1(*a.m_ptr, 0); + return mpz_scan1(*a.ptr(), 0); #endif } @@ -2289,11 +2365,13 @@ template unsigned mpz_manager::log2(mpz const & a) { if (is_nonpos(a)) return 0; - if (is_small(a)) - return ::log2((unsigned)a.m_val); + if (is_small(a)) { + int64_t v = a.value(); + return uint64_log2(static_cast(v)); + } #ifndef _MP_GMP static_assert(sizeof(digit_t) == 8 || sizeof(digit_t) == 4, ""); - mpz_cell * c = a.m_ptr; + mpz_cell * c = a.ptr(); unsigned sz = c->m_size; digit_t * ds = c->m_digits; if (sizeof(digit_t) == 8) @@ -2301,7 +2379,7 @@ unsigned mpz_manager::log2(mpz const & a) { else return (sz - 1)*32 + ::log2(static_cast(ds[sz-1])); #else - unsigned r = mpz_sizeinbase(*a.m_ptr, 2); + unsigned r = mpz_sizeinbase(*a.ptr(), 2); SASSERT(r > 0); return r - 1; #endif @@ -2311,14 +2389,19 @@ template unsigned mpz_manager::mlog2(mpz const & a) { if (is_nonneg(a)) return 0; - if (is_small(a) && a.m_val == INT_MIN) - return ::log2((unsigned)a.m_val); - - if (is_small(a)) - return ::log2((unsigned)-a.m_val); + if (is_small(a)) { + int64_t v = a.value(); + if (v == mpz::SMALL_INT_MIN) { + // Special case: negating SMALL_INT_MIN would overflow + // For 32-bit: SMALL_INT_MIN = -2^30, so log2(2^30) = 30 + // For 64-bit: SMALL_INT_MIN = -2^62, so log2(2^62) = 62 + return (sizeof(uintptr_t) * 8 - 1) - 1; + } + return uint64_log2(static_cast(-v)); + } #ifndef _MP_GMP static_assert(sizeof(digit_t) == 8 || sizeof(digit_t) == 4, ""); - mpz_cell * c = a.m_ptr; + mpz_cell * c = a.ptr(); unsigned sz = c->m_size; digit_t * ds = c->m_digits; if (sizeof(digit_t) == 8) @@ -2327,7 +2410,7 @@ unsigned mpz_manager::mlog2(mpz const & a) { return (sz - 1)*32 + ::log2(static_cast(ds[sz-1])); #else MPZ_BEGIN_CRITICAL(); - mpz_neg(m_tmp, *a.m_ptr); + mpz_neg(m_tmp, *a.ptr()); unsigned r = mpz_sizeinbase(m_tmp, 2); MPZ_END_CRITICAL(); SASSERT(r > 0); @@ -2534,15 +2617,15 @@ template digit_t mpz_manager::get_least_significant(mpz const& a) { SASSERT(!is_neg(a)); if (is_small(a)) - return std::abs(a.m_val); + return std::abs(a.value()); #ifndef _MP_GMP - mpz_cell* cell_a = a.m_ptr; + mpz_cell* cell_a = a.ptr(); unsigned sz = cell_a->m_size; if (sz == 0) return 0; return cell_a->m_digits[0]; #else - return mpz_get_ui(*a.m_ptr); + return mpz_get_ui(*a.ptr()); #endif } @@ -2550,27 +2633,34 @@ template bool mpz_manager::decompose(mpz const & a, svector & digits) { digits.reset(); if (is_small(a)) { - if (a.m_val < 0) { - digits.push_back(-a.m_val); - return true; - } - else { - digits.push_back(a.m_val); - return false; + int64_t v = a.value(); + bool is_neg = v < 0; + uint64_t abs_v = is_neg ? static_cast(-v) : static_cast(v); + + // Decompose absolute value into digits + if (sizeof(digit_t) == sizeof(uint64_t)) { + digits.push_back(static_cast(abs_v)); + } else { + // digit_t is 32-bit, need to split 64-bit value + digits.push_back(static_cast(abs_v)); + if (abs_v >> 32) { + digits.push_back(static_cast(abs_v >> 32)); + } } + return is_neg; } else { #ifndef _MP_GMP - mpz_cell * cell_a = a.m_ptr; + mpz_cell * cell_a = a.ptr(); unsigned sz = cell_a->m_size; for (unsigned i = 0; i < sz; ++i) { digits.push_back(cell_a->m_digits[i]); } - return a.m_val < 0; + return a.sign() < 0; #else bool r = is_neg(a); MPZ_BEGIN_CRITICAL(); - mpz_set(m_tmp, *a.m_ptr); + mpz_set(m_tmp, *a.ptr()); mpz_abs(m_tmp, m_tmp); while (mpz_sgn(m_tmp) != 0) { mpz_tdiv_r_2exp(m_tmp2, m_tmp, 32); @@ -2587,16 +2677,17 @@ bool mpz_manager::decompose(mpz const & a, svector & digits) { template bool mpz_manager::get_bit(mpz const & a, unsigned index) { if (is_small(a)) { - SASSERT(a.m_val >= 0); - if (index >= 8*sizeof(digit_t)) + int64_t v = a.value(); + SASSERT(v >= 0); + if (index >= 64) return false; - return 0 != (a.m_val & (1ull << (digit_t)index)); + return 0 != (v & (1ull << index)); } unsigned i = index / (sizeof(digit_t)*8); unsigned o = index % (sizeof(digit_t)*8); #ifndef _MP_GMP - mpz_cell * cell_a = a.m_ptr; + mpz_cell * cell_a = a.ptr(); unsigned sz = cell_a->m_size; if (sz*sizeof(digit_t)*8 <= index) return false; diff --git a/src/util/mpz.h b/src/util/mpz.h index 505bb177e..55896bbdc 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -68,15 +68,15 @@ class mpz_cell { /** \brief Multi-precision integer. - - If m_kind == mpz_small, it is a small number and the value is stored in m_val. - If m_kind == mpz_large, the value is stored in m_ptr and m_ptr != nullptr. - m_val contains the sign (-1 negative, 1 positive) - under winodws, m_ptr points to a mpz_cell that store the value. -*/ -enum mpz_kind { mpz_small = 0, mpz_large = 1}; -enum mpz_owner { mpz_self = 0, mpz_ext = 1}; + m_value encodes either a small integer (if the least significant bit is 1) + or a pointer to a mpz_cell structure (if the least significant bit is 0). + The last 3 bits of pointers are always 0 due to alignment, so we use them + to store additional information: + - bit 0: small/large info (0 = small, 1 = large) + - bit 1: sign bit (0 = non-negative, 1 = negative) + - bit 2: owner info (0 = owned, 1 = external) +*/ class mpz { #ifndef _MP_GMP @@ -84,11 +84,74 @@ class mpz { #else typedef mpz_t mpz_type; #endif +private: + uintptr_t m_value = 0; + + static constexpr uintptr_t LARGE_BIT = 0x1; + static constexpr uintptr_t SIGN_BIT = 0x2; + static constexpr uintptr_t OWNER_BIT = 0x4; + static constexpr uintptr_t MPZ_PTR_MASK = ~static_cast(0x7); + + // Small integers are stored shifted left by 1, so we have (sizeof(uintptr_t)*8 - 1) bits available + // This gives us: + // - On 32-bit platforms: 31 bits, range [-2^30, 2^30-1] + // - On 64-bit platforms: 63 bits, range [-2^62, 2^62-1] + static constexpr int SMALL_BITS = sizeof(uintptr_t) * 8 - 1; + + // Maximum and minimum values that can be stored as small integers + static constexpr int64_t SMALL_INT_MAX = (static_cast(1) << (SMALL_BITS - 1)) - 1; + static constexpr int64_t SMALL_INT_MIN = -(static_cast(1) << (SMALL_BITS - 1)); + + static bool fits_in_small(int64_t v) { + return v >= SMALL_INT_MIN && v <= SMALL_INT_MAX; + } + + static bool fits_in_small(uint64_t v) { + return v <= static_cast(SMALL_INT_MAX); + } + + static bool fits_in_small(int v) { + return fits_in_small(static_cast(v)); + } + + static bool fits_in_small(unsigned int v) { + return fits_in_small(static_cast(v)); + } + + mpz_type * ptr() const { + SASSERT(!is_small()); + return reinterpret_cast(m_value & MPZ_PTR_MASK); + } + + void set_ptr(mpz_type* p, bool is_negative, bool is_external) { + SASSERT(is_small()); + SASSERT((reinterpret_cast(p) & 0x7) == 0); // Check alignment + m_value = reinterpret_cast(p) | LARGE_BIT; + if (is_negative) + m_value |= SIGN_BIT; + if (is_external) + m_value |= OWNER_BIT; + } + + int get_sign() const { + SASSERT(!is_small()); + return (m_value & SIGN_BIT) ? -1 : 1; + } + + void set_sign(int s) { + SASSERT(!is_small()); + if (s < 0) + m_value |= SIGN_BIT; + else + m_value &= ~SIGN_BIT; + } + + bool is_external() const { + SASSERT(!is_small()); + return (m_value & OWNER_BIT) != 0; + } + protected: - int m_val; - unsigned m_kind:1; - unsigned m_owner:1; - mpz_type * m_ptr; friend class mpz_manager; friend class mpz_manager; friend class mpq_manager; @@ -98,42 +161,65 @@ protected: friend class mpbq_manager; friend class mpz_stack; public: - mpz(int v = 0) noexcept : m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {} - mpz(mpz_type* ptr) noexcept : m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr); } - mpz(mpz && other) noexcept : mpz() { swap(other); } + mpz(int v = 0) noexcept : m_value(static_cast(static_cast(v)) << 1) { + // On 32-bit platforms, INT_MIN doesn't fit in 31 bits. This constructor should only be used + // with values that fit, or the caller should use set_big_i64. + SASSERT(fits_in_small(v)); + } + + mpz(mpz_type* ptr) noexcept { + SASSERT(ptr); + set_ptr(ptr, false, true); // external pointer, non-negative + } + + mpz(mpz && other) noexcept : m_value(other.m_value) { + other.m_value = 0; // reset other to small + } mpz& operator=(mpz const& other) = delete; + mpz& operator=(mpz &&other) noexcept { - swap(other); + std::swap(m_value, other.m_value); return *this; } - void swap(mpz & other) noexcept { - std::swap(m_val, other.m_val); - std::swap(m_ptr, other.m_ptr); - unsigned o = m_owner; m_owner = other.m_owner; other.m_owner = o; - unsigned k = m_kind; m_kind = other.m_kind; other.m_kind = k; - } - void set(int v) { - m_val = v; - m_kind = mpz_small; + SASSERT(is_small()); + m_value = static_cast(static_cast(v)) << 1; } - inline bool is_small() const { return m_kind == mpz_small; } + void set64(int64_t v) { + SASSERT(fits_in_small(v)); + m_value = static_cast(static_cast(v)) << 1; + } - inline int value() const { SASSERT(is_small()); return m_val; } + void swap(mpz & other) noexcept { + std::swap(m_value, other.m_value); + } - inline int sign() const { SASSERT(!is_small()); return m_val; } + inline bool is_small() const { + return (m_value & LARGE_BIT) == 0; + } + + inline int64_t value() const { + SASSERT(is_small()); + // Decode small integer: shift right by 1 (arithmetic shift to preserve sign) + return static_cast(static_cast(m_value) >> 1); + } + + inline int sign() const { + SASSERT(!is_small()); + return get_sign(); + } }; #ifndef _MP_GMP class mpz_stack : public mpz { static const unsigned capacity = 8; - unsigned char m_bytes[sizeof(mpz_cell) + sizeof(digit_t) * capacity]; + alignas(8) unsigned char m_bytes[sizeof(mpz_cell) + sizeof(digit_t) * capacity]; public: mpz_stack():mpz(reinterpret_cast(m_bytes)) { - m_ptr->m_capacity = capacity; + ptr()->m_capacity = capacity; } }; #else @@ -169,16 +255,11 @@ class mpz_manager { // make sure that n is a big number and has capacity equal to at least c. void allocate_if_needed(mpz & n, unsigned c) { if (m_init_cell_capacity > c) c = m_init_cell_capacity; - if (n.m_ptr == nullptr || capacity(n) < c) { + if (n.is_small() || n.ptr() == nullptr || capacity(n) < c) { deallocate(n); - n.m_val = 1; - n.m_kind = mpz_large; - n.m_owner = mpz_self; - n.m_ptr = allocate(c); - } - else { - n.m_kind = mpz_large; + n.set_ptr(allocate(c), false, false); // positive, owned } + // else already has enough capacity, keep as large } void deallocate(bool is_heap, mpz_cell * ptr); @@ -230,27 +311,29 @@ class mpz_manager { } } - void clear(mpz& n) { if (n.m_ptr) { mpz_clear(*n.m_ptr); }} + void clear(mpz& n) { if (!n.is_small() && n.ptr()) { mpz_clear(*n.ptr()); }} #endif void deallocate(mpz& n) { - if (n.m_ptr) { - deallocate(n.m_owner == mpz_self, n.m_ptr); - n.m_ptr = nullptr; - n.m_kind = mpz_small; + if (!n.is_small()) { + auto* p = n.ptr(); + if (p) { + deallocate(!n.is_external(), p); + n.m_value = 0; // reset to small + } } } mpz m_two64; - static int64_t i64(mpz const & a) { return static_cast(a.value()); } + static int64_t i64(mpz const & a) { return a.value(); } void set_big_i64(mpz & c, int64_t v); void set_i64(mpz & c, int64_t v) { - if (v >= INT_MIN && v <= INT_MAX) { - c.set(static_cast(v)); + if (mpz::fits_in_small(v) && is_small(c)) { + c.set64(v); } else { set_big_i64(c, v); @@ -262,11 +345,20 @@ class mpz_manager { #ifndef _MP_GMP - static unsigned capacity(mpz const & c) { return c.m_ptr->m_capacity; } + static unsigned capacity(mpz const & c) { + SASSERT(!c.is_small()); + return c.ptr()->m_capacity; + } - static unsigned size(mpz const & c) { return c.m_ptr->m_size; } + static unsigned size(mpz const & c) { + SASSERT(!c.is_small()); + return c.ptr()->m_size; + } - static digit_t * digits(mpz const & c) { return c.m_ptr->m_digits; } + static digit_t * digits(mpz const & c) { + SASSERT(!c.is_small()); + return c.ptr()->m_digits; + } // Return true if the absolute value fits in a UINT64 static bool is_abs_uint64(mpz const & a) { @@ -282,7 +374,7 @@ class mpz_manager { static uint64_t big_abs_to_uint64(mpz const & a) { SASSERT(is_abs_uint64(a)); SASSERT(!is_small(a)); - if (a.m_ptr->m_size == 1) + if (a.ptr()->m_size == 1) return digits(a)[0]; if (sizeof(digit_t) == sizeof(uint64_t)) // 64-bit machine @@ -307,26 +399,37 @@ class mpz_manager { void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell, mpz_cell* reserve) { if (is_small(a)) { - if (a.value() == INT_MIN) { + int64_t val = a.value(); + cell = reserve; + if (val < 0) { sign = -1; - cell = m_int_min.m_ptr; - } - else { - cell = reserve; - cell->m_size = 1; - if (a.value() < 0) { - sign = -1; - cell->m_digits[0] = -a.value(); + uint64_t abs_val = static_cast(-val); + if (sizeof(digit_t) == sizeof(uint64_t)) { + cell->m_size = 1; + cell->m_digits[0] = static_cast(abs_val); } else { - sign = 1; - cell->m_digits[0] = a.value(); + cell->m_digits[0] = static_cast(abs_val); + cell->m_digits[1] = static_cast(abs_val >> 32); + cell->m_size = (abs_val >> 32) == 0 ? 1 : 2; + } + } + else { + sign = 1; + if (sizeof(digit_t) == sizeof(uint64_t)) { + cell->m_size = 1; + cell->m_digits[0] = static_cast(val); + } + else { + cell->m_digits[0] = static_cast(val); + cell->m_digits[1] = static_cast(val >> 32); + cell->m_size = (val >> 32) == 0 ? 1 : 2; } } } else { sign = a.sign(); - cell = a.m_ptr; + cell = a.ptr(); } } @@ -343,12 +446,10 @@ class mpz_manager { }; void mk_big(mpz & a) { - if (a.m_ptr == nullptr) { - a.m_val = 0; - a.m_ptr = allocate(); - a.m_owner = mpz_self; + if (a.is_small()) { + a.set_ptr(allocate(), false, false); // positive, owned } - a.m_kind = mpz_large; + // else already large with valid pointer } @@ -448,13 +549,15 @@ public: static bool is_zero(mpz const & a) { return sign(a) == 0; } static int sign(mpz const & a) { -#ifndef _MP_GMP - return a.m_val; -#else - if (is_small(a)) - return a.m_val; + if (is_small(a)) { + int v = a.value(); + return (v > 0) - (v < 0); // Returns -1, 0, or 1 + } else - return mpz_sgn(*a.m_ptr); +#ifndef _MP_GMP + return a.sign(); +#else + return mpz_sgn(*a.ptr()); #endif } @@ -537,14 +640,22 @@ public: } void set(mpz & a, int val) { - a.set(val); + // On 32-bit platforms, int can be outside small range + if (mpz::fits_in_small(val) && is_small(a)) { + a.set(val); + } + else { + set_i64(a, val); + } } void set(mpz & a, unsigned val) { - if (val <= INT_MAX) - set(a, static_cast(val)); - else - set(a, static_cast(static_cast(val))); + if (mpz::fits_in_small(val) && is_small(a)) { + a.set(static_cast(val)); + } + else { + set_i64(a, static_cast(val)); + } } void set(mpz & a, char const * val); @@ -554,8 +665,8 @@ public: } void set(mpz & a, uint64_t val) { - if (val < INT_MAX) { - a.set(static_cast(val)); + if (mpz::fits_in_small(val) && is_small(a)) { + a.set64(static_cast(val)); } else { set_big_ui64(a, val); @@ -625,7 +736,7 @@ public: #else if (is_small(a)) return a.value() == 1; - return mpz_cmp_si(*a.m_ptr, 1) == 0; + return mpz_cmp_si(*a.ptr(), 1) == 0; #endif } @@ -635,7 +746,7 @@ public: #else if (is_small(a)) return a.value() == -1; - return mpz_cmp_si(*a.m_ptr, -1) == 0; + return mpz_cmp_si(*a.ptr(), -1) == 0; #endif } @@ -713,7 +824,7 @@ public: #ifndef _MP_GMP return !(0x1 & digits(a)[0]); #else - return mpz_even_p(*a.m_ptr); + return mpz_even_p(*a.ptr()); #endif }