diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index b67e873c0..93949f4db 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -804,6 +804,72 @@ bool arith_rewriter::is_arith_term(expr * n) const { return n->get_kind() == AST_APP && to_app(n)->get_family_id() == get_fid(); } +br_status arith_rewriter::mk_ite_core(expr* c, expr* t, expr* e, expr_ref & result) { + numeral v1, v2; + bool is_int; + bool is_num1 = m_util.is_numeral(t, v1, is_int); + bool is_num2 = m_util.is_numeral(e, v2, is_int); + if (is_num1 && is_num2 && v1 == 0 && v2 != 1) { + result = m_util.mk_mul(e, m.mk_ite(c, t, m_util.mk_numeral(rational(1), is_int))); + return BR_DONE; + } + if (is_num1 && is_num2 && v2 == 0 && v1 != 1) { + result = m_util.mk_mul(t, m.mk_ite(c, m_util.mk_numeral(rational(1), is_int), e)); + return BR_DONE; + } + if (is_num1 && is_num2 && is_int && gcd(v1, v2) != 1) { + auto g = gcd(v1, v2); + if (g > 0 && v1 < 0 && v2 < 0) + g = -g; + + result = m_util.mk_numeral(g, is_int); + result = m_util.mk_mul(result, m.mk_ite(c, m_util.mk_numeral(v1/g, true), m_util.mk_numeral(v2/g, true))); + return BR_REWRITE2; + } + if (is_num1 && is_num2 && v1 != 0 && v2 != 0 && v1 != v2) { + if (v1 > v2) + result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(v1 - v2, is_int), m_util.mk_numeral(rational::zero(), is_int))); + else + result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(rational::zero(), is_int), m_util.mk_numeral(v2 - v1, is_int))); + return BR_DONE; + } + expr* x, *y; + if (is_num1 && m_util.is_mul(e, x, y) && m_util.is_numeral(x, v2, is_int) && v2 != 0) { + if (v1 == 0) { + result = m_util.mk_mul(x, m.mk_ite(c, t, y)); + return BR_DONE; + } + if (is_int && divides(v2, v1)) { + result = m_util.mk_mul(x, m.mk_ite(c, m_util.mk_numeral(v1/v2, true), y)); + return BR_DONE; + } + + } + if (is_num2 && m_util.is_mul(t, x, y) && m_util.is_numeral(x, v1, is_int) && v1 != 0) { + if (v2 == 0) { + result = m_util.mk_mul(x, m.mk_ite(c, y, e)); + return BR_DONE; + } + if (is_int && divides(v1, v2)) { + result = m_util.mk_mul(x, m.mk_ite(c, y, m_util.mk_numeral(v2/v1, true))); + return BR_DONE; + } + + } + if (is_num1 && m_util.is_add(e, x, y) && m_util.is_numeral(x, v2, is_int)) { + result = m_util.mk_add(x, m.mk_ite(c, m_util.mk_numeral(v1 - v2, is_int), y)); + return BR_REWRITE2; + } + if (is_num2 && m_util.is_add(t, x, y) && m_util.is_numeral(x, v1, is_int)) { + result = m_util.mk_add(x, m.mk_ite(c, y, m_util.mk_numeral(v2 - v1, is_int))); + return BR_REWRITE2; + } + + + + return BR_FAILED; +} + br_status arith_rewriter::mk_eq_core(expr * arg1, expr * arg2, expr_ref & result) { br_status st = BR_FAILED; if (m_eq2ineq) { diff --git a/src/ast/rewriter/arith_rewriter.h b/src/ast/rewriter/arith_rewriter.h index a1aadfa7f..cfdd1e58f 100644 --- a/src/ast/rewriter/arith_rewriter.h +++ b/src/ast/rewriter/arith_rewriter.h @@ -137,6 +137,7 @@ public: br_status mk_lt_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_ge_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_gt_core(expr * arg1, expr * arg2, expr_ref & result); + br_status mk_ite_core(expr* c, expr* t, expr* e, expr_ref & result); br_status mk_add_core(unsigned num_args, expr * const * args, expr_ref & result); br_status mk_mul_core(unsigned num_args, expr * const * args, expr_ref & result); diff --git a/src/ast/rewriter/poly_rewriter_def.h b/src/ast/rewriter/poly_rewriter_def.h index f739579e6..a2c6b2a2f 100644 --- a/src/ast/rewriter/poly_rewriter_def.h +++ b/src/ast/rewriter/poly_rewriter_def.h @@ -1017,7 +1017,9 @@ bool poly_rewriter::hoist_ite(expr_ref& e) { ++i; } if (!pinned.empty()) { + TRACE("poly_rewriter", tout << e << "\n"); e = mk_add_app(adds.size(), adds.data()); + TRACE("poly_rewriter", tout << e << "\n"); return true; } return false; diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 3af887008..50f0d3c65 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -172,6 +172,9 @@ struct th_rewriter_cfg : public default_rewriter_cfg { family_id s_fid = args[1]->get_sort()->get_family_id(); if (s_fid == m_bv_rw.get_fid()) st = m_bv_rw.mk_ite_core(args[0], args[1], args[2], result); + if (st == BR_FAILED && s_fid == m_a_rw.get_fid()) + st = m_a_rw.mk_ite_core(args[0], args[1], args[2], result); + CTRACE("th_rewriter_step", st != BR_FAILED, tout << result << "\n"); if (st != BR_FAILED) return st; } @@ -197,7 +200,9 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return st; } - return m_b_rw.mk_app_core(f, num, args, result); + st = m_b_rw.mk_app_core(f, num, args, result); + CTRACE("th_rewriter_step", st != BR_FAILED, tout << result << "\n"); + return st; } if (fid == m_a_rw.get_fid() && OP_LE == f->get_decl_kind() && m_seq_rw.u().has_seq()) { st = m_seq_rw.mk_le_core(args[0], args[1], result); @@ -315,7 +320,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return pull_ite_core(f, to_app(args[1]), to_app(args[0]), result); } family_id fid = f->get_family_id(); - if (num == 2 && (fid == m().get_basic_family_id() || fid == m_a_rw.get_fid() || fid == m_bv_rw.get_fid())) { + if (num == 2 && (fid == m().get_basic_family_id() || fid == m_bv_rw.get_fid())) { // (f v3 (ite c v1 v2)) --> (ite v (f v3 v1) (f v3 v2)) if (m().is_value(args[0]) && is_ite_value_tree(args[1])) return pull_ite_core(f, to_app(args[1]), to_app(args[0]), result); @@ -554,6 +559,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { result = m().mk_app(f_prime, common, m().mk_ite(c, new_t, new_e)); else result = m().mk_app(f_prime, m().mk_ite(c, new_t, new_e), common); + TRACE("push_ite", tout << result << "\n";); return BR_DONE; } TRACE("push_ite", tout << "failed\n";);