mirror of
https://github.com/Z3Prover/z3
synced 2026-01-23 18:44:02 +00:00
Adopt std::optional for try_get_value and try_get_size functions (#8268)
* Initial plan * Convert try_get_value and try_get_size to use std::optional Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com> * Add unit tests for std::optional conversions and fix compilation error Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com> * Address code review comments - improve readability and reduce code duplication Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>
This commit is contained in:
parent
2e7b700769
commit
1bb471447e
14 changed files with 92 additions and 44 deletions
|
|
@ -225,9 +225,8 @@ namespace api {
|
|||
e = m_bv_util.mk_numeral(n, s);
|
||||
}
|
||||
else if (fid == get_datalog_fid() && n.is_uint64()) {
|
||||
uint64_t sz;
|
||||
if (m_datalog_util.try_get_size(s, sz) &&
|
||||
sz <= n.get_uint64()) {
|
||||
if (auto size_opt = m_datalog_util.try_get_size(s);
|
||||
size_opt.has_value() && *size_opt <= n.get_uint64()) {
|
||||
invoke_error_handler(Z3_INVALID_ARG);
|
||||
}
|
||||
e = m_datalog_util.mk_numeral(n.get_uint64(), s);
|
||||
|
|
|
|||
|
|
@ -213,8 +213,11 @@ extern "C" {
|
|||
// must start logging here, since function uses Z3_get_sort_kind above
|
||||
LOG_Z3_get_finite_domain_sort_size(c, s, out);
|
||||
RESET_ERROR_CODE();
|
||||
VERIFY(mk_c(c)->datalog_util().try_get_size(to_sort(s), *out));
|
||||
return true;
|
||||
if (auto size = mk_c(c)->datalog_util().try_get_size(to_sort(s)); size) {
|
||||
*out = *size;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
Z3_CATCH_RETURN(false);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -659,8 +659,7 @@ namespace datalog {
|
|||
|
||||
app* dl_decl_util::mk_numeral(uint64_t value, sort* s) {
|
||||
if (is_finite_sort(s)) {
|
||||
uint64_t sz = 0;
|
||||
if (try_get_size(s, sz) && sz <= value) {
|
||||
if (auto sz = try_get_size(s); sz.has_value() && *sz <= value) {
|
||||
m.raise_exception("value is out of bounds");
|
||||
}
|
||||
parameter params[2] = { parameter(rational(value, rational::ui64())), parameter(s) };
|
||||
|
|
@ -758,13 +757,12 @@ namespace datalog {
|
|||
return m.mk_sort(get_family_id(), DL_FINITE_SORT, 2, params);
|
||||
}
|
||||
|
||||
bool dl_decl_util::try_get_size(const sort * s, uint64_t& size) const {
|
||||
std::optional<uint64_t> dl_decl_util::try_get_size(const sort * s) const {
|
||||
sort_size sz = s->get_info()->get_num_elements();
|
||||
if (sz.is_finite()) {
|
||||
size = sz.size();
|
||||
return true;
|
||||
return sz.size();
|
||||
}
|
||||
return false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
app* dl_decl_util::mk_lt(expr* a, expr* b) {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ Revision History:
|
|||
--*/
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include "ast/ast.h"
|
||||
#include "ast/arith_decl_plugin.h"
|
||||
#include "ast/bv_decl_plugin.h"
|
||||
|
|
@ -172,7 +173,7 @@ namespace datalog {
|
|||
|
||||
sort* mk_sort(const symbol& name, uint64_t domain_size);
|
||||
|
||||
bool try_get_size(const sort *, uint64_t& size) const;
|
||||
std::optional<uint64_t> try_get_size(const sort *) const;
|
||||
|
||||
bool is_finite_sort(sort* s) const {
|
||||
return is_sort_of(s, get_family_id(), DL_FINITE_SORT);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ Revision History:
|
|||
--*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <optional>
|
||||
#include "math/lp/numeric_pair.h"
|
||||
#include "math/lp/lp_types.h"
|
||||
#include "util/debug.h"
|
||||
|
|
@ -52,11 +53,11 @@ std::ostream& print_vector(const C * t, unsigned size, std::ostream & out) {
|
|||
|
||||
|
||||
template <typename A, typename B>
|
||||
bool try_get_value(const std::unordered_map<A,B> & map, const A& key, B & val) {
|
||||
std::optional<B> try_get_value(const std::unordered_map<A,B> & map, const A& key) {
|
||||
const auto it = map.find(key);
|
||||
if (it == map.end()) return false;
|
||||
val = it->second;
|
||||
return true;
|
||||
if (it == map.end())
|
||||
return std::nullopt;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
|
|
|
|||
|
|
@ -50,7 +50,10 @@ namespace datalog {
|
|||
|
||||
sort_domain(sort_kind k, context & ctx, sort * s)
|
||||
: m_kind(k), m_sort(s, ctx.get_manager()) {
|
||||
m_limited_size = ctx.get_decl_util().try_get_size(s, m_size);
|
||||
auto opt_size = ctx.get_decl_util().try_get_size(s);
|
||||
m_limited_size = opt_size.has_value();
|
||||
if (m_limited_size)
|
||||
m_size = *opt_size;
|
||||
}
|
||||
public:
|
||||
virtual ~sort_domain() = default;
|
||||
|
|
|
|||
|
|
@ -1125,13 +1125,12 @@ protected:
|
|||
\brief Make a constant for DK_SYMBOL sort out of an integer
|
||||
*/
|
||||
app* mk_symbol_const(uint64_t el, sort* s) {
|
||||
uint64_t sz = 0;
|
||||
if (m_arith.is_int(s))
|
||||
return m_arith.mk_numeral(rational(el, rational::ui64()), s);
|
||||
else if (m_decl_util.try_get_size(s, sz)) {
|
||||
if (el >= sz) {
|
||||
else if (auto sz = m_decl_util.try_get_size(s)) {
|
||||
if (el >= *sz) {
|
||||
std::ostringstream ous;
|
||||
ous << "numeric value " << el << " is out of bounds of domain size " << sz;
|
||||
ous << "numeric value " << el << " is out of bounds of domain size " << *sz;
|
||||
throw default_exception(ous.str());
|
||||
}
|
||||
return m_decl_util.mk_numeral(el, s);
|
||||
|
|
|
|||
|
|
@ -400,7 +400,11 @@ namespace datalog {
|
|||
}
|
||||
|
||||
bool relation_manager::relation_sort_to_table(const relation_sort & from, table_sort & to) {
|
||||
return get_context().get_decl_util().try_get_size(from, to);
|
||||
if (auto size = get_context().get_decl_util().try_get_size(from)) {
|
||||
to = *size;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void relation_manager::from_predicate(func_decl * pred, unsigned arg_index, relation_sort & result) {
|
||||
|
|
|
|||
|
|
@ -248,6 +248,17 @@ namespace datalog {
|
|||
SASSERT(dl.is_finite_sort(s));
|
||||
return dl.mk_numeral(r.get_uint64(), s);
|
||||
}
|
||||
|
||||
// Helper function to count bits needed to represent a size
|
||||
unsigned udoc_plugin::count_bits_for_size(uint64_t size) const {
|
||||
unsigned num_bits = 0;
|
||||
while (size > 0) {
|
||||
++num_bits;
|
||||
size /= 2;
|
||||
}
|
||||
return num_bits;
|
||||
}
|
||||
|
||||
bool udoc_plugin::is_numeral(expr* e, rational& r, unsigned& num_bits) {
|
||||
if (bv.is_numeral(e, r, num_bits)) return true;
|
||||
if (m.is_true(e)) {
|
||||
|
|
@ -260,12 +271,13 @@ namespace datalog {
|
|||
num_bits = 1;
|
||||
return true;
|
||||
}
|
||||
uint64_t n, sz;
|
||||
if (dl.is_numeral(e, n) && dl.try_get_size(e->get_sort(), sz)) {
|
||||
num_bits = 0;
|
||||
while (sz > 0) ++num_bits, sz = sz/2;
|
||||
r = rational(n, rational::ui64());
|
||||
return true;
|
||||
uint64_t n;
|
||||
if (dl.is_numeral(e, n)) {
|
||||
if (auto sz = dl.try_get_size(e->get_sort())) {
|
||||
num_bits = count_bits_for_size(*sz);
|
||||
r = rational(n, rational::ui64());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -275,10 +287,8 @@ namespace datalog {
|
|||
return bv.get_bv_size(s);
|
||||
if (m.is_bool(s))
|
||||
return 1;
|
||||
uint64_t sz;
|
||||
if (dl.try_get_size(s, sz)) {
|
||||
while (sz > 0) ++num_bits, sz /= 2;
|
||||
return num_bits;
|
||||
if (auto sz = dl.try_get_size(s)) {
|
||||
return count_bits_for_size(*sz);
|
||||
}
|
||||
UNREACHABLE();
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ namespace datalog {
|
|||
static udoc_relation const & get(relation_base const& r);
|
||||
void mk_union(doc_manager& dm, udoc& dst, udoc const& src, udoc* delta);
|
||||
bool is_numeral(expr* e, rational& r, unsigned& num_bits);
|
||||
unsigned count_bits_for_size(uint64_t size) const;
|
||||
unsigned num_sort_bits(expr* e) const { return num_sort_bits(e->get_sort()); }
|
||||
unsigned num_sort_bits(sort* s) const;
|
||||
bool is_finite_sort(sort* s) const;
|
||||
|
|
|
|||
|
|
@ -116,7 +116,9 @@ namespace qe {
|
|||
private:
|
||||
|
||||
bool is_small_domain(contains_app& x, eq_atoms& eqs, uint64_t& domain_size) {
|
||||
VERIFY(m_util.try_get_size(x.x()->get_sort(), domain_size));
|
||||
auto opt_size = m_util.try_get_size(x.x()->get_sort());
|
||||
VERIFY(opt_size);
|
||||
domain_size = *opt_size;
|
||||
return domain_size < eqs.num_eqs() + eqs.num_neqs();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -240,10 +240,10 @@ namespace smt {
|
|||
}
|
||||
|
||||
app* max_value(sort* s) {
|
||||
uint64_t sz;
|
||||
VERIFY(u().try_get_size(s, sz));
|
||||
SASSERT(sz > 0);
|
||||
return mk_bv_constant(sz-1, s);
|
||||
auto sz = u().try_get_size(s);
|
||||
VERIFY(sz);
|
||||
SASSERT(*sz > 0);
|
||||
return mk_bv_constant(*sz-1, s);
|
||||
}
|
||||
|
||||
void mk_lt(app* x, app* y) {
|
||||
|
|
|
|||
|
|
@ -96,16 +96,17 @@ void dl_query_test(ast_manager & m, smt_params & fparams, params_ref& params,
|
|||
f_b.reset();
|
||||
f_q.reset();
|
||||
for(unsigned col=0; col<sig_b.size(); ++col) {
|
||||
uint64_t sort_sz;
|
||||
if(!decl_util.try_get_size(sig_q[col], sort_sz)) {
|
||||
if (auto sort_sz = decl_util.try_get_size(sig_q[col])) {
|
||||
uint64_t num = ran()%(*sort_sz);
|
||||
app * el_b = decl_util.mk_numeral(num, sig_b[col]);
|
||||
f_b.push_back(el_b);
|
||||
app * el_q = decl_util.mk_numeral(num, sig_q[col]);
|
||||
f_q.push_back(el_q);
|
||||
}
|
||||
else {
|
||||
warning_msg("cannot get sort size");
|
||||
return;
|
||||
}
|
||||
uint64_t num = ran()%sort_sz;
|
||||
app * el_b = decl_util.mk_numeral(num, sig_b[col]);
|
||||
f_b.push_back(el_b);
|
||||
app * el_q = decl_util.mk_numeral(num, sig_q[col]);
|
||||
f_q.push_back(el_q);
|
||||
}
|
||||
|
||||
bool found_in_b = rel_b.contains_fact(f_b);
|
||||
|
|
|
|||
|
|
@ -19,7 +19,10 @@ Revision History:
|
|||
#include "util/trace.h"
|
||||
#include "util/debug.h"
|
||||
#include "util/memory_manager.h"
|
||||
#include "math/lp/lp_utils.h"
|
||||
#include "ast/dl_decl_plugin.h"
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
|
||||
static void tst1() {
|
||||
std::optional<int> v;
|
||||
|
|
@ -66,9 +69,32 @@ static void tst3() {
|
|||
ENSURE(*(*v) == 10);
|
||||
}
|
||||
|
||||
static void tst_try_get_value() {
|
||||
std::unordered_map<int, std::string> map;
|
||||
map[1] = "one";
|
||||
map[2] = "two";
|
||||
map[3] = "three";
|
||||
|
||||
// Test successful retrieval
|
||||
auto result1 = try_get_value(map, 1);
|
||||
ENSURE(result1.has_value());
|
||||
ENSURE(*result1 == "one");
|
||||
|
||||
auto result2 = try_get_value(map, 2);
|
||||
ENSURE(result2.has_value());
|
||||
ENSURE(*result2 == "two");
|
||||
|
||||
// Test unsuccessful retrieval
|
||||
auto result_missing = try_get_value(map, 999);
|
||||
ENSURE(!result_missing.has_value());
|
||||
|
||||
TRACE(optional, tout << "try_get_value tests passed\n";);
|
||||
}
|
||||
|
||||
void tst_optional() {
|
||||
tst1();
|
||||
tst2();
|
||||
tst3();
|
||||
tst_try_get_value();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue