diff --git a/src/api/python/z3.py b/src/api/python/z3.py index ee33fbe51..82002c4f7 100644 --- a/src/api/python/z3.py +++ b/src/api/python/z3.py @@ -904,12 +904,8 @@ def _to_expr_ref(a, ctx): if sk == Z3_DATATYPE_SORT: return DatatypeRef(a, ctx) if sk == Z3_FLOATING_POINT_SORT: - if k == Z3_APP_AST: - e = ExprRef(a, ctx) - if e.decl().kind() == Z3_OP_FPA_NUM: - return FPNumRef(a, ctx) - else: - return FPRef(a, ctx) + if k == Z3_APP_AST and _is_numeral(ctx, a): + return FPNumRef(a, ctx) else: return FPRef(a, ctx) if sk == Z3_ROUNDING_MODE_SORT: @@ -7817,7 +7813,8 @@ class FPNumRef(FPRef): return self.decl().kind() == Z3_OP_FPA_PLUS_ZERO or self.decl().kind() == Z3_OP_FPA_MINUS_ZERO def isNegative(self): - return (self.num_args() == 0 and (Z3_OP_FPA_MINUS_INF or Z3_OP_FPA_MINUS_ZERO)) or (self.num_args() == 3 and self.arg(0) == BitVecVal(1)) + k = self.decl().kind() + return (self.num_args() == 0 and (k == Z3_OP_FPA_MINUS_INF or k == Z3_OP_FPA_MINUS_ZERO)) or (self.num_args() == 3 and self.arg(0) == BitVecVal(1)) def _to_fpnum(num, ctx=None): if isinstance(num, FPNum): @@ -7851,7 +7848,7 @@ def is_fp_value(a): True """ return is_fp(a) and _is_numeral(a.ctx, a.ast) - + def FPSort(ebits, sbits, ctx=None): """Return a Z3 floating-point sort of the given sizes. If `ctx=None`, then the global context is used. @@ -7989,7 +7986,7 @@ def fpAbs(a): FloatingPoint(8, 24) """ if __debug__: - _z3_assert(is_fp(a), "First argument must be Z3 floating-point expressions") + _z3_assert(is_fp(a), "First argument must be Z3 floating-point expression") return FPRef(Z3_mk_fpa_abs(a.ctx_ref(), a.as_ast()), a.ctx) def fpNeg(a): @@ -8004,7 +8001,7 @@ def fpNeg(a): FloatingPoint(8, 24) """ if __debug__: - _z3_assert(is_fp(a), "First argument must be Z3 floating-point expressions") + _z3_assert(is_fp(a), "First argument must be Z3 floating-point expression") return FPRef(Z3_mk_fpa_neg(a.ctx_ref(), a.as_ast()), a.ctx) def fpAdd(rm, a, b): @@ -8069,7 +8066,7 @@ def fpDiv(rm, a, b): if __debug__: _z3_assert(is_fprm(rm), "First argument must be a Z3 floating-point rounding mode expression") _z3_assert(is_fp(a) and is_fp(b), "Second and third argument must be Z3 floating-point expressions") - return FPRef(Z3_mk_fpa_mul(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx) + return FPRef(Z3_mk_fpa_div(rm.ctx_ref(), rm.as_ast(), a.as_ast(), b.as_ast()), rm.ctx) def fpRem(a, b): """Create a Z3 floating-point remainder expression. diff --git a/src/api/python/z3printer.py b/src/api/python/z3printer.py index a74a67fa8..da49f1719 100644 --- a/src/api/python/z3printer.py +++ b/src/api/python/z3printer.py @@ -71,7 +71,7 @@ _z3_op_to_fpa_normal_str = { Z3_OP_FPA_NAN : 'NaN', Z3_OP_FPA_PLUS_ZERO : 'PZero', Z3_OP_FPA_MINUS_ZERO : 'NZero', Z3_OP_FPA_ADD : 'fpAdd', Z3_OP_FPA_SUB : 'fpSub', Z3_OP_FPA_NEG : 'fpNeg', Z3_OP_FPA_MUL : 'fpMul', Z3_OP_FPA_DIV : 'fpDiv', Z3_OP_FPA_REM : 'fpRem', Z3_OP_FPA_ABS : 'fpAbs', - Z3_OP_FPA_NEG : 'fpNeg', Z3_OP_FPA_MIN : 'fpMin', Z3_OP_FPA_MAX : 'fpMax', + Z3_OP_FPA_MIN : 'fpMin', Z3_OP_FPA_MAX : 'fpMax', Z3_OP_FPA_FMA : 'fpFMA', Z3_OP_FPA_SQRT : 'fpSqrt', Z3_OP_FPA_ROUND_TO_INTEGRAL : 'fpRoundToIntegral', Z3_OP_FPA_EQ : 'fpEQ', Z3_OP_FPA_LT : 'fpLT', Z3_OP_FPA_GT : 'fpGT', Z3_OP_FPA_LE : 'fpLEQ', @@ -93,9 +93,8 @@ _z3_op_to_fpa_pretty_str = { Z3_OP_FPA_PLUS_INF : '+oo', Z3_OP_FPA_MINUS_INF : '-oo', Z3_OP_FPA_NAN : 'NaN', Z3_OP_FPA_PLUS_ZERO : '+0.0', Z3_OP_FPA_MINUS_ZERO : '-0.0', - Z3_OP_FPA_ADD : '+', Z3_OP_FPA_SUB : '-', Z3_OP_FPA_NEG : '-', Z3_OP_FPA_MUL : '*', - Z3_OP_FPA_DIV : '/', Z3_OP_FPA_REM : '%', - Z3_OP_FPA_NEG: '-', + Z3_OP_FPA_ADD : '+', Z3_OP_FPA_SUB : '-', Z3_OP_FPA_MUL : '*', Z3_OP_FPA_DIV : '/', + Z3_OP_FPA_REM : '%', Z3_OP_FPA_NEG : '-', Z3_OP_FPA_EQ : 'fpEQ', Z3_OP_FPA_LT : '<', Z3_OP_FPA_GT : '>', Z3_OP_FPA_LE : '<=', Z3_OP_FPA_GE : '>=' } @@ -578,7 +577,7 @@ class Formatter: def pp_fprm_value(self, a): z3._z3_assert(z3.is_fprm_value(a), 'expected FPRMNumRef') - if self.fpa_pretty and _z3_op_to_fpa_pretty_str.has_key(a.decl().kind()): + if self.fpa_pretty and a.decl().kind() in _z3_op_to_fpa_pretty_str: return to_format(_z3_op_to_fpa_pretty_str.get(a.decl().kind())) else: return to_format(a.as_string()) @@ -617,7 +616,7 @@ class Formatter: return compose(r) else: if (a.isNaN()): - return to_format(_z3_op_to_fpa_pretty_str[Z3_OP_NAN]) + return to_format(_z3_op_to_fpa_pretty_str[Z3_OP_FPA_NAN]) elif (a.isInf()): if (a.isNegative()): return to_format(_z3_op_to_fpa_pretty_str[Z3_OP_FPA_MINUS_INF]) @@ -649,28 +648,33 @@ class Formatter: z3._z3_assert(isinstance(a, z3.FPRef), "type mismatch") k = a.decl().kind() op = '?' - if self.fpa_pretty: - op = _z3_op_to_fpa_pretty_str.get(k, None) - if (op == None): - op = _z3_op_to_str.get(k, None) + if (self.fpa_pretty and k in _z3_op_to_fpa_pretty_str): + op = _z3_op_to_fpa_pretty_str[k] + elif k in _z3_op_to_fpa_normal_str: + op = _z3_op_to_fpa_normal_str[k] + elif k in _z3_op_to_str: + op = _z3_op_to_str[k] n = a.num_args() - if self.fpa_pretty and self.is_infix(k) and n >= 3: - rm = a.arg(0) - if z3.is_fprm_value(rm) and z3._dflt_rm(a.ctx).eq(rm): - arg1 = to_format(self.pp_expr(a.arg(1), d+1, xs)) - arg2 = to_format(self.pp_expr(a.arg(2), d+1, xs)) - r = [] - r.append(arg1) - r.append(to_format(' ')) - r.append(to_format(op)) - r.append(to_format(' ')) - r.append(arg2) - return compose(r) + if self.fpa_pretty: + if self.is_infix(k) and n >= 3: + rm = a.arg(0) + if z3.is_fprm_value(rm) and z3._dflt_rm(a.ctx).eq(rm): + arg1 = to_format(self.pp_expr(a.arg(1), d+1, xs)) + arg2 = to_format(self.pp_expr(a.arg(2), d+1, xs)) + r = [] + r.append(arg1) + r.append(to_format(' ')) + r.append(to_format(op)) + r.append(to_format(' ')) + r.append(arg2) + return compose(r) + elif k == Z3_OP_FPA_NEG: + return compose([to_format('-') , to_format(self.pp_expr(a.arg(0), d+1, xs))]) - if _z3_op_to_fpa_normal_str.has_key(k): - op = _z3_op_to_fpa_normal_str.get(k, None) + if k in _z3_op_to_fpa_normal_str: + op = _z3_op_to_fpa_normal_str[k] r = [] r.append(to_format(op))