From 49aba844b8b5c0a9f177b956b4dc3f7aa52921a1 Mon Sep 17 00:00:00 2001 From: Anh-Dung Phan Date: Wed, 30 Oct 2013 10:04:56 -0700 Subject: [PATCH] Refactor network_flow Use a template method for pretty printing --- src/smt/network_flow.h | 9 ++---- src/smt/network_flow_def.h | 60 ++++++++++++-------------------------- 2 files changed, 22 insertions(+), 47 deletions(-) diff --git a/src/smt/network_flow.h b/src/smt/network_flow.h index 39574fd55..33f9bf7a6 100644 --- a/src/smt/network_flow.h +++ b/src/smt/network_flow.h @@ -35,11 +35,8 @@ Notes: namespace smt { - template - std::string pp_vector(std::string const & label, svector v, bool has_header = false); - - template - std::string pp_vector(std::string const & label, vector v, bool has_header = false); + template + std::string pp_vector(std::string const & label, TV v, bool has_header = false); // Solve minimum cost flow problem using Network Simplex algorithm template @@ -91,7 +88,7 @@ namespace smt { // Initialize the network with a feasible spanning tree void initialize(); - bool get_edge_id(dl_var source, dl_var target, edge_id & id); + edge_id get_edge_id(dl_var source, dl_var target); void update_potentials(); diff --git a/src/smt/network_flow_def.h b/src/smt/network_flow_def.h index b3be85a2f..f321f3385 100644 --- a/src/smt/network_flow_def.h +++ b/src/smt/network_flow_def.h @@ -24,8 +24,8 @@ Notes: namespace smt { - template - std::string pp_vector(std::string const & label, svector v, bool has_header) { + template + std::string pp_vector(std::string const & label, TV v, bool has_header) { std::ostringstream oss; if (has_header) { oss << "Index "; @@ -42,23 +42,6 @@ namespace smt { return oss.str(); } - template - std::string pp_vector(std::string const & label, vector v, bool has_header) { - std::ostringstream oss; - if (has_header) { - oss << "Index "; - for (unsigned i = 0; i < v.size(); ++i) { - oss << i << " "; - } - oss << std::endl; - } - oss << label << " "; - for (unsigned i = 0; i < v.size(); ++i) { - oss << v[i] << " "; - } - oss << std::endl; - return oss.str(); - } template network_flow::network_flow(graph & g, vector const & balances) : @@ -105,7 +88,7 @@ namespace smt { // Create artificial edges and initialize the spanning tree for (unsigned i = 0; i < num_nodes; ++i) { - m_upwards[i] = m_balances[i] >= fin_numeral::zero(); + m_upwards[i] = !m_balances[i].is_neg(); m_pred[i] = root; m_depth[i] = 1; m_thread[i] = i + 1; @@ -122,8 +105,7 @@ namespace smt { node u = m_thread[root]; while (u != root) { node v = m_pred[u]; - edge_id e_id; - get_edge_id(u, v, e_id); + edge_id e_id = get_edge_id(u, v); if (m_upwards[u]) { m_potentials[u] = m_potentials[v] - m_graph.get_weight(e_id); } @@ -142,16 +124,18 @@ namespace smt { } template - bool network_flow::get_edge_id(dl_var source, dl_var target, edge_id & id) { + edge_id network_flow::get_edge_id(dl_var source, dl_var target) { // m_upwards decides which node is the real source - return m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id); + edge_id id; + VERIFY(m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id)); + return id; } template void network_flow::update_potentials() { TRACE("network_flow", tout << "update_potentials...\n";); node src = m_graph.get_source(m_entering_edge); - node tgt = m_graph.get_source(m_entering_edge); + node tgt = m_graph.get_target(m_entering_edge); numeral cost = m_graph.get_weight(m_entering_edge); numeral change = m_upwards[src] ? (cost - m_potentials[src] + m_potentials[tgt]) : (-cost + m_potentials[src] - m_potentials[tgt]); @@ -169,14 +153,12 @@ namespace smt { m_flows[m_entering_edge] += val; node source = m_graph.get_source(m_entering_edge); for (unsigned u = source; u != m_join_node; u = m_pred[u]) { - edge_id e_id; - get_edge_id(u, m_pred[u], e_id); + edge_id e_id = get_edge_id(u, m_pred[u]); m_flows[e_id] += m_upwards[u] ? -val : val; } node target = m_graph.get_target(m_entering_edge); for (unsigned u = target; u != m_join_node; u = m_pred[u]) { - edge_id e_id; - get_edge_id(u, m_pred[u], e_id); + edge_id e_id = get_edge_id(u, m_pred[u]); m_flows[e_id] += m_upwards[u] ? val : -val; } TRACE("network_flow", tout << pp_vector("Flows", m_flows, true);); @@ -189,13 +171,13 @@ namespace smt { for (unsigned int i = 0; i < es.size(); ++i) { edge const & e = es[i]; edge_id e_id; - if (e.is_enabled() && m_graph.get_edge_id(e.get_source(), e.get_target(), e_id) && m_states[e_id] == NON_BASIS) { - node source = e.get_source(); - node target = e.get_target(); + node source = e.get_source(); + node target = e.get_target(); + if (e.is_enabled() && m_graph.get_edge_id(source, target, e_id) && m_states[e_id] == NON_BASIS) { numeral cost = e.get_weight() - m_potentials[source] + m_potentials[target]; // Choose the first negative-cost edge to be the violating edge // TODO: add multiple pivoting strategies - if (cost < numeral::zero()) { + if (cost.is_neg()) { m_entering_edge = e_id; TRACE("network_flow", { tout << "Found entering edge " << e_id << " between node "; @@ -234,8 +216,7 @@ namespace smt { node src, tgt; // Send flows along the path from source to the ancestor for (unsigned u = source; u != m_join_node; u = m_pred[u]) { - edge_id e_id; - get_edge_id(u, m_pred[u], e_id); + edge_id e_id = get_edge_id(u, m_pred[u]); numeral d = m_upwards[u] ? m_flows[e_id] : infty; if (d < m_delta) { m_delta = d; @@ -246,8 +227,7 @@ namespace smt { // Send flows along the path from target to the ancestor for (unsigned u = target; u != m_join_node; u = m_pred[u]) { - edge_id e_id; - get_edge_id(u, m_pred[u], e_id); + edge_id e_id = get_edge_id(u, m_pred[u]); numeral d = m_upwards[u] ? m_flows[e_id] : infty; if (d <= m_delta) { m_delta = d; @@ -257,7 +237,7 @@ namespace smt { } if (m_delta < infty) { - get_edge_id(src, tgt, m_leaving_edge); + m_leaving_edge = get_edge_id(src, tgt); TRACE("network_flow", { tout << "Found leaving edge " << m_leaving_edge; tout << " between node " << src << " and node " << tgt << "...\n"; @@ -391,12 +371,10 @@ namespace smt { typename network_flow::numeral network_flow::get_optimal_solution(vector & result, bool is_dual) { m_objective_value = numeral::zero(); vector const & es = m_graph.get_all_edges(); - fin_numeral cost; for (unsigned i = 0; i < es.size(); ++i) { edge const & e = es[i]; if (e.is_enabled() && m_states[i] == BASIS) { - cost = e.get_weight().get_rational(); - m_objective_value += cost * m_flows[i]; + m_objective_value += e.get_weight().get_rational() * m_flows[i]; } } result.reset();