diff --git a/src/ast/rewriter/fpa_rewriter.cpp b/src/ast/rewriter/fpa_rewriter.cpp index 5fb9dfd37..40021e330 100644 --- a/src/ast/rewriter/fpa_rewriter.cpp +++ b/src/ast/rewriter/fpa_rewriter.cpp @@ -70,7 +70,7 @@ br_status fpa_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con case OP_FPA_MAX: SASSERT(num_args == 2); st = mk_max(args[0], args[1], result); break; case OP_FPA_FMA: SASSERT(num_args == 4); st = mk_fma(args[0], args[1], args[2], args[3], result); break; case OP_FPA_SQRT: SASSERT(num_args == 2); st = mk_sqrt(args[0], args[1], result); break; - case OP_FPA_ROUND_TO_INTEGRAL: SASSERT(num_args == 2); st = mk_round(args[0], args[1], result); break; + case OP_FPA_ROUND_TO_INTEGRAL: SASSERT(num_args == 2); st = mk_round_to_integral(args[0], args[1], result); break; case OP_FPA_EQ: SASSERT(num_args == 2); st = mk_float_eq(args[0], args[1], result); break; case OP_FPA_LT: SASSERT(num_args == 2); st = mk_lt(args[0], args[1], result); break; @@ -484,7 +484,7 @@ br_status fpa_rewriter::mk_sqrt(expr * arg1, expr * arg2, expr_ref & result) { return BR_FAILED; } -br_status fpa_rewriter::mk_round(expr * arg1, expr * arg2, expr_ref & result) { +br_status fpa_rewriter::mk_round_to_integral(expr * arg1, expr * arg2, expr_ref & result) { mpf_rounding_mode rm; if (m_util.is_rm_numeral(arg1, rm)) { scoped_mpf v2(m_fm); diff --git a/src/ast/rewriter/fpa_rewriter.h b/src/ast/rewriter/fpa_rewriter.h index 2c76fad6a..2da839718 100644 --- a/src/ast/rewriter/fpa_rewriter.h +++ b/src/ast/rewriter/fpa_rewriter.h @@ -57,7 +57,7 @@ public: br_status mk_max(expr * arg1, expr * arg2, expr_ref & result); br_status mk_fma(expr * arg1, expr * arg2, expr * arg3, expr * arg4, expr_ref & result); br_status mk_sqrt(expr * arg1, expr * arg2, expr_ref & result); - br_status mk_round(expr * arg1, expr * arg2, expr_ref & result); + br_status mk_round_to_integral(expr * arg1, expr * arg2, expr_ref & result); br_status mk_float_eq(expr * arg1, expr * arg2, expr_ref & result); br_status mk_lt(expr * arg1, expr * arg2, expr_ref & result); br_status mk_gt(expr * arg1, expr * arg2, expr_ref & result); diff --git a/src/util/mpf.cpp b/src/util/mpf.cpp index 9de56773a..9b6db5213 100644 --- a/src/util/mpf.cpp +++ b/src/util/mpf.cpp @@ -1003,9 +1003,30 @@ void mpf_manager::round_to_integral(mpf_rounding_mode rm, mpf const & x, mpf & o mk_nan(x.ebits, x.sbits, o); else if (is_inf(x)) set(o, x); - else if (x.exponent < 0) + else if (is_zero(x)) mk_zero(x.ebits, x.sbits, x.sign, o); - else if (x.exponent >= x.sbits-1) + else if (x.exponent < 0) { + if (rm == MPF_ROUND_TOWARD_ZERO || + rm == MPF_ROUND_TOWARD_NEGATIVE) + mk_pzero(x.ebits, x.sbits, o); + else if (rm == MPF_ROUND_NEAREST_TEVEN || + rm == MPF_ROUND_NEAREST_TAWAY) { + bool tie = m_mpz_manager.is_zero(x.significand) && x.exponent == -1; + if (tie && rm == MPF_ROUND_NEAREST_TEVEN) + mk_pzero(x.ebits, x.sbits, o); + else if (tie && rm == MPF_ROUND_NEAREST_TAWAY) + mk_one(x.ebits, x.sbits, o); + else if (x.exponent < -1 || m_mpz_manager.lt(x.significand, m_powers2(x.sbits-2))) + mk_pzero(x.ebits, x.sbits, o); + else + mk_one(x.ebits, x.sbits, o); + } + else { + SASSERT(rm == MPF_ROUND_TOWARD_POSITIVE); + mk_one(x.ebits, x.sbits, o); + } + } + else if (x.exponent >= x.sbits - 1) set(o, x); else { SASSERT(x.exponent >= 0 && x.exponent < x.sbits-1); @@ -1016,21 +1037,62 @@ void mpf_manager::round_to_integral(mpf_rounding_mode rm, mpf const & x, mpf & o scoped_mpf a(*this); set(a, x); - unpack(a, true); + unpack(a, true); // A includes hidden bit TRACE("mpf_dbg", tout << "A = " << to_string(a) << std::endl;); - + + SASSERT(m_mpz_manager.lt(a.significand(), m_powers2(x.sbits))); + SASSERT(m_mpz_manager.ge(a.significand(), m_powers2(x.sbits - 1))); + o.exponent = a.exponent(); m_mpz_manager.set(o.significand, a.significand()); - unsigned q = (unsigned) o.exponent; - unsigned shift = o.sbits-q-1; - TRACE("mpf_dbg", tout << "Q = " << q << " shift=" << shift << std::endl;); - m_mpz_manager.machine_div2k(o.significand, shift); - m_mpz_manager.mul2k(o.significand, shift+3); + unsigned shift = o.sbits - ((unsigned)o.exponent) - 1; + const mpz & shift_p = m_powers2(shift); + TRACE("mpf_dbg", tout << "shift=" << shift << std::endl;); + + scoped_mpz div(m_mpz_manager), rem(m_mpz_manager); + m_mpz_manager.machine_div_rem(o.significand, shift_p, div, rem); + TRACE("mpf_dbg", tout << "div=" << m_mpz_manager.to_string(div) << " rem=" << m_mpz_manager.to_string(rem) << std::endl;); - round(rm, o); - } + switch (rm) { + case MPF_ROUND_NEAREST_TEVEN: + case MPF_ROUND_NEAREST_TAWAY: { + scoped_mpz t(m_mpz_manager); + m_mpz_manager.mul2k(rem, 1, t); + bool tie = m_mpz_manager.eq(t, shift_p); + if (tie && + (rm == MPF_ROUND_NEAREST_TEVEN && m_mpz_manager.is_odd(div)) || + (rm == MPF_ROUND_NEAREST_TAWAY && m_mpz_manager.is_even(div))) + m_mpz_manager.inc(div); + else if (m_mpz_manager.gt(t, shift_p)) + m_mpz_manager.inc(div); + break; + } + case MPF_ROUND_TOWARD_POSITIVE: + if (!m_mpz_manager.is_zero(rem) && !o.sign) + m_mpz_manager.inc(div); + break; + case MPF_ROUND_TOWARD_NEGATIVE: + if (!m_mpz_manager.is_zero(rem) && o.sign) + m_mpz_manager.inc(div); + break; + case MPF_ROUND_TOWARD_ZERO: + default: + /* nothing */; + } + + m_mpz_manager.mul2k(div, shift, o.significand); + SASSERT(m_mpz_manager.ge(o.significand, m_powers2(o.sbits - 1))); + + // re-normalize + while (m_mpz_manager.ge(o.significand, m_powers2(o.sbits))) { + m_mpz_manager.machine_div2k(o.significand, 1); + o.exponent++; + } + + m_mpz_manager.sub(o.significand, m_powers2(o.sbits - 1), o.significand); // strip hidden bit + } TRACE("mpf_dbg", tout << "INTEGRAL = " << to_string(o) << std::endl;); } @@ -1449,6 +1511,14 @@ void mpf_manager::mk_nan(unsigned ebits, unsigned sbits, mpf & o) { o.sign = false; } +void mpf_manager::mk_one(unsigned ebits, unsigned sbits, mpf & o) const { + o.sbits = sbits; + o.ebits = ebits; + o.sign = false; + m_mpz_manager.set(o.significand, 0); + o.exponent = 0; +} + void mpf_manager::mk_max_value(unsigned ebits, unsigned sbits, bool sign, mpf & o) { o.sbits = sbits; o.ebits = ebits; diff --git a/src/util/mpf.h b/src/util/mpf.h index bac502c58..533944172 100644 --- a/src/util/mpf.h +++ b/src/util/mpf.h @@ -208,7 +208,9 @@ public: void to_sbv_mpq(mpf_rounding_mode rm, const mpf & x, scoped_mpq & o); -protected: +protected: + void mk_one(unsigned ebits, unsigned sbits, mpf & o) const; + bool has_bot_exp(mpf const & x); bool has_top_exp(mpf const & x);