3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-16 05:48:44 +00:00
This commit is contained in:
Nikolaj Bjorner 2019-02-28 11:44:45 -08:00
commit e76cea4684
6 changed files with 283 additions and 4 deletions

View file

@ -161,7 +161,6 @@ struct pb2bv_rewriter::imp {
} }
if (m_pb_solver == "segmented") { if (m_pb_solver == "segmented") {
expr_ref result(m);
switch (is_le) { switch (is_le) {
case l_true: return mk_seg_le(k); case l_true: return mk_seg_le(k);
case l_false: return mk_seg_ge(k); case l_false: return mk_seg_ge(k);
@ -169,6 +168,11 @@ struct pb2bv_rewriter::imp {
} }
} }
if (m_pb_solver == "binary_merge") {
expr_ref result = binary_merge(is_le, k);
if (result) return result;
}
// fall back to divide and conquer encoding. // fall back to divide and conquer encoding.
SASSERT(k.is_pos()); SASSERT(k.is_pos());
expr_ref zero(m), bound(m); expr_ref zero(m), bound(m);
@ -494,6 +498,37 @@ struct pb2bv_rewriter::imp {
return true; return true;
} }
/**
\brief binary merge encoding.
*/
expr_ref binary_merge(lbool is_le, rational const& k) {
expr_ref result(m);
unsigned_vector coeffs;
for (rational const& c : m_coeffs) {
if (c.is_unsigned()) {
coeffs.push_back(c.get_unsigned());
}
else {
return result;
}
}
if (!k.is_unsigned()) {
return result;
}
switch (is_le) {
case l_true:
result = m_sort.le(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
break;
case l_false:
result = m_sort.ge(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
break;
case l_undef:
result = m_sort.eq(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
break;
}
return result;
}
/** /**
\brief Segment based encoding. \brief Segment based encoding.
The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient. The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient.

View file

@ -43,7 +43,7 @@ def_module_params('sat',
('drat.check_unsat', BOOL, False, 'build up internal proof and check'), ('drat.check_unsat', BOOL, False, 'build up internal proof and check'),
('drat.check_sat', BOOL, False, 'build up internal trace, check satisfying model'), ('drat.check_sat', BOOL, False, 'build up internal trace, check satisfying model'),
('cardinality.solver', BOOL, True, 'use cardinality solver'), ('cardinality.solver', BOOL, True, 'use cardinality solver'),
('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use native solver)'), ('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), binary_merge, segmented, solver (use native solver)'),
('xor.solver', BOOL, False, 'use xor solver'), ('xor.solver', BOOL, False, 'use xor solver'),
('cardinality.encoding', SYMBOL, 'grouped', 'encoding used for at-most-k constraints: grouped, bimander, ordered, unate, circuit'), ('cardinality.encoding', SYMBOL, 'grouped', 'encoding used for at-most-k constraints: grouped, bimander, ordered, unate, circuit'),
('pb.resolve', SYMBOL, 'cardinality', 'resolution strategy for boolean algebra solver: cardinality, rounding'), ('pb.resolve', SYMBOL, 'cardinality', 'resolution strategy for boolean algebra solver: cardinality, rounding'),

View file

@ -383,8 +383,9 @@ namespace smt {
return m_imp->next_decision(); return m_imp->next_decision();
} }
void kernel::display(std::ostream & out) const { std::ostream& kernel::display(std::ostream & out) const {
m_imp->display(out); m_imp->display(out);
return out;
} }
void kernel::collect_statistics(::statistics & st) const { void kernel::collect_statistics(::statistics & st) const {

View file

@ -237,7 +237,7 @@ namespace smt {
/** /**
\brief (For debubbing purposes) Prints the state of the kernel \brief (For debubbing purposes) Prints the state of the kernel
*/ */
void display(std::ostream & out) const; std::ostream& display(std::ostream & out) const;
/** /**
\brief Collect runtime statistics. \brief Collect runtime statistics.

View file

@ -522,7 +522,144 @@ static void tst_sorting_network(sorting_network_encoding enc) {
test_sorting5(enc); test_sorting5(enc);
} }
static void test_pb(unsigned max_w, unsigned sz, unsigned_vector& ws) {
if (ws.empty()) {
for (unsigned w = 1; w <= max_w; ++w) {
ws.push_back(w);
test_pb(max_w, sz, ws);
ws.pop_back();
}
}
else if (ws.size() < sz) {
for (unsigned w = ws.back(); w <= max_w; ++w) {
ws.push_back(w);
test_pb(max_w, sz, ws);
ws.pop_back();
}
}
else {
SASSERT(ws.size() == sz);
ast_manager m;
reg_decl_plugins(m);
expr_ref_vector xs(m), nxs(m);
expr_ref ge(m), eq(m);
smt_params fp;
smt::kernel solver(m, fp);
for (unsigned i = 0; i < sz; ++i) {
xs.push_back(m.mk_const(symbol(i), m.mk_bool_sort()));
nxs.push_back(m.mk_not(xs.back()));
}
std::cout << ws << " " << "\n";
for (unsigned k = max_w + 1; k < ws.size()*max_w; ++k) {
ast_ext2 ext(m);
psort_nw<ast_ext2> sn(ext);
solver.push();
//std::cout << "bound: " << k << "\n";
//std::cout << ws << " " << xs << "\n";
ge = sn.ge(k, sz, ws.c_ptr(), xs.c_ptr());
//std::cout << "ge: " << ge << "\n";
for (expr* cls : ext.m_clauses) {
solver.assert_expr(cls);
}
// solver.display(std::cout);
// for each truth assignment to xs, validate
// that circuit computes the right value for ge
for (unsigned i = 0; i < (1ul << sz); ++i) {
solver.push();
unsigned sum = 0;
for (unsigned j = 0; j < sz; ++j) {
if (0 == ((1 << j) & i)) {
solver.assert_expr(xs.get(j));
sum += ws[j];
}
else {
solver.assert_expr(nxs.get(j));
}
}
// std::cout << "bound: " << k << "\n";
// std::cout << ws << " " << xs << "\n";
// std::cout << sum << " >= " << k << " : " << (sum >= k) << " ";
solver.push();
if (sum < k) {
solver.assert_expr(m.mk_not(ge));
}
else {
solver.assert_expr(ge);
}
// solver.display(std::cout) << "\n";
VERIFY(solver.check() == l_true);
solver.pop(1);
solver.push();
if (sum >= k) {
solver.assert_expr(m.mk_not(ge));
}
else {
solver.assert_expr(ge);
}
// solver.display(std::cout) << "\n";
VERIFY(l_false == solver.check());
solver.pop(1);
solver.pop(1);
}
solver.pop(1);
solver.push();
eq = sn.eq(k, sz, ws.c_ptr(), xs.c_ptr());
for (expr* cls : ext.m_clauses) {
solver.assert_expr(cls);
}
// for each truth assignment to xs, validate
// that circuit computes the right value for ge
for (unsigned i = 0; i < (1ul << sz); ++i) {
solver.push();
unsigned sum = 0;
for (unsigned j = 0; j < sz; ++j) {
if (0 == ((1 << j) & i)) {
solver.assert_expr(xs.get(j));
sum += ws[j];
}
else {
solver.assert_expr(nxs.get(j));
}
}
solver.push();
if (sum != k) {
solver.assert_expr(m.mk_not(eq));
}
else {
solver.assert_expr(eq);
}
// solver.display(std::cout) << "\n";
VERIFY(solver.check() == l_true);
solver.pop(1);
solver.push();
if (sum == k) {
solver.assert_expr(m.mk_not(eq));
}
else {
solver.assert_expr(eq);
}
VERIFY(l_false == solver.check());
solver.pop(1);
solver.pop(1);
}
solver.pop(1);
}
}
}
static void tst_pb() {
unsigned_vector ws;
test_pb(3, 3, ws);
}
void tst_sorting_network() { void tst_sorting_network() {
tst_pb();
tst_sorting_network(sorting_network_encoding::unate_at_most); tst_sorting_network(sorting_network_encoding::unate_at_most);
tst_sorting_network(sorting_network_encoding::circuit_at_most); tst_sorting_network(sorting_network_encoding::circuit_at_most);
tst_sorting_network(sorting_network_encoding::ordered_at_most); tst_sorting_network(sorting_network_encoding::ordered_at_most);

View file

@ -357,6 +357,112 @@ Notes:
} }
} }
/**
\brief encode clauses for ws*xs >= k
- normalize inequality to ws'*xs' >= a*2^(bits-1)
- for each binary digit, sort contributions
- merge with even digits from lower layer - creating 2*n vector
- for last layer return that index a is on.
*/
literal le(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
unsigned sum = 0;
literal_vector Xs;
for (unsigned i = 0; i < n; ++i) {
sum += ws[i];
Xs.push_back(mk_not(xs[i]));
}
if (k >= sum) {
return ctx.mk_true();
}
return ge(sum - k, n, ws, Xs.begin());
}
literal ge(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
m_t = GE_FULL;
return cmp(k, n, ws, xs);
}
literal eq(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
return mk_and(ge(k, n, ws, xs), le(k, n, ws, xs));
#if 0
m_t = EQ;
return cmp(k, n, ws, xs);
#endif
}
literal cmp(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
unsigned w_max = 0, sum = 0;
literal_vector Xs;
unsigned_vector Ws;
for (unsigned i = 0; i < n; ++i) {
sum += ws[i];
w_max = std::max(ws[i], w_max);
Xs.push_back(xs[i]);
Ws.push_back(ws[i]);
}
if (sum < k) {
return ctx.mk_false();
}
// Normalize to form Ws*Xs ~ a*2^{q-1}
SASSERT(w_max > 0);
unsigned bits = 0;
while (w_max > 0) {
bits++;
w_max >>= 1;
}
unsigned pow = (1ul << (bits-1));
unsigned a = (k + pow - 1) / pow; // a*pow >= k
SASSERT(a*pow >= k);
SASSERT((a-1)*pow < k);
if (a*pow > k) {
Ws.push_back(a*pow - k);
Xs.push_back(ctx.mk_true());
++n;
k = a*pow;
}
literal_vector W, We, B, S, E;
for (unsigned i = 0; i < bits; ++i) {
// B is digits from Xs that are set at bit position i
B.reset();
for (unsigned j = 0; j < n; ++j) {
if (0 != ((1 << i) & Ws[j])) {
B.push_back(Xs[j]);
}
}
// We is every second position of W
We.reset();
for (unsigned j = 0; j + 2 <= W.size(); j += 2) {
We.push_back(W[j+1]);
}
// if we test for equality, then what is not included has to be false.
if (m_t == EQ && W.size() % 2 == 1) {
E.push_back(mk_not(W.back()));
}
// B is the sorted (from largest to smallest bit) version of S
S.reset();
sorting(B.size(), B.begin(), S);
// W is the merge of S and We
W.reset();
merge(S.size(), S.begin(), We.size(), We.begin(), W);
}
if (m_t == EQ) {
E.push_back(W[a - 1]);
if (a < W.size()) E.push_back(mk_not(W[a]));
return mk_and(E);
}
SASSERT(m_t == GE_FULL);
return W[a - 1];
}
private: private: