3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-10 19:27:06 +00:00
z3/lib/float_rewriter.cpp
Leonardo de Moura e9eab22e5c Z3 sources
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
2012-10-02 11:35:25 -07:00

442 lines
14 KiB
C++

/*++
Copyright (c) 2012 Microsoft Corporation
Module Name:
float_rewriter.cpp
Abstract:
Basic rewriting rules for floating point numbers.
Author:
Leonardo (leonardo) 2012-02-02
Notes:
--*/
#include"float_rewriter.h"
float_rewriter::float_rewriter(ast_manager & m, params_ref const & p):
m_util(m) {
updt_params(p);
}
float_rewriter::~float_rewriter() {
}
void float_rewriter::updt_params(params_ref const & p) {
}
void float_rewriter::get_param_descrs(param_descrs & r) {
}
br_status float_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) {
br_status st = BR_FAILED;
SASSERT(f->get_family_id() == get_fid());
switch (f->get_decl_kind()) {
case OP_TO_FLOAT: st = mk_to_float(f, num_args, args, result); break;
case OP_FLOAT_ADD: SASSERT(num_args == 3); st = mk_add(args[0], args[1], args[2], result); break;
case OP_FLOAT_SUB: SASSERT(num_args == 3); st = mk_sub(args[0], args[1], args[2], result); break;
case OP_FLOAT_UMINUS: SASSERT(num_args == 1); st = mk_uminus(args[0], result); break;
case OP_FLOAT_MUL: SASSERT(num_args == 3); st = mk_mul(args[0], args[1], args[2], result); break;
case OP_FLOAT_DIV: SASSERT(num_args == 3); st = mk_div(args[0], args[1], args[2], result); break;
case OP_FLOAT_REM: SASSERT(num_args == 2); st = mk_rem(args[0], args[1], result); break;
case OP_FLOAT_ABS: SASSERT(num_args == 1); st = mk_abs(args[0], result); break;
case OP_FLOAT_MIN: SASSERT(num_args == 2); st = mk_min(args[0], args[1], result); break;
case OP_FLOAT_MAX: SASSERT(num_args == 2); st = mk_max(args[0], args[1], result); break;
case OP_FLOAT_FUSED_MA: SASSERT(num_args == 4); st = mk_fused_ma(args[0], args[1], args[2], args[3], result); break;
case OP_FLOAT_SQRT: SASSERT(num_args == 2); st = mk_sqrt(args[0], args[1], result); break;
case OP_FLOAT_ROUND_TO_INTEGRAL: SASSERT(num_args == 2); st = mk_round(args[0], args[1], result); break;
case OP_FLOAT_EQ: SASSERT(num_args == 2); st = mk_float_eq(args[0], args[1], result); break;
case OP_FLOAT_LT: SASSERT(num_args == 2); st = mk_lt(args[0], args[1], result); break;
case OP_FLOAT_GT: SASSERT(num_args == 2); st = mk_gt(args[0], args[1], result); break;
case OP_FLOAT_LE: SASSERT(num_args == 2); st = mk_le(args[0], args[1], result); break;
case OP_FLOAT_GE: SASSERT(num_args == 2); st = mk_ge(args[0], args[1], result); break;
case OP_FLOAT_IS_ZERO: SASSERT(num_args == 1); st = mk_is_zero(args[0], result); break;
case OP_FLOAT_IS_NZERO: SASSERT(num_args == 1); st = mk_is_nzero(args[0], result); break;
case OP_FLOAT_IS_PZERO: SASSERT(num_args == 1); st = mk_is_pzero(args[0], result); break;
case OP_FLOAT_IS_SIGN_MINUS: SASSERT(num_args == 1); st = mk_is_sign_minus(args[0], result); break;
}
return st;
}
br_status float_rewriter::mk_to_float(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) {
SASSERT(f->get_num_parameters() == 2);
SASSERT(f->get_parameter(0).is_int());
SASSERT(f->get_parameter(1).is_int());
unsigned ebits = f->get_parameter(0).get_int();
unsigned sbits = f->get_parameter(1).get_int();
if (num_args == 2) {
mpf_rounding_mode rm;
if (!m_util.is_rm(args[0], rm))
return BR_FAILED;
rational q;
if (!m_util.au().is_numeral(args[1], q))
return BR_FAILED;
mpf v;
m_util.fm().set(v, ebits, sbits, rm, q.to_mpq());
result = m_util.mk_value(v);
m_util.fm().del(v);
return BR_DONE;
}
else if (num_args == 3 &&
m_util.is_rm(m().get_sort(args[0])) &&
m_util.au().is_real(args[1]) &&
m_util.au().is_int(args[2])) {
mpf_rounding_mode rm;
if (!m_util.is_rm(args[0], rm))
return BR_FAILED;
rational q;
if (!m_util.au().is_numeral(args[1], q))
return BR_FAILED;
rational e;
if (!m_util.au().is_numeral(args[2], e))
return BR_FAILED;
TRACE("fp_rewriter", tout << "q: " << q << ", e: " << e << "\n";);
mpf v;
m_util.fm().set(v, ebits, sbits, rm, q.to_mpq(), e.to_mpq().numerator());
result = m_util.mk_value(v);
m_util.fm().del(v);
return BR_DONE;
}
else {
return BR_FAILED;
}
}
br_status float_rewriter::mk_add(expr * arg1, expr * arg2, expr * arg3, expr_ref & result) {
mpf_rounding_mode rm;
if (m_util.is_rm(arg1, rm)) {
scoped_mpf v2(m_util.fm()), v3(m_util.fm());
if (m_util.is_value(arg2, v2) && m_util.is_value(arg3, v3)) {
scoped_mpf t(m_util.fm());
m_util.fm().add(rm, v2, v3, t);
result = m_util.mk_value(t);
return BR_DONE;
}
}
return BR_FAILED;
}
br_status float_rewriter::mk_sub(expr * arg1, expr * arg2, expr * arg3, expr_ref & result) {
// a - b = a + (-b)
result = m_util.mk_add(arg1, arg2, m_util.mk_uminus(arg3));
return BR_REWRITE2;
}
br_status float_rewriter::mk_mul(expr * arg1, expr * arg2, expr * arg3, expr_ref & result) {
mpf_rounding_mode rm;
if (m_util.is_rm(arg1, rm)) {
scoped_mpf v2(m_util.fm()), v3(m_util.fm());
if (m_util.is_value(arg2, v2) && m_util.is_value(arg3, v3)) {
scoped_mpf t(m_util.fm());
m_util.fm().mul(rm, v2, v3, t);
result = m_util.mk_value(t);
return BR_DONE;
}
}
return BR_FAILED;
}
br_status float_rewriter::mk_div(expr * arg1, expr * arg2, expr * arg3, expr_ref & result) {
mpf_rounding_mode rm;
if (m_util.is_rm(arg1, rm)) {
scoped_mpf v2(m_util.fm()), v3(m_util.fm());
if (m_util.is_value(arg2, v2) && m_util.is_value(arg3, v3)) {
scoped_mpf t(m_util.fm());
m_util.fm().div(rm, v2, v3, t);
result = m_util.mk_value(t);
return BR_DONE;
}
}
return BR_FAILED;
}
br_status float_rewriter::mk_uminus(expr * arg1, expr_ref & result) {
if (m_util.is_nan(arg1)) {
// -nan --> nan
result = arg1;
return BR_DONE;
}
if (m_util.is_plus_inf(arg1)) {
// - +oo --> -oo
result = m_util.mk_minus_inf(m().get_sort(arg1));
return BR_DONE;
}
if (m_util.is_minus_inf(arg1)) {
// - -oo -> +oo
result = m_util.mk_plus_inf(m().get_sort(arg1));
return BR_DONE;
}
if (m_util.is_uminus(arg1)) {
// - - a --> a
result = to_app(arg1)->get_arg(0);
return BR_DONE;
}
scoped_mpf v1(m_util.fm());
if (m_util.is_value(arg1, v1)) {
m_util.fm().neg(v1);
result = m_util.mk_value(v1);
return BR_DONE;
}
// TODO: more simplifications
return BR_FAILED;
}
br_status float_rewriter::mk_rem(expr * arg1, expr * arg2, expr_ref & result) {
scoped_mpf v1(m_util.fm()), v2(m_util.fm());
if (m_util.is_value(arg1, v1) && m_util.is_value(arg2, v2)) {
scoped_mpf t(m_util.fm());
m_util.fm().rem(v1, v2, t);
result = m_util.mk_value(t);
return BR_DONE;
}
return BR_FAILED;
}
br_status float_rewriter::mk_abs(expr * arg1, expr_ref & result) {
if (m_util.is_nan(arg1)) {
result = arg1;
return BR_DONE;
}
sort * s = m().get_sort(arg1);
result = m().mk_ite(m_util.mk_lt(arg1, m_util.mk_pzero(s)),
m_util.mk_uminus(arg1),
arg1);
return BR_REWRITE2;
}
br_status float_rewriter::mk_min(expr * arg1, expr * arg2, expr_ref & result) {
if (m_util.is_nan(arg1)) {
result = arg2;
return BR_DONE;
}
if (m_util.is_nan(arg2)) {
result = arg1;
return BR_DONE;
}
// expand as using ite's
result = m().mk_ite(mk_eq_nan(arg1),
arg2,
m().mk_ite(mk_eq_nan(arg2),
arg1,
m().mk_ite(m_util.mk_lt(arg1, arg2),
arg1,
arg2)));
return BR_REWRITE_FULL;
}
br_status float_rewriter::mk_max(expr * arg1, expr * arg2, expr_ref & result) {
if (m_util.is_nan(arg1)) {
result = arg2;
return BR_DONE;
}
if (m_util.is_nan(arg2)) {
result = arg1;
return BR_DONE;
}
// expand as using ite's
result = m().mk_ite(mk_eq_nan(arg1),
arg2,
m().mk_ite(mk_eq_nan(arg2),
arg1,
m().mk_ite(m_util.mk_gt(arg1, arg2),
arg1,
arg2)));
return BR_REWRITE_FULL;
}
br_status float_rewriter::mk_fused_ma(expr * arg1, expr * arg2, expr * arg3, expr * arg4, expr_ref & result) {
mpf_rounding_mode rm;
if (m_util.is_rm(arg1, rm)) {
scoped_mpf v2(m_util.fm()), v3(m_util.fm()), v4(m_util.fm());
if (m_util.is_value(arg2, v2) && m_util.is_value(arg3, v3) && m_util.is_value(arg4, v4)) {
scoped_mpf t(m_util.fm());
m_util.fm().fused_mul_add(rm, v2, v3, v4, t);
result = m_util.mk_value(t);
return BR_DONE;
}
}
return BR_FAILED;
}
br_status float_rewriter::mk_sqrt(expr * arg1, expr * arg2, expr_ref & result) {
mpf_rounding_mode rm;
if (m_util.is_rm(arg1, rm)) {
scoped_mpf v2(m_util.fm());
if (m_util.is_value(arg2, v2)) {
scoped_mpf t(m_util.fm());
m_util.fm().sqrt(rm, v2, t);
result = m_util.mk_value(t);
return BR_DONE;
}
}
return BR_FAILED;
}
br_status float_rewriter::mk_round(expr * arg1, expr * arg2, expr_ref & result) {
mpf_rounding_mode rm;
if (m_util.is_rm(arg1, rm)) {
scoped_mpf v2(m_util.fm());
if (m_util.is_value(arg2, v2)) {
scoped_mpf t(m_util.fm());
m_util.fm().round_to_integral(rm, v2, t);
result = m_util.mk_value(t);
return BR_DONE;
}
}
return BR_FAILED;
}
// This the floating point theory ==
br_status float_rewriter::mk_float_eq(expr * arg1, expr * arg2, expr_ref & result) {
scoped_mpf v1(m_util.fm()), v2(m_util.fm());
if (m_util.is_value(arg1, v1) && m_util.is_value(arg2, v2)) {
result = (m_util.fm().eq(v1, v2)) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
return BR_FAILED;
}
// Return (= arg NaN)
app * float_rewriter::mk_eq_nan(expr * arg) {
return m().mk_eq(arg, m_util.mk_nan(m().get_sort(arg)));
}
// Return (not (= arg NaN))
app * float_rewriter::mk_neq_nan(expr * arg) {
return m().mk_not(mk_eq_nan(arg));
}
br_status float_rewriter::mk_lt(expr * arg1, expr * arg2, expr_ref & result) {
if (m_util.is_nan(arg1) || m_util.is_nan(arg2)) {
result = m().mk_false();
return BR_DONE;
}
if (m_util.is_minus_inf(arg1)) {
// -oo < arg2 --> not(arg2 = -oo) and not(arg2 = NaN)
result = m().mk_and(m().mk_not(m().mk_eq(arg2, arg1)), mk_neq_nan(arg2));
return BR_REWRITE2;
}
if (m_util.is_minus_inf(arg2)) {
// arg1 < -oo --> false
result = m().mk_false();
return BR_DONE;
}
if (m_util.is_plus_inf(arg1)) {
// +oo < arg2 --> false
result = m().mk_false();
return BR_DONE;
}
if (m_util.is_plus_inf(arg2)) {
// arg1 < +oo --> not(arg1 = +oo) and not(arg1 = NaN)
result = m().mk_and(m().mk_not(m().mk_eq(arg1, arg2)), mk_neq_nan(arg1));
return BR_REWRITE2;
}
scoped_mpf v1(m_util.fm()), v2(m_util.fm());
if (m_util.is_value(arg1, v1) && m_util.is_value(arg2, v2)) {
result = (m_util.fm().lt(v1, v2)) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
// TODO: more simplifications
return BR_FAILED;
}
br_status float_rewriter::mk_gt(expr * arg1, expr * arg2, expr_ref & result) {
result = m_util.mk_lt(arg2, arg1);
return BR_REWRITE1;
}
br_status float_rewriter::mk_le(expr * arg1, expr * arg2, expr_ref & result) {
if (m_util.is_nan(arg1) || m_util.is_nan(arg2)) {
result = m().mk_false();
return BR_DONE;
}
scoped_mpf v1(m_util.fm()), v2(m_util.fm());
if (m_util.is_value(arg1, v1) && m_util.is_value(arg2, v2)) {
result = (m_util.fm().le(v1, v2)) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
return BR_FAILED;
}
br_status float_rewriter::mk_ge(expr * arg1, expr * arg2, expr_ref & result) {
result = m_util.mk_le(arg2, arg1);
return BR_REWRITE1;
}
br_status float_rewriter::mk_is_zero(expr * arg1, expr_ref & result) {
scoped_mpf v(m_util.fm());
if (m_util.is_value(arg1, v)) {
result = (m_util.fm().is_zero(v)) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
return BR_FAILED;
}
br_status float_rewriter::mk_is_nzero(expr * arg1, expr_ref & result) {
scoped_mpf v(m_util.fm());
if (m_util.is_value(arg1, v)) {
result = (m_util.fm().is_nzero(v)) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
return BR_FAILED;
}
br_status float_rewriter::mk_is_pzero(expr * arg1, expr_ref & result) {
scoped_mpf v(m_util.fm());
if (m_util.is_value(arg1, v)) {
result = (m_util.fm().is_pzero(v)) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
return BR_FAILED;
}
br_status float_rewriter::mk_is_sign_minus(expr * arg1, expr_ref & result) {
scoped_mpf v(m_util.fm());
if (m_util.is_value(arg1, v)) {
result = (m_util.fm().is_neg(v)) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
return BR_FAILED;
}
// This the SMT =
br_status float_rewriter::mk_eq_core(expr * arg1, expr * arg2, expr_ref & result) {
scoped_mpf v1(m_util.fm()), v2(m_util.fm());
if (m_util.is_value(arg1, v1) && m_util.is_value(arg2, v2)) {
result = (v1 == v2) ? m().mk_true() : m().mk_false();
return BR_DONE;
}
return BR_FAILED;
}