diff --git a/src/smt/network_flow.h b/src/smt/network_flow.h index bb9c0ed07..6b7081b89 100644 --- a/src/smt/network_flow.h +++ b/src/smt/network_flow.h @@ -112,6 +112,8 @@ namespace smt { std::string display_spanning_tree(); + bool check_well_formed(); + public: network_flow(graph & g, vector const & balances); diff --git a/src/smt/network_flow_def.h b/src/smt/network_flow_def.h index 134046316..6ca6e5d1d 100644 --- a/src/smt/network_flow_def.h +++ b/src/smt/network_flow_def.h @@ -133,6 +133,7 @@ namespace smt { tout << pp_vector("Potentials", m_potentials) << pp_vector("Flows", m_flows); }); TRACE("network_flow", tout << "Spanning tree:\n" << display_spanning_tree();); + SASSERT(check_well_formed()); } template @@ -395,6 +396,7 @@ namespace smt { for (unsigned i = 0; i < m_thread.size(); ++i) { m_rev_thread[m_thread[i]] = i; } + SASSERT(check_well_formed()); TRACE("network_flow", { tout << pp_vector("Predecessors", m_pred, true) << pp_vector("Threads", m_thread); @@ -439,6 +441,7 @@ namespace smt { bool network_flow::min_cost() { initialize(); while (choose_entering_edge()) { + SASSERT(check_well_formed()); bool bounded = choose_leaving_edge(); if (!bounded) return false; update_flows(); @@ -477,6 +480,70 @@ namespace smt { } return m_objective_value; } + + static unsigned find(svector& roots, unsigned x) { + unsigned old_x = x; + while (roots[x] >= 0) { + x = roots[x]; + } + roots[old_x] = x; + return x; + } + + static void merge(svector& roots, unsigned x, unsigned y) { + x = find(roots, x); + y = find(roots, y); + SASSERT(roots[x] < 0 && roots[y] < 0); + if (x == y) { + return; + } + if (roots[x] > roots[y]) { + std::swap(x, y); + } + SASSERT(roots[x] <= roots[y]); + roots[y] = x; + roots[x] += roots[y]; + } + + template + bool network_flow::check_well_formed() { + // m_thread is depth-first stack + // m_pred is predecessor link + // m_depth depth counting from a root note. + // m_graph + + node root = m_pred.size()-1; + for (unsigned i = 0; i < m_upwards.size(); ++i) { + if (m_upwards[i]) { + node p = m_pred[i]; + edge_id e = get_edge_id(i, p); + // we are either the root or the predecessor points up. + SASSERT(p == root || m_upwards[p]); + } + } + + // m_thread forms a spanning tree over [0..root] + // union-find structure: + svector roots(root+1, -1); + +#if 0 + for (unsigned i = 0; i < m_thread.size(); ++i) { + if (m_states[i] == BASIS) { + node x = m_thread[i]; + node y = i; + // we are now going to check the edge between x and y: + SASSERT(find(roots, x) != find(roots, y)); + merge(roots, x, y); + } + else { + // ? LOWER, UPPER + } + } +#endif + + return true; + } + } #endif