diff --git a/src/api/dotnet/Context.cs b/src/api/dotnet/Context.cs index 9293b1a31..a3cb41642 100644 --- a/src/api/dotnet/Context.cs +++ b/src/api/dotnet/Context.cs @@ -474,6 +474,36 @@ namespace Microsoft.Z3 return new DatatypeSort(this, symbol, constructors); } + /// + /// Create a forward reference to a datatype sort. + /// This is useful for creating recursive datatypes or parametric datatypes. + /// + /// name of the datatype sort + /// optional array of sort parameters for parametric datatypes + public DatatypeSort MkDatatypeSortRef(Symbol name, Sort[] params = null) + { + Debug.Assert(name != null); + CheckContextMatch(name); + if (params != null) + CheckContextMatch(params); + + var numParams = (params == null) ? 0 : (uint)params.Length; + var paramsNative = (params == null) ? null : AST.ArrayToNative(params); + return new DatatypeSort(this, Native.Z3_mk_datatype_sort(nCtx, name.NativeObject, numParams, paramsNative)); + } + + /// + /// Create a forward reference to a datatype sort. + /// This is useful for creating recursive datatypes or parametric datatypes. + /// + /// name of the datatype sort + /// optional array of sort parameters for parametric datatypes + public DatatypeSort MkDatatypeSortRef(string name, Sort[] params = null) + { + using var symbol = MkSymbol(name); + return MkDatatypeSortRef(symbol, params); + } + /// /// Create mutually recursive datatypes. /// diff --git a/src/api/java/Context.java b/src/api/java/Context.java index 2350b52ae..691ecd737 100644 --- a/src/api/java/Context.java +++ b/src/api/java/Context.java @@ -388,6 +388,54 @@ public class Context implements AutoCloseable { return new DatatypeSort<>(this, mkSymbol(name), constructors); } + /** + * Create a forward reference to a datatype sort. + * This is useful for creating recursive datatypes or parametric datatypes. + * @param name name of the datatype sort + * @param params optional array of sort parameters for parametric datatypes + **/ + public DatatypeSort mkDatatypeSortRef(Symbol name, Sort[] params) + { + checkContextMatch(name); + if (params != null) + checkContextMatch(params); + + int numParams = (params == null) ? 0 : params.length; + long[] paramsNative = (params == null) ? new long[0] : AST.arrayToNative(params); + return new DatatypeSort<>(this, Native.mkDatatypeSort(nCtx(), name.getNativeObject(), numParams, paramsNative)); + } + + /** + * Create a forward reference to a datatype sort (non-parametric). + * This is useful for creating recursive datatypes. + * @param name name of the datatype sort + **/ + public DatatypeSort mkDatatypeSortRef(Symbol name) + { + return mkDatatypeSortRef(name, null); + } + + /** + * Create a forward reference to a datatype sort. + * This is useful for creating recursive datatypes or parametric datatypes. + * @param name name of the datatype sort + * @param params optional array of sort parameters for parametric datatypes + **/ + public DatatypeSort mkDatatypeSortRef(String name, Sort[] params) + { + return mkDatatypeSortRef(mkSymbol(name), params); + } + + /** + * Create a forward reference to a datatype sort (non-parametric). + * This is useful for creating recursive datatypes. + * @param name name of the datatype sort + **/ + public DatatypeSort mkDatatypeSortRef(String name) + { + return mkDatatypeSortRef(name, null); + } + /** * Create mutually recursive datatypes. * @param names names of datatype sorts diff --git a/src/api/ml/z3.ml b/src/api/ml/z3.ml index 4d5238957..57db13f9d 100644 --- a/src/api/ml/z3.ml +++ b/src/api/ml/z3.ml @@ -909,11 +909,18 @@ struct mk_sort ctx (Symbol.mk_string ctx name) constructors let mk_sort_ref (ctx: context) (name:Symbol.symbol) = - Z3native.mk_datatype_sort ctx name + Z3native.mk_datatype_sort ctx name 0 [||] let mk_sort_ref_s (ctx: context) (name: string) = mk_sort_ref ctx (Symbol.mk_string ctx name) + let mk_sort_ref_p (ctx: context) (name:Symbol.symbol) (params:Sort.sort list) = + let param_array = Array.of_list params in + Z3native.mk_datatype_sort ctx name (List.length params) param_array + + let mk_sort_ref_ps (ctx: context) (name: string) (params:Sort.sort list) = + mk_sort_ref_p ctx (Symbol.mk_string ctx name) params + let mk_sorts (ctx:context) (names:Symbol.symbol list) (c:Constructor.constructor list list) = let n = List.length names in let f e = ConstructorList.create ctx e in diff --git a/src/api/ml/z3.mli b/src/api/ml/z3.mli index 7afc01918..6764b0e2d 100644 --- a/src/api/ml/z3.mli +++ b/src/api/ml/z3.mli @@ -1087,6 +1087,12 @@ sig (* [mk_sort_ref_s ctx s] is [mk_sort_ref ctx (Symbol.mk_string ctx s)] *) val mk_sort_ref_s : context -> string -> Sort.sort + (** Create a forward reference to a parametric datatype sort. *) + val mk_sort_ref_p : context -> Symbol.symbol -> Sort.sort list -> Sort.sort + + (** Create a forward reference to a parametric datatype sort. *) + val mk_sort_ref_ps : context -> string -> Sort.sort list -> Sort.sort + (** Create a new datatype sort. *) val mk_sort : context -> Symbol.symbol -> Constructor.constructor list -> Sort.sort diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 051265a78..128726dae 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -5474,10 +5474,30 @@ class DatatypeRef(ExprRef): """Return the datatype sort of the datatype expression `self`.""" return DatatypeSortRef(Z3_get_sort(self.ctx_ref(), self.as_ast()), self.ctx) -def DatatypeSort(name, ctx = None): - """Create a reference to a sort that was declared, or will be declared, as a recursive datatype""" +def DatatypeSort(name, params=None, ctx=None): + """Create a reference to a sort that was declared, or will be declared, as a recursive datatype. + + Args: + name: name of the datatype sort + params: optional list/tuple of sort parameters for parametric datatypes + ctx: Z3 context (optional) + + Example: + >>> # Non-parametric datatype + >>> TreeRef = DatatypeSort('Tree') + >>> # Parametric datatype with one parameter + >>> ListIntRef = DatatypeSort('List', [IntSort()]) + >>> # Parametric datatype with multiple parameters + >>> PairRef = DatatypeSort('Pair', [IntSort(), BoolSort()]) + """ ctx = _get_ctx(ctx) - return DatatypeSortRef(Z3_mk_datatype_sort(ctx.ref(), to_symbol(name, ctx)), ctx) + if params is None or len(params) == 0: + return DatatypeSortRef(Z3_mk_datatype_sort(ctx.ref(), to_symbol(name, ctx), 0, (Sort * 0)()), ctx) + else: + _params = (Sort * len(params))() + for i in range(len(params)): + _params[i] = params[i].ast + return DatatypeSortRef(Z3_mk_datatype_sort(ctx.ref(), to_symbol(name, ctx), len(params), _params), ctx) def TupleSort(name, sorts, ctx=None): """Create a named tuple sort base on a set of underlying sorts