diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index 136a95132..2d9768758 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -954,249 +954,251 @@ void mpz_manager::gcd(mpz const & a, mpz const & b, mpz & c) { return; } } + else { #ifdef _MP_GMP - ensure_mpz_t a1(a), b1(b); - mk_big(c); - mpz_gcd(*c.ptr(), a1(), b1()); - return; + ensure_mpz_t a1(a), b1(b); + mk_big(c); + mpz_gcd(*c.ptr(), a1(), b1()); + return; #endif - if (is_zero(a)) { - set(c, b); - abs(c); - return; - } - if (is_zero(b)) { - set(c, a); - abs(c); - return; - } -#ifdef BINARY_GCD - // Binary GCD for big numbers - // - It doesn't use division - // - The initial experiments, don't show any performance improvement - // - It only works with _MP_INTERNAL - mpz u, v, diff; - set(u, a); - set(v, b); - abs(u); - abs(v); - - unsigned k_u = power_of_two_multiple(u); - unsigned k_v = power_of_two_multiple(v); - unsigned k = k_u < k_v ? k_u : k_v; - - machine_div2k(u, k_u); - - while (true) { - machine_div2k(v, k_v); - - if (lt(u, v)) { - sub(v, u, v); - } - else { - sub(u, v, diff); - swap(u, v); - swap(v, diff); + if (is_zero(a)) { + set(c, b); + abs(c); + return; } - - if (is_zero(v) || is_one(v)) - break; - - // reset least significant bit - if (is_small(v)) - v.set(v.value() & ~1); - else - v.ptr()->m_digits[0] &= ~static_cast(1); - k_v = power_of_two_multiple(v); - } + if (is_zero(b)) { + set(c, a); + abs(c); + return; + } +#ifdef BINARY_GCD + // Binary GCD for big numbers + // - It doesn't use division + // - The initial experiments, don't show any performance improvement + // - It only works with _MP_INTERNAL + mpz u, v, diff; + set(u, a); + set(v, b); + abs(u); + abs(v); - mul2k(u, k, c); - del(u); del(v); del(diff); + unsigned k_u = power_of_two_multiple(u); + unsigned k_v = power_of_two_multiple(v); + unsigned k = k_u < k_v ? k_u : k_v; + + machine_div2k(u, k_u); + + while (true) { + machine_div2k(v, k_v); + + if (lt(u, v)) { + sub(v, u, v); + } + else { + sub(u, v, diff); + swap(u, v); + swap(v, diff); + } + + if (is_zero(v) || is_one(v)) + break; + + // reset least significant bit + if (is_small(v)) + v.set(v.value() & ~1); + else + v.ptr()->m_digits[0] &= ~static_cast(1); + k_v = power_of_two_multiple(v); + } + + mul2k(u, k, c); + del(u); del(v); del(diff); #endif // BINARY_GCD #ifdef EUCLID_GCD - mpz tmp1; - mpz tmp2; - mpz aux; - set(tmp1, a); - set(tmp2, b); - abs(tmp1); - abs(tmp2); - if (lt(tmp1, tmp2)) - swap(tmp1, tmp2); - if (is_zero(tmp2)) { - swap(c, tmp1); - } - else { - while (true) { - if (is_uint64(tmp1) && is_uint64(tmp2)) { - set(c, u64_gcd(get_uint64(tmp1), get_uint64(tmp2))); - break; - } - rem(tmp1, tmp2, aux); - if (is_zero(aux)) { - swap(c, tmp2); - break; - } + mpz tmp1; + mpz tmp2; + mpz aux; + set(tmp1, a); + set(tmp2, b); + abs(tmp1); + abs(tmp2); + if (lt(tmp1, tmp2)) swap(tmp1, tmp2); - swap(tmp2, aux); + if (is_zero(tmp2)) { + swap(c, tmp1); } - } - del(tmp1); del(tmp2); del(aux); + else { + while (true) { + if (is_uint64(tmp1) && is_uint64(tmp2)) { + set(c, u64_gcd(get_uint64(tmp1), get_uint64(tmp2))); + break; + } + rem(tmp1, tmp2, aux); + if (is_zero(aux)) { + swap(c, tmp2); + break; + } + swap(tmp1, tmp2); + swap(tmp2, aux); + } + } + del(tmp1); del(tmp2); del(aux); #endif // EUCLID_GCD #ifdef LS_BINARY_GCD - mpz u, v, t, u1, u2; - set(u, a); - set(v, b); - abs(u); - abs(v); - if (lt(u, v)) - swap(u, v); - while (!is_zero(v)) { - // Basic idea: - // compute t = 2^e*v such that t <= u < 2t - // u := min{u - t, 2t - u} - // - // The assignment u := min{u - t, 2t - u} - // can be replaced with u := u - t - // - // Since u and v are positive, we have: - // 2^{log2(u)} <= u < 2^{(log2(u) + 1)} - // 2^{log2(v)} <= v < 2^{(log2(v) + 1)} - // --> - // 2^{log2(v)}*2^{log2(u)-log2(v)} <= v*2^{log2(u)-log2(v)} < 2^{log2(v) + 1}*2^{log2(u)-log2(v)} - // --> - // 2^{log2(u)} <= v*2^{log2(u)-log2(v)} < 2^{log2(u) + 1} - // - // Now, let t be v*2^{log2(u)-log2(v)} - // If t <= u, then we found t - // Otherwise t = t div 2 - unsigned k_u = log2(u); - unsigned k_v = log2(v); - SASSERT(k_v <= k_u); - unsigned e = k_u - k_v; - mul2k(v, e, t); - sub(u, t, u1); - if (is_neg(u1)) { - // t is too big - machine_div2k(t, 1); - // Now, u1 contains u - 2t - neg(u1); - // Now, u1 contains 2t - u - sub(u, t, u2); // u2 := u - t - } - else { - // u1 contains u - t - mul2k(t, 1); - sub(t, u, u2); - // u2 contains 2t - u - } - SASSERT(is_nonneg(u1)); - SASSERT(is_nonneg(u2)); - if (lt(u1, u2)) - swap(u, u1); - else - swap(u, u2); + mpz u, v, t, u1, u2; + set(u, a); + set(v, b); + abs(u); + abs(v); if (lt(u, v)) - swap(u,v); - } - swap(u, c); - del(u); del(v); del(t); del(u1); del(u2); + swap(u, v); + while (!is_zero(v)) { + // Basic idea: + // compute t = 2^e*v such that t <= u < 2t + // u := min{u - t, 2t - u} + // + // The assignment u := min{u - t, 2t - u} + // can be replaced with u := u - t + // + // Since u and v are positive, we have: + // 2^{log2(u)} <= u < 2^{(log2(u) + 1)} + // 2^{log2(v)} <= v < 2^{(log2(v) + 1)} + // --> + // 2^{log2(v)}*2^{log2(u)-log2(v)} <= v*2^{log2(u)-log2(v)} < 2^{log2(v) + 1}*2^{log2(u)-log2(v)} + // --> + // 2^{log2(u)} <= v*2^{log2(u)-log2(v)} < 2^{log2(u) + 1} + // + // Now, let t be v*2^{log2(u)-log2(v)} + // If t <= u, then we found t + // Otherwise t = t div 2 + unsigned k_u = log2(u); + unsigned k_v = log2(v); + SASSERT(k_v <= k_u); + unsigned e = k_u - k_v; + mul2k(v, e, t); + sub(u, t, u1); + if (is_neg(u1)) { + // t is too big + machine_div2k(t, 1); + // Now, u1 contains u - 2t + neg(u1); + // Now, u1 contains 2t - u + sub(u, t, u2); // u2 := u - t + } + else { + // u1 contains u - t + mul2k(t, 1); + sub(t, u, u2); + // u2 contains 2t - u + } + SASSERT(is_nonneg(u1)); + SASSERT(is_nonneg(u2)); + if (lt(u1, u2)) + swap(u, u1); + else + swap(u, u2); + if (lt(u, v)) + swap(u,v); + } + swap(u, c); + del(u); del(v); del(t); del(u1); del(u2); #endif // LS_BINARY_GCD #ifdef LEHMER_GCD - // For now, it only works if sizeof(digit_t) == sizeof(unsigned) - static_assert(sizeof(digit_t) == sizeof(unsigned), ""); - - int64_t a_hat, b_hat, A, B, C, D, T, q, a_sz, b_sz; - mpz a1, b1, t, r, tmp; - set(a1, a); - set(b1, b); - abs(a1); - abs(b1); - if (lt(a1, b1)) - swap(a1, b1); - while (true) { - SASSERT(ge(a1, b1)); - if (is_small(b1)) { - if (is_small(a1)) { - unsigned r = u_gcd(a1.value(), b1.value()); - set(c, r); - break; + // For now, it only works if sizeof(digit_t) == sizeof(unsigned) + static_assert(sizeof(digit_t) == sizeof(unsigned), ""); + + int64_t a_hat, b_hat, A, B, C, D, T, q, a_sz, b_sz; + mpz a1, b1, t, r, tmp; + set(a1, a); + set(b1, b); + abs(a1); + abs(b1); + if (lt(a1, b1)) + swap(a1, b1); + while (true) { + SASSERT(ge(a1, b1)); + if (is_small(b1)) { + if (is_small(a1)) { + unsigned r = u_gcd(a1.value(), b1.value()); + set(c, r); + break; + } + else { + while (!is_zero(b1)) { + SASSERT(ge(a1, b1)); + rem(a1, b1, tmp); + swap(a1, b1); + swap(b1, tmp); + } + swap(c, a1); + break; + } + } + SASSERT(!is_small(a1)); + SASSERT(!is_small(b1)); + a_sz = a1.ptr()->m_size; + b_sz = b1.ptr()->m_size; + SASSERT(b_sz <= a_sz); + a_hat = a1.ptr()->m_digits[a_sz - 1]; + b_hat = (b_sz == a_sz) ? b1.ptr()->m_digits[b_sz - 1] : 0; + A = 1; + B = 0; + C = 0; + D = 1; + while (true) { + // Loop invariants + SASSERT(a_hat + A <= static_cast(UINT_MAX) + 1); + SASSERT(a_hat + B < static_cast(UINT_MAX) + 1); + SASSERT(b_hat + C < static_cast(UINT_MAX) + 1); + SASSERT(b_hat + D <= static_cast(UINT_MAX) + 1); + // overflows can't happen since I'm using int64 + if (b_hat + C == 0 || b_hat + D == 0) + break; + q = (a_hat + A)/(b_hat + C); + if (q != (a_hat + B)/(b_hat + D)) + break; + T = A - q*C; + A = C; + C = T; + T = B - q*D; + B = D; + D = T; + T = a_hat - q*b_hat; + a_hat = b_hat; + b_hat = T; + } + SASSERT(ge(a1, b1)); + if (B == 0) { + rem(a1, b1, t); + swap(a1, b1); + swap(b1, t); + SASSERT(ge(a1, b1)); } else { - while (!is_zero(b1)) { - SASSERT(ge(a1, b1)); - rem(a1, b1, tmp); - swap(a1, b1); - swap(b1, tmp); - } - swap(c, a1); - break; + // t <- A*a1 + set(tmp, A); + mul(a1, tmp, t); + // t <- t + B*b1 + set(tmp, B); + addmul(t, tmp, b1, t); + // r <- C*a1 + set(tmp, C); + mul(a1, tmp, r); + // r <- r + D*b1 + set(tmp, D); + addmul(r, tmp, b1, r); + // a <- t + swap(a1, t); + // b <- r + swap(b1, r); + SASSERT(ge(a1, b1)); } } - SASSERT(!is_small(a1)); - SASSERT(!is_small(b1)); - a_sz = a1.ptr()->m_size; - b_sz = b1.ptr()->m_size; - SASSERT(b_sz <= a_sz); - a_hat = a1.ptr()->m_digits[a_sz - 1]; - b_hat = (b_sz == a_sz) ? b1.ptr()->m_digits[b_sz - 1] : 0; - A = 1; - B = 0; - C = 0; - D = 1; - while (true) { - // Loop invariants - SASSERT(a_hat + A <= static_cast(UINT_MAX) + 1); - SASSERT(a_hat + B < static_cast(UINT_MAX) + 1); - SASSERT(b_hat + C < static_cast(UINT_MAX) + 1); - SASSERT(b_hat + D <= static_cast(UINT_MAX) + 1); - // overflows can't happen since I'm using int64 - if (b_hat + C == 0 || b_hat + D == 0) - break; - q = (a_hat + A)/(b_hat + C); - if (q != (a_hat + B)/(b_hat + D)) - break; - T = A - q*C; - A = C; - C = T; - T = B - q*D; - B = D; - D = T; - T = a_hat - q*b_hat; - a_hat = b_hat; - b_hat = T; + del(a1); del(b1); del(r); del(t); del(tmp); } - SASSERT(ge(a1, b1)); - if (B == 0) { - rem(a1, b1, t); - swap(a1, b1); - swap(b1, t); - SASSERT(ge(a1, b1)); - } - else { - // t <- A*a1 - set(tmp, A); - mul(a1, tmp, t); - // t <- t + B*b1 - set(tmp, B); - addmul(t, tmp, b1, t); - // r <- C*a1 - set(tmp, C); - mul(a1, tmp, r); - // r <- r + D*b1 - set(tmp, D); - addmul(r, tmp, b1, r); - // a <- t - swap(a1, t); - // b <- r - swap(b1, r); - SASSERT(ge(a1, b1)); - } - } - del(a1); del(b1); del(r); del(t); del(tmp); #endif // LEHMER_GCD } diff --git a/src/util/mpz.h b/src/util/mpz.h index 490c85993..555b5e1de 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -338,10 +338,7 @@ class mpz_manager { void set_big_i64(mpz & c, int64_t v); void set_i64(mpz & c, int64_t v) { - if (mpz::fits_in_small(v)) { - if (!is_small(c)) { - deallocate(c); - } + if (mpz::fits_in_small(v) && is_small(c)) { c.set64(v); } else { @@ -663,10 +660,7 @@ public: void set(mpz & a, int val) { // On 32-bit platforms, int can be outside small range - if (mpz::fits_in_small(val)) { - if (!is_small(a)) { - deallocate(a); - } + if (mpz::fits_in_small(val) && is_small(a)) { a.set(val); } else { @@ -675,10 +669,7 @@ public: } void set(mpz & a, unsigned val) { - if (mpz::fits_in_small(val)) { - if (!is_small(a)) { - deallocate(a); - } + if (mpz::fits_in_small(val) && is_small(a)) { a.set(static_cast(val)); } else { @@ -693,10 +684,7 @@ public: } void set(mpz & a, uint64_t val) { - if (mpz::fits_in_small(val)) { - if (!is_small(a)) { - deallocate(a); - } + if (mpz::fits_in_small(val) && is_small(a)) { a.set64(static_cast(val)); } else {