diff --git a/src/api/api_datatype.cpp b/src/api/api_datatype.cpp index 413cf05e6..5b199f469 100644 --- a/src/api/api_datatype.cpp +++ b/src/api/api_datatype.cpp @@ -306,23 +306,34 @@ extern "C" { Z3_CATCH; } - static datatype_decl* mk_datatype_decl(Z3_context c, - Z3_symbol name, - unsigned num_constructors, - Z3_constructor constructors[]) { + static datatype_decl* api_datatype_decl(Z3_context c, + Z3_symbol name, + unsigned num_parameters, + Z3_sort const parameters[], + unsigned num_constructors, + Z3_constructor constructors[]) { datatype_util& dt_util = mk_c(c)->dtutil(); ast_manager& m = mk_c(c)->m(); - // Collect type variables from field sorts in order of first appearance sort_ref_vector params(m); - obj_hashtable seen; - for (unsigned i = 0; i < num_constructors; ++i) { - constructor* cn = reinterpret_cast(constructors[i]); - for (unsigned j = 0; j < cn->m_sorts.size(); ++j) { - if (cn->m_sorts[j].get() && m.is_type_var(cn->m_sorts[j].get())) { - if (!seen.contains(cn->m_sorts[j].get())) { - params.push_back(cn->m_sorts[j].get()); - seen.insert(cn->m_sorts[j].get()); + + // If parameters are provided explicitly, use them + if (num_parameters > 0 && parameters) { + for (unsigned i = 0; i < num_parameters; ++i) { + params.push_back(to_sort(parameters[i])); + } + } + else { + // Otherwise, collect type variables from field sorts in order of first appearance + obj_hashtable seen; + for (unsigned i = 0; i < num_constructors; ++i) { + constructor* cn = reinterpret_cast(constructors[i]); + for (unsigned j = 0; j < cn->m_sorts.size(); ++j) { + if (cn->m_sorts[j].get() && m.is_type_var(cn->m_sorts[j].get())) { + if (!seen.contains(cn->m_sorts[j].get())) { + params.push_back(cn->m_sorts[j].get()); + seen.insert(cn->m_sorts[j].get()); + } } } } @@ -357,7 +368,7 @@ extern "C" { sort_ref_vector sorts(m); { - datatype_decl * data = mk_datatype_decl(c, name, num_constructors, constructors); + datatype_decl * data = api_datatype_decl(c, name, 0, nullptr, num_constructors, constructors); bool is_ok = mk_c(c)->get_dt_plugin()->mk_datatypes(1, &data, 0, nullptr, sorts); del_datatype_decl(data); @@ -379,6 +390,42 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_sort Z3_API Z3_mk_polymorphic_datatype(Z3_context c, + Z3_symbol name, + unsigned num_parameters, + Z3_sort parameters[], + unsigned num_constructors, + Z3_constructor constructors[]) { + Z3_TRY; + LOG_Z3_mk_polymorphic_datatype(c, name, num_parameters, parameters, num_constructors, constructors); + RESET_ERROR_CODE(); + ast_manager& m = mk_c(c)->m(); + datatype_util data_util(m); + + sort_ref_vector sorts(m); + { + datatype_decl * data = api_datatype_decl(c, name, num_parameters, parameters, num_constructors, constructors); + bool is_ok = mk_c(c)->get_dt_plugin()->mk_datatypes(1, &data, 0, nullptr, sorts); + del_datatype_decl(data); + + if (!is_ok) { + SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); + RETURN_Z3(nullptr); + } + } + sort * s = sorts.get(0); + + mk_c(c)->save_ast_trail(s); + ptr_vector const& cnstrs = *data_util.get_datatype_constructors(s); + + for (unsigned i = 0; i < num_constructors; ++i) { + constructor* cn = reinterpret_cast(constructors[i]); + cn->m_constructor = cnstrs[i]; + } + RETURN_Z3_mk_polymorphic_datatype(of_sort(s)); + Z3_CATCH_RETURN(nullptr); + } + typedef ptr_vector constructor_list; Z3_constructor_list Z3_API Z3_mk_constructor_list(Z3_context c, @@ -436,7 +483,7 @@ extern "C" { ptr_vector datas; for (unsigned i = 0; i < num_sorts; ++i) { constructor_list* cl = reinterpret_cast(constructor_lists[i]); - datas.push_back(mk_datatype_decl(c, sort_names[i], cl->size(), reinterpret_cast(cl->data()))); + datas.push_back(api_datatype_decl(c, sort_names[i], 0, nullptr, cl->size(), reinterpret_cast(cl->data()))); } sort_ref_vector _sorts(m); bool ok = mk_c(c)->get_dt_plugin()->mk_datatypes(datas.size(), datas.data(), 0, nullptr, _sorts); diff --git a/src/api/z3_api.h b/src/api/z3_api.h index a7dd59d86..baa2fa34c 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -2127,6 +2127,33 @@ extern "C" { unsigned num_constructors, Z3_constructor constructors[]); + /** + \brief Create a parametric datatype with explicit type parameters. + + This function is similar to #Z3_mk_datatype, except it takes an explicit set of type parameters. + The parameters can be type variables created with #Z3_mk_type_variable, allowing the definition + of polymorphic datatypes that can be instantiated with different concrete types. + + \param c logical context + \param name name of the datatype + \param num_parameters number of type parameters (can be 0) + \param parameters array of type parameters (type variables or concrete sorts) + \param num_constructors number of constructors + \param constructors array of constructor specifications + + \sa Z3_mk_datatype + \sa Z3_mk_type_variable + \sa Z3_mk_datatype_sort + + def_API('Z3_mk_polymorphic_datatype', SORT, (_in(CONTEXT), _in(SYMBOL), _in(UINT), _in_array(2, SORT), _in(UINT), _inout_array(4, CONSTRUCTOR))) + */ + Z3_sort Z3_API Z3_mk_polymorphic_datatype(Z3_context c, + Z3_symbol name, + unsigned num_parameters, + Z3_sort parameters[], + unsigned num_constructors, + Z3_constructor constructors[]); + /** \brief create a forward reference to a recursive datatype being declared. The forward reference can be used in a nested occurrence: the range of an array diff --git a/src/test/parametric_datatype.cpp b/src/test/parametric_datatype.cpp index 2c4f18be6..2958b934c 100644 --- a/src/test/parametric_datatype.cpp +++ b/src/test/parametric_datatype.cpp @@ -126,6 +126,104 @@ static void test_parametric_pair() { Z3_del_context(ctx); } +/** + * Test Z3_mk_polymorphic_datatype API with explicit parameters. + * + * This test demonstrates the new API that explicitly accepts type parameters. + */ +static void test_polymorphic_datatype_api() { + std::cout << "test_polymorphic_datatype_api\n"; + + Z3_config cfg = Z3_mk_config(); + Z3_context ctx = Z3_mk_context(cfg); + Z3_del_config(cfg); + + // Create type variables alpha and beta for polymorphic datatype + Z3_symbol alpha_sym = Z3_mk_string_symbol(ctx, "alpha"); + Z3_symbol beta_sym = Z3_mk_string_symbol(ctx, "beta"); + Z3_sort alpha = Z3_mk_type_variable(ctx, alpha_sym); + Z3_sort beta = Z3_mk_type_variable(ctx, beta_sym); + + // Define parametric triple datatype with constructor mk-triple(first: alpha, second: beta, third: alpha) + Z3_symbol triple_name = Z3_mk_string_symbol(ctx, "triple"); + Z3_symbol mk_triple_name = Z3_mk_string_symbol(ctx, "mk-triple"); + Z3_symbol is_triple_name = Z3_mk_string_symbol(ctx, "is-triple"); + Z3_symbol first_name = Z3_mk_string_symbol(ctx, "first"); + Z3_symbol second_name = Z3_mk_string_symbol(ctx, "second"); + Z3_symbol third_name = Z3_mk_string_symbol(ctx, "third"); + + Z3_symbol field_names[3] = {first_name, second_name, third_name}; + Z3_sort field_sorts[3] = {alpha, beta, alpha}; // Use type variables + unsigned sort_refs[3] = {0, 0, 0}; // Not recursive references + + Z3_constructor mk_triple_con = Z3_mk_constructor( + ctx, mk_triple_name, is_triple_name, 3, field_names, field_sorts, sort_refs + ); + + // Create the parametric datatype using Z3_mk_polymorphic_datatype + Z3_constructor constructors[1] = {mk_triple_con}; + Z3_sort type_params[2] = {alpha, beta}; + Z3_sort triple = Z3_mk_polymorphic_datatype(ctx, triple_name, 2, type_params, 1, constructors); + + Z3_del_constructor(ctx, mk_triple_con); + + std::cout << "Created parametric triple datatype using Z3_mk_polymorphic_datatype\n"; + std::cout << "triple sort: " << Z3_sort_to_string(ctx, triple) << "\n"; + + // Now instantiate the datatype with concrete types + Z3_sort int_sort = Z3_mk_int_sort(ctx); + Z3_sort bool_sort = Z3_mk_bool_sort(ctx); + + // Create (triple Int Bool) + Z3_sort params_int_bool[2] = {int_sort, bool_sort}; + Z3_sort triple_int_bool = Z3_mk_datatype_sort(ctx, triple_name, 2, params_int_bool); + + std::cout << "Instantiated triple with Int and Bool\n"; + std::cout << "triple_int_bool: " << Z3_sort_to_string(ctx, triple_int_bool) << "\n"; + + // Get constructors and accessors from the instantiated datatype + Z3_func_decl mk_triple_int_bool = Z3_get_datatype_sort_constructor(ctx, triple_int_bool, 0); + Z3_func_decl first_int_bool = Z3_get_datatype_sort_constructor_accessor(ctx, triple_int_bool, 0, 0); + Z3_func_decl second_int_bool = Z3_get_datatype_sort_constructor_accessor(ctx, triple_int_bool, 0, 1); + Z3_func_decl third_int_bool = Z3_get_datatype_sort_constructor_accessor(ctx, triple_int_bool, 0, 2); + + std::cout << "Got constructors and accessors from instantiated datatype\n"; + + // Create a constant t : (triple Int Bool) + Z3_symbol t_sym = Z3_mk_string_symbol(ctx, "t"); + Z3_ast t = Z3_mk_const(ctx, t_sym, triple_int_bool); + + // Create (first t) - should be Int + Z3_ast first_t = Z3_mk_app(ctx, first_int_bool, 1, &t); + + // Create (third t) - should also be Int + Z3_ast third_t = Z3_mk_app(ctx, third_int_bool, 1, &t); + + // Create the equality (= (first t) (third t)) + Z3_ast eq = Z3_mk_eq(ctx, first_t, third_t); + + std::cout << "Created term: " << Z3_ast_to_string(ctx, eq) << "\n"; + + // Verify the term was created successfully + ENSURE(eq != nullptr); + + // Check that first_t and third_t have the same sort (Int) + Z3_sort first_t_sort = Z3_get_sort(ctx, first_t); + Z3_sort third_t_sort = Z3_get_sort(ctx, third_t); + + std::cout << "Sort of (first t): " << Z3_sort_to_string(ctx, first_t_sort) << "\n"; + std::cout << "Sort of (third t): " << Z3_sort_to_string(ctx, third_t_sort) << "\n"; + + // Both should be Int + ENSURE(Z3_is_eq_sort(ctx, first_t_sort, int_sort)); + ENSURE(Z3_is_eq_sort(ctx, third_t_sort, int_sort)); + + std::cout << "test_polymorphic_datatype_api passed!\n"; + + Z3_del_context(ctx); +} + void tst_parametric_datatype() { test_parametric_pair(); + test_polymorphic_datatype_api(); }