diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 82f0d0c86..e7ea51981 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -38,39 +38,27 @@ Revision History: // ----------------------------------- parameter::~parameter() { - if (m_kind == PARAM_RATIONAL) { - dealloc(m_rational); + if (auto p = std::get_if(&m_val)) { + dealloc(*p); } - if (m_kind == PARAM_ZSTRING) { - dealloc(m_zstring); + if (auto p = std::get_if(&m_val)) { + dealloc(*p); } } -parameter::parameter(parameter const& other) { - m_kind = PARAM_INT; - m_int = 0; - *this = other; -} - parameter& parameter::operator=(parameter const& other) { if (this == &other) { return *this; } - if (m_kind == PARAM_RATIONAL) { - dealloc(m_rational); + + this->~parameter(); + m_val = other.m_val; + + if (auto p = std::get_if(&m_val)) { + m_val = alloc(rational, **p); } - m_kind = other.m_kind; - switch(other.m_kind) { - case PARAM_INT: m_int = other.get_int(); break; - case PARAM_AST: m_ast = other.get_ast(); break; - case PARAM_SYMBOL: m_symbol = other.m_symbol; break; - case PARAM_RATIONAL: m_rational = alloc(rational, other.get_rational()); break; - case PARAM_DOUBLE: m_dval = other.m_dval; break; - case PARAM_EXTERNAL: m_ext_id = other.m_ext_id; break; - case PARAM_ZSTRING: m_zstring = alloc(zstring, other.get_zstring()); break; - default: - UNREACHABLE(); - break; + if (auto p = std::get_if(&m_val)) { + m_val = alloc(zstring, **p); } return *this; } @@ -95,45 +83,40 @@ void parameter::del_eh(ast_manager & m, family_id fid) { } bool parameter::operator==(parameter const & p) const { - if (m_kind != p.m_kind) return false; - switch(m_kind) { - case PARAM_INT: return m_int == p.m_int; - case PARAM_AST: return m_ast == p.m_ast; - case PARAM_SYMBOL: return get_symbol() == p.get_symbol(); + if (get_kind() != p.get_kind()) return false; + switch (get_kind()) { case PARAM_RATIONAL: return get_rational() == p.get_rational(); - case PARAM_DOUBLE: return m_dval == p.m_dval; - case PARAM_EXTERNAL: return m_ext_id == p.m_ext_id; case PARAM_ZSTRING: return get_zstring() == p.get_zstring(); - default: UNREACHABLE(); return false; + default: return m_val == p.m_val; } } unsigned parameter::hash() const { unsigned b = 0; - switch(m_kind) { - case PARAM_INT: b = m_int; break; - case PARAM_AST: b = m_ast->hash(); break; + switch (get_kind()) { + case PARAM_INT: b = get_int(); break; + case PARAM_AST: b = get_ast()->hash(); break; case PARAM_SYMBOL: b = get_symbol().hash(); break; case PARAM_RATIONAL: b = get_rational().hash(); break; - case PARAM_DOUBLE: b = static_cast(m_dval); break; + case PARAM_DOUBLE: b = static_cast(get_double()); break; case PARAM_ZSTRING: b = get_zstring().hash(); break; - case PARAM_EXTERNAL: b = m_ext_id; break; + case PARAM_EXTERNAL: b = get_ext_id(); break; } - return (b << 2) | m_kind; + return (b << 2) | get_kind(); } std::ostream& parameter::display(std::ostream& out) const { - switch(m_kind) { + switch (get_kind()) { case PARAM_INT: return out << get_int(); case PARAM_SYMBOL: return out << get_symbol(); case PARAM_RATIONAL: return out << get_rational(); - case PARAM_AST: return out << "#" << get_ast()->get_id(); - case PARAM_DOUBLE: return out << m_dval; - case PARAM_EXTERNAL: return out << "@" << m_ext_id; + case PARAM_AST: return out << '#' << get_ast()->get_id(); + case PARAM_DOUBLE: return out << get_double(); + case PARAM_EXTERNAL: return out << '@' << get_ext_id(); case PARAM_ZSTRING: return out << get_zstring(); default: UNREACHABLE(); - return out << "[invalid parameter]"; + return out; } } diff --git a/src/ast/ast.h b/src/ast/ast.h index cdf53d4b2..07c3d1027 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -47,6 +47,7 @@ Revision History: #include "util/z3_exception.h" #include "util/dependency.h" #include "util/rlimit.h" +#include #define RECYCLE_FREE_AST_INDICES @@ -97,6 +98,7 @@ const family_id arith_family_id = 5; */ class parameter { public: + // NOTE: these must be in the same order as the entries in the variant below enum kind_t { PARAM_INT, PARAM_AST, @@ -113,63 +115,50 @@ public: PARAM_EXTERNAL }; private: - kind_t m_kind; - // It is not possible to use tag pointers, since symbols are already tagged. - union { - int m_int; // for PARAM_INT - ast* m_ast; // for PARAM_AST - symbol m_symbol; // for PARAM_SYMBOL - rational* m_rational; // for PARAM_RATIONAL - zstring* m_zstring; // for PARAM_ZSTRING - double m_dval; // for PARAM_DOUBLE (remark: this is not used in float_decl_plugin) - unsigned m_ext_id; // for PARAM_EXTERNAL - }; + std::variant< + int, // for PARAM_INT + ast*, // for PARAM_AST + symbol, // for PARAM_SYMBOL + zstring*, // for PARAM_ZSTRING + rational*, // for PARAM_RATIONAL + double, // for PARAM_DOUBLE (remark: this is not used in float_decl_plugin) + unsigned // for PARAM_EXTERNAL + > m_val; public: - parameter(): m_kind(PARAM_INT), m_int(0) {} - explicit parameter(int val): m_kind(PARAM_INT), m_int(val) {} - explicit parameter(unsigned val): m_kind(PARAM_INT), m_int(val) {} - explicit parameter(ast * p): m_kind(PARAM_AST), m_ast(p) {} - explicit parameter(symbol const & s): m_kind(PARAM_SYMBOL), m_symbol(s) {} - explicit parameter(rational const & r): m_kind(PARAM_RATIONAL), m_rational(alloc(rational, r)) {} - explicit parameter(rational && r) : m_kind(PARAM_RATIONAL), m_rational(alloc(rational, std::move(r))) {} - explicit parameter(zstring const& s): m_kind(PARAM_ZSTRING), m_zstring(alloc(zstring, s)) {} - explicit parameter(zstring && s): m_kind(PARAM_ZSTRING), m_zstring(alloc(zstring, std::move(s))) {} - explicit parameter(double d):m_kind(PARAM_DOUBLE), m_dval(d) {} - explicit parameter(const char *s):m_kind(PARAM_SYMBOL), m_symbol(symbol(s)) {} - explicit parameter(const std::string &s):m_kind(PARAM_SYMBOL), m_symbol(symbol(s)) {} - explicit parameter(unsigned ext_id, bool):m_kind(PARAM_EXTERNAL), m_ext_id(ext_id) {} - parameter(parameter const&); + parameter() : m_val(0) {} + explicit parameter(int val): m_val(val) {} + explicit parameter(unsigned val): m_val((int)val) {} + explicit parameter(ast * p): m_val(p) {} + explicit parameter(symbol const & s): m_val(s) {} + explicit parameter(rational const & r): m_val(alloc(rational, r)) {} + explicit parameter(rational && r) : m_val(alloc(rational, std::move(r))) {} + explicit parameter(zstring const& s): m_val(alloc(zstring, s)) {} + explicit parameter(zstring && s): m_val(alloc(zstring, std::move(s))) {} + explicit parameter(double d): m_val(d) {} + explicit parameter(const char *s): m_val(symbol(s)) {} + explicit parameter(const std::string &s): m_val(symbol(s)) {} + explicit parameter(unsigned ext_id, bool): m_val(ext_id) {} + parameter(parameter const& other) { *this = other; } - parameter(parameter && other) noexcept : m_kind(other.m_kind) { - switch (other.m_kind) { - case PARAM_INT: m_int = other.get_int(); break; - case PARAM_AST: m_ast = other.get_ast(); break; - case PARAM_SYMBOL: m_symbol = other.m_symbol; break; - case PARAM_RATIONAL: m_rational = nullptr; std::swap(m_rational, other.m_rational); break; - case PARAM_DOUBLE: m_dval = other.m_dval; break; - case PARAM_EXTERNAL: m_ext_id = other.m_ext_id; break; - case PARAM_ZSTRING: m_zstring = other.m_zstring; break; - default: - UNREACHABLE(); - break; - } + parameter(parameter && other) noexcept : m_val(std::move(other.m_val)) { + other.m_val = 0; } ~parameter(); parameter& operator=(parameter const& other); - kind_t get_kind() const { return m_kind; } - bool is_int() const { return m_kind == PARAM_INT; } - bool is_ast() const { return m_kind == PARAM_AST; } - bool is_symbol() const { return m_kind == PARAM_SYMBOL; } - bool is_rational() const { return m_kind == PARAM_RATIONAL; } - bool is_double() const { return m_kind == PARAM_DOUBLE; } - bool is_external() const { return m_kind == PARAM_EXTERNAL; } - bool is_zstring() const { return m_kind == PARAM_ZSTRING; } + kind_t get_kind() const { return static_cast(m_val.index()); } + bool is_int() const { return get_kind() == PARAM_INT; } + bool is_ast() const { return get_kind() == PARAM_AST; } + bool is_symbol() const { return get_kind() == PARAM_SYMBOL; } + bool is_rational() const { return get_kind() == PARAM_RATIONAL; } + bool is_double() const { return get_kind() == PARAM_DOUBLE; } + bool is_external() const { return get_kind() == PARAM_EXTERNAL; } + bool is_zstring() const { return get_kind() == PARAM_ZSTRING; } bool is_int(int & i) const { return is_int() && (i = get_int(), true); } bool is_ast(ast * & a) const { return is_ast() && (a = get_ast(), true); } @@ -191,13 +180,13 @@ public: */ void del_eh(ast_manager & m, family_id fid); - int get_int() const { SASSERT(is_int()); return m_int; } - ast * get_ast() const { SASSERT(is_ast()); return m_ast; } - symbol get_symbol() const { SASSERT(is_symbol()); return m_symbol; } - rational const & get_rational() const { SASSERT(is_rational()); return *m_rational; } - zstring const& get_zstring() const { SASSERT(is_zstring()); return *m_zstring; } - double get_double() const { SASSERT(is_double()); return m_dval; } - unsigned get_ext_id() const { SASSERT(is_external()); return m_ext_id; } + int get_int() const { return std::get(m_val); } + ast * get_ast() const { return std::get(m_val); } + symbol get_symbol() const { return std::get(m_val); } + rational const & get_rational() const { return *std::get(m_val); } + zstring const& get_zstring() const { return *std::get(m_val); } + double get_double() const { return std::get(m_val); } + unsigned get_ext_id() const { return std::get(m_val); } bool operator==(parameter const & p) const; bool operator!=(parameter const & p) const { return !operator==(p); }