3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-22 13:53:39 +00:00

port improvements to arith rewriter

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2024-01-03 13:57:09 -08:00
parent a7bfdcd0ea
commit b75367ffc7
2 changed files with 29 additions and 28 deletions

View file

@ -1119,7 +1119,7 @@ br_status arith_rewriter::mk_idiv_core(expr * arg1, expr * arg2, expr_ref & resu
return BR_REWRITE3; return BR_REWRITE3;
} }
} }
if (divides(arg1, arg2, result)) { if (get_divides(arg1, arg2, result)) {
expr_ref zero(m_util.mk_int(0), m); expr_ref zero(m_util.mk_int(0), m);
result = m.mk_ite(m.mk_eq(zero, arg2), m_util.mk_idiv(arg1, zero), result); result = m.mk_ite(m.mk_eq(zero, arg2), m_util.mk_idiv(arg1, zero), result);
return BR_REWRITE_FULL; return BR_REWRITE_FULL;
@ -1137,7 +1137,7 @@ br_status arith_rewriter::mk_idiv_core(expr * arg1, expr * arg2, expr_ref & resu
// //
// implement div ab ac = floor( ab / ac) = floor (b / c) = div b c // implement div ab ac = floor( ab / ac) = floor (b / c) = div b c
// //
bool arith_rewriter::divides(expr* num, expr* den, expr_ref& result) { bool arith_rewriter::get_divides(expr* num, expr* den, expr_ref& result) {
expr_fast_mark1 mark; expr_fast_mark1 mark;
rational num_r(1), den_r(1); rational num_r(1), den_r(1);
expr* num_e = nullptr, *den_e = nullptr; expr* num_e = nullptr, *den_e = nullptr;
@ -1232,20 +1232,22 @@ static rational symmod(rational const& a, rational const& b) {
br_status arith_rewriter::mk_mod_core(expr * arg1, expr * arg2, expr_ref & result) { br_status arith_rewriter::mk_mod_core(expr * arg1, expr * arg2, expr_ref & result) {
set_curr_sort(arg1->get_sort()); set_curr_sort(arg1->get_sort());
numeral x, y; numeral v1, v2;
bool is_num_x = m_util.is_numeral(arg1, x); bool is_int;
bool is_num_y = m_util.is_numeral(arg2, y); bool is_num1 = m_util.is_numeral(arg1, v1, is_int);
if (is_num_x && is_num_y && !y.is_zero()) { bool is_num2 = m_util.is_numeral(arg2, v2, is_int);
result = m_util.mk_int(mod(x, y));
if (is_num1 && is_num2 && !v2.is_zero()) {
result = m_util.mk_numeral(mod(v1, v2), is_int);
return BR_DONE; return BR_DONE;
} }
if (is_num_y && y.is_int() && (y.is_one() || y.is_minus_one())) { if (is_num2 && is_int && (v2.is_one() || v2.is_minus_one())) {
result = m_util.mk_numeral(numeral(0), true); result = m_util.mk_numeral(numeral(0), true);
return BR_DONE; return BR_DONE;
} }
if (arg1 == arg2 && !is_num_y) { if (arg1 == arg2 && !is_num2) {
expr_ref zero(m_util.mk_int(0), m); expr_ref zero(m_util.mk_int(0), m);
result = m.mk_ite(m.mk_eq(arg2, zero), m_util.mk_mod(zero, zero), zero); result = m.mk_ite(m.mk_eq(arg2, zero), m_util.mk_mod(zero, zero), zero);
return BR_DONE; return BR_DONE;
@ -1253,47 +1255,46 @@ br_status arith_rewriter::mk_mod_core(expr * arg1, expr * arg2, expr_ref & resul
// mod is idempotent on non-zero modulus. // mod is idempotent on non-zero modulus.
expr* t1, *t2; expr* t1, *t2;
if (m_util.is_mod(arg1, t1, t2) && t2 == arg2 && is_num_y && y.is_int() && !y.is_zero()) { if (m_util.is_mod(arg1, t1, t2) && t2 == arg2 && is_num2 && is_int && !v2.is_zero()) {
result = arg1;
return BR_DONE;
}
rational lo, hi;
if (is_num_y && get_range(arg1, lo, hi) && 0 <= lo && hi < y) {
result = arg1; result = arg1;
return BR_DONE; return BR_DONE;
} }
// propagate mod inside only if there is something to reduce. // propagate mod inside only if there is something to reduce.
if (is_num_y && y.is_int() && y.is_pos() && (is_add(arg1) || is_mul(arg1))) { if (is_num2 && is_int && v2.is_pos() && (is_add(arg1) || is_mul(arg1))) {
TRACE("mod_bug", tout << "mk_mod:\n" << mk_ismt2_pp(arg1, m) << "\n" << mk_ismt2_pp(arg2, m) << "\n";); TRACE("mod_bug", tout << "mk_mod:\n" << mk_ismt2_pp(arg1, m) << "\n" << mk_ismt2_pp(arg2, m) << "\n";);
expr_ref_buffer args(m); expr_ref_buffer args(m);
bool change = false; bool change = false;
for (expr* arg : *to_app(arg1)) { for (expr* arg : *to_app(arg1)) {
rational arg_v; rational arg_v;
if (m_util.is_numeral(arg, arg_v) && mod(arg_v, y) != arg_v) { if (m_util.is_numeral(arg, arg_v) && mod(arg_v, v2) != arg_v) {
change = true; change = true;
args.push_back(m_util.mk_numeral(mod(arg_v, y), true)); args.push_back(m_util.mk_numeral(mod(arg_v, v2), true));
} }
else if (m_util.is_mod(arg, t1, t2) && t2 == arg2) { else if (m_util.is_mod(arg, t1, t2) && t2 == arg2) {
change = true; change = true;
args.push_back(t1); args.push_back(t1);
} }
else if (m_util.is_mul(arg, t1, t2) && m_util.is_numeral(t1, arg_v) && symmod(arg_v, y) != arg_v) { else if (m_util.is_mul(arg, t1, t2) && m_util.is_numeral(t1, arg_v) && symmod(arg_v, v2) != arg_v) {
change = true; change = true;
args.push_back(m_util.mk_mul(m_util.mk_numeral(symmod(arg_v, y), true), t2)); args.push_back(m_util.mk_mul(m_util.mk_numeral(symmod(arg_v, v2), true), t2));
} }
else { else {
args.push_back(arg); args.push_back(arg);
} }
} }
if (!change) { if (change) {
return BR_FAILED; // did not find any target for applying simplification
}
result = m_util.mk_mod(m.mk_app(to_app(arg1)->get_decl(), args.size(), args.data()), arg2); result = m_util.mk_mod(m.mk_app(to_app(arg1)->get_decl(), args.size(), args.data()), arg2);
TRACE("mod_bug", tout << "mk_mod result: " << mk_ismt2_pp(result, m) << "\n";); TRACE("mod_bug", tout << "mk_mod result: " << mk_ismt2_pp(result, m) << "\n";);
return BR_REWRITE3; return BR_REWRITE3;
} }
}
expr* x, *y;
if (is_num2 && v2.is_pos() && m_util.is_mul(arg1, x, y) && m_util.is_numeral(x, v1, is_int) && divides(v1, v2)) {
result = m_util.mk_mul(x, m_util.mk_mod(y, m_util.mk_int(v2/v1)));
return BR_REWRITE1;
}
return BR_FAILED; return BR_FAILED;
} }

View file

@ -104,7 +104,7 @@ class arith_rewriter : public poly_rewriter<arith_rewriter_core> {
expr_ref neg_monomial(expr * e); expr_ref neg_monomial(expr * e);
expr * mk_sin_value(rational const & k); expr * mk_sin_value(rational const & k);
app * mk_sqrt(rational const & k); app * mk_sqrt(rational const & k);
bool divides(expr* d, expr* n, expr_ref& result); bool get_divides(expr* d, expr* n, expr_ref& result);
expr_ref remove_divisor(expr* arg, expr* num, expr* den); expr_ref remove_divisor(expr* arg, expr* num, expr* den);
void flat_mul(expr* e, ptr_buffer<expr>& args); void flat_mul(expr* e, ptr_buffer<expr>& args);
void remove_divisor(expr* d, ptr_buffer<expr>& args); void remove_divisor(expr* d, ptr_buffer<expr>& args);