/*++ Copyright (c) 2006 Microsoft Corporation Module Name: union_find.h Abstract: Author: Leonardo de Moura (leonardo) 2008-05-31. Revision History: --*/ #ifndef _UNION_FIND_H_ #define _UNION_FIND_H_ #include "trail.h" class union_find_default_ctx { public: typedef trail_stack _trail_stack; union_find_default_ctx() : m_stack(*this) {} void unmerge_eh(unsigned, unsigned) {} void merge_eh(unsigned, unsigned, unsigned, unsigned) {} void after_merge_eh(unsigned, unsigned, unsigned, unsigned) {} _trail_stack& get_trail_stack() { return m_stack; } private: _trail_stack m_stack; }; template class union_find { Ctx & m_ctx; trail_stack & m_trail_stack; svector m_find; svector m_size; svector m_next; class mk_var_trail; friend class mk_var_trail; class mk_var_trail : public trail { union_find & m_owner; public: mk_var_trail(union_find & o):m_owner(o) {} virtual ~mk_var_trail() {} virtual void undo(Ctx & ctx) { m_owner.m_find.pop_back(); m_owner.m_size.pop_back(); m_owner.m_next.pop_back(); } }; mk_var_trail m_mk_var_trail; class merge_trail; friend class merge_trail; class merge_trail : public trail { union_find & m_owner; unsigned m_r1; public: merge_trail(union_find & o, unsigned r1):m_owner(o), m_r1(r1) {} virtual ~merge_trail() {} virtual void undo(Ctx & ctx) { m_owner.unmerge(m_r1); } }; void unmerge(unsigned r1) { unsigned r2 = m_find[r1]; TRACE("union_find", tout << "unmerging " << r1 << " " << r2 << "\n";); SASSERT(find(r2) == r2); m_size[r2] -= m_size[r1]; m_find[r1] = r1; std::swap(m_next[r1], m_next[r2]); m_ctx.unmerge_eh(r2, r1); CASSERT("union_find", check_invariant()); } public: union_find(Ctx & ctx):m_ctx(ctx), m_trail_stack(ctx.get_trail_stack()), m_mk_var_trail(*this) {} unsigned mk_var() { unsigned r = m_find.size(); m_find.push_back(r); m_size.push_back(1); m_next.push_back(r); m_trail_stack.push_ptr(&m_mk_var_trail); return r; } unsigned get_num_vars() const { return m_find.size(); } unsigned find(unsigned v) const { while (true) { unsigned new_v = m_find[v]; if (new_v == v) return v; v = new_v; } } unsigned next(unsigned v) const { return m_next[v]; } bool is_root(unsigned v) const { return m_find[v] == v; } void merge(unsigned v1, unsigned v2) { unsigned r1 = find(v1); unsigned r2 = find(v2); TRACE("union_find", tout << "merging " << r1 << " " << r2 << "\n";); if (r1 == r2) return; if (m_size[r1] > m_size[r2]) std::swap(r1, r2); m_ctx.merge_eh(r2, r1, v2, v1); m_find[r1] = r2; m_size[r2] += m_size[r1]; std::swap(m_next[r1], m_next[r2]); m_trail_stack.push(merge_trail(*this, r1)); m_ctx.after_merge_eh(r2, r1, v2, v1); CASSERT("union_find", check_invariant()); } void display(std::ostream & out) const { unsigned num = get_num_vars(); for (unsigned v = 0; v < num; v++) { out << "v" << v << " --> v" << m_find[v] << "\n"; } } #ifdef Z3DEBUG bool check_invariant() const { unsigned num = get_num_vars(); for (unsigned v = 0; v < num; v++) { if (is_root(v)) { unsigned curr = v; unsigned sz = 0; do { SASSERT(find(curr) == v); sz++; curr = next(curr); } while (curr != v); SASSERT(m_size[v] == sz); } } return true; } #endif }; #endif /* _UNION_FIND_H_ */