diff --git a/src/math/bigfix/Hacl_Bignum256.c b/src/math/bigfix/Hacl_Bignum256.c index 372baa30f..6894f8bd8 100644 --- a/src/math/bigfix/Hacl_Bignum256.c +++ b/src/math/bigfix/Hacl_Bignum256.c @@ -274,6 +274,7 @@ void Hacl_Bignum256_sqr(uint64_t *a, uint64_t *res) res[i0 + i0] = r; } uint64_t c0 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, res, res, res); + (void)c0; KRML_CHECK_SIZE(sizeof (uint64_t), resLen); uint64_t *tmp = alloca(resLen * sizeof (uint64_t)); memset(tmp, 0U, resLen * sizeof (uint64_t)); @@ -286,6 +287,7 @@ void Hacl_Bignum256_sqr(uint64_t *a, uint64_t *res) tmp[(uint32_t)2U * i + (uint32_t)1U] = hi; } uint64_t c1 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, res, tmp, res); + (void)c1; } static inline void precompr2(uint32_t nBits, uint64_t *n, uint64_t *res) @@ -414,6 +416,7 @@ static inline void areduction(uint64_t *n, uint64_t nInv, uint64_t *c, uint64_t uint64_t c00 = c0; uint64_t tmp[4U] = { 0U }; uint64_t c1 = Hacl_Bignum256_sub(res, n, tmp); + (void)c1; uint64_t m = (uint64_t)0U - c00; for (uint32_t i = (uint32_t)0U; i < (uint32_t)4U; i++) { @@ -497,6 +500,7 @@ static inline void amont_sqr(uint64_t *n, uint64_t nInv_u64, uint64_t *aM, uint6 c[i0 + i0] = r; } uint64_t c0 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, c, c, c); + (void)c0; KRML_CHECK_SIZE(sizeof (uint64_t), resLen); uint64_t *tmp = alloca(resLen * sizeof (uint64_t)); memset(tmp, 0U, resLen * sizeof (uint64_t)); diff --git a/src/math/bigfix/u256.cpp b/src/math/bigfix/u256.cpp index 65c7144dc..b00820099 100644 --- a/src/math/bigfix/u256.cpp +++ b/src/math/bigfix/u256.cpp @@ -21,35 +21,19 @@ u256 u256::operator*(u256 const& other) const { return u256(result); } -u256 u256::operator+(u256 const& other) const { - u256 result; - Hacl_Bignum256_add(const_cast(m_num), const_cast(other.m_num), result.m_num); - return result; -} - -u256 u256::operator-(u256 const& other) const { - u256 result; - Hacl_Bignum256_sub(const_cast(m_num), const_cast(other.m_num), result.m_num); - return result; -} - u256& u256::operator*=(u256 const& other) { uint64_t result[8]; - Hacl_Bignum256_add(const_cast(m_num), const_cast(other.m_num), result); + Hacl_Bignum256_mul(const_cast(m_num), const_cast(other.m_num), result); std::uninitialized_copy(m_num, m_num + sizeof(*this), result); return *this; } u256& u256::operator+=(u256 const& other) { - uint64_t result[4]; - Hacl_Bignum256_add(const_cast(m_num), const_cast(other.m_num), result); - std::uninitialized_copy(m_num, m_num + sizeof(*this), result); + Hacl_Bignum256_add(const_cast(m_num), const_cast(other.m_num), m_num); return *this; } u256& u256::operator-=(u256 const& other) { - uint64_t result[4]; - Hacl_Bignum256_sub(const_cast(m_num), const_cast(other.m_num), result); - std::uninitialized_copy(m_num, m_num + sizeof(*this), result); + Hacl_Bignum256_sub(const_cast(m_num), const_cast(other.m_num), m_num); return *this; } diff --git a/src/math/bigfix/u256.h b/src/math/bigfix/u256.h index 52214d06d..5ae5eae7c 100644 --- a/src/math/bigfix/u256.h +++ b/src/math/bigfix/u256.h @@ -9,8 +9,8 @@ public: u256(uint64_t n); u256(uint64_t const* v); u256 operator*(u256 const& other) const; - u256 operator+(u256 const& other) const; - u256 operator-(u256 const& other) const; + u256 operator+(u256 const& other) const { u256 r = *this; return r += other; } + u256 operator-(u256 const& other) const { u256 r = *this; return r -= other; } u256& operator+=(u256 const& other); u256& operator*=(u256 const& other); u256& operator-=(u256 const& other); diff --git a/src/math/interval/mod_interval_def.h b/src/math/interval/mod_interval_def.h index 9fc1f6e06..b1bdcfcf7 100644 --- a/src/math/interval/mod_interval_def.h +++ b/src/math/interval/mod_interval_def.h @@ -123,20 +123,6 @@ Numeral mod_interval::closest_value(Numeral const& n) const { // TBD: correctness and completeness for wrap-around semantics needs to be checked/fixed -template -mod_interval& mod_interval::intersect_ule(Numeral const& h) { - if (is_empty()) - return *this; - if (is_max(h)) - return *this; - else if (is_free()) - lo = 0, hi = h + 1; - else if (hi > lo && lo > h) - set_empty(); - else if (hi == 0 || h + 1 < hi) - hi = h + 1; - return *this; -} template mod_interval& mod_interval::intersect_uge(Numeral const& l) { @@ -148,6 +134,50 @@ mod_interval& mod_interval::intersect_uge(Numeral const& l) { lo = l, hi = 0; else if ((lo < hi || hi == 0) && lo < l) lo = l; + else if (hi < lo && hi <= l && l <= lo) + hi = 0; + else if (hi < lo && lo < l) + hi = 0, lo = l; + return *this; +} + +template +mod_interval& mod_interval::intersect_ugt(Numeral const& l) { + if (is_empty()) + return *this; + if (is_max(l)) + set_empty(); + else if (is_free()) + lo = l + 1, hi = 0; + else if (lo < hi && hi <= l) + set_empty(); + else if (lo < hi) + lo = l + 1; + else if (hi < lo && hi <= l + 1 && l <= lo - 1) + hi = 0; + else if (hi < lo && lo <= l) + hi = 0, lo = l + 1; + return *this; +} + +template +mod_interval& mod_interval::intersect_ule(Numeral const& h) { + if (is_empty()) + return *this; + if (is_max(h)) + return *this; + else if (is_free()) + lo = 0, hi = h + 1; + else if (hi > lo && lo > h) + set_empty(); + else if (hi == 0 && h >= lo) + hi = h + 1; + else if (lo <= h && h + 1 < hi) + hi = h + 1; + else if (hi < lo && h < hi) + hi = h + 1, lo = 0; + else if (hi <= h && h < lo) + lo = 0; return *this; } @@ -163,25 +193,13 @@ mod_interval& mod_interval::intersect_ult(Numeral const& h) { set_empty(); else if (hi > lo && h < hi) hi = h; + else if (hi < lo && h <= hi) + hi = h, lo = 0; + else if (hi < h && h <= lo) + lo = 0; return *this; } -template -mod_interval& mod_interval::intersect_ugt(Numeral const& l) { - if (is_empty()) - return *this; - if (is_max(l)) - set_empty(); - else if (is_free()) - lo = l + 1, hi = 0; - else if (lo > l) - return *this; - else if (lo < hi && hi <= l) - set_empty(); - else if (lo < hi) - lo = l + 1; - return *this; -} template mod_interval& mod_interval::intersect_fixed(Numeral const& a) { diff --git a/src/test/mod_interval.cpp b/src/test/mod_interval.cpp index a743e34ba..69a39220c 100644 --- a/src/test/mod_interval.cpp +++ b/src/test/mod_interval.cpp @@ -30,6 +30,69 @@ static void test_interval2() { SASSERT(i.is_empty()); i = mod_interval(500, 10); std::cout << "test-wrap: " << i << "\n"; + std::cout << " >= 0: " << i.intersect_uge(0) << "\n"; + std::cout << " >= 2: " << i.intersect_uge(2) << "\n"; + std::cout << " >= 11: " << i.intersect_uge(11) << "\n"; + i = mod_interval(500, 10); + std::cout << " >= 10: " << i << " -> " << i.intersect_uge(10) << "\n"; + i = mod_interval(500, 10); + std::cout << " >= 499: " << i << " -> " << i.intersect_uge(499) << "\n"; + i = mod_interval(500, 10); + std::cout << " >= 500: " << i << " -> " << i.intersect_uge(500) << "\n"; + i = mod_interval(500, 10); + std::cout << " >= 501: " << i << " -> " << i.intersect_uge(501) << "\n"; + + i = mod_interval(500, 10); + std::cout << " > 0: " << i.intersect_ugt(0) << "\n"; + std::cout << " > 2: " << i.intersect_ugt(2) << "\n"; + std::cout << " > 10: " << i.intersect_ugt(10) << "\n"; + std::cout << " > 11: " << i.intersect_ugt(11) << "\n"; + i = mod_interval(500, 10); + std::cout << " > 10: " << i << " -> " << i.intersect_ugt(10) << "\n"; + i = mod_interval(500, 10); + std::cout << " > 499: " << i << " -> " << i.intersect_ugt(499) << "\n"; + i = mod_interval(500, 10); + std::cout << " > 500: " << i << " -> " << i.intersect_ugt(500) << "\n"; + i = mod_interval(500, 10); + std::cout << " > 501: " << i << " -> " << i.intersect_ugt(501) << "\n"; + + i = mod_interval(500, 10); + std::cout << " <= 0: " << i.intersect_ule(0) << "\n"; + i = mod_interval(500, 10); + std::cout << " <= 2: " << i.intersect_ule(2) << "\n"; + i = mod_interval(500, 10); + std::cout << " <= 9: " << i.intersect_ule(9) << "\n"; + i = mod_interval(500, 10); + std::cout << " <= 10: " << i.intersect_ule(10) << "\n"; + i = mod_interval(500, 10); + std::cout << " <= 11: " << i.intersect_ule(11) << "\n"; + i = mod_interval(500, 10); + std::cout << " <= 499: " << i << " -> " << i.intersect_ule(499) << "\n"; + i = mod_interval(500, 10); + std::cout << " <= 500: " << i << " -> " << i.intersect_ule(500) << "\n"; + i = mod_interval(500, 10); + std::cout << " <= 501: " << i << " -> " << i.intersect_ule(501) << "\n"; + + + i = mod_interval(500, 10); + std::cout << " < 0: " << i.intersect_ult(0) << "\n"; + i = mod_interval(500, 10); + std::cout << " < 2: " << i.intersect_ult(2) << "\n"; + i = mod_interval(500, 10); + std::cout << " < 9: " << i.intersect_ult(9) << "\n"; + i = mod_interval(500, 10); + std::cout << " < 10: " << i.intersect_ult(10) << "\n"; + i = mod_interval(500, 10); + std::cout << " < 11: " << i.intersect_ult(11) << "\n"; + i = mod_interval(500, 10); + std::cout << " < 499: " << i << " -> " << i.intersect_ult(499) << "\n"; + i = mod_interval(500, 10); + std::cout << " < 500: " << i << " -> " << i.intersect_ult(500) << "\n"; + i = mod_interval(500, 10); + std::cout << " < 501: " << i << " -> " << i.intersect_ult(501) << "\n"; + + + } void tst_mod_interval() {