3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-15 02:16:16 +00:00

switch parameter to an std::variant

plus fix mem leak & move constructor for zstrings
This commit is contained in:
Nuno Lopes 2021-05-23 13:07:29 +01:00
parent 9eb566b401
commit f8406623b4
2 changed files with 68 additions and 96 deletions

View file

@ -38,39 +38,27 @@ Revision History:
// ----------------------------------- // -----------------------------------
parameter::~parameter() { parameter::~parameter() {
if (m_kind == PARAM_RATIONAL) { if (auto p = std::get_if<rational*>(&m_val)) {
dealloc(m_rational); dealloc(*p);
} }
if (m_kind == PARAM_ZSTRING) { if (auto p = std::get_if<zstring*>(&m_val)) {
dealloc(m_zstring); dealloc(*p);
} }
} }
parameter::parameter(parameter const& other) {
m_kind = PARAM_INT;
m_int = 0;
*this = other;
}
parameter& parameter::operator=(parameter const& other) { parameter& parameter::operator=(parameter const& other) {
if (this == &other) { if (this == &other) {
return *this; return *this;
} }
if (m_kind == PARAM_RATIONAL) {
dealloc(m_rational); this->~parameter();
m_val = other.m_val;
if (auto p = std::get_if<rational*>(&m_val)) {
m_val = alloc(rational, **p);
} }
m_kind = other.m_kind; if (auto p = std::get_if<zstring*>(&m_val)) {
switch(other.m_kind) { m_val = alloc(zstring, **p);
case PARAM_INT: m_int = other.get_int(); break;
case PARAM_AST: m_ast = other.get_ast(); break;
case PARAM_SYMBOL: m_symbol = other.m_symbol; break;
case PARAM_RATIONAL: m_rational = alloc(rational, other.get_rational()); break;
case PARAM_DOUBLE: m_dval = other.m_dval; break;
case PARAM_EXTERNAL: m_ext_id = other.m_ext_id; break;
case PARAM_ZSTRING: m_zstring = alloc(zstring, other.get_zstring()); break;
default:
UNREACHABLE();
break;
} }
return *this; return *this;
} }
@ -95,45 +83,40 @@ void parameter::del_eh(ast_manager & m, family_id fid) {
} }
bool parameter::operator==(parameter const & p) const { bool parameter::operator==(parameter const & p) const {
if (m_kind != p.m_kind) return false; if (get_kind() != p.get_kind()) return false;
switch(m_kind) { switch (get_kind()) {
case PARAM_INT: return m_int == p.m_int;
case PARAM_AST: return m_ast == p.m_ast;
case PARAM_SYMBOL: return get_symbol() == p.get_symbol();
case PARAM_RATIONAL: return get_rational() == p.get_rational(); case PARAM_RATIONAL: return get_rational() == p.get_rational();
case PARAM_DOUBLE: return m_dval == p.m_dval;
case PARAM_EXTERNAL: return m_ext_id == p.m_ext_id;
case PARAM_ZSTRING: return get_zstring() == p.get_zstring(); case PARAM_ZSTRING: return get_zstring() == p.get_zstring();
default: UNREACHABLE(); return false; default: return m_val == p.m_val;
} }
} }
unsigned parameter::hash() const { unsigned parameter::hash() const {
unsigned b = 0; unsigned b = 0;
switch(m_kind) { switch (get_kind()) {
case PARAM_INT: b = m_int; break; case PARAM_INT: b = get_int(); break;
case PARAM_AST: b = m_ast->hash(); break; case PARAM_AST: b = get_ast()->hash(); break;
case PARAM_SYMBOL: b = get_symbol().hash(); break; case PARAM_SYMBOL: b = get_symbol().hash(); break;
case PARAM_RATIONAL: b = get_rational().hash(); break; case PARAM_RATIONAL: b = get_rational().hash(); break;
case PARAM_DOUBLE: b = static_cast<unsigned>(m_dval); break; case PARAM_DOUBLE: b = static_cast<unsigned>(get_double()); break;
case PARAM_ZSTRING: b = get_zstring().hash(); break; case PARAM_ZSTRING: b = get_zstring().hash(); break;
case PARAM_EXTERNAL: b = m_ext_id; break; case PARAM_EXTERNAL: b = get_ext_id(); break;
} }
return (b << 2) | m_kind; return (b << 2) | get_kind();
} }
std::ostream& parameter::display(std::ostream& out) const { std::ostream& parameter::display(std::ostream& out) const {
switch(m_kind) { switch (get_kind()) {
case PARAM_INT: return out << get_int(); case PARAM_INT: return out << get_int();
case PARAM_SYMBOL: return out << get_symbol(); case PARAM_SYMBOL: return out << get_symbol();
case PARAM_RATIONAL: return out << get_rational(); case PARAM_RATIONAL: return out << get_rational();
case PARAM_AST: return out << "#" << get_ast()->get_id(); case PARAM_AST: return out << '#' << get_ast()->get_id();
case PARAM_DOUBLE: return out << m_dval; case PARAM_DOUBLE: return out << get_double();
case PARAM_EXTERNAL: return out << "@" << m_ext_id; case PARAM_EXTERNAL: return out << '@' << get_ext_id();
case PARAM_ZSTRING: return out << get_zstring(); case PARAM_ZSTRING: return out << get_zstring();
default: default:
UNREACHABLE(); UNREACHABLE();
return out << "[invalid parameter]"; return out;
} }
} }

