3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 00:26:38 +00:00

test / fix wrap-around for mod-interval

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2021-07-03 10:49:22 -07:00
parent 1355ea432a
commit ff717a9db1
5 changed files with 120 additions and 51 deletions

View file

@ -274,6 +274,7 @@ void Hacl_Bignum256_sqr(uint64_t *a, uint64_t *res)
res[i0 + i0] = r;
}
uint64_t c0 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, res, res, res);
(void)c0;
KRML_CHECK_SIZE(sizeof (uint64_t), resLen);
uint64_t *tmp = alloca(resLen * sizeof (uint64_t));
memset(tmp, 0U, resLen * sizeof (uint64_t));
@ -286,6 +287,7 @@ void Hacl_Bignum256_sqr(uint64_t *a, uint64_t *res)
tmp[(uint32_t)2U * i + (uint32_t)1U] = hi;
}
uint64_t c1 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, res, tmp, res);
(void)c1;
}
static inline void precompr2(uint32_t nBits, uint64_t *n, uint64_t *res)
@ -414,6 +416,7 @@ static inline void areduction(uint64_t *n, uint64_t nInv, uint64_t *c, uint64_t
uint64_t c00 = c0;
uint64_t tmp[4U] = { 0U };
uint64_t c1 = Hacl_Bignum256_sub(res, n, tmp);
(void)c1;
uint64_t m = (uint64_t)0U - c00;
for (uint32_t i = (uint32_t)0U; i < (uint32_t)4U; i++)
{
@ -497,6 +500,7 @@ static inline void amont_sqr(uint64_t *n, uint64_t nInv_u64, uint64_t *aM, uint6
c[i0 + i0] = r;
}
uint64_t c0 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, c, c, c);
(void)c0;
KRML_CHECK_SIZE(sizeof (uint64_t), resLen);
uint64_t *tmp = alloca(resLen * sizeof (uint64_t));
memset(tmp, 0U, resLen * sizeof (uint64_t));

View file

@ -21,35 +21,19 @@ u256 u256::operator*(u256 const& other) const {
return u256(result);
}
u256 u256::operator+(u256 const& other) const {
u256 result;
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result.m_num);
return result;
}
u256 u256::operator-(u256 const& other) const {
u256 result;
Hacl_Bignum256_sub(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result.m_num);
return result;
}
u256& u256::operator*=(u256 const& other) {
uint64_t result[8];
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
Hacl_Bignum256_mul(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
std::uninitialized_copy(m_num, m_num + sizeof(*this), result);
return *this;
}
u256& u256::operator+=(u256 const& other) {
uint64_t result[4];
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
std::uninitialized_copy(m_num, m_num + sizeof(*this), result);
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), m_num);
return *this;
}
u256& u256::operator-=(u256 const& other) {
uint64_t result[4];
Hacl_Bignum256_sub(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
std::uninitialized_copy(m_num, m_num + sizeof(*this), result);
Hacl_Bignum256_sub(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), m_num);
return *this;
}

View file

@ -9,8 +9,8 @@ public:
u256(uint64_t n);
u256(uint64_t const* v);
u256 operator*(u256 const& other) const;
u256 operator+(u256 const& other) const;
u256 operator-(u256 const& other) const;
u256 operator+(u256 const& other) const { u256 r = *this; return r += other; }
u256 operator-(u256 const& other) const { u256 r = *this; return r -= other; }
u256& operator+=(u256 const& other);
u256& operator*=(u256 const& other);
u256& operator-=(u256 const& other);

View file

