diff --git a/src/muz/spacer/spacer_proof_utils.cpp b/src/muz/spacer/spacer_proof_utils.cpp index aa7f611c8..ea91a7d04 100644 --- a/src/muz/spacer/spacer_proof_utils.cpp +++ b/src/muz/spacer/spacer_proof_utils.cpp @@ -73,6 +73,129 @@ namespace spacer { + class linear_combinator { + struct scaled_lit { + bool is_pos; + app *lit; + rational coeff; + scaled_lit(bool is_pos, app *lit, const rational &coeff) : + is_pos(is_pos), lit(lit), coeff(coeff) {} + }; + ast_manager &m; + th_rewriter m_rw; + arith_util m_arith; + expr_ref m_sum; + bool m_is_strict; + rational m_lc; + vector m_lits; + public: + linear_combinator(ast_manager &m) : m(m), m_rw(m), m_arith(m), + m_sum(m), m_is_strict(false), + m_lc(1) {} + + void add_lit(app* lit, rational const &coeff, bool is_pos = true) { + m_lits.push_back(scaled_lit(is_pos, lit, coeff)); + } + + void normalize_coeff() { + for (auto &lit : m_lits) + m_lc = lcm(m_lc, denominator(lit.coeff)); + if (!m_lc.is_one()) { + for (auto &lit : m_lits) + lit.coeff *= m_lc; + } + } + + rational const &lc() const {return m_lc;} + + bool process_lit(scaled_lit &lit0) { + arith_util a(m); + app* lit = lit0.lit; + rational &coeff = lit0.coeff; + bool is_pos = lit0.is_pos; + + + if (m.is_not(lit)) { + lit = to_app(lit->get_arg(0)); + is_pos = !is_pos; + } + if (!m_arith.is_le(lit) && !m_arith.is_lt(lit) && + !m_arith.is_ge(lit) && !m_arith.is_gt(lit) && !m.is_eq(lit)) { + return false; + } + SASSERT(lit->get_num_args() == 2); + sort* s = m.get_sort(lit->get_arg(0)); + bool is_int = m_arith.is_int(s); + if (!is_int && m_arith.is_int_expr(lit->get_arg(0))) { + is_int = true; + s = m_arith.mk_int(); + } + + if (!is_int && is_pos && (m_arith.is_gt(lit) || m_arith.is_lt(lit))) { + m_is_strict = true; + } + if (!is_int && !is_pos && (m_arith.is_ge(lit) || m_arith.is_le(lit))) { + m_is_strict = true; + } + + + SASSERT(m_arith.is_int(s) || m_arith.is_real(s)); + expr_ref sign1(m), sign2(m), term(m); + sign1 = m_arith.mk_numeral(m.is_eq(lit)?coeff:abs(coeff), s); + sign2 = m_arith.mk_numeral(m.is_eq(lit)?-coeff:-abs(coeff), s); + if (!m_sum.get()) { + m_sum = m_arith.mk_numeral(rational(0), s); + } + + expr* a0 = lit->get_arg(0); + expr* a1 = lit->get_arg(1); + + if (is_pos && (m_arith.is_ge(lit) || m_arith.is_gt(lit))) { + std::swap(a0, a1); + } + if (!is_pos && (m_arith.is_le(lit) || m_arith.is_lt(lit))) { + std::swap(a0, a1); + } + + // + // Multiplying by coefficients over strict + // and non-strict inequalities: + // + // (a <= b) * 2 + // (a - b <= 0) * 2 + // (2a - 2b <= 0) + + // (a < b) * 2 <=> + // (a +1 <= b) * 2 <=> + // 2a + 2 <= 2b <=> + // 2a+2-2b <= 0 + + bool strict_ineq = + is_pos?(m_arith.is_gt(lit) || m_arith.is_lt(lit)):(m_arith.is_ge(lit) || m_arith.is_le(lit)); + + if (is_int && strict_ineq) { + m_sum = m_arith.mk_add(m_sum, sign1); + } + + term = m_arith.mk_mul(sign1, a0); + m_sum = m_arith.mk_add(m_sum, term); + term = m_arith.mk_mul(sign2, a1); + m_sum = m_arith.mk_add(m_sum, term); + + m_rw(m_sum); + return true; + } + + expr_ref operator()(){ + if (!m_sum) normalize_coeff(); + m_sum.reset(); + for (auto &lit : m_lits) { + if (!process_lit(lit)) return expr_ref(m); + } + return m_sum; + } + }; + /* * ==================================== * methods for transforming proofs @@ -101,6 +224,60 @@ namespace spacer { return pf; } + static bool match_mul(expr *e, expr_ref &var, expr_ref &val, arith_util &a) { + expr *e1 = nullptr, *e2 = nullptr; + if (!a.is_mul(e, e1, e2)) { + if (a.is_numeral(e)) return false; + if (!var || var == e) { + var = e; + val = a.mk_numeral(rational(1), get_sort(e)); + return true; + } + return false; + } + + if (!a.is_numeral(e1)) std::swap(e1, e2); + if (!a.is_numeral(e1)) return false; + + // if variable is given, match it as well + if (!var || var == e2) { + var = e2; + val = e1; + return true; + } + return false; + } + + static expr_ref get_coeff(expr *lit0, expr_ref &var) { + ast_manager &m = var.m(); + arith_util a(m); + + expr *lit = nullptr; + if (!m.is_not(lit0, lit)) lit = lit0; + + expr *e1 = nullptr, *e2 = nullptr; + // assume e2 is numeral and ignore it + if ((a.is_le(lit, e1, e2) || a.is_lt(lit, e1, e2) || + a.is_ge(lit, e1, e2) || a.is_gt(lit, e1, e2) || + m.is_eq(lit, e1, e2))) { + if (a.is_numeral(e1)) std::swap(e1, e2); + } + else { + e1 = lit; + } + + expr_ref val(m); + if (!a.is_add(e1)) { + if (match_mul(e1, var, val, a)) return val; + } + else { + for (auto *arg : *to_app(e1)) { + if (match_mul(arg, var, val, a)) return val; + } + } + return expr_ref(m); + } + // convert assign-bounds lemma to a farkas lemma by adding missing coeff // assume that missing coeff is for premise at position 0 static proof_ref mk_fk_from_ab(ast_manager &m, @@ -108,9 +285,44 @@ namespace spacer { unsigned num_params, parameter const *params) { SASSERT(num_params == parents.size() + 1 /* one param is missing */); + + // compute missing coefficient + linear_combinator lcb(m); + for (unsigned i = 1, sz = parents.size(); i < sz; ++i) { + app *p = to_app(m.get_fact(parents.get(i))); + rational const &r = params[i+1].get_rational(); + lcb.add_lit(p, r); + } + + TRACE("spacer.fkab", + tout << "lit0 is: " << mk_pp(m.get_fact(parents.get(0)), m) << "\n" + << "LCB is: " << lcb() << "\n";); + + expr_ref var(m), val1(m), val2(m); + val1 = get_coeff(m.get_fact(parents.get(0)), var); + val2 = get_coeff(lcb(), var); + TRACE("spacer.fkab", + tout << "var: " << var + << " val1: " << val1 << " val2: " << val2 << "\n";); + + rational rat1, rat2, coeff0; + arith_util a(m); + if (a.is_numeral(val1, rat1) && a.is_numeral(val2, rat2)) { + coeff0 = abs(rat2/rat1); + coeff0 = coeff0 / lcb.lc(); + TRACE("spacer.fkab", tout << "coeff0: " << coeff0 << "\n";); + } + else { + IF_VERBOSE(1, verbose_stream() + << "\n\n\nFAILED TO FIND COEFFICIENT\n\n\n";); + // failed to find a coefficient + return proof_ref(m); + } + + buffer v; v.push_back(parameter(symbol("farkas"))); - v.push_back(parameter(rational(1))); + v.push_back(parameter(coeff0)); for (unsigned i = 2; i < num_params; ++i) v.push_back(params[i]); @@ -124,11 +336,11 @@ namespace spacer { v.size(), v.c_ptr()); SASSERT(is_arith_lemma(m, pf)); + DEBUG_CODE( proof_checker pc(m); expr_ref_vector side(m); - SASSERT(pc.check(pf, side)); - ); + ENSURE(pc.check(pf, side));); return pf; } @@ -179,7 +391,9 @@ namespace spacer { d->get_num_parameters(), d->get_parameters()); } - else { + + // fall back to th-lemma + if (!th_lemma) { th_lemma = mk_th_lemma(m, hyps, d->get_num_parameters(), d->get_parameters());