3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 13:28:47 +00:00
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2018-04-13 16:22:36 -07:00
parent d57bca8f8c
commit f7e49501af
4 changed files with 109 additions and 57 deletions

View file

@ -56,9 +56,7 @@ struct pb2bv_rewriter::imp {
rational m_k; rational m_k;
vector<rational> m_coeffs; vector<rational> m_coeffs;
bool m_keep_cardinality_constraints; bool m_keep_cardinality_constraints;
bool m_keep_pb_constraints; symbol m_pb_solver;
bool m_pb_num_system;
bool m_pb_totalizer;
unsigned m_min_arity; unsigned m_min_arity;
template<lbool is_le> template<lbool is_le>
@ -85,7 +83,7 @@ struct pb2bv_rewriter::imp {
struct compare_coeffs { struct compare_coeffs {
bool operator()(ca const& x, ca const& y) const { bool operator()(ca const& x, ca const& y) const {
return x.first < y.first; return x.first > y.first;
} }
}; };
@ -126,11 +124,12 @@ struct pb2bv_rewriter::imp {
if (i + 1 < sz && !m_coeffs[i+1].is_neg()) tout << "+ "; if (i + 1 < sz && !m_coeffs[i+1].is_neg()) tout << "+ ";
} }
switch (is_le) { switch (is_le) {
case l_true: tout << "<= "; break; case l_true: tout << "<= "; break;
case l_undef: tout << "= "; break; case l_undef: tout << "= "; break;
case l_false: tout << ">= "; break; case l_false: tout << ">= "; break;
} }
tout << k << "\n";); tout << k << "\n";);
if (k.is_zero()) { if (k.is_zero()) {
if (is_le != l_false) { if (is_le != l_false) {
return expr_ref(m.mk_not(::mk_or(m_args)), m); return expr_ref(m.mk_not(::mk_or(m_args)), m);
@ -143,7 +142,7 @@ struct pb2bv_rewriter::imp {
return expr_ref((is_le == l_false)?m.mk_true():m.mk_false(), m); return expr_ref((is_le == l_false)?m.mk_true():m.mk_false(), m);
} }
if (m_pb_totalizer) { if (m_pb_solver == "totalizer") {
expr_ref result(m); expr_ref result(m);
switch (is_le) { switch (is_le) {
case l_true: if (mk_le_tot(sz, args, k, result)) return result; else break; case l_true: if (mk_le_tot(sz, args, k, result)) return result; else break;
@ -152,7 +151,7 @@ struct pb2bv_rewriter::imp {
} }
} }
if (m_pb_num_system) { if (m_pb_solver == "sorting") {
expr_ref result(m); expr_ref result(m);
switch (is_le) { switch (is_le) {
case l_true: if (mk_le(sz, args, k, result)) return result; else break; case l_true: if (mk_le(sz, args, k, result)) return result; else break;
@ -161,6 +160,15 @@ struct pb2bv_rewriter::imp {
} }
} }
if (m_pb_solver == "segmented") {
expr_ref result(m);
switch (is_le) {
case l_true: return mk_seg_le(k);
case l_false: return mk_seg_ge(k);
case l_undef: break;
}
}
// fall back to divide and conquer encoding. // fall back to divide and conquer encoding.
SASSERT(k.is_pos()); SASSERT(k.is_pos());
expr_ref zero(m), bound(m); expr_ref zero(m), bound(m);
@ -486,14 +494,82 @@ struct pb2bv_rewriter::imp {
return true; return true;
} }
expr_ref mk_and(expr_ref& a, expr_ref& b) { /**
\brief Segment based encoding.
The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient.
The segments are sorted, such that the segment with highest coefficient is first.
Then for each segment create circuits based on sorting networks the arguments of the segment.
*/
expr_ref mk_seg_ge(rational const& k) {
rational bound(-k);
for (unsigned i = 0; i < m_args.size(); ++i) {
m_args[i] = mk_not(m_args[i].get());
bound += m_coeffs[i];
}
return mk_seg_le(bound);
}
expr_ref mk_seg_le(rational const& k) {
sort_args();
unsigned sz = m_args.size();
expr* const* args = m_args.c_ptr();
// Create sorted entries.
vector<ptr_vector<expr>> outs;
vector<rational> coeffs;
for (unsigned i = 0, seg_size = 0; i < sz; i += seg_size) {
seg_size = segment_size(i);
ptr_vector<expr> out;
m_sort.sorting(seg_size, args + i, out);
out.push_back(m.mk_false());
outs.push_back(out);
coeffs.push_back(m_coeffs[i]);
}
return mk_seg_le_rec(outs, coeffs, 0, k);
}
expr_ref mk_seg_le_rec(vector<ptr_vector<expr>> const& outs, vector<rational> const& coeffs, unsigned i, rational const& k) {
rational const& c = coeffs[i];
ptr_vector<expr> const& out = outs[i];
if (k.is_neg()) {
return expr_ref(m.mk_false(), m);
}
if (i == outs.size()) {
return expr_ref(m.mk_true(), m);
}
if (i + 1 == outs.size() && k >= rational(out.size()-1)*c) {
return expr_ref(m.mk_true(), m);
}
expr_ref_vector fmls(m);
fmls.push_back(m.mk_implies(m.mk_not(out[0]), mk_seg_le_rec(outs, coeffs, i + 1, k)));
rational k1;
for (unsigned j = 0; j + 1 < out.size(); ++j) {
k1 = k - rational(j+1)*c;
if (k1.is_neg()) {
fmls.push_back(m.mk_not(out[j]));
break;
}
fmls.push_back(m.mk_implies(m.mk_and(out[j], m.mk_not(out[j+1])), mk_seg_le_rec(outs, coeffs, i + 1, k1)));
}
return ::mk_and(fmls);
}
// The number of arguments with the same coefficient.
unsigned segment_size(unsigned start) const {
unsigned i = start;
while (i < m_args.size() && m_coeffs[i] == m_coeffs[start]) ++i;
return i - start;
}
expr_ref mk_and(expr_ref& a, expr_ref& b) {
if (m.is_true(a)) return b; if (m.is_true(a)) return b;
if (m.is_true(b)) return a; if (m.is_true(b)) return a;
if (m.is_false(a)) return a; if (m.is_false(a)) return a;
if (m.is_false(b)) return b; if (m.is_false(b)) return b;
return expr_ref(m.mk_and(a, b), m); return expr_ref(m.mk_and(a, b), m);
} }
expr_ref mk_or(expr_ref& a, expr_ref& b) { expr_ref mk_or(expr_ref& a, expr_ref& b) {
if (m.is_true(a)) return a; if (m.is_true(a)) return a;
if (m.is_true(b)) return b; if (m.is_true(b)) return b;
@ -607,12 +683,12 @@ struct pb2bv_rewriter::imp {
m_trail(m), m_trail(m),
m_args(m), m_args(m),
m_keep_cardinality_constraints(false), m_keep_cardinality_constraints(false),
m_keep_pb_constraints(false), m_pb_solver(symbol("solver")),
m_pb_num_system(false),
m_pb_totalizer(false),
m_min_arity(9) m_min_arity(9)
{} {}
void set_pb_solver(symbol const& s) { m_pb_solver = s; }
bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
if (f->get_family_id() == pb.get_family_id() && mk_pb(full, f, sz, args, result)) { if (f->get_family_id() == pb.get_family_id() && mk_pb(full, f, sz, args, result)) {
// skip // skip
@ -756,13 +832,13 @@ struct pb2bv_rewriter::imp {
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
++m_imp.m_compile_card; ++m_imp.m_compile_card;
} }
else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") {
return false; return false;
} }
else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") {
return false; return false;
} }
else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") {
return false; return false;
} }
else { else {
@ -811,17 +887,6 @@ struct pb2bv_rewriter::imp {
m_keep_cardinality_constraints = f; m_keep_cardinality_constraints = f;
} }
void keep_pb_constraints(bool f) {
m_keep_pb_constraints = f;
}
void pb_num_system(bool f) {
m_pb_num_system = f;
}
void pb_totalizer(bool f) {
m_pb_totalizer = f;
}
void set_at_most1(sorting_network_encoding enc) { m_sort.cfg().m_encoding = enc; } void set_at_most1(sorting_network_encoding enc) { m_sort.cfg().m_encoding = enc; }
}; };
@ -836,9 +901,7 @@ struct pb2bv_rewriter::imp {
} }
card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {} card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {}
void keep_cardinality_constraints(bool f) { m_r.keep_cardinality_constraints(f); } void keep_cardinality_constraints(bool f) { m_r.keep_cardinality_constraints(f); }
void keep_pb_constraints(bool f) { m_r.keep_pb_constraints(f); } void set_pb_solver(symbol const& s) { m_r.set_pb_solver(s); }
void pb_num_system(bool f) { m_r.pb_num_system(f); }
void pb_totalizer(bool f) { m_r.pb_totalizer(f); }
void set_at_most1(sorting_network_encoding enc) { m_r.set_at_most1(enc); } void set_at_most1(sorting_network_encoding enc) { m_r.set_at_most1(enc); }
}; };
@ -850,9 +913,7 @@ struct pb2bv_rewriter::imp {
rewriter_tpl<card2bv_rewriter_cfg>(m, false, m_cfg), rewriter_tpl<card2bv_rewriter_cfg>(m, false, m_cfg),
m_cfg(i, m) {} m_cfg(i, m) {}
void keep_cardinality_constraints(bool f) { m_cfg.keep_cardinality_constraints(f); } void keep_cardinality_constraints(bool f) { m_cfg.keep_cardinality_constraints(f); }
void keep_pb_constraints(bool f) { m_cfg.keep_pb_constraints(f); } void set_pb_solver(symbol const& s) { m_cfg.set_pb_solver(s); }
void pb_num_system(bool f) { m_cfg.pb_num_system(f); }
void pb_totalizer(bool f) { m_cfg.pb_totalizer(f); }
void set_at_most1(sorting_network_encoding e) { m_cfg.set_at_most1(e); } void set_at_most1(sorting_network_encoding e) { m_cfg.set_at_most1(e); }
void rewrite(expr* e, expr_ref& r, proof_ref& p) { void rewrite(expr* e, expr_ref& r, proof_ref& p) {
expr_ref ee(e, m()); expr_ref ee(e, m());
@ -875,26 +936,15 @@ struct pb2bv_rewriter::imp {
gparams::get_module("sat").get_bool("cardinality.solver", false); gparams::get_module("sat").get_bool("cardinality.solver", false);
} }
bool keep_pb() const { symbol pb_solver() const {
params_ref const& p = m_params; params_ref const& p = m_params;
return symbol s = p.get_sym("sat.pb.solver", symbol());
p.get_bool("keep_pb_constraints", false) || if (s != symbol()) return s;
p.get_bool("sat.pb.solver", false) || s = p.get_sym("pb.solver", symbol());
p.get_bool("pb.solver", false) || if (s != symbol()) return s;
gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("solver") ; return gparams::get_module("sat").get_sym("pb.solver", symbol("solver"));
} }
bool pb_num_system() const {
return m_params.get_bool("pb_num_system", false) ||
gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("sorting");
}
bool pb_totalizer() const {
return m_params.get_bool("pb_totalizer", false) ||
gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("totalizer");
}
sorting_network_encoding atmost1_encoding() const { sorting_network_encoding atmost1_encoding() const {
symbol enc = m_params.get_sym("atmost1_encoding", symbol()); symbol enc = m_params.get_sym("atmost1_encoding", symbol());
if (enc == symbol()) { if (enc == symbol()) {
@ -920,16 +970,12 @@ struct pb2bv_rewriter::imp {
void updt_params(params_ref const & p) { void updt_params(params_ref const & p) {
m_params.append(p); m_params.append(p);
m_rw.keep_cardinality_constraints(keep_cardinality()); m_rw.keep_cardinality_constraints(keep_cardinality());
m_rw.keep_pb_constraints(keep_pb()); m_rw.set_pb_solver(pb_solver());
m_rw.pb_num_system(pb_num_system());
m_rw.pb_totalizer(pb_totalizer());
m_rw.set_at_most1(atmost1_encoding()); m_rw.set_at_most1(atmost1_encoding());
} }
void collect_param_descrs(param_descrs& r) const { void collect_param_descrs(param_descrs& r) const {
r.insert("keep_cardinality_constraints", CPK_BOOL, "(default: true) retain cardinality constraints (don't bit-blast them) and use built-in cardinality solver"); r.insert("keep_cardinality_constraints", CPK_BOOL, "(default: true) retain cardinality constraints (don't bit-blast them) and use built-in cardinality solver");
r.insert("keep_pb_constraints", CPK_BOOL, "(default: true) retain pb constraints (don't bit-blast them) and use built-in pb solver"); r.insert("pb.solver", CPK_SYMBOL, "(default: solver) retain pb constraints (don't bit-blast them) and use built-in pb solver");
r.insert("pb_num_system", CPK_BOOL, "(default: false) use pb number system encoding");
r.insert("pb_totalizer", CPK_BOOL, "(default: false) use pb totalizer encoding");
} }
unsigned get_num_steps() const { return m_rw.get_num_steps(); } unsigned get_num_steps() const { return m_rw.get_num_steps(); }

View file

@ -178,8 +178,10 @@ namespace sat {
m_pb_solver = PB_TOTALIZER; m_pb_solver = PB_TOTALIZER;
else if (s == symbol("solver")) else if (s == symbol("solver"))
m_pb_solver = PB_SOLVER; m_pb_solver = PB_SOLVER;
else if (s == symbol("segmented"))
m_pb_solver = PB_SEGMENTED;
else else
throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting"); throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting, segmented");
m_card_solver = p.cardinality_solver(); m_card_solver = p.cardinality_solver();

View file

@ -54,7 +54,8 @@ namespace sat {
PB_SOLVER, PB_SOLVER,
PB_CIRCUIT, PB_CIRCUIT,
PB_SORTING, PB_SORTING,
PB_TOTALIZER PB_TOTALIZER,
PB_SEGMENTED
}; };
enum reward_t { enum reward_t {

View file

@ -281,9 +281,12 @@ public:
m_params.append(p); m_params.append(p);
sat_params p1(p); sat_params p1(p);
m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver()); m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver());
m_params.set_sym("pb.solver", p1.pb_solver());
m_params.set_bool("keep_pb_constraints", m_solver.get_config().m_pb_solver == sat::PB_SOLVER); m_params.set_bool("keep_pb_constraints", m_solver.get_config().m_pb_solver == sat::PB_SOLVER);
m_params.set_bool("pb_num_system", m_solver.get_config().m_pb_solver == sat::PB_SORTING); m_params.set_bool("pb_num_system", m_solver.get_config().m_pb_solver == sat::PB_SORTING);
m_params.set_bool("pb_totalizer", m_solver.get_config().m_pb_solver == sat::PB_TOTALIZER); m_params.set_bool("pb_totalizer", m_solver.get_config().m_pb_solver == sat::PB_TOTALIZER);
m_params.set_bool("xor_solver", p1.xor_solver()); m_params.set_bool("xor_solver", p1.xor_solver());
m_solver.updt_params(m_params); m_solver.updt_params(m_params);
m_solver.set_incremental(is_incremental() && !override_incremental()); m_solver.set_incremental(is_incremental() && !override_incremental());