diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index 8445ec134..eca7436b7 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -225,7 +225,7 @@ mpz_manager::sign_cell::sign_cell(mpz_manager& m, mpz const& a): template void mpz_manager::del(mpz_manager* m, mpz & a) { - if (!a.is_small()) { + if (a.has_ptr()) { SASSERT(m); mpz::mpz_type* p = a.ptr(); m->deallocate(!a.is_external(), p); @@ -272,7 +272,7 @@ void mpz_manager::set_big_i64(mpz & c, int64_t v) { _v = v; } #ifndef _MP_GMP - if (c.is_small()) { + if (!c.has_ptr()) { c.set_ptr(allocate(m_init_cell_capacity), sign, false); } else { c.set_sign(sign ? -1 : 1); @@ -290,7 +290,7 @@ void mpz_manager::set_big_i64(mpz & c, int64_t v) { c.ptr()->m_size = digits(c)[1] == 0 ? 1 : 2; } #else - if (c.is_small()) { + if (!c.has_ptr()) { c.set_ptr(allocate(), false, false); } mpz_set_ui(*c.ptr(), static_cast(_v)); @@ -310,7 +310,7 @@ void mpz_manager::set_big_i64(mpz & c, int64_t v) { template void mpz_manager::set_big_ui64(mpz & c, uint64_t v) { #ifndef _MP_GMP - if (c.is_small()) { + if (!c.has_ptr()) { c.set_ptr(allocate(m_init_cell_capacity), false, false); // positive, owned } else { c.set_sign(1); // positive @@ -328,7 +328,7 @@ void mpz_manager::set_big_ui64(mpz & c, uint64_t v) { c.ptr()->m_size = digits(c)[1] == 0 ? 1 : 2; } #else - if (c.is_small()) { + if (!c.has_ptr()) { c.set_ptr(allocate(), false, false); // positive, owned } mpz_set_ui(*c.ptr(), static_cast(v)); @@ -344,7 +344,7 @@ void mpz_manager::set_big_ui64(mpz & c, uint64_t v) { template mpz_manager::ensure_mpz_t::ensure_mpz_t(mpz const& a) { - if (is_small(a)) { + if (!a.has_ptr()) { m_result = &m_local; mpz_init(m_local); mpz_set_si(m_local, a.value()); @@ -377,15 +377,15 @@ void mpz_manager::set(mpz_cell& src, mpz & a, int sign, unsigned sz) { unsigned d = src.m_digits[0]; int64_t val = sign < 0 ? -static_cast(d) : static_cast(d); - if (i == 1 && mpz::fits_in_small(val) && a.is_small()) { - a.set(val); + if (i == 1 && mpz::fits_in_small(val) && !a.has_ptr()) { + set(a, val); return; } set_digits(a, i, src.m_digits); a.set_sign(sign); - SASSERT(!a.is_small()); + SASSERT(a.has_ptr()); } #endif @@ -668,7 +668,7 @@ template void mpz_manager::neg(mpz & a) { STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";); if (is_small(a)) { - a.set(-a.value()); + set(a, -a.value()); } #ifndef _MP_GMP else { @@ -687,7 +687,7 @@ void mpz_manager::abs(mpz & a) { if (is_small(a)) { int64_t v = a.value(); if (v < 0) { - a.set(-v); + set(a, -v); } } else { @@ -947,7 +947,7 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { // reset least significant bit if (is_small(v)) - v.set(v.value() & ~1); + set(v, v.value() & ~1); else v.ptr()->m_digits[0] &= ~static_cast(1); k_v = power_of_two_multiple(v); @@ -1079,7 +1079,7 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { } } sign_cell ca(*this, a1); - SASSERT(!is_small(b1)); + SASSERT(b1.has_ptr()); a_sz = ca.cell()->m_size; b_sz = b1.ptr()->m_size; SASSERT(b_sz <= a_sz); @@ -1145,7 +1145,7 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { template unsigned mpz_manager::size_info(mpz const & a) { - if (is_small(a)) + if (!a.has_ptr()) return 1; #ifndef _MP_GMP return a.ptr()->m_size + 1; @@ -1343,7 +1343,7 @@ void mpz_manager::bitwise_or(mpz const & a, mpz const & b, mpz & c) { SASSERT(is_nonneg(b)); TRACE(mpz, tout << "is_small(a): " << is_small(a) << ", is_small(b): " << is_small(b) << "\n";); if (is_small(a) && is_small(b)) { - c.set(a.value() | b.value()); + set(c, a.value() | b.value()); } else { #ifndef _MP_GMP @@ -1388,7 +1388,7 @@ void mpz_manager::bitwise_and(mpz const & a, mpz const & b, mpz & c) { SASSERT(is_nonneg(a)); SASSERT(is_nonneg(b)); if (is_small(a) && is_small(b)) { - c.set(a.value() & b.value()); + set(c, a.value() & b.value()); } else { #ifndef _MP_GMP @@ -1535,16 +1535,13 @@ int mpz_manager::big_compare(mpz const & a, mpz const & b) { template bool mpz_manager::is_uint64(mpz const & a) const { -#ifndef _MP_GMP if (is_small(a)) return a.value() >= 0; +#ifndef _MP_GMP if (a.sign() < 0) return false; return size(a) <= (sizeof(digit_t) == sizeof(uint64_t) ? 1 : 2); #else - // GMP version - if (is_small(a)) - return a.value() >= 0; return is_nonneg(a) && mpz_cmp(*a.ptr(), m_uint64_max) <= 0; #endif } @@ -1831,7 +1828,7 @@ void mpz_manager::power(mpz const & a, unsigned p, mpz & b) { if (is_small(a)) { if (a.value() == 2) { if (p < 8 * sizeof(int) - 1) { - b.set(1 << p); + set(b, 1 << p); } else { unsigned sz = p/(8 * sizeof(digit_t)) + 1; @@ -1927,9 +1924,9 @@ void mpz_manager::ensure_capacity(mpz & a, unsigned capacity) { if (capacity < m_init_cell_capacity) capacity = m_init_cell_capacity; - if (is_small(a)) { + if (!a.has_ptr()) { int64_t val = a.value(); - uint64_t abs_val = static_cast(-val); + uint64_t abs_val = static_cast(val < 0 ? -val : val); allocate_if_needed(a, capacity); if (sizeof(digit_t) == sizeof(uint64_t)) { a.ptr()->m_digits[0] = static_cast(abs_val); @@ -1974,10 +1971,10 @@ void mpz_manager::machine_div2k(mpz & a, unsigned k) { if (k < 64) { int64_t twok = 1ull << ((int64_t)k); int64_t val = a.value(); - a.set(val/twok); + set(a, val/twok); } else { - a.set(0); + set(a, 0); } return; } @@ -2040,20 +2037,22 @@ template void mpz_manager::mul2k(mpz & a, unsigned k) { if (k == 0 || is_zero(a)) return; - if (is_small(a) && k < 32) { - set(a, a.value() * (static_cast(1) << k)); + + int64_t result; + if (is_small(a) && k < 64 && !mul_overflows(a.value(), static_cast(1) << k, result)) { + set(a, result); return; } #ifndef _MP_GMP TRACE(mpz_mul2k, tout << "mul2k\na: " << to_string(a) << "\nk: " << k << "\n";); unsigned word_shift = k / (8 * sizeof(digit_t)); unsigned bit_shift = k % (8 * sizeof(digit_t)); - unsigned old_sz = is_small(a) ? 1 : a.ptr()->m_size; + unsigned old_sz = a.has_ptr() ? size(a) : (sizeof(int64_t) / sizeof(digit_t)); unsigned new_sz = old_sz + word_shift + 1; ensure_capacity(a, new_sz); TRACE(mpz_mul2k, tout << "word_shift: " << word_shift << "\nbit_shift: " << bit_shift << "\nold_sz: " << old_sz << "\nnew_sz: " << new_sz << "\na after ensure capacity:\n" << to_string(a) << "\n";); - SASSERT(!is_small(a)); + SASSERT(a.has_ptr()); mpz_cell * cell_a = a.ptr(); old_sz = cell_a->m_size; digit_t * ds = cell_a->m_digits; diff --git a/src/util/mpz.h b/src/util/mpz.h index e74da812a..d1746bf35 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -57,6 +57,7 @@ class mpz_cell { unsigned m_size; unsigned m_capacity; digit_t m_digits[0]; + friend class mpz; friend class mpz_manager; friend class mpz_manager; friend class mpz_stack; @@ -110,12 +111,12 @@ private: } mpz_type * ptr() const { - SASSERT(!is_small()); + SASSERT(has_ptr()); return reinterpret_cast(m_value & MPZ_PTR_MASK); } void set_ptr(mpz_type* p, bool is_negative, bool is_external) { - SASSERT(is_small()); + SASSERT(!has_ptr()); SASSERT((reinterpret_cast(p) & 0x7) == 0); // Check alignment m_value = reinterpret_cast(p) | LARGE_BIT; if (is_negative) @@ -124,13 +125,8 @@ private: m_value |= EXTERNAL_BIT; } - int get_sign() const { - SASSERT(!is_small()); - return (m_value & SIGN_BIT) ? -1 : 1; - } - void set_sign(int s) { - SASSERT(!is_small()); + SASSERT(has_ptr()); if (s < 0) m_value |= SIGN_BIT; else @@ -138,7 +134,7 @@ private: } bool is_external() const { - SASSERT(!is_small()); + SASSERT(has_ptr()); return (m_value & EXTERNAL_BIT) != 0; } @@ -180,24 +176,33 @@ public: } void set(int64_t v) { - SASSERT(is_small()); + SASSERT(!has_ptr()); SASSERT(fits_in_small(v)); m_value = static_cast(v) << 1; } + inline bool has_ptr() const { + return (m_value & LARGE_BIT) != 0; + } + inline bool is_small() const { - return (m_value & LARGE_BIT) == 0; + return !has_ptr() || ptr()->m_size == 1; } inline int64_t value() const { SASSERT(is_small()); + if (has_ptr()) { + // Small value stored in a single digit + int64_t v = static_cast(ptr()->m_digits[0]); + return sign() < 0 ? -v : v; + } // Decode small integer: shift right by 1 (arithmetic shift to preserve sign) return static_cast(static_cast(m_value) >> 1); } inline int sign() const { - SASSERT(!is_small()); - return get_sign(); + SASSERT(has_ptr()); + return (m_value & SIGN_BIT) ? -1 : 1; } }; @@ -243,7 +248,7 @@ class mpz_manager { // make sure that n is a big number and has capacity equal to at least c. void allocate_if_needed(mpz & n, unsigned c) { if (m_init_cell_capacity > c) c = m_init_cell_capacity; - if (n.is_small() || capacity(n) < c) { + if (!n.has_ptr() || capacity(n) < c) { deallocate(n); n.set_ptr(allocate(c), false, false); // positive, owned } @@ -298,11 +303,11 @@ class mpz_manager { } } - void clear(mpz& n) { if (!n.is_small()) { mpz_clear(*n.ptr()); }} + void clear(mpz& n) { if (n.has_ptr()) { mpz_clear(*n.ptr()); }} #endif void deallocate(mpz& n) { - if (!n.is_small()) { + if (n.has_ptr()) { deallocate(!n.is_external(), n.ptr()); n.m_value = 0; // reset to small } @@ -317,17 +322,17 @@ class mpz_manager { #ifndef _MP_GMP static unsigned capacity(mpz const & c) { - SASSERT(!c.is_small()); + SASSERT(c.has_ptr()); return c.ptr()->m_capacity; } static unsigned size(mpz const & c) { - SASSERT(!c.is_small()); + SASSERT(c.has_ptr()); return c.ptr()->m_size; } static digit_t * digits(mpz const & c) { - SASSERT(!c.is_small()); + SASSERT(c.has_ptr()); return c.ptr()->m_digits; } @@ -369,7 +374,7 @@ class mpz_manager { }; void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell, mpz_cell* reserve) { - if (is_small(a)) { + if (!a.has_ptr()) { int64_t val = a.value(); bool neg = val < 0; uint64_t abs_val = static_cast(neg ? -val : val); @@ -405,7 +410,7 @@ class mpz_manager { }; void mk_big(mpz & a) { - if (a.is_small()) { + if (!a.has_ptr()) { a.set_ptr(allocate(), false, false); // positive, owned } } @@ -503,13 +508,11 @@ public: static bool is_neg(mpz const & a) { return sign(a) < 0; } static bool is_zero(mpz const & a) { - if (a.is_small()) - return a.value() == 0; - return size(a) == 1 && digits(a)[0] == 0; + return a.is_small() && a.value() == 0; } static int sign(mpz const & a) { - if (is_small(a)) { + if (a.is_small()) { int64_t v = a.value(); return (v > 0) - (v < 0); // Returns -1, 0, or 1 } @@ -605,7 +608,7 @@ public: void set(mpz & a, unsigned val) { set(a, (uint64_t)val); } void set(mpz & a, int64_t val) { - if (mpz::fits_in_small(val) && is_small(a)) { + if (mpz::fits_in_small(val) && !a.has_ptr()) { a.set(val); } else { @@ -614,7 +617,7 @@ public: } void set(mpz & a, uint64_t val) { - if (mpz::fits_in_small(val) && is_small(a)) { + if (mpz::fits_in_small(val) && !a.has_ptr()) { a.set(static_cast(val)); } else { @@ -683,17 +686,19 @@ public: if (is_small(a)) return a.value() == 1; #ifndef _MP_GMP - return size(a) == 1 && digits(a)[0] == 1 && a.sign() > 0; + return false; #else return mpz_cmp_si(*a.ptr(), 1) == 0; #endif } + // best effort static bool is_minus_one(mpz const & a) { if (is_small(a)) return a.value() == -1; #ifndef _MP_GMP - return size(a) == 1 && digits(a)[0] == 1 && a.sign() < 0; + //return eq(a, mpz(-1)); + return false; #else return mpz_cmp_si(*a.ptr(), -1) == 0; #endif