diff --git a/src/tactic/fpa/fpa2bv_converter.cpp b/src/tactic/fpa/fpa2bv_converter.cpp index 0913fbbcf..d085ce277 100644 --- a/src/tactic/fpa/fpa2bv_converter.cpp +++ b/src/tactic/fpa/fpa2bv_converter.cpp @@ -1390,7 +1390,142 @@ void fpa2bv_converter::mk_fusedma(func_decl * f, unsigned num, expr * const * ar } void fpa2bv_converter::mk_sqrt(func_decl * f, unsigned num, expr * const * args, expr_ref & result) { - NOT_IMPLEMENTED_YET(); + SASSERT(num == 2); + + expr_ref rm(m), x(m); + rm = args[0]; + x = args[1]; + + expr_ref nan(m), nzero(m), pzero(m), ninf(m), pinf(m); + mk_nan(f, nan); + mk_nzero(f, nzero); + mk_pzero(f, pzero); + mk_minus_inf(f, ninf); + mk_plus_inf(f, pinf); + + expr_ref x_is_nan(m), x_is_zero(m), x_is_pos(m), x_is_inf(m); + mk_is_nan(x, x_is_nan); + mk_is_zero(x, x_is_zero); + mk_is_pos(x, x_is_pos); + mk_is_inf(x, x_is_inf); + + expr_ref zero1(m), one1(m); + zero1 = m_bv_util.mk_numeral(0, 1); + one1 = m_bv_util.mk_numeral(1, 1); + + expr_ref c1(m), c2(m), c3(m), c4(m), c5(m), c6(m); + expr_ref v1(m), v2(m), v3(m), v4(m), v5(m), v6(m), v7(m); + + // (x is NaN) -> NaN + c1 = x_is_nan; + v1 = x; + + // (x is +oo) -> +oo + mk_is_pinf(x, c2); + v2 = x; + + // (x is +-0) -> +-0 + mk_is_zero(x, c3); + v3 = x; + + // (x < 0) -> NaN + mk_is_neg(x, c4); + v4 = nan; + + // else comes the actual square root. + unsigned ebits = m_util.get_ebits(f->get_range()); + unsigned sbits = m_util.get_sbits(f->get_range()); + SASSERT(ebits <= sbits); + + expr_ref a_sgn(m), a_sig(m), a_exp(m), a_lz(m); + unpack(x, a_sgn, a_sig, a_exp, a_lz, true); + + dbg_decouple("fpa2bv_sqrt_sig", a_sig); + dbg_decouple("fpa2bv_sqrt_exp", a_exp); + + SASSERT(m_bv_util.get_bv_size(a_sig) == sbits); + SASSERT(m_bv_util.get_bv_size(a_exp) == ebits); + + expr_ref res_sgn(m), res_sig(m), res_exp(m); + + res_sgn = zero1; + + expr_ref real_exp(m); + real_exp = m_bv_util.mk_bv_sub(m_bv_util.mk_sign_extend(1, a_exp), m_bv_util.mk_zero_extend(1, a_lz)); + res_exp = m_bv_util.mk_sign_extend(2, m_bv_util.mk_extract(ebits, 1, real_exp)); + + expr_ref e_is_odd(m); + e_is_odd = m.mk_eq(m_bv_util.mk_extract(0, 0, real_exp), one1); + + dbg_decouple("fpa2bv_sqrt_e_is_odd", e_is_odd); + dbg_decouple("fpa2bv_sqrt_real_exp", real_exp); + + expr_ref sig_prime(m); + m_simp.mk_ite(e_is_odd, m_bv_util.mk_concat(a_sig, zero1), + m_bv_util.mk_concat(zero1, a_sig), + sig_prime); + SASSERT(m_bv_util.get_bv_size(sig_prime) == sbits+1); + dbg_decouple("fpa2bv_sqrt_sig_prime", sig_prime); + + // This is algorithm 10.2 in the Handbook of Floating-Point Arithmetic + expr_ref Q(m), R(m), S(m), T(m); + + const mpz & p2 = fu().fm().m_powers2(sbits-1); + Q = m_bv_util.mk_numeral(p2, sbits+2); + R = m_bv_util.mk_bv_sub(m_bv_util.mk_concat(zero1, sig_prime), Q); + S = Q; + + for (unsigned i = 0; i < sbits; i++) { + dbg_decouple("fpa2bv_sqrt_Q", Q); + dbg_decouple("fpa2bv_sqrt_R", R); + + S = m_bv_util.mk_concat(zero1, m_bv_util.mk_extract(sbits+1, 1, S)); + + expr_ref twoQ_plus_S(m); + twoQ_plus_S = m_bv_util.mk_bv_add(m_bv_util.mk_concat(Q, zero1), m_bv_util.mk_concat(zero1, S)); + T = m_bv_util.mk_bv_sub(m_bv_util.mk_concat(R, zero1), twoQ_plus_S); + + dbg_decouple("fpa2bv_sqrt_T", T); + + SASSERT(m_bv_util.get_bv_size(Q) == sbits + 2); + SASSERT(m_bv_util.get_bv_size(R) == sbits + 2); + SASSERT(m_bv_util.get_bv_size(S) == sbits + 2); + SASSERT(m_bv_util.get_bv_size(T) == sbits + 3); + + expr_ref t_lt_0(m); + m_simp.mk_eq(m_bv_util.mk_extract(sbits+2, sbits+2, T), one1, t_lt_0); + + m_simp.mk_ite(t_lt_0, Q, + m_bv_util.mk_bv_add(Q, S), + Q); + m_simp.mk_ite(t_lt_0, m_bv_util.mk_concat(m_bv_util.mk_extract(sbits, 0, R), zero1), + m_bv_util.mk_extract(sbits+1, 0, T), + R); + } + + expr_ref rest(m), last(m), q_is_odd(m), rest_ext(m); + last = m_bv_util.mk_extract(0, 0, Q); + rest = m_bv_util.mk_extract(sbits, 1, Q); + m_simp.mk_eq(last, one1, q_is_odd); + dbg_decouple("fpa2bv_sqrt_q_is_odd", q_is_odd); + rest_ext = m_bv_util.mk_concat(rest, m_bv_util.mk_numeral(0, 4)); + m_simp.mk_ite(q_is_odd, m_bv_util.mk_bv_add(rest_ext, m_bv_util.mk_numeral(8, sbits+4)), + rest_ext, + res_sig); + + SASSERT(m_bv_util.get_bv_size(res_sig) == sbits + 4); + + expr_ref rounded(m); + round(f->get_range(), rm, res_sgn, res_sig, res_exp, rounded); + v5 = rounded; + + // And finally, we tie them together. + mk_ite(c4, v4, v5, result); + mk_ite(c3, v3, result, result); + mk_ite(c2, v2, result, result); + mk_ite(c1, v1, result, result); + + SASSERT(is_well_sorted(m, result)); } void fpa2bv_converter::mk_round_to_integral(func_decl * f, unsigned num, expr * const * args, expr_ref & result) {