diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 3df740911..bd32bfed0 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -66,27 +66,31 @@ namespace { namespace polysat { - unsigned slicing::dep_t::to_uint() const { - return std::visit([](auto arg) -> unsigned { + void* slicing::dep_t::encode() const { + void* p = std::visit([](auto arg) -> void* { using T = std::decay_t; if constexpr (std::is_same_v) - return UINT_MAX; + return nullptr; else if constexpr (std::is_same_v) - return (arg.to_uint() << 1); + return box(arg.to_uint(), 1); else if constexpr (std::is_same_v) - return (arg << 1) + 1; + return box(arg, 2); else static_assert(always_false_v, "non-exhaustive visitor!"); }, m_data); + SASSERT(*this == decode(p)); + return p; } - slicing::dep_t slicing::dep_t::from_uint(unsigned x) { - if (x == UINT_MAX) - return dep_t(); - else if ((x & 1) == 0) - return dep_t(sat::to_literal(x >> 1)); + slicing::dep_t slicing::dep_t::decode(void* p) { + if (!p) + return {}; + unsigned tag = get_tag(p); + SASSERT(tag == 1 || tag == 2); + if (tag == 1) + return dep_t(sat::to_literal(unbox(p))); else - return dep_t(static_cast(x >> 1)); + return dep_t(unbox(p)); } std::ostream& slicing::display(std::ostream& out, dep_t d) { @@ -99,16 +103,6 @@ namespace polysat { return out; } - void* slicing::encode_dep(dep_t d) { - void* p = box(d.to_uint()); - SASSERT(d == decode_dep(p)); - return p; - } - - slicing::dep_t slicing::decode_dep(void* p) { - return dep_t::from_uint(unbox(p)); - } - slicing::dep_t slicing::mk_var_dep(pvar v, enode* s) { SASSERT_EQ(m_dep_var.size(), m_dep_slice.size()); unsigned const idx = m_dep_var.size(); @@ -123,7 +117,7 @@ namespace polysat { { reg_decl_plugins(m_ast); m_bv = alloc(bv_util, m_ast); - m_egraph.set_display_justification([&](std::ostream& out, void* d) { display(out, decode_dep(d)); }); + m_egraph.set_display_justification([&](std::ostream& out, void* dp) { display(out, dep_t::decode(dp)); }); m_egraph.set_on_merge([&](enode* root, enode* other) { egraph_on_merge(root, other); }); m_egraph.set_on_propagate([&](enode* lit, enode* ante) { egraph_on_propagate(lit, ante); }); } @@ -271,7 +265,7 @@ namespace polysat { return eqn; auto args = {x, y}; eqn = m_egraph.mk(eq, 0, args.size(), args.begin()); - auto j = euf::justification::external(encode_dep(lit)); + auto j = euf::justification::external(dep_t(lit).encode()); m_egraph.set_value(eqn, l_false, j); SASSERT(eqn->is_equality()); SASSERT_EQ(eqn->value(), l_false); @@ -317,7 +311,7 @@ namespace polysat { slice_info& concat_info = m_info[concat->get_id()]; SASSERT(!concat_info.slice); // not yet set concat_info.slice = s; - m_egraph.merge(s, concat, encode_dep(dep_t())); + m_egraph.merge(s, concat, dep_t().encode()); } void slicing::add_congruence(pvar v) { @@ -411,12 +405,12 @@ namespace polysat { SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before. void* j_hi = j.ext(); void* j_lo = j.ext(); - dep_t d = decode_dep(j.ext()); + dep_t d = dep_t::decode(j.ext()); if (d.is_var_idx()) { enode* ds = get_dep_slice(d); SASSERT(ds == n || ds == target); - j_hi = encode_dep(mk_var_dep(get_dep_var(d), sub_hi(ds))); - j_lo = encode_dep(mk_var_dep(get_dep_var(d), sub_lo(ds))); + j_hi = mk_var_dep(get_dep_var(d), sub_hi(ds)).encode(); + j_lo = mk_var_dep(get_dep_var(d), sub_lo(ds)).encode(); } m_egraph.merge(sub_hi(n), sub_hi(target), j_hi); m_egraph.merge(sub_lo(n), sub_lo(target), j_lo); @@ -592,7 +586,7 @@ namespace polysat { pvar x = null_var; enode* sx = nullptr; pvar y = null_var; enode* sy = nullptr; for (void* dp : m_tmp_deps) { - dep_t const d = decode_dep(dp); + dep_t const d = dep_t::decode(dp); if (d.is_null()) continue; if (d.is_lit()) { @@ -756,7 +750,7 @@ namespace polysat { SASSERT(is_value(s2)); dep = mk_var_dep(get_dep_var(dep), s1); } - m_egraph.merge(s1, s2, encode_dep(dep)); + m_egraph.merge(s1, s2, dep.encode()); return !is_conflict(); } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index 208afe64d..0da6d70ec 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -54,8 +54,8 @@ namespace polysat { unsigned var_idx() const { SASSERT(is_var_idx()); return *std::get_if(&m_data); } bool operator==(dep_t other) const { return m_data == other.m_data; } bool operator!=(dep_t other) const { return !operator==(other); } - unsigned to_uint() const; - static dep_t from_uint(unsigned x); + void* encode() const; + static dep_t decode(void* p); }; using dep_vector = svector; @@ -124,9 +124,6 @@ namespace polysat { // Add s = concat(s1, ..., sn) void add_concat_node(enode* s, enode* concat); - static void* encode_dep(dep_t d); - static dep_t decode_dep(void* d); - slice_info& info(euf::enode* n); slice_info const& info(euf::enode* n) const; diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index 9b2aa630a..7c6b49ecd 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -37,7 +37,7 @@ namespace polysat { static std::ostream& display_reason(scoped_solver_slicing& s, std::ostream& out, ptr_vector deps) { char const* delim = ""; for (void* dp : deps) { - slicing::dep_t d = slicing::decode_dep(dp); + slicing::dep_t d = slicing::dep_t::decode(dp); if (d.is_null()) continue; s.sl().display(out << delim, d);