3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-08-06 03:10:25 +00:00

Refactor network_flow

Use a template method for pretty printing
This commit is contained in:
Anh-Dung Phan 2013-10-30 10:04:56 -07:00
parent 42cbbe830e
commit 49aba844b8
2 changed files with 22 additions and 47 deletions

View file

@ -35,11 +35,8 @@ Notes:
namespace smt { namespace smt {
template<typename T> template<typename TV>
std::string pp_vector(std::string const & label, svector<T> v, bool has_header = false); std::string pp_vector(std::string const & label, TV v, bool has_header = false);
template<typename T>
std::string pp_vector(std::string const & label, vector<T> v, bool has_header = false);
// Solve minimum cost flow problem using Network Simplex algorithm // Solve minimum cost flow problem using Network Simplex algorithm
template<typename Ext> template<typename Ext>
@ -91,7 +88,7 @@ namespace smt {
// Initialize the network with a feasible spanning tree // Initialize the network with a feasible spanning tree
void initialize(); void initialize();
bool get_edge_id(dl_var source, dl_var target, edge_id & id); edge_id get_edge_id(dl_var source, dl_var target);
void update_potentials(); void update_potentials();

View file

@ -24,8 +24,8 @@ Notes:
namespace smt { namespace smt {
template<typename T> template<typename TV>
std::string pp_vector(std::string const & label, svector<T> v, bool has_header) { std::string pp_vector(std::string const & label, TV v, bool has_header) {
std::ostringstream oss; std::ostringstream oss;
if (has_header) { if (has_header) {
oss << "Index "; oss << "Index ";
@ -42,23 +42,6 @@ namespace smt {
return oss.str(); return oss.str();
} }
template<typename T>
std::string pp_vector(std::string const & label, vector<T> v, bool has_header) {
std::ostringstream oss;
if (has_header) {
oss << "Index ";
for (unsigned i = 0; i < v.size(); ++i) {
oss << i << " ";
}
oss << std::endl;
}
oss << label << " ";
for (unsigned i = 0; i < v.size(); ++i) {
oss << v[i] << " ";
}
oss << std::endl;
return oss.str();
}
template<typename Ext> template<typename Ext>
network_flow<Ext>::network_flow(graph & g, vector<fin_numeral> const & balances) : network_flow<Ext>::network_flow(graph & g, vector<fin_numeral> const & balances) :
@ -105,7 +88,7 @@ namespace smt {
// Create artificial edges and initialize the spanning tree // Create artificial edges and initialize the spanning tree
for (unsigned i = 0; i < num_nodes; ++i) { for (unsigned i = 0; i < num_nodes; ++i) {
m_upwards[i] = m_balances[i] >= fin_numeral::zero(); m_upwards[i] = !m_balances[i].is_neg();
m_pred[i] = root; m_pred[i] = root;
m_depth[i] = 1; m_depth[i] = 1;
m_thread[i] = i + 1; m_thread[i] = i + 1;
@ -122,8 +105,7 @@ namespace smt {
node u = m_thread[root]; node u = m_thread[root];
while (u != root) { while (u != root) {
node v = m_pred[u]; node v = m_pred[u];
edge_id e_id; edge_id e_id = get_edge_id(u, v);
get_edge_id(u, v, e_id);
if (m_upwards[u]) { if (m_upwards[u]) {
m_potentials[u] = m_potentials[v] - m_graph.get_weight(e_id); m_potentials[u] = m_potentials[v] - m_graph.get_weight(e_id);
} }
@ -142,16 +124,18 @@ namespace smt {
} }
template<typename Ext> template<typename Ext>
bool network_flow<Ext>::get_edge_id(dl_var source, dl_var target, edge_id & id) { edge_id network_flow<Ext>::get_edge_id(dl_var source, dl_var target) {
// m_upwards decides which node is the real source // m_upwards decides which node is the real source
return m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id); edge_id id;
VERIFY(m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id));
return id;
} }
template<typename Ext> template<typename Ext>
void network_flow<Ext>::update_potentials() { void network_flow<Ext>::update_potentials() {
TRACE("network_flow", tout << "update_potentials...\n";); TRACE("network_flow", tout << "update_potentials...\n";);
node src = m_graph.get_source(m_entering_edge); node src = m_graph.get_source(m_entering_edge);
node tgt = m_graph.get_source(m_entering_edge); node tgt = m_graph.get_target(m_entering_edge);
numeral cost = m_graph.get_weight(m_entering_edge); numeral cost = m_graph.get_weight(m_entering_edge);
numeral change = m_upwards[src] ? (cost - m_potentials[src] + m_potentials[tgt]) : numeral change = m_upwards[src] ? (cost - m_potentials[src] + m_potentials[tgt]) :
(-cost + m_potentials[src] - m_potentials[tgt]); (-cost + m_potentials[src] - m_potentials[tgt]);
@ -169,14 +153,12 @@ namespace smt {
m_flows[m_entering_edge] += val; m_flows[m_entering_edge] += val;
node source = m_graph.get_source(m_entering_edge); node source = m_graph.get_source(m_entering_edge);
for (unsigned u = source; u != m_join_node; u = m_pred[u]) { for (unsigned u = source; u != m_join_node; u = m_pred[u]) {
edge_id e_id; edge_id e_id = get_edge_id(u, m_pred[u]);
get_edge_id(u, m_pred[u], e_id);
m_flows[e_id] += m_upwards[u] ? -val : val; m_flows[e_id] += m_upwards[u] ? -val : val;
} }
node target = m_graph.get_target(m_entering_edge); node target = m_graph.get_target(m_entering_edge);
for (unsigned u = target; u != m_join_node; u = m_pred[u]) { for (unsigned u = target; u != m_join_node; u = m_pred[u]) {
edge_id e_id; edge_id e_id = get_edge_id(u, m_pred[u]);
get_edge_id(u, m_pred[u], e_id);
m_flows[e_id] += m_upwards[u] ? val : -val; m_flows[e_id] += m_upwards[u] ? val : -val;
} }
TRACE("network_flow", tout << pp_vector("Flows", m_flows, true);); TRACE("network_flow", tout << pp_vector("Flows", m_flows, true););
@ -189,13 +171,13 @@ namespace smt {
for (unsigned int i = 0; i < es.size(); ++i) { for (unsigned int i = 0; i < es.size(); ++i) {
edge const & e = es[i]; edge const & e = es[i];
edge_id e_id; edge_id e_id;
if (e.is_enabled() && m_graph.get_edge_id(e.get_source(), e.get_target(), e_id) && m_states[e_id] == NON_BASIS) { node source = e.get_source();
node source = e.get_source(); node target = e.get_target();
node target = e.get_target(); if (e.is_enabled() && m_graph.get_edge_id(source, target, e_id) && m_states[e_id] == NON_BASIS) {
numeral cost = e.get_weight() - m_potentials[source] + m_potentials[target]; numeral cost = e.get_weight() - m_potentials[source] + m_potentials[target];
// Choose the first negative-cost edge to be the violating edge // Choose the first negative-cost edge to be the violating edge
// TODO: add multiple pivoting strategies // TODO: add multiple pivoting strategies
if (cost < numeral::zero()) { if (cost.is_neg()) {
m_entering_edge = e_id; m_entering_edge = e_id;
TRACE("network_flow", { TRACE("network_flow", {
tout << "Found entering edge " << e_id << " between node "; tout << "Found entering edge " << e_id << " between node ";
@ -234,8 +216,7 @@ namespace smt {
node src, tgt; node src, tgt;
// Send flows along the path from source to the ancestor // Send flows along the path from source to the ancestor
for (unsigned u = source; u != m_join_node; u = m_pred[u]) { for (unsigned u = source; u != m_join_node; u = m_pred[u]) {
edge_id e_id; edge_id e_id = get_edge_id(u, m_pred[u]);
get_edge_id(u, m_pred[u], e_id);
numeral d = m_upwards[u] ? m_flows[e_id] : infty; numeral d = m_upwards[u] ? m_flows[e_id] : infty;
if (d < m_delta) { if (d < m_delta) {
m_delta = d; m_delta = d;
@ -246,8 +227,7 @@ namespace smt {
// Send flows along the path from target to the ancestor // Send flows along the path from target to the ancestor
for (unsigned u = target; u != m_join_node; u = m_pred[u]) { for (unsigned u = target; u != m_join_node; u = m_pred[u]) {
edge_id e_id; edge_id e_id = get_edge_id(u, m_pred[u]);
get_edge_id(u, m_pred[u], e_id);
numeral d = m_upwards[u] ? m_flows[e_id] : infty; numeral d = m_upwards[u] ? m_flows[e_id] : infty;
if (d <= m_delta) { if (d <= m_delta) {
m_delta = d; m_delta = d;
@ -257,7 +237,7 @@ namespace smt {
} }
if (m_delta < infty) { if (m_delta < infty) {
get_edge_id(src, tgt, m_leaving_edge); m_leaving_edge = get_edge_id(src, tgt);
TRACE("network_flow", { TRACE("network_flow", {
tout << "Found leaving edge " << m_leaving_edge; tout << "Found leaving edge " << m_leaving_edge;
tout << " between node " << src << " and node " << tgt << "...\n"; tout << " between node " << src << " and node " << tgt << "...\n";
@ -391,12 +371,10 @@ namespace smt {
typename network_flow<Ext>::numeral network_flow<Ext>::get_optimal_solution(vector<numeral> & result, bool is_dual) { typename network_flow<Ext>::numeral network_flow<Ext>::get_optimal_solution(vector<numeral> & result, bool is_dual) {
m_objective_value = numeral::zero(); m_objective_value = numeral::zero();
vector<edge> const & es = m_graph.get_all_edges(); vector<edge> const & es = m_graph.get_all_edges();
fin_numeral cost;
for (unsigned i = 0; i < es.size(); ++i) { for (unsigned i = 0; i < es.size(); ++i) {
edge const & e = es[i]; edge const & e = es[i];
if (e.is_enabled() && m_states[i] == BASIS) { if (e.is_enabled() && m_states[i] == BASIS) {
cost = e.get_weight().get_rational(); m_objective_value += e.get_weight().get_rational() * m_flows[i];
m_objective_value += cost * m_flows[i];
} }
} }
result.reset(); result.reset();