diff --git a/src/smt/network_flow.h b/src/smt/network_flow.h index 461312985..d845c671a 100644 --- a/src/smt/network_flow.h +++ b/src/smt/network_flow.h @@ -244,7 +244,7 @@ namespace smt { }; graph m_graph; - thread_spanning_tree m_tree; + spanning_tree_base * m_tree; // Denote supply/demand b_i on node i vector m_balances; diff --git a/src/smt/network_flow_def.h b/src/smt/network_flow_def.h index 515d49933..0ec7fe0c0 100644 --- a/src/smt/network_flow_def.h +++ b/src/smt/network_flow_def.h @@ -27,8 +27,7 @@ namespace smt { template network_flow::network_flow(graph & g, vector const & balances) : - m_balances(balances), - m_tree(m_graph) { + m_balances(balances) { // Network flow graph has the edges in the reversed order compared to constraint graph // We only take enabled edges from the original graph for (unsigned i = 0; i < g.get_num_nodes(); ++i) { @@ -42,6 +41,7 @@ namespace smt { } } m_step = 0; + m_tree = alloc(thread_spanning_tree, m_graph); } template @@ -82,7 +82,7 @@ namespace smt { tree.push_back(m_graph.add_edge(src, tgt, numeral::one(), explanation())); } - m_tree.initialize(tree); + m_tree->initialize(tree); TRACE("network_flow", { tout << pp_vector("Potentials", m_potentials, true) << pp_vector("Flows", m_flows); @@ -96,14 +96,14 @@ namespace smt { node src = m_graph.get_source(m_enter_id); node tgt = m_graph.get_target(m_enter_id); numeral cost = m_potentials[src] - m_potentials[tgt] - m_graph.get_weight(m_enter_id); - numeral change = m_tree.in_subtree_t2(tgt) ? cost : -cost; + numeral change = m_tree->in_subtree_t2(tgt) ? cost : -cost; node start = m_graph.get_source(m_leave_id); - if (!m_tree.in_subtree_t2(start)) { + if (!m_tree->in_subtree_t2(start)) { start = m_graph.get_target(m_leave_id);; } TRACE("network_flow", tout << "update_potentials of T_" << start << " with change = " << change << "...\n";); svector descendants; - m_tree.get_descendants(start, descendants); + m_tree->get_descendants(start, descendants); SASSERT(descendants.size() >= 1); for (unsigned i = 0; i < descendants.size(); ++i) { node u = descendants[i]; @@ -120,7 +120,7 @@ namespace smt { node tgt = m_graph.get_target(m_enter_id); svector path; svector against; - m_tree.get_path(src, tgt, path, against); + m_tree->get_path(src, tgt, path, against); SASSERT(path.size() >= 1); for (unsigned i = 0; i < path.size(); ++i) { edge_id e_id = path[i]; @@ -138,7 +138,7 @@ namespace smt { edge_id leave_id; svector path; svector against; - m_tree.get_path(src, tgt, path, against); + m_tree->get_path(src, tgt, path, against); SASSERT(path.size() >= 1); for (unsigned i = 0; i < path.size(); ++i) { edge_id e_id = path[i]; @@ -164,7 +164,7 @@ namespace smt { template void network_flow::update_spanning_tree() { - m_tree.update(m_enter_id, m_leave_id); + m_tree->update(m_enter_id, m_leave_id); } // FIXME: should declare pivot as a pivot_rule_impl and refactor @@ -240,7 +240,7 @@ namespace smt { template bool network_flow::check_well_formed() { - SASSERT(m_tree.check_well_formed()); + SASSERT(m_tree->check_well_formed()); SASSERT(!m_delta || !(*m_delta).is_neg()); // m_flows are zero on non-basic edges diff --git a/src/smt/spanning_tree.h b/src/smt/spanning_tree.h index c591f5485..f5f6b624f 100644 --- a/src/smt/spanning_tree.h +++ b/src/smt/spanning_tree.h @@ -49,7 +49,7 @@ namespace smt { void swap_order(node q, node v); node find_rev_thread(node n) const; - void fix_depth(node start, node end); + void fix_depth(node start, node after_end); node get_final(int start); bool is_preorder_traversal(node start, node end); node get_common_ancestor(node u, node v); diff --git a/src/smt/spanning_tree_base.h b/src/smt/spanning_tree_base.h index e492f80fa..25384d52e 100644 --- a/src/smt/spanning_tree_base.h +++ b/src/smt/spanning_tree_base.h @@ -47,14 +47,14 @@ namespace smt { typedef int node; public: - virtual void initialize(svector const & tree) {}; - virtual void get_descendants(node start, svector & descendants) {}; + virtual void initialize(svector const & tree) = 0; + virtual void get_descendants(node start, svector & descendants) = 0; - virtual void update(edge_id enter_id, edge_id leave_id) {}; - virtual void get_path(node start, node end, svector & path, svector & against) {}; - virtual bool in_subtree_t2(node child) {UNREACHABLE(); return false;}; + virtual void update(edge_id enter_id, edge_id leave_id) = 0; + virtual void get_path(node start, node end, svector & path, svector & against) = 0; + virtual bool in_subtree_t2(node child) = 0; - virtual bool check_well_formed() {UNREACHABLE(); return false;}; + virtual bool check_well_formed() = 0; }; } diff --git a/src/smt/spanning_tree_def.h b/src/smt/spanning_tree_def.h index c1e05ba3b..281a2cf45 100644 --- a/src/smt/spanning_tree_def.h +++ b/src/smt/spanning_tree_def.h @@ -162,11 +162,26 @@ namespace smt { tout << u << ", " << v << ") leaves\n"; }); + // Old threads: alpha -> v -*-> f(v) -> beta | p -*-> f(p) -> gamma + // New threads: alpha -> beta | p -*-> f(p) -> v -*-> f(v) -> gamma + + node f_p = get_final(p); + node f_v = get_final(v); + node alpha = find_rev_thread(v); + node beta = m_thread[f_v]; + node gamma = m_thread[f_p]; + + if (v != gamma) { + m_thread[alpha] = beta; + m_thread[f_p] = v; + m_thread[f_v] = gamma; + } + node old_pred = m_pred[q]; // Update stem nodes from q to v if (q != v) { - for (node n = q; n != u; ) { - SASSERT(old_pred != u || n == v); // the last processed node is v + for (node n = q; n != v; ) { + SASSERT(old_pred != u); // the last processed node is v SASSERT(-1 != m_pred[old_pred]); int next_old_pred = m_pred[old_pred]; swap_order(n, old_pred); @@ -175,34 +190,18 @@ namespace smt { old_pred = next_old_pred; } } - - // Old threads: alpha -> q -*-> f(q) -> beta | p -*-> f(p) -> gamma - // New threads: alpha -> beta | p -*-> f(p) -> q -*-> f(q) -> gamma - - node f_p = get_final(p); - node f_q = get_final(q); - node alpha = find_rev_thread(q); - node beta = m_thread[f_q]; - node gamma = m_thread[f_p]; - - if (q != gamma) { - m_thread[alpha] = beta; - m_thread[f_p] = q; - m_thread[f_q] = gamma; - } - + m_pred[q] = p; m_tree[q] = enter_id; m_root_t2 = q; + node after_final_q = (v == gamma) ? beta : gamma; + fix_depth(q, after_final_q); + SASSERT(!in_subtree_t2(p)); SASSERT(in_subtree_t2(q)); SASSERT(!in_subtree_t2(u)); SASSERT(in_subtree_t2(v)); - - // Update the depth. - - fix_depth(q, get_final(q)); TRACE("network_flow", { tout << pp_vector("Predecessors", m_pred, true) << pp_vector("Threads", m_thread); @@ -210,6 +209,53 @@ namespace smt { }); } + /** + swap v and q in tree. + - fixup m_thread + - fixup m_pred + + Case 1: final(q) == final(v) + ------- + Old thread: prev -> v -*-> alpha -> q -*-> final(q) -> next + New thread: prev -> q -*-> final(q) -> v -*-> alpha -> next + + Case 2: final(q) != final(v) + ------- + Old thread: prev -> v -*-> alpha -> q -*-> final(q) -> beta -*-> final(v) -> next + New thread: prev -> q -*-> final(q) -> v -*-> alpha -> beta -*-> final(v) -> next + + */ + template + void thread_spanning_tree::swap_order(node q, node v) { + SASSERT(q != v); + SASSERT(m_pred[q] == v); + SASSERT(is_preorder_traversal(v, get_final(v))); + node prev = find_rev_thread(v); + node f_q = get_final(q); + node f_v = get_final(v); + node next = m_thread[f_v]; + node alpha = find_rev_thread(q); + + if (f_q == f_v) { + SASSERT(f_q != v && alpha != next); + m_thread[f_q] = v; + m_thread[alpha] = next; + f_q = alpha; + } + else { + node beta = m_thread[f_q]; + SASSERT(f_q != v && alpha != beta); + m_thread[f_q] = v; + m_thread[alpha] = beta; + f_q = f_v; + } + SASSERT(prev != q); + m_thread[prev] = q; + m_pred[v] = q; + // Notes: f_q has to be used since m_depth hasn't been updated yet. + SASSERT(is_preorder_traversal(q, f_q)); + } + /** \brief Check invariants of main data-structures. @@ -311,53 +357,6 @@ namespace smt { roots[y] = x; } - /** - swap v and q in tree. - - fixup m_thread - - fixup m_pred - - Case 1: final(q) == final(v) - ------- - Old thread: prev -> v -*-> alpha -> q -*-> final(q) -> next - New thread: prev -> q -*-> final(q) -> v -*-> alpha -> next - - Case 2: final(q) != final(v) - ------- - Old thread: prev -> v -*-> alpha -> q -*-> final(q) -> beta -*-> final(v) -> next - New thread: prev -> q -*-> final(q) -> v -*-> alpha -> beta -*-> final(v) -> next - - */ - template - void thread_spanning_tree::swap_order(node q, node v) { - SASSERT(q != v); - SASSERT(m_pred[q] == v); - SASSERT(is_preorder_traversal(v, get_final(v))); - node prev = find_rev_thread(v); - node f_q = get_final(q); - node f_v = get_final(v); - node next = m_thread[f_v]; - node alpha = find_rev_thread(q); - - if (f_q == f_v) { - SASSERT(f_q != v && alpha != next); - m_thread[f_q] = v; - m_thread[alpha] = next; - f_q = alpha; - } - else { - node beta = m_thread[f_q]; - SASSERT(f_q != v && alpha != beta); - m_thread[f_q] = v; - m_thread[alpha] = beta; - f_q = f_v; - } - SASSERT(prev != q); - m_thread[prev] = q; - m_pred[v] = q; - // Notes: f_q has to be used since m_depth hasn't been updated yet. - SASSERT(is_preorder_traversal(q, f_q)); - } - /** \brief find node that points to 'n' in m_thread */ @@ -372,12 +371,11 @@ namespace smt { } template - void thread_spanning_tree::fix_depth(node start, node end) { - SASSERT(m_pred[start] != -1); - m_depth[start] = m_depth[m_pred[start]]+1; - while (start != end) { - start = m_thread[start]; + void thread_spanning_tree::fix_depth(node start, node after_end) { + while (start != after_end) { + SASSERT(m_pred[start] != -1); m_depth[start] = m_depth[m_pred[start]]+1; + start = m_thread[start]; } }