3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-02 20:31:21 +00:00

Bugfixes for the Python FPA API

Signed-off-by: Christoph M. Wintersteiger <cwinter@microsoft.com>
This commit is contained in:
Christoph M. Wintersteiger 2015-01-22 18:31:30 +00:00
parent 0ab54b9e0c
commit c3ff342bea
2 changed files with 229 additions and 122 deletions

View file

@ -7509,7 +7509,7 @@ class FPSortRef(SortRef):
def sbits(self): def sbits(self):
"""Retrieves the number of bits reserved for the exponent in the FloatingPoint sort `self`. """Retrieves the number of bits reserved for the exponent in the FloatingPoint sort `self`.
>>> b = FloatingPointSort(8, 24) >>> b = FPSort(8, 24)
>>> b.sbits() >>> b.sbits()
24 24
""" """
@ -7520,9 +7520,9 @@ class FPSortRef(SortRef):
>>> b = FPSort(8, 24) >>> b = FPSort(8, 24)
>>> b.cast(1.0) >>> b.cast(1.0)
1.0 1
>>> b.cast(1.0).sexpr() >>> b.cast(1.0).sexpr()
'1.0' '(fp #b0 #x7f #b00000000000000000000000)'
""" """
if is_expr(val): if is_expr(val):
if __debug__: if __debug__:
@ -7579,7 +7579,7 @@ class FPRMSortRef(SortRef):
def is_fp_sort(s): def is_fp_sort(s):
"""Return True if `s` is a Z3 floating-point sort. """Return True if `s` is a Z3 floating-point sort.
>>> is_fp_sort(FloatingPointSort(8, 24)) >>> is_fp_sort(FPSort(8, 24))
True True
>>> is_fp_sort(IntSort()) >>> is_fp_sort(IntSort())
False False
@ -7591,10 +7591,10 @@ def is_fprm_sort(s):
>>> is_fprm_sort(FPSort(8, 24)) >>> is_fprm_sort(FPSort(8, 24))
False False
>>> is_fprm_sort() >>> is_fprm_sort(RNE().sort())
False True
""" """
return isinstance(s, FPSortRef) return isinstance(s, FPRMSortRef)
### FP Expressions ### FP Expressions
@ -7606,15 +7606,15 @@ class FPRef(ExprRef):
>>> x = FP('1.0', FPSort(8, 24)) >>> x = FP('1.0', FPSort(8, 24))
>>> x.sort() >>> x.sort()
(_ FloatingPoint 8 24) FPSort(8, 24)
>>> x.sort() == FloatingPointSort(8, 24) >>> x.sort() == FPSort(8, 24)
True True
""" """
return FPSortRef(Z3_get_sort(self.ctx_ref(), self.as_ast()), self.ctx) return FPSortRef(Z3_get_sort(self.ctx_ref(), self.as_ast()), self.ctx)
def ebits(self): def ebits(self):
"""Retrieves the number of bits reserved for the exponent in the FloatingPoint expression `self`. """Retrieves the number of bits reserved for the exponent in the FloatingPoint expression `self`.
>>> b = FloatingPointSort(8, 24) >>> b = FPSort(8, 24)
>>> b.ebits() >>> b.ebits()
8 8
""" """
@ -7622,7 +7622,7 @@ class FPRef(ExprRef):
def sbits(self): def sbits(self):
"""Retrieves the number of bits reserved for the exponent in the FloatingPoint expression `self`. """Retrieves the number of bits reserved for the exponent in the FloatingPoint expression `self`.
>>> b = FloatingPointSort(8, 24) >>> b = FPSort(8, 24)
>>> b.sbits() >>> b.sbits()
24 24
""" """
@ -7651,12 +7651,12 @@ class FPRef(ExprRef):
def __add__(self, other): def __add__(self, other):
"""Create the Z3 expression `self + other`. """Create the Z3 expression `self + other`.
>>> x = FP('x', 8, 24) >>> x = FP('x', FPSort(8, 24))
>>> y = FP('y', 8, 24) >>> y = FP('y', FPSort(8, 24))
>>> x + y >>> x + y
x + y x + y
>>> (x + y).sort() >>> (x + y).sort()
FloatingPoint(8 24) FPSort(8, 24)
""" """
a, b = z3._coerce_exprs(self, other) a, b = z3._coerce_exprs(self, other)
return fpAdd(_dflt_rm(), self, other) return fpAdd(_dflt_rm(), self, other)
@ -7666,7 +7666,7 @@ class FPRef(ExprRef):
>>> x = FP('x', FPSort(8, 24)) >>> x = FP('x', FPSort(8, 24))
>>> 10 + x >>> 10 + x
10 + x 1.25*(2**3) + x
""" """
a, b = _coerce_exprs(self, other) a, b = _coerce_exprs(self, other)
return fpAdd(_dflt_rm(), other, self) return fpAdd(_dflt_rm(), other, self)
@ -7674,12 +7674,12 @@ class FPRef(ExprRef):
def __sub__(self, other): def __sub__(self, other):
"""Create the Z3 expression `self - other`. """Create the Z3 expression `self - other`.
>>> x = FP('x', 8, 24) >>> x = FP('x', FPSort(8, 24))
>>> y = FP('y', 8, 24) >>> y = FP('y', FPSort(8, 24))
>>> x - y >>> x - y
x - y x - y
>>> (x - y).sort() >>> (x - y).sort()
FloatingPoint(8 24) FPSort(8, 24)
""" """
a, b = z3._coerce_exprs(self, other) a, b = z3._coerce_exprs(self, other)
return fpSub(_dflt_rm(), self, other) return fpSub(_dflt_rm(), self, other)
@ -7689,7 +7689,7 @@ class FPRef(ExprRef):
>>> x = FP('x', FPSort(8, 24)) >>> x = FP('x', FPSort(8, 24))
>>> 10 - x >>> 10 - x
10 - x 1.25*(2**3) - x
""" """
a, b = _coerce_exprs(self, other) a, b = _coerce_exprs(self, other)
return fpSub(_dflt_rm(), other, self) return fpSub(_dflt_rm(), other, self)
@ -7697,22 +7697,27 @@ class FPRef(ExprRef):
def __mul__(self, other): def __mul__(self, other):
"""Create the Z3 expression `self * other`. """Create the Z3 expression `self * other`.
>>> x = FP('x', 8, 24) >>> x = FP('x', FPSort(8, 24))
>>> y = FP('y', 8, 24) >>> y = FP('y', FPSort(8, 24))
>>> x * y >>> x * y
x * y x * y
>>> (x * y).sort() >>> (x * y).sort()
FloatingPoint(8 24) FPSort(8, 24)
>>> 10 * y
1.25*(2**3) * y
""" """
a, b = z3._coerce_exprs(self, other) a, b = z3._coerce_exprs(self, other)
return fpMul(_dflt_rm(), self, other) return fpMul(_dflt_rm(), self, other)
def __rmul_(self, other): def __rmul__(self, other):
"""Create the Z3 expression `other * self`. """Create the Z3 expression `other * self`.
>>> x = FP('x', FPSort(8, 24)) >>> x = FP('x', FPSort(8, 24))
>>> 10 * x >>> y = FP('y', FPSort(8, 24))
10 * x >>> x * y
x * y
>>> x * 10
x * 1.25*(2**3)
""" """
a, b = _coerce_exprs(self, other) a, b = _coerce_exprs(self, other)
return fpMul(_dflt_rm(), other, self) return fpMul(_dflt_rm(), other, self)
@ -7792,7 +7797,12 @@ def RTZ(ctx=None):
def is_fprm(a): def is_fprm(a):
"""Return `True` if `a` is a Z3 floating-point rounding mode expression. """Return `True` if `a` is a Z3 floating-point rounding mode expression.
>>> rm = ? >>> rm = RNE()
>>> is_fprm(rm)
True
>>> rm = 1.0
>>> is_fprm(rm)
False
""" """
return isinstance(a, FPRMRef) return isinstance(a, FPRMRef)
@ -7814,7 +7824,65 @@ class FPNumRef(FPRef):
def isNegative(self): def isNegative(self):
k = self.decl().kind() k = self.decl().kind()
return (self.num_args() == 0 and (k == Z3_OP_FPA_MINUS_INF or k == Z3_OP_FPA_MINUS_ZERO)) or (self.num_args() == 3 and self.arg(0) == BitVecVal(1)) return (self.num_args() == 0 and (k == Z3_OP_FPA_MINUS_INF or k == Z3_OP_FPA_MINUS_ZERO)) or (self.sign() == True)
"""
The sign of the numeral
>>> x = FPNumRef(+1.0, FPSort(8, 24))
>>> x.sign()
False
>>> x = FPNumRef(-1.0, FPSort(8, 24))
>>> x.sign()
True
"""
def sign(self):
l = (ctypes.c_int)()
if Z3_fpa_get_numeral_sign(self.ctx.ref(), self.as_ast(), byref(l)) == False:
raise Z3Exception("error retrieving the sign of a numeral.")
return l.value != 0
"""
The significand of the numeral
>>> x = FPNumRef(2.5, FPSort(8, 24))
1.25
"""
def significand(self):
return Z3_fpa_get_numeral_significand_string(self.ctx.ref(), self.as_ast())
"""
The exponent of the numeral
>>> x = FPNumRef(2.5, FPSort(8, 24))
>>>
1
"""
def exponent(self):
return Z3_fpa_get_numeral_exponent_string(self.ctx.ref(), self.as_ast())
"""
The exponent of the numeral as a long
>>> x = FPNumRef(2.5, FPSort(8, 24))
1
"""
def exponent_as_long(self):
ptr = (ctypes.c_longlong * 1)()
if not Z3_fpa_get_numeral_exponent_int64(self.ctx.ref(), self.as_ast(), ptr):
raise Z3Exception("error retrieving the exponent of a numeral.")
return ptr[0]
"""
The string representation of the numeral
>>> x = FPNumRef(20, FPSort(8, 24))
1.25*(2**4)
"""
def as_string(self):
s = Z3_fpa_get_numeral_string(self.ctx.ref(), self.as_ast())
return ("FPVal(%s, %s)" % (s, FPSortRef(self.sort()).as_string()))
def _to_fpnum(num, ctx=None): def _to_fpnum(num, ctx=None):
if isinstance(num, FPNum): if isinstance(num, FPNum):
@ -7843,7 +7911,7 @@ def is_fp_value(a):
False False
>>> b = FPVal(1.0, FPSort(8, 24)) >>> b = FPVal(1.0, FPSort(8, 24))
>>> b >>> b
1.0p0 1
>>> is_fp_value(b) >>> is_fp_value(b)
True True
""" """
@ -7855,9 +7923,9 @@ def FPSort(ebits, sbits, ctx=None):
>>> Single = FPSort(8, 24) >>> Single = FPSort(8, 24)
>>> Double = FPSort(11, 53) >>> Double = FPSort(11, 53)
>>> Single >>> Single
(_ FloatingPoint 8 24) FPSort(8, 24)
>>> x = Const('x', Single) >>> x = Const('x', Single)
>>> eq(x, FP('x', 8, 24)) >>> eq(x, FP('x', FPSort(8, 24)))
True True
""" """
ctx = z3._get_ctx(ctx) ctx = z3._get_ctx(ctx)
@ -7911,11 +7979,17 @@ def fpZero(s, negative):
def FPVal(sig, exp=None, fps=None, ctx=None): def FPVal(sig, exp=None, fps=None, ctx=None):
"""Return a floating-point value of value `val` and sort `fps`. If `ctx=None`, then the global context is used. """Return a floating-point value of value `val` and sort `fps`. If `ctx=None`, then the global context is used.
>>> v = FPVal(1.0, FPSort(8, 24))) >>> v = FPVal(20.0, FPSort(8, 24))
>>> v >>> v
1.0 1.25*(2**4)
>>> print("0x%.8x" % v.as_long()) >>> print("0x%.8x" % v.exponent_as_long())
0x0000000a 0x00000004
>>> v = FPVal(2.25, FPSort(8, 24))
>>> v
1.125*(2**1)
>>> v = FPVal(-2.25, FPSort(8, 24))
>>> v
-1.125*(2**1)
""" """
ctx = _get_ctx(ctx) ctx = _get_ctx(ctx)
if is_fp_sort(exp): if is_fp_sort(exp):
@ -7923,7 +7997,7 @@ def FPVal(sig, exp=None, fps=None, ctx=None):
exp = None exp = None
elif fps == None: elif fps == None:
fps = _dflt_fps(ctx) fps = _dflt_fps(ctx)
_z3_assert(is_fp_sort(fps), "sort mismatch") _z3_assert(is_fp_sort(fps), "sort mismatch")
if exp == None: if exp == None:
exp = 0 exp = 0
val = _to_float_str(sig) val = _to_float_str(sig)
@ -7942,7 +8016,7 @@ def FP(name, fpsort, ctx=None):
>>> x.ebits() >>> x.ebits()
8 8
>>> x.sort() >>> x.sort()
(_ FloatingPoint 8 24) FPSort(8, 24)
>>> word = FPSort(8, 24) >>> word = FPSort(8, 24)
>>> x2 = FP('x', word) >>> x2 = FP('x', word)
>>> eq(x, x2) >>> eq(x, x2)
@ -7975,17 +8049,28 @@ def fpAbs(a):
"""Create a Z3 floating-point absolute value expression. """Create a Z3 floating-point absolute value expression.
>>> s = FPSort(8, 24) >>> s = FPSort(8, 24)
>>> rm = FPRM.RNE >>> rm = RNE()
>>> x = FPVal(1.0, s) >>> x = FPVal(1.0, s)
>>> fpAbs(x) >>> fpAbs(x)
1.0 fpAbs(1)
>>> x = FPVal(-1.0, s) >>> y = FPVal(-20.0, s)
>>> fpAbs(x) >>> y
1.0 -1.25*(2**4)
>>> fpAbs(y)
fpAbs(-1.25*(2**4))
>>> fpAbs(-1.25*(2**4))
fpAbs(-1.25*(2**4))
>>> fpAbs(x).sort() >>> fpAbs(x).sort()
FloatingPoint(8, 24) FPSort(8, 24)
""" """
if __debug__: ctx = None
if not is_expr(a):
ctx =_get_ctx(ctx)
s = get_default_fp_sort(ctx)
a = FPVal(a, s)
else:
ctx = a.ctx
if __debug__:
_z3_assert(is_fp(a), "First argument must be Z3 floating-point expression") _z3_assert(is_fp(a), "First argument must be Z3 floating-point expression")
return FPRef(Z3_mk_fpa_abs(a.ctx_ref(), a.as_ast()), a.ctx) return FPRef(Z3_mk_fpa_abs(a.ctx_ref(), a.as_ast()), a.ctx)
@ -7993,13 +8078,20 @@ def fpNeg(a):
"""Create a Z3 floating-point addition expression. """Create a Z3 floating-point addition expression.
>>> s = FPSort(8, 24) >>> s = FPSort(8, 24)
>>> rm = FPRM.RNE >>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> fpNeg(x)
fp.add(rm, x, y) -x
>>> fp.add(rm, x, y).sort() >>> fpNeg(x).sort()
FloatingPoint(8, 24) FPSort(8, 24)
""" """
ctx = None
if not is_expr(a):
ctx =_get_ctx(ctx)
s = get_default_fp_sort(ctx)
a = FPVal(a, s)
else:
ctx = a.ctx
if __debug__: if __debug__:
_z3_assert(is_fp(a), "First argument must be Z3 floating-point expression") _z3_assert(is_fp(a), "First argument must be Z3 floating-point expression")
return FPRef(Z3_mk_fpa_neg(a.ctx_ref(), a.as_ast()), a.ctx) return FPRef(Z3_mk_fpa_neg(a.ctx_ref(), a.as_ast()), a.ctx)
@ -8008,79 +8100,88 @@ def fpAdd(rm, a, b):
"""Create a Z3 floating-point addition expression. """Create a Z3 floating-point addition expression.
>>> s = FPSort(8, 24) >>> s = FPSort(8, 24)
>>> rm = FPRM.RNE >>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> y = FP('y', s)
fp.add(rm, x, y) >>> fpAdd(rm, x, y)
>>> fp.add(rm, x, y).sort() fpAdd(RNE(), x, y)
FloatingPoint(8, 24) >>> fpAdd(rm, x, y).sort()
FPSort(8, 24)
""" """
if __debug__: if __debug__:
_z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression")
_z3_assert(is_fp(a) and is_fp(b), "Second and third argument must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression")
a, b = _coerce_exprs(a, b)
return FPRef(Z3_mk_fpa_add(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx) return FPRef(Z3_mk_fpa_add(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx)
def fpSub(rm, a, b): def fpSub(rm, a, b):
"""Create a Z3 floating-point subtraction expression. """Create a Z3 floating-point subtraction expression.
>>> s = FPSort(8, 24) >>> s = FPSort(8, 24)
>>> rm = FPRM.RNE >>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> y = FP('y', s)
fp.add(rm, x, y) >>> fpSub(rm, x, y)
>>> fp.add(rm, x, y).sort() fpSub(RNE(), x, y)
FloatingPoint(8, 24) >>> fpSub(rm, x, y).sort()
FPSort(8, 24)
""" """
if __debug__: if __debug__:
_z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression")
_z3_assert(is_fp(a) and is_fp(b), "Second and third argument must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression")
a, b = _coerce_exprs(a, b)
return FPRef(Z3_mk_fpa_sub(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx) return FPRef(Z3_mk_fpa_sub(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx)
def fpMul(rm, a, b): def fpMul(rm, a, b):
"""Create a Z3 floating-point multiplication expression. """Create a Z3 floating-point multiplication expression.
>>> s = FPSort(8, 24) >>> s = FPSort(8, 24)
>>> rm = FPRM.RNE >>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> y = FP('y', s)
fp.add(rm, x, y) >>> fpMul(rm, x, y)
>>> fp.add(rm, x, y).sort() fpMul(RNE(), x, y)
FloatingPoint(8, 24) >>> fpMul(rm, x, y).sort()
FPSort(8, 24)
""" """
if __debug__: if __debug__:
_z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression")
_z3_assert(is_fp(a) and is_fp(b), "Second and third argument must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression")
a, b = _coerce_exprs(a, b)
return FPRef(Z3_mk_fpa_mul(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx) return FPRef(Z3_mk_fpa_mul(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx)
def fpDiv(rm, a, b): def fpDiv(rm, a, b):
"""Create a Z3 floating-point divison expression. """Create a Z3 floating-point divison expression.
>>> s = FPSort(8, 24) >>> s = FPSort(8, 24)
>>> rm = FPRM.RNE >>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> y = FP('y', s)
fpAdd(rm, x, y) >>> fpDiv(rm, x, y)
>>> fp.add(rm, x, y).sort() fpDiv(RNE(), x, y)
FloatingPoint(8, 24) >>> fpDiv(rm, x, y).sort()
FPSort(8, 24)
""" """
if __debug__: if __debug__:
_z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression")
_z3_assert(is_fp(a) and is_fp(b), "Second and third argument must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression")
a, b = _coerce_exprs(a, b)
return FPRef(Z3_mk_fpa_div(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx) return FPRef(Z3_mk_fpa_div(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx)
def fpRem(a, b): def fpRem(a, b):
"""Create a Z3 floating-point remainder expression. """Create a Z3 floating-point remainder expression.
>>> s = FPSort(8, 24) >>> s = FPSort(8, 24)
>>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> y = FP('y', s)
>>> fpRem(x, y)
fpRem(x, y) fpRem(x, y)
>>> fpRem(rm, x, y).sort() >>> fpRem(x, y).sort()
FloatingPoint(8, 24) FPSort(8, 24)
""" """
if __debug__: if __debug__:
_z3_assert(is_fp(a) and is_fp(b), "Both arguments must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression")
a, b = _coerce_exprs(a, b)
return FPRef(Z3_mk_fpa_rem(a.ctx_ref(), a.as_ast(), b.as_ast()), a.ctx) return FPRef(Z3_mk_fpa_rem(a.ctx_ref(), a.as_ast(), b.as_ast()), a.ctx)
def fpMin(a, b): def fpMin(a, b):
@ -8090,12 +8191,14 @@ def fpMin(a, b):
>>> rm = RNE() >>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> y = FP('y', s)
>>> fpMin(x, y)
fpMin(x, y) fpMin(x, y)
>>> fpMin(rm, x, y).sort() >>> fpMin(x, y).sort()
FloatingPoint(8, 24) FPSort(8, 24)
""" """
if __debug__: if __debug__:
_z3_assert(is_fp(a) and is_fp(b), "Both arguments must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression")
a, b = _coerce_exprs(a, b)
return FPRef(Z3_mk_fpa_min(a.ctx_ref(), a.as_ast(), b.as_ast()), a.ctx) return FPRef(Z3_mk_fpa_min(a.ctx_ref(), a.as_ast(), b.as_ast()), a.ctx)
def fpMax(a, b): def fpMax(a, b):
@ -8105,12 +8208,14 @@ def fpMax(a, b):
>>> rm = RNE() >>> rm = RNE()
>>> x = FP('x', s) >>> x = FP('x', s)
>>> y = FP('y', s) >>> y = FP('y', s)
fpMin(x, y) >>> fpMax(x, y)
>>> fpMin(rm, x, y).sort() fpMax(x, y)
FloatingPoint(8, 24) >>> fpMax(x, y).sort()
FPSort(8, 24)
""" """
if __debug__: if __debug__:
_z3_assert(is_fp(a) and is_fp(b), "Both arguments must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b), "Second or third argument must be a Z3 floating-point expression")
a, b = _coerce_exprs(a, b)
return FPRef(Z3_mk_fpa_max(a.ctx_ref(), a.as_ast(), b.as_ast()), a.ctx) return FPRef(Z3_mk_fpa_max(a.ctx_ref(), a.as_ast(), b.as_ast()), a.ctx)
def fpFMA(rm, a, b, c): def fpFMA(rm, a, b, c):
@ -8118,23 +8223,38 @@ def fpFMA(rm, a, b, c):
""" """
if __debug__: if __debug__:
_z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression")
_z3_assert(is_fp(a) and is_fp(b) and is_fp(c), "Second, third, and fourth argument must be Z3 floating-point expressions") _z3_assert(is_fp(a) or is_fp(b) or is_fp(c), "Second, third, or fourth argument must be a Z3 floating-point expression")
a, b, c = _coerce_expr_list([a, b, c])
return FPRef(Z3_mk_fpa_fma(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast(), c.as_ast()), rm.ctx) return FPRef(Z3_mk_fpa_fma(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast(), c.as_ast()), rm.ctx)
def fpSqrt(rm, a): def fpSqrt(rm, a):
"""Create a Z3 floating-point square root expression. """Create a Z3 floating-point square root expression.
""" """
ctx = None
if not is_expr(a):
ctx =_get_ctx(ctx)
s = get_default_fp_sort(ctx)
a = FPVal(a, s)
else:
ctx = a.ctx
if __debug__: if __debug__:
_z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression")
_z3_assert(is_fp(a), "Second argument must be Z3 floating-point expressions") _z3_assert(is_fp(a), "Second argument must be a Z3 floating-point expressions")
return FPRef(Z3_mk_fpa_sqrt(rm.ctx_ref(), rm.as_ast(), a.as_ast()), rm.ctx) return FPRef(Z3_mk_fpa_sqrt(rm.ctx_ref(), rm.as_ast(), a.as_ast()), rm.ctx)
def fpRoundToIntegral(rm, a): def fpRoundToIntegral(rm, a):
"""Create a Z3 floating-point roundToIntegral expression. """Create a Z3 floating-point roundToIntegral expression.
""" """
ctx = None
if not is_expr(a):
ctx =_get_ctx(ctx)
s = get_default_fp_sort(ctx)
a = FPVal(a, s)
else:
ctx = a.ctx
if __debug__: if __debug__:
_z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression")
_z3_assert(is_fp(a), "Second argument must be Z3 floating-point expressions") _z3_assert(is_fp(a), "Second argument must be a Z3 floating-point expressions")
return FPRef(Z3_mk_fpa_round_to_integral(rm.ctx_ref(), rm.as_ast(), a.as_ast()), rm.ctx) return FPRef(Z3_mk_fpa_round_to_integral(rm.ctx_ref(), rm.as_ast(), a.as_ast()), rm.ctx)
def fpIsNaN(a): def fpIsNaN(a):
@ -8195,11 +8315,9 @@ def fpLT(a, b):
>>> x, y = FPs('x y', FPSort(8, 24)) >>> x, y = FPs('x y', FPSort(8, 24))
>>> fpLT(x, y) >>> fpLT(x, y)
x <= y x < y
>>> (x <= y).sexpr() >>> (x <= y).sexpr()
'?' '(fp.leq x y)'
>>> fpLT(x, y).sexpr()
'?'
""" """
_check_fp_args(a, b) _check_fp_args(a, b)
a, b = _coerce_exprs(a, b) a, b = _coerce_exprs(a, b)
@ -8212,9 +8330,7 @@ def fpLEQ(a, b):
>>> fpLEQ(x, y) >>> fpLEQ(x, y)
x <= y x <= y
>>> (x <= y).sexpr() >>> (x <= y).sexpr()
'?' '(fp.leq x y)'
>>> fpLEQ(x, y).sexpr()
'?'
""" """
_check_fp_args(a, b) _check_fp_args(a, b)
a, b = _coerce_exprs(a, b) a, b = _coerce_exprs(a, b)
@ -8225,11 +8341,9 @@ def fpGT(a, b):
>>> x, y = FPs('x y', FPSort(8, 24)) >>> x, y = FPs('x y', FPSort(8, 24))
>>> fpGT(x, y) >>> fpGT(x, y)
x <= y x > y
>>> (x <= y).sexpr() >>> (x > y).sexpr()
'?' '(fp.gt x y)'
>>> fpGT(x, y).sexpr()
'?'
""" """
_check_fp_args(a, b) _check_fp_args(a, b)
a, b = _coerce_exprs(a, b) a, b = _coerce_exprs(a, b)
@ -8240,12 +8354,12 @@ def fpGEQ(a, b):
"""Create the Z3 floating-point expression `other <= self`. """Create the Z3 floating-point expression `other <= self`.
>>> x, y = FPs('x y', FPSort(8, 24)) >>> x, y = FPs('x y', FPSort(8, 24))
>>> fp_geq(x, y) >>> x + y
x <= y x + y
>>> (x <= y).sexpr() >>> fpGEQ(x, y)
'?' x >= y
>>> fp_geq(x, y).sexpr() >>> (x >= y).sexpr()
'?' '(fp.geq x y)'
""" """
_check_fp_args(a, b) _check_fp_args(a, b)
a, b = _coerce_exprs(a, b) a, b = _coerce_exprs(a, b)
@ -8256,11 +8370,9 @@ def fpEQ(a, b):
>>> x, y = FPs('x y', FPSort(8, 24)) >>> x, y = FPs('x y', FPSort(8, 24))
>>> fpEQ(x, y) >>> fpEQ(x, y)
x <= y fpEQ(x, y)
>>> (x <= y).sexpr()
'?'
>>> fpEQ(x, y).sexpr() >>> fpEQ(x, y).sexpr()
'?' '(fp.eq x y)'
""" """
_check_fp_args(a, b) _check_fp_args(a, b)
a, b = _coerce_exprs(a, b) a, b = _coerce_exprs(a, b)
@ -8271,11 +8383,9 @@ def fpNEQ(a, b):
>>> x, y = FPs('x y', FPSort(8, 24)) >>> x, y = FPs('x y', FPSort(8, 24))
>>> fpNEQ(x, y) >>> fpNEQ(x, y)
x <= y Not(fpEQ(x, y))
>>> (x <= y).sexpr() >>> (x != y).sexpr()
'?' '(not (fp.eq x y))'
>>> fpNEQ(x, y).sexpr()
'?'
""" """
_check_fp_args(a, b) _check_fp_args(a, b)
a, b = _coerce_exprs(a, b) a, b = _coerce_exprs(a, b)

View file

@ -182,9 +182,6 @@ _infix_map = {}
_unary_map = {} _unary_map = {}
_infix_compact_map = {} _infix_compact_map = {}
for (_k,_v) in _z3_op_to_fpa_normal_str.items():
_z3_op_to_str[_k] = _v
for _k in _z3_infix: for _k in _z3_infix:
_infix_map[_k] = True _infix_map[_k] = True
for _k in _z3_unary: for _k in _z3_unary:
@ -515,7 +512,7 @@ class Formatter:
self.precision = 10 self.precision = 10
self.ellipses = to_format(_ellipses) self.ellipses = to_format(_ellipses)
self.max_visited = 10000 self.max_visited = 10000
self.fpa_pretty = False self.fpa_pretty = True
def pp_ellipses(self): def pp_ellipses(self):
return self.ellipses return self.ellipses
@ -576,8 +573,8 @@ class Formatter:
return to_format(a.as_string()) return to_format(a.as_string())
def pp_fprm_value(self, a): def pp_fprm_value(self, a):
z3._z3_assert(z3.is_fprm_value(a), 'expected FPRMNumRef') z3._z3_assert(z3.is_fprm_value(a), 'expected FPRMNumRef')
if self.fpa_pretty and a.decl().kind() in _z3_op_to_fpa_pretty_str: if self.fpa_pretty and (a.decl().kind() in _z3_op_to_fpa_pretty_str):
return to_format(_z3_op_to_fpa_pretty_str.get(a.decl().kind())) return to_format(_z3_op_to_fpa_pretty_str.get(a.decl().kind()))
else: else:
return to_format(_z3_op_to_fpa_normal_str.get(a.decl().kind())) return to_format(_z3_op_to_fpa_normal_str.get(a.decl().kind()))
@ -600,12 +597,12 @@ class Formatter:
else: else:
z3._z3_assert(z3.is_fp_value(a), 'expecting FP num ast') z3._z3_assert(z3.is_fp_value(a), 'expecting FP num ast')
r = [] r = []
sgn = c_long(0) sgn = c_int(0)
sgnb = Z3_fpa_get_numeral_sign(a.ctx_ref(), a.ast, byref(sgn)) sgnb = Z3_fpa_get_numeral_sign(a.ctx_ref(), a.ast, byref(sgn))
sig = Z3_fpa_get_numeral_significand_string(a.ctx_ref(), a.ast) sig = Z3_fpa_get_numeral_significand_string(a.ctx_ref(), a.ast)
exp = Z3_fpa_get_numeral_exponent_string(a.ctx_ref(), a.ast) exp = Z3_fpa_get_numeral_exponent_string(a.ctx_ref(), a.ast)
r.append(to_format('FPVal(')) r.append(to_format('FPVal('))
if not sgnb and sgn: if sgnb and sgn.value != 0:
r.append(to_format('-')) r.append(to_format('-'))
r.append(to_format(sig)) r.append(to_format(sig))
r.append(to_format('*(2**')) r.append(to_format('*(2**'))
@ -634,17 +631,17 @@ class Formatter:
sgnb = Z3_fpa_get_numeral_sign(a.ctx_ref(), a.ast, byref(sgn)) sgnb = Z3_fpa_get_numeral_sign(a.ctx_ref(), a.ast, byref(sgn))
sig = Z3_fpa_get_numeral_significand_string(a.ctx_ref(), a.ast) sig = Z3_fpa_get_numeral_significand_string(a.ctx_ref(), a.ast)
exp = Z3_fpa_get_numeral_exponent_string(a.ctx_ref(), a.ast) exp = Z3_fpa_get_numeral_exponent_string(a.ctx_ref(), a.ast)
if not sgnb and sgn != 0: if sgnb and sgn.value != 0:
r.append(to_format('-')) r.append(to_format('-'))
r.append(to_format(sig)) r.append(to_format(sig))
if (exp != '0'): if (exp != '0'):
r.append(to_format('*(2**')) r.append(to_format('*(2**'))
r.append(to_format(exp)) r.append(to_format(exp))
r.append(to_format(')')) r.append(to_format(')'))
return compose(r) return compose(r)
def pp_fp(self, a, d, xs): def pp_fp(self, a, d, xs):
z3._z3_assert(isinstance(a, z3.FPRef), "type mismatch") z3._z3_assert(isinstance(a, z3.FPRef), "type mismatch")
k = a.decl().kind() k = a.decl().kind()
op = '?' op = '?'
@ -653,7 +650,7 @@ class Formatter:
elif k in _z3_op_to_fpa_normal_str: elif k in _z3_op_to_fpa_normal_str:
op = _z3_op_to_fpa_normal_str[k] op = _z3_op_to_fpa_normal_str[k]
elif k in _z3_op_to_str: elif k in _z3_op_to_str:
op = _z3_op_to_str[k] op = _z3_op_to_str[k]
n = a.num_args() n = a.num_args()
@ -1164,7 +1161,7 @@ def set_fpa_pretty(flag=True):
for _k in _z3_fpa_infix: for _k in _z3_fpa_infix:
_infix_map[_k] = False _infix_map[_k] = False
set_fpa_pretty(True)
def in_html_mode(): def in_html_mode():
return isinstance(_Formatter, HTMLFormatter) return isinstance(_Formatter, HTMLFormatter)