mirror of
https://github.com/Z3Prover/z3
synced 2025-04-29 11:55:51 +00:00
Add a vector of edges to handle spanning trees
This commit is contained in:
parent
9f53a4aa18
commit
5a27c035e4
5 changed files with 129 additions and 109 deletions
|
@ -27,7 +27,8 @@ namespace smt {
|
|||
|
||||
template<typename Ext>
|
||||
network_flow<Ext>::network_flow(graph & g, vector<fin_numeral> const & balances) :
|
||||
m_balances(balances) {
|
||||
m_balances(balances),
|
||||
m_tree(m_graph) {
|
||||
// Network flow graph has the edges in the reversed order compared to constraint graph
|
||||
// We only take enabled edges from the original graph
|
||||
for (unsigned i = 0; i < g.get_num_nodes(); ++i) {
|
||||
|
@ -40,14 +41,13 @@ namespace smt {
|
|||
m_graph.add_edge(e.get_target(), e.get_source(), e.get_weight(), explanation());
|
||||
}
|
||||
}
|
||||
m_tree = thread_spanning_tree<Ext>(m_graph);
|
||||
m_step = 0;
|
||||
}
|
||||
|
||||
template<typename Ext>
|
||||
void network_flow<Ext>::initialize() {
|
||||
TRACE("network_flow", tout << "initialize...\n";);
|
||||
// Create an artificial root node to construct initial spanning m_tree
|
||||
// Create an artificial root node to construct initial spanning tree
|
||||
unsigned num_nodes = m_graph.get_num_nodes();
|
||||
unsigned num_edges = m_graph.get_num_edges();
|
||||
|
||||
|
@ -69,39 +69,44 @@ namespace smt {
|
|||
m_states.resize(num_nodes + num_edges);
|
||||
m_states.fill(LOWER);
|
||||
|
||||
// Create artificial edges from/to root node to/from other nodes and initialize the spanning m_tree
|
||||
svector<bool> upwards(num_nodes, false);
|
||||
// Create artificial edges from/to root node to/from other nodes and initialize the spanning tree
|
||||
svector<edge_id> tree;
|
||||
bool is_forward;
|
||||
for (unsigned i = 0; i < num_nodes; ++i) {
|
||||
upwards[i] = !m_balances[i].is_neg();
|
||||
is_forward = !m_balances[i].is_neg();
|
||||
m_states[num_edges + i] = BASIS;
|
||||
node src = upwards[i] ? i : root;
|
||||
node tgt = upwards[i] ? root : i;
|
||||
m_flows[num_edges + i] = upwards[i] ? m_balances[i] : -m_balances[i];
|
||||
m_potentials[i] = upwards[i] ? numeral::one() : -numeral::one();
|
||||
m_graph.add_edge(src, tgt, numeral::one(), explanation());
|
||||
node src = is_forward ? i : root;
|
||||
node tgt = is_forward ? root : i;
|
||||
m_flows[num_edges + i] = is_forward ? m_balances[i] : -m_balances[i];
|
||||
m_potentials[i] = is_forward ? numeral::one() : -numeral::one();
|
||||
tree.push_back(m_graph.add_edge(src, tgt, numeral::one(), explanation()));
|
||||
}
|
||||
|
||||
m_tree.initialize(upwards);
|
||||
m_tree.initialize(tree);
|
||||
|
||||
TRACE("network_flow", {
|
||||
tout << pp_vector("Potentials", m_potentials, true) << pp_vector("Flows", m_flows);
|
||||
});
|
||||
TRACE("network_flow", tout << "Spanning m_tree:\n" << display_spanning_tree(););
|
||||
TRACE("network_flow", tout << "Spanning tree:\n" << display_spanning_tree(););
|
||||
SASSERT(check_well_formed());
|
||||
}
|
||||
|
||||
template<typename Ext>
|
||||
void network_flow<Ext>::update_potentials() {
|
||||
node src = m_graph.get_source(m_enter_id);
|
||||
node tgt = m_graph.get_target(m_enter_id);
|
||||
node tgt = m_graph.get_target(m_enter_id);
|
||||
numeral cost = m_potentials[src] - m_potentials[tgt] - m_graph.get_weight(m_enter_id);
|
||||
numeral change = m_is_swap_leave ? -cost : cost;
|
||||
numeral change = m_tree.in_subtree_t2(tgt) ? cost : -cost;
|
||||
node start = m_graph.get_source(m_leave_id);
|
||||
if (!m_tree.in_subtree_t2(start)) {
|
||||
start = m_graph.get_target(m_leave_id);;
|
||||
}
|
||||
TRACE("network_flow", tout << "update_potentials of T_" << start << " with change = " << change << "...\n";);
|
||||
svector<node> descendants;
|
||||
node start = m_is_swap_enter ? src : tgt;
|
||||
TRACE("network_flow", tout << "update_potentials of T_" << start << " with delta = " << change << "...\n";);
|
||||
m_tree.get_descendants(start, descendants);
|
||||
SASSERT(descendants.size() >= 1);
|
||||
for (unsigned i = 0; i < descendants.size(); ++i) {
|
||||
node u = descendants[i];
|
||||
node u = descendants[i];
|
||||
m_potentials[u] += change;
|
||||
}
|
||||
TRACE("network_flow", tout << pp_vector("Potentials", m_potentials, true););
|
||||
|
@ -110,25 +115,25 @@ namespace smt {
|
|||
template<typename Ext>
|
||||
void network_flow<Ext>::update_flows() {
|
||||
TRACE("network_flow", tout << "update_flows...\n";);
|
||||
numeral val = *m_delta;
|
||||
m_flows[m_enter_id] += val;
|
||||
m_flows[m_enter_id] += *m_delta;
|
||||
node src = m_graph.get_source(m_enter_id);
|
||||
node tgt = m_graph.get_target(m_enter_id);
|
||||
svector<edge_id> path;
|
||||
svector<bool> against;
|
||||
m_tree.get_path(src, tgt, path, against);
|
||||
SASSERT(path.size() >= 1);
|
||||
for (unsigned i = 0; i < path.size(); ++i) {
|
||||
edge_id e_id = path[i];
|
||||
m_flows[e_id] += against[i] ? -val : val;
|
||||
m_flows[e_id] += against[i] ? - *m_delta : *m_delta;
|
||||
}
|
||||
TRACE("network_flow", tout << pp_vector("Flows", m_flows, true););
|
||||
}
|
||||
|
||||
template<typename Ext>
|
||||
bool network_flow<Ext>::choose_entering_edge() {
|
||||
TRACE("network_flow", tout << "choose_entering_edge...\n";);
|
||||
vector<edge> const & es = m_graph.get_all_edges();
|
||||
for (unsigned i = 0; i < es.size(); ++i) {
|
||||
TRACE("network_flow", tout << "choose_entering_edge...\n";);
|
||||
unsigned num_edges = m_graph.get_num_edges();
|
||||
for (unsigned i = 0; i < num_edges; ++i) {
|
||||
node src = m_graph.get_source(i);
|
||||
node tgt = m_graph.get_target(i);
|
||||
if (m_states[i] != BASIS) {
|
||||
|
@ -138,7 +143,7 @@ namespace smt {
|
|||
m_enter_id = i;
|
||||
TRACE("network_flow", {
|
||||
tout << "Found entering edge " << i << " between node ";
|
||||
tout << src << " and node " << tgt << "...\n";
|
||||
tout << src << " and node " << tgt << " with reduced cost = " << cost << "...\n";
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
@ -158,9 +163,10 @@ namespace smt {
|
|||
svector<edge_id> path;
|
||||
svector<bool> against;
|
||||
m_tree.get_path(src, tgt, path, against);
|
||||
SASSERT(path.size() >= 1);
|
||||
for (unsigned i = 0; i < path.size(); ++i) {
|
||||
edge_id e_id = path[i];
|
||||
if (against[i] && (!m_delta || m_flows[e_id] <= *m_delta)) {
|
||||
if (against[i] && (!m_delta || m_flows[e_id] < *m_delta)) {
|
||||
m_delta = m_flows[e_id];
|
||||
leave_id = e_id;
|
||||
}
|
||||
|
@ -182,7 +188,7 @@ namespace smt {
|
|||
|
||||
template<typename Ext>
|
||||
void network_flow<Ext>::update_spanning_tree() {
|
||||
m_tree.update(m_enter_id, m_leave_id, m_is_swap_enter, m_is_swap_leave);
|
||||
m_tree.update(m_enter_id, m_leave_id);
|
||||
}
|
||||
|
||||
// Minimize cost flows
|
||||
|
@ -201,7 +207,7 @@ namespace smt {
|
|||
m_states[m_leave_id] = (m_flows[m_leave_id].is_zero()) ? LOWER : UPPER;
|
||||
update_spanning_tree();
|
||||
update_potentials();
|
||||
TRACE("network_flow", tout << "Spanning m_tree:\n" << display_spanning_tree(););
|
||||
TRACE("network_flow", tout << "Spanning tree:\n" << display_spanning_tree(););
|
||||
SASSERT(check_well_formed());
|
||||
}
|
||||
else {
|
||||
|
@ -217,12 +223,11 @@ namespace smt {
|
|||
template<typename Ext>
|
||||
typename network_flow<Ext>::numeral network_flow<Ext>::get_optimal_solution(vector<numeral> & result, bool is_dual) {
|
||||
numeral objective_value = numeral::zero();
|
||||
vector<edge> const & es = m_graph.get_all_edges();
|
||||
for (unsigned i = 0; i < es.size(); ++i) {
|
||||
edge const & e = es[i];
|
||||
unsigned num_edges = m_graph.get_num_edges();
|
||||
for (unsigned i = 0; i < num_edges; ++i) {
|
||||
if (m_states[i] == BASIS)
|
||||
{
|
||||
objective_value += e.get_weight().get_rational() * m_flows[i];
|
||||
objective_value += m_graph.get_weight(i).get_rational() * m_flows[i];
|
||||
}
|
||||
}
|
||||
result.reset();
|
||||
|
@ -243,6 +248,7 @@ namespace smt {
|
|||
template<typename Ext>
|
||||
bool network_flow<Ext>::check_well_formed() {
|
||||
SASSERT(m_tree.check_well_formed());
|
||||
SASSERT(!m_delta || !(*m_delta).is_neg());
|
||||
|
||||
// m_flows are zero on non-basic edges
|
||||
for (unsigned i = 0; i < m_flows.size(); ++i) {
|
||||
|
@ -250,11 +256,10 @@ namespace smt {
|
|||
SASSERT(m_states[i] == BASIS || m_flows[i].is_zero());
|
||||
}
|
||||
|
||||
vector<edge> const & es = m_graph.get_all_edges();
|
||||
for (unsigned i = 0; i < es.size(); ++i) {
|
||||
edge const & e = es[i];
|
||||
unsigned num_edges = m_graph.get_num_edges();
|
||||
for (unsigned i = 0; i < num_edges; ++i) {
|
||||
if (m_states[i] == BASIS) {
|
||||
SASSERT(m_potentials[e.get_source()] - m_potentials[e.get_target()] == e.get_weight());
|
||||
SASSERT(m_potentials[m_graph.get_source(i)] - m_potentials[m_graph.get_target(i)] == m_graph.get_weight(i));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -264,11 +269,10 @@ namespace smt {
|
|||
template<typename Ext>
|
||||
bool network_flow<Ext>::check_optimal() {
|
||||
numeral total_cost = numeral::zero();
|
||||
vector<edge> const & es = m_graph.get_all_edges();
|
||||
for (unsigned i = 0; i < es.size(); ++i) {
|
||||
edge const & e = es[i];
|
||||
unsigned num_edges = m_graph.get_num_edges();
|
||||
for (unsigned i = 0; i < num_edges; ++i) {
|
||||
if (m_states[i] == BASIS) {
|
||||
total_cost += e.get_weight().get_rational() * m_flows[i];
|
||||
total_cost += m_graph.get_weight(i).get_rational() * m_flows[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -299,15 +303,14 @@ namespace smt {
|
|||
oss << prefix << root << "[shape=doublecircle,label=\"" << prefix << root << " [";
|
||||
oss << m_potentials[root] << "/" << m_balances[root] << "]\"];\n";
|
||||
|
||||
vector<edge> const & es = m_graph.get_all_edges();
|
||||
for (unsigned i = 0; i < es.size(); ++i) {
|
||||
edge const & e = es[i];
|
||||
oss << prefix << e.get_source() << " -> " << prefix << e.get_target();
|
||||
unsigned num_edges = m_graph.get_num_edges();
|
||||
for (unsigned i = 0; i < num_edges; ++i) {
|
||||
oss << prefix << m_graph.get_source(i) << " -> " << prefix << m_graph.get_target(i);
|
||||
if (m_states[i] == BASIS) {
|
||||
oss << "[color=red,penwidth=3.0,label=\"" << m_flows[i] << "/" << e.get_weight() << "\"];\n";
|
||||
oss << "[color=red,penwidth=3.0,label=\"" << m_flows[i] << "/" << m_graph.get_weight(i) << "\"];\n";
|
||||
}
|
||||
else {
|
||||
oss << "[label=\"" << m_flows[i] << "/" << e.get_weight() << "\"];\n";
|
||||
oss << "[label=\"" << m_flows[i] << "/" << m_graph.get_weight(i) << "\"];\n";
|
||||
}
|
||||
}
|
||||
oss << std::endl;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue