diff --git a/src/math/interval/mod_interval.h b/src/math/interval/mod_interval.h index ab4bdfb6c..88dfff485 100644 --- a/src/math/interval/mod_interval.h +++ b/src/math/interval/mod_interval.h @@ -45,7 +45,7 @@ inline std::ostream& operator<<(std::ostream& out, pp const& p) { template class mod_interval { - bool emp { false }; + bool emp = false; public: Numeral lo { 0 }; Numeral hi { 0 }; @@ -59,7 +59,7 @@ public: bool is_empty() const { return emp; } bool is_singleton() const { return !is_empty() && (lo + 1 == hi || (hi == 0 && is_max(lo))); } bool contains(Numeral const& n) const; - virtual bool is_max(Numeral const& n) const { return n + 1 == 0; } + virtual bool is_max(Numeral const& n) const { return (Numeral)(n + 1) == 0; } void set_free() { lo = hi = 0; emp = false; } void set_bounds(Numeral const& l, Numeral const& h) { lo = l; hi = h; } @@ -89,6 +89,16 @@ public: return out << "[" << pp(lo) << ", " << pp(hi) << "["; } Numeral closest_value(Numeral const& n) const; + bool operator==(mod_interval const& other) const { + if (is_empty()) + return other.is_empty(); + if (is_free()) + return other.is_free(); + return lo == other.lo && hi == other.hi; + } + bool operator!=(mod_interval const& other) const { + return !(*this == other); + } }; template diff --git a/src/math/interval/mod_interval_def.h b/src/math/interval/mod_interval_def.h index 6182c1b93..7e15c9741 100644 --- a/src/math/interval/mod_interval_def.h +++ b/src/math/interval/mod_interval_def.h @@ -143,7 +143,7 @@ Numeral mod_interval::closest_value(Numeral const& n) const { return n; if (is_empty()) return n; - if (lo - n < n - hi) + if ((Numeral)(lo - n) < (Numeral)(n - hi)) return lo; return hi - 1; } @@ -159,11 +159,13 @@ mod_interval& mod_interval::intersect_uge(Numeral const& l) { else if (is_free()) lo = l, hi = 0; else if ((lo < hi || hi == 0) && lo < l) - lo = l; + lo = l; else if (hi < lo && hi <= l && l <= lo) hi = 0; else if (hi < lo && lo < l) hi = 0, lo = l; + else if (0 < l && l < hi && hi < lo) + lo = l, hi = 0; return *this; } @@ -175,14 +177,18 @@ mod_interval& mod_interval::intersect_ugt(Numeral const& l) { set_empty(); else if (is_free()) lo = l + 1, hi = 0; - else if (lo < hi && hi <= l) + else if (lo < hi && hi - 1 <= l) set_empty(); - else if (lo < hi) + else if (lo < hi && l < lo) + return *this; + else if (lo < hi) lo = l + 1; else if (hi < lo && hi <= l + 1 && l < lo) hi = 0; else if (hi < lo && lo <= l) hi = 0, lo = l + 1; + else if (l <= hi && hi < lo) + lo = l + 1, hi = 0; return *this; } @@ -194,16 +200,22 @@ mod_interval& mod_interval::intersect_ule(Numeral const& h) { return *this; else if (is_free()) lo = 0, hi = h + 1; - else if (hi > lo && lo > h) + else if (h < lo && (lo < hi || hi == 0)) set_empty(); else if (hi == 0 && h >= lo) hi = h + 1; else if (lo <= h && h + 1 < hi) hi = h + 1; - else if (hi < lo && h < hi) + else if (h < hi && hi < lo) hi = h + 1, lo = 0; else if (hi <= h && h < lo) lo = 0; + else if (hi == 0 && hi == h && hi < lo) + set_empty(); + else if (0 < hi && hi == h && hi < lo) + lo = 0; + else if (0 < hi && hi < h && hi < lo) + lo = 0, hi = h + 1; return *this; } @@ -215,14 +227,16 @@ mod_interval& mod_interval::intersect_ult(Numeral const& h) { set_empty(); else if (is_free()) lo = 0, hi = h; - else if (hi > lo && lo >= h) + else if (h <= lo && (lo < hi || hi == 0)) set_empty(); - else if (hi > lo && h < hi) + else if (h > lo && (h < hi || hi == 0)) hi = h; else if (hi < lo && h <= hi) hi = h, lo = 0; else if (hi < h && h <= lo) lo = 0; + else if (0 < hi && hi < lo && hi + 1 <= h) + lo = 0, hi = h; return *this; } diff --git a/src/test/mod_interval.cpp b/src/test/mod_interval.cpp index db560aee9..92aa03371 100644 --- a/src/test/mod_interval.cpp +++ b/src/test/mod_interval.cpp @@ -134,8 +134,75 @@ static void test_interval_intersect() { test_interval_intersect(bounds[i], bounds[j], bounds[k], bounds[l]); } +static void test_interval_intersect2(unsigned i, unsigned j, uint8_t k) { + if (i == j && i != 0) + return; + mod_interval x0(i, j); + + auto validate = [&](char const* t, mod_interval const& y, mod_interval const& z) { + if (y == z) + return; + std::cout << t << "(" << (unsigned)k << ") " << x0 << " -> " << y << " " << z << "\n"; + SASSERT(false); + }; + + { + mod_interval x = x0; + auto uge2 = x & mod_interval(k, 0); + auto uge1 = x.intersect_uge(k); + validate("uge", uge1, uge2); + } + + { + mod_interval x = x0; + auto ule1 = x.intersect_ule(k); + if ((uint8_t)(k + 1) != 0) { + auto ule2 = x0 & mod_interval(0, k + 1); + validate("ule", ule1, ule2); + } + else { + validate("ule", ule1, x0); + } + } + + { + mod_interval x = x0; + auto ult1 = x.intersect_ult(k); + if (k != 0) { + auto ult2 = x0 & mod_interval(0, k); + validate("ult", ult1, ult2); + } + else { + validate("ult", ult1, mod_interval::empty()); + } + } + { + mod_interval x = x0; + auto ugt1 = x.intersect_ugt(k); + + if ((uint8_t)(k + 1) != 0) { + auto ugt2 = x0 & mod_interval(k + 1, 0); + validate("ugt", ugt1, ugt2); + } + else { + validate("ugt", ugt1, mod_interval::empty()); + } + } +} + + +static void test_interval_intersect2() { + unsigned bounds[8] = { 0, 1, 2, 3, 252, 253, 254, 255 }; + for (unsigned i = 0; i < 8; ++i) + for (unsigned j = 0; j < 8; ++j) + for (unsigned k = 0; k < 8; ++k) + test_interval_intersect2(bounds[i], bounds[j], bounds[k]); +} + + void tst_mod_interval() { test_interval_intersect(); + test_interval_intersect2(); test_interval1(); test_interval2(); }