3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-16 13:58:45 +00:00

Refactor network_flow

Use a template method for pretty printing
This commit is contained in:
Anh-Dung Phan 2013-10-30 10:04:56 -07:00
parent 42cbbe830e
commit 49aba844b8
2 changed files with 22 additions and 47 deletions

View file

@ -35,11 +35,8 @@ Notes:
namespace smt {
template<typename T>
std::string pp_vector(std::string const & label, svector<T> v, bool has_header = false);
template<typename T>
std::string pp_vector(std::string const & label, vector<T> v, bool has_header = false);
template<typename TV>
std::string pp_vector(std::string const & label, TV v, bool has_header = false);
// Solve minimum cost flow problem using Network Simplex algorithm
template<typename Ext>
@ -91,7 +88,7 @@ namespace smt {
// Initialize the network with a feasible spanning tree
void initialize();
bool get_edge_id(dl_var source, dl_var target, edge_id & id);
edge_id get_edge_id(dl_var source, dl_var target);
void update_potentials();

View file

@ -24,8 +24,8 @@ Notes:
namespace smt {
template<typename T>
std::string pp_vector(std::string const & label, svector<T> v, bool has_header) {
template<typename TV>
std::string pp_vector(std::string const & label, TV v, bool has_header) {
std::ostringstream oss;
if (has_header) {
oss << "Index ";
@ -42,23 +42,6 @@ namespace smt {
return oss.str();
}
template<typename T>
std::string pp_vector(std::string const & label, vector<T> v, bool has_header) {
std::ostringstream oss;
if (has_header) {
oss << "Index ";
for (unsigned i = 0; i < v.size(); ++i) {
oss << i << " ";
}
oss << std::endl;
}
oss << label << " ";
for (unsigned i = 0; i < v.size(); ++i) {
oss << v[i] << " ";
}
oss << std::endl;
return oss.str();
}
template<typename Ext>
network_flow<Ext>::network_flow(graph & g, vector<fin_numeral> const & balances) :
@ -105,7 +88,7 @@ namespace smt {
// Create artificial edges and initialize the spanning tree
for (unsigned i = 0; i < num_nodes; ++i) {
m_upwards[i] = m_balances[i] >= fin_numeral::zero();
m_upwards[i] = !m_balances[i].is_neg();
m_pred[i] = root;
m_depth[i] = 1;
m_thread[i] = i + 1;
@ -122,8 +105,7 @@ namespace smt {
node u = m_thread[root];
while (u != root) {
node v = m_pred[u];
edge_id e_id;
get_edge_id(u, v, e_id);
edge_id e_id = get_edge_id(u, v);
if (m_upwards[u]) {
m_potentials[u] = m_potentials[v] - m_graph.get_weight(e_id);
}
@ -142,16 +124,18 @@ namespace smt {
}
template<typename Ext>
bool network_flow<Ext>::get_edge_id(dl_var source, dl_var target, edge_id & id) {
edge_id network_flow<Ext>::get_edge_id(dl_var source, dl_var target) {
// m_upwards decides which node is the real source
return m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id);
edge_id id;
VERIFY(m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id));
return id;
}
template<typename Ext>
void network_flow<Ext>::update_potentials() {
TRACE("network_flow", tout << "update_potentials...\n";);
node src = m_graph.get_source(m_entering_edge);
node tgt = m_graph.get_source(m_entering_edge);
node tgt = m_graph.get_target(m_entering_edge);
numeral cost = m_graph.get_weight(m_entering_edge);
numeral change = m_upwards[src] ? (cost - m_potentials[src] + m_potentials[tgt]) :
(-cost + m_potentials[src] - m_potentials[tgt]);
@ -169,14 +153,12 @@ namespace smt {
m_flows[m_entering_edge] += val;
node source = m_graph.get_source(m_entering_edge);
for (unsigned u = source; u != m_join_node; u = m_pred[u]) {
edge_id e_id;
get_edge_id(u, m_pred[u], e_id);
edge_id e_id = get_edge_id(u, m_pred[u]);
m_flows[e_id] += m_upwards[u] ? -val : val;
}
node target = m_graph.get_target(m_entering_edge);
for (unsigned u = target; u != m_join_node; u = m_pred[u]) {
edge_id e_id;
get_edge_id(u, m_pred[u], e_id);
edge_id e_id = get_edge_id(u, m_pred[u]);
m_flows[e_id] += m_upwards[u] ? val : -val;
}
TRACE("network_flow", tout << pp_vector("Flows", m_flows, true););
@ -189,13 +171,13 @@ namespace smt {
for (unsigned int i = 0; i < es.size(); ++i) {
edge const & e = es[i];
edge_id e_id;
if (e.is_enabled() && m_graph.get_edge_id(e.get_source(), e.get_target(), e_id) && m_states[e_id] == NON_BASIS) {
node source = e.get_source();
node target = e.get_target();
node source = e.get_source();
node target = e.get_target();
if (e.is_enabled() && m_graph.get_edge_id(source, target, e_id) && m_states[e_id] == NON_BASIS) {
numeral cost = e.get_weight() - m_potentials[source] + m_potentials[target];
// Choose the first negative-cost edge to be the violating edge
// TODO: add multiple pivoting strategies
if (cost < numeral::zero()) {
if (cost.is_neg()) {
m_entering_edge = e_id;
TRACE("network_flow", {
tout << "Found entering edge " << e_id << " between node ";
@ -234,8 +216,7 @@ namespace smt {
node src, tgt;
// Send flows along the path from source to the ancestor
for (unsigned u = source; u != m_join_node; u = m_pred[u]) {
edge_id e_id;
get_edge_id(u, m_pred[u], e_id);
edge_id e_id = get_edge_id(u, m_pred[u]);
numeral d = m_upwards[u] ? m_flows[e_id] : infty;
if (d < m_delta) {
m_delta = d;
@ -246,8 +227,7 @@ namespace smt {
// Send flows along the path from target to the ancestor
for (unsigned u = target; u != m_join_node; u = m_pred[u]) {
edge_id e_id;
get_edge_id(u, m_pred[u], e_id);
edge_id e_id = get_edge_id(u, m_pred[u]);
numeral d = m_upwards[u] ? m_flows[e_id] : infty;
if (d <= m_delta) {
m_delta = d;
@ -257,7 +237,7 @@ namespace smt {
}
if (m_delta < infty) {
get_edge_id(src, tgt, m_leaving_edge);
m_leaving_edge = get_edge_id(src, tgt);
TRACE("network_flow", {
tout << "Found leaving edge " << m_leaving_edge;
tout << " between node " << src << " and node " << tgt << "...\n";
@ -391,12 +371,10 @@ namespace smt {
typename network_flow<Ext>::numeral network_flow<Ext>::get_optimal_solution(vector<numeral> & result, bool is_dual) {
m_objective_value = numeral::zero();
vector<edge> const & es = m_graph.get_all_edges();
fin_numeral cost;
for (unsigned i = 0; i < es.size(); ++i) {
edge const & e = es[i];
if (e.is_enabled() && m_states[i] == BASIS) {
cost = e.get_weight().get_rational();
m_objective_value += cost * m_flows[i];
m_objective_value += e.get_weight().get_rational() * m_flows[i];
}
}
result.reset();