diff --git a/src/math/bigfix/u256.cpp b/src/math/bigfix/u256.cpp index 79ebb1844..6b956b844 100644 --- a/src/math/bigfix/u256.cpp +++ b/src/math/bigfix/u256.cpp @@ -1,8 +1,52 @@ #include "math/bigfix/u256.h" #include "math/bigfix/Hacl_Bignum256.h" +u256::u256(uint64_t n) { + m_num[0] = n; + m_num[1] = 0; + m_num[2] = 0; + m_num[3] = 0; +} + +u256::u256(uint64_t const* v) { + std::uninitialized_copy(v, v + sizeof(*this), m_num); +} + u256 u256::operator*(u256 const& other) const { uint64_t result[8]; Hacl_Bignum256_mul(const_cast(m_num), const_cast(other.m_num), result); return u256(result); } + +u256 u256::operator+(u256 const& other) const { + u256 result; + Hacl_Bignum256_add(const_cast(m_num), const_cast(other.m_num), result.m_num); + return result; +} + +u256 u256::operator-(u256 const& other) const { + u256 result; + Hacl_Bignum256_sub(const_cast(m_num), const_cast(other.m_num), result.m_num); + return result; +} + +u256& u256::operator*=(u256 const& other) { + uint64_t result[8]; + Hacl_Bignum256_add(const_cast(m_num), const_cast(other.m_num), result); + std::uninitialized_copy(m_num, m_num + sizeof(*this), result); + return *this; +} + +u256& u256::operator+=(u256 const& other) { + uint64_t result[4]; + Hacl_Bignum256_add(const_cast(m_num), const_cast(other.m_num), result); + std::uninitialized_copy(m_num, m_num + sizeof(*this), result); + return *this; +} + +u256& u256::operator-=(u256 const& other) { + uint64_t result[4]; + Hacl_Bignum256_sub(const_cast(m_num), const_cast(other.m_num), result); + std::uninitialized_copy(m_num, m_num + sizeof(*this), result); + return *this; +} diff --git a/src/math/bigfix/u256.h b/src/math/bigfix/u256.h index 7d6e11991..42d6a3cd8 100644 --- a/src/math/bigfix/u256.h +++ b/src/math/bigfix/u256.h @@ -7,9 +7,11 @@ class u256 { public: u256() { memset(this, 0, sizeof(*this)); } u256(uint64_t n); - u256(uint64_t const* v) { memcpy(m_num, v, sizeof(*this)); } + u256(uint64_t const* v); u256 operator*(u256 const& other) const; u256 operator+(u256 const& other) const; u256 operator-(u256 const& other) const; - + u256& operator+=(u256 const& other); + u256& operator*=(u256 const& other); + u256& operator-=(u256 const& other); }; diff --git a/src/math/interval/mod_interval_def.h b/src/math/interval/mod_interval_def.h index 0e1b07d95..9fc1f6e06 100644 --- a/src/math/interval/mod_interval_def.h +++ b/src/math/interval/mod_interval_def.h @@ -133,7 +133,7 @@ mod_interval& mod_interval::intersect_ule(Numeral const& h) { lo = 0, hi = h + 1; else if (hi > lo && lo > h) set_empty(); - else if (hi != 0 || h + 1 < hi) + else if (hi == 0 || h + 1 < hi) hi = h + 1; return *this; } @@ -146,7 +146,7 @@ mod_interval& mod_interval::intersect_uge(Numeral const& l) { set_empty(); else if (is_free()) lo = l, hi = 0; - else if (lo < hi && lo < l) + else if ((lo < hi || hi == 0) && lo < l) lo = l; return *this; } @@ -175,7 +175,7 @@ mod_interval& mod_interval::intersect_ugt(Numeral const& l) { else if (is_free()) lo = l + 1, hi = 0; else if (lo > l) - return; + return *this; else if (lo < hi && hi <= l) set_empty(); else if (lo < hi) diff --git a/src/test/mod_interval.cpp b/src/test/mod_interval.cpp index 70e3a6eea..a743e34ba 100644 --- a/src/test/mod_interval.cpp +++ b/src/test/mod_interval.cpp @@ -15,7 +15,21 @@ static void test_interval2() { mod_interval i; std::cout << " >= 0: " << i.intersect_uge(0) << "\n"; std::cout << " >= 1: " << i.intersect_uge(1) << "\n"; - + std::cout << " >= 2: " << i.intersect_uge(2) << "\n"; + SASSERT(i.lo == 2 && i.hi == 0); + std::cout << " <= 10: " << i.intersect_ule(10) << "\n"; + std::cout << " > 2: " << i.intersect_ugt(2) << "\n"; + std::cout << " > 2: " << i.intersect_ugt(2) << "\n"; + std::cout << " <= 10: " << i.intersect_ule(10) << "\n"; + std::cout << " <= 11: " << i.intersect_ule(11) << "\n"; + std::cout << " <= 9: " << i.intersect_ule(9) << "\n"; + std::cout << " <= 2: " << i.intersect_ule(2) << "\n"; + SASSERT(i.is_empty()); + i = mod_interval(2, 10); + std::cout << " >= 10: " << i.intersect_uge(10) << "\n"; + SASSERT(i.is_empty()); + i = mod_interval(500, 10); + std::cout << "test-wrap: " << i << "\n"; } void tst_mod_interval() {