mirror of
https://github.com/Z3Prover/z3
synced 2026-02-01 14:57:57 +00:00
optimize has_sign_bit and mod2k to not compute powers of two
this is very useful for bitvectors of large bitwidths
This commit is contained in:
parent
74cbd6de32
commit
09370d7782
6 changed files with 97 additions and 34 deletions
|
|
@ -836,9 +836,7 @@ rational bv_recognizers::norm(rational const & val, unsigned bv_size, bool is_si
|
|||
|
||||
bool bv_recognizers::has_sign_bit(rational const & n, unsigned bv_size) const {
|
||||
SASSERT(bv_size > 0);
|
||||
rational m = norm(n, bv_size, false);
|
||||
rational p = rational::power_of_two(bv_size - 1);
|
||||
return m >= p;
|
||||
return numerator(n).get_bit(bv_size - 1) == 1;
|
||||
}
|
||||
|
||||
bool bv_recognizers::is_bv_sort(sort const * s) const {
|
||||
|
|
|
|||
|
|
@ -461,8 +461,7 @@ public:
|
|||
MATCH_UNARY(is_int2bv);
|
||||
bool is_bit2bool(expr* e, expr*& bv, unsigned& idx) const;
|
||||
|
||||
rational norm(rational const & val, unsigned bv_size, bool is_signed) const ;
|
||||
rational norm(rational const & val, unsigned bv_size) const { return norm(val, bv_size, false); }
|
||||
rational norm(rational const & val, unsigned bv_size, bool is_signed = false) const ;
|
||||
bool has_sign_bit(rational const & n, unsigned bv_size) const;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ public:
|
|||
mpq(int v) : m_num(v) {}
|
||||
mpq() = default;
|
||||
mpq(mpq &&) noexcept = default;
|
||||
mpq(mpz && n) noexcept : m_num(std::move(n)) {}
|
||||
mpq & operator=(mpq&&) = default;
|
||||
mpq & operator=(mpq const&) = delete;
|
||||
mpz const & numerator() const { return m_num; }
|
||||
|
|
@ -558,6 +559,8 @@ public:
|
|||
mod(a.m_num, b.m_num, c);
|
||||
}
|
||||
|
||||
mpz mod2k(mpz const & a, unsigned k) { return mpz_manager<SYNCH>::mod2k(a, k); }
|
||||
|
||||
static unsigned hash(mpz const & a) { return mpz_manager<SYNCH>::hash(a); }
|
||||
|
||||
static unsigned hash(mpq const & a) { return hash(a.m_num) + 3*hash(a.m_den); }
|
||||
|
|
|
|||
100
src/util/mpz.cpp
100
src/util/mpz.cpp
|
|
@ -633,6 +633,79 @@ void mpz_manager<SYNCH>::mod(mpz const & a, mpz const & b, mpz & c) {
|
|||
STRACE(mpz, tout << to_string(c) << "\n";);
|
||||
}
|
||||
|
||||
template<bool SYNCH>
|
||||
mpz mpz_manager<SYNCH>::mod2k(mpz const & a, unsigned k) {
|
||||
if (is_zero(a))
|
||||
return 0;
|
||||
|
||||
mpz result;
|
||||
|
||||
if (is_small(a) && k < 64) {
|
||||
uint64_t mask = ((1ULL << k) - 1);
|
||||
uint64_t uval = static_cast<uint64_t>(i64(a));
|
||||
set_i64(result, static_cast<int64_t>(uval & mask));
|
||||
return result;
|
||||
}
|
||||
|
||||
if (is_nonneg(a) && bitsize(a) <= k) {
|
||||
return dup(a);
|
||||
}
|
||||
|
||||
#ifndef _MP_GMP
|
||||
sign_cell ca(*this, a);
|
||||
unsigned digit_size = sizeof(digit_t) * 8;
|
||||
unsigned digit_count = k / digit_size;
|
||||
unsigned rem_bits = k % digit_size;
|
||||
unsigned total_digits = digit_count + (rem_bits > 0);
|
||||
digit_t mask = (1ULL << rem_bits) - 1;
|
||||
bool is_zero = true;
|
||||
|
||||
allocate_if_needed(result, total_digits);
|
||||
|
||||
// compute |a| mod 2^k-
|
||||
for (unsigned i = 0, e = std::min(digit_count, ca.cell()->m_size); i < e; ++i) {
|
||||
is_zero &= (digits(result)[i] = ca.cell()->m_digits[i]) == 0;
|
||||
}
|
||||
for (unsigned i = ca.cell()->m_size; i < total_digits; ++i) {
|
||||
digits(result)[i] = 0;
|
||||
}
|
||||
|
||||
if (rem_bits > 0 && digit_count < ca.cell()->m_size) {
|
||||
is_zero &= (digits(result)[digit_count] = ca.cell()->m_digits[digit_count] & mask) == 0;
|
||||
}
|
||||
result.m_ptr->m_size = total_digits;
|
||||
|
||||
if (ca.sign() < 0 && !is_zero) {
|
||||
// Negative case: if non-zero, result = 2^k - (|a| mod 2^k)
|
||||
// which boils down to computing ~result + 1
|
||||
for (unsigned i = 0; i < total_digits; ++i) {
|
||||
digits(result)[i] = ~digits(result)[i];
|
||||
}
|
||||
|
||||
// Increment result
|
||||
digit_t carry = 1;
|
||||
for (unsigned i = 0; i < total_digits && carry; ++i) {
|
||||
digit_t sum = digits(result)[i] + carry;
|
||||
carry = sum < digits(result)[i];
|
||||
digits(result)[i] = sum;
|
||||
}
|
||||
|
||||
// Clamp to k bits
|
||||
if (rem_bits != 0) {
|
||||
digits(result)[digit_count] &= mask;
|
||||
}
|
||||
}
|
||||
normalize(result);
|
||||
#else
|
||||
ensure_mpz_t a1(a);
|
||||
mk_big(result);
|
||||
MPZ_BEGIN_CRITICAL();
|
||||
mpz_tdiv_r_2exp(*result.m_ptr, a1(), k);
|
||||
MPZ_END_CRITICAL();
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
template<bool SYNCH>
|
||||
void mpz_manager<SYNCH>::neg(mpz & a) {
|
||||
STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";);
|
||||
|
|
@ -1140,13 +1213,9 @@ unsigned mpz_manager<SYNCH>::size_info(mpz const & a) {
|
|||
|
||||
template<bool SYNCH>
|
||||
struct mpz_manager<SYNCH>::sz_lt {
|
||||
mpz_manager<SYNCH> & m;
|
||||
mpz const * m_as;
|
||||
|
||||
sz_lt(mpz_manager<SYNCH> & _m, mpz const * as):m(_m), m_as(as) {}
|
||||
|
||||
mpz const * m_as;
|
||||
bool operator()(unsigned p1, unsigned p2) {
|
||||
return m.size_info(m_as[p1]) < m.size_info(m_as[p2]);
|
||||
return size_info(m_as[p1]) < size_info(m_as[p2]);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -1180,8 +1249,7 @@ void mpz_manager<SYNCH>::gcd(unsigned sz, mpz const * as, mpz & g) {
|
|||
sbuffer<unsigned, 1024> p;
|
||||
for (i = 0; i < sz; ++i)
|
||||
p.push_back(i);
|
||||
sz_lt lt(*this, as);
|
||||
std::sort(p.begin(), p.end(), lt);
|
||||
std::sort(p.begin(), p.end(), sz_lt{as});
|
||||
TRACE(mpz_gcd, for (unsigned i = 0; i < sz; ++i) tout << p[i] << ":" << size_info(as[p[i]]) << " "; tout << "\n";);
|
||||
gcd(as[p[0]], as[p[1]], g);
|
||||
for (i = 2; i < sz; ++i) {
|
||||
|
|
@ -1695,15 +1763,11 @@ void mpz_manager<SYNCH>::display(std::ostream & out, mpz const & a) const {
|
|||
else {
|
||||
#ifndef _MP_GMP
|
||||
if (a.m_val < 0)
|
||||
out << "-";
|
||||
if (sizeof(digit_t) == 4) {
|
||||
sbuffer<char, 1024> buffer(11*size(a), 0);
|
||||
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
|
||||
}
|
||||
else {
|
||||
sbuffer<char, 1024> buffer(21*size(a), 0);
|
||||
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
|
||||
}
|
||||
out << '-';
|
||||
|
||||
auto sz = sizeof(digit_t) == 4 ? 11 : 21;
|
||||
sbuffer<char, 1024> buffer(sz * size(a), 0);
|
||||
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
|
||||
#else
|
||||
// GMP version
|
||||
size_t sz = mpz_sizeinbase(*a.m_ptr, 10) + 2;
|
||||
|
|
@ -1835,7 +1899,7 @@ template<bool SYNCH>
|
|||
std::string mpz_manager<SYNCH>::to_string(mpz const & a) const {
|
||||
std::ostringstream buffer;
|
||||
display(buffer, a);
|
||||
return buffer.str();
|
||||
return std::move(buffer).str();
|
||||
}
|
||||
|
||||
template<bool SYNCH>
|
||||
|
|
|
|||
|
|
@ -98,12 +98,9 @@ protected:
|
|||
friend class mpbq_manager;
|
||||
friend class mpz_stack;
|
||||
public:
|
||||
mpz(int v):m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
|
||||
mpz():m_val(0), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
|
||||
mpz(mpz_type* ptr): m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr);}
|
||||
mpz(mpz && other) noexcept : m_val(other.m_val), m_kind(other.m_kind), m_owner(other.m_owner), m_ptr(nullptr) {
|
||||
std::swap(m_ptr, other.m_ptr);
|
||||
}
|
||||
mpz(int v = 0) noexcept : m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
|
||||
mpz(mpz_type* ptr) noexcept : m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr); }
|
||||
mpz(mpz && other) noexcept : mpz() { swap(other); }
|
||||
|
||||
mpz& operator=(mpz const& other) = delete;
|
||||
mpz& operator=(mpz &&other) noexcept {
|
||||
|
|
@ -390,7 +387,7 @@ class mpz_manager {
|
|||
int big_compare(mpz const & a, mpz const & b);
|
||||
|
||||
public:
|
||||
unsigned size_info(mpz const & a);
|
||||
static unsigned size_info(mpz const & a);
|
||||
struct sz_lt;
|
||||
|
||||
static bool precise() { return true; }
|
||||
|
|
@ -438,6 +435,8 @@ public:
|
|||
|
||||
void mod(mpz const & a, mpz const & b, mpz & c);
|
||||
|
||||
mpz mod2k(mpz const & a, unsigned k);
|
||||
|
||||
void neg(mpz & a);
|
||||
|
||||
void abs(mpz & a);
|
||||
|
|
|
|||
|
|
@ -51,10 +51,10 @@ public:
|
|||
explicit rational(unsigned n) { m().set(m_val, n); }
|
||||
|
||||
rational(int n, int d) { m().set(m_val, n, d); }
|
||||
|
||||
rational(mpq const & q) { m().set(m_val, q); }
|
||||
|
||||
rational(mpq && q) noexcept : m_val(std::move(q)) {}
|
||||
rational(mpz const & z) { m().set(m_val, z); }
|
||||
rational(mpz && z) noexcept : m_val(std::move(z)) {}
|
||||
|
||||
explicit rational(double z) { UNREACHABLE(); }
|
||||
|
||||
|
|
@ -274,8 +274,8 @@ public:
|
|||
}
|
||||
|
||||
friend inline rational mod2k(rational const & a, unsigned k) {
|
||||
if (a.is_nonneg() && a.is_int() && a.bitsize() <= k)
|
||||
return a;
|
||||
if (a.is_int())
|
||||
return rational::m().mod2k(a.m_val.numerator(), k);
|
||||
return mod(a, power_of_two(k));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue