diff --git a/src/api/api_context.cpp b/src/api/api_context.cpp index 82d65a5dd..e2e6a9fab 100644 --- a/src/api/api_context.cpp +++ b/src/api/api_context.cpp @@ -225,9 +225,8 @@ namespace api { e = m_bv_util.mk_numeral(n, s); } else if (fid == get_datalog_fid() && n.is_uint64()) { - uint64_t sz; - if (m_datalog_util.try_get_size(s, sz) && - sz <= n.get_uint64()) { + if (auto size_opt = m_datalog_util.try_get_size(s); + size_opt.has_value() && *size_opt <= n.get_uint64()) { invoke_error_handler(Z3_INVALID_ARG); } e = m_datalog_util.mk_numeral(n.get_uint64(), s); diff --git a/src/api/api_datalog.cpp b/src/api/api_datalog.cpp index 0b3eb989d..61d4aa9fd 100644 --- a/src/api/api_datalog.cpp +++ b/src/api/api_datalog.cpp @@ -213,8 +213,11 @@ extern "C" { // must start logging here, since function uses Z3_get_sort_kind above LOG_Z3_get_finite_domain_sort_size(c, s, out); RESET_ERROR_CODE(); - VERIFY(mk_c(c)->datalog_util().try_get_size(to_sort(s), *out)); - return true; + if (auto size = mk_c(c)->datalog_util().try_get_size(to_sort(s)); size) { + *out = *size; + return true; + } + return false; Z3_CATCH_RETURN(false); } diff --git a/src/ast/dl_decl_plugin.cpp b/src/ast/dl_decl_plugin.cpp index 19ae67fd5..af4d30add 100644 --- a/src/ast/dl_decl_plugin.cpp +++ b/src/ast/dl_decl_plugin.cpp @@ -659,8 +659,7 @@ namespace datalog { app* dl_decl_util::mk_numeral(uint64_t value, sort* s) { if (is_finite_sort(s)) { - uint64_t sz = 0; - if (try_get_size(s, sz) && sz <= value) { + if (auto sz = try_get_size(s); sz.has_value() && *sz <= value) { m.raise_exception("value is out of bounds"); } parameter params[2] = { parameter(rational(value, rational::ui64())), parameter(s) }; @@ -758,13 +757,12 @@ namespace datalog { return m.mk_sort(get_family_id(), DL_FINITE_SORT, 2, params); } - bool dl_decl_util::try_get_size(const sort * s, uint64_t& size) const { + std::optional dl_decl_util::try_get_size(const sort * s) const { sort_size sz = s->get_info()->get_num_elements(); if (sz.is_finite()) { - size = sz.size(); - return true; + return sz.size(); } - return false; + return std::nullopt; } app* dl_decl_util::mk_lt(expr* a, expr* b) { diff --git a/src/ast/dl_decl_plugin.h b/src/ast/dl_decl_plugin.h index c1cf08719..850d18126 100644 --- a/src/ast/dl_decl_plugin.h +++ b/src/ast/dl_decl_plugin.h @@ -18,6 +18,7 @@ Revision History: --*/ #pragma once +#include #include "ast/ast.h" #include "ast/arith_decl_plugin.h" #include "ast/bv_decl_plugin.h" @@ -172,7 +173,7 @@ namespace datalog { sort* mk_sort(const symbol& name, uint64_t domain_size); - bool try_get_size(const sort *, uint64_t& size) const; + std::optional try_get_size(const sort *) const; bool is_finite_sort(sort* s) const { return is_sort_of(s, get_family_id(), DL_FINITE_SORT); diff --git a/src/math/lp/lp_utils.h b/src/math/lp/lp_utils.h index fca2cff32..874bf1539 100644 --- a/src/math/lp/lp_utils.h +++ b/src/math/lp/lp_utils.h @@ -19,6 +19,7 @@ Revision History: --*/ #pragma once #include +#include #include "math/lp/numeric_pair.h" #include "math/lp/lp_types.h" #include "util/debug.h" @@ -52,11 +53,11 @@ std::ostream& print_vector(const C * t, unsigned size, std::ostream & out) { template -bool try_get_value(const std::unordered_map & map, const A& key, B & val) { +std::optional try_get_value(const std::unordered_map & map, const A& key) { const auto it = map.find(key); - if (it == map.end()) return false; - val = it->second; - return true; + if (it == map.end()) + return std::nullopt; + return it->second; } template diff --git a/src/muz/base/dl_context.cpp b/src/muz/base/dl_context.cpp index 0a828db7a..bf84d7be7 100644 --- a/src/muz/base/dl_context.cpp +++ b/src/muz/base/dl_context.cpp @@ -50,7 +50,10 @@ namespace datalog { sort_domain(sort_kind k, context & ctx, sort * s) : m_kind(k), m_sort(s, ctx.get_manager()) { - m_limited_size = ctx.get_decl_util().try_get_size(s, m_size); + auto opt_size = ctx.get_decl_util().try_get_size(s); + m_limited_size = opt_size.has_value(); + if (m_limited_size) + m_size = *opt_size; } public: virtual ~sort_domain() = default; diff --git a/src/muz/fp/datalog_parser.cpp b/src/muz/fp/datalog_parser.cpp index 192ccc547..bac981392 100644 --- a/src/muz/fp/datalog_parser.cpp +++ b/src/muz/fp/datalog_parser.cpp @@ -1125,13 +1125,12 @@ protected: \brief Make a constant for DK_SYMBOL sort out of an integer */ app* mk_symbol_const(uint64_t el, sort* s) { - uint64_t sz = 0; if (m_arith.is_int(s)) return m_arith.mk_numeral(rational(el, rational::ui64()), s); - else if (m_decl_util.try_get_size(s, sz)) { - if (el >= sz) { + else if (auto sz = m_decl_util.try_get_size(s)) { + if (el >= *sz) { std::ostringstream ous; - ous << "numeric value " << el << " is out of bounds of domain size " << sz; + ous << "numeric value " << el << " is out of bounds of domain size " << *sz; throw default_exception(ous.str()); } return m_decl_util.mk_numeral(el, s); diff --git a/src/muz/rel/dl_relation_manager.cpp b/src/muz/rel/dl_relation_manager.cpp index ed74edd2e..dda270462 100644 --- a/src/muz/rel/dl_relation_manager.cpp +++ b/src/muz/rel/dl_relation_manager.cpp @@ -400,7 +400,11 @@ namespace datalog { } bool relation_manager::relation_sort_to_table(const relation_sort & from, table_sort & to) { - return get_context().get_decl_util().try_get_size(from, to); + if (auto size = get_context().get_decl_util().try_get_size(from)) { + to = *size; + return true; + } + return false; } void relation_manager::from_predicate(func_decl * pred, unsigned arg_index, relation_sort & result) { diff --git a/src/muz/rel/udoc_relation.cpp b/src/muz/rel/udoc_relation.cpp index 3d98b25ab..fb4d12a02 100644 --- a/src/muz/rel/udoc_relation.cpp +++ b/src/muz/rel/udoc_relation.cpp @@ -248,6 +248,17 @@ namespace datalog { SASSERT(dl.is_finite_sort(s)); return dl.mk_numeral(r.get_uint64(), s); } + + // Helper function to count bits needed to represent a size + unsigned udoc_plugin::count_bits_for_size(uint64_t size) const { + unsigned num_bits = 0; + while (size > 0) { + ++num_bits; + size /= 2; + } + return num_bits; + } + bool udoc_plugin::is_numeral(expr* e, rational& r, unsigned& num_bits) { if (bv.is_numeral(e, r, num_bits)) return true; if (m.is_true(e)) { @@ -260,12 +271,13 @@ namespace datalog { num_bits = 1; return true; } - uint64_t n, sz; - if (dl.is_numeral(e, n) && dl.try_get_size(e->get_sort(), sz)) { - num_bits = 0; - while (sz > 0) ++num_bits, sz = sz/2; - r = rational(n, rational::ui64()); - return true; + uint64_t n; + if (dl.is_numeral(e, n)) { + if (auto sz = dl.try_get_size(e->get_sort())) { + num_bits = count_bits_for_size(*sz); + r = rational(n, rational::ui64()); + return true; + } } return false; } @@ -275,10 +287,8 @@ namespace datalog { return bv.get_bv_size(s); if (m.is_bool(s)) return 1; - uint64_t sz; - if (dl.try_get_size(s, sz)) { - while (sz > 0) ++num_bits, sz /= 2; - return num_bits; + if (auto sz = dl.try_get_size(s)) { + return count_bits_for_size(*sz); } UNREACHABLE(); return 0; diff --git a/src/muz/rel/udoc_relation.h b/src/muz/rel/udoc_relation.h index 54c0c580a..f7a47f0e0 100644 --- a/src/muz/rel/udoc_relation.h +++ b/src/muz/rel/udoc_relation.h @@ -105,6 +105,7 @@ namespace datalog { static udoc_relation const & get(relation_base const& r); void mk_union(doc_manager& dm, udoc& dst, udoc const& src, udoc* delta); bool is_numeral(expr* e, rational& r, unsigned& num_bits); + unsigned count_bits_for_size(uint64_t size) const; unsigned num_sort_bits(expr* e) const { return num_sort_bits(e->get_sort()); } unsigned num_sort_bits(sort* s) const; bool is_finite_sort(sort* s) const; diff --git a/src/qe/qe_dl_plugin.cpp b/src/qe/qe_dl_plugin.cpp index 3161903a3..515fd1882 100644 --- a/src/qe/qe_dl_plugin.cpp +++ b/src/qe/qe_dl_plugin.cpp @@ -116,7 +116,9 @@ namespace qe { private: bool is_small_domain(contains_app& x, eq_atoms& eqs, uint64_t& domain_size) { - VERIFY(m_util.try_get_size(x.x()->get_sort(), domain_size)); + auto opt_size = m_util.try_get_size(x.x()->get_sort()); + VERIFY(opt_size); + domain_size = *opt_size; return domain_size < eqs.num_eqs() + eqs.num_neqs(); } diff --git a/src/smt/theory_dl.cpp b/src/smt/theory_dl.cpp index d94c9595a..ee8c94d9a 100644 --- a/src/smt/theory_dl.cpp +++ b/src/smt/theory_dl.cpp @@ -240,10 +240,10 @@ namespace smt { } app* max_value(sort* s) { - uint64_t sz; - VERIFY(u().try_get_size(s, sz)); - SASSERT(sz > 0); - return mk_bv_constant(sz-1, s); + auto sz = u().try_get_size(s); + VERIFY(sz); + SASSERT(*sz > 0); + return mk_bv_constant(*sz-1, s); } void mk_lt(app* x, app* y) { diff --git a/src/test/dl_query.cpp b/src/test/dl_query.cpp index 71caa3d78..fe96881a9 100644 --- a/src/test/dl_query.cpp +++ b/src/test/dl_query.cpp @@ -96,16 +96,17 @@ void dl_query_test(ast_manager & m, smt_params & fparams, params_ref& params, f_b.reset(); f_q.reset(); for(unsigned col=0; col +#include static void tst1() { std::optional v; @@ -66,9 +69,32 @@ static void tst3() { ENSURE(*(*v) == 10); } +static void tst_try_get_value() { + std::unordered_map map; + map[1] = "one"; + map[2] = "two"; + map[3] = "three"; + + // Test successful retrieval + auto result1 = try_get_value(map, 1); + ENSURE(result1.has_value()); + ENSURE(*result1 == "one"); + + auto result2 = try_get_value(map, 2); + ENSURE(result2.has_value()); + ENSURE(*result2 == "two"); + + // Test unsuccessful retrieval + auto result_missing = try_get_value(map, 999); + ENSURE(!result_missing.has_value()); + + TRACE(optional, tout << "try_get_value tests passed\n";); +} + void tst_optional() { tst1(); tst2(); tst3(); + tst_try_get_value(); }