3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-01-23 18:44:02 +00:00

Adopt std::optional for try_get_value and try_get_size functions (#8268)

* Initial plan

* Convert try_get_value and try_get_size to use std::optional

Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>

* Add unit tests for std::optional conversions and fix compilation error

Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>

* Address code review comments - improve readability and reduce code duplication

Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>
This commit is contained in:
Copilot 2026-01-21 12:41:50 -08:00 committed by GitHub
parent 2e7b700769
commit 1bb471447e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 92 additions and 44 deletions

View file

@ -225,9 +225,8 @@ namespace api {
e = m_bv_util.mk_numeral(n, s);
}
else if (fid == get_datalog_fid() && n.is_uint64()) {
uint64_t sz;
if (m_datalog_util.try_get_size(s, sz) &&
sz <= n.get_uint64()) {
if (auto size_opt = m_datalog_util.try_get_size(s);
size_opt.has_value() && *size_opt <= n.get_uint64()) {
invoke_error_handler(Z3_INVALID_ARG);
}
e = m_datalog_util.mk_numeral(n.get_uint64(), s);

View file

@ -213,8 +213,11 @@ extern "C" {
// must start logging here, since function uses Z3_get_sort_kind above
LOG_Z3_get_finite_domain_sort_size(c, s, out);
RESET_ERROR_CODE();
VERIFY(mk_c(c)->datalog_util().try_get_size(to_sort(s), *out));
return true;
if (auto size = mk_c(c)->datalog_util().try_get_size(to_sort(s)); size) {
*out = *size;
return true;
}
return false;
Z3_CATCH_RETURN(false);
}

View file

@ -659,8 +659,7 @@ namespace datalog {
app* dl_decl_util::mk_numeral(uint64_t value, sort* s) {
if (is_finite_sort(s)) {
uint64_t sz = 0;
if (try_get_size(s, sz) && sz <= value) {
if (auto sz = try_get_size(s); sz.has_value() && *sz <= value) {
m.raise_exception("value is out of bounds");
}
parameter params[2] = { parameter(rational(value, rational::ui64())), parameter(s) };
@ -758,13 +757,12 @@ namespace datalog {
return m.mk_sort(get_family_id(), DL_FINITE_SORT, 2, params);
}
bool dl_decl_util::try_get_size(const sort * s, uint64_t& size) const {
std::optional<uint64_t> dl_decl_util::try_get_size(const sort * s) const {
sort_size sz = s->get_info()->get_num_elements();
if (sz.is_finite()) {
size = sz.size();
return true;
return sz.size();
}
return false;
return std::nullopt;
}
app* dl_decl_util::mk_lt(expr* a, expr* b) {

View file

@ -18,6 +18,7 @@ Revision History:
--*/
#pragma once
#include <optional>
#include "ast/ast.h"
#include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h"
@ -172,7 +173,7 @@ namespace datalog {
sort* mk_sort(const symbol& name, uint64_t domain_size);
bool try_get_size(const sort *, uint64_t& size) const;
std::optional<uint64_t> try_get_size(const sort *) const;
bool is_finite_sort(sort* s) const {
return is_sort_of(s, get_family_id(), DL_FINITE_SORT);

View file

@ -19,6 +19,7 @@ Revision History:
--*/
#pragma once
#include <string>
#include <optional>
#include "math/lp/numeric_pair.h"
#include "math/lp/lp_types.h"
#include "util/debug.h"
@ -52,11 +53,11 @@ std::ostream& print_vector(const C * t, unsigned size, std::ostream & out) {
template <typename A, typename B>
bool try_get_value(const std::unordered_map<A,B> & map, const A& key, B & val) {
std::optional<B> try_get_value(const std::unordered_map<A,B> & map, const A& key) {
const auto it = map.find(key);
if (it == map.end()) return false;
val = it->second;
return true;
if (it == map.end())
return std::nullopt;
return it->second;
}
template <typename A, typename B>

View file

@ -50,7 +50,10 @@ namespace datalog {
sort_domain(sort_kind k, context & ctx, sort * s)
: m_kind(k), m_sort(s, ctx.get_manager()) {
m_limited_size = ctx.get_decl_util().try_get_size(s, m_size);
auto opt_size = ctx.get_decl_util().try_get_size(s);
m_limited_size = opt_size.has_value();
if (m_limited_size)
m_size = *opt_size;
}
public:
virtual ~sort_domain() = default;

View file

@ -1125,13 +1125,12 @@ protected:
\brief Make a constant for DK_SYMBOL sort out of an integer
*/
app* mk_symbol_const(uint64_t el, sort* s) {
uint64_t sz = 0;
if (m_arith.is_int(s))
return m_arith.mk_numeral(rational(el, rational::ui64()), s);
else if (m_decl_util.try_get_size(s, sz)) {
if (el >= sz) {
else if (auto sz = m_decl_util.try_get_size(s)) {
if (el >= *sz) {
std::ostringstream ous;
ous << "numeric value " << el << " is out of bounds of domain size " << sz;
ous << "numeric value " << el << " is out of bounds of domain size " << *sz;
throw default_exception(ous.str());
}
return m_decl_util.mk_numeral(el, s);

View file

@ -400,7 +400,11 @@ namespace datalog {
}
bool relation_manager::relation_sort_to_table(const relation_sort & from, table_sort & to) {
return get_context().get_decl_util().try_get_size(from, to);
if (auto size = get_context().get_decl_util().try_get_size(from)) {
to = *size;
return true;
}
return false;
}
void relation_manager::from_predicate(func_decl * pred, unsigned arg_index, relation_sort & result) {

View file

@ -248,6 +248,17 @@ namespace datalog {
SASSERT(dl.is_finite_sort(s));
return dl.mk_numeral(r.get_uint64(), s);
}
// Helper function to count bits needed to represent a size
unsigned udoc_plugin::count_bits_for_size(uint64_t size) const {
unsigned num_bits = 0;
while (size > 0) {
++num_bits;
size /= 2;
}
return num_bits;
}
bool udoc_plugin::is_numeral(expr* e, rational& r, unsigned& num_bits) {
if (bv.is_numeral(e, r, num_bits)) return true;
if (m.is_true(e)) {
@ -260,12 +271,13 @@ namespace datalog {
num_bits = 1;
return true;
}
uint64_t n, sz;
if (dl.is_numeral(e, n) && dl.try_get_size(e->get_sort(), sz)) {
num_bits = 0;
while (sz > 0) ++num_bits, sz = sz/2;
r = rational(n, rational::ui64());
return true;
uint64_t n;
if (dl.is_numeral(e, n)) {
if (auto sz = dl.try_get_size(e->get_sort())) {
num_bits = count_bits_for_size(*sz);
r = rational(n, rational::ui64());
return true;
}
}
return false;
}
@ -275,10 +287,8 @@ namespace datalog {
return bv.get_bv_size(s);
if (m.is_bool(s))
return 1;
uint64_t sz;
if (dl.try_get_size(s, sz)) {
while (sz > 0) ++num_bits, sz /= 2;
return num_bits;
if (auto sz = dl.try_get_size(s)) {
return count_bits_for_size(*sz);
}
UNREACHABLE();
return 0;

View file

@ -105,6 +105,7 @@ namespace datalog {
static udoc_relation const & get(relation_base const& r);
void mk_union(doc_manager& dm, udoc& dst, udoc const& src, udoc* delta);
bool is_numeral(expr* e, rational& r, unsigned& num_bits);
unsigned count_bits_for_size(uint64_t size) const;
unsigned num_sort_bits(expr* e) const { return num_sort_bits(e->get_sort()); }
unsigned num_sort_bits(sort* s) const;
bool is_finite_sort(sort* s) const;

View file

@ -116,7 +116,9 @@ namespace qe {
private:
bool is_small_domain(contains_app& x, eq_atoms& eqs, uint64_t& domain_size) {
VERIFY(m_util.try_get_size(x.x()->get_sort(), domain_size));
auto opt_size = m_util.try_get_size(x.x()->get_sort());
VERIFY(opt_size);
domain_size = *opt_size;
return domain_size < eqs.num_eqs() + eqs.num_neqs();
}

View file

@ -240,10 +240,10 @@ namespace smt {
}
app* max_value(sort* s) {
uint64_t sz;
VERIFY(u().try_get_size(s, sz));
SASSERT(sz > 0);
return mk_bv_constant(sz-1, s);
auto sz = u().try_get_size(s);
VERIFY(sz);
SASSERT(*sz > 0);
return mk_bv_constant(*sz-1, s);
}
void mk_lt(app* x, app* y) {

View file

@ -96,16 +96,17 @@ void dl_query_test(ast_manager & m, smt_params & fparams, params_ref& params,
f_b.reset();
f_q.reset();
for(unsigned col=0; col<sig_b.size(); ++col) {
uint64_t sort_sz;
if(!decl_util.try_get_size(sig_q[col], sort_sz)) {
if (auto sort_sz = decl_util.try_get_size(sig_q[col])) {
uint64_t num = ran()%(*sort_sz);
app * el_b = decl_util.mk_numeral(num, sig_b[col]);
f_b.push_back(el_b);
app * el_q = decl_util.mk_numeral(num, sig_q[col]);
f_q.push_back(el_q);
}
else {
warning_msg("cannot get sort size");
return;
}
uint64_t num = ran()%sort_sz;
app * el_b = decl_util.mk_numeral(num, sig_b[col]);
f_b.push_back(el_b);
app * el_q = decl_util.mk_numeral(num, sig_q[col]);
f_q.push_back(el_q);
}
bool found_in_b = rel_b.contains_fact(f_b);

View file

@ -19,7 +19,10 @@ Revision History:
#include "util/trace.h"
#include "util/debug.h"
#include "util/memory_manager.h"
#include "math/lp/lp_utils.h"
#include "ast/dl_decl_plugin.h"
#include <optional>
#include <unordered_map>
static void tst1() {
std::optional<int> v;
@ -66,9 +69,32 @@ static void tst3() {
ENSURE(*(*v) == 10);
}
static void tst_try_get_value() {
std::unordered_map<int, std::string> map;
map[1] = "one";
map[2] = "two";
map[3] = "three";
// Test successful retrieval
auto result1 = try_get_value(map, 1);
ENSURE(result1.has_value());
ENSURE(*result1 == "one");
auto result2 = try_get_value(map, 2);
ENSURE(result2.has_value());
ENSURE(*result2 == "two");
// Test unsuccessful retrieval
auto result_missing = try_get_value(map, 999);
ENSURE(!result_missing.has_value());
TRACE(optional, tout << "try_get_value tests passed\n";);
}
void tst_optional() {
tst1();
tst2();
tst3();
tst_try_get_value();
}