3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-02-01 14:57:57 +00:00

optimize has_sign_bit and mod2k to not compute powers of two

this is very useful for bitvectors of large bitwidths
This commit is contained in:
Nuno Lopes 2026-01-31 10:36:43 +00:00
parent 74cbd6de32
commit 09370d7782
6 changed files with 97 additions and 34 deletions

View file

@ -836,9 +836,7 @@ rational bv_recognizers::norm(rational const & val, unsigned bv_size, bool is_si
bool bv_recognizers::has_sign_bit(rational const & n, unsigned bv_size) const {
SASSERT(bv_size > 0);
rational m = norm(n, bv_size, false);
rational p = rational::power_of_two(bv_size - 1);
return m >= p;
return numerator(n).get_bit(bv_size - 1) == 1;
}
bool bv_recognizers::is_bv_sort(sort const * s) const {

View file

@ -461,8 +461,7 @@ public:
MATCH_UNARY(is_int2bv);
bool is_bit2bool(expr* e, expr*& bv, unsigned& idx) const;
rational norm(rational const & val, unsigned bv_size, bool is_signed) const ;
rational norm(rational const & val, unsigned bv_size) const { return norm(val, bv_size, false); }
rational norm(rational const & val, unsigned bv_size, bool is_signed = false) const ;
bool has_sign_bit(rational const & n, unsigned bv_size) const;
};

View file

@ -30,6 +30,7 @@ public:
mpq(int v) : m_num(v) {}
mpq() = default;
mpq(mpq &&) noexcept = default;
mpq(mpz && n) noexcept : m_num(std::move(n)) {}
mpq & operator=(mpq&&) = default;
mpq & operator=(mpq const&) = delete;
mpz const & numerator() const { return m_num; }
@ -558,6 +559,8 @@ public:
mod(a.m_num, b.m_num, c);
}
mpz mod2k(mpz const & a, unsigned k) { return mpz_manager<SYNCH>::mod2k(a, k); }
static unsigned hash(mpz const & a) { return mpz_manager<SYNCH>::hash(a); }
static unsigned hash(mpq const & a) { return hash(a.m_num) + 3*hash(a.m_den); }

View file

