From df4065536fa51e81a664f506751a4ea547a20e6e Mon Sep 17 00:00:00 2001 From: "Christoph M. Wintersteiger" Date: Wed, 3 Jul 2019 12:32:28 +0100 Subject: [PATCH] Cleaned up FP predicates in the Python API. Fixes #2323. --- src/api/python/z3/z3.py | 119 +++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 64 deletions(-) diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 712b9aace..a33eedb36 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -61,7 +61,7 @@ def z3_debug(): if sys.version < '3': def _is_int(v): - return isinstance(v, (int, long)) + return isinstance(v, (int, long)) else: def _is_int(v): return isinstance(v, int) @@ -191,7 +191,7 @@ class Context: Z3_del_config(conf) def __del__(self): - Z3_del_context(self.ctx) + Z3_del_context(self.ctx) self.ctx = None self.eh = None @@ -322,7 +322,7 @@ class AstRef(Z3PPObject): def __nonzero__(self): return self.__bool__() - + def __bool__(self): if is_true(self): return True @@ -728,9 +728,9 @@ class FuncDeclRef(AstRef): if k == Z3_PARAMETER_INT: result[i] = Z3_get_decl_int_parameter(self.ctx_ref(), self.ast, i) elif k == Z3_PARAMETER_DOUBLE: - result[i] = Z3_get_decl_double_parameter(self.ctx_ref(), self.ast, i) + result[i] = Z3_get_decl_double_parameter(self.ctx_ref(), self.ast, i) elif k == Z3_PARAMETER_RATIONAL: - result[i] = Z3_get_decl_rational_parameter(self.ctx_ref(), self.ast, i) + result[i] = Z3_get_decl_rational_parameter(self.ctx_ref(), self.ast, i) elif k == Z3_PARAMETER_SYMBOL: result[i] = Z3_get_decl_symbol_parameter(self.ctx_ref(), self.ast, i) elif k == Z3_PARAMETER_SORT: @@ -1318,7 +1318,7 @@ def FreshConst(sort, prefix='c'): """Create a fresh constant of a specified sort""" ctx = _get_ctx(sort.ctx) return _to_expr_ref(Z3_mk_fresh_const(ctx.ref(), prefix, sort.ast), ctx) - + def Var(idx, s): """Create a Z3 free variable. Free variables are used to create quantified formulas. @@ -1399,14 +1399,14 @@ class BoolRef(ExprRef): def __rmul__(self, other): return self * other - + def __mul__(self, other): """Create the Z3 expression `self * other`. """ if other == 1: return self if other == 0: - return 0 + return 0 return If(self, other, 0) @@ -1996,7 +1996,7 @@ def _mk_quantifier(is_forall, vs, body, weight=1, qid="", skid="", patterns=[], _z3_assert(is_bool(body) or is_app(vs) or (len(vs) > 0 and is_app(vs[0])), "Z3 expression expected") _z3_assert(is_const(vs) or (len(vs) > 0 and all([ is_const(v) for v in vs])), "Invalid bounded variable(s)") _z3_assert(all([is_pattern(a) or is_expr(a) for a in patterns]), "Z3 patterns expected") - _z3_assert(all([is_expr(p) for p in no_patterns]), "no patterns are Z3 expressions") + _z3_assert(all([is_expr(p) for p in no_patterns]), "no patterns are Z3 expressions") if is_app(vs): ctx = vs.ctx vs = [vs] @@ -4163,7 +4163,7 @@ def BVSubNoOverflow(a, b): _check_bv_args(a, b) a, b = _coerce_exprs(a, b) return BoolRef(Z3_mk_bvsub_no_overflow(a.ctx_ref(), a.as_ast(), b.as_ast()), a.ctx) - + def BVSubNoUnderflow(a, b, signed): """A predicate the determines that bit-vector subtraction does not underflow""" @@ -4223,7 +4223,7 @@ class ArraySortRef(SortRef): >>> A.range() Bool """ - return _to_sort_ref(Z3_get_array_sort_range(self.ctx_ref(), self.ast), self.ctx) + return _to_sort_ref(Z3_get_array_sort_range(self.ctx_ref(), self.ast), self.ctx) class ArrayRef(ExprRef): """Array expressions. """ @@ -4378,7 +4378,7 @@ def ArraySort(*sig): return ArraySortRef(Z3_mk_array_sort(ctx.ref(), d.ast, r.ast), ctx) dom = (Sort * arity)() for i in range(arity): - dom[i] = sig[i].ast + dom[i] = sig[i].ast return ArraySortRef(Z3_mk_array_sort_n(ctx.ref(), arity, dom, r.ast), ctx) def Array(name, dom, rng): @@ -4456,7 +4456,7 @@ def Select(a, i): _z3_assert(is_array(a), "First argument must be a Z3 array expression") return a[i] - + def Map(f, *args): """Return a Z3 map array expression. @@ -4628,7 +4628,7 @@ def SetDifference(a, b): """ ctx = _ctx_from_ast_arg_list([a, b]) return ArrayRef(Z3_mk_set_difference(ctx.ref(), a.as_ast(), b.as_ast()), ctx) - + def IsMember(e, s): """ Check if e is a member of set s >>> a = Const('a', SetSort(IntSort())) @@ -4638,7 +4638,7 @@ def IsMember(e, s): ctx = _ctx_from_ast_arg_list([s,e]) e = _py2expr(e, ctx) return BoolRef(Z3_mk_set_member(ctx.ref(), e.as_ast(), s.as_ast()), ctx) - + def IsSubset(a, b): """ Check if a is a subset of b >>> a = Const('a', SetSort(IntSort())) @@ -4757,7 +4757,7 @@ class ScopedConstructorList: self.c = c self.ctx = ctx def __del__(self): - if self.ctx.ref() is not None: + if self.ctx.ref() is not None: Z3_del_constructor_list(self.ctx.ref(), self.c) def CreateDatatypes(*ds): @@ -4953,7 +4953,7 @@ def TupleSort(name, sorts, ctx = None): """Create a named tuple sort base on a set of underlying sorts Example: >>> pair, mk_pair, (first, second) = TupleSort("pair", [IntSort(), StringSort()]) - """ + """ tuple = Datatype(name, ctx) projects = [ ('project%d' % i, sorts[i]) for i in range(len(sorts)) ] tuple.declare(name, *projects) @@ -4970,8 +4970,8 @@ def DisjointSum(name, sorts, ctx=None): sum.declare("inject%d" % i, ("project%d" % i, sorts[i])) sum = sum.create() return sum, [(sum.constructor(i), sum.accessor(i, 0)) for i in range(len(sorts))] - - + + def EnumSort(name, values, ctx=None): """Return a new enumeration sort named `name` containing the given values. @@ -6679,7 +6679,7 @@ class Solver(Z3PPObject): return AstVector(Z3_solver_get_unsat_core(self.ctx.ref(), self.solver), self.ctx) def consequences(self, assumptions, variables): - """Determine fixed values for the variables based on the solver state and assumptions. + """Determine fixed values for the variables based on the solver state and assumptions. >>> s = Solver() >>> a, b, c, d = Bools('a b c d') >>> s.add(Implies(a,b), Implies(b, c)) @@ -6697,7 +6697,7 @@ class Solver(Z3PPObject): _vars = AstVector(None, self.ctx) for a in variables: _vars.push(a) - variables = _vars + variables = _vars _z3_assert(isinstance(assumptions, AstVector), "ast vector expected") _z3_assert(isinstance(variables, AstVector), "ast vector expected") consequences = AstVector(None, self.ctx) @@ -6713,7 +6713,7 @@ class Solver(Z3PPObject): def from_string(self, s): """Parse assertions from a string""" Z3_solver_from_string(self.ctx.ref(), self.solver, s) - + def cube(self, vars = None): """Get set of cubes The method takes an optional set of variables that restrict which @@ -6728,11 +6728,11 @@ class Solver(Z3PPObject): while True: lvl = self.backtrack_level self.backtrack_level = 4000000000 - r = AstVector(Z3_solver_cube(self.ctx.ref(), self.solver, self.cube_vs.vector, lvl), self.ctx) + r = AstVector(Z3_solver_cube(self.ctx.ref(), self.solver, self.cube_vs.vector, lvl), self.ctx) if (len(r) == 1 and is_false(r[0])): - return - yield r - if (len(r) == 0): + return + yield r + if (len(r) == 0): return def cube_vars(self): @@ -6745,7 +6745,7 @@ class Solver(Z3PPObject): def proof(self): """Return a proof for the last `check()`. Proof construction must be enabled.""" return _to_expr_ref(Z3_solver_get_proof(self.ctx.ref(), self.solver), self.ctx) - + def assertions(self): """Return an AST vector containing all added constraints. @@ -6771,7 +6771,7 @@ class Solver(Z3PPObject): return AstVector(Z3_solver_get_non_units(self.ctx.ref(), self.solver), self.ctx) def trail_levels(self): - """Return trail and decision levels of the solver state after a check() call. + """Return trail and decision levels of the solver state after a check() call. """ trail = self.trail() levels = (ctypes.c_uint * len(trail))() @@ -6779,10 +6779,10 @@ class Solver(Z3PPObject): return trail, levels def trail(self): - """Return trail of the solver state after a check() call. + """Return trail of the solver state after a check() call. """ return AstVector(Z3_solver_get_trail(self.ctx.ref(), self.solver), self.ctx) - + def statistics(self): """Return statistics for the last `check()`. @@ -7443,7 +7443,7 @@ class Optimize(Z3PPObject): for i in range(num): _assumptions[i] = assumptions[i].as_ast() return CheckSatResult(Z3_optimize_check(self.ctx.ref(), self.optimize, num, _assumptions)) - + def reason_unknown(self): """Return a string that describes why the last `check()` returned `unknown`.""" return Z3_optimize_get_reason_unknown(self.ctx.ref(), self.optimize) @@ -7476,7 +7476,7 @@ class Optimize(Z3PPObject): def upper_values(self, obj): if not isinstance(obj, OptimizeObjective): raise Z3Exception("Expecting objective handle returned by maximize/minimize") - return obj.upper_values() + return obj.upper_values() def from_file(self, filename): """Parse assertions and objectives from a file""" @@ -8268,7 +8268,7 @@ def Product(*args): return 1 ctx = _ctx_from_ast_arg_list(args) if ctx is None: - return _reduce(lambda a, b: a * b, args, 1) + return _reduce(lambda a, b: a * b, args, 1) args = _coerce_expr_list(args, ctx) if is_bv(args[0]): return _reduce(lambda a, b: a * b, args, 1) @@ -8880,7 +8880,7 @@ class FPRef(ExprRef): def __neg__(self): """Create the Z3 expression `-self`. - + >>> x = FP('x', Float32()) >>> -x -x @@ -8998,7 +8998,7 @@ def is_fprm_value(a): ### FP Numerals -class FPNumRef(FPRef): +class FPNumRef(FPRef): """The sign of the numeral. >>> x = FPVal(+1.0, FPSort(8, 24)) @@ -9015,7 +9015,7 @@ class FPNumRef(FPRef): return l.value != 0 """The sign of a floating-point numeral as a bit-vector expression. - + Remark: NaN's are invalid arguments. """ def sign_as_bv(self): @@ -9041,7 +9041,7 @@ class FPNumRef(FPRef): if not Z3_fpa_get_numeral_significand_uint64(self.ctx.ref(), self.as_ast(), ptr): raise Z3Exception("error retrieving the significand of a numeral.") return ptr[0] - + """The significand of the numeral as a bit-vector expression. Remark: NaN are invalid arguments. @@ -9390,18 +9390,11 @@ def _mk_fp_unary(f, rm, a, ctx): _z3_assert(is_fp(a), "Second argument must be a Z3 floating-point expression") return FPRef(f(ctx.ref(), rm.as_ast(), a.as_ast()), ctx) -def _mk_fp_unary_norm(f, a, ctx): - ctx = _get_ctx(ctx) - [a] = _coerce_fp_expr_list([a], ctx) - if z3_debug(): - _z3_assert(is_fp(a), "First argument must be a Z3 floating-point expression") - return BoolRef(f(ctx.ref(), a.as_ast()), ctx) - def _mk_fp_unary_pred(f, a, ctx): ctx = _get_ctx(ctx) [a] = _coerce_fp_expr_list([a], ctx) if z3_debug(): - _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression") + _z3_assert(is_fp(a), "First argument must be a Z3 floating-point expression") return BoolRef(f(ctx.ref(), a.as_ast()), ctx) def _mk_fp_bin(f, rm, a, b, ctx): @@ -9557,7 +9550,7 @@ def fpIsNaN(a, ctx=None): >>> fpIsNaN(x) fpIsNaN(x) """ - return _mk_fp_unary_norm(Z3_mk_fpa_is_nan, a, ctx) + return _mk_fp_unary_pred(Z3_mk_fpa_is_nan, a, ctx) def fpIsInf(a, ctx=None): """Create a Z3 floating-point isInfinite expression. @@ -9567,33 +9560,32 @@ def fpIsInf(a, ctx=None): >>> fpIsInf(x) fpIsInf(x) """ - return _mk_fp_unary_norm(Z3_mk_fpa_is_infinite, a, ctx) + return _mk_fp_unary_pred(Z3_mk_fpa_is_infinite, a, ctx) def fpIsZero(a, ctx=None): """Create a Z3 floating-point isZero expression. """ - return _mk_fp_unary_norm(Z3_mk_fpa_is_zero, a, ctx) + return _mk_fp_unary_pred(Z3_mk_fpa_is_zero, a, ctx) def fpIsNormal(a, ctx=None): """Create a Z3 floating-point isNormal expression. """ - return _mk_fp_unary_norm(Z3_mk_fpa_is_normal, a, ctx) + return _mk_fp_unary_pred(Z3_mk_fpa_is_normal, a, ctx) def fpIsSubnormal(a, ctx=None): """Create a Z3 floating-point isSubnormal expression. """ - return _mk_fp_unary_norm(Z3_mk_fpa_is_subnormal, a, ctx) + return _mk_fp_unary_pred(Z3_mk_fpa_is_subnormal, a, ctx) def fpIsNegative(a, ctx=None): """Create a Z3 floating-point isNegative expression. """ - return _mk_fp_unary_norm(Z3_mk_fpa_is_negative, a, ctx) + return _mk_fp_unary_pred(Z3_mk_fpa_is_negative, a, ctx) def fpIsPositive(a, ctx=None): """Create a Z3 floating-point isPositive expression. """ - return _mk_fp_unary_norm(Z3_mk_fpa_is_positive, a, ctx) - return FPRef(Z3_mk_fpa_is_positive(a.ctx_ref(), a.as_ast()), a.ctx) + return _mk_fp_unary_pred(Z3_mk_fpa_is_positive, a, ctx) def _check_fp_args(a, b): if z3_debug(): @@ -9732,7 +9724,7 @@ def fpToFP(a1, a2=None, a3=None, ctx=None): raise Z3Exception("Unsupported combination of arguments for conversion to floating-point term.") def fpBVToFP(v, sort, ctx=None): - """Create a Z3 floating-point conversion expression that represents the + """Create a Z3 floating-point conversion expression that represents the conversion from a bit-vector term to a floating-point term. >>> x_bv = BitVecVal(0x3F800000, 32) @@ -9748,7 +9740,7 @@ def fpBVToFP(v, sort, ctx=None): return FPRef(Z3_mk_fpa_to_fp_bv(ctx.ref(), v.ast, sort.ast), ctx) def fpFPToFP(rm, v, sort, ctx=None): - """Create a Z3 floating-point conversion expression that represents the + """Create a Z3 floating-point conversion expression that represents the conversion from a floating-point term to a floating-point term of different precision. >>> x_sgl = FPVal(1.0, Float32()) @@ -9767,7 +9759,7 @@ def fpFPToFP(rm, v, sort, ctx=None): return FPRef(Z3_mk_fpa_to_fp_float(ctx.ref(), rm.ast, v.ast, sort.ast), ctx) def fpRealToFP(rm, v, sort, ctx=None): - """Create a Z3 floating-point conversion expression that represents the + """Create a Z3 floating-point conversion expression that represents the conversion from a real term to a floating-point term. >>> x_r = RealVal(1.5) @@ -9784,7 +9776,7 @@ def fpRealToFP(rm, v, sort, ctx=None): return FPRef(Z3_mk_fpa_to_fp_real(ctx.ref(), rm.ast, v.ast, sort.ast), ctx) def fpSignedToFP(rm, v, sort, ctx=None): - """Create a Z3 floating-point conversion expression that represents the + """Create a Z3 floating-point conversion expression that represents the conversion from a signed bit-vector term (encoding an integer) to a floating-point term. >>> x_signed = BitVecVal(-5, BitVecSort(32)) @@ -9801,7 +9793,7 @@ def fpSignedToFP(rm, v, sort, ctx=None): return FPRef(Z3_mk_fpa_to_fp_signed(ctx.ref(), rm.ast, v.ast, sort.ast), ctx) def fpUnsignedToFP(rm, v, sort, ctx=None): - """Create a Z3 floating-point conversion expression that represents the + """Create a Z3 floating-point conversion expression that represents the conversion from an unsigned bit-vector term (encoding an integer) to a floating-point term. >>> x_signed = BitVecVal(-5, BitVecSort(32)) @@ -9936,7 +9928,7 @@ class SeqSortRef(SortRef): def basis(self): return _to_sort_ref(Z3_get_seq_sort_basis(self.ctx_ref(), self.ast), self.ctx) - + def StringSort(ctx=None): """Create a string sort @@ -9992,13 +9984,13 @@ class SeqRef(ExprRef): def __le__(self, other): return SeqRef(Z3_mk_str_le(self.ctx_ref(), self.as_ast(), other.as_ast()), self.ctx) - + def __lt__(self, other): return SeqRef(Z3_mk_str_lt(self.ctx_ref(), self.as_ast(), other.as_ast()), self.ctx) def __ge__(self, other): return SeqRef(Z3_mk_str_le(self.ctx_ref(), other.as_ast(), self.as_ast()), self.ctx) - + def __gt__(self, other): return SeqRef(Z3_mk_str_lt(self.ctx_ref(), other.as_ast(), self.as_ast()), self.ctx) @@ -10062,7 +10054,7 @@ def String(name, ctx=None): return SeqRef(Z3_mk_const(ctx.ref(), to_symbol(name, ctx), StringSort(ctx).ast), ctx) def Strings(names, ctx=None): - """Return string constants""" + """Return string constants""" ctx = _get_ctx(ctx) if isinstance(names, str): names = names.split(" ") @@ -10208,7 +10200,7 @@ def LastIndexOf(s, substr): s = _coerce_seq(s, ctx) substr = _coerce_seq(substr, ctx) return ArithRef(Z3_mk_seq_last_index(s.ctx_ref(), s.as_ast(), substr.as_ast()), s.ctx) - + def Length(s): """Obtain the length of a sequence 's' @@ -10414,4 +10406,3 @@ def TransitiveClosure(f): The transitive closure R+ is a new relation. """ return FuncDeclRef(Z3_mk_transitive_closure(f.ctx_ref(), f.ast), f.ctx) -