mirror of
https://github.com/Z3Prover/z3
synced 2025-04-22 00:26:38 +00:00
test / fix wrap-around for mod-interval
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
1355ea432a
commit
ff717a9db1
5 changed files with 120 additions and 51 deletions
|
@ -274,6 +274,7 @@ void Hacl_Bignum256_sqr(uint64_t *a, uint64_t *res)
|
|||
res[i0 + i0] = r;
|
||||
}
|
||||
uint64_t c0 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, res, res, res);
|
||||
(void)c0;
|
||||
KRML_CHECK_SIZE(sizeof (uint64_t), resLen);
|
||||
uint64_t *tmp = alloca(resLen * sizeof (uint64_t));
|
||||
memset(tmp, 0U, resLen * sizeof (uint64_t));
|
||||
|
@ -286,6 +287,7 @@ void Hacl_Bignum256_sqr(uint64_t *a, uint64_t *res)
|
|||
tmp[(uint32_t)2U * i + (uint32_t)1U] = hi;
|
||||
}
|
||||
uint64_t c1 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, res, tmp, res);
|
||||
(void)c1;
|
||||
}
|
||||
|
||||
static inline void precompr2(uint32_t nBits, uint64_t *n, uint64_t *res)
|
||||
|
@ -414,6 +416,7 @@ static inline void areduction(uint64_t *n, uint64_t nInv, uint64_t *c, uint64_t
|
|||
uint64_t c00 = c0;
|
||||
uint64_t tmp[4U] = { 0U };
|
||||
uint64_t c1 = Hacl_Bignum256_sub(res, n, tmp);
|
||||
(void)c1;
|
||||
uint64_t m = (uint64_t)0U - c00;
|
||||
for (uint32_t i = (uint32_t)0U; i < (uint32_t)4U; i++)
|
||||
{
|
||||
|
@ -497,6 +500,7 @@ static inline void amont_sqr(uint64_t *n, uint64_t nInv_u64, uint64_t *aM, uint6
|
|||
c[i0 + i0] = r;
|
||||
}
|
||||
uint64_t c0 = Hacl_Bignum_Addition_bn_add_eq_len_u64(resLen, c, c, c);
|
||||
(void)c0;
|
||||
KRML_CHECK_SIZE(sizeof (uint64_t), resLen);
|
||||
uint64_t *tmp = alloca(resLen * sizeof (uint64_t));
|
||||
memset(tmp, 0U, resLen * sizeof (uint64_t));
|
||||
|
|
|
@ -21,35 +21,19 @@ u256 u256::operator*(u256 const& other) const {
|
|||
return u256(result);
|
||||
}
|
||||
|
||||
u256 u256::operator+(u256 const& other) const {
|
||||
u256 result;
|
||||
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result.m_num);
|
||||
return result;
|
||||
}
|
||||
|
||||
u256 u256::operator-(u256 const& other) const {
|
||||
u256 result;
|
||||
Hacl_Bignum256_sub(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result.m_num);
|
||||
return result;
|
||||
}
|
||||
|
||||
u256& u256::operator*=(u256 const& other) {
|
||||
uint64_t result[8];
|
||||
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
|
||||
Hacl_Bignum256_mul(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
|
||||
std::uninitialized_copy(m_num, m_num + sizeof(*this), result);
|
||||
return *this;
|
||||
}
|
||||
|
||||
u256& u256::operator+=(u256 const& other) {
|
||||
uint64_t result[4];
|
||||
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
|
||||
std::uninitialized_copy(m_num, m_num + sizeof(*this), result);
|
||||
Hacl_Bignum256_add(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), m_num);
|
||||
return *this;
|
||||
}
|
||||
|
||||
u256& u256::operator-=(u256 const& other) {
|
||||
uint64_t result[4];
|
||||
Hacl_Bignum256_sub(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), result);
|
||||
std::uninitialized_copy(m_num, m_num + sizeof(*this), result);
|
||||
Hacl_Bignum256_sub(const_cast<uint64_t*>(m_num), const_cast<uint64_t*>(other.m_num), m_num);
|
||||
return *this;
|
||||
}
|
||||
|
|
|
@ -9,8 +9,8 @@ public:
|
|||
u256(uint64_t n);
|
||||
u256(uint64_t const* v);
|
||||
u256 operator*(u256 const& other) const;
|
||||
u256 operator+(u256 const& other) const;
|
||||
u256 operator-(u256 const& other) const;
|
||||
u256 operator+(u256 const& other) const { u256 r = *this; return r += other; }
|
||||
u256 operator-(u256 const& other) const { u256 r = *this; return r -= other; }
|
||||
u256& operator+=(u256 const& other);
|
||||
u256& operator*=(u256 const& other);
|
||||
u256& operator-=(u256 const& other);
|
||||
|
|
|
@ -123,20 +123,6 @@ Numeral mod_interval<Numeral>::closest_value(Numeral const& n) const {
|
|||
|
||||
// TBD: correctness and completeness for wrap-around semantics needs to be checked/fixed
|
||||
|
||||
template<typename Numeral>
|
||||
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ule(Numeral const& h) {
|
||||
if (is_empty())
|
||||
return *this;
|
||||
if (is_max(h))
|
||||
return *this;
|
||||
else if (is_free())
|
||||
lo = 0, hi = h + 1;
|
||||
else if (hi > lo && lo > h)
|
||||
set_empty();
|
||||
else if (hi == 0 || h + 1 < hi)
|
||||
hi = h + 1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename Numeral>
|
||||
mod_interval<Numeral>& mod_interval<Numeral>::intersect_uge(Numeral const& l) {
|
||||
|
@ -148,6 +134,50 @@ mod_interval<Numeral>& mod_interval<Numeral>::intersect_uge(Numeral const& l) {
|
|||
lo = l, hi = 0;
|
||||
else if ((lo < hi || hi == 0) && lo < l)
|
||||
lo = l;
|
||||
else if (hi < lo && hi <= l && l <= lo)
|
||||
hi = 0;
|
||||
else if (hi < lo && lo < l)
|
||||
hi = 0, lo = l;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename Numeral>
|
||||
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ugt(Numeral const& l) {
|
||||
if (is_empty())
|
||||
return *this;
|
||||
if (is_max(l))
|
||||
set_empty();
|
||||
else if (is_free())
|
||||
lo = l + 1, hi = 0;
|
||||
else if (lo < hi && hi <= l)
|
||||
set_empty();
|
||||
else if (lo < hi)
|
||||
lo = l + 1;
|
||||
else if (hi < lo && hi <= l + 1 && l <= lo - 1)
|
||||
hi = 0;
|
||||
else if (hi < lo && lo <= l)
|
||||
hi = 0, lo = l + 1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename Numeral>
|
||||
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ule(Numeral const& h) {
|
||||
if (is_empty())
|
||||
return *this;
|
||||
if (is_max(h))
|
||||
return *this;
|
||||
else if (is_free())
|
||||
lo = 0, hi = h + 1;
|
||||
else if (hi > lo && lo > h)
|
||||
set_empty();
|
||||
else if (hi == 0 && h >= lo)
|
||||
hi = h + 1;
|
||||
else if (lo <= h && h + 1 < hi)
|
||||
hi = h + 1;
|
||||
else if (hi < lo && h < hi)
|
||||
hi = h + 1, lo = 0;
|
||||
else if (hi <= h && h < lo)
|
||||
lo = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -163,25 +193,13 @@ mod_interval<Numeral>& mod_interval<Numeral>::intersect_ult(Numeral const& h) {
|
|||
set_empty();
|
||||
else if (hi > lo && h < hi)
|
||||
hi = h;
|
||||
else if (hi < lo && h <= hi)
|
||||
hi = h, lo = 0;
|
||||
else if (hi < h && h <= lo)
|
||||
lo = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename Numeral>
|
||||
mod_interval<Numeral>& mod_interval<Numeral>::intersect_ugt(Numeral const& l) {
|
||||
if (is_empty())
|
||||
return *this;
|
||||
if (is_max(l))
|
||||
set_empty();
|
||||
else if (is_free())
|
||||
lo = l + 1, hi = 0;
|
||||
else if (lo > l)
|
||||
return *this;
|
||||
else if (lo < hi && hi <= l)
|
||||
set_empty();
|
||||
else if (lo < hi)
|
||||
lo = l + 1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename Numeral>
|
||||
mod_interval<Numeral>& mod_interval<Numeral>::intersect_fixed(Numeral const& a) {
|
||||
|
|
|
@ -30,6 +30,69 @@ static void test_interval2() {
|
|||
SASSERT(i.is_empty());
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << "test-wrap: " << i << "\n";
|
||||
std::cout << " >= 0: " << i.intersect_uge(0) << "\n";
|
||||
std::cout << " >= 2: " << i.intersect_uge(2) << "\n";
|
||||
std::cout << " >= 11: " << i.intersect_uge(11) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " >= 10: " << i << " -> " << i.intersect_uge(10) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " >= 499: " << i << " -> " << i.intersect_uge(499) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " >= 500: " << i << " -> " << i.intersect_uge(500) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " >= 501: " << i << " -> " << i.intersect_uge(501) << "\n";
|
||||
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " > 0: " << i.intersect_ugt(0) << "\n";
|
||||
std::cout << " > 2: " << i.intersect_ugt(2) << "\n";
|
||||
std::cout << " > 10: " << i.intersect_ugt(10) << "\n";
|
||||
std::cout << " > 11: " << i.intersect_ugt(11) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " > 10: " << i << " -> " << i.intersect_ugt(10) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " > 499: " << i << " -> " << i.intersect_ugt(499) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " > 500: " << i << " -> " << i.intersect_ugt(500) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " > 501: " << i << " -> " << i.intersect_ugt(501) << "\n";
|
||||
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 0: " << i.intersect_ule(0) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 2: " << i.intersect_ule(2) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 9: " << i.intersect_ule(9) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 10: " << i.intersect_ule(10) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 11: " << i.intersect_ule(11) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 499: " << i << " -> " << i.intersect_ule(499) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 500: " << i << " -> " << i.intersect_ule(500) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " <= 501: " << i << " -> " << i.intersect_ule(501) << "\n";
|
||||
|
||||
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 0: " << i.intersect_ult(0) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 2: " << i.intersect_ult(2) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 9: " << i.intersect_ult(9) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 10: " << i.intersect_ult(10) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 11: " << i.intersect_ult(11) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 499: " << i << " -> " << i.intersect_ult(499) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 500: " << i << " -> " << i.intersect_ult(500) << "\n";
|
||||
i = mod_interval<uint32_t>(500, 10);
|
||||
std::cout << " < 501: " << i << " -> " << i.intersect_ult(501) << "\n";
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
void tst_mod_interval() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue