diff --git a/src/math/bigfix/u256.cpp b/src/math/bigfix/u256.cpp index b00820099..0a6f626cf 100644 --- a/src/math/bigfix/u256.cpp +++ b/src/math/bigfix/u256.cpp @@ -7,12 +7,25 @@ u256::u256() { } u256::u256(uint64_t n) { + // TBD use instead: bn_from_bytes_be? m_num[0] = n; m_num[1] = m_num[2] = m_num[3] = 0; } +u256::u256(rational const& n) { + uint8_t bytes[32]; + for (unsigned i = 0; i < 32; ++i) + bytes[i] = 0; + for (unsigned i = 0; i < 256; ++i) + bytes[(i / 7)] |= n.get_bit(i) << (i % 8); + auto* v = Hacl_Bignum256_new_bn_from_bytes_be(32, bytes); + std::uninitialized_copy(v, v + 4, m_num); + free(v); +} + + u256::u256(uint64_t const* v) { - std::uninitialized_copy(v, v + sizeof(*this), m_num); + std::uninitialized_copy(v, v + 4, m_num); } u256 u256::operator*(u256 const& other) const { @@ -37,3 +50,52 @@ u256& u256::operator-=(u256 const& other) { Hacl_Bignum256_sub(const_cast(m_num), const_cast(other.m_num), m_num); return *this; } + +u256& u256::inv() { + uint64_t zero[4]; + zero[0] = zero[1] = zero[2] = zero[3] = 0; + Hacl_Bignum256_sub(zero, const_cast(m_num), m_num); + return *this; +} + +u256 u256::mul_inverse() const { + NOT_IMPLEMENTED_YET(); + + /* + Write `a mod n` in `res`. + + The argument a is meant to be a 512-bit bignum, i.e. uint64_t[8]. + The argument n and the outparam res are meant to be 256-bit bignums, i.e. uint64_t[4]. + + The function returns false if any of the following preconditions are violated, + true otherwise. + • 1 < n + • n % 2 = 1 + VERIFY(Hacl_Bignum256_mod(uint64_t *n, uint64_t *a, uint64_t *res)); + */ + + return *this; +} + +unsigned u256::trailing_zeros() const { + unsigned r = 0; + for (unsigned i = 0; i < 3; ++i) { + r += ::trailing_zeros(m_num[i]); + if (r != (i+1)*64) + return r; + } + return r + ::trailing_zeros(m_num[3]); +} + +u256 u256::gcd(u256 const& other) const { + NOT_IMPLEMENTED_YET(); + return *this; +} + +std::ostream& u256::display(std::ostream& out) const { + rational n; + for (unsigned i = 0; i < 4; ++i) + if (m_num[i] != 0) + n += rational(m_num[i], rational::ui64()) * rational::power_of_two(i * 64); + return out << n; +} diff --git a/src/math/bigfix/u256.h b/src/math/bigfix/u256.h index 5ae5eae7c..620910734 100644 --- a/src/math/bigfix/u256.h +++ b/src/math/bigfix/u256.h @@ -1,17 +1,52 @@ #pragma once #include "util/util.h" +#include "util/rational.h" class u256 { uint64_t m_num[4]; + u256(uint64_t const* v); public: u256(); u256(uint64_t n); - u256(uint64_t const* v); + u256(rational const& n); u256 operator*(u256 const& other) const; u256 operator+(u256 const& other) const { u256 r = *this; return r += other; } u256 operator-(u256 const& other) const { u256 r = *this; return r -= other; } + u256 operator-() const { u256 r = *this; return r.inv(); } + + u256 mul_inverse() const; + unsigned trailing_zeros() const; + u256 gcd(u256 const& other) const; + + // updates + void reset() { m_num[0] = m_num[1] = m_num[2] = m_num[3] = 0; } u256& operator+=(u256 const& other); u256& operator*=(u256 const& other); u256& operator-=(u256 const& other); + u256& inv(); /* unary minus */ + + // comparisons + bool operator==(u256 const& other) const; + bool operator!=(u256 const& other) const; + bool operator<(u256 const& other) const; + bool operator<=(u256 const& other) const; + bool operator>(u256 const& other) const; + bool operator>=(u256 const& other) const; + + bool operator<(uint64_t other) const; + bool operator<=(uint64_t other) const; + bool operator>(uint64_t other) const; + bool operator>=(uint64_t other) const; + + bool is_zero() const { return m_num[0] == 0 && m_num[1] == 0 && m_num[2] == 0 && m_num[3] == 0; } + bool is_one() const { return m_num[0] == 1 && m_num[1] == 0 && m_num[2] == 0 && m_num[3] == 0; } + bool is_even() const { return (m_num[0]&1) == 0; } + + std::ostream& display(std::ostream& out) const; + }; + +inline std::ostream& operator<<(std::ostream& out, u256 const& u) { + return u.display(out); +} diff --git a/src/math/interval/mod_interval_def.h b/src/math/interval/mod_interval_def.h index b1bdcfcf7..da3e1b843 100644 --- a/src/math/interval/mod_interval_def.h +++ b/src/math/interval/mod_interval_def.h @@ -123,7 +123,6 @@ Numeral mod_interval::closest_value(Numeral const& n) const { // TBD: correctness and completeness for wrap-around semantics needs to be checked/fixed - template mod_interval& mod_interval::intersect_uge(Numeral const& l) { if (is_empty()) @@ -153,7 +152,7 @@ mod_interval& mod_interval::intersect_ugt(Numeral const& l) { set_empty(); else if (lo < hi) lo = l + 1; - else if (hi < lo && hi <= l + 1 && l <= lo - 1) + else if (hi < lo && hi <= l + 1 && l < lo) hi = 0; else if (hi < lo && lo <= l) hi = 0, lo = l + 1;