3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-19 04:13:38 +00:00

Refactor pivot rules

This commit is contained in:
Anh-Dung Phan 2013-11-21 19:05:17 -08:00
parent 97dfb6d521
commit 3b2dd47cd4
5 changed files with 47 additions and 29 deletions

View file

@ -961,7 +961,6 @@ public:
for (; it != end; ++it) { for (; it != end; ++it) {
edge_id e_id = *it; edge_id e_id = *it;
edge & e = m_edges[e_id]; edge & e = m_edges[e_id];
if (!e.is_enabled()) continue;
SASSERT(e.get_source() == current); SASSERT(e.get_source() == current);
dl_var neighbour = e.get_target(); dl_var neighbour = e.get_target();
neighbours.push_back(neighbour); neighbours.push_back(neighbour);
@ -972,7 +971,6 @@ public:
for (; it != end; ++it) { for (; it != end; ++it) {
edge_id e_id = *it; edge_id e_id = *it;
edge & e = m_edges[e_id]; edge & e = m_edges[e_id];
if (!e.is_enabled()) continue;
SASSERT(e.get_target() == current); SASSERT(e.get_target() == current);
dl_var neighbour = e.get_source(); dl_var neighbour = e.get_source();
neighbours.push_back(neighbour); neighbours.push_back(neighbour);
@ -982,27 +980,41 @@ public:
void dfs_undirected(dl_var start, svector<dl_var> & threads) { void dfs_undirected(dl_var start, svector<dl_var> & threads) {
threads.reset(); threads.reset();
threads.resize(get_num_nodes()); threads.resize(get_num_nodes());
uint_set visited; uint_set discovered, explored;
svector<dl_var> nodes; svector<dl_var> nodes;
discovered.insert(start);
nodes.push_back(start); nodes.push_back(start);
dl_var prev = -1; dl_var prev = -1;
while(!nodes.empty()) { while(!nodes.empty()) {
dl_var current = nodes.back(); dl_var current = nodes.back();
nodes.pop_back(); SASSERT(discovered.contains(current) && !explored.contains(current));
visited.insert(current); std::cout << "thread[" << prev << "] --> " << current << std::endl;
if (prev != -1) if (prev != -1) {
threads[prev] = current; threads[prev] = current;
std::cout << "thread[" << prev << "] --> " << current << std::endl;
}
prev = current; prev = current;
svector<dl_var> neighbours; svector<dl_var> neighbours;
get_neighbours_undirected(current, neighbours); get_neighbours_undirected(current, neighbours);
SASSERT(!neighbours.empty());
bool found = false;
for (unsigned i = 0; i < neighbours.size(); ++i) { for (unsigned i = 0; i < neighbours.size(); ++i) {
dl_var next = neighbours[i];
DEBUG_CODE( DEBUG_CODE(
edge_id id; edge_id id;
SASSERT(prev == -1 || get_edge_id(prev, current, id) || get_edge_id(current, prev, id));); SASSERT(get_edge_id(current, next, id) || get_edge_id(next, current, id)););
if (!visited.contains(neighbours[i])) { if (!discovered.contains(next) && !explored.contains(next)) {
nodes.push_back(neighbours[i]); discovered.insert(next);
nodes.push_back(next);
found = true;
break;
} }
} }
SASSERT(!nodes.empty());
if (!found) {
explored.insert(current);
nodes.pop_back();
}
} }
threads[prev] = start; threads[prev] = start;
} }
@ -1022,8 +1034,10 @@ public:
SASSERT(visited.contains(current)); SASSERT(visited.contains(current));
svector<dl_var> neighbours; svector<dl_var> neighbours;
get_neighbours_undirected(current, neighbours); get_neighbours_undirected(current, neighbours);
SASSERT(!neighbours.empty());
for (unsigned i = 0; i < neighbours.size(); ++i) { for (unsigned i = 0; i < neighbours.size(); ++i) {
dl_var & next = neighbours[i]; dl_var next = neighbours[i];
std::cout << "parents[" << next << "] --> " << current << std::endl;
DEBUG_CODE( DEBUG_CODE(
edge_id id; edge_id id;
SASSERT(get_edge_id(current, next, id) || get_edge_id(next, current, id));); SASSERT(get_edge_id(current, next, id) || get_edge_id(next, current, id)););

View file

@ -82,7 +82,7 @@ namespace smt {
bool choose_entering_edge() {return false;}; bool choose_entering_edge() {return false;};
}; };
class first_eligible_pivot : pivot_rule_impl { class first_eligible_pivot : public pivot_rule_impl {
private: private:
edge_id m_next_edge; edge_id m_next_edge;
@ -117,7 +117,7 @@ namespace smt {
}; };
}; };
class best_eligible_pivot : pivot_rule_impl { class best_eligible_pivot : public pivot_rule_impl {
public: public:
best_eligible_pivot(graph & g, vector<numeral> & potentials, best_eligible_pivot(graph & g, vector<numeral> & potentials,
svector<edge_state> & states, edge_id & enter_id) : svector<edge_state> & states, edge_id & enter_id) :
@ -152,7 +152,7 @@ namespace smt {
}; };
}; };
class candidate_list_pivot : pivot_rule_impl { class candidate_list_pivot : public pivot_rule_impl {
private: private:
edge_id m_next_edge; edge_id m_next_edge;
svector<edge_id> m_candidates; svector<edge_id> m_candidates;

View file

@ -170,18 +170,21 @@ namespace smt {
// FIXME: should declare pivot as a pivot_rule_impl and refactor // FIXME: should declare pivot as a pivot_rule_impl and refactor
template<typename Ext> template<typename Ext>
bool network_flow<Ext>::choose_entering_edge(pivot_rule pr) { bool network_flow<Ext>::choose_entering_edge(pivot_rule pr) {
if (pr == FIRST_ELIGIBLE) { pivot_rule_impl * pivot;
first_eligible_pivot pivot(m_graph, m_potentials, m_states, m_enter_id); switch (pr) {
return pivot.choose_entering_edge(); case FIRST_ELIGIBLE:
pivot = alloc(first_eligible_pivot, m_graph, m_potentials, m_states, m_enter_id);
break;
case BEST_ELIGIBLE:
pivot = alloc(best_eligible_pivot, m_graph, m_potentials, m_states, m_enter_id);
break;
case CANDIDATE_LIST:
pivot = alloc(best_eligible_pivot, m_graph, m_potentials, m_states, m_enter_id);
break;
default:
UNREACHABLE();
} }
else if (pr == BEST_ELIGIBLE) { return pivot->choose_entering_edge();
best_eligible_pivot pivot(m_graph, m_potentials, m_states, m_enter_id);
return pivot.choose_entering_edge();
}
else {
candidate_list_pivot pivot(m_graph, m_potentials, m_states, m_enter_id);
return pivot.choose_entering_edge();
}
} }
// Minimize cost flows // Minimize cost flows

View file

@ -74,6 +74,7 @@ namespace smt {
private: private:
graph * m_tree_graph; graph * m_tree_graph;
public: public:
basic_spanning_tree(graph & g); basic_spanning_tree(graph & g);
void initialize(svector<edge_id> const & tree); void initialize(svector<edge_id> const & tree);

View file

@ -420,14 +420,15 @@ namespace smt {
template<typename Ext> template<typename Ext>
void basic_spanning_tree<Ext>::initialize(svector<edge_id> const & tree) { void basic_spanning_tree<Ext>::initialize(svector<edge_id> const & tree) {
unsigned num_nodes = m_graph.get_num_nodes();
m_tree_graph = alloc(graph); m_tree_graph = alloc(graph);
m_tree = tree;
unsigned num_nodes = m_graph.get_num_nodes();
for (unsigned i = 0; i < num_nodes; ++i) { for (unsigned i = 0; i < num_nodes; ++i) {
m_tree_graph->init_var(i); m_tree_graph->init_var(i);
} }
vector<edge> const & es = m_graph.get_all_edges(); vector<edge> const & es = m_graph.get_all_edges();
svector<edge_id>::const_iterator it = tree.begin(), end = tree.end(); svector<edge_id>::const_iterator it = m_tree.begin(), end = m_tree.end();
for(; it != end; ++it) { for(; it != end; ++it) {
edge const & e = es[*it]; edge const & e = es[*it];
m_tree_graph->add_edge(e.get_source(), e.get_target(), e.get_weight(), explanation()); m_tree_graph->add_edge(e.get_source(), e.get_target(), e.get_weight(), explanation());
@ -440,8 +441,7 @@ namespace smt {
template<typename Ext> template<typename Ext>
void basic_spanning_tree<Ext>::update(edge_id enter_id, edge_id leave_id) { void basic_spanning_tree<Ext>::update(edge_id enter_id, edge_id leave_id) {
if (m_tree_graph) if (m_tree_graph) dealloc(m_tree_graph);
dealloc(m_tree_graph);
m_tree_graph = alloc(graph); m_tree_graph = alloc(graph);
unsigned num_nodes = m_graph.get_num_nodes(); unsigned num_nodes = m_graph.get_num_nodes();
for (unsigned i = 0; i < num_nodes; ++i) { for (unsigned i = 0; i < num_nodes; ++i) {