From 18291543d6175ebcb55fbc805ae0f9e1e8b07808 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 1 Feb 2022 13:21:51 -0800 Subject: [PATCH] fixing corner cases for viable intervals --- src/math/polysat/viable.cpp | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/math/polysat/viable.cpp b/src/math/polysat/viable.cpp index d6a5ac5fe..7c24241ef 100644 --- a/src/math/polysat/viable.cpp +++ b/src/math/polysat/viable.cpp @@ -122,6 +122,7 @@ namespace polysat { m_alloc.push_back(ne); return false; } + auto create_entry = [&]() { m_trail.push_back({ v, entry_kind::unit_e, ne }); @@ -136,6 +137,13 @@ namespace polysat { e->remove_from(m_units[v], e); }; + if (ne->interval.is_full()) { + while (m_units[v]) + remove_entry(m_units[v]); + m_units[v] = create_entry(); + return true; + } + if (!e) m_units[v] = create_entry(); else { @@ -257,16 +265,21 @@ namespace polysat { lo = val - lambda_l; increase_hi(hi); } - LOG("forbidden interval " << e->coeff << " * " << e->interval << " [" << lo << ", " << hi << "["); + LOG("forbidden interval v" << v << " " << val << " " << e->coeff << " * " << e->interval << " [" << lo << ", " << hi << "["); SASSERT(hi <= mod_value); - if (hi == mod_value) hi = 0; + bool full = (lo == 0 && hi == mod_value); + if (hi == mod_value) + hi = 0; pdd lop = s.var2pdd(v).mk_val(lo); pdd hip = s.var2pdd(v).mk_val(hi); entry* ne = alloc_entry(); ne->src = e->src; ne->side_cond = e->side_cond; ne->coeff = 1; - ne->interval = eval_interval::proper(lop, lo, hip, hi); + if (full) + ne->interval = eval_interval::full(); + else + ne->interval = eval_interval::proper(lop, lo, hip, hi); intersect(v, ne); return false; } @@ -384,6 +397,8 @@ namespace polysat { entry* first = e; entry* last = e->prev(); + if (e->interval.is_full()) + return false; // quick check: last interval doesn't wrap around, so hi_val // has not been covered if (last->interval.lo_val() < last->interval.hi_val()) @@ -616,7 +631,7 @@ namespace polysat { std::ostream& viable::display(std::ostream& out) const { for (pvar v = 0; v < m_units.size(); ++v) - display(out << "v" << v << ": ", v); + display(out << "v" << v << ": ", v) << "\n"; return out; }