diff --git a/src/smt/network_flow.h b/src/smt/network_flow.h index 489cf351e..3303e2618 100644 --- a/src/smt/network_flow.h +++ b/src/smt/network_flow.h @@ -59,9 +59,6 @@ namespace smt { // Duals of flows which are convenient to compute dual solutions vector m_potentials; - // Keep optimal solution of the min cost flow problem - numeral m_objective_value; - // Basic feasible flows vector m_flows; @@ -100,6 +97,7 @@ namespace smt { bool edge_in_tree(node src, node dst) const; bool check_well_formed(); + bool check_optimal(); public: diff --git a/src/smt/network_flow_def.h b/src/smt/network_flow_def.h index 834cfc075..d576e3704 100644 --- a/src/smt/network_flow_def.h +++ b/src/smt/network_flow_def.h @@ -48,6 +48,7 @@ namespace smt { m_potentials.resize(num_nodes); tree = thread_spanning_tree(); + m_step = 0; } template @@ -79,7 +80,7 @@ namespace smt { m_states[num_edges + i] = BASIS; node src = upwards[i] ? i : root; node tgt = upwards[i] ? root : i; - m_flows[num_edges + i] = upwards[i] ? m_balances[i] : -m_balances[i]; + m_flows[num_edges + i] = upwards[i] ? m_balances[i] : -m_balances[i]; m_graph.add_edge(src, tgt, numeral::one(), explanation()); } @@ -117,7 +118,7 @@ namespace smt { node src = m_graph.get_source(m_entering_edge); node tgt = m_graph.get_target(m_entering_edge); numeral cost = m_graph.get_weight(m_entering_edge); - numeral change = tree.get_arc_direction(src) ? (-cost - m_potentials[src] + m_potentials[tgt]) : (cost + m_potentials[src] - m_potentials[tgt]); + numeral change = m_potentials[tgt] - m_potentials[src] + (tree.get_arc_direction(src) ? -cost : cost); svector descendants; tree.get_descendants(src, descendants); for (unsigned i = 0; i < descendants.size(); ++i) { @@ -142,7 +143,7 @@ namespace smt { } node target = m_graph.get_target(m_entering_edge); - tree.get_ancestors(target,ancestors); + tree.get_ancestors(target, ancestors); for (unsigned i = 0; i < ancestors.size() && ancestors[i] != m_join_node; ++i) { node u = ancestors[i]; edge_id e_id = get_edge_id(u, tree.get_parent(u)); @@ -266,18 +267,20 @@ namespace smt { } } TRACE("network_flow", tout << "Found optimal solution.\n";); + SASSERT(check_optimal()); return true; } // Get the optimal solution template typename network_flow::numeral network_flow::get_optimal_solution(vector & result, bool is_dual) { - m_objective_value = numeral::zero(); + numeral objective_value = numeral::zero(); vector const & es = m_graph.get_all_edges(); for (unsigned i = 0; i < es.size(); ++i) { edge const & e = es[i]; - if (m_states[i] == BASIS) { - m_objective_value += e.get_weight().get_rational() * m_flows[i]; + if (m_states[i] == BASIS) + { + objective_value += e.get_weight().get_rational() * m_flows[i]; } } result.reset(); @@ -287,7 +290,7 @@ namespace smt { else { result.append(m_flows); } - return m_objective_value; + return objective_value; } template @@ -297,7 +300,7 @@ namespace smt { template bool network_flow::edge_in_tree(node src, node dst) const { - return edge_in_tree(get_edge_id(src,dst)); + return edge_in_tree(get_edge_id(src, dst)); } @@ -310,7 +313,6 @@ namespace smt { SASSERT(m_states[i] == BASIS || m_flows[i].is_zero()); } - // m_upwards show correct direction for (unsigned i = 0; i < m_potentials.size(); ++i) { node p = tree.get_parent(i); @@ -322,7 +324,31 @@ namespace smt { } template - std::string network_flow:: display_spanning_tree() { + bool network_flow::check_optimal() { + numeral total_cost = numeral::zero(); + vector const & es = m_graph.get_all_edges(); + for (unsigned i = 0; i < es.size(); ++i) { + edge const & e = es[i]; + if (m_states[i] == BASIS) + { + total_cost += e.get_weight().get_rational() * m_flows[i]; + } + } + + // m_flows are zero on non-basic edges + for (unsigned i = 0; i < m_flows.size(); ++i) { + SASSERT(m_states[i] == BASIS || m_flows[i].is_zero()); + } + numeral total_balance = numeral::zero(); + for (unsigned i = 0; i < m_potentials.size(); ++i) { + total_balance += m_balances[i] * m_potentials[i]; + } + std::cout << "Total balance: " << total_balance << ", total cost: " << total_cost << std::endl; + return total_cost == total_balance; + } + + template + std::string network_flow::display_spanning_tree() { ++m_step;; std::ostringstream oss; std::string prefix = "T";