/*++ Copyright (c) 2006 Microsoft Corporation Module Name: smt_cg_table.h Abstract: Author: Leonardo de Moura (leonardo) 2008-02-19. Revision History: --*/ #ifndef _SMT_CG_TABLE_H_ #define _SMT_CG_TABLE_H_ #include"smt_enode.h" #include"hashtable.h" #include"chashtable.h" namespace smt { typedef std::pair enode_bool_pair; #if 0 /** \brief Congruence table. */ class cg_table { struct cg_khasher { unsigned operator()(enode const * n) const { return n->get_decl_id(); } }; struct cg_chasher { unsigned operator()(enode const * n, unsigned idx) const { return n->get_arg(idx)->get_root()->hash(); } }; struct cg_hash { cg_khasher m_khasher; cg_chasher m_chasher; public: unsigned operator()(enode * n) const; }; struct cg_eq { bool & m_commutativity; cg_eq(bool & comm):m_commutativity(comm) {} bool operator()(enode * n1, enode * n2) const; }; typedef ptr_hashtable table; ast_manager & m_manager; bool m_commutativity; //!< true if the last found congruence used commutativity table m_table; public: cg_table(ast_manager & m); /** \brief Try to insert n into the table. If the table already contains an element n' congruent to n, then do nothing and return n' and a boolean indicating whether n and n' are congruence modulo commutativity, otherwise insert n and return (n,false). */ enode_bool_pair insert(enode * n) { // it doesn't make sense to insert a constant. SASSERT(n->get_num_args() > 0); m_commutativity = false; enode * n_prime = m_table.insert_if_not_there(n); SASSERT(contains(n)); return enode_bool_pair(n_prime, m_commutativity); } void erase(enode * n) { SASSERT(n->get_num_args() > 0); m_table.erase(n); SASSERT(!contains(n)); } bool contains(enode * n) const { return m_table.contains(n); } enode * find(enode * n) const { enode * r = 0; return m_table.find(n, r) ? r : 0; } bool contains_ptr(enode * n) const { enode * n_prime; return m_table.find(n, n_prime) && n == n_prime; } void reset() { m_table.reset(); } void display(std::ostream & out) const; void display_compact(std::ostream & out) const; #ifdef Z3DEBUG bool check_invariant() const; #endif }; #else // one table per function symbol /** \brief Congruence table. */ class cg_table { struct cg_unary_hash { unsigned operator()(enode * n) const { SASSERT(n->get_num_args() == 1); return n->get_arg(0)->get_root()->hash(); } }; struct cg_unary_eq { bool operator()(enode * n1, enode * n2) const { SASSERT(n1->get_num_args() == 1); SASSERT(n2->get_num_args() == 1); SASSERT(n1->get_decl() == n2->get_decl()); return n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root(); } }; typedef chashtable unary_table; struct cg_binary_hash { unsigned operator()(enode * n) const { SASSERT(n->get_num_args() == 2); // too many collisions // unsigned r = 17 + n->get_arg(0)->get_root()->hash(); // return r * 31 + n->get_arg(1)->get_root()->hash(); return combine_hash(n->get_arg(0)->get_root()->hash(), n->get_arg(1)->get_root()->hash()); } }; struct cg_binary_eq { bool operator()(enode * n1, enode * n2) const { SASSERT(n1->get_num_args() == 2); SASSERT(n2->get_num_args() == 2); SASSERT(n1->get_decl() == n2->get_decl()); #if 1 return n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root() && n1->get_arg(1)->get_root() == n2->get_arg(1)->get_root(); #else bool r = n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root() && n1->get_arg(1)->get_root() == n2->get_arg(1)->get_root(); static unsigned counter = 0; static unsigned failed = 0; if (!r) failed++; counter++; if (counter % 100000 == 0) std::cerr << "[cg_eq] " << counter << " " << failed << "\n"; return r; #endif } }; typedef chashtable binary_table; struct cg_comm_hash { unsigned operator()(enode * n) const { SASSERT(n->get_num_args() == 2); unsigned h1 = n->get_arg(0)->get_root()->hash(); unsigned h2 = n->get_arg(1)->get_root()->hash(); if (h1 > h2) std::swap(h1, h2); return hash_u((h1 << 16) | (h2 & 0xFFFF)); } }; struct cg_comm_eq { bool & m_commutativity; cg_comm_eq(bool & c):m_commutativity(c) {} bool operator()(enode * n1, enode * n2) const { SASSERT(n1->get_num_args() == 2); SASSERT(n2->get_num_args() == 2); SASSERT(n1->get_decl() == n2->get_decl()); enode * c1_1 = n1->get_arg(0)->get_root(); enode * c1_2 = n1->get_arg(1)->get_root(); enode * c2_1 = n2->get_arg(0)->get_root(); enode * c2_2 = n2->get_arg(1)->get_root(); if (c1_1 == c2_1 && c1_2 == c2_2) { return true; } if (c1_1 == c2_2 && c1_2 == c2_1) { m_commutativity = true; return true; } return false; } }; typedef chashtable comm_table; struct cg_hash { unsigned operator()(enode * n) const; }; struct cg_eq { bool operator()(enode * n1, enode * n2) const; }; typedef chashtable table; ast_manager & m_manager; bool m_commutativity; //!< true if the last found congruence used commutativity ptr_vector m_tables; obj_map m_func_decl2id; enum table_kind { UNARY, BINARY, BINARY_COMM, NARY }; void * mk_table_for(func_decl * d); unsigned set_func_decl_id(enode * n); void * get_table(enode * n) { unsigned tid = n->get_func_decl_id(); if (tid == UINT_MAX) tid = set_func_decl_id(n); SASSERT(tid < m_tables.size()); return m_tables[tid]; } public: cg_table(ast_manager & m); ~cg_table(); /** \brief Try to insert n into the table. If the table already contains an element n' congruent to n, then do nothing and return n' and a boolean indicating whether n and n' are congruence modulo commutativity, otherwise insert n and return (n,false). */ enode_bool_pair insert(enode * n) { // it doesn't make sense to insert a constant. SASSERT(n->get_num_args() > 0); enode * n_prime; void * t = get_table(n); switch (static_cast(GET_TAG(t))) { case UNARY: n_prime = UNTAG(unary_table*, t)->insert_if_not_there(n); return enode_bool_pair(n_prime, false); case BINARY: n_prime = UNTAG(binary_table*, t)->insert_if_not_there(n); return enode_bool_pair(n_prime, false); case BINARY_COMM: m_commutativity = false; n_prime = UNTAG(comm_table*, t)->insert_if_not_there(n); return enode_bool_pair(n_prime, m_commutativity); default: n_prime = UNTAG(table*, t)->insert_if_not_there(n); return enode_bool_pair(n_prime, false); } } void erase(enode * n) { SASSERT(n->get_num_args() > 0); void * t = get_table(n); switch (static_cast(GET_TAG(t))) { case UNARY: UNTAG(unary_table*, t)->erase(n); break; case BINARY: UNTAG(binary_table*, t)->erase(n); break; case BINARY_COMM: UNTAG(comm_table*, t)->erase(n); break; default: UNTAG(table*, t)->erase(n); break; } } bool contains(enode * n) const { SASSERT(n->get_num_args() > 0); void * t = const_cast(this)->get_table(n); switch (static_cast(GET_TAG(t))) { case UNARY: return UNTAG(unary_table*, t)->contains(n); case BINARY: return UNTAG(binary_table*, t)->contains(n); case BINARY_COMM: return UNTAG(comm_table*, t)->contains(n); default: return UNTAG(table*, t)->contains(n); } } enode * find(enode * n) const { SASSERT(n->get_num_args() > 0); enode * r = 0; void * t = const_cast(this)->get_table(n); switch (static_cast(GET_TAG(t))) { case UNARY: return UNTAG(unary_table*, t)->find(n, r) ? r : 0; case BINARY: return UNTAG(binary_table*, t)->find(n, r) ? r : 0; case BINARY_COMM: return UNTAG(comm_table*, t)->find(n, r) ? r : 0; default: return UNTAG(table*, t)->find(n, r) ? r : 0; } } bool contains_ptr(enode * n) const { enode * r; SASSERT(n->get_num_args() > 0); void * t = const_cast(this)->get_table(n); switch (static_cast(GET_TAG(t))) { case UNARY: return UNTAG(unary_table*, t)->find(n, r) && n == r; case BINARY: return UNTAG(binary_table*, t)->find(n, r) && n == r; case BINARY_COMM: return UNTAG(comm_table*, t)->find(n, r) && n == r; default: return UNTAG(table*, t)->find(n, r) && n == r; } } void reset(); void display(std::ostream & out) const; void display_compact(std::ostream & out) const; #ifdef Z3DEBUG bool check_invariant() const; #endif }; #endif }; #endif /* _SMT_CG_TABLE_H_ */