diff --git a/src/smt/network_flow_def.h b/src/smt/network_flow_def.h index e4fb13f37..fe7d323f6 100644 --- a/src/smt/network_flow_def.h +++ b/src/smt/network_flow_def.h @@ -42,7 +42,6 @@ namespace smt { return oss.str(); } - template network_flow::network_flow(graph & g, vector const & balances) : m_balances(balances) { @@ -98,6 +97,7 @@ namespace smt { m_balances[root] = -sum_supply; m_flows.resize(num_nodes + num_edges); + m_flows.fill(numeral::zero()); m_states.resize(num_nodes + num_edges); m_states.fill(LOWER); @@ -269,15 +269,11 @@ namespace smt { node v = m_graph.get_target(m_leaving_edge); // v is parent of u so T_u does not contain root node if (m_pred[u] == v) { - node temp = u; - u = v; - v = temp; + std::swap(u, v); } if ((m_states[m_entering_edge] == UPPER) == m_in_edge_dir) { // q should be in T_v so swap p and q - node temp = p; - p = q; - q = temp; + std::swap(p, q); } TRACE("network_flow", { @@ -329,7 +325,14 @@ namespace smt { gamma = m_thread[m_final[v]]; // Check that f(u) is not in T_v - node delta = m_final[u] != m_final[v] ? m_final[u] : phi; + bool found_final_u = false; + for (node n = v; n == m_final[v]; n = m_thread[n]) { + if (n == m_final[u]) { + found_final_u = true; + break; + } + } + node delta = found_final_u ? phi : m_final[u]; n = u; last = m_pred[gamma]; while (n != last && n != -1) { @@ -379,14 +382,14 @@ namespace smt { for (unsigned i = 0; i < m_thread.size(); ++i) { m_rev_thread[m_thread[i]] = i; - } - SASSERT(check_well_formed()); + } TRACE("network_flow", { tout << pp_vector("Predecessors", m_pred, true) << pp_vector("Threads", m_thread); tout << pp_vector("Reverse Threads", m_rev_thread) << pp_vector("Last Successors", m_final); tout << pp_vector("Depths", m_depth) << pp_vector("Upwards", m_upwards); }); + SASSERT(check_well_formed()); } template @@ -424,7 +427,7 @@ namespace smt { template bool network_flow::min_cost() { initialize(); - while (choose_entering_edge()) { + while (choose_entering_edge()) { SASSERT(check_well_formed()); bool bounded = choose_leaving_edge(); if (!bounded) return false; @@ -438,7 +441,7 @@ namespace smt { } else { m_states[m_leaving_edge] = m_states[m_leaving_edge] == LOWER ? UPPER : LOWER; - } + } } TRACE("network_flow", tout << "Found optimal solution.\n";); return true; @@ -470,7 +473,10 @@ namespace smt { while (roots[x] >= 0) { x = roots[x]; } - roots[old_x] = x; + SASSERT(roots[x] < 0); + if (old_x != x) { + roots[old_x] = x; + } return x; } @@ -485,45 +491,65 @@ namespace smt { std::swap(x, y); } SASSERT(roots[x] <= roots[y]); - roots[y] = x; roots[x] += roots[y]; + roots[y] = x; + } + + static int get_final(int root, svector const & thread, svector const & depth) { + int n = root; + while (depth[thread[n]] > depth[root]) { + n = thread[n]; + } + return n; } template bool network_flow::check_well_formed() { - // m_thread is depth-first stack - // m_pred is predecessor link - // m_depth depth counting from a root note. - // m_graph -#if 0 node root = m_pred.size()-1; + + // m_upwards show correct direction for (unsigned i = 0; i < m_upwards.size(); ++i) { - if (m_upwards[i]) { - node p = m_pred[i]; - edge_id e = get_edge_id(i, p); - // we are either the root or the predecessor points up. - SASSERT(p == root || m_upwards[p]); - } + node p = m_pred[i]; + edge_id id; + SASSERT(m_upwards[i] == m_graph.get_edge_id(i, p, id)); + } + + // m_depth[x] denotes distance from x to the root node + for (node x = m_thread[root]; x != root; x = m_thread[x]) { + SASSERT(m_depth[x] == m_depth[m_pred[x]] + 1); + } + + // m_final of a node denotes the last node with a bigger depth + for (unsigned i = 0; i < m_final.size(); ++i) { + SASSERT(m_final[i] == get_final(i, m_thread, m_depth)); } // m_thread forms a spanning tree over [0..root] - // union-find structure: - svector roots(root+1, -1); - - for (unsigned i = 0; i < m_thread.size(); ++i) { - if (m_states[i] == BASIS) { - node x = m_thread[i]; - node y = i; - // we are now going to check the edge between x and y: - SASSERT(find(roots, x) != find(roots, y)); - merge(roots, x, y); - } - else { - // ? LOWER, UPPER - } + // Union-find structure + svector roots(m_pred.size(), -1); + + for (node x = m_thread[root]; x != root; x = m_thread[x]) { + node y = m_pred[x]; + // We are now going to check the edge between x and y + SASSERT(find(roots, x) != find(roots, y)); + merge(roots, x, y); + } + + std::cout << "roots" << std::endl; + for (unsigned i = 0; i < roots.size(); ++i) { + std::cout << i << " |-> " << roots[i] << std::endl; } -#endif + // All nodes belong to the same spanning tree + for (unsigned i = 0; i < roots.size(); ++i) { + SASSERT(i == 0 ? roots[i] + roots.size() == 0 : roots[i] == 0); + } + + // m_flows are zero on non-basic edges + for (unsigned i = 0; i < m_flows.size(); ++i) { + SASSERT(m_states[i] == BASIS || m_flows[i].is_zero()); + } + return true; }