3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-10-29 18:52:29 +00:00

add overloads for finite sets

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2025-10-23 17:29:11 +02:00
parent 541a554ecd
commit 69e0793f6c

View file

@ -5050,6 +5050,8 @@ def EmptySet(s):
K(Int, False)
"""
ctx = s.ctx
if is_finite_set_sort(s):
return FiniteSetEmpty(s)
return ArrayRef(Z3_mk_empty_set(ctx.ref(), s.ast), ctx)
@ -5070,6 +5072,9 @@ def SetUnion(*args):
union(a, b)
"""
args = _get_args(args)
if len(args) > 0 and is_finite_set(args[0]):
from functools import reduce
return reduce(FiniteSetUnion, args)
ctx = _ctx_from_ast_arg_list(args)
_args, sz = _to_ast_array(args)
return ArrayRef(Z3_mk_set_union(ctx.ref(), sz, _args), ctx)
@ -5084,6 +5089,9 @@ def SetIntersect(*args):
"""
args = _get_args(args)
ctx = _ctx_from_ast_arg_list(args)
if len(args) > 0 and is_finite_set(args[0]):
from functools import reduce
return reduce(FiniteSetIntersect, args)
_args, sz = _to_ast_array(args)
return ArrayRef(Z3_mk_set_intersect(ctx.ref(), sz, _args), ctx)
@ -5094,8 +5102,10 @@ def SetAdd(s, e):
>>> SetAdd(a, 1)
Store(a, 1, True)
"""
ctx = _ctx_from_ast_arg_list([s, e])
ctx = _ctx_from_ast_arg_list([s, e])
e = _py2expr(e, ctx)
if is_finite_set(s):
return FiniteSetSingleton(e) | s
return ArrayRef(Z3_mk_set_add(ctx.ref(), s.as_ast(), e.as_ast()), ctx)
@ -5107,6 +5117,8 @@ def SetDel(s, e):
"""
ctx = _ctx_from_ast_arg_list([s, e])
e = _py2expr(e, ctx)
if is_finite_set(s):
return s - FiniteSetSingleton(e)
return ArrayRef(Z3_mk_set_del(ctx.ref(), s.as_ast(), e.as_ast()), ctx)
@ -5128,6 +5140,8 @@ def SetDifference(a, b):
setminus(a, b)
"""
ctx = _ctx_from_ast_arg_list([a, b])
if is_finite_set(a):
return FiniteSetDifference(a, b)
return ArrayRef(Z3_mk_set_difference(ctx.ref(), a.as_ast(), b.as_ast()), ctx)
@ -5139,6 +5153,8 @@ def IsMember(e, s):
"""
ctx = _ctx_from_ast_arg_list([s, e])
e = _py2expr(e, ctx)
if is_finite_set(s):
return FiniteSetIsMember(e, s)
return BoolRef(Z3_mk_set_member(ctx.ref(), e.as_ast(), s.as_ast()), ctx)
@ -5150,6 +5166,8 @@ def IsSubset(a, b):
subset(a, b)
"""
ctx = _ctx_from_ast_arg_list([a, b])
if is_finite_set(a):
return FiniteSetIsSubset(a, b)
return BoolRef(Z3_mk_set_subset(ctx.ref(), a.as_ast(), b.as_ast()), ctx)
@ -5178,7 +5196,7 @@ class FiniteSetSortRef(SortRef):
elem_sort = self.element_sort()
result = FiniteSetEmpty(self)
for e in val:
result = FiniteSetUnion(result, FiniteSetSingleton(_py2expr(e, self.ctx, elem_sort)))
result = FiniteSetUnion(result, Singleton(_py2expr(e, self.ctx, elem_sort)))
return result
_z3_assert(False, "Cannot cast to finite set sort")
@ -5260,9 +5278,9 @@ def FiniteSetEmpty(set_sort):
return FiniteSetRef(Z3_mk_finite_set_empty(ctx.ref(), set_sort.ast), ctx)
def FiniteSetSingleton(elem):
def Singleton(elem):
"""Create a singleton finite set containing elem.
>>> FiniteSetSingleton(IntVal(1))
>>> Singleton(IntVal(1))
set.singleton(1)
"""
ctx = elem.ctx
@ -5311,6 +5329,8 @@ def FiniteSetMember(elem, set):
ctx = _ctx_from_ast_arg_list([elem, set])
return BoolRef(Z3_mk_finite_set_member(ctx.ref(), elem.as_ast(), set.as_ast()), ctx)
def In(elem, set):
return FiniteSetMember(elem, set)
def FiniteSetSize(set):
"""Get the size (cardinality) of a finite set.
@ -11726,6 +11746,14 @@ def Union(*args):
sz = len(args)
if z3_debug():
_z3_assert(sz > 0, "At least one argument expected.")
arg0 = args[0]
if is_finite_set(arg0):
for a in args[1:]:
if not is_finite_set(a):
raise Z3Exception("All arguments must be regular expressions or finite sets.")
arg0 = arg0 | a
return arg0
if z3_debug():
_z3_assert(all([is_re(a) for a in args]), "All arguments must be regular expressions.")
if sz == 1:
return args[0]
@ -11744,6 +11772,14 @@ def Intersect(*args):
sz = len(args)
if z3_debug():
_z3_assert(sz > 0, "At least one argument expected.")
arg0 = args[0]
if is_finite_set(arg0):
for a in args[1:]:
if not is_finite_set(a):
raise Z3Exception("All arguments must be regular expressions or finite sets.")
arg0 = arg0 & a
return arg0
if z3_debug():
_z3_assert(all([is_re(a) for a in args]), "All arguments must be regular expressions.")
if sz == 1:
return args[0]