diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index 3a134a13c..5cef19234 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -53,6 +53,7 @@ struct pb2bv_rewriter::imp { rational m_k; vector m_coeffs; bool m_keep_cardinality_constraints; + bool m_keep_pb_constraints; unsigned m_min_arity; template @@ -565,6 +566,7 @@ struct pb2bv_rewriter::imp { m_trail(m), m_args(m), m_keep_cardinality_constraints(false), + m_keep_pb_constraints(false), m_min_arity(2) {} @@ -701,11 +703,33 @@ struct pb2bv_rewriter::imp { if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false; result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); } + else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { + return false; + } + else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { + return false; + } + else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { + return false; + } else { result = mk_bv(f, sz, args); } return true; } + + bool has_small_coefficients(func_decl* f) { + unsigned sz = f->get_arity(); + unsigned sum = 0; + for (unsigned i = 0; i < sz; ++i) { + rational c = pb.get_coeff(f, i); + if (!c.is_unsigned()) return false; + unsigned sum1 = sum + c.get_unsigned(); + if (sum1 < sum) return false; + sum = sum1; + } + return true; + } // definitions used for sorting network literal mk_false() { return m.mk_false(); } @@ -733,6 +757,10 @@ struct pb2bv_rewriter::imp { void keep_cardinality_constraints(bool f) { m_keep_cardinality_constraints = f; } + + void keep_pb_constraints(bool f) { + m_keep_pb_constraints = f; + } }; struct card2bv_rewriter_cfg : public default_rewriter_cfg { @@ -745,6 +773,7 @@ struct pb2bv_rewriter::imp { } card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {} void keep_cardinality_constraints(bool f) { m_r.keep_cardinality_constraints(f); } + void keep_pb_constraints(bool f) { m_r.keep_pb_constraints(f); } }; class card_pb_rewriter : public rewriter_tpl { @@ -754,6 +783,7 @@ struct pb2bv_rewriter::imp { rewriter_tpl(m, false, m_cfg), m_cfg(i, m) {} void keep_cardinality_constraints(bool f) { m_cfg.keep_cardinality_constraints(f); } + void keep_pb_constraints(bool f) { m_cfg.keep_pb_constraints(f); } }; card_pb_rewriter m_rw; @@ -764,14 +794,17 @@ struct pb2bv_rewriter::imp { m_num_translated(0), m_rw(*this, m) { m_rw.keep_cardinality_constraints(p.get_bool("keep_cardinality_constraints", false)); + m_rw.keep_pb_constraints(p.get_bool("keep_pb_constraints", false)); } void updt_params(params_ref const & p) { m_params.append(p); m_rw.keep_cardinality_constraints(m_params.get_bool("keep_cardinality_constraints", false)); + m_rw.keep_pb_constraints(m_params.get_bool("keep_pb_constraints", false)); } void collect_param_descrs(param_descrs& r) const { r.insert("keep_cardinality_constraints", CPK_BOOL, "(default: true) retain cardinality constraints (don't bit-blast them) and use built-in cardinality solver"); + r.insert("keep_pb_constraints", CPK_BOOL, "(default: true) retain pb constraints (don't bit-blast them) and use built-in pb solver"); } unsigned get_num_steps() const { return m_rw.get_num_steps(); } diff --git a/src/sat/card_extension.cpp b/src/sat/card_extension.cpp index 30206a441..bd32a7c40 100644 --- a/src/sat/card_extension.cpp +++ b/src/sat/card_extension.cpp @@ -49,6 +49,9 @@ namespace sat { m_max_sum(0) { for (unsigned i = 0; i < wlits.size(); ++i) { m_wlits[i] = wlits[i]; + if (m_max_sum + wlits[i].first < m_max_sum) { + throw default_exception("addition of pb coefficients overflows"); + } m_max_sum += wlits[i].first; } } diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index a13a8e8b5..226c79642 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -27,6 +27,7 @@ def_module_params('sat', ('drat.file', SYMBOL, '', 'file to dump DRAT proofs'), ('drat.check', BOOL, False, 'build up internal proof and check'), ('cardinality.solver', BOOL, False, 'use cardinality solver'), + ('pb.solver', BOOL, False, 'use pb solver'), ('xor.solver', BOOL, False, 'use xor solver'), ('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'), ('local_search', BOOL, False, 'use local search instead of CDCL'), diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 0c0f82537..0d12c0a94 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -216,7 +216,8 @@ public: m_params.append(p); sat_params p1(p); m_params.set_bool("elim_vars", false); - m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver()); + m_params.set_bool("keep_cardinality_constraints", p1.pb_solver() || p1.cardinality_solver()); + m_params.set_bool("keep_pb_constraints", p1.pb_solver()); m_params.set_bool("xor_solver", p1.xor_solver()); m_solver.updt_params(m_params); m_optimize_model = m_params.get_bool("optimize_model", false); diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 971843d55..9d07ec3e6 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -419,6 +419,103 @@ struct goal2sat::imp { } } + typedef std::pair wliteral; + + void check_unsigned(rational const& c) { + if (!c.is_unsigned()) { + throw default_exception("unsigned coefficient expected"); + } + } + + void convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits) { + for (unsigned i = 0; i < lits.size(); ++i) { + rational c = pb.get_coeff(t, i); + check_unsigned(c); + wlits.push_back(std::make_pair(c.get_unsigned(), lits[i])); + } + } + + void convert_pb_args(app* t, svector& wlits) { + sat::literal_vector lits; + convert_pb_args(t->get_num_args(), lits); + convert_to_wlits(t, lits, wlits); + } + + void convert_pb_ge(app* t, bool root, bool sign) { + rational k = pb.get_k(t); + check_unsigned(k); + svector wlits; + convert_pb_args(t, wlits); + unsigned sz = m_result_stack.size(); + if (root) { + m_result_stack.reset(); + m_ext->add_pb_ge(sat::null_bool_var, wlits, k.get_unsigned()); + } + else { + sat::bool_var v = m_solver.mk_var(true); + sat::literal lit(v, sign); + m_ext->add_pb_ge(v, wlits, k.get_unsigned()); + TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); + m_result_stack.shrink(sz - t->get_num_args()); + m_result_stack.push_back(lit); + } + } + + void convert_pb_le(app* t, bool root, bool sign) { + rational k = pb.get_k(t); + k.neg(); + svector wlits; + convert_pb_args(t, wlits); + for (unsigned i = 0; i < wlits.size(); ++i) { + wlits[i].second.neg(); + k += rational(wlits[i].first); + } + check_unsigned(k); + unsigned sz = m_result_stack.size(); + if (root) { + m_result_stack.reset(); + m_ext->add_pb_ge(sat::null_bool_var, wlits, k.get_unsigned()); + } + else { + sat::bool_var v = m_solver.mk_var(true); + sat::literal lit(v, sign); + m_ext->add_pb_ge(v, wlits, k.get_unsigned()); + TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); + m_result_stack.shrink(sz - t->get_num_args()); + m_result_stack.push_back(lit); + } + } + + void convert_pb_eq(app* t, bool root, bool sign) { + rational k = pb.get_k(t); + SASSERT(k.is_unsigned()); + svector wlits; + convert_pb_args(t, wlits); + sat::bool_var v1 = root ? sat::null_bool_var : m_solver.mk_var(true); + sat::bool_var v2 = root ? sat::null_bool_var : m_solver.mk_var(true); + m_ext->add_pb_ge(v1, wlits, k.get_unsigned()); + k.neg(); + for (unsigned i = 0; i < wlits.size(); ++i) { + wlits[i].second.neg(); + k += rational(wlits[i].first); + } + check_unsigned(k); + m_ext->add_pb_ge(v2, wlits, k.get_unsigned()); + if (root) { + m_result_stack.reset(); + } + else { + sat::literal l1(v1, false), l2(v2, false); + sat::bool_var v = m_solver.mk_var(); + sat::literal l(v, false); + mk_clause(~l, l1); + mk_clause(~l, l2); + mk_clause(~l1, ~l2, l); + m_result_stack.shrink(m_result_stack.size() - t->get_num_args()); + m_result_stack.push_back(l); + } + } + void convert_at_least_k(app* t, rational k, bool root, bool sign) { SASSERT(k.is_unsigned()); sat::literal_vector lits; @@ -529,16 +626,28 @@ struct goal2sat::imp { convert_at_least_k(t, pb.get_k(t), root, sign); break; case OP_PB_LE: - SASSERT(pb.has_unit_coefficients(t)); - convert_at_most_k(t, pb.get_k(t), root, sign); + if (pb.has_unit_coefficients(t)) { + convert_at_most_k(t, pb.get_k(t), root, sign); + } + else { + convert_pb_le(t, root, sign); + } break; case OP_PB_GE: - SASSERT(pb.has_unit_coefficients(t)); - convert_at_least_k(t, pb.get_k(t), root, sign); + if (pb.has_unit_coefficients(t)) { + convert_at_least_k(t, pb.get_k(t), root, sign); + } + else { + convert_pb_ge(t, root, sign); + } break; case OP_PB_EQ: - SASSERT(pb.has_unit_coefficients(t)); - convert_eq_k(t, pb.get_k(t), root, sign); + if (pb.has_unit_coefficients(t)) { + convert_eq_k(t, pb.get_k(t), root, sign); + } + else { + convert_pb_eq(t, root, sign); + } break; default: UNREACHABLE();