/*++ Copyright (c) 2020 Microsoft Corporation Module Name: ba_internalize.cpp Abstract: Internalize methods for Boolean algebra operators. Author: Nikolaj Bjorner (nbjorner) 2020-08-25 --*/ #include "sat/smt/pb_solver.h" #include "ast/pb_decl_plugin.h" #include "sat/smt/euf_solver.h" namespace pb { void solver::internalize(expr* e, bool redundant) { internalize(e, false, false, redundant); } literal solver::internalize(expr* e, bool sign, bool root, bool redundant) { flet _redundant(m_is_redundant, redundant); if (m_pb.is_pb(e)) { sat::literal lit = internalize_pb(e, sign, root); if (m_ctx && !root && lit != sat::null_literal) m_ctx->attach_lit(lit, e); return lit; } UNREACHABLE(); return sat::null_literal; } literal solver::internalize_pb(expr* e, bool sign, bool root) { SASSERT(m_pb.is_pb(e)); app* t = to_app(e); rational k = m_pb.get_k(t); switch (t->get_decl_kind()) { case OP_AT_MOST_K: return convert_at_most_k(t, k, root, sign); case OP_AT_LEAST_K: return convert_at_least_k(t, k, root, sign); case OP_PB_LE: if (m_pb.has_unit_coefficients(t)) return convert_at_most_k(t, k, root, sign); else return convert_pb_le(t, root, sign); case OP_PB_GE: if (m_pb.has_unit_coefficients(t)) return convert_at_least_k(t, k, root, sign); else return convert_pb_ge(t, root, sign); case OP_PB_EQ: if (m_pb.has_unit_coefficients(t)) return convert_eq_k(t, k, root, sign); else return convert_pb_eq(t, root, sign); default: UNREACHABLE(); } return sat::null_literal; } void solver::check_unsigned(rational const& c) { if (!c.is_unsigned()) { throw default_exception("unsigned coefficient expected"); } } void solver::convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits) { for (unsigned i = 0; i < lits.size(); ++i) { rational c = m_pb.get_coeff(t, i); check_unsigned(c); wlits.push_back(std::make_pair(c.get_unsigned(), lits[i])); } } void solver::convert_pb_args(app* t, literal_vector& lits) { for (expr* arg : *t) { lits.push_back(si.internalize(arg, m_is_redundant)); s().set_external(lits.back().var()); } } void solver::convert_pb_args(app* t, svector& wlits) { sat::literal_vector lits; convert_pb_args(t, lits); convert_to_wlits(t, lits, wlits); } literal solver::convert_pb_le(app* t, bool root, bool sign) { rational k = m_pb.get_k(t); k.neg(); svector wlits; convert_pb_args(t, wlits); for (wliteral& wl : wlits) { wl.second.neg(); k += rational(wl.first); } check_unsigned(k); if (root && s().num_user_scopes() == 0) { unsigned k1 = k.get_unsigned(); if (sign) { k1 = 1 - k1; for (wliteral& wl : wlits) { wl.second.neg(); k1 += wl.first; } } add_pb_ge(sat::null_bool_var, wlits, k1); return sat::null_literal; } else { bool_var v = s().add_var(true); literal lit(v, sign); add_pb_ge(v, wlits, k.get_unsigned()); TRACE("ba", tout << "root: " << root << " lit: " << lit << "\n";); return lit; } } literal solver::convert_pb_ge(app* t, bool root, bool sign) { rational k = m_pb.get_k(t); check_unsigned(k); svector wlits; convert_pb_args(t, wlits); if (root && s().num_user_scopes() == 0) { unsigned k1 = k.get_unsigned(); if (sign) { k1 = 1 - k1; for (wliteral& wl : wlits) { wl.second.neg(); k1 += wl.first; } } add_pb_ge(sat::null_bool_var, wlits, k1); return sat::null_literal; } else { sat::bool_var v = s().add_var(true); sat::literal lit(v, sign); add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); return lit; } } literal solver::convert_pb_eq(app* t, bool root, bool sign) { rational k = m_pb.get_k(t); SASSERT(k.is_unsigned()); svector wlits; convert_pb_args(t, wlits); bool base_assert = (root && !sign && s().num_user_scopes() == 0); bool_var v1 = base_assert ? sat::null_bool_var : s().add_var(true); bool_var v2 = base_assert ? sat::null_bool_var : s().add_var(true); add_pb_ge(v1, wlits, k.get_unsigned()); k.neg(); for (wliteral& wl : wlits) { wl.second.neg(); k += rational(wl.first); } check_unsigned(k); add_pb_ge(v2, wlits, k.get_unsigned()); if (base_assert) { return sat::null_literal; } else { literal l1(v1, false), l2(v2, false); bool_var v = s().add_var(false); literal l(v, false); s().mk_clause(~l, l1); s().mk_clause(~l, l2); s().mk_clause(~l1, ~l2, l); si.cache(t, l); if (sign) l.neg(); return l; } } literal solver::convert_at_least_k(app* t, rational const& k, bool root, bool sign) { SASSERT(k.is_unsigned()); literal_vector lits; convert_pb_args(t, lits); unsigned k2 = k.get_unsigned(); if (root && s().num_user_scopes() == 0) { if (sign) { for (literal& l : lits) l.neg(); k2 = lits.size() + 1 - k2; } add_at_least(sat::null_bool_var, lits, k2); return sat::null_literal; } else { bool_var v = s().add_var(true); literal lit(v, false); add_at_least(v, lits, k.get_unsigned()); si.cache(t, lit); if (sign) lit.neg(); TRACE("ba", tout << "root: " << root << " lit: " << lit << "\n";); return lit; } } literal solver::convert_at_most_k(app* t, rational const& k, bool root, bool sign) { SASSERT(k.is_unsigned()); literal_vector lits; convert_pb_args(t, lits); for (literal& l : lits) { l.neg(); } unsigned k2 = lits.size() - k.get_unsigned(); if (root && s().num_user_scopes() == 0) { if (sign) { for (literal& l : lits) l.neg(); k2 = lits.size() + 1 - k2; } add_at_least(sat::null_bool_var, lits, k2); return sat::null_literal; } else { bool_var v = s().add_var(true); literal lit(v, false); add_at_least(v, lits, k2); si.cache(t, lit); if (sign) lit.neg(); return lit; } } literal solver::convert_eq_k(app* t, rational const& k, bool root, bool sign) { SASSERT(k.is_unsigned()); literal_vector lits; convert_pb_args(t, lits); bool_var v1 = (root && !sign) ? sat::null_bool_var : s().add_var(true); bool_var v2 = (root && !sign) ? sat::null_bool_var : s().add_var(true); add_at_least(v1, lits, k.get_unsigned()); for (literal& l : lits) { l.neg(); } add_at_least(v2, lits, lits.size() - k.get_unsigned()); if (!root || sign) { literal l1(v1, false), l2(v2, false); bool_var v = s().add_var(false); literal l(v, false); s().mk_clause(~l, l1); s().mk_clause(~l, l2); s().mk_clause(~l1, ~l2, l); si.cache(t, l); if (sign) l.neg(); return l; } else { return sat::null_literal; } } expr_ref solver::get_card(std::function& lit2expr, card const& c) { ptr_buffer lits; for (sat::literal l : c) { lits.push_back(lit2expr(l)); } expr_ref fml(m_pb.mk_at_least_k(c.size(), lits.data(), c.k()), m); if (c.lit() != sat::null_literal) { fml = m.mk_eq(lit2expr(c.lit()), fml); } return fml; } expr_ref solver::get_pb(std::function& lit2expr, pbc const& p) { ptr_buffer lits; vector coeffs; for (auto const& wl : p) { lits.push_back(lit2expr(wl.second)); coeffs.push_back(rational(wl.first)); } rational k(p.k()); expr_ref fml(m_pb.mk_ge(p.size(), coeffs.data(), lits.data(), k), m); if (p.lit() != sat::null_literal) { fml = m.mk_eq(lit2expr(p.lit()), fml); } return fml; } bool solver::to_formulas(std::function& l2e, expr_ref_vector& fmls) { for (auto* c : constraints()) { switch (c->tag()) { case pb::tag_t::card_t: fmls.push_back(get_card(l2e, c->to_card())); break; case pb::tag_t::pb_t: fmls.push_back(get_pb(l2e, c->to_pb())); break; } } return true; } }