diff --git a/src/ast/rewriter/pb_rewriter.cpp b/src/ast/rewriter/pb_rewriter.cpp index ce6e294c7..eb85f8ec9 100644 --- a/src/ast/rewriter/pb_rewriter.cpp +++ b/src/ast/rewriter/pb_rewriter.cpp @@ -20,6 +20,7 @@ Notes: #include "pb_rewriter.h" #include "pb_rewriter_def.h" #include "ast_pp.h" +#include "ast_util.h" #include "ast_smt_pp.h" @@ -245,21 +246,31 @@ br_status pb_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * cons case l_false: result = m.mk_false(); break; - default: + default: { + bool all_unit = true; + unsigned sz = vec.size(); m_args.reset(); m_coeffs.reset(); - for (unsigned i = 0; i < vec.size(); ++i) { + for (unsigned i = 0; i < sz; ++i) { m_args.push_back(vec[i].first); m_coeffs.push_back(vec[i].second); + all_unit &= m_coeffs.back().is_one(); } if (is_eq) { - result = m_util.mk_eq(vec.size(), m_coeffs.c_ptr(), m_args.c_ptr(), k); + result = m_util.mk_eq(sz, m_coeffs.c_ptr(), m_args.c_ptr(), k); + } + else if (all_unit && k.is_one()) { + result = mk_or(m, sz, m_args.c_ptr()); + } + else if (all_unit && k == rational(sz)) { + result = mk_and(m, sz, m_args.c_ptr()); } else { - result = m_util.mk_ge(vec.size(), m_coeffs.c_ptr(), m_args.c_ptr(), k); + result = m_util.mk_ge(sz, m_coeffs.c_ptr(), m_args.c_ptr(), k); } break; } + } TRACE("pb", expr_ref tmp(m); tmp = m.mk_app(f, num_args, args); diff --git a/src/tactic/arith/lia2card_tactic.cpp b/src/tactic/arith/lia2card_tactic.cpp index 7d9efd44b..0d64353df 100644 --- a/src/tactic/arith/lia2card_tactic.cpp +++ b/src/tactic/arith/lia2card_tactic.cpp @@ -20,24 +20,90 @@ Notes: #include"cooperate.h" #include"bound_manager.h" #include"ast_pp.h" -#include"expr_safe_replace.h" // NB: should use proof-producing expr_substitute in polished version. #include"pb_decl_plugin.h" #include"arith_decl_plugin.h" +#include"rewriter_def.h" +#include"ast_util.h" class lia2card_tactic : public tactic { + struct lia_rewriter_cfg : public default_rewriter_cfg { + ast_manager& m; + lia2card_tactic& t; + arith_util a; + expr_ref_vector args; + vector coeffs; + rational coeff; + + br_status mk_app_core(func_decl* f, unsigned sz, expr*const* es, expr_ref& result) { + args.reset(); + coeffs.reset(); + coeff.reset(); + if (is_decl_of(f, a.get_family_id(), OP_LE) && + t.get_pb_sum(es[0], rational::one(), args, coeffs, coeff) && + t.get_pb_sum(es[1], -rational::one(), args, coeffs, coeff)) { + result = t.mk_le(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff); + return BR_DONE; + } + if (is_decl_of(f, a.get_family_id(), OP_GE) && + t.get_pb_sum(es[1], rational::one(), args, coeffs, coeff) && + t.get_pb_sum(es[0], -rational::one(), args, coeffs, coeff)) { + result = t.mk_le(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff); + return BR_DONE; + } + if (is_decl_of(f, a.get_family_id(), OP_LT) && + t.get_pb_sum(es[1], rational::one(), args, coeffs, coeff) && + t.get_pb_sum(es[0], -rational::one(), args, coeffs, coeff)) { + result = m.mk_not(t.mk_le(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff)); + return BR_DONE; + } + if (is_decl_of(f, a.get_family_id(), OP_GT) && + t.get_pb_sum(es[0], rational::one(), args, coeffs, coeff) && + t.get_pb_sum(es[1], -rational::one(), args, coeffs, coeff)) { + result = m.mk_not(t.mk_le(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff)); + return BR_DONE; + } + if (m.is_eq(f) && + t.get_pb_sum(es[0], rational::one(), args, coeffs, coeff) && + t.get_pb_sum(es[1], -rational::one(), args, coeffs, coeff)) { + result = t.mk_eq(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff); + return BR_DONE; + } + return BR_FAILED; + } + + bool rewrite_patterns() const { return false; } + bool flat_assoc(func_decl * f) const { return false; } + br_status reduce_app(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) { + result_pr = 0; + return mk_app_core(f, num, args, result); + } + lia_rewriter_cfg(lia2card_tactic& t):m(t.m), t(t), a(m), args(m) {} + }; + + class lia_rewriter : public rewriter_tpl { + lia_rewriter_cfg m_cfg; + public: + lia_rewriter(lia2card_tactic& t): + rewriter_tpl(t.m, false, m_cfg), + m_cfg(t) + {} + }; + public: typedef obj_hashtable expr_set; ast_manager & m; arith_util a; + lia_rewriter m_rw; params_ref m_params; pb_util m_pb; mutable ptr_vector* m_todo; expr_set* m_01s; bool m_compile_equality; - + lia2card_tactic(ast_manager & _m, params_ref const & p): m(_m), a(m), + m_rw(*this), m_pb(m), m_todo(alloc(ptr_vector)), m_01s(alloc(expr_set)), @@ -50,6 +116,7 @@ public: } void set_cancel(bool f) { + m_rw.set_cancel(f); } void updt_params(params_ref const & p) { @@ -84,18 +151,12 @@ public: TRACE("pb", tout << "add bound " << mk_pp(x, m) << "\n";); } } - - expr_safe_replace sub(m); - extract_pb_substitution(g, sub); - - expr_ref new_curr(m); - proof_ref new_pr(m); - - for (unsigned i = 0; i < g->size(); i++) { - expr * curr = g->form(i); - sub(curr, new_curr); - if (m.proofs_enabled()) { - new_pr = m.mk_rewrite(curr, new_curr); + for (unsigned i = 0; i < g->size(); i++) { + expr_ref new_curr(m); + proof_ref new_pr(m); + m_rw(g->form(i), new_curr, new_pr); + if (m.proofs_enabled() && !new_pr) { + new_pr = m.mk_rewrite(g->form(i), new_curr); new_pr = m.mk_modus_ponens(g->pr(i), new_pr); } g->update(i, new_curr, new_pr, g->dep(i)); @@ -109,33 +170,6 @@ public: // TBD: support proof conversion (or not..) } - void extract_pb_substitution(goal_ref const& g, expr_safe_replace& sub) { - ast_mark mark; - for (unsigned i = 0; i < g->size(); i++) { - extract_pb_substitution(mark, g->form(i), sub); - } - } - - void extract_pb_substitution(ast_mark& mark, expr* fml, expr_safe_replace& sub) { - expr_ref tmp(m); - m_todo->reset(); - m_todo->push_back(fml); - while (!m_todo->empty()) { - expr* e = m_todo->back(); - m_todo->pop_back(); - if (mark.is_marked(e) || !is_app(e)) { - continue; - } - mark.mark(e, true); - if (get_pb_relation(sub, e, tmp)) { - sub.insert(e, tmp); - continue; - } - app* ap = to_app(e); - m_todo->append(ap->get_num_args(), ap->get_args()); - } - } - bool is_01var(expr* x) const { return m_01s->contains(x); @@ -146,31 +180,6 @@ public: return expr_ref(r, m); } - bool get_pb_relation(expr_safe_replace& sub, expr* fml, expr_ref& result) { - expr* x, *y; - expr_ref_vector args(m); - vector coeffs; - rational coeff; - if ((a.is_le(fml, x, y) || a.is_ge(fml, y, x)) && - get_pb_sum(x, rational::one(), args, coeffs, coeff) && - get_pb_sum(y, -rational::one(), args, coeffs, coeff)) { - result = mk_le(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff); - return true; - } - else if ((a.is_lt(fml, y, x) || a.is_gt(fml, x, y)) && - get_pb_sum(x, rational::one(), args, coeffs, coeff) && - get_pb_sum(y, -rational::one(), args, coeffs, coeff)) { - result = m.mk_not(mk_le(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff)); - return true; - } - else if (m.is_eq(fml, x, y) && - get_pb_sum(x, rational::one(), args, coeffs, coeff) && - get_pb_sum(y, -rational::one(), args, coeffs, coeff)) { - result = mk_eq(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff); - return true; - } - return false; - } expr* mk_le(unsigned sz, rational const* weights, expr* const* args, rational const& w) { if (sz == 0) { @@ -208,42 +217,55 @@ public: } bool get_pb_sum(expr* x, rational const& mul, expr_ref_vector& args, vector& coeffs, rational& coeff) { + expr_ref_vector conds(m); + return get_sum(x, mul, conds, args, coeffs, coeff); + } + + bool get_sum(expr* x, rational const& mul, expr_ref_vector& conds, expr_ref_vector& args, vector& coeffs, rational& coeff) { expr *y, *z, *u; rational r, q; app* f = to_app(x); bool ok = true; if (a.is_add(x)) { for (unsigned i = 0; ok && i < f->get_num_args(); ++i) { - ok = get_pb_sum(f->get_arg(i), mul, args, coeffs, coeff); + ok = get_sum(f->get_arg(i), mul, conds, args, coeffs, coeff); } } else if (a.is_sub(x, y, z)) { - ok = get_pb_sum(y, mul, args, coeffs, coeff); - ok = ok && get_pb_sum(z, -mul, args, coeffs, coeff); + ok = get_sum(y, mul, conds, args, coeffs, coeff); + ok = ok && get_sum(z, -mul, conds, args, coeffs, coeff); } else if (a.is_uminus(x, y)) { - ok = get_pb_sum(y, -mul, args, coeffs, coeff); + ok = get_sum(y, -mul, conds, args, coeffs, coeff); } else if (a.is_mul(x, y, z) && is_numeral(y, r)) { - ok = get_pb_sum(z, r*mul, args, coeffs, coeff); + ok = get_sum(z, r*mul, conds, args, coeffs, coeff); } else if (a.is_mul(x, z, y) && is_numeral(y, r)) { - ok = get_pb_sum(z, r*mul, args, coeffs, coeff); + ok = get_sum(z, r*mul, conds, args, coeffs, coeff); } else if (a.is_to_real(x, y)) { - ok = get_pb_sum(y, mul, args, coeffs, coeff); + ok = get_sum(y, mul, conds, args, coeffs, coeff); } else if (m.is_ite(x, y, z, u) && is_numeral(z, r) && is_numeral(u, q)) { - insert_arg(r*mul, y, args, coeffs, coeff); + insert_arg(r*mul, add_conds(conds, y), args, coeffs, coeff); // q*(1-y) = -q*y + q coeff += q*mul; - insert_arg(-q*mul, y, args, coeffs, coeff); + insert_arg(-q*mul, add_conds(conds, y), args, coeffs, coeff); } + else if (m.is_ite(x, y, z, u)) { + conds.push_back(y); + ok = get_sum(z, mul, conds, args, coeffs, coeff); + conds.pop_back(); + conds.push_back(m.mk_not(y)); + ok &= get_sum(u, mul, conds, args, coeffs, coeff); + conds.pop_back(); + } else if (is_01var(x)) { - insert_arg(mul, mk_01(x), args, coeffs, coeff); + insert_arg(mul, add_conds(conds, mk_01(x)), args, coeffs, coeff); } else if (is_numeral(x, r)) { - coeff += mul*r; + insert_arg(mul*r, add_conds(conds, m.mk_true()), args, coeffs, coeff); } else { TRACE("pb", tout << "Can't handle " << mk_pp(x, m) << "\n";); @@ -252,6 +274,13 @@ public: return ok; } + expr_ref add_conds(expr_ref_vector const& es, expr* e) { + if (es.empty()) return expr_ref(e, m); + expr_ref result = expr_ref(m.mk_and(es.size(), es.c_ptr()), m); + result = m.mk_and(e, result); + return result; + } + bool is_numeral(expr* e, rational& r) { if (a.is_uminus(e, e) && is_numeral(e, r)) { r.neg(); @@ -265,7 +294,10 @@ public: void insert_arg(rational const& p, expr* x, expr_ref_vector& args, vector& coeffs, rational& coeff) { - if (p.is_neg()) { + if (m.is_true(x)) { + coeff += p; + } + else if (p.is_neg()) { // p*x = -p*(1-x) + p args.push_back(m.mk_not(x)); coeffs.push_back(-p);