From 5166d9111bfa354df01ac9a8b09847112340b4fe Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Fri, 19 Dec 2025 16:09:02 -1000 Subject: [PATCH] better sort of root functions Signed-off-by: Lev Nachmanson --- src/nlsat/levelwise.cpp | 140 +++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 82 deletions(-) diff --git a/src/nlsat/levelwise.cpp b/src/nlsat/levelwise.cpp index 1d58acdfc..160845035 100644 --- a/src/nlsat/levelwise.cpp +++ b/src/nlsat/levelwise.cpp @@ -154,17 +154,12 @@ namespace nlsat { void clear() { m_pairs.clear(); m_rfunc.clear(); - m_l_start = m_l_end = m_u_start = m_u_end = -1; } // the indices point te the m_rfunc vector - unsigned m_l_start = -1; - unsigned m_l_end = -1; - unsigned m_u_start = -1; - unsigned m_u_end = -1; void add_pair(unsigned j, unsigned k) { m_pairs.emplace_back(j, k);} }; relation_E m_rel; - relation_mode m_relation_mode = biggest_cell; // there are other choices as well + relation_mode m_relation_mode = chain; // there are other choices as well assignment const & sample() const { return m_solver.sample();} assignment & sample() { return m_solver.sample(); } polynomial::cache & m_cache; @@ -345,67 +340,7 @@ namespace nlsat { } - // Compute root function interval from sorted roots. - void compute_interval_from_sorted_roots() { - root_function_interval & I = m_I[m_level]; - // default: whole line sector (-inf, +inf) - I.section = false; - I.l = nullptr; I.u = nullptr; - if (m_rel.empty()) return; - if (!sample().is_assigned(m_level)) return; - anum const& y_val = sample().value(m_level); - TRACE(lws, tout << "sample val:"; m_am.display_decimal(tout, y_val); tout << "\n";); - - // find first index where roots[idx].val >= y_val - const auto & rfs = m_rel.m_rfunc; - unsigned idx = 0; - while (idx < rfs.size() && m_am.compare(rfs[idx].val, y_val) < 0) { - TRACE(lws, tout << "idx:" << idx << ", val:"; m_am.display_decimal(tout, rfs[idx].val); tout << "\n";); - ++idx; - } - if (idx < rfs.size() && m_am.compare(rfs[idx].val, y_val) == 0) { - TRACE(lws, tout << "exact match at idx:" << idx << ", it's a section\n";); - auto const& ire = rfs[idx].ire; - I.section = true; - I.l = ire.p; I.l_index = ire.i; - I.u = nullptr; I.u_index = -1; // the section is defined by the I.l - TRACE(lws, tout << "section bound -> p:"; if (I.l) m_pm.display(tout, I.l); tout << ", index:" << I.l_index << "\n";); - m_rel.m_l_start = m_rel.m_l_end = idx; - while (++idx < rfs.size() && m_am.compare(rfs[idx].val, y_val) == 0) { - m_rel.m_l_end = idx; - TRACE(lws, tout << "idx:" << idx << ", val:"; m_am.display_decimal(tout, rfs[idx].val); tout << "\n";); - } - TRACE(lws, display_relation(tout);); - return; - } - // sector: lower bound is last root with val < y, upper bound is first root with val > y - if (idx > 0) { - // find start,end of equal-valued group for lower bound - unsigned start = idx - 1; - m_rel.m_l_end = start; - while (start > 0 && m_am.compare(rfs[start-1].val, rfs[start].val) == 0) { - --start; - TRACE(lws, tout << "start:" << start << ", val:"; m_am.display_decimal(tout, rfs[start].val); tout << "\n";); - } - m_rel.m_l_start = start; - auto const& ire = rfs[start].ire; - I.l = ire.p; I.l_index = ire.i; - } - if (idx < rfs.size()) { - // find start, end of equal-valued group for upper bound - unsigned start = idx; - m_rel.m_u_start = idx; - while (start + 1 < rfs.size() && m_am.compare(rfs[start].val, rfs[start + 1].val) == 0) { - ++start; - TRACE(lws, tout << "start:" << start << ", val:"; m_am.display_decimal(tout, rfs[start].val); tout << "\n";); - } - auto const& ire = rfs[start].ire; - m_rel.m_u_end = start; - I.u = ire.p; I.u_index = ire.i; - } - TRACE(lws, display_relation(tout) << std::endl;); - } property pop(p_q_plus& q) { return q.pop(); @@ -427,11 +362,42 @@ namespace nlsat { // Part B of construct_interval: build (I, E, ≼) representation for level i void build_representation() { collect_E(); - // todo: this order needs to be abstracted: it does not have to be linear. - // We need a boolean function E_rel(a, b) - std::sort(m_rel.m_rfunc.begin(), m_rel.m_rfunc.end(), [&](root_function const& a, root_function const& b){ + if (m_rel.m_rfunc.size() == 0) + return; + anum const& v = sample().value(m_level); + auto cmp = [&](root_function const& a, root_function const& b) { + if (a.ire.p == b.ire.p) + return a.ire.i < b.ire.i; return m_am.lt(a.val, b.val); - }); + }; + auto &rfs = m_rel.m_rfunc; + auto mid = std::partition(rfs.begin(), rfs.end(), [&](root_function const& f) { return m_am.compare(f.val, v) <= 0; }); + std::sort(rfs.begin(), mid, cmp); + std::sort(mid, rfs.end(), cmp); + auto & I = m_I[m_level]; + unsigned l_index = -1, u_index = -1; + SASSERT(mid == rfs.end() || m_am.lt(v, mid->val)); + if (mid != rfs.begin()) { + auto& r = *(mid - 1); + if (m_am.eq(r.val, v)) { + l_index = mid - rfs.begin() - 1; + I.section = true; + I.l = r.ire.p; I.l_index = r.ire.i; + } else { + SASSERT( m_am.lt(r.val, v)); + l_index = mid - rfs.begin() - 1; + I.l = r.ire.p; I.l_index = r.ire.i; + if (mid != rfs.end()) { + u_index = l_index + 1; + I.u = mid->ire.p; I.u_index = mid->ire.i; + } + } + } else { // mid == rfs.begin() + auto & r = *mid; + I.u = r.ire.p; I.u_index = r.ire.i; + } + + fill_relation_pairs(l_index, u_index); TRACE(lws, if (m_rel.empty()) tout << "E is empty\n"; else { tout << "E:\n"; @@ -445,18 +411,14 @@ namespace nlsat { display(tout, m_rel.m_rfunc[pair.first]) << "<<<" ; display(tout, m_rel.m_rfunc[pair.second])<< "\n"; } }); - compute_interval_from_sorted_roots(); - fill_relation_pairs(); - TRACE(lws, display(tout << "m_I[" << m_level << "]:", m_I[m_level]) << std::endl;); +TRACE(lws, display(tout << "m_I[" << m_level << "]:", m_I[m_level]) << std::endl;); } - void fill_relation_with_biggest_cell_heuristic() { - unsigned l = m_rel.m_l_end; + void fill_relation_with_biggest_cell_heuristic(unsigned l, unsigned u) { if (is_set(l)) for (unsigned j = 0; j < l; j++) m_rel.add_pair(j, l); - unsigned u = m_rel.m_u_start; if (is_set(u)) for (unsigned j = u + 1; j < m_rel.m_rfunc.size(); j++) m_rel.add_pair(u, j); @@ -467,10 +429,27 @@ namespace nlsat { } } - void fill_relation_pairs() { + void fill_relation_with_chain_heuristic(unsigned l, unsigned u) { + if (is_set(l)) + for (unsigned j = 0; j < l; j++) + m_rel.add_pair(j, j+1); + + if (is_set(u)) + for (unsigned j = u + 1; j < m_rel.m_rfunc.size(); j++) + m_rel.add_pair(j - 1, j); + + if (is_set(l) && is_set(u)) { + SASSERT(l + 1 == u); + m_rel.add_pair(l, u); + } + } + + void fill_relation_pairs(unsigned l, unsigned u) { if (m_relation_mode == biggest_cell) - fill_relation_with_biggest_cell_heuristic(); - else + fill_relation_with_biggest_cell_heuristic(l, u); + else if (m_relation_mode == chain) + fill_relation_with_chain_heuristic(l, u); + else NOT_IMPLEMENTED_YET(); } @@ -1044,9 +1023,6 @@ namespace nlsat { for (const auto& pair : m_rel.m_pairs) { out << " (" << pair.first << ", " << pair.second << ")\n"; } - out << " Indices:\n"; - out << " m_l_start:" << (int)m_rel.m_l_start << ", m_l_end:" << (int)m_rel.m_l_end << "\n"; - out << " m_u_start:" << (int)m_rel.m_u_start << ", m_u_end:" << (int)m_rel.m_u_end << "\n"; return out; }