diff --git a/src/tactic/arith/fm_tactic.cpp b/src/tactic/arith/fm_tactic.cpp index d2e73e54a..d0564139a 100644 --- a/src/tactic/arith/fm_tactic.cpp +++ b/src/tactic/arith/fm_tactic.cpp @@ -61,7 +61,8 @@ class fm_tactic : public tactic { return m.is_false(val); } - r_kind process(func_decl * x, expr * cls, arith_util & u, model& ev, rational & r) { + r_kind process(func_decl * x, expr * cls, arith_util & u, model& ev, rational & r, expr_ref& r_e) { + r_e = nullptr; unsigned num_lits; expr * const * lits; if (m.is_or(cls)) { @@ -93,6 +94,7 @@ class fm_tactic : public tactic { expr * lhs = to_app(l)->get_arg(0); expr * rhs = to_app(l)->get_arg(1); rational c; + expr_ref c_e(m); if (!u.is_numeral(rhs, c)) return NONE; if (neg) @@ -133,27 +135,41 @@ class fm_tactic : public tactic { expr_ref val(m); val = ev(monomial); rational tmp; - if (!u.is_numeral(val, tmp)) - return NONE; - if (neg) - tmp.neg(); - c -= tmp; + if (u.is_numeral(val, tmp)) { + if (neg) + tmp.neg(); + c -= tmp; + } + else { + // this happens for algebraic numerals + if (neg) + val = u.mk_uminus(val); + if (!c_e) + c_e = u.mk_uminus(val); + else + c_e = u.mk_sub(c_e, val); + } } } if (u.is_int(x->get_range()) && strict) { // a*x < c --> a*x <= c-1 SASSERT(c.is_int()); c--; + SASSERT(!c_e); } is_lower = a_val.is_neg(); c /= a_val; + if (c_e) + c_e = u.mk_div(c_e, u.mk_numeral(a_val, false)); if (u.is_int(x->get_range())) { + SASSERT(!c_e); if (is_lower) c = ceil(c); else c = floor(c); } r = c; + r_e = c_e; } } (void)found; @@ -187,6 +203,12 @@ class fm_tactic : public tactic { //model_evaluator ev(*(md.get())); //ev.set_model_completion(true); arith_util u(m); + auto mk_max = [&](expr* a, expr* b) { + return expr_ref(m.mk_ite(u.mk_ge(a, b), a, b), m); + }; + auto mk_min = [&](expr* a, expr* b) { + return expr_ref(m.mk_ite(u.mk_ge(a, b), b, a), m); + }; unsigned i = m_xs.size(); while (i > 0) { --i; @@ -194,42 +216,67 @@ class fm_tactic : public tactic { rational lower; rational upper; rational val; - bool has_lower = false; - bool has_upper = false; + expr_ref val_e(m), val_upper_e(m), val_lower_e(m); + bool has_lower = false, has_upper = false; TRACE("fm_mc", tout << "processing " << x->get_name() << "\n";); for (expr* cl : m_clauses[i]) { if (!m.inc()) throw tactic_exception(m.limit().get_cancel_msg()); - switch (process(x, cl, u, *md, val)) { + switch (process(x, cl, u, *md, val, val_e)) { case NONE: TRACE("fm_mc", tout << "no bound for:\n" << mk_ismt2_pp(cl, m) << "\n";); break; case LOWER: TRACE("fm_mc", tout << "lower bound: " << val << " for:\n" << mk_ismt2_pp(cl, m) << "\n";); - if (!has_lower || val > lower) - lower = val; - has_lower = true; + if (val_e) + val_lower_e = val_lower_e != nullptr ? mk_max(val_lower_e, val_e) : val_e; + else if (!has_lower || val > lower) + lower = val, has_lower = true; break; case UPPER: TRACE("fm_mc", tout << "upper bound: " << val << " for:\n" << mk_ismt2_pp(cl, m) << "\n";); - if (!has_upper || val < upper) - upper = val; - has_upper = true; + if (val_e) + val_upper_e = val_upper_e != nullptr ? mk_min(val_upper_e, val_e) : val_e; + else if (!has_upper || val < upper) + upper = val, has_upper = true; break; } } expr * x_val; + if (u.is_int(x->get_range())) { - if (has_lower) + if (val_lower_e) { + x_val = val_lower_e; + if (has_lower) + x_val = mk_max(x_val, u.mk_numeral(lower, true)); + } + else if (val_upper_e) { + x_val = val_upper_e; + if (has_upper) + x_val = mk_min(x_val, u.mk_numeral(upper, true)); + } + else if (has_lower) x_val = u.mk_numeral(lower, true); else if (has_upper) x_val = u.mk_numeral(upper, true); else x_val = u.mk_numeral(rational(0), true); + } else { - if (has_lower && has_upper) + if (val_lower_e && has_lower) + val_lower_e = mk_max(val_lower_e, u.mk_numeral(lower, false)); + if (val_upper_e && has_upper) + val_upper_e = mk_min(val_upper_e, u.mk_numeral(upper, false)); + + if (val_lower_e && val_upper_e) + x_val = u.mk_div(u.mk_add(val_lower_e, val_upper_e), u.mk_real(2)); + else if (val_lower_e) + x_val = u.mk_add(val_lower_e, u.mk_real(1)); + else if (val_upper_e) + x_val = u.mk_sub(val_upper_e, u.mk_real(1)); + else if (has_lower && has_upper) x_val = u.mk_numeral((upper + lower)/rational(2), false); else if (has_lower) x_val = u.mk_numeral(lower + rational(1), false);