@ -123,20 +123,6 @@ Numeral mod_interval<Numeral>::closest_value(Numeral const& n) const {
// TBD: correctness and completeness for wrap-around semantics needs to be checked/fixed
template<typename Numeral>
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ule(Numeral const& h) {
if (is_empty())
return *this;
if (is_max(h))
return *this;
else if (is_free())
lo = 0, hi = h + 1;
else if (hi > lo && lo > h)
set_empty();
else if (hi == 0 || h + 1 < hi)
hi = h + 1;
return *this;
}
template<typename Numeral>
mod_interval<Numeral>& mod_interval<Numeral>::intersect_uge(Numeral const& l) {
@ -148,6 +134,50 @@ mod_interval<Numeral>& mod_interval<Numeral>::intersect_uge(Numeral const& l) {
lo = l, hi = 0;
else if ((lo < hi || hi == 0) && lo < l)
lo = l;
else if (hi < lo && hi <= l && l <= lo)
hi = 0;
else if (hi < lo && lo < l)
hi = 0, lo = l;
return *this;
}
template<typename Numeral>
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ugt(Numeral const& l) {
if (is_empty())
return *this;
if (is_max(l))
set_empty();
else if (is_free())
lo = l + 1, hi = 0;
else if (lo < hi && hi <= l)
set_empty();
else if (lo < hi)
lo = l + 1;
else if (hi < lo && hi <= l + 1 && l <= lo - 1)
hi = 0;
else if (hi < lo && lo <= l)
hi = 0, lo = l + 1;
return *this;
}
template<typename Numeral>
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ule(Numeral const& h) {
if (is_empty())
return *this;
if (is_max(h))
return *this;
else if (is_free())
lo = 0, hi = h + 1;
else if (hi > lo && lo > h)
set_empty();
else if (hi == 0 && h >= lo)
hi = h + 1;
else if (lo <= h && h + 1 < hi)
hi = h + 1;
else if (hi < lo && h < hi)
hi = h + 1, lo = 0;
else if (hi <= h && h < lo)
lo = 0;
return *this;
}
@ -163,25 +193,13 @@ mod_interval<Numeral>& mod_interval<Numeral>::intersect_ult(Numeral const& h) {
set_empty();
else if (hi > lo && h < hi)
hi = h;
else if (hi < lo && h <= hi)
hi = h, lo = 0;
else if (hi < h && h <= lo)
lo = 0;
return *this;
}
template<typename Numeral>
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ugt(Numeral const& l) {
if (is_empty())
return *this;
if (is_max(l))
set_empty();
else if (is_free())
lo = l + 1, hi = 0;
else if (lo > l)
return *this;
else if (lo < hi && hi <= l)
set_empty();
else if (lo < hi)
lo = l + 1;
return *this;
}
template<typename Numeral>
mod_interval<Numeral>& mod_interval<Numeral>::intersect_fixed(Numeral const& a) {

View file

@ -30,6 +30,69 @@ static void test_interval2() {
SASSERT(i.is_empty());
i = mod_interval<uint32_t>(500, 10);
std::cout << "test-wrap: " << i << "\n";
std::cout << " >= 0: " << i.intersect_uge(0) << "\n";
std::cout << " >= 2: " << i.intersect_uge(2) << "\n";
std::cout << " >= 11: " << i.intersect_uge(11) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " >= 10: " << i << " -> " << i.intersect_uge(10) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " >= 499: " << i << " -> " << i.intersect_uge(499) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " >= 500: " << i << " -> " << i.intersect_uge(500) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " >= 501: " << i << " -> " << i.intersect_uge(501) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " > 0: " << i.intersect_ugt(0) << "\n";
std::cout << " > 2: " << i.intersect_ugt(2) << "\n";
std::cout << " > 10: " << i.intersect_ugt(10) << "\n";
std::cout << " > 11: " << i.intersect_ugt(11) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " > 10: " << i << " -> " << i.intersect_ugt(10) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " > 499: " << i << " -> " << i.intersect_ugt(499) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " > 500: " << i << " -> " << i.intersect_ugt(500) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " > 501: " << i << " -> " << i.intersect_ugt(501) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 0: " << i.intersect_ule(0) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 2: " << i.intersect_ule(2) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 9: " << i.intersect_ule(9) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 10: " << i.intersect_ule(10) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 11: " << i.intersect_ule(11) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 499: " << i << " -> " << i.intersect_ule(499) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 500: " << i << " -> " << i.intersect_ule(500) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " <= 501: " << i << " -> " << i.intersect_ule(501) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 0: " << i.intersect_ult(0) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 2: " << i.intersect_ult(2) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 9: " << i.intersect_ult(9) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 10: " << i.intersect_ult(10) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 11: " << i.intersect_ult(11) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 499: " << i << " -> " << i.intersect_ult(499) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 500: " << i << " -> " << i.intersect_ult(500) << "\n";
i = mod_interval<uint32_t>(500, 10);
std::cout << " < 501: " << i << " -> " << i.intersect_ult(501) << "\n";
}
void tst_mod_interval() {