diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index c9fc89b89..47cec7aa1 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -515,28 +515,38 @@ br_status arith_rewriter::reduce_power(expr * arg1, expr * arg2, op_kind kind, e } } -bool arith_rewriter::is_factor(expr* s, expr* t) { +bool arith_rewriter::is_mul_factor(expr* s, expr* t) { if (m_util.is_mul(t)) - return any_of(*to_app(t), [&](expr* m) { return m == s; }); - if (m_util.is_add(t)) - return all_of(*to_app(t), [&](expr* f) { return is_factor(s, f); }); + return any_of(*to_app(t), [&](expr* m) { return is_mul_factor(s, m); }); return s == t; } +bool arith_rewriter::is_add_factor(expr* s, expr* t) { + if (m_util.is_add(t)) + return all_of(*to_app(t), [&](expr* f) { return is_add_factor(s, f); }); + return is_mul_factor(s, t); +} + expr_ref arith_rewriter::remove_factor(expr* s, expr* t) { if (m_util.is_mul(t)) { ptr_buffer r; - r.append(to_app(t)->get_num_args(), to_app(t)->get_args()); + r.push_back(t); for (unsigned i = 0; i < r.size(); ++i) { expr* arg = r[i]; + if (m_util.is_mul(arg)) { + r.append(to_app(arg)->get_num_args(), to_app(arg)->get_args()); + r[i] = r.back(); + r.pop_back(); + --i; + continue; + } if (s == arg) { r[i] = r.back(); r.pop_back(); break; } } - SASSERT(to_app(t)->get_num_args() == r.size() + 1); switch (r.size()) { case 0: return expr_ref(m_util.mk_numeral(rational(1), m_util.is_int(t)), m); @@ -546,49 +556,60 @@ expr_ref arith_rewriter::remove_factor(expr* s, expr* t) { return expr_ref(m_util.mk_mul(r.size(), r.data()), m); } } - if (m_util.is_add(t)) { - expr_ref_vector sum(m); - for (expr* arg : *to_app(t)) - sum.push_back(remove_factor(s, arg)); - return expr_ref(m_util.mk_add(sum.size(), sum.data()), m); + expr_ref_vector sum(m); + sum.push_back(t); + for (unsigned i = 0; i < sum.size(); ++i) { + expr* arg = sum.get(i); + if (m_util.is_add(arg)) { + sum.append(to_app(arg)->get_num_args(), to_app(arg)->get_args()); + sum[i] = sum.back(); + sum.pop_back(); + --i; + continue; + } + sum[i] = remove_factor(s, arg); } - SASSERT(s == t); - return expr_ref(m_util.mk_numeral(rational(1), m_util.is_int(t)), m); + if (sum.size() == 1) + return expr_ref(sum.get(0), m); + else + return expr_ref(m_util.mk_add(sum.size(), sum.data()), m); } -br_status arith_rewriter::factor_le_ge_eq(expr * arg1, expr * arg2, op_kind kind, expr_ref & result) { - auto is_nl_mul = [&](expr* t) { - if (!m_util.is_mul(t)) - return false; - if (to_app(t)->get_num_args() <= 1) - return false; - unsigned num_vars = 0; +void arith_rewriter::get_nl_muls(expr* t, ptr_buffer& muls) { + if (m_util.is_mul(t)) { for (expr* arg : *to_app(t)) - if (!m_util.is_numeral(arg)) - ++num_vars; - return num_vars > 1; - }; - auto find_nl_factor = [&](expr* s) -> expr* { - if (is_nl_mul(s)) { - for (expr* arg : *to_app(s)) - if (!m_util.is_numeral(arg) && is_factor(arg, s)) - return arg; + get_nl_muls(arg, muls); + } + else if (!m_util.is_numeral(t)) + muls.push_back(t); +} + +expr* arith_rewriter::find_nl_factor(expr* t) { + ptr_buffer sum, muls; + sum.push_back(t); + + for (unsigned i = 0; i < sum.size(); ++i) { + expr* arg = sum[i]; + if (m_util.is_add(arg)) + sum.append(to_app(arg)->get_num_args(), to_app(arg)->get_args()); + else if (m_util.is_mul(arg)) { + muls.reset(); + get_nl_muls(arg, muls); + if (muls.size() <= 1) + continue; + for (auto m : muls) { + if (is_add_factor(m, t)) + return m; + } return nullptr; } - if (m_util.is_add(s)) { - for (expr* arg : *to_app(s)) { - if (is_nl_mul(arg)) { - for (expr* arg1 : *to_app(arg)) - if (!m_util.is_numeral(arg1) && is_factor(arg1, s)) - return arg1; - return nullptr; - } - } - } - return nullptr; - }; + } + return nullptr; +} +br_status arith_rewriter::factor_le_ge_eq(expr * arg1, expr * arg2, op_kind kind, expr_ref & result) { + if (is_zero(arg2)) { expr* f = find_nl_factor(arg1); if (!f) @@ -604,8 +625,7 @@ br_status arith_rewriter::factor_le_ge_eq(expr * arg1, expr * arg2, op_kind kind break; case LE: result = m.mk_or(m.mk_not(m.mk_iff(m_util.mk_ge(f, z), m_util.mk_ge(f2, z))), result); - break; - + break; } return BR_REWRITE3; } diff --git a/src/ast/rewriter/arith_rewriter.h b/src/ast/rewriter/arith_rewriter.h index 489dc06a6..a1aadfa7f 100644 --- a/src/ast/rewriter/arith_rewriter.h +++ b/src/ast/rewriter/arith_rewriter.h @@ -73,7 +73,10 @@ class arith_rewriter : public poly_rewriter { br_status is_separated(expr * arg1, expr * arg2, op_kind kind, expr_ref & result); bool is_non_negative(expr* e); br_status mk_le_ge_eq_core(expr * arg1, expr * arg2, op_kind kind, expr_ref & result); - bool is_factor(expr* s, expr* t); + bool is_add_factor(expr* s, expr* t); + bool is_mul_factor(expr* s, expr* t); + expr* find_nl_factor(expr* t); + void get_nl_muls(expr* t, ptr_buffer& muls); expr_ref remove_factor(expr* s, expr* t); br_status factor_le_ge_eq(expr * arg1, expr * arg2, op_kind kind, expr_ref & result); diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index a8a89b0d1..88a51fee1 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -558,7 +558,6 @@ namespace sls { ctx.force_restart(); num_fail = 0; } -// m_stats.m_num_steps++; return false; } @@ -1673,7 +1672,7 @@ namespace sls { result += 1; if (dtt_new != 0 && dtt_old == 0) { - if (/*m_use_tabu && */ctx.is_unit(lit)) + if (m_use_tabu && ctx.is_unit(lit)) return 0; result -= 1; }