@ -633,6 +633,79 @@ void mpz_manager<SYNCH>::mod(mpz const & a, mpz const & b, mpz & c) {
STRACE(mpz, tout << to_string(c) << "\n";);
}
template<bool SYNCH>
mpz mpz_manager<SYNCH>::mod2k(mpz const & a, unsigned k) {
if (is_zero(a))
return 0;
mpz result;
if (is_small(a) && k < 64) {
uint64_t mask = ((1ULL << k) - 1);
uint64_t uval = static_cast<uint64_t>(i64(a));
set_i64(result, static_cast<int64_t>(uval & mask));
return result;
}
if (is_nonneg(a) && bitsize(a) <= k) {
return dup(a);
}
#ifndef _MP_GMP
sign_cell ca(*this, a);
unsigned digit_size = sizeof(digit_t) * 8;
unsigned digit_count = k / digit_size;
unsigned rem_bits = k % digit_size;
unsigned total_digits = digit_count + (rem_bits > 0);
digit_t mask = (1ULL << rem_bits) - 1;
bool is_zero = true;
allocate_if_needed(result, total_digits);
// compute |a| mod 2^k-
for (unsigned i = 0, e = std::min(digit_count, ca.cell()->m_size); i < e; ++i) {
is_zero &= (digits(result)[i] = ca.cell()->m_digits[i]) == 0;
}
for (unsigned i = ca.cell()->m_size; i < total_digits; ++i) {
digits(result)[i] = 0;
}
if (rem_bits > 0 && digit_count < ca.cell()->m_size) {
is_zero &= (digits(result)[digit_count] = ca.cell()->m_digits[digit_count] & mask) == 0;
}
result.m_ptr->m_size = total_digits;
if (ca.sign() < 0 && !is_zero) {
// Negative case: if non-zero, result = 2^k - (|a| mod 2^k)
// which boils down to computing ~result + 1
for (unsigned i = 0; i < total_digits; ++i) {
digits(result)[i] = ~digits(result)[i];
}
// Increment result
digit_t carry = 1;
for (unsigned i = 0; i < total_digits && carry; ++i) {
digit_t sum = digits(result)[i] + carry;
carry = sum < digits(result)[i];
digits(result)[i] = sum;
}
// Clamp to k bits
if (rem_bits != 0) {
digits(result)[digit_count] &= mask;
}
}
normalize(result);
#else
ensure_mpz_t a1(a);
mk_big(result);
MPZ_BEGIN_CRITICAL();
mpz_tdiv_r_2exp(*result.m_ptr, a1(), k);
MPZ_END_CRITICAL();
#endif
return result;
}
template<bool SYNCH>
void mpz_manager<SYNCH>::neg(mpz & a) {
STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";);
@ -1140,13 +1213,9 @@ unsigned mpz_manager<SYNCH>::size_info(mpz const & a) {
template<bool SYNCH>
struct mpz_manager<SYNCH>::sz_lt {
mpz_manager<SYNCH> & m;
mpz const * m_as;
sz_lt(mpz_manager<SYNCH> & _m, mpz const * as):m(_m), m_as(as) {}
mpz const * m_as;
bool operator()(unsigned p1, unsigned p2) {
return m.size_info(m_as[p1]) < m.size_info(m_as[p2]);
return size_info(m_as[p1]) < size_info(m_as[p2]);
}
};
@ -1180,8 +1249,7 @@ void mpz_manager<SYNCH>::gcd(unsigned sz, mpz const * as, mpz & g) {
sbuffer<unsigned, 1024> p;
for (i = 0; i < sz; ++i)
p.push_back(i);
sz_lt lt(*this, as);
std::sort(p.begin(), p.end(), lt);
std::sort(p.begin(), p.end(), sz_lt{as});
TRACE(mpz_gcd, for (unsigned i = 0; i < sz; ++i) tout << p[i] << ":" << size_info(as[p[i]]) << " "; tout << "\n";);
gcd(as[p[0]], as[p[1]], g);
for (i = 2; i < sz; ++i) {
@ -1695,15 +1763,11 @@ void mpz_manager<SYNCH>::display(std::ostream & out, mpz const & a) const {
else {
#ifndef _MP_GMP
if (a.m_val < 0)
out << "-";
if (sizeof(digit_t) == 4) {
sbuffer<char, 1024> buffer(11*size(a), 0);
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
}
else {
sbuffer<char, 1024> buffer(21*size(a), 0);
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
}
out << '-';
auto sz = sizeof(digit_t) == 4 ? 11 : 21;
sbuffer<char, 1024> buffer(sz * size(a), 0);
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
#else
// GMP version
size_t sz = mpz_sizeinbase(*a.m_ptr, 10) + 2;
@ -1835,7 +1899,7 @@ template<bool SYNCH>
std::string mpz_manager<SYNCH>::to_string(mpz const & a) const {
std::ostringstream buffer;
display(buffer, a);
return buffer.str();
return std::move(buffer).str();
}
template<bool SYNCH>

View file

@ -98,12 +98,9 @@ protected:
friend class mpbq_manager;
friend class mpz_stack;
public:
mpz(int v):m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
mpz():m_val(0), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
mpz(mpz_type* ptr): m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr);}
mpz(mpz && other) noexcept : m_val(other.m_val), m_kind(other.m_kind), m_owner(other.m_owner), m_ptr(nullptr) {
std::swap(m_ptr, other.m_ptr);
}
mpz(int v = 0) noexcept : m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
mpz(mpz_type* ptr) noexcept : m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr); }
mpz(mpz && other) noexcept : mpz() { swap(other); }
mpz& operator=(mpz const& other) = delete;
mpz& operator=(mpz &&other) noexcept {
@ -390,7 +387,7 @@ class mpz_manager {
int big_compare(mpz const & a, mpz const & b);
public:
unsigned size_info(mpz const & a);
static unsigned size_info(mpz const & a);
struct sz_lt;
static bool precise() { return true; }
@ -438,6 +435,8 @@ public:
void mod(mpz const & a, mpz const & b, mpz & c);
mpz mod2k(mpz const & a, unsigned k);
void neg(mpz & a);
void abs(mpz & a);

View file

@ -51,10 +51,10 @@ public:
explicit rational(unsigned n) { m().set(m_val, n); }
rational(int n, int d) { m().set(m_val, n, d); }
rational(mpq const & q) { m().set(m_val, q); }
rational(mpq && q) noexcept : m_val(std::move(q)) {}
rational(mpz const & z) { m().set(m_val, z); }
rational(mpz && z) noexcept : m_val(std::move(z)) {}
explicit rational(double z) { UNREACHABLE(); }
@ -274,8 +274,8 @@ public:
}
friend inline rational mod2k(rational const & a, unsigned k) {
if (a.is_nonneg() && a.is_int() && a.bitsize() <= k)
return a;
if (a.is_int())
return rational::m().mod2k(a.m_val.numerator(), k);
return mod(a, power_of_two(k));
}