diff --git a/src/api/api_datatype.cpp b/src/api/api_datatype.cpp index c2f4b8b3e..413cf05e6 100644 --- a/src/api/api_datatype.cpp +++ b/src/api/api_datatype.cpp @@ -312,6 +312,22 @@ extern "C" { Z3_constructor constructors[]) { datatype_util& dt_util = mk_c(c)->dtutil(); ast_manager& m = mk_c(c)->m(); + + // Collect type variables from field sorts in order of first appearance + sort_ref_vector params(m); + obj_hashtable seen; + for (unsigned i = 0; i < num_constructors; ++i) { + constructor* cn = reinterpret_cast(constructors[i]); + for (unsigned j = 0; j < cn->m_sorts.size(); ++j) { + if (cn->m_sorts[j].get() && m.is_type_var(cn->m_sorts[j].get())) { + if (!seen.contains(cn->m_sorts[j].get())) { + params.push_back(cn->m_sorts[j].get()); + seen.insert(cn->m_sorts[j].get()); + } + } + } + } + ptr_vector constrs; for (unsigned i = 0; i < num_constructors; ++i) { constructor* cn = reinterpret_cast(constructors[i]); @@ -326,7 +342,7 @@ extern "C" { } constrs.push_back(mk_constructor_decl(cn->m_name, cn->m_tester, acc.size(), acc.data())); } - return mk_datatype_decl(dt_util, to_symbol(name), 0, nullptr, num_constructors, constrs.data()); + return mk_datatype_decl(dt_util, to_symbol(name), params.size(), params.data(), num_constructors, constrs.data()); } Z3_sort Z3_API Z3_mk_datatype(Z3_context c, diff --git a/src/ast/datatype_decl_plugin.cpp b/src/ast/datatype_decl_plugin.cpp index f91afc9ac..5bb918c5f 100644 --- a/src/ast/datatype_decl_plugin.cpp +++ b/src/ast/datatype_decl_plugin.cpp @@ -300,6 +300,12 @@ namespace datatype { TRACE(datatype, tout << "expected sort parameter at position " << i << " got: " << s << "\n";); throw invalid_datatype(); } + // Allow type variables as parameters for polymorphic datatypes + sort* param_sort = to_sort(s.get_ast()); + if (!m_manager->is_type_var(param_sort) && param_sort->get_family_id() == null_family_id) { + // Type variables and concrete sorts are allowed, but not other uninterpreted sorts + // Actually, all sorts should be allowed including uninterpreted ones + } } sort* s = m_manager->mk_sort(name.get_symbol(), diff --git a/src/test/parametric_datatype.cpp b/src/test/parametric_datatype.cpp index 9b3d4704f..2c4f18be6 100644 --- a/src/test/parametric_datatype.cpp +++ b/src/test/parametric_datatype.cpp @@ -22,14 +22,9 @@ Author: /** * Test polymorphic type variables with algebraic datatype definitions. * - * This test demonstrates the new Z3_mk_datatype_sort API that accepts type parameters. - * It creates two concrete instantiations of a generic pair concept: - * - pair_int_real with fields (first:Int, second:Real) - * - pair_real_int with fields (first:Real, second:Int) - * Then creates constants p1 and p2 of these types and verifies that: - * - (first-ir p1) has type Int - * - (second-ri p2) has type Int - * - The equality (= (first-ir p1) (second-ri p2)) is well-typed + * This test uses Z3_mk_type_variable to create polymorphic type parameters alpha and beta, + * defines a generic pair datatype, then instantiates it with concrete types using + * Z3_mk_datatype_sort with parameters. */ static void test_parametric_pair() { std::cout << "test_parametric_pair\n"; @@ -38,78 +33,78 @@ static void test_parametric_pair() { Z3_context ctx = Z3_mk_context(cfg); Z3_del_config(cfg); + // Create type variables alpha and beta for polymorphic datatype + Z3_symbol alpha_sym = Z3_mk_string_symbol(ctx, "alpha"); + Z3_symbol beta_sym = Z3_mk_string_symbol(ctx, "beta"); + Z3_sort alpha = Z3_mk_type_variable(ctx, alpha_sym); + Z3_sort beta = Z3_mk_type_variable(ctx, beta_sym); + + // Define parametric pair datatype with constructor mk-pair(first: alpha, second: beta) + Z3_symbol pair_name = Z3_mk_string_symbol(ctx, "pair"); + Z3_symbol mk_pair_name = Z3_mk_string_symbol(ctx, "mk-pair"); + Z3_symbol is_pair_name = Z3_mk_string_symbol(ctx, "is-pair"); + Z3_symbol first_name = Z3_mk_string_symbol(ctx, "first"); + Z3_symbol second_name = Z3_mk_string_symbol(ctx, "second"); + + Z3_symbol field_names[2] = {first_name, second_name}; + Z3_sort field_sorts[2] = {alpha, beta}; // Use type variables + unsigned sort_refs[2] = {0, 0}; // Not recursive references + + Z3_constructor mk_pair_con = Z3_mk_constructor( + ctx, mk_pair_name, is_pair_name, 2, field_names, field_sorts, sort_refs + ); + + // Create the parametric datatype + Z3_constructor constructors[1] = {mk_pair_con}; + Z3_sort pair = Z3_mk_datatype(ctx, pair_name, 1, constructors); + + Z3_del_constructor(ctx, mk_pair_con); + + std::cout << "Created parametric pair datatype\n"; + std::cout << "pair sort: " << Z3_sort_to_string(ctx, pair) << "\n"; + + // Now instantiate the datatype with concrete types Z3_sort int_sort = Z3_mk_int_sort(ctx); Z3_sort real_sort = Z3_mk_real_sort(ctx); - // The approach: Create two separate datatypes - pair_int_real and pair_real_int - // Each is a concrete instantiation of the parametric pair concept + // Create (pair Int Real) + Z3_sort params_int_real[2] = {int_sort, real_sort}; + Z3_sort pair_int_real = Z3_mk_datatype_sort(ctx, pair_name, 2, params_int_real); - // First datatype: pair_int_real with fields (first:Int, second:Real) - Z3_symbol pair_int_real_name = Z3_mk_string_symbol(ctx, "pair_int_real"); - Z3_symbol mk_pair_ir_name = Z3_mk_string_symbol(ctx, "mk-pair-ir"); - Z3_symbol is_pair_ir_name = Z3_mk_string_symbol(ctx, "is-pair-ir"); - Z3_symbol first_ir_name = Z3_mk_string_symbol(ctx, "first-ir"); - Z3_symbol second_ir_name = Z3_mk_string_symbol(ctx, "second-ir"); + // Create (pair Real Int) + Z3_sort params_real_int[2] = {real_sort, int_sort}; + Z3_sort pair_real_int = Z3_mk_datatype_sort(ctx, pair_name, 2, params_real_int); - Z3_symbol field_names_ir[2] = {first_ir_name, second_ir_name}; - Z3_sort field_sorts_ir[2] = {int_sort, real_sort}; - unsigned sort_refs_ir[2] = {0, 0}; + std::cout << "Instantiated pair with Int and Real\n"; + std::cout << "pair_int_real: " << Z3_sort_to_string(ctx, pair_int_real) << "\n"; + std::cout << "pair_real_int: " << Z3_sort_to_string(ctx, pair_real_int) << "\n"; - Z3_constructor mk_pair_ir_con = Z3_mk_constructor( - ctx, mk_pair_ir_name, is_pair_ir_name, 2, field_names_ir, field_sorts_ir, sort_refs_ir - ); + // Get constructors and accessors from the instantiated datatypes + Z3_func_decl mk_pair_int_real = Z3_get_datatype_sort_constructor(ctx, pair_int_real, 0); + Z3_func_decl first_int_real = Z3_get_datatype_sort_constructor_accessor(ctx, pair_int_real, 0, 0); + Z3_func_decl second_int_real = Z3_get_datatype_sort_constructor_accessor(ctx, pair_int_real, 0, 1); - Z3_constructor constructors_ir[1] = {mk_pair_ir_con}; - Z3_sort pair_int_real = Z3_mk_datatype(ctx, pair_int_real_name, 1, constructors_ir); + Z3_func_decl mk_pair_real_int = Z3_get_datatype_sort_constructor(ctx, pair_real_int, 0); + Z3_func_decl first_real_int = Z3_get_datatype_sort_constructor_accessor(ctx, pair_real_int, 0, 0); + Z3_func_decl second_real_int = Z3_get_datatype_sort_constructor_accessor(ctx, pair_real_int, 0, 1); - Z3_func_decl mk_pair_ir_decl, is_pair_ir_decl; - Z3_func_decl accessors_ir[2]; - Z3_query_constructor(ctx, mk_pair_ir_con, 2, &mk_pair_ir_decl, &is_pair_ir_decl, accessors_ir); - Z3_func_decl first_ir_decl = accessors_ir[0]; - Z3_func_decl second_ir_decl = accessors_ir[1]; - Z3_del_constructor(ctx, mk_pair_ir_con); + std::cout << "Got constructors and accessors from instantiated datatypes\n"; - // Second datatype: pair_real_int with fields (first:Real, second:Int) - Z3_symbol pair_real_int_name = Z3_mk_string_symbol(ctx, "pair_real_int"); - Z3_symbol mk_pair_ri_name = Z3_mk_string_symbol(ctx, "mk-pair-ri"); - Z3_symbol is_pair_ri_name = Z3_mk_string_symbol(ctx, "is-pair-ri"); - Z3_symbol first_ri_name = Z3_mk_string_symbol(ctx, "first-ri"); - Z3_symbol second_ri_name = Z3_mk_string_symbol(ctx, "second-ri"); - - Z3_symbol field_names_ri[2] = {first_ri_name, second_ri_name}; - Z3_sort field_sorts_ri[2] = {real_sort, int_sort}; - unsigned sort_refs_ri[2] = {0, 0}; - - Z3_constructor mk_pair_ri_con = Z3_mk_constructor( - ctx, mk_pair_ri_name, is_pair_ri_name, 2, field_names_ri, field_sorts_ri, sort_refs_ri - ); - - Z3_constructor constructors_ri[1] = {mk_pair_ri_con}; - Z3_sort pair_real_int = Z3_mk_datatype(ctx, pair_real_int_name, 1, constructors_ri); - - Z3_func_decl mk_pair_ri_decl, is_pair_ri_decl; - Z3_func_decl accessors_ri[2]; - Z3_query_constructor(ctx, mk_pair_ri_con, 2, &mk_pair_ri_decl, &is_pair_ri_decl, accessors_ri); - Z3_func_decl first_ri_decl = accessors_ri[0]; - Z3_func_decl second_ri_decl = accessors_ri[1]; - Z3_del_constructor(ctx, mk_pair_ri_con); - - // Create constants p1 : pair_int_real and p2 : pair_real_int + // Create constants p1 : (pair Int Real) and p2 : (pair Real Int) Z3_symbol p1_sym = Z3_mk_string_symbol(ctx, "p1"); Z3_symbol p2_sym = Z3_mk_string_symbol(ctx, "p2"); Z3_ast p1 = Z3_mk_const(ctx, p1_sym, pair_int_real); Z3_ast p2 = Z3_mk_const(ctx, p2_sym, pair_real_int); - // Create (first-ir p1) - should be Int - Z3_ast first_p1 = Z3_mk_app(ctx, first_ir_decl, 1, &p1); + // Create (first p1) - should be Int + Z3_ast first_p1 = Z3_mk_app(ctx, first_int_real, 1, &p1); - // Create (second-ri p2) - should be Int - Z3_ast second_p2 = Z3_mk_app(ctx, second_ri_decl, 1, &p2); + // Create (second p2) - should be Int + Z3_ast second_p2 = Z3_mk_app(ctx, second_real_int, 1, &p2); - // Create the equality (= (first-ir p1) (second-ri p2)) + // Create the equality (= (first p1) (second p2)) Z3_ast eq = Z3_mk_eq(ctx, first_p1, second_p2); - // Print the term std::cout << "Created term: " << Z3_ast_to_string(ctx, eq) << "\n"; // Verify the term was created successfully @@ -119,8 +114,8 @@ static void test_parametric_pair() { Z3_sort first_p1_sort = Z3_get_sort(ctx, first_p1); Z3_sort second_p2_sort = Z3_get_sort(ctx, second_p2); - std::cout << "Sort of (first-ir p1): " << Z3_sort_to_string(ctx, first_p1_sort) << "\n"; - std::cout << "Sort of (second-ri p2): " << Z3_sort_to_string(ctx, second_p2_sort) << "\n"; + std::cout << "Sort of (first p1): " << Z3_sort_to_string(ctx, first_p1_sort) << "\n"; + std::cout << "Sort of (second p2): " << Z3_sort_to_string(ctx, second_p2_sort) << "\n"; // Both should be Int ENSURE(Z3_is_eq_sort(ctx, first_p1_sort, int_sort));