diff --git a/src/math/polysat/forbidden_intervals.cpp b/src/math/polysat/forbidden_intervals.cpp index 78414ae46..c1a246215 100644 --- a/src/math/polysat/forbidden_intervals.cpp +++ b/src/math/polysat/forbidden_intervals.cpp @@ -200,12 +200,12 @@ namespace polysat { if (a1 != a2 && !a1.is_zero() && !a2.is_zero()) return false; SASSERT(b1.is_val()); - SASSERT(b2.is_val()); - - LOG("values " << a1 << " " << a2); + SASSERT(b2.is_val()); _backtrack.released = true; + // LOG("add " << c << " " << a1 << " " << b1 << " " << a2 << " " << b2); + if (match_linear1(c, a1, b1, e1, a2, b2, e2, out_interval, out_side_cond)) return true; if (match_linear2(c, a1, b1, e1, a2, b2, e2, out_interval, out_side_cond)) diff --git a/src/math/polysat/solver.h b/src/math/polysat/solver.h index 74a39e5ed..784fc6601 100644 --- a/src/math/polysat/solver.h +++ b/src/math/polysat/solver.h @@ -66,6 +66,7 @@ namespace polysat { friend class ex_polynomial_superposition; friend class inf_saturate; friend class constraint_manager; + friend class scoped_solverv; reslimit& m_lim; params_ref m_params; diff --git a/src/math/polysat/viable2.cpp b/src/math/polysat/viable2.cpp index 011a9b068..67d7b853b 100644 --- a/src/math/polysat/viable2.cpp +++ b/src/math/polysat/viable2.cpp @@ -85,6 +85,8 @@ namespace polysat { e->remove_from(m_viable[v], e); }; + //LOG("intersect " << ne->interval); + if (!e) m_viable[v] = create_entry(); else { @@ -107,6 +109,10 @@ namespace polysat { } SASSERT(e->interval.lo_val() != ne->interval.lo_val()); if (e->interval.lo_val() > ne->interval.lo_val()) { + if (first->prev()->interval.contains(ne->interval)) { + m_alloc.push_back(ne); + return; + } e->insert_before(create_entry()); if (e == first) m_viable[v] = e->prev(); @@ -155,7 +161,7 @@ namespace polysat { for (; e != last; e = e->next()) { if (e->interval.currently_contains(val)) return false; - if (e->interval.lo_val() < val) + if (val < e->interval.lo_val()) return true; } return true; @@ -294,11 +300,13 @@ namespace polysat { while (true) { if (e->interval.is_full()) return e->next() == e; - if (e->interval.is_currently_empty()) + if (e->interval.is_currently_empty()) return false; + auto* n = e->next(); - if (n != e && e->interval.contains(n->interval)) + if (n != e && e->interval.contains(n->interval)) return false; + if (n == first) break; if (e->interval.lo_val() >= n->interval.lo_val()) diff --git a/src/test/viable.cpp b/src/test/viable.cpp index c3368b3dd..352f0c76f 100644 --- a/src/test/viable.cpp +++ b/src/test/viable.cpp @@ -8,9 +8,24 @@ namespace polysat { reslimit lim; }; - struct scoped_solverv : public solver_scopev, public solver { + class scoped_solverv : public solver_scopev, public solver { + public: viable2 v; scoped_solverv(): solver(lim), v(*this) {} + ~scoped_solverv() { + for (unsigned i = m_trail.size(); i-- > 0;) { + switch (m_trail[i]) { + case trail_instr_t::viable_add_i: + v.pop_viable(); + break; + case trail_instr_t::viable_rem_i: + v.push_viable(); + break; + default: + break; + } + } + } }; static void test1() { @@ -37,8 +52,66 @@ namespace polysat { std::cout << s.v.find_viable(xv, val) << " " << val << "\n"; } - static void test2() { + static void add_interval(scoped_solverv& s, pvar xv, pdd x, unsigned lo, unsigned len) { + if (lo == 0) + s.v.intersect(xv, s.ule(x, len - 1)); + else if (lo + len == 8) + s.v.intersect(xv, s.ule(lo, x)); + else + s.v.intersect(xv, ~s.ule(x - ((lo + len)%8), x - lo)); + } + static bool member(unsigned i, unsigned lo, unsigned len) { + return (lo <= i && i < lo + len) || + (lo + len >= 8 && i < ((lo + len) % 8)); + } + + static bool member(unsigned i, vector> const& intervals) { + for (auto [lo, len] : intervals) + if (!member(i, lo, len)) + return false; + return true; + } + + static void validate_intervals(scoped_solverv& s, pvar xv, vector> const& intervals) { + for (unsigned i = 0; i < 8; ++i) { + bool mem1 = member(i, intervals); + bool mem2 = s.v.is_viable(xv, rational(i)); + if (mem1 != mem2) + std::cout << "test violation: " << i << " member of all intervals " << mem1 << " viable: " << mem2 << "\n"; + SASSERT(mem1 == mem2); + } + } + + static void test_intervals(vector> const& intervals) { + scoped_solverv s; + auto xv = s.add_var(3); + auto x = s.var(xv); + s.v.push(3); + for (auto const& [lo, len] : intervals) + add_interval(s, xv, x, lo, len); + std::cout << intervals << "\n"; + //std::cout << s.v << "\n"; + validate_intervals(s, xv, intervals); + } + + static void test_intervals(unsigned count, vector>& intervals) { + if (count == 0) { + test_intervals(intervals); + return; + } + for (unsigned lo1 = 0; lo1 < 8; ++lo1) { + for (unsigned len1 = 1; len1 <= 8; ++len1) { + intervals.push_back(std::make_pair(lo1, len1)); + test_intervals(count - 1, intervals); + intervals.pop_back(); + } + } + } + + static void test2() { + vector> intervals; + test_intervals(3, intervals); } }