diff --git a/src/api/api_array.cpp b/src/api/api_array.cpp index c1ea4729e..7aa3a87bf 100644 --- a/src/api/api_array.cpp +++ b/src/api/api_array.cpp @@ -309,6 +309,22 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_sort Z3_API Z3_get_array_sort_domain_n(Z3_context c, Z3_sort t, unsigned idx) { + Z3_TRY; + LOG_Z3_get_array_sort_domain_n(c, t, idx); + RESET_ERROR_CODE(); + CHECK_VALID_AST(t, nullptr); + if (to_sort(t)->get_family_id() == mk_c(c)->get_array_fid() && + to_sort(t)->get_decl_kind() == ARRAY_SORT && + get_array_arity(to_sort(t)) > idx) { + Z3_sort r = reinterpret_cast(get_array_domain(to_sort(t), idx)); + RETURN_Z3(r); + } + SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); + RETURN_Z3(nullptr); + Z3_CATCH_RETURN(nullptr); + } + Z3_sort Z3_API Z3_get_array_sort_range(Z3_context c, Z3_sort t) { Z3_TRY; LOG_Z3_get_array_sort_range(c, t); diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 5559f708e..2f7d33f4e 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -4495,6 +4495,11 @@ class ArraySortRef(SortRef): """ return _to_sort_ref(Z3_get_array_sort_domain(self.ctx_ref(), self.ast), self.ctx) + def domain_n(self, i): + """Return the domain of the array sort `self`. + """ + return _to_sort_ref(Z3_get_array_sort_domain_n(self.ctx_ref(), self.ast, i), self.ctx) + def range(self): """Return the range of the array sort `self`. @@ -4526,6 +4531,10 @@ class ArrayRef(ExprRef): """ return self.sort().domain() + def domain_n(self, i): + """Shorthand for self.sort().domain_n(i)`.""" + return self.sort().domain_n(i) + def range(self): """Shorthand for `self.sort().range()`. @@ -4553,7 +4562,7 @@ class ArrayRef(ExprRef): def _array_select(ar, arg): if isinstance(arg, tuple): - args = [ar.domain().cast(a) for a in arg] + args = [ar.domain_n(i).cast(arg[i]) for i in range(len(arg))] _args, sz = _to_ast_array(args) return _to_expr_ref(Z3_mk_select_n(ar.ctx_ref(), ar.as_ast(), sz, _args), ar.ctx) arg = ar.domain().cast(arg) @@ -4686,7 +4695,7 @@ def ArraySort(*sig): return ArraySortRef(Z3_mk_array_sort_n(ctx.ref(), arity, dom, r.ast), ctx) -def Array(name, dom, rng): +def Array(name, *sorts): """Return an array constant named `name` with the given domain and range sorts. >>> a = Array('a', IntSort(), IntSort()) @@ -4695,12 +4704,12 @@ def Array(name, dom, rng): >>> a[0] a[0] """ - s = ArraySort(dom, rng) + s = ArraySort(sorts) ctx = s.ctx return ArrayRef(Z3_mk_const(ctx.ref(), to_symbol(name, ctx), s.ast), ctx) -def Update(a, i, v): +def Update(a, *args): """Return a Z3 store array expression. >>> a = Array('a', IntSort(), IntSort()) @@ -4716,10 +4725,20 @@ def Update(a, i, v): """ if z3_debug(): _z3_assert(is_array_sort(a), "First argument must be a Z3 array expression") - i = a.sort().domain().cast(i) - v = a.sort().range().cast(v) + args = _get_args(args) ctx = a.ctx - return _to_expr_ref(Z3_mk_store(ctx.ref(), a.as_ast(), i.as_ast(), v.as_ast()), ctx) + if len(args) <= 1: + raise Z3Exception("array update requires index and value arguments") + if len(args) == 2: + i = args[0] + v = args[1] + i = a.sort().domain().cast(i) + v = a.sort().range().cast(v) + return _to_expr_ref(Z3_mk_store(ctx.ref(), a.as_ast(), i.as_ast(), v.as_ast()), ctx) + v = a.sort().range().cast(args[-1]) + idxs = [a.sort().domain_n(i).cast(args[i]) for i in range(len(args)-1)] + _args, sz = _to_ast_array(idxs) + return _to_expr_ref(Z3_mk_store_n(ctx.ref(), a.as_ast(), sz, _args, v.as_ast()), ctx) def Default(a): @@ -4733,7 +4752,7 @@ def Default(a): return a.default() -def Store(a, i, v): +def Store(a, *args): """Return a Z3 store array expression. >>> a = Array('a', IntSort(), IntSort()) @@ -4747,10 +4766,10 @@ def Store(a, i, v): >>> prove(Implies(i != j, s[j] == a[j])) proved """ - return Update(a, i, v) + return Update(a, args) -def Select(a, i): +def Select(a, *args): """Return a Z3 select array expression. >>> a = Array('a', IntSort(), IntSort()) @@ -4760,9 +4779,10 @@ def Select(a, i): >>> eq(Select(a, i), a[i]) True """ + args = _get_args(args) if z3_debug(): _z3_assert(is_array_sort(a), "First argument must be a Z3 array expression") - return a[i] + return a[args] def Map(f, *args): diff --git a/src/api/z3_api.h b/src/api/z3_api.h index c9bf68a67..16d3a292c 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -4358,11 +4358,27 @@ extern "C" { \sa Z3_mk_array_sort \sa Z3_get_sort_kind + \sa Z3_get_array_sort_domain_n def_API('Z3_get_array_sort_domain', SORT, (_in(CONTEXT), _in(SORT))) */ Z3_sort Z3_API Z3_get_array_sort_domain(Z3_context c, Z3_sort t); + + /** + \brief Return the i'th domain sort of an n-dimensional array. + + \pre Z3_get_sort_kind(c, t) == Z3_ARRAY_SORT + + \sa Z3_mk_array_sort + \sa Z3_get_sort_kind + \sa Z3_get_array_sort_domain + + def_API('Z3_get_array_sort_domain_n', SORT, (_in(CONTEXT), _in(SORT), _in(UINT))) + + */ + Z3_sort Z3_API Z3_get_array_sort_domain_n(Z3_context c, Z3_sort t, unsigned idx); + /** \brief Return the range of the given array sort.