diff --git a/src/math/polysat/viable.cpp b/src/math/polysat/viable.cpp index 987ced644..fa10299af 100644 --- a/src/math/polysat/viable.cpp +++ b/src/math/polysat/viable.cpp @@ -75,18 +75,18 @@ namespace polysat { } void viable::pop_viable() { - auto& [v, k, e] = m_trail.back(); + auto const& [v, k, e] = m_trail.back(); SASSERT(well_formed(m_units[v])); switch (k) { case entry_kind::unit_e: - e->remove_from(m_units[v], e); + entry::remove_from(m_units[v], e); SASSERT(well_formed(m_units[v])); break; case entry_kind::equal_e: - e->remove_from(m_equal_lin[v], e); + entry::remove_from(m_equal_lin[v], e); break; case entry_kind::diseq_e: - e->remove_from(m_diseq_lin[v], e); + entry::remove_from(m_diseq_lin[v], e); break; default: UNREACHABLE(); @@ -104,7 +104,9 @@ namespace polysat { (void)k; SASSERT(well_formed(m_units[v])); if (e->prev() != e) { - e->prev()->insert_after(e); + entry* pos = e->prev(); + e->init(e); + pos->insert_after(e); if (e->interval.lo_val() < m_units[v]->interval.lo_val()) m_units[v] = e; } @@ -180,6 +182,7 @@ namespace polysat { entries[v] = e; else e->insert_after(entries[v]); + SASSERT(entries[v]->invariant()); SASSERT(well_formed(m_units[v])); } @@ -272,10 +275,10 @@ namespace polysat { */ bool viable::refine_equal_lin(pvar v, rational const& val) { // LOG_H2("refine-equal-lin with v" << v << ", val = " << val); - auto* e = m_equal_lin[v]; + entry const* e = m_equal_lin[v]; if (!e) return true; - entry* first = e; + entry const* first = e; rational const& max_value = s.var2pdd(v).max_value(); rational mod_value = max_value + 1; @@ -380,10 +383,10 @@ namespace polysat { bool viable::refine_disequal_lin(pvar v, rational const& val) { // LOG_H2("refine-disequal-lin with v" << v << ", val = " << val); - auto* e = m_diseq_lin[v]; + entry const* e = m_diseq_lin[v]; if (!e) return true; - entry* first = e; + entry const* first = e; rational const& max_value = s.var2pdd(v).max_value(); rational const mod_value = max_value + 1; @@ -632,9 +635,9 @@ namespace polysat { bool viable::resolve(pvar v, conflict& core) { if (has_viable(v)) return false; - auto* e = m_units[v]; + entry const* e = m_units[v]; // TODO: in the forbidden interval paper, they start with the longest interval. We should also try that at some point. - entry* first = e; + entry const* first = e; SASSERT(e); // If there is a full interval, all others would have been removed SASSERT(!e->interval.is_full() || e->next() == e); @@ -642,7 +645,7 @@ namespace polysat { do { // Build constraint: upper bound of each interval is not contained in the next interval, // using the equivalence: t \in [l;h[ <=> t-l < h-l - entry* n = e->next(); + entry const* n = e->next(); // Choose the next interval which furthest extends the covered region. // Example: @@ -666,7 +669,7 @@ namespace polysat { // // The interval 'first' is always part of the lemma. If we reach first again here, we have covered the complete domain. while (n != first) { - entry* n1 = n->next(); + entry const* n1 = n->next(); // Check if n1 is eligible; if yes, then n1 is better than n. // // Case 1, n1 overlaps e (unless n1 == e): diff --git a/src/math/polysat/viable.h b/src/math/polysat/viable.h index 4507323a8..036d80491 100644 --- a/src/math/polysat/viable.h +++ b/src/math/polysat/viable.h @@ -38,7 +38,7 @@ namespace polysat { solver& s; forbidden_intervals m_forbidden_intervals; - struct entry : public dll_base, public fi_record {}; + struct entry final : public dll_base, public fi_record {}; enum class entry_kind { unit_e, equal_e, diseq_e }; ptr_vector m_alloc; diff --git a/src/util/debug.h b/src/util/debug.h index 795976eac..cd4634ae2 100644 --- a/src/util/debug.h +++ b/src/util/debug.h @@ -19,6 +19,7 @@ Revision History: #pragma once #include +#include void enable_assertions(bool f); bool assertions_enabled(); diff --git a/src/util/dlist.h b/src/util/dlist.h index e8be4cda5..50ba9cec3 100644 --- a/src/util/dlist.h +++ b/src/util/dlist.h @@ -18,14 +18,25 @@ Revision History: --*/ #pragma once #include +#include "util/debug.h" +#include "util/util.h" template class dll_iterator; template class dll_base { - T* m_next { nullptr }; - T* m_prev { nullptr }; + T* m_next = nullptr; + T* m_prev = nullptr; + +protected: + dll_base() = default; + ~dll_base() = default; + public: + dll_base(dll_base const&) = delete; + dll_base(dll_base&&) = delete; + dll_base& operator=(dll_base const&) = delete; + dll_base& operator=(dll_base&&) = delete; T* prev() { return m_prev; } T* next() { return m_next; } @@ -35,6 +46,7 @@ public: void init(T* t) { m_next = t; m_prev = t; + SASSERT(invariant()); } static T* pop(T*& list) { @@ -45,23 +57,61 @@ public: return head; } - void insert_after(T* elem) { + void insert_after(T* other) { +#ifndef NDEBUG + SASSERT(other); + SASSERT(invariant()); + SASSERT(other->invariant()); + unsigned const old_sz1 = count_if(*static_cast(this), [](T const&) { return true; }); + unsigned const old_sz2 = count_if(*other, [](T const&) { return true; }); +#endif + // have: this -> next -> ... + // insert: other -> ... -> other_end + // result: this -> other -> ... -> other_end -> next -> ... T* next = this->m_next; - elem->m_prev = next->m_prev; - elem->m_next = next; - this->m_next = elem; - next->m_prev = elem; + T* other_end = other->m_prev; + this->m_next = other; + other->m_prev = static_cast(this); + other_end->m_next = next; + next->m_prev = other_end; +#ifndef NDEBUG + SASSERT(invariant()); + SASSERT(other->invariant()); + unsigned const new_sz = count_if(*static_cast(this), [](T const&) { return true; }); + SASSERT_EQ(new_sz, old_sz1 + old_sz2); +#endif } - void insert_before(T* elem) { + void insert_before(T* other) { +#ifndef NDEBUG + SASSERT(other); + SASSERT(invariant()); + SASSERT(other->invariant()); + unsigned const old_sz1 = count_if(*static_cast(this), [](T const&) { return true; }); + unsigned const old_sz2 = count_if(*other, [](T const&) { return true; }); +#endif + // have: prev -> this -> ... + // insert: other -> ... -> other_end + // result: prev -> other -> ... -> other_end -> this -> ... T* prev = this->m_prev; - elem->m_next = prev->m_next; - elem->m_prev = prev; - prev->m_next = elem; - this->m_prev = elem; + T* other_end = other->m_prev; + prev->m_next = other; + other->m_prev = prev; + other_end->m_next = static_cast(this); + this->m_prev = other_end; +#ifndef NDEBUG + SASSERT(invariant()); + SASSERT(other->invariant()); + unsigned const new_sz = count_if(*static_cast(this), [](T const&) { return true; }); + SASSERT_EQ(new_sz, old_sz1 + old_sz2); +#endif } static void remove_from(T*& list, T* elem) { + SASSERT(list); + SASSERT(elem); + SASSERT(list->invariant()); + SASSERT(elem->invariant()); if (list->m_next == list) { SASSERT(elem == list); list = nullptr; @@ -73,6 +123,7 @@ public: auto* prev = elem->m_prev; prev->m_next = next; next->m_prev = prev; + SASSERT(list->invariant()); } static void push_to_front(T*& list, T* elem) { @@ -141,10 +192,11 @@ public: return {elem, false}; } - // using value_type = T; - // using pointer = T const*; - // using reference = T const&; - // using iterator_category = std::input_iterator_tag; + using value_type = T; + using pointer = T const*; + using reference = T const&; + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; dll_iterator& operator++() { m_elem = m_elem->next(); diff --git a/src/util/util.h b/src/util/util.h index d2a4771e9..21f879517 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -379,10 +379,10 @@ inline size_t megabytes_to_bytes(unsigned mb) { /** Compact version of std::all_of */ template -bool all_of(Container const& c, Predicate f) +bool all_of(Container const& c, Predicate p) { using std::begin, std::end; // allows begin(c) to also find c.begin() - return std::all_of(begin(c), end(c), std::forward(f)); + return std::all_of(begin(c), end(c), std::forward(p)); } /** Compact version of std::count */ @@ -392,3 +392,11 @@ std::size_t count(Container const& c, Item x) using std::begin, std::end; // allows begin(c) to also find c.begin() return std::count(begin(c), end(c), std::forward(x)); } + +/** Compact version of std::count_if */ +template +std::size_t count_if(Container const& c, Predicate p) +{ + using std::begin, std::end; // allows begin(c) to also find c.begin() + return std::count_if(begin(c), end(c), std::forward(p)); +}