From 09370d7782bf6c68902876b713c10993c12f9c18 Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Sat, 31 Jan 2026 10:36:43 +0000 Subject: [PATCH] optimize has_sign_bit and mod2k to not compute powers of two this is very useful for bitvectors of large bitwidths --- src/ast/bv_decl_plugin.cpp | 4 +- src/ast/bv_decl_plugin.h | 3 +- src/util/mpq.h | 3 ++ src/util/mpz.cpp | 100 ++++++++++++++++++++++++++++++------- src/util/mpz.h | 13 +++-- src/util/rational.h | 8 +-- 6 files changed, 97 insertions(+), 34 deletions(-) diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index 28fbb9fbb..558870ca8 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -836,9 +836,7 @@ rational bv_recognizers::norm(rational const & val, unsigned bv_size, bool is_si bool bv_recognizers::has_sign_bit(rational const & n, unsigned bv_size) const { SASSERT(bv_size > 0); - rational m = norm(n, bv_size, false); - rational p = rational::power_of_two(bv_size - 1); - return m >= p; + return numerator(n).get_bit(bv_size - 1) == 1; } bool bv_recognizers::is_bv_sort(sort const * s) const { diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 916910087..6e4858c20 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -461,8 +461,7 @@ public: MATCH_UNARY(is_int2bv); bool is_bit2bool(expr* e, expr*& bv, unsigned& idx) const; - rational norm(rational const & val, unsigned bv_size, bool is_signed) const ; - rational norm(rational const & val, unsigned bv_size) const { return norm(val, bv_size, false); } + rational norm(rational const & val, unsigned bv_size, bool is_signed = false) const ; bool has_sign_bit(rational const & n, unsigned bv_size) const; }; diff --git a/src/util/mpq.h b/src/util/mpq.h index 212fa1c97..03a51d7d1 100644 --- a/src/util/mpq.h +++ b/src/util/mpq.h @@ -30,6 +30,7 @@ public: mpq(int v) : m_num(v) {} mpq() = default; mpq(mpq &&) noexcept = default; + mpq(mpz && n) noexcept : m_num(std::move(n)) {} mpq & operator=(mpq&&) = default; mpq & operator=(mpq const&) = delete; mpz const & numerator() const { return m_num; } @@ -558,6 +559,8 @@ public: mod(a.m_num, b.m_num, c); } + mpz mod2k(mpz const & a, unsigned k) { return mpz_manager::mod2k(a, k); } + static unsigned hash(mpz const & a) { return mpz_manager::hash(a); } static unsigned hash(mpq const & a) { return hash(a.m_num) + 3*hash(a.m_den); } diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index 6ee634a43..10c345b0c 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -633,6 +633,79 @@ void mpz_manager::mod(mpz const & a, mpz const & b, mpz & c) { STRACE(mpz, tout << to_string(c) << "\n";); } +template +mpz mpz_manager::mod2k(mpz const & a, unsigned k) { + if (is_zero(a)) + return 0; + + mpz result; + + if (is_small(a) && k < 64) { + uint64_t mask = ((1ULL << k) - 1); + uint64_t uval = static_cast(i64(a)); + set_i64(result, static_cast(uval & mask)); + return result; + } + + if (is_nonneg(a) && bitsize(a) <= k) { + return dup(a); + } + +#ifndef _MP_GMP + sign_cell ca(*this, a); + unsigned digit_size = sizeof(digit_t) * 8; + unsigned digit_count = k / digit_size; + unsigned rem_bits = k % digit_size; + unsigned total_digits = digit_count + (rem_bits > 0); + digit_t mask = (1ULL << rem_bits) - 1; + bool is_zero = true; + + allocate_if_needed(result, total_digits); + + // compute |a| mod 2^k- + for (unsigned i = 0, e = std::min(digit_count, ca.cell()->m_size); i < e; ++i) { + is_zero &= (digits(result)[i] = ca.cell()->m_digits[i]) == 0; + } + for (unsigned i = ca.cell()->m_size; i < total_digits; ++i) { + digits(result)[i] = 0; + } + + if (rem_bits > 0 && digit_count < ca.cell()->m_size) { + is_zero &= (digits(result)[digit_count] = ca.cell()->m_digits[digit_count] & mask) == 0; + } + result.m_ptr->m_size = total_digits; + + if (ca.sign() < 0 && !is_zero) { + // Negative case: if non-zero, result = 2^k - (|a| mod 2^k) + // which boils down to computing ~result + 1 + for (unsigned i = 0; i < total_digits; ++i) { + digits(result)[i] = ~digits(result)[i]; + } + + // Increment result + digit_t carry = 1; + for (unsigned i = 0; i < total_digits && carry; ++i) { + digit_t sum = digits(result)[i] + carry; + carry = sum < digits(result)[i]; + digits(result)[i] = sum; + } + + // Clamp to k bits + if (rem_bits != 0) { + digits(result)[digit_count] &= mask; + } + } + normalize(result); +#else + ensure_mpz_t a1(a); + mk_big(result); + MPZ_BEGIN_CRITICAL(); + mpz_tdiv_r_2exp(*result.m_ptr, a1(), k); + MPZ_END_CRITICAL(); +#endif + return result; +} + template void mpz_manager::neg(mpz & a) { STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";); @@ -1140,13 +1213,9 @@ unsigned mpz_manager::size_info(mpz const & a) { template struct mpz_manager::sz_lt { - mpz_manager & m; - mpz const * m_as; - - sz_lt(mpz_manager & _m, mpz const * as):m(_m), m_as(as) {} - + mpz const * m_as; bool operator()(unsigned p1, unsigned p2) { - return m.size_info(m_as[p1]) < m.size_info(m_as[p2]); + return size_info(m_as[p1]) < size_info(m_as[p2]); } }; @@ -1180,8 +1249,7 @@ void mpz_manager::gcd(unsigned sz, mpz const * as, mpz & g) { sbuffer p; for (i = 0; i < sz; ++i) p.push_back(i); - sz_lt lt(*this, as); - std::sort(p.begin(), p.end(), lt); + std::sort(p.begin(), p.end(), sz_lt{as}); TRACE(mpz_gcd, for (unsigned i = 0; i < sz; ++i) tout << p[i] << ":" << size_info(as[p[i]]) << " "; tout << "\n";); gcd(as[p[0]], as[p[1]], g); for (i = 2; i < sz; ++i) { @@ -1695,15 +1763,11 @@ void mpz_manager::display(std::ostream & out, mpz const & a) const { else { #ifndef _MP_GMP if (a.m_val < 0) - out << "-"; - if (sizeof(digit_t) == 4) { - sbuffer buffer(11*size(a), 0); - out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size()); - } - else { - sbuffer buffer(21*size(a), 0); - out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size()); - } + out << '-'; + + auto sz = sizeof(digit_t) == 4 ? 11 : 21; + sbuffer buffer(sz * size(a), 0); + out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size()); #else // GMP version size_t sz = mpz_sizeinbase(*a.m_ptr, 10) + 2; @@ -1835,7 +1899,7 @@ template std::string mpz_manager::to_string(mpz const & a) const { std::ostringstream buffer; display(buffer, a); - return buffer.str(); + return std::move(buffer).str(); } template diff --git a/src/util/mpz.h b/src/util/mpz.h index 88ee35e15..505bb177e 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -98,12 +98,9 @@ protected: friend class mpbq_manager; friend class mpz_stack; public: - mpz(int v):m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {} - mpz():m_val(0), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {} - mpz(mpz_type* ptr): m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr);} - mpz(mpz && other) noexcept : m_val(other.m_val), m_kind(other.m_kind), m_owner(other.m_owner), m_ptr(nullptr) { - std::swap(m_ptr, other.m_ptr); - } + mpz(int v = 0) noexcept : m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {} + mpz(mpz_type* ptr) noexcept : m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr); } + mpz(mpz && other) noexcept : mpz() { swap(other); } mpz& operator=(mpz const& other) = delete; mpz& operator=(mpz &&other) noexcept { @@ -390,7 +387,7 @@ class mpz_manager { int big_compare(mpz const & a, mpz const & b); public: - unsigned size_info(mpz const & a); + static unsigned size_info(mpz const & a); struct sz_lt; static bool precise() { return true; } @@ -438,6 +435,8 @@ public: void mod(mpz const & a, mpz const & b, mpz & c); + mpz mod2k(mpz const & a, unsigned k); + void neg(mpz & a); void abs(mpz & a); diff --git a/src/util/rational.h b/src/util/rational.h index 61569f40c..3649a9848 100644 --- a/src/util/rational.h +++ b/src/util/rational.h @@ -51,10 +51,10 @@ public: explicit rational(unsigned n) { m().set(m_val, n); } rational(int n, int d) { m().set(m_val, n, d); } - rational(mpq const & q) { m().set(m_val, q); } - + rational(mpq && q) noexcept : m_val(std::move(q)) {} rational(mpz const & z) { m().set(m_val, z); } + rational(mpz && z) noexcept : m_val(std::move(z)) {} explicit rational(double z) { UNREACHABLE(); } @@ -274,8 +274,8 @@ public: } friend inline rational mod2k(rational const & a, unsigned k) { - if (a.is_nonneg() && a.is_int() && a.bitsize() <= k) - return a; + if (a.is_int()) + return rational::m().mod2k(a.m_val.numerator(), k); return mod(a, power_of_two(k)); }