diff --git a/src/ast/pb_decl_plugin.cpp b/src/ast/pb_decl_plugin.cpp index 21ca1e968..808f4e3dc 100644 --- a/src/ast/pb_decl_plugin.cpp +++ b/src/ast/pb_decl_plugin.cpp @@ -23,7 +23,8 @@ pb_decl_plugin::pb_decl_plugin(): m_at_most_sym("at-most"), m_at_least_sym("at-least"), m_pble_sym("pble"), - m_pbge_sym("pbge") + m_pbge_sym("pbge"), + m_pbeq_sym("pbeq") {} func_decl * pb_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, @@ -41,6 +42,7 @@ func_decl * pb_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, p case OP_AT_MOST_K: sym = m_at_most_sym; break; case OP_PB_LE: sym = m_pble_sym; break; case OP_PB_GE: sym = m_pbge_sym; break; + case OP_PB_EQ: sym = m_pbeq_sym; break; default: break; } switch(k) { @@ -53,7 +55,8 @@ func_decl * pb_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, p return m.mk_func_decl(sym, arity, domain, m.mk_bool_sort(), info); } case OP_PB_GE: - case OP_PB_LE: { + case OP_PB_LE: + case OP_PB_EQ: { if (num_parameters != 1 + arity) { m.raise_exception("function expects arity+1 rational parameters"); } @@ -74,7 +77,7 @@ func_decl * pb_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, p } } else { - m.raise_exception("function 'pble' expects arity+1 integer parameters"); + m.raise_exception("functions 'pble/pbge/pbeq' expect arity+1 integer parameters"); } } func_decl_info info(m_family_id, k, num_parameters, params.c_ptr()); @@ -89,8 +92,10 @@ func_decl * pb_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, p void pb_decl_plugin::get_op_names(svector & op_names, symbol const & logic) { if (logic == symbol::null) { op_names.push_back(builtin_name(m_at_most_sym.bare_str(), OP_AT_MOST_K)); + op_names.push_back(builtin_name(m_at_least_sym.bare_str(), OP_AT_LEAST_K)); op_names.push_back(builtin_name(m_pble_sym.bare_str(), OP_PB_LE)); op_names.push_back(builtin_name(m_pbge_sym.bare_str(), OP_PB_GE)); + op_names.push_back(builtin_name(m_pbeq_sym.bare_str(), OP_PB_EQ)); } } @@ -112,6 +117,15 @@ app * pb_util::mk_ge(unsigned num_args, rational const * coeffs, expr * const * return m.mk_app(m_fid, OP_PB_GE, params.size(), params.c_ptr(), num_args, args, m.mk_bool_sort()); } +app * pb_util::mk_eq(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k) { + vector params; + params.push_back(parameter(k)); + for (unsigned i = 0; i < num_args; ++i) { + params.push_back(parameter(coeffs[i])); + } + return m.mk_app(m_fid, OP_PB_EQ, params.size(), params.c_ptr(), num_args, args, m.mk_bool_sort()); +} + app * pb_util::mk_at_most_k(unsigned num_args, expr * const * args, unsigned k) { @@ -123,7 +137,7 @@ bool pb_util::is_at_most_k(func_decl *a) const { return is_decl_of(a, m_fid, OP_AT_MOST_K); } -bool pb_util::is_at_most_k(app *a, rational& k) const { +bool pb_util::is_at_most_k(expr *a, rational& k) const { if (is_at_most_k(a)) { k = get_k(a); return true; @@ -133,7 +147,6 @@ bool pb_util::is_at_most_k(app *a, rational& k) const { } } - app * pb_util::mk_at_least_k(unsigned num_args, expr * const * args, unsigned k) { parameter param(k); return m.mk_app(m_fid, OP_AT_LEAST_K, 1, ¶m, num_args, args, m.mk_bool_sort()); @@ -143,7 +156,7 @@ bool pb_util::is_at_least_k(func_decl *a) const { return is_decl_of(a, m_fid, OP_AT_LEAST_K); } -bool pb_util::is_at_least_k(app *a, rational& k) const { +bool pb_util::is_at_least_k(expr *a, rational& k) const { if (is_at_least_k(a)) { k = get_k(a); return true; @@ -159,7 +172,7 @@ rational pb_util::get_k(func_decl *a) const { return to_rational(p); } else { - SASSERT(is_le(a) || is_ge(a)); + SASSERT(is_le(a) || is_ge(a) || is_eq(a)); return to_rational(p); } } @@ -169,7 +182,7 @@ bool pb_util::is_le(func_decl *a) const { return is_decl_of(a, m_fid, OP_PB_LE); } -bool pb_util::is_le(app* a, rational& k) const { +bool pb_util::is_le(expr* a, rational& k) const { if (is_le(a)) { k = get_k(a); return true; @@ -183,7 +196,7 @@ bool pb_util::is_ge(func_decl *a) const { return is_decl_of(a, m_fid, OP_PB_GE); } -bool pb_util::is_ge(app* a, rational& k) const { +bool pb_util::is_ge(expr* a, rational& k) const { if (is_ge(a)) { k = get_k(a); return true; @@ -193,11 +206,26 @@ bool pb_util::is_ge(app* a, rational& k) const { } } + +bool pb_util::is_eq(func_decl *a) const { + return is_decl_of(a, m_fid, OP_PB_EQ); +} + +bool pb_util::is_eq(expr* a, rational& k) const { + if (is_eq(a)) { + k = get_k(a); + return true; + } + else { + return false; + } +} + rational pb_util::get_coeff(func_decl* a, unsigned index) const { if (is_at_most_k(a) || is_at_least_k(a)) { return rational::one(); } - SASSERT(is_le(a) || is_ge(a)); + SASSERT(is_le(a) || is_ge(a) || is_eq(a)); SASSERT(1 + index < a->get_num_parameters()); return to_rational(a->get_parameter(index + 1)); } diff --git a/src/ast/pb_decl_plugin.h b/src/ast/pb_decl_plugin.h index 9ee74dfaf..54da51775 100644 --- a/src/ast/pb_decl_plugin.h +++ b/src/ast/pb_decl_plugin.h @@ -34,6 +34,7 @@ enum pb_op_kind { OP_AT_LEAST_K, // at least K Booleans are true. OP_PB_LE, // pseudo-Boolean <= (generalizes at_most_k) OP_PB_GE, // pseudo-Boolean >= + OP_PB_EQ, // equality LAST_PB_OP }; @@ -43,10 +44,12 @@ class pb_decl_plugin : public decl_plugin { symbol m_at_least_sym; symbol m_pble_sym; symbol m_pbge_sym; + symbol m_pbeq_sym; func_decl * mk_at_most(unsigned arity, unsigned k); func_decl * mk_at_least(unsigned arity, unsigned k); func_decl * mk_le(unsigned arity, rational const* coeffs, int k); func_decl * mk_ge(unsigned arity, rational const* coeffs, int k); + func_decl * mk_eq(unsigned arity, rational const* coeffs, int k); public: pb_decl_plugin(); virtual ~pb_decl_plugin() {} @@ -82,22 +85,27 @@ public: app * mk_at_least_k(unsigned num_args, expr * const * args, unsigned k); app * mk_le(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k); app * mk_ge(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k); + app * mk_eq(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k); bool is_at_most_k(func_decl *a) const; bool is_at_most_k(expr *a) const { return is_app(a) && is_at_most_k(to_app(a)->get_decl()); } - bool is_at_most_k(app *a, rational& k) const; + bool is_at_most_k(expr *a, rational& k) const; bool is_at_least_k(func_decl *a) const; bool is_at_least_k(expr *a) const { return is_app(a) && is_at_least_k(to_app(a)->get_decl()); } - bool is_at_least_k(app *a, rational& k) const; + bool is_at_least_k(expr *a, rational& k) const; rational get_k(func_decl *a) const; rational get_k(expr *a) const { return get_k(to_app(a)->get_decl()); } bool is_le(func_decl *a) const; bool is_le(expr *a) const { return is_app(a) && is_le(to_app(a)->get_decl()); } - bool is_le(app* a, rational& k) const; + bool is_le(expr* a, rational& k) const; bool is_ge(func_decl* a) const; bool is_ge(expr* a) const { return is_app(a) && is_ge(to_app(a)->get_decl()); } - bool is_ge(app* a, rational& k) const; + bool is_ge(expr* a, rational& k) const; rational get_coeff(expr* a, unsigned index) const { return get_coeff(to_app(a)->get_decl(), index); } rational get_coeff(func_decl* a, unsigned index) const; + + bool is_eq(func_decl* f) const; + bool is_eq(expr* e) const { return is_app(e) && is_eq(to_app(e)->get_decl()); } + bool is_eq(expr* e, rational& k) const; private: rational to_rational(parameter const& p) const; }; diff --git a/src/ast/rewriter/pb_rewriter.cpp b/src/ast/rewriter/pb_rewriter.cpp index 5b1b37120..76dcdd958 100644 --- a/src/ast/rewriter/pb_rewriter.cpp +++ b/src/ast/rewriter/pb_rewriter.cpp @@ -102,19 +102,21 @@ br_status pb_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * cons break; case OP_AT_LEAST_K: case OP_PB_GE: + case OP_PB_EQ: break; default: UNREACHABLE(); return BR_FAILED; } + bool is_eq = f->get_decl_kind() == OP_PB_EQ; pb_ast_rewriter_util pbu(m); pb_rewriter_util util(pbu); - util.unique(vec, k); - lbool is_sat = util.normalize(vec, k); - util.prune(vec, k); + util.unique(vec, k, is_eq); + lbool is_sat = util.normalize(vec, k, is_eq); + util.prune(vec, k, is_eq); switch (is_sat) { case l_true: result = m.mk_true(); @@ -129,7 +131,12 @@ br_status pb_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * cons m_args.push_back(vec[i].first); m_coeffs.push_back(vec[i].second); } - result = m_util.mk_ge(vec.size(), m_coeffs.c_ptr(), m_args.c_ptr(), k); + if (is_eq) { + result = m_util.mk_eq(vec.size(), m_coeffs.c_ptr(), m_args.c_ptr(), k); + } + else { + result = m_util.mk_ge(vec.size(), m_coeffs.c_ptr(), m_args.c_ptr(), k); + } break; } TRACE("pb", diff --git a/src/ast/rewriter/pb_rewriter.h b/src/ast/rewriter/pb_rewriter.h index 8ea41b668..0e0986d1a 100644 --- a/src/ast/rewriter/pb_rewriter.h +++ b/src/ast/rewriter/pb_rewriter.h @@ -28,12 +28,12 @@ Notes: template class pb_rewriter_util { PBU& m_util; - void display(std::ostream& out, typename PBU::args_t& args, typename PBU::numeral& k); + void display(std::ostream& out, typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq); public: pb_rewriter_util(PBU& u) : m_util(u) {} - void unique(typename PBU::args_t& args, typename PBU::numeral& k); - lbool normalize(typename PBU::args_t& args, typename PBU::numeral& k); - void prune(typename PBU::args_t& args, typename PBU::numeral& k); + void unique(typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq); + lbool normalize(typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq); + void prune(typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq); }; /** diff --git a/src/ast/rewriter/pb_rewriter_def.h b/src/ast/rewriter/pb_rewriter_def.h index c6e21b6ce..3c60babce 100644 --- a/src/ast/rewriter/pb_rewriter_def.h +++ b/src/ast/rewriter/pb_rewriter_def.h @@ -23,20 +23,20 @@ Notes: template -void pb_rewriter_util::display(std::ostream& out, typename PBU::args_t& args, typename PBU::numeral& k) { +void pb_rewriter_util::display(std::ostream& out, typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq) { for (unsigned i = 0; i < args.size(); ++i) { out << args[i].second << " * "; m_util.display(out, args[i].first); out << " "; if (i+1 < args.size()) out << "+ "; } - out << " >= " << k << "\n"; + out << (is_eq?" = ":" >= ") << k << "\n"; } template -void pb_rewriter_util::unique(typename PBU::args_t& args, typename PBU::numeral& k) { +void pb_rewriter_util::unique(typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq) { - TRACE("pb_verbose", display(tout << "pre-unique:", args, k);); + TRACE("pb_verbose", display(tout << "pre-unique:", args, k, is_eq);); for (unsigned i = 0; i < args.size(); ++i) { if (m_util.is_negated(args[i].first)) { args[i].first = m_util.negate(args[i].first); @@ -85,19 +85,19 @@ void pb_rewriter_util::unique(typename PBU::args_t& args, typename PBU::num } } args.resize(i); - TRACE("pb_verbose", display(tout << "post-unique:", args, k);); + TRACE("pb_verbose", display(tout << "post-unique:", args, k, is_eq);); } template -lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU::numeral& k) { - TRACE("pb_verbose", display(tout << "pre-normalize:", args, k);); +lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq) { + TRACE("pb_verbose", display(tout << "pre-normalize:", args, k, is_eq);); DEBUG_CODE( bool found = false; for (unsigned i = 0; !found && i < args.size(); ++i) { found = args[i].second.is_zero(); } - if (found) display(verbose_stream(), args, k); + if (found) display(verbose_stream(), args, k, is_eq); SASSERT(!found);); // @@ -121,17 +121,29 @@ lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU: sum += args[i].second; } // detect tautologies: - if (k <= PBU::numeral::zero()) { + if (!is_eq && k <= PBU::numeral::zero()) { args.reset(); k = PBU::numeral::zero(); return l_true; } + if (is_eq && k.is_zero() && args.empty()) { + return l_true; + } + // detect infeasible constraints: if (sum < k) { args.reset(); k = PBU::numeral::one(); return l_false; } + + if (is_eq && k == sum) { + for (unsigned i = 0; i < args.size(); ++i) { + args[i].second = PBU::numeral::one(); + } + k = PBU::numeral::one(); + return l_undef; + } bool all_int = true; for (unsigned i = 0; all_int && i < args.size(); ++i) { @@ -150,6 +162,11 @@ lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU: args[i].second *= d; } } + + if (is_eq) { + TRACE("pb_verbose", display(tout << "post-normalize:", args, k, is_eq);); + return l_undef; + } // Ensure the largest coefficient is not larger than k: sum = PBU::numeral::zero(); @@ -193,7 +210,7 @@ lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU: } else if (g > PBU::numeral::one()) { IF_VERBOSE(3, verbose_stream() << "cut " << g << "\n"; - display(verbose_stream(), args, k); + display(verbose_stream(), args, k, is_eq); ); // @@ -241,7 +258,7 @@ lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU: PBU::numeral n1 = floor(n0); PBU::numeral n2 = ceil(k/min) - PBU::numeral::one(); if (n1 == n2 && !n0.is_int()) { - IF_VERBOSE(3, display(verbose_stream() << "set cardinality\n", args, k);); + IF_VERBOSE(3, display(verbose_stream() << "set cardinality\n", args, k, is_eq);); for (unsigned i = 0; i < args.size(); ++i) { args[i].second = PBU::numeral::one(); @@ -249,13 +266,15 @@ lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU: k = n1 + PBU::numeral::one(); } } + TRACE("pb_verbose", display(tout << "post-normalize:", args, k, is_eq);); return l_undef; } - template -void pb_rewriter_util::prune(typename PBU::args_t& args, typename PBU::numeral& k) { - +void pb_rewriter_util::prune(typename PBU::args_t& args, typename PBU::numeral& k, bool is_eq) { + if (is_eq) { + return; + } PBU::numeral nlt(0); unsigned occ = 0; for (unsigned i = 0; nlt < k && i < args.size(); ++i) { @@ -264,8 +283,7 @@ void pb_rewriter_util::prune(typename PBU::args_t& args, typename PBU::nume ++occ; } } - if (0 < occ && nlt < k) { - + if (0 < occ && nlt < k) { for (unsigned i = 0; i < args.size(); ++i) { if (args[i].second < k) { args[i] = args.back(); @@ -273,8 +291,8 @@ void pb_rewriter_util::prune(typename PBU::args_t& args, typename PBU::nume --i; } } - unique(args, k); - normalize(args, k); + unique(args, k, is_eq); + normalize(args, k, is_eq); } } diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 6c9df93b0..b17f1ba29 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -27,6 +27,7 @@ Notes: #include "tactic.h" #include "lia2card_tactic.h" #include "elim01_tactic.h" +#include "solve_eqs_tactic.h" #include "simplify_tactic.h" #include "tactical.h" #include "model_smt2_pp.h" @@ -299,15 +300,15 @@ namespace opt { g->assert_expr(fmls[i].get()); } tactic_ref tac0 = mk_simplify_tactic(m); - tactic_ref tac1 = mk_elim01_tactic(m); - tactic_ref tac2 = mk_lia2card_tactic(m); + tactic_ref tac2 = mk_elim01_tactic(m); + tactic_ref tac3 = mk_lia2card_tactic(m); tactic_ref tac; opt_params optp(m_params); if (optp.elim_01()) { - tac = and_then(tac0.get(), tac1.get(), tac2.get()); + tac = and_then(tac0.get(), tac2.get(), tac3.get()); } else { - tac = tac0; + tac = tac0.get(); } proof_converter_ref pc; expr_dependency_ref core(m); diff --git a/src/opt/weighted_maxsat.cpp b/src/opt/weighted_maxsat.cpp index 8bf5e8430..12a52208f 100644 --- a/src/opt/weighted_maxsat.cpp +++ b/src/opt/weighted_maxsat.cpp @@ -38,20 +38,26 @@ namespace smt { stats() { reset(); } }; - opt::opt_solver& s; - app_ref_vector m_vars; // Auxiliary variables per soft clause - expr_ref_vector m_fmls; // Formulas per soft clause + opt::opt_solver& s; + mutable unsynch_mpz_manager m_mpz; + app_ref_vector m_vars; // Auxiliary variables per soft clause + expr_ref_vector m_fmls; // Formulas per soft clause app_ref m_min_cost_atom; // atom tracking modified lower bound app_ref_vector m_min_cost_atoms; - bool_var m_min_cost_bv; // max cost Boolean variable - vector m_weights; // weights of theory variables. - svector m_costs; // set of asserted theory variables - svector m_cost_save; // set of asserted theory variables - rational m_cost; // current sum of asserted costs - rational m_min_cost; // current maximal cost assignment. - u_map m_bool2var; // bool_var -> theory_var - svector m_var2bool; // theory_var -> bool_var + bool_var m_min_cost_bv; // max cost Boolean variable + vector m_rweights; // weights of theory variables. + scoped_mpz_vector m_zweights; + svector m_costs; // set of asserted theory variables + svector m_cost_save; // set of asserted theory variables + rational m_rcost; // current sum of asserted costs + rational m_rmin_cost; // current maximal cost assignment. + scoped_mpz m_zcost; // current sum of asserted costs + scoped_mpz m_zmin_cost; // current maximal cost assignment. + u_map m_bool2var; // bool_var -> theory_var + svector m_var2bool; // theory_var -> bool_var bool m_propagate; + bool m_normalize; + rational m_den; // lcm of denominators for rational weights. svector m_assigned; stats m_stats; @@ -63,9 +69,15 @@ namespace smt { m_fmls(m), m_min_cost_atom(m), m_min_cost_atoms(m), - m_propagate(false) + m_zweights(m_mpz), + m_zcost(m_mpz), + m_zmin_cost(m_mpz), + m_propagate(false), + m_normalize(false) {} + virtual ~theory_weighted_maxsat() { } + /** \brief return the complement of variables that are currently assigned. */ @@ -100,14 +112,7 @@ namespace smt { } virtual void init_search_eh() { - context & ctx = get_context(); - ast_manager& m = get_manager(); - bool initialized = !m_var2bool.empty(); m_propagate = true; - - for (unsigned i = 0; i < m_min_cost_atoms.size(); ++i) { - app* var = m_min_cost_atoms[i].get(); - } } void assert_weighted(expr* fml, rational const& w) { @@ -118,12 +123,12 @@ namespace smt { s.mc().insert(var->get_decl()); wfml = m.mk_or(var, fml); ctx.assert_expr(wfml); - m_weights.push_back(w); + m_rweights.push_back(w); m_vars.push_back(var); m_fmls.push_back(fml); m_assigned.push_back(false); - m_min_cost += w; - + m_rmin_cost += w; + m_normalize = true; register_var(var, true); } @@ -152,15 +157,20 @@ namespace smt { return bv; } - rational const& get_min_cost() const { - return m_min_cost; + rational const& get_min_cost() { + unsynch_mpq_manager mgr; + scoped_mpq q(mgr); + mgr.set(q, m_zmin_cost, m_den.to_mpq().numerator()); + m_rmin_cost = rational(q); + return m_rmin_cost; } expr* set_min_cost(rational const& c) { + m_normalize = true; ast_manager& m = get_manager(); std::ostringstream strm; strm << "cost <= " << c; - m_min_cost = c; + m_rmin_cost = c; m_min_cost_atom = m.mk_fresh_const(strm.str().c_str(), m.mk_bool_sort()); m_min_cost_atoms.push_back(m_min_cost_atom); s.mc().insert(m_min_cost_atom->get_decl()); @@ -177,17 +187,18 @@ namespace smt { virtual void assign_eh(bool_var v, bool is_true) { TRACE("opt", tout << "Assign " << mk_pp(m_vars[m_bool2var[v]].get(), get_manager()) << " " << is_true << "\n";); if (is_true) { + if (m_normalize) normalize(); context& ctx = get_context(); theory_var tv = m_bool2var[v]; if (m_assigned[tv]) return; - rational const& w = m_weights[tv]; - ctx.push_trail(value_trail(m_cost)); + mpz const& w = m_zweights[tv]; + ctx.push_trail(value_trail(m_zcost)); ctx.push_trail(push_back_vector >(m_costs)); ctx.push_trail(value_trail(m_assigned[tv])); - m_cost += w; + m_zcost += w; m_costs.push_back(tv); m_assigned[tv] = true; - if (m_cost > m_min_cost) { + if (m_zcost > m_zmin_cost) { block(); } } @@ -213,10 +224,13 @@ namespace smt { theory::reset_eh(); m_vars.reset(); m_fmls.reset(); - m_weights.reset(); + m_rweights.reset(); m_costs.reset(); - m_min_cost.reset(); - m_cost.reset(); + m_rmin_cost.reset(); + m_rcost.reset(); + m_zweights.reset(); + m_zcost.reset(); + m_zmin_cost.reset(); m_cost_save.reset(); m_bool2var.reset(); m_var2bool.reset(); @@ -254,7 +268,7 @@ namespace smt { } bool is_optimal() const { - return m_cost < m_min_cost; + return m_mpz.lt(m_zcost, m_zmin_cost); } expr_ref mk_block() { @@ -263,17 +277,22 @@ namespace smt { compare_cost compare_cost(*this); svector costs(m_costs); std::sort(costs.begin(), costs.end(), compare_cost); - rational weight(0); - for (unsigned i = 0; i < costs.size() && weight < m_min_cost; ++i) { - weight += m_weights[costs[i]]; + scoped_mpz weight(m_mpz); + m_mpz.reset(weight); + for (unsigned i = 0; i < costs.size() && m_mpz.lt(weight, m_zmin_cost); ++i) { + weight += m_zweights[costs[i]]; disj.push_back(m.mk_not(m_vars[costs[i]].get())); } if (m_min_cost_atom) { disj.push_back(m.mk_not(m_min_cost_atom)); } if (is_optimal()) { - IF_VERBOSE(1, verbose_stream() << "(wmaxsat with upper bound: " << weight << ")\n";); - m_min_cost = weight; + unsynch_mpq_manager mgr; + scoped_mpq q(mgr); + mgr.set(q, m_zmin_cost, m_den.to_mpq().numerator()); + rational rw = rational(q); + IF_VERBOSE(1, verbose_stream() << "(wmaxsat with upper bound: " << rw << ")\n";); + m_zmin_cost = weight; m_cost_save.reset(); m_cost_save.append(m_costs); } @@ -307,9 +326,11 @@ namespace smt { compare_cost compare_cost(*this); svector costs(m_costs); std::sort(costs.begin(), costs.end(), compare_cost); - rational weight(0); - for (unsigned i = 0; i < costs.size() && weight < m_min_cost; ++i) { - weight += m_weights[costs[i]]; + + scoped_mpz weight(m_mpz); + m_mpz.reset(weight); + for (unsigned i = 0; i < costs.size() && weight < m_zmin_cost; ++i) { + weight += m_zweights[costs[i]]; lits.push_back(~literal(m_var2bool[costs[i]])); } if (m_min_cost_atom) { @@ -328,13 +349,33 @@ namespace smt { ctx.mk_th_axiom(get_id(), lits.size(), lits.c_ptr()); } + + void normalize() { + m_den = rational::one(); + for (unsigned i = 0; i < m_rweights.size(); ++i) { + m_den = lcm(m_den, denominator(m_rweights[i])); + } + m_den = lcm(m_den, denominator(m_rmin_cost)); + m_zweights.reset(); + for (unsigned i = 0; i < m_rweights.size(); ++i) { + rational r = m_rweights[i]*m_den; + SASSERT(r.is_int()); + mpq const& q = r.to_mpq(); + m_zweights.push_back(q.numerator()); + } + rational r = m_rcost* m_den; + m_zcost = r.to_mpq().numerator(); + r = m_rmin_cost * m_den; + m_zmin_cost = r.to_mpq().numerator(); + m_normalize = false; + } class compare_cost { theory_weighted_maxsat& m_th; public: compare_cost(theory_weighted_maxsat& t):m_th(t) {} bool operator() (theory_var v, theory_var w) const { - return m_th.m_weights[v] > m_th.m_weights[w]; + return m_th.m_mpz.gt(m_th.m_zweights[v], m_th.m_zweights[w]); } }; }; diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index 57d81a56b..ebb5aee10 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -28,686 +28,6 @@ Notes: namespace smt { - // parametric sorting network - // Described in Abio et.al. CP 2013. - class psort_nw { - class vc { - unsigned v; // number of vertices - unsigned c; // number of clauses - static const unsigned lambda = 5; - public: - vc(unsigned v, unsigned c):v(v), c(c) {} - - bool operator<(vc const& other) const { - return to_int() < other.to_int(); - } - vc operator+(vc const& other) const { - return vc(v + other.v, c + other.c); - } - unsigned to_int() const { - return lambda*v + c; - } - vc operator*(unsigned n) const { - return vc(n*v, n*c); - } - }; - - static vc min(vc const& v1, vc const& v2) { - return (v1.to_int() < v2.to_int())?v1:v2; - } - - - enum cmp_t { LE, GE, EQ, GE_FULL, LE_FULL }; - context& ctx; - cmp_t m_t; - - // for testing - static const bool m_disable_dcard = false; - static const bool m_disable_dsorting = false; - static const bool m_disable_dsmerge = false; - static const bool m_force_dcard = false; - static const bool m_force_dsorting = false; - static const bool m_force_dsmerge = false; - - public: - struct stats { - unsigned m_num_compiled_vars; - unsigned m_num_compiled_clauses; - void reset() { memset(this, 0, sizeof(*this)); } - stats() { reset(); } - }; - stats m_stats; - - psort_nw(context& c): ctx(c) {} - - literal ge(bool full, unsigned k, unsigned n, literal const* xs) { - if (k > n) { - return false_literal; - } - if (k == 0) { - return true_literal; - } - SASSERT(0 < k && k <= n); - literal_vector in, out; - if (dualize(k, n, xs, in)) { - return le(full, k, in.size(), in.c_ptr()); - } - else { - SASSERT(2*k <= n); - m_t = full?GE_FULL:GE; - card(k, n, xs, out); - return out[k-1]; - } - } - - literal le(bool full, unsigned k, unsigned n, literal const* xs) { - if (k >= n) { - return true_literal; - } - SASSERT(k < n); - literal_vector in, out; - if (dualize(k, n, xs, in)) { - return ge(full, k, n, in.c_ptr()); - } - else { - SASSERT(2*k <= n); - m_t = full?LE_FULL:LE; - card(k + 1, n, xs, out); - return ~out[k]; - } - } - - literal eq(unsigned k, unsigned n, literal const* xs) { - if (k > n) { - return false_literal; - } - SASSERT(k <= n); - literal_vector in, out; - if (dualize(k, n, xs, in)) { - return eq(k, n, in.c_ptr()); - } - else { - SASSERT(2*k < n); - m_t = EQ; - card(k+1, n, xs, out); - SASSERT(out.size() >= k+1); - return out[k-1]; // & ~out[m] TBD - } - } - - - private: - - std::ostream& pp(std::ostream& out, unsigned n, literal const* lits) { - for (unsigned i = 0; i < n; ++i) out << lits[i] << " "; - return out; - } - - std::ostream& pp(std::ostream& out, literal_vector const& lits) { - for (unsigned i = 0; i < lits.size(); ++i) out << lits[i] << " "; - return out; - } - - std::ostream& ppv(std::ostream& out, unsigned n, literal const* lits) { - for (unsigned i = 0; i < n; ++i) { - expr_ref tmp(ctx.get_manager()); - ctx.literal2expr(lits[i], tmp); - out << tmp << " "; - } - return out; - } - - std::ostream& ppv(std::ostream& out, literal_vector const& lits) { - for (unsigned i = 0; i < lits.size(); ++i) { - expr_ref tmp(ctx.get_manager()); - ctx.literal2expr(lits[i], tmp); - out << tmp << " "; - } - return out; - } - - // 0 <= k <= N - // SUM x_i >= k - // <=> - // SUM ~x_i <= N - k - // suppose k > N/2, then it is better to solve dual. - - bool dualize(unsigned& k, unsigned N, literal const* xs, literal_vector& in) { - SASSERT(0 <= k && k <= N); - if (2*k <= N) { - return false; - } - k = N - k; - for (unsigned i = 0; i < N; ++i) { - in.push_back(~xs[i]); - } - TRACE("pb", - pp(tout << N << ": ", in); - tout << " ~ " << k << "\n";); - return true; - } - - - bool even(unsigned n) const { return (0 == (n & 0x1)); } - bool odd(unsigned n) const { return !even(n); } - unsigned ceil2(unsigned n) const { return n/2 + odd(n); } - unsigned floor2(unsigned n) const { return n/2; } - unsigned power2(unsigned n) const { SASSERT(n < 10); return 1 << n; } - - literal max(literal a, literal b) { - if (a == b) return a; - m_stats.m_num_compiled_vars++; - ast_manager& m = ctx.get_manager(); - expr_ref t1(m), t2(m), t3(m); - ctx.literal2expr(a, t1); - ctx.literal2expr(b, t2); - t3 = m.mk_or(t1, t2); - bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); - return literal(v); - } - - literal min(literal a, literal b) { - if (a == b) return a; - m_stats.m_num_compiled_vars++; - ast_manager& m = ctx.get_manager(); - expr_ref t1(m), t2(m), t3(m); - ctx.literal2expr(a, t1); - ctx.literal2expr(b, t2); - t3 = m.mk_and(t1, t2); - bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); - return literal(v); - } - - literal fresh() { - m_stats.m_num_compiled_vars++; - ast_manager& m = ctx.get_manager(); - app_ref y(m); - y = m.mk_fresh_const("y", m.mk_bool_sort()); - return literal(ctx.mk_bool_var(y)); - } - void add_clause(literal l1, literal l2, literal l3) { - literal lits[3] = { l1, l2, l3 }; - add_clause(3, lits); - } - void add_clause(literal l1, literal l2) { - literal lits[2] = { l1, l2 }; - add_clause(2, lits); - } - void add_clause(unsigned n, literal const* ls) { - m_stats.m_num_compiled_clauses++; - literal_vector tmp(n, ls); - TRACE("pb", pp(tout, n, ls) << "\n";); - ctx.mk_clause(n, tmp.c_ptr(), 0, CLS_AUX, 0); - } - - // y1 <= max(x1,x2) - // y2 <= min(x1,x2) - void cmp_ge(literal x1, literal x2, literal y1, literal y2) { - add_clause(~y2, x1); - add_clause(~y2, x2); - add_clause(~y1, x1, x2); - } - - // max(x1,x2) <= y1 - // min(x1,x2) <= y2 - void cmp_le(literal x1, literal x2, literal y1, literal y2) { - add_clause(~x1, y1); - add_clause(~x2, y1); - add_clause(~x1, ~x2, y2); - } - - void cmp_eq(literal x1, literal x2, literal y1, literal y2) { - cmp_ge(x1, x2, y1, y2); - cmp_le(x1, x2, y1, y2); - } - - void cmp(literal x1, literal x2, literal y1, literal y2) { - switch(m_t) { - case LE: cmp_le(x1, x2, y1, y2); break; - case GE: cmp_ge(x1, x2, y1, y2); break; - case EQ: cmp_eq(x1, x2, y1, y2); break; - } - } - vc vc_cmp() { - return vc(2, (m_t==EQ)?6:3); - } - - void card(unsigned k, unsigned n, literal const* xs, literal_vector& out) { - TRACE("pb", tout << "card k:" << k << " n: " << n << "\n";); - if (n <= k) { - sorting(n, xs, out); - } - else if (use_dcard(k, n)) { - dsorting(k, n, xs, out); - } - else { - literal_vector out1, out2; - unsigned l = n/2; // TBD - card(k, l, xs, out1); - card(k, n-l, xs + l, out2); - smerge(k, out1.size(), out1.c_ptr(), out2.size(), out2.c_ptr(), out); - } - TRACE("pb", tout << "card k:" << k << " n: " << n << "\n"; - pp(tout << "in:", n, xs) << "\n"; - pp(tout << "out:", out) << "\n";); - - } - vc vc_card(unsigned k, unsigned n) { - if (n <= k) { - return vc_sorting(n); - } - else if (use_dcard(k, n)) { - return vc_dsorting(k, n); - } - else { - return vc_card_rec(k, n); - } - } - vc vc_card_rec(unsigned k, unsigned n) { - unsigned l = n/2; - return vc_card(k, l) + vc_card(k, n-l) + vc_smerge(k, l, n-l); - } - bool use_dcard(unsigned k, unsigned n) { - return m_force_dcard || (!m_disable_dcard && n < 10 && vc_dsorting(k, n) < vc_card_rec(k, n)); - } - - - void merge(unsigned a, literal const* as, - unsigned b, literal const* bs, - literal_vector& out) { - TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n";); - if (a == 1 && b == 1) { - literal y1 = max(as[0], bs[0]); - literal y2 = min(as[0], bs[0]); - out.push_back(y1); - out.push_back(y2); - cmp(as[0], bs[0], y1, y2); - } - else if (a == 0) { - out.append(b, bs); - } - else if (b == 0) { - out.append(a, as); - } - else if (use_dsmerge(a, b, a + b)) { - dsmerge(a + b, a, as, b, bs, out); - } - else if (even(a) && odd(b)) { - merge(b, bs, a, as, out); - } - else { - literal_vector even_a, odd_a; - literal_vector even_b, odd_b; - literal_vector out1, out2; - SASSERT(a > 1 || b > 1); - split(a, as, even_a, odd_a); - split(b, bs, even_b, odd_b); - SASSERT(!even_a.empty()); - SASSERT(!even_b.empty()); - merge(even_a.size(), even_a.c_ptr(), - even_b.size(), even_b.c_ptr(), out1); - merge(odd_a.size(), odd_a.c_ptr(), - odd_b.size(), odd_b.c_ptr(), out2); - interleave(out1, out2, out); - } - TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n"; - pp(tout << "a:", a, as) << "\n"; - pp(tout << "b:", b, bs) << "\n"; - pp(tout << "out:", out) << "\n";); - } - vc vc_merge(unsigned a, unsigned b) { - if (a == 1 && b == 1) { - return vc_cmp(); - } - else if (a == 0 || b == 0) { - return vc(0, 0); - } - else if (use_dsmerge(a, b, a + b)) { - return vc_dsmerge(a, b, a + b); - } - else { - return vc_merge_rec(a, b); - } - } - vc vc_merge_rec(unsigned a, unsigned b) { - return - vc_merge(ceil2(a), ceil2(b)) + - vc_merge(floor2(a), floor2(b)) + - vc_interleave(ceil2(a) + ceil2(b), floor2(a) + floor2(b)); - } - void split(unsigned n, literal const* ls, literal_vector& even, literal_vector& odd) { - for (unsigned i = 0; i < n; i += 2) { - even.push_back(ls[i]); - } - for (unsigned i = 1; i < n; i += 2) { - odd.push_back(ls[i]); - } - } - - void interleave(literal_vector const& as, - literal_vector const& bs, - literal_vector& out) { - TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n";); - SASSERT(as.size() >= bs.size()); - SASSERT(as.size() <= bs.size() + 2); - SASSERT(!as.empty()); - out.push_back(as[0]); - unsigned sz = std::min(as.size()-1, bs.size()); - for (unsigned i = 0; i < sz; ++i) { - literal y1 = max(as[i+1],bs[i]); - literal y2 = min(as[i+1],bs[i]); - cmp(as[i+1], bs[i], y1, y2); - out.push_back(y1); - out.push_back(y2); - } - if (as.size() == bs.size()) { - out.push_back(bs[sz]); - } - else if (as.size() == bs.size() + 2) { - out.push_back(as[sz+1]); - } - SASSERT(out.size() == as.size() + bs.size()); - TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n"; - pp(tout << "a: ", as) << "\n"; - pp(tout << "b: ", bs) << "\n"; - pp(tout << "out: ", out) << "\n";); - - } - vc vc_interleave(unsigned a, unsigned b) { - return vc_cmp()*std::min(a-1,b); - } - - void sorting(unsigned n, literal const* xs, literal_vector& out) { - TRACE("pb", tout << "sorting: " << n << "\n";); - switch(n) { - case 0: - break; - case 1: - out.push_back(xs[0]); - break; - case 2: - merge(1, xs, 1, xs+1, out); - break; - default: - if (use_dsorting(n)) { - dsorting(n, n, xs, out); - } - else { - literal_vector out1, out2; - unsigned l = n/2; // TBD - sorting(l, xs, out1); - sorting(n-l, xs+l, out2); - merge(out1.size(), out1.c_ptr(), - out2.size(), out2.c_ptr(), - out); - } - break; - } - TRACE("pb", tout << "sorting: " << n << "\n"; - pp(tout << "in:", n, xs) << "\n"; - pp(tout << "out:", out) << "\n";); - - } - vc vc_sorting(unsigned n) { - switch(n) { - case 0: return vc(0,0); - case 1: return vc(0,0); - case 2: return vc_merge(1,1); - default: - if (use_dsorting(n)) { - return vc_dsorting(n, n); - } - else { - return vc_sorting_rec(n); - } - } - } - vc vc_sorting_rec(unsigned n) { - SASSERT(n > 2); - unsigned l = n/2; - return vc_sorting(l) + vc_sorting(n-l) + vc_merge(l, n-l); - } - - bool use_dsorting(unsigned n) { - SASSERT(n > 2); - return m_force_dsorting || - (!m_disable_dsorting && n < 10 && vc_dsorting(n, n) < vc_sorting_rec(n)); - } - - void smerge(unsigned c, - unsigned a, literal const* as, - unsigned b, literal const* bs, - literal_vector& out) { - TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n";); - if (a == 1 && b == 1 && c == 1) { - literal y = max(as[0], bs[0]); - if (m_t != GE) { - // x1 <= max(x1,x2) - // x2 <= max(x1,x2) - add_clause(~as[0], y); - add_clause(~bs[0], y); - } - if (m_t != LE) { - // max(x1,x2) <= x1, x2 - add_clause(~y, as[0], bs[0]); - } - out.push_back(y); - } - else if (a == 0) { - out.append(std::min(c, b), bs); - } - else if (b == 0) { - out.append(std::min(c, a), as); - } - else if (a > c) { - smerge(c, c, as, b, bs, out); - } - else if (b > c) { - smerge(c, a, as, c, bs, out); - } - else if (a + b <= c) { - merge(a, as, b, bs, out); - } - else if (use_dsmerge(a, b, c)) { - dsmerge(c, a, as, b, bs, out); - } - else { - literal_vector even_a, odd_a; - literal_vector even_b, odd_b; - literal_vector out1, out2; - split(a, as, even_a, odd_a); - split(b, bs, even_b, odd_b); - SASSERT(!even_a.empty()); - SASSERT(!even_b.empty()); - unsigned c1, c2; - if (even(c)) { - c1 = 1 + c/2; c2 = c/2; - } - else { - c1 = (c + 1)/2; c2 = (c - 1)/2; - } - smerge(c1, even_a.size(), even_a.c_ptr(), - even_b.size(), even_b.c_ptr(), out1); - smerge(c2, odd_a.size(), odd_a.c_ptr(), - odd_b.size(), odd_b.c_ptr(), out2); - SASSERT(out1.size() == std::min(even_a.size()+even_b.size(), c1)); - SASSERT(out2.size() == std::min(odd_a.size()+odd_b.size(), c2)); - literal y; - if (even(c)) { - literal z1 = out1.back(); - literal z2 = out2.back(); - out1.pop_back(); - out2.pop_back(); - y = max(z1, z2); - if (m_t != GE) { - add_clause(~z1, y); - add_clause(~z2, y); - } - if (m_t != LE) { - add_clause(~y, z1, z2); - } - } - interleave(out1, out2, out); - if (even(c)) { - out.push_back(y); - } - } - TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n"; - pp(tout << "a:", a, as) << "\n"; - pp(tout << "b:", b, bs) << "\n"; - pp(tout << "out:", out) << "\n"; - ); - SASSERT(out.size() == std::min(a + b, c)); - } - - vc vc_smerge(unsigned a, unsigned b, unsigned c) { - if (a == 1 && b == 1 && c == 1) { - vc v(1,0); - if (m_t != GE) v = v + vc(0, 2); - if (m_t != LE) v = v + vc(0, 1); - return v; - } - if (a == 0 || b == 0) return vc(0, 0); - if (a > c) return vc_smerge(c, b, c); - if (b > c) return vc_smerge(a, c, c); - if (a + b <= c) return vc_merge(a, b); - if (use_dsmerge(a, b, c)) return vc_dsmerge(a, b, c); - return vc_smerge_rec(a, b, c); - } - vc vc_smerge_rec(unsigned a, unsigned b, unsigned c) { - return - vc_smerge(ceil2(a), ceil2(b), even(c)?(1+c/2):((c+1)/2)) + - vc_smerge(floor2(a), floor2(b), even(c)?(c/2):((c-1)/2)) + - vc_interleave(ceil2(a)+ceil2(b),floor2(a)+floor2(b)) + - vc(1, 0) + - ((m_t != GE)?vc(0, 2):vc(0, 0)) + - ((m_t != LE)?vc(0, 1):vc(0, 0)); - } - bool use_dsmerge(unsigned a, unsigned b, unsigned c) { - return - m_force_dsmerge || - (!m_disable_dsmerge && - a < (1 << 15) && b < (1 << 15) && - vc_dsmerge(a, b, a + b) < vc_smerge_rec(a, b, c)); - } - - void dsmerge( - unsigned c, - unsigned a, literal const* as, - unsigned b, literal const* bs, - literal_vector& out) { - TRACE("pb", tout << "dsmerge: c:" << c << " a:" << a << " b:" << b << "\n";); - SASSERT(a <= c); - SASSERT(b <= c); - SASSERT(a + b > c); - for (unsigned i = 0; i < c; ++i) { - out.push_back(fresh()); - } - if (m_t != GE) { - for (unsigned i = 0; i < a; ++i) { - add_clause(~as[i], out[i]); - } - for (unsigned i = 0; i < b; ++i) { - add_clause(~bs[i], out[i]); - } - for (unsigned i = 1; i <= a; ++i) { - for (unsigned j = 1; j <= b && i + j <= c; ++j) { - add_clause(~as[i-1],~bs[j-1],out[i+j-1]); - } - } - } - if (m_t != LE) { - for (unsigned k = 1; k <= c; ++k) { - literal_vector ls; - ls.push_back(~out[k-1]); - if (k <= a) { - ls.push_back(as[k-1]); - } - if (k <= b) { - ls.push_back(bs[k-1]); - } - for (unsigned i = 1; i <= std::min(a,k-1); ++i) { - if (k + 1 - i <= b) { - ls.push_back(as[i-1]); - ls.push_back(bs[k-i]); - add_clause(ls.size(), ls.c_ptr()); - ls.pop_back(); - ls.pop_back(); - } - } - } - } - } - vc vc_dsmerge(unsigned a, unsigned b, unsigned c) { - vc v(c, 0); - if (m_t != GE) { - v = v + vc(0, a + b + std::min(a, c)*std::min(b, c)/2); - } - if (m_t != LE) { - v = v + vc(0, std::min(a, c)*std::min(b, c)/2); - } - return v; - } - - - void dsorting(unsigned m, unsigned n, literal const* xs, - literal_vector& out) { - TRACE("pb", tout << "dsorting m: " << m << " n: " << n << "\n";); - SASSERT(m <= n); - literal_vector lits; - for (unsigned i = 0; i < m; ++i) { - out.push_back(fresh()); - } - if (m_t != GE) { - for (unsigned k = 1; k <= m; ++k) { - lits.push_back(out[k-1]); - add_subset(true, k, 0, lits, n, xs); - lits.pop_back(); - } - } - if (m_t != LE) { - for (unsigned k = 1; k <= m; ++k) { - lits.push_back(~out[k-1]); - add_subset(false, n-k+1, 0, lits, n, xs); - lits.pop_back(); - } - } - } - vc vc_dsorting(unsigned m, unsigned n) { - SASSERT(m <= n && n < 10); - vc v(m, 0); - if (m_t != GE) { - v = v + vc(0, power2(n-1)); - } - if (m_t != LE) { - v = v + vc(0, power2(n-1)); - } - return v; - } - - void add_subset(bool polarity, unsigned k, unsigned offset, literal_vector& lits, - unsigned n, literal const* xs) { - TRACE("pb", tout << "k:" << k << " offset: " << offset << " n: " << n << " "; - pp(tout, lits) << "\n";); - SASSERT(k + offset <= n); - if (k == 0) { - add_clause(lits.size(), lits.c_ptr()); - return; - } - for (unsigned i = offset; i < n - k + 1; ++i) { - lits.push_back(polarity?~xs[i]:xs[i]); - add_subset(polarity, k-1, i+1, lits, n, xs); - lits.pop_back(); - } - } - }; - - // for testing - literal theory_pb::assert_ge(context& ctx, unsigned k, unsigned n, literal const* xs) { - psort_nw sort(ctx); - return sort.ge(false, k, n, xs); - } - class pb_lit_rewriter_util { public: typedef std::pair arg_t; @@ -742,6 +62,7 @@ namespace smt { }; void theory_pb::ineq::negate() { + SASSERT(!m_is_eq); m_lit.neg(); numeral sum(0); for (unsigned i = 0; i < size(); ++i) { @@ -762,25 +83,28 @@ namespace smt { m_compiled = l_false; m_args.reset(); m_k.reset(); + m_nfixed = 0; + m_max_sum.reset(); + m_min_sum.reset(); } void theory_pb::ineq::unique() { pb_lit_rewriter_util pbu; pb_rewriter_util util(pbu); - util.unique(m_args, m_k); + util.unique(m_args, m_k, m_is_eq); } void theory_pb::ineq::prune() { pb_lit_rewriter_util pbu; pb_rewriter_util util(pbu); - util.prune(m_args, m_k); + util.prune(m_args, m_k, m_is_eq); } lbool theory_pb::ineq::normalize() { pb_lit_rewriter_util pbu; pb_rewriter_util util(pbu); - return util.normalize(m_args, m_k); + return util.normalize(m_args, m_k, m_is_eq); } app_ref theory_pb::ineq::to_expr(context& ctx, ast_manager& m) { @@ -794,7 +118,12 @@ namespace smt { coeffs.push_back(coeff(i)); } pb_util pb(m); - result = pb.mk_ge(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), k()); + if (m_is_eq) { + result = pb.mk_eq(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), k()); + } + else { + result = pb.mk_ge(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), k()); + } return result; } @@ -820,7 +149,8 @@ namespace smt { theory(m.mk_family_id("pb")), m_params(p), m_util(m), - m_lemma(null_literal) + m_lemma(null_literal, false), + m_max_compiled_coeff(rational(8)) { m_learn_complements = p.m_pb_learn_complements; m_conflict_frequency = p.m_pb_conflict_frequency; @@ -835,23 +165,26 @@ namespace smt { return alloc(theory_pb, new_ctx->get_manager(), m_params); } - bool theory_pb::internalize_atom(app * atom, bool gate_ctx) { - context& ctx = get_context(); - ast_manager& m = get_manager(); - unsigned num_args = atom->get_num_args(); - SASSERT(m_util.is_at_most_k(atom) || m_util.is_le(atom) || m_util.is_ge(atom)); + bool theory_pb::internalize_atom(app * atom, bool gate_ctx) { + SASSERT(m_util.is_at_most_k(atom) || m_util.is_le(atom) || + m_util.is_ge(atom) || m_util.is_at_least_k(atom) || + m_util.is_eq(atom)); + + context& ctx = get_context(); if (ctx.b_internalized(atom)) { return false; } + SASSERT(!ctx.b_internalized(atom)); m_stats.m_num_predicates++; - SASSERT(!ctx.b_internalized(atom)); + ast_manager& m = get_manager(); + unsigned num_args = atom->get_num_args(); bool_var abv = ctx.mk_bool_var(atom); ctx.set_var_theory(abv, get_id()); - ineq* c = alloc(ineq, literal(abv)); + ineq* c = alloc(ineq, literal(abv), m_util.is_eq(atom)); c->m_k = m_util.get_k(atom); numeral& k = c->m_k; arg_t& args = c->m_args; @@ -871,11 +204,17 @@ namespace smt { k = -k; } else { - SASSERT(m_util.is_at_least_k(atom) || m_util.is_ge(atom)); + SASSERT(m_util.is_at_least_k(atom) || m_util.is_ge(atom) || m_util.is_eq(atom)); } c->unique(); lbool is_true = c->normalize(); c->prune(); +#if 1 + if (c->is_ge() && is_true == l_undef) { + c->negate(); // Hack: negation further normalizes + c->negate(); + } +#endif literal lit(abv); @@ -892,7 +231,7 @@ namespace smt { break; } - if (c->k().is_one()) { + if (c->k().is_one() && c->is_ge()) { literal_vector& lits = get_lits(); lits.push_back(~lit); for (unsigned i = 0; i < c->size(); ++i) { @@ -911,11 +250,10 @@ namespace smt { max_watch = std::max(max_watch, args[i].second); } - // pre-compile threshold for cardinality - bool enable_compile = m_enable_compilation; + bool enable_compile = m_enable_compilation && c->is_ge(); for (unsigned i = 0; enable_compile && i < args.size(); ++i) { - enable_compile = (args[i].second < rational(8)); + enable_compile = (args[i].second <= m_max_compiled_coeff); } if (enable_compile) { unsigned log = 1, n = 1; @@ -931,6 +269,7 @@ namespace smt { else { c->m_compilation_threshold = UINT_MAX; } + init_watch_var(*c); m_ineqs.insert(abv, c); m_ineqs_trail.push_back(abv); @@ -967,10 +306,9 @@ namespace smt { negate = !negate; } - // assumes relevancy level = 2 or 0. // TBD: should should have been like an uninterpreted - // function intenalize, where enodes for each argument + // function internalize, where enodes for each argument // is available. if (!has_bv) { expr_ref tmp(m), fml(m); @@ -990,6 +328,7 @@ namespace smt { } void theory_pb::del_watch(watch_list& watch, unsigned index, ineq& c, unsigned ineq_index) { + SASSERT(c.is_ge()); if (index < watch.size()) { std::swap(watch[index], watch[watch.size()-1]); } @@ -1014,8 +353,8 @@ namespace smt { } void theory_pb::add_watch(ineq& c, unsigned i) { + SASSERT(c.is_ge()); literal lit = c.lit(i); - bool_var v = lit.var(); numeral coeff = c.coeff(i); c.m_watch_sum += coeff; SASSERT(i >= c.watch_size()); @@ -1027,13 +366,50 @@ namespace smt { if (coeff > c.max_watch()) { c.set_max_watch(coeff); } + watch_literal(lit, &c); + } + void theory_pb::watch_literal(literal lit, ineq* c) { ptr_vector* ineqs; - if (!m_watch.find(lit.index(), ineqs)) { + if (!m_lwatch.find(lit.index(), ineqs)) { ineqs = alloc(ptr_vector); - m_watch.insert(lit.index(), ineqs); + m_lwatch.insert(lit.index(), ineqs); + } + ineqs->push_back(c); + } + + void theory_pb::watch_var(bool_var v, ineq* c) { + ptr_vector* ineqs; + if (!m_vwatch.find(v, ineqs)) { + ineqs = alloc(ptr_vector); + m_vwatch.insert(v, ineqs); + } + ineqs->push_back(c); + } + + void theory_pb::unwatch_var(bool_var v, ineq* c) { + ptr_vector* ineqs = 0; + if (m_vwatch.find(v, ineqs)) { + remove(*ineqs, c); + } + } + + void theory_pb::unwatch_literal(literal w, ineq* c) { + ptr_vector* ineqs = 0; + if (m_lwatch.find(w.index(), ineqs)) { + remove(*ineqs, c); + } + } + + void theory_pb::remove(ptr_vector& ineqs, ineq* c) { + context& ctx = get_context(); + for (unsigned j = 0; j < ineqs.size(); ++j) { + if (ineqs[j] == c) { + std::swap(ineqs[j], ineqs[ineqs.size()-1]); + ineqs.pop_back(); + break; + } } - ineqs->push_back(&c); } void theory_pb::collect_statistics(::statistics& st) const { @@ -1048,7 +424,11 @@ namespace smt { void theory_pb::reset_eh() { // m_watch; - u_map*>::iterator it = m_watch.begin(), end = m_watch.end(); + u_map*>::iterator it = m_lwatch.begin(), end = m_lwatch.end(); + for (; it != end; ++it) { + dealloc(it->m_value); + } + it = m_vwatch.begin(), end = m_vwatch.end(); for (; it != end; ++it) { dealloc(it->m_value); } @@ -1056,18 +436,17 @@ namespace smt { for (; itc != endc; ++itc) { dealloc(itc->m_value); } - m_watch.reset(); + m_lwatch.reset(); + m_vwatch.reset(); m_ineqs.reset(); m_ineqs_trail.reset(); m_ineqs_lim.reset(); - m_assign_ineqs_trail.reset(); - m_assign_ineqs_lim.reset(); m_stats.reset(); m_to_compile.reset(); } void theory_pb::new_eq_eh(theory_var v1, theory_var v2) { - IF_VERBOSE(0, verbose_stream() << v1 << " = " << v2 << "\n";); + UNREACHABLE(); } final_check_status theory_pb::final_check_eh() { @@ -1081,20 +460,53 @@ namespace smt { ptr_vector* ineqs = 0; literal nlit(v, is_true); TRACE("pb", tout << "assign: " << ~nlit << "\n";); - if (m_watch.find(nlit.index(), ineqs)) { + if (m_lwatch.find(nlit.index(), ineqs)) { for (unsigned i = 0; i < ineqs->size(); ++i) { - if (assign_watch(v, is_true, *ineqs, i)) { + ineq* c = (*ineqs)[i]; + SASSERT(c->is_ge()); + if (assign_watch_ge(v, is_true, *ineqs, i)) { // i was removed from watch list. --i; } } } + if (m_vwatch.find(v, ineqs)) { + for (unsigned i = 0; i < ineqs->size(); ++i) { + ineq* c = (*ineqs)[i]; + assign_watch(v, is_true, *c); + } + } ineq* c = 0; if (m_ineqs.find(v, c)) { - assign_ineq(*c, is_true); + if (c->is_ge()) { + assign_ineq(*c, is_true); + } + else { + assign_eq(*c, is_true); + } } } + literal_vector& theory_pb::get_all_literals(ineq& c, bool negate) { + context& ctx = get_context(); + literal_vector& lits = get_lits(); + for (unsigned i = 0; i < c.size(); ++i) { + literal l = c.lit(i); + switch(ctx.get_assignment(l)) { + case l_true: + lits.push_back(negate?(~l):l); + break; + case l_false: + lits.push_back(negate?l:(~l)); + break; + default: + break; + } + } + return lits; + + } + literal_vector& theory_pb::get_helpful_literals(ineq& c, bool negate) { numeral sum = numeral::zero(); context& ctx = get_context(); @@ -1124,6 +536,27 @@ namespace smt { return lits; } + class theory_pb::rewatch_vars : public trail { + theory_pb& pb; + ineq& c; + public: + rewatch_vars(theory_pb& p, ineq& c): pb(p), c(c) {} + virtual void undo(context& ctx) { + for (unsigned i = 0; i < c.size(); ++i) { + pb.watch_var(c.lit(i).var(), &c); + } + } + }; + + class theory_pb::negate_ineq : public trail { + ineq& c; + public: + negate_ineq(ineq& c): c(c) {} + virtual void undo(context& ctx) { + c.negate(); + } + }; + /** \brief propagate assignment to inequality. This is a basic, non-optimized implementation based @@ -1131,13 +564,22 @@ namespace smt { and/or relatively few compared to number of argumets. */ void theory_pb::assign_ineq(ineq& c, bool is_true) { + context& ctx = get_context(); + ctx.push_trail(value_trail(c.m_max_sum)); + ctx.push_trail(value_trail(c.m_min_sum)); + ctx.push_trail(value_trail(c.m_nfixed)); + ctx.push_trail(rewatch_vars(*this, c)); + clear_watch(c); + SASSERT(c.is_ge()); if (c.lit().sign() == is_true) { + unsigned sz = c.size(); c.negate(); + SASSERT(sz == c.size()); + ctx.push_trail(negate_ineq(c)); } SASSERT(c.well_formed()); - context& ctx = get_context(); numeral maxsum = numeral::zero(); numeral mininc = numeral::zero(); for (unsigned i = 0; i < c.size(); ++i) { @@ -1160,16 +602,8 @@ namespace smt { add_clause(c, lits); } else { - c.m_watch_sum = numeral::zero(); - c.m_watch_sz = 0; - c.m_max_watch = numeral::zero(); - for (unsigned i = 0; c.watch_sum() < c.k() + c.max_watch() && i < c.size(); ++i) { - if (ctx.get_assignment(c.lit(i)) != l_false) { - add_watch(c, i); - } - } + init_watch_literal(c); SASSERT(c.watch_sum() >= c.k()); - m_assign_ineqs_trail.push_back(&c); DEBUG_CODE(validate_watch(c);); } @@ -1186,19 +620,123 @@ namespace smt { } } + /** + \brief propagate assignment to equality. + */ + void theory_pb::assign_eq(ineq& c, bool is_true) { + SASSERT(c.is_eq()); + + } + + /** + Propagation rules: + + nfixed = N & minsum = k -> T + nfixed = N & minsum != k -> F + + minsum > k or maxsum < k -> F + minsum = k & = -> fix 0 variables + nfixed+1 = N & = -> fix unassigned variable or conflict + nfixed+1 = N & != -> maybe forced unassigned to ensure disequal + minsum >= k -> T + maxsum < k -> F + */ + + void theory_pb::assign_watch(bool_var v, bool is_true, ineq& c) { + + context& ctx = get_context(); + unsigned i; + literal l = c.lit(); + lbool asgn = ctx.get_assignment(l); + + if (c.max_sum() < c.k() && asgn == l_false) { + return; + } + if (c.is_ge() && c.min_sum() >= c.k() && asgn == l_true) { + return; + } + for (i = 0; i < c.size(); ++i) { + if (c.lit(i).var() == v) { + break; + } + } + + TRACE("pb", display(tout << "assign watch " << literal(v,!is_true) << " ", c, true);); + + SASSERT(i < c.size()); + if (c.lit(i).sign() == is_true) { + ctx.push_trail(value_trail(c.m_max_sum)); + c.m_max_sum -= c.coeff(i); + } + else { + ctx.push_trail(value_trail(c.m_min_sum)); + c.m_min_sum += c.coeff(i); + } + DEBUG_CODE( + numeral sum(0); + numeral maxs(0); + for (unsigned i = 0; i < c.size(); ++i) { + if (ctx.get_assignment(c.lit(i)) == l_true) sum += c.coeff(i); + if (ctx.get_assignment(c.lit(i)) != l_false) maxs += c.coeff(i); + } + CTRACE("pb", (maxs > c.max_sum()), display(tout, c, true);); + SASSERT(c.min_sum() <= sum); + SASSERT(sum <= maxs); + SASSERT(maxs <= c.max_sum()); + ); + SASSERT(c.min_sum() <= c.max_sum()); + SASSERT(!c.min_sum().is_neg()); + ctx.push_trail(value_trail(c.m_nfixed)); + ++c.m_nfixed; + SASSERT(c.nfixed() <= c.size()); + if (c.is_ge() && c.min_sum() >= c.k() && asgn != l_true) { + TRACE("pb", display(tout << "Set " << l << "\n", c, true);); + add_assign(c, get_helpful_literals(c, false), l); + } + else if (c.max_sum() < c.k() && asgn != l_false) { + TRACE("pb", display(tout << "Set " << ~l << "\n", c, true);); + add_assign(c, get_unhelpful_literals(c, true), ~l); + } + else if (c.is_eq() && c.nfixed() == c.size() && c.min_sum() == c.k() && asgn != l_true) { + TRACE("pb", display(tout << "Set " << l << "\n", c, true);); + add_assign(c, get_all_literals(c, false), l); + } + else if (c.is_eq() && c.nfixed() == c.size() && c.min_sum() != c.k() && asgn != l_false) { + TRACE("pb", display(tout << "Set " << ~l << "\n", c, true);); + add_assign(c, get_all_literals(c, false), ~l); + } +#if 0 + else if (c.is_eq() && c.min_sum() > c.k() && asgn != l_false) { + TRACE("pb", display(tout << "Set " << ~l << "\n", c, true);); + add_assign(c, get_all_literals(c, false), ~l); + } + else if (c.is_eq() && asgn == l_true && c.min_sum() == c.k() && c.max_sum() > c.k()) { + literal_vector& lits = get_all_literals(c, false); + lits.push_back(c.lit()); + for (unsigned i = 0; i < c.size(); ++i) { + if (ctx.get_assignment(c.lit(i)) == l_undef) { + add_assign(c, lits, ~c.lit(i)); + } + } + } +#endif + else { + IF_VERBOSE(3, display(verbose_stream() << "no propagation ", c, true);); + } + } + + /** \brief v is assigned in inequality c. Update current bounds and watch list. Optimize for case where the c.lit() is True. This covers the case where inequalities are unit literals and formulas in negation normal form - (inequalities are closed under negation). - + (inequalities are closed under negation). */ - bool theory_pb::assign_watch(bool_var v, bool is_true, watch_list& watch, unsigned watch_index) { + bool theory_pb::assign_watch_ge(bool_var v, bool is_true, watch_list& watch, unsigned watch_index) { bool removed = false; context& ctx = get_context(); ineq& c = *watch[watch_index]; unsigned w = c.find_lit(v, 0, c.watch_size()); - SASSERT(ctx.get_assignment(c.lit()) == l_true); SASSERT(is_true == c.lit(w).sign()); @@ -1241,14 +779,15 @@ namespace smt { // Create clauses x1 or ~L or x2 // x1 or ~L or x4 // - + literal_vector& lits = get_unhelpful_literals(c, true); lits.push_back(c.lit()); numeral deficit = c.watch_sum() - k; for (unsigned i = 0; i < c.size(); ++i) { if (ctx.get_assignment(c.lit(i)) == l_undef && deficit < c.coeff(i)) { DEBUG_CODE(validate_assign(c, lits, c.lit(i));); - add_assign(c, lits, c.lit(i)); + add_assign(c, lits, c.lit(i)); + // break; } } } @@ -1265,10 +804,11 @@ namespace smt { return removed; } + // plugin for simple sorting network struct theory_pb::sort_expr { - theory_pb& th; - context& ctx; - ast_manager& m; + theory_pb& th; + context& ctx; + ast_manager& m; expr_ref_vector m_trail; sort_expr(theory_pb& th): th(th), @@ -1396,9 +936,63 @@ namespace smt { void add_clause(literal l1, literal l2) { add_clause(l1, l2, null_literal); } - }; + struct theory_pb::psort_expr { + context& ctx; + ast_manager& m; + typedef literal literal; + typedef literal_vector literal_vector; + + psort_expr(context& c): + ctx(c), + m(c.get_manager()) {} + + literal fresh() { + app_ref y(m); + y = m.mk_fresh_const("y", m.mk_bool_sort()); + return literal(ctx.mk_bool_var(y)); + } + + literal max(literal a, literal b) { + if (a == b) return a; + expr_ref t1(m), t2(m), t3(m); + ctx.literal2expr(a, t1); + ctx.literal2expr(b, t2); + t3 = m.mk_or(t1, t2); + bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); + return literal(v); + } + + literal min(literal a, literal b) { + if (a == b) return a; + expr_ref t1(m), t2(m), t3(m); + ctx.literal2expr(a, t1); + ctx.literal2expr(b, t2); + t3 = m.mk_and(t1, t2); + bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); + return literal(v); + } + + void mk_clause(unsigned n, literal const* ls) { + literal_vector tmp(n, ls); + ctx.mk_clause(n, tmp.c_ptr(), 0, CLS_AUX, 0); + } + + literal mk_false() { return false_literal; } + literal mk_true() { return true_literal; } + + std::ostream& pp(std::ostream& out, literal l) { return out << l; } + + }; + + // for testing + literal theory_pb::assert_ge(context& ctx, unsigned k, unsigned n, literal const* xs) { + psort_expr ps(ctx); + psort_nw sort(ps); + return sort.ge(false, k, n, xs); + } + void theory_pb::inc_propagations(ineq& c) { ++c.m_num_propagations; @@ -1415,6 +1009,7 @@ namespace smt { m_to_compile.reset(); } + void theory_pb::compile_ineq(ineq& c) { ++m_stats.m_num_compiles; ast_manager& m = get_manager(); @@ -1452,7 +1047,8 @@ namespace smt { } if (ctx.get_assignment(thl) == l_true && ctx.get_assign_level(thl) == ctx.get_base_level()) { - psort_nw sortnw(ctx); + psort_expr ps(ctx); + psort_nw sortnw(ps); sortnw.m_stats.reset(); at_least_k = sortnw.ge(false, k, in.size(), in.c_ptr()); ctx.mk_clause(~thl, at_least_k, 0); @@ -1460,7 +1056,8 @@ namespace smt { m_stats.m_num_compiled_clauses += sortnw.m_stats.m_num_compiled_clauses; } else { - psort_nw sortnw(ctx); + psort_expr ps(ctx); + psort_nw sortnw(ps); sortnw.m_stats.reset(); literal at_least_k = sortnw.ge(true, k, in.size(), in.c_ptr()); ctx.mk_clause(~thl, at_least_k, 0); @@ -1507,46 +1104,82 @@ namespace smt { void theory_pb::push_scope_eh() { m_ineqs_lim.push_back(m_ineqs_trail.size()); - m_assign_ineqs_lim.push_back(m_assign_ineqs_trail.size()); } void theory_pb::pop_scope_eh(unsigned num_scopes) { - // remove watched literals. - unsigned new_lim = m_assign_ineqs_lim.size()-num_scopes; - unsigned sz = m_assign_ineqs_lim[new_lim]; - while (m_assign_ineqs_trail.size() > sz) { - ineq* c = m_assign_ineqs_trail.back(); - for (unsigned i = 0; i < c->watch_size(); ++i) { - literal w = c->lit(i); - ptr_vector* ineqs = 0; - VERIFY(m_watch.find(w.index(), ineqs)); - for (unsigned j = 0; j < ineqs->size(); ++j) { - if ((*ineqs)[j] == c) { - std::swap((*ineqs)[j],(*ineqs)[ineqs->size()-1]); - ineqs->pop_back(); - break; - } - } - } - m_assign_ineqs_trail.pop_back(); - } - m_assign_ineqs_lim.resize(new_lim); - // remove inequalities. - new_lim = m_ineqs_lim.size()-num_scopes; - sz = m_ineqs_lim[new_lim]; + unsigned new_lim = m_ineqs_lim.size()-num_scopes; + unsigned sz = m_ineqs_lim[new_lim]; while (m_ineqs_trail.size() > sz) { bool_var v = m_ineqs_trail.back(); ineq* c = 0; VERIFY(m_ineqs.find(v, c)); + clear_watch(*c); m_ineqs.remove(v); - m_ineqs_trail.pop_back(); + m_ineqs_trail.pop_back(); dealloc(c); } m_ineqs_lim.resize(new_lim); } + void theory_pb::clear_watch(ineq& c) { + for (unsigned i = 0; i < c.size(); ++i) { + literal w = c.lit(i); + unwatch_var(w.var(), &c); + unwatch_literal(w, &c); + } + c.m_watch_sum.reset(); + c.m_watch_sz = 0; + c.m_max_watch.reset(); + c.m_nfixed = 0; + c.m_max_sum.reset(); + c.m_min_sum.reset(); + } + + class theory_pb::unwatch_ge : public trail { + theory_pb& pb; + ineq& c; + public: + unwatch_ge(theory_pb& p, ineq& c): pb(p), c(c) {} + + virtual void undo(context& ctx) { + for (unsigned i = 0; i < c.watch_size(); ++i) { + pb.unwatch_literal(c.lit(i), &c); + } + c.m_watch_sz = 0; + c.m_watch_sum.reset(); + c.m_max_watch.reset(); + } + }; + + + void theory_pb::init_watch_literal(ineq& c) { + context& ctx = get_context(); + c.m_watch_sum = numeral::zero(); + c.m_watch_sz = 0; + c.m_max_watch = numeral::zero(); + for (unsigned i = 0; c.watch_sum() < c.k() + c.max_watch() && i < c.size(); ++i) { + if (ctx.get_assignment(c.lit(i)) != l_false) { + add_watch(c, i); + } + } + ctx.push_trail(unwatch_ge(*this, c)); + } + + void theory_pb::init_watch_var(ineq& c) { + c.m_min_sum.reset(); + c.m_max_sum.reset(); + c.m_nfixed = 0; + c.m_watch_sum.reset(); + c.m_max_watch.reset(); + c.m_watch_sz = 0; + for (unsigned i = 0; i < c.size(); ++i) { + watch_var(c.lit(i).var(), &c); + c.m_max_sum += c.coeff(i); + } + } + literal_vector& theory_pb::get_lits() { m_literals.reset(); return m_literals; @@ -1574,7 +1207,6 @@ namespace smt { tout << "=> " << l << "\n"; display(tout, c, true);); - ctx.assign(l, ctx.mk_justification( pb_justification( c, get_id(), ctx.get_region(), lits.size(), lits.c_ptr(), l))); @@ -1586,6 +1218,11 @@ namespace smt { inc_propagations(c); m_stats.m_num_conflicts++; context& ctx = get_context(); +#if 0 + if (m_stats.m_num_conflicts == 1000) { + display(std::cout); + } +#endif TRACE("pb", tout << "#prop:" << c.m_num_propagations << " - "; for (unsigned i = 0; i < lits.size(); ++i) { tout << lits[i] << " "; @@ -1720,7 +1357,10 @@ namespace smt { // modeled after sat_solver/smt_context // bool theory_pb::resolve_conflict(ineq& c) { - + + if (!c.is_ge()) { + return false; + } TRACE("pb", display(tout, c, true);); bool_var v; @@ -1756,8 +1396,8 @@ namespace smt { break; } if (is_sat == l_true) { - IF_VERBOSE(0, verbose_stream() << "lemma already evaluated ";); - TRACE("pb", tout << "lemma already evaluated ";); + IF_VERBOSE(0, verbose_stream() << "lemma already evaluated\n";); + TRACE("pb", tout << "lemma already evaluated\n";); return false; } TRACE("pb", display(tout, m_lemma, true);); @@ -1838,18 +1478,26 @@ namespace smt { TRACE("pb", tout << "axiom " << conseq << "\n";); break; case b_justification::JUSTIFICATION: { - justification& j = *js.get_justification(); - if (j.get_from_theory() != get_id()) { - TRACE("pb", tout << "skipping justification for " << conseq - << " from theory " << j.get_from_theory() << " " - << typeid(j).name() << "\n";); - m_ineq_literals.push_back(conseq); + justification* j = js.get_justification(); + pb_justification* pbj = 0; + + if (!conseq.sign() && j->get_from_theory() == get_id()) { + pbj = dynamic_cast(j); + } + if (pbj && pbj->get_ineq().is_ge()) { + // only resolve >= that are positive consequences. + pbj = 0; + } + if (pbj) { + // weaken the lemma and resolve. + TRACE("pb", display(tout << "resolve with inequality", pbj->get_ineq(), true);); + process_ineq(pbj->get_ineq(), conseq, conseq_coeff); } else { - pb_justification& pbj = dynamic_cast(j); - // weaken the lemma and resolve. - TRACE("pb", display(tout << "resolve with inequality", pbj.get_ineq(), true);); - process_ineq(pbj.get_ineq(), conseq, conseq_coeff); + TRACE("pb", tout << "skipping justification for " << conseq + << " from theory " << j->get_from_theory() << " " + << typeid(*j).name() << "\n";); + m_ineq_literals.push_back(conseq); } break; } @@ -2002,8 +1650,9 @@ namespace smt { ); SASSERT(sum <= maxsum); - SASSERT((sum >= c.k()) == (ctx.get_assignment(c.lit()) == l_true)); - SASSERT((maxsum < c.k()) == (ctx.get_assignment(c.lit()) == l_false)); + SASSERT(!c.is_ge() || (sum >= c.k()) == (ctx.get_assignment(c.lit()) == l_true)); + SASSERT(!c.is_ge() || (maxsum < c.k()) == (ctx.get_assignment(c.lit()) == l_false)); + SASSERT(!c.is_eq() || (sum == c.k()) == (ctx.get_assignment(c.lit()) == l_true)); } // display methods @@ -2083,12 +1732,14 @@ namespace smt { out << " + "; } } - out << " >= " << c.m_k << "\n"; + out << (c.is_ge()?" >= ":" = ") << c.m_k << "\n"; if (c.m_num_propagations) out << "propagations: " << c.m_num_propagations << " "; if (c.max_watch().is_pos()) out << "max_watch: " << c.max_watch() << " "; if (c.watch_size()) out << "watch size: " << c.watch_size() << " "; if (c.watch_sum().is_pos()) out << "watch-sum: " << c.watch_sum() << " "; - if (c.m_num_propagations || c.max_watch().is_pos() || c.watch_size() || c.watch_sum().is_pos()) out << "\n"; + if (!c.max_sum().is_zero()) out << "sum: [" << c.min_sum() << ":" << c.max_sum() << "] "; + if (c.m_num_propagations || c.max_watch().is_pos() || c.watch_size() || + c.watch_sum().is_pos() || !c.max_sum().is_zero()) out << "\n"; return out; } @@ -2174,7 +1825,7 @@ namespace smt { } void theory_pb::display(std::ostream& out) const { - u_map*>::iterator it = m_watch.begin(), end = m_watch.end(); + u_map*>::iterator it = m_lwatch.begin(), end = m_lwatch.end(); for (; it != end; ++it) { out << "watch: " << to_literal(it->m_key) << " |-> "; watch_list const& wl = *it->m_value; @@ -2183,10 +1834,19 @@ namespace smt { } out << "\n"; } + it = m_vwatch.begin(), end = m_vwatch.end(); + for (; it != end; ++it) { + out << "watch (v): " << literal(it->m_key) << " |-> "; + watch_list const& wl = *it->m_value; + for (unsigned i = 0; i < wl.size(); ++i) { + out << wl[i]->lit() << " "; + } + out << "\n"; + } u_map::iterator itc = m_ineqs.begin(), endc = m_ineqs.end(); for (; itc != endc; ++itc) { ineq& c = *itc->m_value; - display(out, c); + display(out, c, true); } } diff --git a/src/smt/theory_pb.h b/src/smt/theory_pb.h index 4cb0a337e..922ab76f8 100644 --- a/src/smt/theory_pb.h +++ b/src/smt/theory_pb.h @@ -29,8 +29,12 @@ namespace smt { class theory_pb : public theory { struct sort_expr; + struct psort_expr; class pb_justification; - class pb_model_value_proc; + class pb_model_value_proc; + class unwatch_ge; + class rewatch_vars; + class negate_ineq; typedef rational numeral; typedef vector > arg_t; @@ -48,20 +52,24 @@ namespace smt { struct ineq { literal m_lit; // literal repesenting predicate + bool m_is_eq; // is this an = or >=. arg_t m_args; // encode args[0]*coeffs[0]+...+args[n-1]*coeffs[n-1] >= m_k; numeral m_k; // invariants: m_k > 0, coeffs[i] > 0 // Watch the first few positions until the sum satisfies: - // sum coeffs[i] >= m_lower + max_watch - + // sum coeffs[i] >= m_lower + max_watch numeral m_max_watch; // maximal coefficient. unsigned m_watch_sz; // number of literals being watched. - numeral m_watch_sum; // maximal sum of watch literals. + numeral m_watch_sum; // maximal sum of watch literals. + // Watch infrastructure for = and unassigned >=: + unsigned m_nfixed; // number of variables that are fixed. + numeral m_max_sum; // maximal possible sum. + numeral m_min_sum; // minimal possible sum. unsigned m_num_propagations; unsigned m_compilation_threshold; lbool m_compiled; - ineq(literal l) : m_lit(l) { + ineq(literal l, bool is_eq) : m_lit(l), m_is_eq(is_eq) { reset(); } @@ -75,10 +83,16 @@ namespace smt { numeral const& watch_sum() const { return m_watch_sum; } numeral const& max_watch() const { return m_max_watch; } - void set_max_watch(numeral const& n) { m_max_watch = n; } - + void set_max_watch(numeral const& n) { m_max_watch = n; } unsigned watch_size() const { return m_watch_sz; } + // variable watch infrastructure + numeral min_sum() const { return m_min_sum; } + numeral max_sum() const { return m_max_sum; } + unsigned nfixed() const { return m_nfixed; } + bool vwatch_initialized() const { return !max_sum().is_zero(); } + void vwatch_reset() { m_min_sum.reset(); m_max_sum.reset(); m_nfixed = 0; } + unsigned find_lit(bool_var v, unsigned begin, unsigned end) { while (lit(begin).var() != v) { ++begin; @@ -100,17 +114,19 @@ namespace smt { bool well_formed() const; app_ref to_expr(context& ctx, ast_manager& m); + + bool is_eq() const { return m_is_eq; } + bool is_ge() const { return !m_is_eq; } }; typedef ptr_vector watch_list; theory_pb_params m_params; - u_map m_watch; // per literal. + u_map m_lwatch; // per literal. + u_map m_vwatch; // per variable. u_map m_ineqs; // per inequality. unsigned_vector m_ineqs_trail; unsigned_vector m_ineqs_lim; - ptr_vector m_assign_ineqs_trail; - unsigned_vector m_assign_ineqs_lim; literal_vector m_literals; // temporary vector pb_util m_util; stats m_stats; @@ -118,13 +134,24 @@ namespace smt { unsigned m_conflict_frequency; bool m_learn_complements; bool m_enable_compilation; + rational m_max_compiled_coeff; // internalize_atom: literal compile_arg(expr* arg); void add_watch(ineq& c, unsigned index); void del_watch(watch_list& watch, unsigned index, ineq& c, unsigned ineq_index); - bool assign_watch(bool_var v, bool is_true, watch_list& watch, unsigned index); + void init_watch_literal(ineq& c); + void init_watch_var(ineq& c); + void clear_watch(ineq& c); + void watch_literal(literal lit, ineq* c); + void watch_var(bool_var v, ineq* c); + void unwatch_literal(literal w, ineq* c); + void unwatch_var(bool_var v, ineq* c); + void remove(ptr_vector& ineqs, ineq* c); + bool assign_watch_ge(bool_var v, bool is_true, watch_list& watch, unsigned index); + void assign_watch(bool_var v, bool is_true, ineq& c); void assign_ineq(ineq& c, bool is_true); + void assign_eq(ineq& c, bool is_true); std::ostream& display(std::ostream& out, ineq const& c, bool values = false) const; virtual void display(std::ostream& out) const; @@ -134,6 +161,7 @@ namespace smt { void add_assign(ineq& c, literal_vector const& lits, literal l); literal_vector& get_lits(); + literal_vector& get_all_literals(ineq& c, bool negate); literal_vector& get_helpful_literals(ineq& c, bool negate); literal_vector& get_unhelpful_literals(ineq& c, bool negate); diff --git a/src/tactic/arith/elim01_tactic.cpp b/src/tactic/arith/elim01_tactic.cpp index af427ccc5..e6c345734 100644 --- a/src/tactic/arith/elim01_tactic.cpp +++ b/src/tactic/arith/elim01_tactic.cpp @@ -26,10 +26,12 @@ Notes: #include"model_smt2_pp.h" class bool2int_model_converter : public model_converter { - ast_manager& m; - arith_util a; - func_decl_ref_vector m_refs; - obj_map m_bool2int; + ast_manager& m; + arith_util a; + func_decl_ref_vector m_refs; + obj_hashtable m_bools; + vector > m_nums_as_bool; + ptr_vector m_nums_as_int; public: bool2int_model_converter(ast_manager& m): @@ -42,26 +44,34 @@ public: SASSERT(goal_idx == 0); model * new_model = alloc(model, m); unsigned num = old_model->get_num_constants(); + for (unsigned i = 0; i < m_nums_as_int.size(); ++i) { + func_decl* f_old = m_nums_as_int[i]; + rational val(0); + rational po(1); + bool is_value = true; + for (unsigned j = 0; is_value && j < m_nums_as_bool[i].size(); ++j) { + func_decl* f = m_nums_as_bool[i][j]; + expr* fi = old_model->get_const_interp(f); + if (!fi) { + is_value = false; + } + else if (m.is_true(fi)) { + val += po; + } + else if (!m.is_false(fi)) { + is_value = false; + } + po *= rational(2); + } + if (is_value) { + expr* fi = a.mk_numeral(val, true); + new_model->register_decl(f_old, fi); + } + } for (unsigned i = 0; i < num; ++i) { func_decl* f = old_model->get_constant(i); expr* fi = old_model->get_const_interp(f); - func_decl* f_old; - if (m_bool2int.find(f, f_old)) { - if (!fi) { - // no-op - } - else if (m.is_false(fi)) { - fi = a.mk_numeral(rational(0), true); - } - else if (m.is_true(fi)) { - fi = a.mk_numeral(rational(1), true); - } - else { - fi = 0; - } - new_model->register_decl(f_old, fi); - } - else { + if (!m_bools.contains(f)) { new_model->register_decl(f, fi); } } @@ -78,15 +88,27 @@ public: void insert(func_decl* x_new, func_decl* x_old) { m_refs.push_back(x_new); m_refs.push_back(x_old); - m_bool2int.insert(x_new, x_old); + m_bools.insert(x_new); + m_nums_as_int.push_back(x_old); + m_nums_as_bool.push_back(ptr_vector()); + m_nums_as_bool.back().push_back(x_new); } + void insert(func_decl* x_old, unsigned sz, func_decl * const* x_new) { + m_nums_as_int.push_back(x_old); + m_nums_as_bool.push_back(ptr_vector()); + m_refs.push_back(x_old); + for (unsigned i = 0; i < sz; ++i) { + m_refs.push_back(x_new[i]); + m_nums_as_bool.back().push_back(x_new[i]); + m_bools.insert(x_new[i]); + } + } virtual model_converter * translate(ast_translation & translator) { bool2int_model_converter* mc = alloc(bool2int_model_converter, translator.to()); - obj_map::iterator it = m_bool2int.begin(), end = m_bool2int.end(); - for (; it != end; ++it) { - mc->insert(translator(it->m_key), translator(it->m_value)); + for (unsigned i = 0; i < m_nums_as_int.size(); ++i) { + mc->insert(m_nums_as_int[i], m_nums_as_bool[i].size(), m_nums_as_bool[i].c_ptr()); } return mc; } @@ -99,10 +121,14 @@ public: ast_manager & m; arith_util a; params_ref m_params; + unsigned m_max_hi_default; + rational m_max_hi; elim01_tactic(ast_manager & _m, params_ref const & p): m(_m), - a(m) { + a(m), + m_max_hi_default(8), + m_max_hi(rational(m_max_hi_default)) { } virtual ~elim01_tactic() { @@ -111,9 +137,15 @@ public: void set_cancel(bool f) { } - void updt_params(params_ref const & p) { + virtual void updt_params(params_ref const & p) { + m_max_hi = rational(p.get_uint("max_coefficient", m_max_hi_default)); m_params = p; } + + virtual void collect_param_descrs(param_descrs & r) { + r.insert("max_coefficient", CPK_UINT, "(default: 1) maximal upper bound for finite range -> Bool conversion"); + } + virtual void operator()(goal_ref const & g, goal_ref_buffer & result, @@ -129,6 +161,7 @@ public: bool2int_model_converter* b2i = alloc(bool2int_model_converter, m); mc = b2i; bound_manager bounds(m); + expr_ref_vector axioms(m); bounds(*g); bound_manager::iterator bit = bounds.begin(), bend = bounds.end(); @@ -139,11 +172,8 @@ public: rational lo, hi; if (a.is_int(x) && bounds.has_lower(x, lo, s1) && !s1 && lo.is_zero() && - bounds.has_upper(x, hi, s2) && !s2 && hi.is_one()) { - app* x_new = m.mk_fresh_const(x->get_decl()->get_name().str().c_str(), m.mk_bool_sort()); - sub.insert(x, m.mk_ite(x_new, a.mk_numeral(rational(1), true), - a.mk_numeral(rational(0), true))); - b2i->insert(x_new->get_decl(), x->get_decl()); + bounds.has_upper(x, hi, s2) && !s2 && hi <= m_max_hi) { + add_variable(b2i, sub, x, hi.get_unsigned(), axioms); } } @@ -155,6 +185,9 @@ public: sub(curr, new_curr); g->update(i, new_curr, new_pr, g->dep(i)); } + for (unsigned i = 0; i < axioms.size(); ++i) { + g->assert_expr(axioms[i].get()); + } g->inc_depth(); result.push_back(g.get()); TRACE("pb", g->display(tout);); @@ -166,10 +199,41 @@ public: virtual tactic * translate(ast_manager & m) { return alloc(elim01_tactic, m, m_params); } - - virtual void collect_param_descrs(param_descrs & r) {} - + virtual void cleanup() {} + + void add_variable(bool2int_model_converter* b2i, + expr_safe_replace& sub, + app* x, + unsigned max_value, + expr_ref_vector& axioms) { + std::string name = x->get_decl()->get_name().str(); + unsigned sh = 0; + app_ref_vector xs(m), ites(m); + func_decl_ref_vector xfs(m); + app_ref zero(m), sum(m); + zero = a.mk_numeral(rational(0), true); + while (max_value >= (1ul << sh)) { + xs.push_back(m.mk_fresh_const(name.c_str(), m.mk_bool_sort())); + xfs.push_back(xs.back()->get_decl()); + ites.push_back(m.mk_ite(xs.back(), a.mk_numeral(rational(1 << sh), true), zero)); + ++sh; + } + if (ites.size() == 1) { + sum = ites[0].get(); + } + else { + sum = a.mk_add(ites.size(), (expr*const*)ites.c_ptr()); + } + + sub.insert(x, sum); + b2i->insert(x->get_decl(), xfs.size(), xfs.c_ptr()); + // if max_value+1 is not a power of two: + if ((max_value & (max_value + 1)) != 0) { + axioms.push_back(a.mk_le(sum, a.mk_numeral(rational(max_value), true))); + } + } + }; tactic * mk_elim01_tactic(ast_manager & m, params_ref const & p) { diff --git a/src/tactic/arith/lia2card_tactic.cpp b/src/tactic/arith/lia2card_tactic.cpp index c6c49ec5b..a819791ec 100644 --- a/src/tactic/arith/lia2card_tactic.cpp +++ b/src/tactic/arith/lia2card_tactic.cpp @@ -155,8 +155,7 @@ public: else if (m.is_eq(fml, x, y) && get_pb_sum(x, rational::one(), args, coeffs, coeff) && get_pb_sum(y, -rational::one(), args, coeffs, coeff)) { - result = m.mk_and(mk_le(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff), - mk_ge(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff)); + result = mk_eq(coeffs.size(), coeffs.c_ptr(), args.c_ptr(), -coeff); return true; } return false; @@ -174,6 +173,14 @@ public: } return m_pb.mk_le(sz, weights, args, w); } + + expr* mk_eq(unsigned sz, rational const* weights, expr* const* args, rational const& w) { +#if 1 + return m.mk_and(mk_ge(sz, weights, args, w), mk_le(sz, weights, args, w)); +#else + return m_pb.mk_eq(sz, weights, args, w); +#endif + } expr* mk_ge(unsigned sz, rational const* weights, expr* const* args, rational const& w) { if (sz == 0) { diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index f4eb4b032..8ba260cae 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -116,4 +116,649 @@ Notes: } }; + // parametric sorting network + // Described in Abio et.al. CP 2013. + template + class psort_nw { + typedef typename psort_expr::literal literal; + typedef typename psort_expr::literal_vector literal_vector; + + class vc { + unsigned v; // number of vertices + unsigned c; // number of clauses + static const unsigned lambda = 5; + public: + vc(unsigned v, unsigned c):v(v), c(c) {} + + bool operator<(vc const& other) const { + return to_int() < other.to_int(); + } + vc operator+(vc const& other) const { + return vc(v + other.v, c + other.c); + } + unsigned to_int() const { + return lambda*v + c; + } + vc operator*(unsigned n) const { + return vc(n*v, n*c); + } + }; + + static vc min(vc const& v1, vc const& v2) { + return (v1.to_int() < v2.to_int())?v1:v2; + } + + + enum cmp_t { LE, GE, EQ, GE_FULL, LE_FULL }; + psort_expr& ctx; + cmp_t m_t; + + // for testing + static const bool m_disable_dcard = false; + static const bool m_disable_dsorting = false; + static const bool m_disable_dsmerge = false; + static const bool m_force_dcard = false; + static const bool m_force_dsorting = false; + static const bool m_force_dsmerge = false; + + public: + struct stats { + unsigned m_num_compiled_vars; + unsigned m_num_compiled_clauses; + void reset() { memset(this, 0, sizeof(*this)); } + stats() { reset(); } + }; + stats m_stats; + + psort_nw(psort_expr& c): ctx(c) {} + + literal ge(bool full, unsigned k, unsigned n, literal const* xs) { + if (k > n) { + return ctx.mk_false(); + } + if (k == 0) { + return ctx.mk_true(); + } + SASSERT(0 < k && k <= n); + literal_vector in, out; + if (dualize(k, n, xs, in)) { + return le(full, k, in.size(), in.c_ptr()); + } + else { + SASSERT(2*k <= n); + m_t = full?GE_FULL:GE; + card(k, n, xs, out); + return out[k-1]; + } + } + + literal le(bool full, unsigned k, unsigned n, literal const* xs) { + if (k >= n) { + return ctx.mk_true(); + } + SASSERT(k < n); + literal_vector in, out; + if (dualize(k, n, xs, in)) { + return ge(full, k, n, in.c_ptr()); + } + else { + SASSERT(2*k <= n); + m_t = full?LE_FULL:LE; + card(k + 1, n, xs, out); + return ~out[k]; + } + } + + literal eq(unsigned k, unsigned n, literal const* xs) { + if (k > n) { + return ctx.mk_false(); + } + SASSERT(k <= n); + literal_vector in, out; + if (dualize(k, n, xs, in)) { + return eq(k, n, in.c_ptr()); + } + else { + SASSERT(2*k < n); + m_t = EQ; + card(k+1, n, xs, out); + SASSERT(out.size() >= k+1); + return out[k-1]; // & ~out[m] TBD + } + } + + + private: + + std::ostream& pp(std::ostream& out, unsigned n, literal const* lits) { + for (unsigned i = 0; i < n; ++i) ctx.pp(out, lits[i]) << " "; + return out; + } + + std::ostream& pp(std::ostream& out, literal_vector const& lits) { + for (unsigned i = 0; i < lits.size(); ++i) ctx.pp(out, lits[i]) << " "; + return out; + } + + // 0 <= k <= N + // SUM x_i >= k + // <=> + // SUM ~x_i <= N - k + // suppose k > N/2, then it is better to solve dual. + + bool dualize(unsigned& k, unsigned N, literal const* xs, literal_vector& in) { + SASSERT(0 <= k && k <= N); + if (2*k <= N) { + return false; + } + k = N - k; + for (unsigned i = 0; i < N; ++i) { + in.push_back(~xs[i]); + } + TRACE("pb", + pp(tout << N << ": ", in); + tout << " ~ " << k << "\n";); + return true; + } + + + bool even(unsigned n) const { return (0 == (n & 0x1)); } + bool odd(unsigned n) const { return !even(n); } + unsigned ceil2(unsigned n) const { return n/2 + odd(n); } + unsigned floor2(unsigned n) const { return n/2; } + unsigned power2(unsigned n) const { SASSERT(n < 10); return 1 << n; } + + + literal max(literal a, literal b) { + if (a == b) return a; + m_stats.m_num_compiled_vars++; + return ctx.max(a, b); + } + + literal min(literal a, literal b) { + if (a == b) return a; + m_stats.m_num_compiled_vars++; + return ctx.min(a, b); + } + + literal fresh() { + m_stats.m_num_compiled_vars++; + return ctx.fresh(); + } + void add_clause(literal l1, literal l2, literal l3) { + literal lits[3] = { l1, l2, l3 }; + add_clause(3, lits); + } + void add_clause(literal l1, literal l2) { + literal lits[2] = { l1, l2 }; + add_clause(2, lits); + } + void add_clause(unsigned n, literal const* ls) { + m_stats.m_num_compiled_clauses++; + literal_vector tmp(n, ls); + ctx.mk_clause(n, tmp.c_ptr()); + } + + // y1 <= max(x1,x2) + // y2 <= min(x1,x2) + void cmp_ge(literal x1, literal x2, literal y1, literal y2) { + add_clause(~y2, x1); + add_clause(~y2, x2); + add_clause(~y1, x1, x2); + } + + // max(x1,x2) <= y1 + // min(x1,x2) <= y2 + void cmp_le(literal x1, literal x2, literal y1, literal y2) { + add_clause(~x1, y1); + add_clause(~x2, y1); + add_clause(~x1, ~x2, y2); + } + + void cmp_eq(literal x1, literal x2, literal y1, literal y2) { + cmp_ge(x1, x2, y1, y2); + cmp_le(x1, x2, y1, y2); + } + + void cmp(literal x1, literal x2, literal y1, literal y2) { + switch(m_t) { + case LE: cmp_le(x1, x2, y1, y2); break; + case GE: cmp_ge(x1, x2, y1, y2); break; + case EQ: cmp_eq(x1, x2, y1, y2); break; + } + } + vc vc_cmp() { + return vc(2, (m_t==EQ)?6:3); + } + + void card(unsigned k, unsigned n, literal const* xs, literal_vector& out) { + TRACE("pb", tout << "card k:" << k << " n: " << n << "\n";); + if (n <= k) { + sorting(n, xs, out); + } + else if (use_dcard(k, n)) { + dsorting(k, n, xs, out); + } + else { + literal_vector out1, out2; + unsigned l = n/2; // TBD + card(k, l, xs, out1); + card(k, n-l, xs + l, out2); + smerge(k, out1.size(), out1.c_ptr(), out2.size(), out2.c_ptr(), out); + } + TRACE("pb", tout << "card k:" << k << " n: " << n << "\n"; + pp(tout << "in:", n, xs) << "\n"; + pp(tout << "out:", out) << "\n";); + + } + vc vc_card(unsigned k, unsigned n) { + if (n <= k) { + return vc_sorting(n); + } + else if (use_dcard(k, n)) { + return vc_dsorting(k, n); + } + else { + return vc_card_rec(k, n); + } + } + vc vc_card_rec(unsigned k, unsigned n) { + unsigned l = n/2; + return vc_card(k, l) + vc_card(k, n-l) + vc_smerge(k, l, n-l); + } + bool use_dcard(unsigned k, unsigned n) { + return m_force_dcard || (!m_disable_dcard && n < 10 && vc_dsorting(k, n) < vc_card_rec(k, n)); + } + + + void merge(unsigned a, literal const* as, + unsigned b, literal const* bs, + literal_vector& out) { + TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n";); + if (a == 1 && b == 1) { + literal y1 = max(as[0], bs[0]); + literal y2 = min(as[0], bs[0]); + out.push_back(y1); + out.push_back(y2); + cmp(as[0], bs[0], y1, y2); + } + else if (a == 0) { + out.append(b, bs); + } + else if (b == 0) { + out.append(a, as); + } + else if (use_dsmerge(a, b, a + b)) { + dsmerge(a + b, a, as, b, bs, out); + } + else if (even(a) && odd(b)) { + merge(b, bs, a, as, out); + } + else { + literal_vector even_a, odd_a; + literal_vector even_b, odd_b; + literal_vector out1, out2; + SASSERT(a > 1 || b > 1); + split(a, as, even_a, odd_a); + split(b, bs, even_b, odd_b); + SASSERT(!even_a.empty()); + SASSERT(!even_b.empty()); + merge(even_a.size(), even_a.c_ptr(), + even_b.size(), even_b.c_ptr(), out1); + merge(odd_a.size(), odd_a.c_ptr(), + odd_b.size(), odd_b.c_ptr(), out2); + interleave(out1, out2, out); + } + TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n"; + pp(tout << "a:", a, as) << "\n"; + pp(tout << "b:", b, bs) << "\n"; + pp(tout << "out:", out) << "\n";); + } + vc vc_merge(unsigned a, unsigned b) { + if (a == 1 && b == 1) { + return vc_cmp(); + } + else if (a == 0 || b == 0) { + return vc(0, 0); + } + else if (use_dsmerge(a, b, a + b)) { + return vc_dsmerge(a, b, a + b); + } + else { + return vc_merge_rec(a, b); + } + } + vc vc_merge_rec(unsigned a, unsigned b) { + return + vc_merge(ceil2(a), ceil2(b)) + + vc_merge(floor2(a), floor2(b)) + + vc_interleave(ceil2(a) + ceil2(b), floor2(a) + floor2(b)); + } + void split(unsigned n, literal const* ls, literal_vector& even, literal_vector& odd) { + for (unsigned i = 0; i < n; i += 2) { + even.push_back(ls[i]); + } + for (unsigned i = 1; i < n; i += 2) { + odd.push_back(ls[i]); + } + } + + void interleave(literal_vector const& as, + literal_vector const& bs, + literal_vector& out) { + TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n";); + SASSERT(as.size() >= bs.size()); + SASSERT(as.size() <= bs.size() + 2); + SASSERT(!as.empty()); + out.push_back(as[0]); + unsigned sz = std::min(as.size()-1, bs.size()); + for (unsigned i = 0; i < sz; ++i) { + literal y1 = max(as[i+1],bs[i]); + literal y2 = min(as[i+1],bs[i]); + cmp(as[i+1], bs[i], y1, y2); + out.push_back(y1); + out.push_back(y2); + } + if (as.size() == bs.size()) { + out.push_back(bs[sz]); + } + else if (as.size() == bs.size() + 2) { + out.push_back(as[sz+1]); + } + SASSERT(out.size() == as.size() + bs.size()); + TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n"; + pp(tout << "a: ", as) << "\n"; + pp(tout << "b: ", bs) << "\n"; + pp(tout << "out: ", out) << "\n";); + + } + vc vc_interleave(unsigned a, unsigned b) { + return vc_cmp()*std::min(a-1,b); + } + + void sorting(unsigned n, literal const* xs, literal_vector& out) { + TRACE("pb", tout << "sorting: " << n << "\n";); + switch(n) { + case 0: + break; + case 1: + out.push_back(xs[0]); + break; + case 2: + merge(1, xs, 1, xs+1, out); + break; + default: + if (use_dsorting(n)) { + dsorting(n, n, xs, out); + } + else { + literal_vector out1, out2; + unsigned l = n/2; // TBD + sorting(l, xs, out1); + sorting(n-l, xs+l, out2); + merge(out1.size(), out1.c_ptr(), + out2.size(), out2.c_ptr(), + out); + } + break; + } + TRACE("pb", tout << "sorting: " << n << "\n"; + pp(tout << "in:", n, xs) << "\n"; + pp(tout << "out:", out) << "\n";); + + } + vc vc_sorting(unsigned n) { + switch(n) { + case 0: return vc(0,0); + case 1: return vc(0,0); + case 2: return vc_merge(1,1); + default: + if (use_dsorting(n)) { + return vc_dsorting(n, n); + } + else { + return vc_sorting_rec(n); + } + } + } + vc vc_sorting_rec(unsigned n) { + SASSERT(n > 2); + unsigned l = n/2; + return vc_sorting(l) + vc_sorting(n-l) + vc_merge(l, n-l); + } + + bool use_dsorting(unsigned n) { + SASSERT(n > 2); + return m_force_dsorting || + (!m_disable_dsorting && n < 10 && vc_dsorting(n, n) < vc_sorting_rec(n)); + } + + void smerge(unsigned c, + unsigned a, literal const* as, + unsigned b, literal const* bs, + literal_vector& out) { + TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n";); + if (a == 1 && b == 1 && c == 1) { + literal y = max(as[0], bs[0]); + if (m_t != GE) { + // x1 <= max(x1,x2) + // x2 <= max(x1,x2) + add_clause(~as[0], y); + add_clause(~bs[0], y); + } + if (m_t != LE) { + // max(x1,x2) <= x1, x2 + add_clause(~y, as[0], bs[0]); + } + out.push_back(y); + } + else if (a == 0) { + out.append(std::min(c, b), bs); + } + else if (b == 0) { + out.append(std::min(c, a), as); + } + else if (a > c) { + smerge(c, c, as, b, bs, out); + } + else if (b > c) { + smerge(c, a, as, c, bs, out); + } + else if (a + b <= c) { + merge(a, as, b, bs, out); + } + else if (use_dsmerge(a, b, c)) { + dsmerge(c, a, as, b, bs, out); + } + else { + literal_vector even_a, odd_a; + literal_vector even_b, odd_b; + literal_vector out1, out2; + split(a, as, even_a, odd_a); + split(b, bs, even_b, odd_b); + SASSERT(!even_a.empty()); + SASSERT(!even_b.empty()); + unsigned c1, c2; + if (even(c)) { + c1 = 1 + c/2; c2 = c/2; + } + else { + c1 = (c + 1)/2; c2 = (c - 1)/2; + } + smerge(c1, even_a.size(), even_a.c_ptr(), + even_b.size(), even_b.c_ptr(), out1); + smerge(c2, odd_a.size(), odd_a.c_ptr(), + odd_b.size(), odd_b.c_ptr(), out2); + SASSERT(out1.size() == std::min(even_a.size()+even_b.size(), c1)); + SASSERT(out2.size() == std::min(odd_a.size()+odd_b.size(), c2)); + literal y; + if (even(c)) { + literal z1 = out1.back(); + literal z2 = out2.back(); + out1.pop_back(); + out2.pop_back(); + y = max(z1, z2); + if (m_t != GE) { + add_clause(~z1, y); + add_clause(~z2, y); + } + if (m_t != LE) { + add_clause(~y, z1, z2); + } + } + interleave(out1, out2, out); + if (even(c)) { + out.push_back(y); + } + } + TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n"; + pp(tout << "a:", a, as) << "\n"; + pp(tout << "b:", b, bs) << "\n"; + pp(tout << "out:", out) << "\n"; + ); + SASSERT(out.size() == std::min(a + b, c)); + } + + vc vc_smerge(unsigned a, unsigned b, unsigned c) { + if (a == 1 && b == 1 && c == 1) { + vc v(1,0); + if (m_t != GE) v = v + vc(0, 2); + if (m_t != LE) v = v + vc(0, 1); + return v; + } + if (a == 0 || b == 0) return vc(0, 0); + if (a > c) return vc_smerge(c, b, c); + if (b > c) return vc_smerge(a, c, c); + if (a + b <= c) return vc_merge(a, b); + if (use_dsmerge(a, b, c)) return vc_dsmerge(a, b, c); + return vc_smerge_rec(a, b, c); + } + vc vc_smerge_rec(unsigned a, unsigned b, unsigned c) { + return + vc_smerge(ceil2(a), ceil2(b), even(c)?(1+c/2):((c+1)/2)) + + vc_smerge(floor2(a), floor2(b), even(c)?(c/2):((c-1)/2)) + + vc_interleave(ceil2(a)+ceil2(b),floor2(a)+floor2(b)) + + vc(1, 0) + + ((m_t != GE)?vc(0, 2):vc(0, 0)) + + ((m_t != LE)?vc(0, 1):vc(0, 0)); + } + bool use_dsmerge(unsigned a, unsigned b, unsigned c) { + return + m_force_dsmerge || + (!m_disable_dsmerge && + a < (1 << 15) && b < (1 << 15) && + vc_dsmerge(a, b, a + b) < vc_smerge_rec(a, b, c)); + } + + void dsmerge( + unsigned c, + unsigned a, literal const* as, + unsigned b, literal const* bs, + literal_vector& out) { + TRACE("pb", tout << "dsmerge: c:" << c << " a:" << a << " b:" << b << "\n";); + SASSERT(a <= c); + SASSERT(b <= c); + SASSERT(a + b >= c); + for (unsigned i = 0; i < c; ++i) { + out.push_back(fresh()); + } + if (m_t != GE) { + for (unsigned i = 0; i < a; ++i) { + add_clause(~as[i], out[i]); + } + for (unsigned i = 0; i < b; ++i) { + add_clause(~bs[i], out[i]); + } + for (unsigned i = 1; i <= a; ++i) { + for (unsigned j = 1; j <= b && i + j <= c; ++j) { + add_clause(~as[i-1],~bs[j-1],out[i+j-1]); + } + } + } + if (m_t != LE) { + for (unsigned k = 1; k <= c; ++k) { + literal_vector ls; + ls.push_back(~out[k-1]); + if (k <= a) { + ls.push_back(as[k-1]); + } + if (k <= b) { + ls.push_back(bs[k-1]); + } + for (unsigned i = 1; i <= std::min(a,k-1); ++i) { + if (k + 1 - i <= b) { + ls.push_back(as[i-1]); + ls.push_back(bs[k-i]); + add_clause(ls.size(), ls.c_ptr()); + ls.pop_back(); + ls.pop_back(); + } + } + } + } + } + vc vc_dsmerge(unsigned a, unsigned b, unsigned c) { + vc v(c, 0); + if (m_t != GE) { + v = v + vc(0, a + b + std::min(a, c)*std::min(b, c)/2); + } + if (m_t != LE) { + v = v + vc(0, std::min(a, c)*std::min(b, c)/2); + } + return v; + } + + + void dsorting(unsigned m, unsigned n, literal const* xs, + literal_vector& out) { + TRACE("pb", tout << "dsorting m: " << m << " n: " << n << "\n";); + SASSERT(m <= n); + literal_vector lits; + for (unsigned i = 0; i < m; ++i) { + out.push_back(fresh()); + } + if (m_t != GE) { + for (unsigned k = 1; k <= m; ++k) { + lits.push_back(out[k-1]); + add_subset(true, k, 0, lits, n, xs); + lits.pop_back(); + } + } + if (m_t != LE) { + for (unsigned k = 1; k <= m; ++k) { + lits.push_back(~out[k-1]); + add_subset(false, n-k+1, 0, lits, n, xs); + lits.pop_back(); + } + } + } + vc vc_dsorting(unsigned m, unsigned n) { + SASSERT(m <= n && n < 10); + vc v(m, 0); + if (m_t != GE) { + v = v + vc(0, power2(n-1)); + } + if (m_t != LE) { + v = v + vc(0, power2(n-1)); + } + return v; + } + + void add_subset(bool polarity, unsigned k, unsigned offset, literal_vector& lits, + unsigned n, literal const* xs) { + TRACE("pb", tout << "k:" << k << " offset: " << offset << " n: " << n << " "; + pp(tout, lits) << "\n";); + SASSERT(k + offset <= n); + if (k == 0) { + add_clause(lits.size(), lits.c_ptr()); + return; + } + for (unsigned i = offset; i < n - k + 1; ++i) { + lits.push_back(polarity?~xs[i]:xs[i]); + add_subset(polarity, k-1, i+1, lits, n, xs); + lits.pop_back(); + } + } + }; + #endif