mirror of
https://github.com/Z3Prover/z3
synced 2025-04-16 13:58:45 +00:00
Refactor network_flow
Use a template method for pretty printing
This commit is contained in:
parent
42cbbe830e
commit
49aba844b8
|
@ -35,11 +35,8 @@ Notes:
|
|||
|
||||
namespace smt {
|
||||
|
||||
template<typename T>
|
||||
std::string pp_vector(std::string const & label, svector<T> v, bool has_header = false);
|
||||
|
||||
template<typename T>
|
||||
std::string pp_vector(std::string const & label, vector<T> v, bool has_header = false);
|
||||
template<typename TV>
|
||||
std::string pp_vector(std::string const & label, TV v, bool has_header = false);
|
||||
|
||||
// Solve minimum cost flow problem using Network Simplex algorithm
|
||||
template<typename Ext>
|
||||
|
@ -91,7 +88,7 @@ namespace smt {
|
|||
// Initialize the network with a feasible spanning tree
|
||||
void initialize();
|
||||
|
||||
bool get_edge_id(dl_var source, dl_var target, edge_id & id);
|
||||
edge_id get_edge_id(dl_var source, dl_var target);
|
||||
|
||||
|
||||
void update_potentials();
|
||||
|
|
|
@ -24,8 +24,8 @@ Notes:
|
|||
|
||||
namespace smt {
|
||||
|
||||
template<typename T>
|
||||
std::string pp_vector(std::string const & label, svector<T> v, bool has_header) {
|
||||
template<typename TV>
|
||||
std::string pp_vector(std::string const & label, TV v, bool has_header) {
|
||||
std::ostringstream oss;
|
||||
if (has_header) {
|
||||
oss << "Index ";
|
||||
|
@ -42,23 +42,6 @@ namespace smt {
|
|||
return oss.str();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::string pp_vector(std::string const & label, vector<T> v, bool has_header) {
|
||||
std::ostringstream oss;
|
||||
if (has_header) {
|
||||
oss << "Index ";
|
||||
for (unsigned i = 0; i < v.size(); ++i) {
|
||||
oss << i << " ";
|
||||
}
|
||||
oss << std::endl;
|
||||
}
|
||||
oss << label << " ";
|
||||
for (unsigned i = 0; i < v.size(); ++i) {
|
||||
oss << v[i] << " ";
|
||||
}
|
||||
oss << std::endl;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template<typename Ext>
|
||||
network_flow<Ext>::network_flow(graph & g, vector<fin_numeral> const & balances) :
|
||||
|
@ -105,7 +88,7 @@ namespace smt {
|
|||
|
||||
// Create artificial edges and initialize the spanning tree
|
||||
for (unsigned i = 0; i < num_nodes; ++i) {
|
||||
m_upwards[i] = m_balances[i] >= fin_numeral::zero();
|
||||
m_upwards[i] = !m_balances[i].is_neg();
|
||||
m_pred[i] = root;
|
||||
m_depth[i] = 1;
|
||||
m_thread[i] = i + 1;
|
||||
|
@ -122,8 +105,7 @@ namespace smt {
|
|||
node u = m_thread[root];
|
||||
while (u != root) {
|
||||
node v = m_pred[u];
|
||||
edge_id e_id;
|
||||
get_edge_id(u, v, e_id);
|
||||
edge_id e_id = get_edge_id(u, v);
|
||||
if (m_upwards[u]) {
|
||||
m_potentials[u] = m_potentials[v] - m_graph.get_weight(e_id);
|
||||
}
|
||||
|
@ -142,16 +124,18 @@ namespace smt {
|
|||
}
|
||||
|
||||
template<typename Ext>
|
||||
bool network_flow<Ext>::get_edge_id(dl_var source, dl_var target, edge_id & id) {
|
||||
edge_id network_flow<Ext>::get_edge_id(dl_var source, dl_var target) {
|
||||
// m_upwards decides which node is the real source
|
||||
return m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id);
|
||||
edge_id id;
|
||||
VERIFY(m_upwards[source] ? m_graph.get_edge_id(source, target, id) : m_graph.get_edge_id(target, source, id));
|
||||
return id;
|
||||
}
|
||||
|
||||
template<typename Ext>
|
||||
void network_flow<Ext>::update_potentials() {
|
||||
TRACE("network_flow", tout << "update_potentials...\n";);
|
||||
node src = m_graph.get_source(m_entering_edge);
|
||||
node tgt = m_graph.get_source(m_entering_edge);
|
||||
node tgt = m_graph.get_target(m_entering_edge);
|
||||
numeral cost = m_graph.get_weight(m_entering_edge);
|
||||
numeral change = m_upwards[src] ? (cost - m_potentials[src] + m_potentials[tgt]) :
|
||||
(-cost + m_potentials[src] - m_potentials[tgt]);
|
||||
|
@ -169,14 +153,12 @@ namespace smt {
|
|||
m_flows[m_entering_edge] += val;
|
||||
node source = m_graph.get_source(m_entering_edge);
|
||||
for (unsigned u = source; u != m_join_node; u = m_pred[u]) {
|
||||
edge_id e_id;
|
||||
get_edge_id(u, m_pred[u], e_id);
|
||||
edge_id e_id = get_edge_id(u, m_pred[u]);
|
||||
m_flows[e_id] += m_upwards[u] ? -val : val;
|
||||
}
|
||||
node target = m_graph.get_target(m_entering_edge);
|
||||
for (unsigned u = target; u != m_join_node; u = m_pred[u]) {
|
||||
edge_id e_id;
|
||||
get_edge_id(u, m_pred[u], e_id);
|
||||
edge_id e_id = get_edge_id(u, m_pred[u]);
|
||||
m_flows[e_id] += m_upwards[u] ? val : -val;
|
||||
}
|
||||
TRACE("network_flow", tout << pp_vector("Flows", m_flows, true););
|
||||
|
@ -189,13 +171,13 @@ namespace smt {
|
|||
for (unsigned int i = 0; i < es.size(); ++i) {
|
||||
edge const & e = es[i];
|
||||
edge_id e_id;
|
||||
if (e.is_enabled() && m_graph.get_edge_id(e.get_source(), e.get_target(), e_id) && m_states[e_id] == NON_BASIS) {
|
||||
node source = e.get_source();
|
||||
node target = e.get_target();
|
||||
node source = e.get_source();
|
||||
node target = e.get_target();
|
||||
if (e.is_enabled() && m_graph.get_edge_id(source, target, e_id) && m_states[e_id] == NON_BASIS) {
|
||||
numeral cost = e.get_weight() - m_potentials[source] + m_potentials[target];
|
||||
// Choose the first negative-cost edge to be the violating edge
|
||||
// TODO: add multiple pivoting strategies
|
||||
if (cost < numeral::zero()) {
|
||||
if (cost.is_neg()) {
|
||||
m_entering_edge = e_id;
|
||||
TRACE("network_flow", {
|
||||
tout << "Found entering edge " << e_id << " between node ";
|
||||
|
@ -234,8 +216,7 @@ namespace smt {
|
|||
node src, tgt;
|
||||
// Send flows along the path from source to the ancestor
|
||||
for (unsigned u = source; u != m_join_node; u = m_pred[u]) {
|
||||
edge_id e_id;
|
||||
get_edge_id(u, m_pred[u], e_id);
|
||||
edge_id e_id = get_edge_id(u, m_pred[u]);
|
||||
numeral d = m_upwards[u] ? m_flows[e_id] : infty;
|
||||
if (d < m_delta) {
|
||||
m_delta = d;
|
||||
|
@ -246,8 +227,7 @@ namespace smt {
|
|||
|
||||
// Send flows along the path from target to the ancestor
|
||||
for (unsigned u = target; u != m_join_node; u = m_pred[u]) {
|
||||
edge_id e_id;
|
||||
get_edge_id(u, m_pred[u], e_id);
|
||||
edge_id e_id = get_edge_id(u, m_pred[u]);
|
||||
numeral d = m_upwards[u] ? m_flows[e_id] : infty;
|
||||
if (d <= m_delta) {
|
||||
m_delta = d;
|
||||
|
@ -257,7 +237,7 @@ namespace smt {
|
|||
}
|
||||
|
||||
if (m_delta < infty) {
|
||||
get_edge_id(src, tgt, m_leaving_edge);
|
||||
m_leaving_edge = get_edge_id(src, tgt);
|
||||
TRACE("network_flow", {
|
||||
tout << "Found leaving edge " << m_leaving_edge;
|
||||
tout << " between node " << src << " and node " << tgt << "...\n";
|
||||
|
@ -391,12 +371,10 @@ namespace smt {
|
|||
typename network_flow<Ext>::numeral network_flow<Ext>::get_optimal_solution(vector<numeral> & result, bool is_dual) {
|
||||
m_objective_value = numeral::zero();
|
||||
vector<edge> const & es = m_graph.get_all_edges();
|
||||
fin_numeral cost;
|
||||
for (unsigned i = 0; i < es.size(); ++i) {
|
||||
edge const & e = es[i];
|
||||
if (e.is_enabled() && m_states[i] == BASIS) {
|
||||
cost = e.get_weight().get_rational();
|
||||
m_objective_value += cost * m_flows[i];
|
||||
m_objective_value += e.get_weight().get_rational() * m_flows[i];
|
||||
}
|
||||
}
|
||||
result.reset();
|
||||
|
|
Loading…
Reference in a new issue