3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-14 04:48:45 +00:00
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2018-04-13 16:22:36 -07:00
parent d57bca8f8c
commit f7e49501af
4 changed files with 109 additions and 57 deletions

View file

@ -56,9 +56,7 @@ struct pb2bv_rewriter::imp {
rational m_k;
vector<rational> m_coeffs;
bool m_keep_cardinality_constraints;
bool m_keep_pb_constraints;
bool m_pb_num_system;
bool m_pb_totalizer;
symbol m_pb_solver;
unsigned m_min_arity;
template<lbool is_le>
@ -85,7 +83,7 @@ struct pb2bv_rewriter::imp {
struct compare_coeffs {
bool operator()(ca const& x, ca const& y) const {
return x.first < y.first;
return x.first > y.first;
}
};
@ -126,11 +124,12 @@ struct pb2bv_rewriter::imp {
if (i + 1 < sz && !m_coeffs[i+1].is_neg()) tout << "+ ";
}
switch (is_le) {
case l_true: tout << "<= "; break;
case l_true: tout << "<= "; break;
case l_undef: tout << "= "; break;
case l_false: tout << ">= "; break;
}
tout << k << "\n";);
if (k.is_zero()) {
if (is_le != l_false) {
return expr_ref(m.mk_not(::mk_or(m_args)), m);
@ -143,7 +142,7 @@ struct pb2bv_rewriter::imp {
return expr_ref((is_le == l_false)?m.mk_true():m.mk_false(), m);
}
if (m_pb_totalizer) {
if (m_pb_solver == "totalizer") {
expr_ref result(m);
switch (is_le) {
case l_true: if (mk_le_tot(sz, args, k, result)) return result; else break;
@ -152,7 +151,7 @@ struct pb2bv_rewriter::imp {
}
}
if (m_pb_num_system) {
if (m_pb_solver == "sorting") {
expr_ref result(m);
switch (is_le) {
case l_true: if (mk_le(sz, args, k, result)) return result; else break;
@ -161,6 +160,15 @@ struct pb2bv_rewriter::imp {
}
}
if (m_pb_solver == "segmented") {
expr_ref result(m);
switch (is_le) {
case l_true: return mk_seg_le(k);
case l_false: return mk_seg_ge(k);
case l_undef: break;
}
}
// fall back to divide and conquer encoding.
SASSERT(k.is_pos());
expr_ref zero(m), bound(m);
@ -486,14 +494,82 @@ struct pb2bv_rewriter::imp {
return true;
}
expr_ref mk_and(expr_ref& a, expr_ref& b) {
/**
\brief Segment based encoding.
The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient.
The segments are sorted, such that the segment with highest coefficient is first.
Then for each segment create circuits based on sorting networks the arguments of the segment.
*/
expr_ref mk_seg_ge(rational const& k) {
rational bound(-k);
for (unsigned i = 0; i < m_args.size(); ++i) {
m_args[i] = mk_not(m_args[i].get());
bound += m_coeffs[i];
}
return mk_seg_le(bound);
}
expr_ref mk_seg_le(rational const& k) {
sort_args();
unsigned sz = m_args.size();
expr* const* args = m_args.c_ptr();
// Create sorted entries.
vector<ptr_vector<expr>> outs;
vector<rational> coeffs;
for (unsigned i = 0, seg_size = 0; i < sz; i += seg_size) {
seg_size = segment_size(i);
ptr_vector<expr> out;
m_sort.sorting(seg_size, args + i, out);
out.push_back(m.mk_false());
outs.push_back(out);
coeffs.push_back(m_coeffs[i]);
}
return mk_seg_le_rec(outs, coeffs, 0, k);
}
expr_ref mk_seg_le_rec(vector<ptr_vector<expr>> const& outs, vector<rational> const& coeffs, unsigned i, rational const& k) {
rational const& c = coeffs[i];
ptr_vector<expr> const& out = outs[i];
if (k.is_neg()) {
return expr_ref(m.mk_false(), m);
}
if (i == outs.size()) {
return expr_ref(m.mk_true(), m);
}
if (i + 1 == outs.size() && k >= rational(out.size()-1)*c) {
return expr_ref(m.mk_true(), m);
}
expr_ref_vector fmls(m);
fmls.push_back(m.mk_implies(m.mk_not(out[0]), mk_seg_le_rec(outs, coeffs, i + 1, k)));
rational k1;
for (unsigned j = 0; j + 1 < out.size(); ++j) {
k1 = k - rational(j+1)*c;
if (k1.is_neg()) {
fmls.push_back(m.mk_not(out[j]));
break;
}
fmls.push_back(m.mk_implies(m.mk_and(out[j], m.mk_not(out[j+1])), mk_seg_le_rec(outs, coeffs, i + 1, k1)));
}
return ::mk_and(fmls);
}
// The number of arguments with the same coefficient.
unsigned segment_size(unsigned start) const {
unsigned i = start;
while (i < m_args.size() && m_coeffs[i] == m_coeffs[start]) ++i;
return i - start;
}
expr_ref mk_and(expr_ref& a, expr_ref& b) {
if (m.is_true(a)) return b;
if (m.is_true(b)) return a;
if (m.is_false(a)) return a;
if (m.is_false(b)) return b;
return expr_ref(m.mk_and(a, b), m);
}
expr_ref mk_or(expr_ref& a, expr_ref& b) {
if (m.is_true(a)) return a;
if (m.is_true(b)) return b;
@ -607,12 +683,12 @@ struct pb2bv_rewriter::imp {
m_trail(m),
m_args(m),
m_keep_cardinality_constraints(false),
m_keep_pb_constraints(false),
m_pb_num_system(false),
m_pb_totalizer(false),
m_pb_solver(symbol("solver")),
m_min_arity(9)
{}
void set_pb_solver(symbol const& s) { m_pb_solver = s; }
bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
if (f->get_family_id() == pb.get_family_id() && mk_pb(full, f, sz, args, result)) {
// skip
@ -756,13 +832,13 @@ struct pb2bv_rewriter::imp {
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
++m_imp.m_compile_card;
}
else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) {
else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") {
return false;
}
else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) {
else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") {
return false;
}
else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) {
else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") {
return false;
}
else {
@ -811,17 +887,6 @@ struct pb2bv_rewriter::imp {
m_keep_cardinality_constraints = f;
}
void keep_pb_constraints(bool f) {
m_keep_pb_constraints = f;
}
void pb_num_system(bool f) {
m_pb_num_system = f;
}
void pb_totalizer(bool f) {
m_pb_totalizer = f;
}
void set_at_most1(sorting_network_encoding enc) { m_sort.cfg().m_encoding = enc; }
};
@ -836,9 +901,7 @@ struct pb2bv_rewriter::imp {
}
card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {}
void keep_cardinality_constraints(bool f) { m_r.keep_cardinality_constraints(f); }
void keep_pb_constraints(bool f) { m_r.keep_pb_constraints(f); }
void pb_num_system(bool f) { m_r.pb_num_system(f); }
void pb_totalizer(bool f) { m_r.pb_totalizer(f); }
void set_pb_solver(symbol const& s) { m_r.set_pb_solver(s); }
void set_at_most1(sorting_network_encoding enc) { m_r.set_at_most1(enc); }
};
@ -850,9 +913,7 @@ struct pb2bv_rewriter::imp {
rewriter_tpl<card2bv_rewriter_cfg>(m, false, m_cfg),
m_cfg(i, m) {}
void keep_cardinality_constraints(bool f) { m_cfg.keep_cardinality_constraints(f); }
void keep_pb_constraints(bool f) { m_cfg.keep_pb_constraints(f); }
void pb_num_system(bool f) { m_cfg.pb_num_system(f); }
void pb_totalizer(bool f) { m_cfg.pb_totalizer(f); }
void set_pb_solver(symbol const& s) { m_cfg.set_pb_solver(s); }
void set_at_most1(sorting_network_encoding e) { m_cfg.set_at_most1(e); }
void rewrite(expr* e, expr_ref& r, proof_ref& p) {
expr_ref ee(e, m());
@ -875,26 +936,15 @@ struct pb2bv_rewriter::imp {
gparams::get_module("sat").get_bool("cardinality.solver", false);
}
bool keep_pb() const {
symbol pb_solver() const {
params_ref const& p = m_params;
return
p.get_bool("keep_pb_constraints", false) ||
p.get_bool("sat.pb.solver", false) ||
p.get_bool("pb.solver", false) ||
gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("solver") ;
symbol s = p.get_sym("sat.pb.solver", symbol());
if (s != symbol()) return s;
s = p.get_sym("pb.solver", symbol());
if (s != symbol()) return s;
return gparams::get_module("sat").get_sym("pb.solver", symbol("solver"));
}
bool pb_num_system() const {
return m_params.get_bool("pb_num_system", false) ||
gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("sorting");
}
bool pb_totalizer() const {
return m_params.get_bool("pb_totalizer", false) ||
gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("totalizer");
}
sorting_network_encoding atmost1_encoding() const {
symbol enc = m_params.get_sym("atmost1_encoding", symbol());
if (enc == symbol()) {
@ -920,16 +970,12 @@ struct pb2bv_rewriter::imp {
void updt_params(params_ref const & p) {
m_params.append(p);
m_rw.keep_cardinality_constraints(keep_cardinality());
m_rw.keep_pb_constraints(keep_pb());
m_rw.pb_num_system(pb_num_system());
m_rw.pb_totalizer(pb_totalizer());
m_rw.set_pb_solver(pb_solver());
m_rw.set_at_most1(atmost1_encoding());
}
void collect_param_descrs(param_descrs& r) const {
r.insert("keep_cardinality_constraints", CPK_BOOL, "(default: true) retain cardinality constraints (don't bit-blast them) and use built-in cardinality solver");
r.insert("keep_pb_constraints", CPK_BOOL, "(default: true) retain pb constraints (don't bit-blast them) and use built-in pb solver");
r.insert("pb_num_system", CPK_BOOL, "(default: false) use pb number system encoding");
r.insert("pb_totalizer", CPK_BOOL, "(default: false) use pb totalizer encoding");
r.insert("pb.solver", CPK_SYMBOL, "(default: solver) retain pb constraints (don't bit-blast them) and use built-in pb solver");
}
unsigned get_num_steps() const { return m_rw.get_num_steps(); }

View file

@ -178,8 +178,10 @@ namespace sat {
m_pb_solver = PB_TOTALIZER;
else if (s == symbol("solver"))
m_pb_solver = PB_SOLVER;
else if (s == symbol("segmented"))
m_pb_solver = PB_SEGMENTED;
else
throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting");
throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting, segmented");
m_card_solver = p.cardinality_solver();

View file

@ -54,7 +54,8 @@ namespace sat {
PB_SOLVER,
PB_CIRCUIT,
PB_SORTING,
PB_TOTALIZER
PB_TOTALIZER,
PB_SEGMENTED
};
enum reward_t {

View file

@ -281,9 +281,12 @@ public:
m_params.append(p);
sat_params p1(p);
m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver());
m_params.set_sym("pb.solver", p1.pb_solver());
m_params.set_bool("keep_pb_constraints", m_solver.get_config().m_pb_solver == sat::PB_SOLVER);
m_params.set_bool("pb_num_system", m_solver.get_config().m_pb_solver == sat::PB_SORTING);
m_params.set_bool("pb_totalizer", m_solver.get_config().m_pb_solver == sat::PB_TOTALIZER);
m_params.set_bool("xor_solver", p1.xor_solver());
m_solver.updt_params(m_params);
m_solver.set_incremental(is_incremental() && !override_incremental());