diff --git a/src/smt/network_flow.h b/src/smt/network_flow.h index 3303e2618..2d40edbd2 100644 --- a/src/smt/network_flow.h +++ b/src/smt/network_flow.h @@ -32,7 +32,7 @@ Notes: #include"inf_rational.h" #include"diff_logic.h" -#include"spanning_tree.h" +#include"spanning_tree_def.h" namespace smt { @@ -51,7 +51,7 @@ namespace smt { typedef typename Ext::fin_numeral fin_numeral; graph m_graph; - thread_spanning_tree tree; + thread_spanning_tree tree; // Denote supply/demand b_i on node i vector m_balances; diff --git a/src/smt/network_flow_def.h b/src/smt/network_flow_def.h index d576e3704..f697acaac 100644 --- a/src/smt/network_flow_def.h +++ b/src/smt/network_flow_def.h @@ -47,7 +47,7 @@ namespace smt { m_balances.resize(num_nodes); m_potentials.resize(num_nodes); - tree = thread_spanning_tree(); + tree = thread_spanning_tree(); m_step = 0; } @@ -302,7 +302,6 @@ namespace smt { bool network_flow::edge_in_tree(node src, node dst) const { return edge_in_tree(get_edge_id(src, dst)); } - template bool network_flow::check_well_formed() { diff --git a/src/smt/spanning_tree.h b/src/smt/spanning_tree.h index 6769a6a18..41264245e 100644 --- a/src/smt/spanning_tree.h +++ b/src/smt/spanning_tree.h @@ -19,13 +19,20 @@ Notes: #ifndef _SPANNING_TREE_H_ #define _SPANNING_TREE_H_ +#include "diff_logic.h" #include "spanning_tree_base.h" namespace smt { - - class thread_spanning_tree : virtual public spanning_tree_base { + template + class thread_spanning_tree : public spanning_tree_base, private Ext { private: + typedef dl_var node; + typedef dl_edge edge; + typedef dl_graph graph; + typedef typename Ext::numeral numeral; + typedef typename Ext::fin_numeral fin_numeral; + // Store the parent of a node i in the spanning tree svector m_pred; // Store the number of edge on the path from node i to the root @@ -47,8 +54,8 @@ namespace smt { public: void initialize(svector const & upwards, int num_nodes); - void get_descendants(node start, svector& descendants); - void get_ancestors(node start, svector& ancestors); + void get_descendants(node start, svector & descendants); + void get_ancestors(node start, svector & ancestors); node get_common_ancestor(node u, node v); void update(node p, node q, node u, node v); bool check_well_formed(); diff --git a/src/smt/spanning_tree_base.h b/src/smt/spanning_tree_base.h index 6be090711..065a4b042 100644 --- a/src/smt/spanning_tree_base.h +++ b/src/smt/spanning_tree_base.h @@ -24,8 +24,6 @@ Notes: #include "vector.h" namespace smt { - typedef int node; - template inline std::string pp_vector(std::string const & label, TV v, bool has_header = false) { std::ostringstream oss; @@ -44,27 +42,29 @@ namespace smt { return oss.str(); } - class spanning_tree_base { - public: - spanning_tree_base() {}; - virtual ~spanning_tree_base() {}; + class spanning_tree_base { + private: + typedef int node; + + public: + virtual void initialize(svector const & upwards, int num_nodes) {}; - virtual void initialize(svector const & upwards, int num_nodes) = 0; /** \brief Get all descendants of a node including itself */ - virtual void get_descendants(node start, svector& descendants) = 0; + virtual void get_descendants(node start, svector & descendants) {}; /** \brief Get all ancestors of a node including itself */ - virtual void get_ancestors(node start, svector& ancestors) = 0; - virtual node get_common_ancestor(node u, node v) = 0; - virtual void update(node p, node q, node u, node v) = 0; - virtual bool check_well_formed() = 0; + virtual void get_ancestors(node start, svector & ancestors) {}; + + virtual node get_common_ancestor(node u, node v) {UNREACHABLE(); return -1;}; + virtual void update(node p, node q, node u, node v) {}; + virtual bool check_well_formed() {UNREACHABLE(); return false;}; // TODO: remove these two unnatural functions - virtual bool get_arc_direction(node start) const = 0; - virtual node get_parent(node start) = 0; + virtual bool get_arc_direction(node start) const {UNREACHABLE(); return false;}; + virtual node get_parent(node start) {UNREACHABLE(); return -1;}; }; } diff --git a/src/smt/spanning_tree.cpp b/src/smt/spanning_tree_def.h similarity index 85% rename from src/smt/spanning_tree.cpp rename to src/smt/spanning_tree_def.h index 21c3fcfb4..aac40e1b7 100644 --- a/src/smt/spanning_tree.cpp +++ b/src/smt/spanning_tree_def.h @@ -3,7 +3,7 @@ Copyright (c) 2013 Microsoft Corporation Module Name: - spanning_tree.cpp + spanning_tree_def.h Abstract: @@ -15,16 +15,13 @@ Author: Notes: --*/ -#include + +#ifndef _SPANNING_TREE_DEF_H_ +#define _SPANNING_TREE_DEF_H_ + #include "spanning_tree.h" -#include "debug.h" -#include "vector.h" -#include "uint_set.h" -#include "trace.h" namespace smt { - - /** swap v and q in tree. - fixup m_thread @@ -41,7 +38,8 @@ namespace smt { New thread: prev -> q -*-> final(q) -> v -*-> alpha -> beta -*-> final(v) -> next */ - void thread_spanning_tree::swap_order(node q, node v) { + template + void thread_spanning_tree::swap_order(node q, node v) { SASSERT(q != v); SASSERT(m_pred[q] == v); SASSERT(is_preorder_traversal(v, get_final(v))); @@ -68,7 +66,8 @@ namespace smt { /** \brief find node that points to 'n' in m_thread */ - node thread_spanning_tree::find_rev_thread(node n) const { + template + typename thread_spanning_tree::node thread_spanning_tree::find_rev_thread(node n) const { node ancestor = m_pred[n]; SASSERT(ancestor != -1); while (m_thread[ancestor] != n) { @@ -77,7 +76,8 @@ namespace smt { return ancestor; } - void thread_spanning_tree::fix_depth(node start, node end) { + template + void thread_spanning_tree::fix_depth(node start, node end) { SASSERT(m_pred[start] != -1); m_depth[start] = m_depth[m_pred[start]]+1; while (start != end) { @@ -86,7 +86,8 @@ namespace smt { } } - node thread_spanning_tree::get_final(int start) { + template + typename thread_spanning_tree::node thread_spanning_tree::get_final(int start) { int n = start; while (m_depth[m_thread[n]] > m_depth[start]) { n = m_thread[n]; @@ -94,7 +95,8 @@ namespace smt { return n; } - bool thread_spanning_tree::is_preorder_traversal(node start, node end) { + template + bool thread_spanning_tree::is_preorder_traversal(node start, node end) { // get children of start uint_set children; children.insert(start); @@ -118,7 +120,8 @@ namespace smt { return true; } - bool thread_spanning_tree::is_ancestor_of(node ancestor, node child) { + template + bool thread_spanning_tree::is_ancestor_of(node ancestor, node child) { for (node n = child; n != -1; n = m_pred[n]) { if (n == ancestor) { return true; @@ -154,7 +157,8 @@ namespace smt { roots[y] = x; } - void thread_spanning_tree::initialize(svector const & upwards, int num_nodes) { + template + void thread_spanning_tree::initialize(svector const & upwards, int num_nodes) { m_pred.resize(num_nodes + 1); m_depth.resize(num_nodes + 1); m_thread.resize(num_nodes + 1); @@ -179,7 +183,8 @@ namespace smt { }); } - node thread_spanning_tree::get_common_ancestor(node u, node v) { + template + typename thread_spanning_tree::node thread_spanning_tree::get_common_ancestor(node u, node v) { while (u != v) { if (m_depth[u] > m_depth[v]) u = m_pred[u]; @@ -189,7 +194,8 @@ namespace smt { return u; } - void thread_spanning_tree::get_descendants(node start, svector& descendants) { + template + void thread_spanning_tree::get_descendants(node start, svector& descendants) { descendants.reset(); node u = start; while (m_depth[m_thread[u]] > m_depth[start]) { @@ -198,7 +204,8 @@ namespace smt { } } - void thread_spanning_tree::get_ancestors(node start, svector& ancestors) { + template + void thread_spanning_tree::get_ancestors(node start, svector& ancestors) { ancestors.reset(); while (m_pred[start] != -1) { ancestors.push_back(start); @@ -222,7 +229,8 @@ namespace smt { \ \ / q q */ - void thread_spanning_tree::update(node p, node q, node u, node v) { + template + void thread_spanning_tree::update(node p, node q, node u, node v) { bool q_upwards = false; // v is parent of u so T_u does not contain root node @@ -299,7 +307,8 @@ namespace smt { m_upwards direction of edge from i to m_pred[i] m_graph */ - bool thread_spanning_tree::check_well_formed() { + template + bool thread_spanning_tree::check_well_formed() { node root = m_pred.size()-1; // Check that m_thread traverses each node. @@ -346,11 +355,15 @@ namespace smt { return true; } - bool thread_spanning_tree::get_arc_direction(node start) const { + template + bool thread_spanning_tree::get_arc_direction(node start) const { return m_upwards[start]; } - node thread_spanning_tree::get_parent(node start) { + template + typename thread_spanning_tree::node thread_spanning_tree::get_parent(node start) { return m_pred[start]; } -} \ No newline at end of file +} + +#endif \ No newline at end of file