diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 104647284..a23656d8b 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -5050,6 +5050,8 @@ def EmptySet(s): K(Int, False) """ ctx = s.ctx + if is_finite_set_sort(s): + return FiniteSetEmpty(s) return ArrayRef(Z3_mk_empty_set(ctx.ref(), s.ast), ctx) @@ -5070,6 +5072,9 @@ def SetUnion(*args): union(a, b) """ args = _get_args(args) + if len(args) > 0 and is_finite_set(args[0]): + from functools import reduce + return reduce(FiniteSetUnion, args) ctx = _ctx_from_ast_arg_list(args) _args, sz = _to_ast_array(args) return ArrayRef(Z3_mk_set_union(ctx.ref(), sz, _args), ctx) @@ -5084,6 +5089,9 @@ def SetIntersect(*args): """ args = _get_args(args) ctx = _ctx_from_ast_arg_list(args) + if len(args) > 0 and is_finite_set(args[0]): + from functools import reduce + return reduce(FiniteSetIntersect, args) _args, sz = _to_ast_array(args) return ArrayRef(Z3_mk_set_intersect(ctx.ref(), sz, _args), ctx) @@ -5094,8 +5102,10 @@ def SetAdd(s, e): >>> SetAdd(a, 1) Store(a, 1, True) """ - ctx = _ctx_from_ast_arg_list([s, e]) + ctx = _ctx_from_ast_arg_list([s, e]) e = _py2expr(e, ctx) + if is_finite_set(s): + return FiniteSetSingleton(e) | s return ArrayRef(Z3_mk_set_add(ctx.ref(), s.as_ast(), e.as_ast()), ctx) @@ -5107,6 +5117,8 @@ def SetDel(s, e): """ ctx = _ctx_from_ast_arg_list([s, e]) e = _py2expr(e, ctx) + if is_finite_set(s): + return s - FiniteSetSingleton(e) return ArrayRef(Z3_mk_set_del(ctx.ref(), s.as_ast(), e.as_ast()), ctx) @@ -5128,6 +5140,8 @@ def SetDifference(a, b): setminus(a, b) """ ctx = _ctx_from_ast_arg_list([a, b]) + if is_finite_set(a): + return FiniteSetDifference(a, b) return ArrayRef(Z3_mk_set_difference(ctx.ref(), a.as_ast(), b.as_ast()), ctx) @@ -5139,6 +5153,8 @@ def IsMember(e, s): """ ctx = _ctx_from_ast_arg_list([s, e]) e = _py2expr(e, ctx) + if is_finite_set(s): + return FiniteSetIsMember(e, s) return BoolRef(Z3_mk_set_member(ctx.ref(), e.as_ast(), s.as_ast()), ctx) @@ -5150,6 +5166,8 @@ def IsSubset(a, b): subset(a, b) """ ctx = _ctx_from_ast_arg_list([a, b]) + if is_finite_set(a): + return FiniteSetIsSubset(a, b) return BoolRef(Z3_mk_set_subset(ctx.ref(), a.as_ast(), b.as_ast()), ctx) @@ -5178,7 +5196,7 @@ class FiniteSetSortRef(SortRef): elem_sort = self.element_sort() result = FiniteSetEmpty(self) for e in val: - result = FiniteSetUnion(result, FiniteSetSingleton(_py2expr(e, self.ctx, elem_sort))) + result = FiniteSetUnion(result, Singleton(_py2expr(e, self.ctx, elem_sort))) return result _z3_assert(False, "Cannot cast to finite set sort") @@ -5260,9 +5278,9 @@ def FiniteSetEmpty(set_sort): return FiniteSetRef(Z3_mk_finite_set_empty(ctx.ref(), set_sort.ast), ctx) -def FiniteSetSingleton(elem): +def Singleton(elem): """Create a singleton finite set containing elem. - >>> FiniteSetSingleton(IntVal(1)) + >>> Singleton(IntVal(1)) set.singleton(1) """ ctx = elem.ctx @@ -5311,6 +5329,8 @@ def FiniteSetMember(elem, set): ctx = _ctx_from_ast_arg_list([elem, set]) return BoolRef(Z3_mk_finite_set_member(ctx.ref(), elem.as_ast(), set.as_ast()), ctx) +def In(elem, set): + return FiniteSetMember(elem, set) def FiniteSetSize(set): """Get the size (cardinality) of a finite set. @@ -11726,6 +11746,14 @@ def Union(*args): sz = len(args) if z3_debug(): _z3_assert(sz > 0, "At least one argument expected.") + arg0 = args[0] + if is_finite_set(arg0): + for a in args[1:]: + if not is_finite_set(a): + raise Z3Exception("All arguments must be regular expressions or finite sets.") + arg0 = arg0 | a + return arg0 + if z3_debug(): _z3_assert(all([is_re(a) for a in args]), "All arguments must be regular expressions.") if sz == 1: return args[0] @@ -11744,6 +11772,14 @@ def Intersect(*args): sz = len(args) if z3_debug(): _z3_assert(sz > 0, "At least one argument expected.") + arg0 = args[0] + if is_finite_set(arg0): + for a in args[1:]: + if not is_finite_set(a): + raise Z3Exception("All arguments must be regular expressions or finite sets.") + arg0 = arg0 & a + return arg0 + if z3_debug(): _z3_assert(all([is_re(a) for a in args]), "All arguments must be regular expressions.") if sz == 1: return args[0]