mirror of
https://github.com/Z3Prover/z3
synced 2026-06-03 07:37:54 +00:00
optimize has_sign_bit and mod2k to not compute powers of two
this is very useful for bitvectors of large bitwidths
This commit is contained in:
parent
9771839005
commit
6c3f9a3540
6 changed files with 97 additions and 34 deletions
|
|
@ -836,9 +836,7 @@ rational bv_recognizers::norm(rational const & val, unsigned bv_size, bool is_si
|
||||||
|
|
||||||
bool bv_recognizers::has_sign_bit(rational const & n, unsigned bv_size) const {
|
bool bv_recognizers::has_sign_bit(rational const & n, unsigned bv_size) const {
|
||||||
SASSERT(bv_size > 0);
|
SASSERT(bv_size > 0);
|
||||||
rational m = norm(n, bv_size, false);
|
return numerator(n).get_bit(bv_size - 1) == 1;
|
||||||
rational p = rational::power_of_two(bv_size - 1);
|
|
||||||
return m >= p;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bv_recognizers::is_bv_sort(sort const * s) const {
|
bool bv_recognizers::is_bv_sort(sort const * s) const {
|
||||||
|
|
|
||||||
|
|
@ -461,8 +461,7 @@ public:
|
||||||
MATCH_UNARY(is_int2bv);
|
MATCH_UNARY(is_int2bv);
|
||||||
bool is_bit2bool(expr* e, expr*& bv, unsigned& idx) const;
|
bool is_bit2bool(expr* e, expr*& bv, unsigned& idx) const;
|
||||||
|
|
||||||
rational norm(rational const & val, unsigned bv_size, bool is_signed) const ;
|
rational norm(rational const & val, unsigned bv_size, bool is_signed = false) const ;
|
||||||
rational norm(rational const & val, unsigned bv_size) const { return norm(val, bv_size, false); }
|
|
||||||
bool has_sign_bit(rational const & n, unsigned bv_size) const;
|
bool has_sign_bit(rational const & n, unsigned bv_size) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ public:
|
||||||
mpq(int v) : m_num(v) {}
|
mpq(int v) : m_num(v) {}
|
||||||
mpq() = default;
|
mpq() = default;
|
||||||
mpq(mpq &&) noexcept = default;
|
mpq(mpq &&) noexcept = default;
|
||||||
|
mpq(mpz && n) noexcept : m_num(std::move(n)) {}
|
||||||
mpq & operator=(mpq&&) = default;
|
mpq & operator=(mpq&&) = default;
|
||||||
mpq & operator=(mpq const&) = delete;
|
mpq & operator=(mpq const&) = delete;
|
||||||
mpz const & numerator() const { return m_num; }
|
mpz const & numerator() const { return m_num; }
|
||||||
|
|
@ -558,6 +559,8 @@ public:
|
||||||
mod(a.m_num, b.m_num, c);
|
mod(a.m_num, b.m_num, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mpz mod2k(mpz const & a, unsigned k) { return mpz_manager<SYNCH>::mod2k(a, k); }
|
||||||
|
|
||||||
static unsigned hash(mpz const & a) { return mpz_manager<SYNCH>::hash(a); }
|
static unsigned hash(mpz const & a) { return mpz_manager<SYNCH>::hash(a); }
|
||||||
|
|
||||||
static unsigned hash(mpq const & a) { return hash(a.m_num) + 3*hash(a.m_den); }
|
static unsigned hash(mpq const & a) { return hash(a.m_num) + 3*hash(a.m_den); }
|
||||||
|
|
|
||||||
100
src/util/mpz.cpp
100
src/util/mpz.cpp
|
|
@ -633,6 +633,79 @@ void mpz_manager<SYNCH>::mod(mpz const & a, mpz const & b, mpz & c) {
|
||||||
STRACE(mpz, tout << to_string(c) << "\n";);
|
STRACE(mpz, tout << to_string(c) << "\n";);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<bool SYNCH>
|
||||||
|
mpz mpz_manager<SYNCH>::mod2k(mpz const & a, unsigned k) {
|
||||||
|
if (is_zero(a))
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
mpz result;
|
||||||
|
|
||||||
|
if (is_small(a) && k < 64) {
|
||||||
|
uint64_t mask = ((1ULL << k) - 1);
|
||||||
|
uint64_t uval = static_cast<uint64_t>(i64(a));
|
||||||
|
set_i64(result, static_cast<int64_t>(uval & mask));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_nonneg(a) && bitsize(a) <= k) {
|
||||||
|
return dup(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef _MP_GMP
|
||||||
|
sign_cell ca(*this, a);
|
||||||
|
unsigned digit_size = sizeof(digit_t) * 8;
|
||||||
|
unsigned digit_count = k / digit_size;
|
||||||
|
unsigned rem_bits = k % digit_size;
|
||||||
|
unsigned total_digits = digit_count + (rem_bits > 0);
|
||||||
|
digit_t mask = (1ULL << rem_bits) - 1;
|
||||||
|
bool is_zero = true;
|
||||||
|
|
||||||
|
allocate_if_needed(result, total_digits);
|
||||||
|
|
||||||
|
// compute |a| mod 2^k-
|
||||||
|
for (unsigned i = 0, e = std::min(digit_count, ca.cell()->m_size); i < e; ++i) {
|
||||||
|
is_zero &= (digits(result)[i] = ca.cell()->m_digits[i]) == 0;
|
||||||
|
}
|
||||||
|
for (unsigned i = ca.cell()->m_size; i < total_digits; ++i) {
|
||||||
|
digits(result)[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rem_bits > 0 && digit_count < ca.cell()->m_size) {
|
||||||
|
is_zero &= (digits(result)[digit_count] = ca.cell()->m_digits[digit_count] & mask) == 0;
|
||||||
|
}
|
||||||
|
result.m_ptr->m_size = total_digits;
|
||||||
|
|
||||||
|
if (ca.sign() < 0 && !is_zero) {
|
||||||
|
// Negative case: if non-zero, result = 2^k - (|a| mod 2^k)
|
||||||
|
// which boils down to computing ~result + 1
|
||||||
|
for (unsigned i = 0; i < total_digits; ++i) {
|
||||||
|
digits(result)[i] = ~digits(result)[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment result
|
||||||
|
digit_t carry = 1;
|
||||||
|
for (unsigned i = 0; i < total_digits && carry; ++i) {
|
||||||
|
digit_t sum = digits(result)[i] + carry;
|
||||||
|
carry = sum < digits(result)[i];
|
||||||
|
digits(result)[i] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clamp to k bits
|
||||||
|
if (rem_bits != 0) {
|
||||||
|
digits(result)[digit_count] &= mask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalize(result);
|
||||||
|
#else
|
||||||
|
ensure_mpz_t a1(a);
|
||||||
|
mk_big(result);
|
||||||
|
MPZ_BEGIN_CRITICAL();
|
||||||
|
mpz_tdiv_r_2exp(*result.m_ptr, a1(), k);
|
||||||
|
MPZ_END_CRITICAL();
|
||||||
|
#endif
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
template<bool SYNCH>
|
template<bool SYNCH>
|
||||||
void mpz_manager<SYNCH>::neg(mpz & a) {
|
void mpz_manager<SYNCH>::neg(mpz & a) {
|
||||||
STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";);
|
STRACE(mpz, tout << "[mpz] 0 - " << to_string(a) << " == ";);
|
||||||
|
|
@ -1140,13 +1213,9 @@ unsigned mpz_manager<SYNCH>::size_info(mpz const & a) {
|
||||||
|
|
||||||
template<bool SYNCH>
|
template<bool SYNCH>
|
||||||
struct mpz_manager<SYNCH>::sz_lt {
|
struct mpz_manager<SYNCH>::sz_lt {
|
||||||
mpz_manager<SYNCH> & m;
|
mpz const * m_as;
|
||||||
mpz const * m_as;
|
|
||||||
|
|
||||||
sz_lt(mpz_manager<SYNCH> & _m, mpz const * as):m(_m), m_as(as) {}
|
|
||||||
|
|
||||||
bool operator()(unsigned p1, unsigned p2) {
|
bool operator()(unsigned p1, unsigned p2) {
|
||||||
return m.size_info(m_as[p1]) < m.size_info(m_as[p2]);
|
return size_info(m_as[p1]) < size_info(m_as[p2]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1180,8 +1249,7 @@ void mpz_manager<SYNCH>::gcd(unsigned sz, mpz const * as, mpz & g) {
|
||||||
sbuffer<unsigned, 1024> p;
|
sbuffer<unsigned, 1024> p;
|
||||||
for (i = 0; i < sz; ++i)
|
for (i = 0; i < sz; ++i)
|
||||||
p.push_back(i);
|
p.push_back(i);
|
||||||
sz_lt lt(*this, as);
|
std::sort(p.begin(), p.end(), sz_lt{as});
|
||||||
std::sort(p.begin(), p.end(), lt);
|
|
||||||
TRACE(mpz_gcd, for (unsigned i = 0; i < sz; ++i) tout << p[i] << ":" << size_info(as[p[i]]) << " "; tout << "\n";);
|
TRACE(mpz_gcd, for (unsigned i = 0; i < sz; ++i) tout << p[i] << ":" << size_info(as[p[i]]) << " "; tout << "\n";);
|
||||||
gcd(as[p[0]], as[p[1]], g);
|
gcd(as[p[0]], as[p[1]], g);
|
||||||
for (i = 2; i < sz; ++i) {
|
for (i = 2; i < sz; ++i) {
|
||||||
|
|
@ -1695,15 +1763,11 @@ void mpz_manager<SYNCH>::display(std::ostream & out, mpz const & a) const {
|
||||||
else {
|
else {
|
||||||
#ifndef _MP_GMP
|
#ifndef _MP_GMP
|
||||||
if (a.m_val < 0)
|
if (a.m_val < 0)
|
||||||
out << "-";
|
out << '-';
|
||||||
if (sizeof(digit_t) == 4) {
|
|
||||||
sbuffer<char, 1024> buffer(11*size(a), 0);
|
auto sz = sizeof(digit_t) == 4 ? 11 : 21;
|
||||||
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
|
sbuffer<char, 1024> buffer(sz * size(a), 0);
|
||||||
}
|
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
|
||||||
else {
|
|
||||||
sbuffer<char, 1024> buffer(21*size(a), 0);
|
|
||||||
out << m_mpn_manager.to_string(digits(a), size(a), buffer.begin(), buffer.size());
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
// GMP version
|
// GMP version
|
||||||
size_t sz = mpz_sizeinbase(*a.m_ptr, 10) + 2;
|
size_t sz = mpz_sizeinbase(*a.m_ptr, 10) + 2;
|
||||||
|
|
@ -1835,7 +1899,7 @@ template<bool SYNCH>
|
||||||
std::string mpz_manager<SYNCH>::to_string(mpz const & a) const {
|
std::string mpz_manager<SYNCH>::to_string(mpz const & a) const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
display(buffer, a);
|
display(buffer, a);
|
||||||
return buffer.str();
|
return std::move(buffer).str();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool SYNCH>
|
template<bool SYNCH>
|
||||||
|
|
|
||||||
|
|
@ -98,12 +98,9 @@ protected:
|
||||||
friend class mpbq_manager;
|
friend class mpbq_manager;
|
||||||
friend class mpz_stack;
|
friend class mpz_stack;
|
||||||
public:
|
public:
|
||||||
mpz(int v):m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
|
mpz(int v = 0) noexcept : m_val(v), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
|
||||||
mpz():m_val(0), m_kind(mpz_small), m_owner(mpz_self), m_ptr(nullptr) {}
|
mpz(mpz_type* ptr) noexcept : m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr); }
|
||||||
mpz(mpz_type* ptr): m_val(0), m_kind(mpz_small), m_owner(mpz_ext), m_ptr(ptr) { SASSERT(ptr);}
|
mpz(mpz && other) noexcept : mpz() { swap(other); }
|
||||||
mpz(mpz && other) noexcept : m_val(other.m_val), m_kind(other.m_kind), m_owner(other.m_owner), m_ptr(nullptr) {
|
|
||||||
std::swap(m_ptr, other.m_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
mpz& operator=(mpz const& other) = delete;
|
mpz& operator=(mpz const& other) = delete;
|
||||||
mpz& operator=(mpz &&other) noexcept {
|
mpz& operator=(mpz &&other) noexcept {
|
||||||
|
|
@ -390,7 +387,7 @@ class mpz_manager {
|
||||||
int big_compare(mpz const & a, mpz const & b);
|
int big_compare(mpz const & a, mpz const & b);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
unsigned size_info(mpz const & a);
|
static unsigned size_info(mpz const & a);
|
||||||
struct sz_lt;
|
struct sz_lt;
|
||||||
|
|
||||||
static bool precise() { return true; }
|
static bool precise() { return true; }
|
||||||
|
|
@ -438,6 +435,8 @@ public:
|
||||||
|
|
||||||
void mod(mpz const & a, mpz const & b, mpz & c);
|
void mod(mpz const & a, mpz const & b, mpz & c);
|
||||||
|
|
||||||
|
mpz mod2k(mpz const & a, unsigned k);
|
||||||
|
|
||||||
void neg(mpz & a);
|
void neg(mpz & a);
|
||||||
|
|
||||||
void abs(mpz & a);
|
void abs(mpz & a);
|
||||||
|
|
|
||||||
|
|
@ -51,10 +51,10 @@ public:
|
||||||
explicit rational(unsigned n) { m().set(m_val, n); }
|
explicit rational(unsigned n) { m().set(m_val, n); }
|
||||||
|
|
||||||
rational(int n, int d) { m().set(m_val, n, d); }
|
rational(int n, int d) { m().set(m_val, n, d); }
|
||||||
|
|
||||||
rational(mpq const & q) { m().set(m_val, q); }
|
rational(mpq const & q) { m().set(m_val, q); }
|
||||||
|
rational(mpq && q) noexcept : m_val(std::move(q)) {}
|
||||||
rational(mpz const & z) { m().set(m_val, z); }
|
rational(mpz const & z) { m().set(m_val, z); }
|
||||||
|
rational(mpz && z) noexcept : m_val(std::move(z)) {}
|
||||||
|
|
||||||
explicit rational(double z) { UNREACHABLE(); }
|
explicit rational(double z) { UNREACHABLE(); }
|
||||||
|
|
||||||
|
|
@ -274,8 +274,8 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
friend inline rational mod2k(rational const & a, unsigned k) {
|
friend inline rational mod2k(rational const & a, unsigned k) {
|
||||||
if (a.is_nonneg() && a.is_int() && a.bitsize() <= k)
|
if (a.is_int())
|
||||||
return a;
|
return rational::m().mod2k(a.m_val.numerator(), k);
|
||||||
return mod(a, power_of_two(k));
|
return mod(a, power_of_two(k));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue