/*++ Copyright (c) 2006 Microsoft Corporation Module Name: mpz.h Abstract: Author: Leonardo de Moura (leonardo) 2010-06-17. Revision History: --*/ #ifndef _MPZ_H_ #define _MPZ_H_ #include #include #include"util.h" #include"small_object_allocator.h" #include"trace.h" #include"scoped_numeral.h" #include"scoped_numeral_vector.h" #include"z3_omp.h" unsigned u_gcd(unsigned u, unsigned v); uint64 u64_gcd(uint64 u, uint64 v); #ifdef _MP_GMP typedef unsigned digit_t; #endif #ifdef _MSC_VER #pragma warning(disable : 4200) #endif template class mpz_manager; template class mpq_manager; #if !defined(_MP_GMP) && !defined(_MP_MSBIGNUM) && !defined(_MP_INTERNAL) #ifdef _WINDOWS #define _MP_INTERNAL #else #define _MP_GMP #endif #endif #if defined(_MP_MSBIGNUM) typedef size_t digit_t; #elif defined(_MP_INTERNAL) typedef unsigned int digit_t; #endif #ifndef _MP_GMP class mpz_cell { unsigned m_size; unsigned m_capacity; digit_t m_digits[0]; friend class mpz_manager; friend class mpz_manager; }; #else #include #endif /** \brief Multi-precision integer. If m_ptr == 0, the it is a small number and the value is stored at m_val. Otherwise, m_val contains the sign (-1 negative, 1 positive), and m_ptr points to a mpz_cell that store the value. <<< This last statement is true only in Windows. */ class mpz { int m_val; #ifndef _MP_GMP mpz_cell * m_ptr; #else mpz_t * m_ptr; #endif friend class mpz_manager; friend class mpz_manager; friend class mpq_manager; friend class mpq_manager; friend class mpq; friend class mpbq; friend class mpbq_manager; mpz & operator=(mpz const & other) { UNREACHABLE(); return *this; } public: mpz(int v):m_val(v), m_ptr(0) {} mpz():m_val(0), m_ptr(0) {} void swap(mpz & other) { std::swap(m_val, other.m_val); std::swap(m_ptr, other.m_ptr); } }; inline void swap(mpz & m1, mpz & m2) { m1.swap(m2); } template class mpz_manager { small_object_allocator m_allocator; omp_nest_lock_t m_lock; #define MPZ_BEGIN_CRITICAL() if (SYNCH) omp_set_nest_lock(&m_lock); #define MPZ_END_CRITICAL() if (SYNCH) omp_unset_nest_lock(&m_lock); #ifndef _MP_GMP unsigned m_init_cell_capacity; mpz_cell * m_tmp[2]; mpz_cell * m_arg[2]; mpz m_int_min; static unsigned cell_size(unsigned capacity) { return sizeof(mpz_cell) + sizeof(digit_t) * capacity; } mpz_cell * allocate(unsigned capacity) { SASSERT(capacity >= m_init_cell_capacity); mpz_cell * cell = reinterpret_cast(m_allocator.allocate(cell_size(capacity))); cell->m_capacity = capacity; return cell; } // make sure that n is a big number and has capacity equal to at least c. void allocate_if_needed(mpz & n, unsigned c) { if (c < m_init_cell_capacity) c = m_init_cell_capacity; if (is_small(n)) { n.m_val = 1; n.m_ptr = allocate(c); n.m_ptr->m_capacity = c; } else if (capacity(n) < c) { deallocate(n.m_ptr); n.m_val = 1; n.m_ptr = allocate(c); n.m_ptr->m_capacity = c; } } void deallocate(mpz_cell * ptr) { m_allocator.deallocate(cell_size(ptr->m_capacity), ptr); } /** \brief Make sure that m_tmp[IDX] can hold the given number of digits */ template void ensure_tmp_capacity(unsigned capacity) { if (m_tmp[IDX]->m_capacity >= capacity) return; deallocate(m_tmp[IDX]); unsigned new_capacity = (3 * capacity + 1) >> 1; m_tmp[IDX] = allocate(new_capacity); SASSERT(m_tmp[IDX]->m_capacity >= capacity); } // Expand capacity of a while preserving its content. void ensure_capacity(mpz & a, unsigned sz); void normalize(mpz & a); #else // GMP code mpz_t m_tmp, m_tmp2; mpz_t m_two32; mpz_t * m_arg[2]; mpz_t m_uint64_max; mpz_t m_int64_max; mpz_t * allocate() { mpz_t * cell = reinterpret_cast(m_allocator.allocate(sizeof(mpz_t))); mpz_init(*cell); return cell; } void deallocate(mpz_t * ptr) { mpz_clear(*ptr); m_allocator.deallocate(sizeof(mpz_t), ptr); } #endif mpz m_two64; /** \brief Set \c a with the value stored at m_tmp[IDX], and the given sign. \c sz is an overapproximation of the the size of the number stored at \c tmp. */ template void set(mpz & a, int sign, unsigned sz); static int64 i64(mpz const & a) { return static_cast(a.m_val); } void set_big_i64(mpz & c, int64 v); void set_i64(mpz & c, int64 v) { if (v >= INT_MIN && v <= INT_MAX) { del(c); c.m_val = static_cast(v); } else { MPZ_BEGIN_CRITICAL(); set_big_i64(c, v); MPZ_END_CRITICAL(); } } void set_big_ui64(mpz & c, uint64 v); #ifndef _MP_GMP static unsigned capacity(mpz const & c) { return c.m_ptr->m_capacity; } static unsigned size(mpz const & c) { return c.m_ptr->m_size; } static digit_t * digits(mpz const & c) { return c.m_ptr->m_digits; } template void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell) { if (is_small(a)) { if (a.m_val == INT_MIN) { sign = -1; cell = m_int_min.m_ptr; } else { cell = m_arg[IDX]; SASSERT(cell->m_size == 1); if (a.m_val < 0) { sign = -1; cell->m_digits[0] = -a.m_val; } else { sign = 1; cell->m_digits[0] = a.m_val; } } } else { sign = a.m_val; cell = a.m_ptr; } } #else // GMP code template void get_arg(mpz const & a, mpz_t * & result) { if (is_small(a)) { result = m_arg[IDX]; mpz_set_si(*result, a.m_val); } else { result = a.m_ptr; } } void mk_big(mpz & a) { if (a.m_ptr == 0) a.m_ptr = allocate(); } #endif #ifndef _MP_GMP template void big_add_sub(mpz const & a, mpz const & b, mpz & c); #endif void big_add(mpz const & a, mpz const & b, mpz & c); void big_sub(mpz const & a, mpz const & b, mpz & c); void big_mul(mpz const & a, mpz const & b, mpz & c); void big_set(mpz & target, mpz const & source); #ifndef _MP_GMP #define QUOT_ONLY 0 #define REM_ONLY 1 #define QUOT_AND_REM 2 #define qr_mode int template void quot_rem_core(mpz const & a, mpz const & b, mpz & q, mpz & r); #endif void big_div_rem(mpz const & a, mpz const & b, mpz & q, mpz & r); void big_div(mpz const & a, mpz const & b, mpz & c); void big_rem(mpz const & a, mpz const & b, mpz & c); int big_compare(mpz const & a, mpz const & b); unsigned size_info(mpz const & a); struct sz_lt; public: static bool precise() { return true; } static bool field() { return false; } typedef mpz numeral; mpz_manager(); ~mpz_manager(); static bool is_small(mpz const & a) { return a.m_ptr == 0; } static mpz mk_z(int val) { return mpz(val); } void del(mpz & a) { if (a.m_ptr != 0) { MPZ_BEGIN_CRITICAL(); deallocate(a.m_ptr); MPZ_END_CRITICAL(); a.m_ptr = 0; } } void add(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz] " << to_string(a) << " + " << to_string(b) << " == ";); if (is_small(a) && is_small(b)) { set_i64(c, i64(a) + i64(b)); } else { MPZ_BEGIN_CRITICAL(); big_add(a, b, c); MPZ_END_CRITICAL(); } STRACE("mpz", tout << to_string(c) << "\n";); } void sub(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz] " << to_string(a) << " - " << to_string(b) << " == ";); if (is_small(a) && is_small(b)) { set_i64(c, i64(a) - i64(b)); } else { MPZ_BEGIN_CRITICAL(); big_sub(a, b, c); MPZ_END_CRITICAL(); } STRACE("mpz", tout << to_string(c) << "\n";); } void inc(mpz & a) { add(a, mpz(1), a); } void dec(mpz & a) { add(a, mpz(-1), a); } void mul(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz] " << to_string(a) << " * " << to_string(b) << " == ";); if (is_small(a) && is_small(b)) { set_i64(c, i64(a) * i64(b)); } else { MPZ_BEGIN_CRITICAL(); big_mul(a, b, c); MPZ_END_CRITICAL(); } STRACE("mpz", tout << to_string(c) << "\n";); } // d <- a + b*c void addmul(mpz const & a, mpz const & b, mpz const & c, mpz & d) { if (is_one(b)) { add(a, c, d); } else if (is_minus_one(b)) { sub(a, c, d); } else { mpz tmp; mul(b,c,tmp); add(a,tmp,d); del(tmp); } } // d <- a + b*c void submul(mpz const & a, mpz const & b, mpz const & c, mpz & d) { if (is_one(b)) { sub(a, c, d); } else if (is_minus_one(b)) { add(a, c, d); } else { mpz tmp; mul(b,c,tmp); sub(a,tmp,d); del(tmp); } } void machine_div_rem(mpz const & a, mpz const & b, mpz & q, mpz & r) { STRACE("mpz", tout << "[mpz-ext] divrem(" << to_string(a) << ", " << to_string(b) << ") == ";); if (is_small(a) && is_small(b)) { int64 _a = i64(a); int64 _b = i64(b); set_i64(q, _a / _b); set_i64(r, _a % _b); } else { MPZ_BEGIN_CRITICAL(); big_div_rem(a, b, q, r); MPZ_END_CRITICAL(); } STRACE("mpz", tout << "(" << to_string(q) << ", " << to_string(r) << ")\n";); } void machine_div(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz-ext] machine-div(" << to_string(a) << ", " << to_string(b) << ") == ";); if (is_small(a) && is_small(b)) { set_i64(c, i64(a) / i64(b)); } else { MPZ_BEGIN_CRITICAL(); big_div(a, b, c); MPZ_END_CRITICAL(); } STRACE("mpz", tout << to_string(c) << "\n";); } void rem(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz-ext] rem(" << to_string(a) << ", " << to_string(b) << ") == ";); if (is_small(a) && is_small(b)) { set_i64(c, i64(a) % i64(b)); } else { MPZ_BEGIN_CRITICAL(); big_rem(a, b, c); MPZ_END_CRITICAL(); } STRACE("mpz", tout << to_string(c) << "\n";); } void div(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz-ext] div(" << to_string(a) << ", " << to_string(b) << ") == ";); if (is_neg(a)) { mpz tmp; machine_div_rem(a, b, c, tmp); if (!is_zero(tmp)) { if (is_neg(b)) add(c, mk_z(1), c); else sub(c, mk_z(1), c); } del(tmp); } else { machine_div(a, b, c); } STRACE("mpz", tout << to_string(c) << "\n";); } void mod(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz-ext] mod(" << to_string(a) << ", " << to_string(b) << ") == ";); rem(a, b, c); if (is_neg(c)) { if (is_pos(b)) add(c, b, c); else sub(c, b, c); } STRACE("mpz", tout << to_string(c) << "\n";); } void neg(mpz & a) { STRACE("mpz", tout << "[mpz] 0 - " << to_string(a) << " == ";); if (is_small(a) && a.m_val == INT_MIN) { // neg(INT_MIN) is not a small int set_big_i64(a, - static_cast(INT_MIN)); return; } #ifndef _MP_GMP a.m_val = -a.m_val; #else if (is_small(a)) { a.m_val = -a.m_val; } else { mpz_neg(*a.m_ptr, *a.m_ptr); } #endif STRACE("mpz", tout << to_string(a) << "\n";); } void abs(mpz & a) { if (is_small(a)) { if (a.m_val < 0) { if (a.m_val == INT_MIN) { // abs(INT_MIN) is not a small int set_big_i64(a, - static_cast(INT_MIN)); } else a.m_val = -a.m_val; } } else { #ifndef _MP_GMP a.m_val = 1; #else mpz_abs(*a.m_ptr, *a.m_ptr); #endif } } static bool is_pos(mpz const & a) { #ifndef _MP_GMP return a.m_val > 0; #else if (is_small(a)) return a.m_val > 0; else return mpz_sgn(*a.m_ptr) > 0; #endif } static bool is_neg(mpz const & a) { #ifndef _MP_GMP return a.m_val < 0; #else if (is_small(a)) return a.m_val < 0; else return mpz_sgn(*a.m_ptr) < 0; #endif } static bool is_zero(mpz const & a) { #ifndef _MP_GMP return a.m_val == 0; #else if (is_small(a)) return a.m_val == 0; else return mpz_sgn(*a.m_ptr) == 0; #endif } static int sign(mpz const & a) { #ifndef _MP_GMP return a.m_val; #else if (is_small(a)) return a.m_val; else return mpz_sgn(*a.m_ptr); #endif } static bool is_nonpos(mpz const & a) { return !is_pos(a); } static bool is_nonneg(mpz const & a) { return !is_neg(a); } bool eq(mpz const & a, mpz const & b) { if (is_small(a) && is_small(b)) { return a.m_val == b.m_val; } else { MPZ_BEGIN_CRITICAL(); bool res = big_compare(a, b) == 0; MPZ_END_CRITICAL(); return res; } } bool lt(mpz const & a, mpz const & b) { if (is_small(a) && is_small(b)) { return a.m_val < b.m_val; } else { MPZ_BEGIN_CRITICAL(); bool res = big_compare(a, b) < 0; MPZ_END_CRITICAL(); return res; } } bool neq(mpz const & a, mpz const & b) { return !eq(a, b); } bool gt(mpz const & a, mpz const & b) { return lt(b, a); } bool ge(mpz const & a, mpz const & b) { return !lt(a, b); } bool le(mpz const & a, mpz const & b) { return !lt(b, a); } void gcd(mpz const & a, mpz const & b, mpz & c); void gcd(unsigned sz, mpz const * as, mpz & g); /** \brief Extended Euclid: r1*a + r2*b = g */ void gcd(mpz const & r1, mpz const & r2, mpz & a, mpz & b, mpz & g); void lcm(mpz const & a, mpz const & b, mpz & c); /** \brief Return true if a | b */ bool divides(mpz const & a, mpz const & b); // not a field void inv(mpz & a) { SASSERT(false); } void bitwise_or(mpz const & a, mpz const & b, mpz & c); void bitwise_and(mpz const & a, mpz const & b, mpz & c); void bitwise_xor(mpz const & a, mpz const & b, mpz & c); void bitwise_not(unsigned sz, mpz const & a, mpz & c); void set(mpz & target, mpz const & source) { if (is_small(source)) { del(target); target.m_val = source.m_val; } else { MPZ_BEGIN_CRITICAL(); big_set(target, source); MPZ_END_CRITICAL(); } } void set(mpz & a, int val) { del(a); a.m_val = val; } void set(mpz & a, unsigned val) { if (val <= INT_MAX) set(a, static_cast(val)); else set(a, static_cast(static_cast(val))); } void set(mpz & a, char const * val); void set(mpz & a, int64 val) { set_i64(a, val); } void set(mpz & a, uint64 val) { if (val < INT_MAX) { del(a); a.m_val = static_cast(val); } else { MPZ_BEGIN_CRITICAL(); set_big_ui64(a, val); MPZ_END_CRITICAL(); } } void set(mpz & target, unsigned sz, digit_t const * digits); void reset(mpz & a) { del(a); a.m_val = 0; } void swap(mpz & a, mpz & b) { std::swap(a.m_val, b.m_val); std::swap(a.m_ptr, b.m_ptr); } bool is_uint64(mpz const & a) const; bool is_int64(mpz const & a) const; uint64 get_uint64(mpz const & a) const; int64 get_int64(mpz const & a) const; double get_double(mpz const & a) const; std::string to_string(mpz const & a) const; void display(std::ostream & out, mpz const & a) const; /** \brief Display mpz number in SMT 2.0 format. If decimal == true, then ".0" is appended. */ void display_smt2(std::ostream & out, mpz const & a, bool decimal) const; static unsigned hash(mpz const & a); static bool is_one(mpz const & a) { #ifndef _MP_GMP return is_small(a) && a.m_val == 1; #else if (is_small(a)) return a.m_val == 1; return mpz_cmp_si(*a.m_ptr, 1) == 0; #endif } static bool is_minus_one(mpz const & a) { #ifndef _MP_GMP return is_small(a) && a.m_val == -1; #else if (is_small(a)) return a.m_val == -1; return mpz_cmp_si(*a.m_ptr, -1) == 0; #endif } void power(mpz const & a, unsigned p, mpz & b); bool is_power_of_two(mpz const & a); bool is_power_of_two(mpz const & a, unsigned & shift); void machine_div2k(mpz & a, unsigned k); void machine_div2k(mpz const & a, unsigned k, mpz & r) { set(r, a); machine_div2k(r, k); } void mul2k(mpz & a, unsigned k); void mul2k(mpz const & a, unsigned k, mpz & r) { set(r, a); mul2k(r, k); } /** \brief Return largest k s.t. n is a multiple of 2^k */ unsigned power_of_two_multiple(mpz const & n); /** \brief Return the position of the most significant bit. Return 0 if the number is negative */ unsigned log2(mpz const & n); /** \brief log2(-n) Return 0 if the number is nonegative */ unsigned mlog2(mpz const & n); /** \brief Return the bit-size of n. This method is mainly used for collecting statistics. */ unsigned bitsize(mpz const & n); /** \brief Return true if the number is a perfect square, and store the square root in 'root'. If the number n is positive and the result is false, then root will contain the smallest integer r such that r*r > n. */ bool is_perfect_square(mpz const & a, mpz & root); /** \brief Return the biggest k s.t. 2^k <= a. \remark Return 0 if a is not positive. */ unsigned prev_power_of_two(mpz const & a) { return log2(a); } /** \brief Return true if a^{1/n} is an integer, and store the result in a. Otherwise return false, and update a with the smallest integer r such that r*r > n. \remark This method assumes that if n is even, then a is nonegative */ bool root(mpz & a, unsigned n); bool root(mpz const & a, unsigned n, mpz & r) { set(r, a); return root(r, n); } bool is_even(mpz const & a) { if (is_small(a)) return !(a.m_val & 0x1); #ifndef _MP_GMP return !(0x1 & digits(a)[0]); #else return mpz_even_p(*a.m_ptr); #endif } bool is_odd(mpz const & n) { return !is_even(n); } // Store the digits of n into digits, and return the sign. bool decompose(mpz const & n, svector & digits); }; typedef mpz_manager synch_mpz_manager; typedef mpz_manager unsynch_mpz_manager; typedef _scoped_numeral scoped_mpz; typedef _scoped_numeral scoped_synch_mpz; typedef _scoped_numeral_vector scoped_mpz_vector; #endif /* _MPZ_H_ */