mirror of
https://github.com/Z3Prover/z3
synced 2025-04-08 18:31:49 +00:00
355 lines
11 KiB
C++
355 lines
11 KiB
C++
/*++
|
|
Copyright (c) 2006 Microsoft Corporation
|
|
|
|
Module Name:
|
|
|
|
smt_cg_table.h
|
|
|
|
Abstract:
|
|
|
|
<abstract>
|
|
|
|
Author:
|
|
|
|
Leonardo de Moura (leonardo) 2008-02-19.
|
|
|
|
Revision History:
|
|
|
|
--*/
|
|
#ifndef _SMT_CG_TABLE_H_
|
|
#define _SMT_CG_TABLE_H_
|
|
|
|
#include"smt_enode.h"
|
|
#include"hashtable.h"
|
|
#include"chashtable.h"
|
|
|
|
namespace smt {
|
|
|
|
typedef std::pair<enode *, bool> enode_bool_pair;
|
|
|
|
#if 0
|
|
/**
|
|
\brief Congruence table.
|
|
*/
|
|
class cg_table {
|
|
struct cg_khasher {
|
|
unsigned operator()(enode const * n) const { return n->get_decl_id(); }
|
|
};
|
|
|
|
struct cg_chasher {
|
|
unsigned operator()(enode const * n, unsigned idx) const {
|
|
return n->get_arg(idx)->get_root()->hash();
|
|
}
|
|
};
|
|
|
|
struct cg_hash {
|
|
cg_khasher m_khasher;
|
|
cg_chasher m_chasher;
|
|
public:
|
|
unsigned operator()(enode * n) const;
|
|
};
|
|
|
|
struct cg_eq {
|
|
bool & m_commutativity;
|
|
cg_eq(bool & comm):m_commutativity(comm) {}
|
|
bool operator()(enode * n1, enode * n2) const;
|
|
};
|
|
|
|
typedef ptr_hashtable<enode, cg_hash, cg_eq> table;
|
|
|
|
ast_manager & m_manager;
|
|
bool m_commutativity; //!< true if the last found congruence used commutativity
|
|
table m_table;
|
|
public:
|
|
cg_table(ast_manager & m);
|
|
|
|
/**
|
|
\brief Try to insert n into the table. If the table already
|
|
contains an element n' congruent to n, then do nothing and
|
|
return n' and a boolean indicating whether n and n' are congruence
|
|
modulo commutativity, otherwise insert n and return (n,false).
|
|
*/
|
|
enode_bool_pair insert(enode * n) {
|
|
// it doesn't make sense to insert a constant.
|
|
SASSERT(n->get_num_args() > 0);
|
|
m_commutativity = false;
|
|
enode * n_prime = m_table.insert_if_not_there(n);
|
|
SASSERT(contains(n));
|
|
return enode_bool_pair(n_prime, m_commutativity);
|
|
}
|
|
|
|
void erase(enode * n) {
|
|
SASSERT(n->get_num_args() > 0);
|
|
m_table.erase(n);
|
|
SASSERT(!contains(n));
|
|
}
|
|
|
|
bool contains(enode * n) const {
|
|
return m_table.contains(n);
|
|
}
|
|
|
|
enode * find(enode * n) const {
|
|
enode * r = 0;
|
|
return m_table.find(n, r) ? r : 0;
|
|
}
|
|
|
|
bool contains_ptr(enode * n) const {
|
|
enode * n_prime;
|
|
return m_table.find(n, n_prime) && n == n_prime;
|
|
}
|
|
|
|
void reset() {
|
|
m_table.reset();
|
|
}
|
|
|
|
void display(std::ostream & out) const;
|
|
|
|
void display_compact(std::ostream & out) const;
|
|
#ifdef Z3DEBUG
|
|
bool check_invariant() const;
|
|
#endif
|
|
};
|
|
#else
|
|
// one table per function symbol
|
|
|
|
/**
|
|
\brief Congruence table.
|
|
*/
|
|
class cg_table {
|
|
struct cg_unary_hash {
|
|
unsigned operator()(enode * n) const {
|
|
SASSERT(n->get_num_args() == 1);
|
|
return n->get_arg(0)->get_root()->hash();
|
|
}
|
|
};
|
|
|
|
struct cg_unary_eq {
|
|
bool operator()(enode * n1, enode * n2) const {
|
|
SASSERT(n1->get_num_args() == 1);
|
|
SASSERT(n2->get_num_args() == 1);
|
|
SASSERT(n1->get_decl() == n2->get_decl());
|
|
return n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root();
|
|
}
|
|
};
|
|
|
|
typedef chashtable<enode *, cg_unary_hash, cg_unary_eq> unary_table;
|
|
|
|
struct cg_binary_hash {
|
|
unsigned operator()(enode * n) const {
|
|
SASSERT(n->get_num_args() == 2);
|
|
// too many collisions
|
|
// unsigned r = 17 + n->get_arg(0)->get_root()->hash();
|
|
// return r * 31 + n->get_arg(1)->get_root()->hash();
|
|
return combine_hash(n->get_arg(0)->get_root()->hash(), n->get_arg(1)->get_root()->hash());
|
|
}
|
|
};
|
|
|
|
struct cg_binary_eq {
|
|
bool operator()(enode * n1, enode * n2) const {
|
|
SASSERT(n1->get_num_args() == 2);
|
|
SASSERT(n2->get_num_args() == 2);
|
|
SASSERT(n1->get_decl() == n2->get_decl());
|
|
#if 1
|
|
return
|
|
n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root() &&
|
|
n1->get_arg(1)->get_root() == n2->get_arg(1)->get_root();
|
|
#else
|
|
bool r =
|
|
n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root() &&
|
|
n1->get_arg(1)->get_root() == n2->get_arg(1)->get_root();
|
|
static unsigned counter = 0;
|
|
static unsigned failed = 0;
|
|
if (!r)
|
|
failed++;
|
|
counter++;
|
|
if (counter % 100000 == 0)
|
|
std::cerr << "[cg_eq] " << counter << " " << failed << "\n";
|
|
return r;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
typedef chashtable<enode*, cg_binary_hash, cg_binary_eq> binary_table;
|
|
|
|
struct cg_comm_hash {
|
|
unsigned operator()(enode * n) const {
|
|
SASSERT(n->get_num_args() == 2);
|
|
unsigned h1 = n->get_arg(0)->get_root()->hash();
|
|
unsigned h2 = n->get_arg(1)->get_root()->hash();
|
|
if (h1 > h2)
|
|
std::swap(h1, h2);
|
|
return hash_u((h1 << 16) | (h2 & 0xFFFF));
|
|
}
|
|
};
|
|
|
|
struct cg_comm_eq {
|
|
bool & m_commutativity;
|
|
cg_comm_eq(bool & c):m_commutativity(c) {}
|
|
bool operator()(enode * n1, enode * n2) const {
|
|
SASSERT(n1->get_num_args() == 2);
|
|
SASSERT(n2->get_num_args() == 2);
|
|
SASSERT(n1->get_decl() == n2->get_decl());
|
|
enode * c1_1 = n1->get_arg(0)->get_root();
|
|
enode * c1_2 = n1->get_arg(1)->get_root();
|
|
enode * c2_1 = n2->get_arg(0)->get_root();
|
|
enode * c2_2 = n2->get_arg(1)->get_root();
|
|
if (c1_1 == c2_1 && c1_2 == c2_2) {
|
|
return true;
|
|
}
|
|
if (c1_1 == c2_2 && c1_2 == c2_1) {
|
|
m_commutativity = true;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
};
|
|
|
|
typedef chashtable<enode*, cg_comm_hash, cg_comm_eq> comm_table;
|
|
|
|
struct cg_hash {
|
|
unsigned operator()(enode * n) const;
|
|
};
|
|
|
|
struct cg_eq {
|
|
bool operator()(enode * n1, enode * n2) const;
|
|
};
|
|
|
|
typedef chashtable<enode*, cg_hash, cg_eq> table;
|
|
|
|
ast_manager & m_manager;
|
|
bool m_commutativity; //!< true if the last found congruence used commutativity
|
|
ptr_vector<void> m_tables;
|
|
obj_map<func_decl, unsigned> m_func_decl2id;
|
|
|
|
enum table_kind {
|
|
UNARY,
|
|
BINARY,
|
|
BINARY_COMM,
|
|
NARY
|
|
};
|
|
|
|
void * mk_table_for(func_decl * d);
|
|
unsigned set_func_decl_id(enode * n);
|
|
|
|
void * get_table(enode * n) {
|
|
unsigned tid = n->get_func_decl_id();
|
|
if (tid == UINT_MAX)
|
|
tid = set_func_decl_id(n);
|
|
SASSERT(tid < m_tables.size());
|
|
return m_tables[tid];
|
|
}
|
|
|
|
public:
|
|
cg_table(ast_manager & m);
|
|
~cg_table();
|
|
|
|
/**
|
|
\brief Try to insert n into the table. If the table already
|
|
contains an element n' congruent to n, then do nothing and
|
|
return n' and a boolean indicating whether n and n' are congruence
|
|
modulo commutativity, otherwise insert n and return (n,false).
|
|
*/
|
|
enode_bool_pair insert(enode * n) {
|
|
// it doesn't make sense to insert a constant.
|
|
SASSERT(n->get_num_args() > 0);
|
|
enode * n_prime;
|
|
void * t = get_table(n);
|
|
switch (static_cast<table_kind>(GET_TAG(t))) {
|
|
case UNARY:
|
|
n_prime = UNTAG(unary_table*, t)->insert_if_not_there(n);
|
|
return enode_bool_pair(n_prime, false);
|
|
case BINARY:
|
|
n_prime = UNTAG(binary_table*, t)->insert_if_not_there(n);
|
|
return enode_bool_pair(n_prime, false);
|
|
case BINARY_COMM:
|
|
m_commutativity = false;
|
|
n_prime = UNTAG(comm_table*, t)->insert_if_not_there(n);
|
|
return enode_bool_pair(n_prime, m_commutativity);
|
|
default:
|
|
n_prime = UNTAG(table*, t)->insert_if_not_there(n);
|
|
return enode_bool_pair(n_prime, false);
|
|
}
|
|
}
|
|
|
|
void erase(enode * n) {
|
|
SASSERT(n->get_num_args() > 0);
|
|
void * t = get_table(n);
|
|
switch (static_cast<table_kind>(GET_TAG(t))) {
|
|
case UNARY:
|
|
UNTAG(unary_table*, t)->erase(n);
|
|
break;
|
|
case BINARY:
|
|
UNTAG(binary_table*, t)->erase(n);
|
|
break;
|
|
case BINARY_COMM:
|
|
UNTAG(comm_table*, t)->erase(n);
|
|
break;
|
|
default:
|
|
UNTAG(table*, t)->erase(n);
|
|
break;
|
|
}
|
|
}
|
|
|
|
bool contains(enode * n) const {
|
|
SASSERT(n->get_num_args() > 0);
|
|
void * t = const_cast<cg_table*>(this)->get_table(n);
|
|
switch (static_cast<table_kind>(GET_TAG(t))) {
|
|
case UNARY:
|
|
return UNTAG(unary_table*, t)->contains(n);
|
|
case BINARY:
|
|
return UNTAG(binary_table*, t)->contains(n);
|
|
case BINARY_COMM:
|
|
return UNTAG(comm_table*, t)->contains(n);
|
|
default:
|
|
return UNTAG(table*, t)->contains(n);
|
|
}
|
|
}
|
|
|
|
enode * find(enode * n) const {
|
|
SASSERT(n->get_num_args() > 0);
|
|
enode * r = 0;
|
|
void * t = const_cast<cg_table*>(this)->get_table(n);
|
|
switch (static_cast<table_kind>(GET_TAG(t))) {
|
|
case UNARY:
|
|
return UNTAG(unary_table*, t)->find(n, r) ? r : 0;
|
|
case BINARY:
|
|
return UNTAG(binary_table*, t)->find(n, r) ? r : 0;
|
|
case BINARY_COMM:
|
|
return UNTAG(comm_table*, t)->find(n, r) ? r : 0;
|
|
default:
|
|
return UNTAG(table*, t)->find(n, r) ? r : 0;
|
|
}
|
|
}
|
|
|
|
bool contains_ptr(enode * n) const {
|
|
enode * r;
|
|
SASSERT(n->get_num_args() > 0);
|
|
void * t = const_cast<cg_table*>(this)->get_table(n);
|
|
switch (static_cast<table_kind>(GET_TAG(t))) {
|
|
case UNARY:
|
|
return UNTAG(unary_table*, t)->find(n, r) && n == r;
|
|
case BINARY:
|
|
return UNTAG(binary_table*, t)->find(n, r) && n == r;
|
|
case BINARY_COMM:
|
|
return UNTAG(comm_table*, t)->find(n, r) && n == r;
|
|
default:
|
|
return UNTAG(table*, t)->find(n, r) && n == r;
|
|
}
|
|
}
|
|
|
|
void reset();
|
|
|
|
void display(std::ostream & out) const;
|
|
|
|
void display_compact(std::ostream & out) const;
|
|
#ifdef Z3DEBUG
|
|
bool check_invariant() const;
|
|
#endif
|
|
};
|
|
|
|
#endif
|
|
};
|
|
|
|
#endif /* _SMT_CG_TABLE_H_ */
|
|
|