View file

@ -47,6 +47,7 @@ Revision History:
#include "util/z3_exception.h" #include "util/z3_exception.h"
#include "util/dependency.h" #include "util/dependency.h"
#include "util/rlimit.h" #include "util/rlimit.h"
#include <variant>
#define RECYCLE_FREE_AST_INDICES #define RECYCLE_FREE_AST_INDICES
@ -97,6 +98,7 @@ const family_id arith_family_id = 5;
*/ */
class parameter { class parameter {
public: public:
// NOTE: these must be in the same order as the entries in the variant below
enum kind_t { enum kind_t {
PARAM_INT, PARAM_INT,
PARAM_AST, PARAM_AST,
@ -113,63 +115,50 @@ public:
PARAM_EXTERNAL PARAM_EXTERNAL
}; };
private: private:
kind_t m_kind;
// It is not possible to use tag pointers, since symbols are already tagged. // It is not possible to use tag pointers, since symbols are already tagged.
union { std::variant<
int m_int; // for PARAM_INT int, // for PARAM_INT
ast* m_ast; // for PARAM_AST ast*, // for PARAM_AST
symbol m_symbol; // for PARAM_SYMBOL symbol, // for PARAM_SYMBOL
rational* m_rational; // for PARAM_RATIONAL zstring*, // for PARAM_ZSTRING
zstring* m_zstring; // for PARAM_ZSTRING rational*, // for PARAM_RATIONAL
double m_dval; // for PARAM_DOUBLE (remark: this is not used in float_decl_plugin) double, // for PARAM_DOUBLE (remark: this is not used in float_decl_plugin)
unsigned m_ext_id; // for PARAM_EXTERNAL unsigned // for PARAM_EXTERNAL
}; > m_val;
public: public:
parameter(): m_kind(PARAM_INT), m_int(0) {} parameter() : m_val(0) {}
explicit parameter(int val): m_kind(PARAM_INT), m_int(val) {} explicit parameter(int val): m_val(val) {}
explicit parameter(unsigned val): m_kind(PARAM_INT), m_int(val) {} explicit parameter(unsigned val): m_val((int)val) {}
explicit parameter(ast * p): m_kind(PARAM_AST), m_ast(p) {} explicit parameter(ast * p): m_val(p) {}
explicit parameter(symbol const & s): m_kind(PARAM_SYMBOL), m_symbol(s) {} explicit parameter(symbol const & s): m_val(s) {}
explicit parameter(rational const & r): m_kind(PARAM_RATIONAL), m_rational(alloc(rational, r)) {} explicit parameter(rational const & r): m_val(alloc(rational, r)) {}
explicit parameter(rational && r) : m_kind(PARAM_RATIONAL), m_rational(alloc(rational, std::move(r))) {} explicit parameter(rational && r) : m_val(alloc(rational, std::move(r))) {}
explicit parameter(zstring const& s): m_kind(PARAM_ZSTRING), m_zstring(alloc(zstring, s)) {} explicit parameter(zstring const& s): m_val(alloc(zstring, s)) {}
explicit parameter(zstring && s): m_kind(PARAM_ZSTRING), m_zstring(alloc(zstring, std::move(s))) {} explicit parameter(zstring && s): m_val(alloc(zstring, std::move(s))) {}
explicit parameter(double d):m_kind(PARAM_DOUBLE), m_dval(d) {} explicit parameter(double d): m_val(d) {}
explicit parameter(const char *s):m_kind(PARAM_SYMBOL), m_symbol(symbol(s)) {} explicit parameter(const char *s): m_val(symbol(s)) {}
explicit parameter(const std::string &s):m_kind(PARAM_SYMBOL), m_symbol(symbol(s)) {} explicit parameter(const std::string &s): m_val(symbol(s)) {}
explicit parameter(unsigned ext_id, bool):m_kind(PARAM_EXTERNAL), m_ext_id(ext_id) {} explicit parameter(unsigned ext_id, bool): m_val(ext_id) {}
parameter(parameter const&); parameter(parameter const& other) { *this = other; }
parameter(parameter && other) noexcept : m_kind(other.m_kind) { parameter(parameter && other) noexcept : m_val(std::move(other.m_val)) {
switch (other.m_kind) { other.m_val = 0;
case PARAM_INT: m_int = other.get_int(); break;
case PARAM_AST: m_ast = other.get_ast(); break;
case PARAM_SYMBOL: m_symbol = other.m_symbol; break;
case PARAM_RATIONAL: m_rational = nullptr; std::swap(m_rational, other.m_rational); break;
case PARAM_DOUBLE: m_dval = other.m_dval; break;
case PARAM_EXTERNAL: m_ext_id = other.m_ext_id; break;
case PARAM_ZSTRING: m_zstring = other.m_zstring; break;
default:
UNREACHABLE();
break;
}
} }
~parameter(); ~parameter();
parameter& operator=(parameter const& other); parameter& operator=(parameter const& other);
kind_t get_kind() const { return m_kind; } kind_t get_kind() const { return static_cast<kind_t>(m_val.index()); }
bool is_int() const { return m_kind == PARAM_INT; } bool is_int() const { return get_kind() == PARAM_INT; }
bool is_ast() const { return m_kind == PARAM_AST; } bool is_ast() const { return get_kind() == PARAM_AST; }
bool is_symbol() const { return m_kind == PARAM_SYMBOL; } bool is_symbol() const { return get_kind() == PARAM_SYMBOL; }
bool is_rational() const { return m_kind == PARAM_RATIONAL; } bool is_rational() const { return get_kind() == PARAM_RATIONAL; }
bool is_double() const { return m_kind == PARAM_DOUBLE; } bool is_double() const { return get_kind() == PARAM_DOUBLE; }
bool is_external() const { return m_kind == PARAM_EXTERNAL; } bool is_external() const { return get_kind() == PARAM_EXTERNAL; }
bool is_zstring() const { return m_kind == PARAM_ZSTRING; } bool is_zstring() const { return get_kind() == PARAM_ZSTRING; }
bool is_int(int & i) const { return is_int() && (i = get_int(), true); } bool is_int(int & i) const { return is_int() && (i = get_int(), true); }
bool is_ast(ast * & a) const { return is_ast() && (a = get_ast(), true); } bool is_ast(ast * & a) const { return is_ast() && (a = get_ast(), true); }
@ -191,13 +180,13 @@ public:
*/ */
void del_eh(ast_manager & m, family_id fid); void del_eh(ast_manager & m, family_id fid);
int get_int() const { SASSERT(is_int()); return m_int; } int get_int() const { return std::get<int>(m_val); }
ast * get_ast() const { SASSERT(is_ast()); return m_ast; } ast * get_ast() const { return std::get<ast*>(m_val); }
symbol get_symbol() const { SASSERT(is_symbol()); return m_symbol; } symbol get_symbol() const { return std::get<symbol>(m_val); }
rational const & get_rational() const { SASSERT(is_rational()); return *m_rational; } rational const & get_rational() const { return *std::get<rational*>(m_val); }
zstring const& get_zstring() const { SASSERT(is_zstring()); return *m_zstring; } zstring const& get_zstring() const { return *std::get<zstring*>(m_val); }
double get_double() const { SASSERT(is_double()); return m_dval; } double get_double() const { return std::get<double>(m_val); }
unsigned get_ext_id() const { SASSERT(is_external()); return m_ext_id; } unsigned get_ext_id() const { return std::get<unsigned>(m_val); }
bool operator==(parameter const & p) const; bool operator==(parameter const & p) const;
bool operator!=(parameter const & p) const { return !operator==(p); } bool operator!=(parameter const & p) const { return !operator==(p); }