3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 10:25:18 +00:00

add binary_merge encoding option

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2019-02-28 08:35:22 -08:00
parent c4ee4ffae4
commit 4c76d43670
6 changed files with 283 additions and 4 deletions

View file

@ -161,7 +161,6 @@ struct pb2bv_rewriter::imp {
}
if (m_pb_solver == "segmented") {
expr_ref result(m);
switch (is_le) {
case l_true: return mk_seg_le(k);
case l_false: return mk_seg_ge(k);
@ -169,6 +168,11 @@ struct pb2bv_rewriter::imp {
}
}
if (m_pb_solver == "binary_merge") {
expr_ref result = binary_merge(is_le, k);
if (result) return result;
}
// fall back to divide and conquer encoding.
SASSERT(k.is_pos());
expr_ref zero(m), bound(m);
@ -494,6 +498,37 @@ struct pb2bv_rewriter::imp {
return true;
}
/**
\brief binary merge encoding.
*/
expr_ref binary_merge(lbool is_le, rational const& k) {
expr_ref result(m);
unsigned_vector coeffs;
for (rational const& c : m_coeffs) {
if (c.is_unsigned()) {
coeffs.push_back(c.get_unsigned());
}
else {
return result;
}
}
if (!k.is_unsigned()) {
return result;
}
switch (is_le) {
case l_true:
result = m_sort.le(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
break;
case l_false:
result = m_sort.ge(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
break;
case l_undef:
result = m_sort.eq(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
break;
}
return result;
}
/**
\brief Segment based encoding.
The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient.

View file

@ -43,7 +43,7 @@ def_module_params('sat',
('drat.check_unsat', BOOL, False, 'build up internal proof and check'),
('drat.check_sat', BOOL, False, 'build up internal trace, check satisfying model'),
('cardinality.solver', BOOL, True, 'use cardinality solver'),
('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use native solver)'),
('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), binary_merge, segmented, solver (use native solver)'),
('xor.solver', BOOL, False, 'use xor solver'),
('cardinality.encoding', SYMBOL, 'grouped', 'encoding used for at-most-k constraints: grouped, bimander, ordered, unate, circuit'),
('pb.resolve', SYMBOL, 'cardinality', 'resolution strategy for boolean algebra solver: cardinality, rounding'),

View file

@ -383,8 +383,9 @@ namespace smt {
return m_imp->next_decision();
}
void kernel::display(std::ostream & out) const {
std::ostream& kernel::display(std::ostream & out) const {
m_imp->display(out);
return out;
}
void kernel::collect_statistics(::statistics & st) const {

View file

@ -237,7 +237,7 @@ namespace smt {
/**
\brief (For debubbing purposes) Prints the state of the kernel
*/
void display(std::ostream & out) const;
std::ostream& display(std::ostream & out) const;
/**
\brief Collect runtime statistics.

View file

@ -522,7 +522,144 @@ static void tst_sorting_network(sorting_network_encoding enc) {
test_sorting5(enc);
}
static void test_pb(unsigned max_w, unsigned sz, unsigned_vector& ws) {
if (ws.empty()) {
for (unsigned w = 1; w <= max_w; ++w) {
ws.push_back(w);
test_pb(max_w, sz, ws);
ws.pop_back();
}
}
else if (ws.size() < sz) {
for (unsigned w = ws.back(); w <= max_w; ++w) {
ws.push_back(w);
test_pb(max_w, sz, ws);
ws.pop_back();
}
}
else {
SASSERT(ws.size() == sz);
ast_manager m;
reg_decl_plugins(m);
expr_ref_vector xs(m), nxs(m);
expr_ref ge(m), eq(m);
smt_params fp;
smt::kernel solver(m, fp);
for (unsigned i = 0; i < sz; ++i) {
xs.push_back(m.mk_const(symbol(i), m.mk_bool_sort()));
nxs.push_back(m.mk_not(xs.back()));
}
std::cout << ws << " " << "\n";
for (unsigned k = max_w + 1; k < ws.size()*max_w; ++k) {
ast_ext2 ext(m);
psort_nw<ast_ext2> sn(ext);
solver.push();
//std::cout << "bound: " << k << "\n";
//std::cout << ws << " " << xs << "\n";
ge = sn.ge(k, sz, ws.c_ptr(), xs.c_ptr());
//std::cout << "ge: " << ge << "\n";
for (expr* cls : ext.m_clauses) {
solver.assert_expr(cls);
}
// solver.display(std::cout);
// for each truth assignment to xs, validate
// that circuit computes the right value for ge
for (unsigned i = 0; i < (1ul << sz); ++i) {
solver.push();
unsigned sum = 0;
for (unsigned j = 0; j < sz; ++j) {
if (0 == ((1 << j) & i)) {
solver.assert_expr(xs.get(j));
sum += ws[j];
}
else {
solver.assert_expr(nxs.get(j));
}
}
// std::cout << "bound: " << k << "\n";
// std::cout << ws << " " << xs << "\n";
// std::cout << sum << " >= " << k << " : " << (sum >= k) << " ";
solver.push();
if (sum < k) {
solver.assert_expr(m.mk_not(ge));
}
else {
solver.assert_expr(ge);
}
// solver.display(std::cout) << "\n";
VERIFY(solver.check() == l_true);
solver.pop(1);
solver.push();
if (sum >= k) {
solver.assert_expr(m.mk_not(ge));
}
else {
solver.assert_expr(ge);
}
// solver.display(std::cout) << "\n";
VERIFY(l_false == solver.check());
solver.pop(1);
solver.pop(1);
}
solver.pop(1);
solver.push();
eq = sn.eq(k, sz, ws.c_ptr(), xs.c_ptr());
for (expr* cls : ext.m_clauses) {
solver.assert_expr(cls);
}
// for each truth assignment to xs, validate
// that circuit computes the right value for ge
for (unsigned i = 0; i < (1ul << sz); ++i) {
solver.push();
unsigned sum = 0;
for (unsigned j = 0; j < sz; ++j) {
if (0 == ((1 << j) & i)) {
solver.assert_expr(xs.get(j));
sum += ws[j];
}
else {
solver.assert_expr(nxs.get(j));
}
}
solver.push();
if (sum != k) {
solver.assert_expr(m.mk_not(eq));
}
else {
solver.assert_expr(eq);
}
// solver.display(std::cout) << "\n";
VERIFY(solver.check() == l_true);
solver.pop(1);
solver.push();
if (sum == k) {
solver.assert_expr(m.mk_not(eq));
}
else {
solver.assert_expr(eq);
}
VERIFY(l_false == solver.check());
solver.pop(1);
solver.pop(1);
}
solver.pop(1);
}
}
}
static void tst_pb() {
unsigned_vector ws;
test_pb(3, 3, ws);
}
void tst_sorting_network() {
tst_pb();
tst_sorting_network(sorting_network_encoding::unate_at_most);
tst_sorting_network(sorting_network_encoding::circuit_at_most);
tst_sorting_network(sorting_network_encoding::ordered_at_most);

View file

@ -357,6 +357,112 @@ Notes:
}
}
/**
\brief encode clauses for ws*xs >= k
- normalize inequality to ws'*xs' >= a*2^(bits-1)
- for each binary digit, sort contributions
- merge with even digits from lower layer - creating 2*n vector
- for last layer return that index a is on.
*/
literal le(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
unsigned sum = 0;
literal_vector Xs;
for (unsigned i = 0; i < n; ++i) {
sum += ws[i];
Xs.push_back(mk_not(xs[i]));
}
if (k >= sum) {
return ctx.mk_true();
}
return ge(sum - k, n, ws, Xs.begin());
}
literal ge(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
m_t = GE_FULL;
return cmp(k, n, ws, xs);
}
literal eq(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
return mk_and(ge(k, n, ws, xs), le(k, n, ws, xs));
#if 0
m_t = EQ;
return cmp(k, n, ws, xs);
#endif
}
literal cmp(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
unsigned w_max = 0, sum = 0;
literal_vector Xs;
unsigned_vector Ws;
for (unsigned i = 0; i < n; ++i) {
sum += ws[i];
w_max = std::max(ws[i], w_max);
Xs.push_back(xs[i]);
Ws.push_back(ws[i]);
}
if (sum < k) {
return ctx.mk_false();
}
// Normalize to form Ws*Xs ~ a*2^{q-1}
SASSERT(w_max > 0);
unsigned bits = 0;
while (w_max > 0) {
bits++;
w_max >>= 1;
}
unsigned pow = (1ul << (bits-1));
unsigned a = (k + pow - 1) / pow; // a*pow >= k
SASSERT(a*pow >= k);
SASSERT((a-1)*pow < k);
if (a*pow > k) {
Ws.push_back(a*pow - k);
Xs.push_back(ctx.mk_true());
++n;
k = a*pow;
}
literal_vector W, We, B, S, E;
for (unsigned i = 0; i < bits; ++i) {
// B is digits from Xs that are set at bit position i
B.reset();
for (unsigned j = 0; j < n; ++j) {
if (0 != ((1 << i) & Ws[j])) {
B.push_back(Xs[j]);
}
}
// We is every second position of W
We.reset();
for (unsigned j = 0; j + 2 <= W.size(); j += 2) {
We.push_back(W[j+1]);
}
// if we test for equality, then what is not included has to be false.
if (m_t == EQ && W.size() % 2 == 1) {
E.push_back(mk_not(W.back()));
}
// B is the sorted (from largest to smallest bit) version of S
S.reset();
sorting(B.size(), B.begin(), S);
// W is the merge of S and We
W.reset();
merge(S.size(), S.begin(), We.size(), We.begin(), W);
}
if (m_t == EQ) {
E.push_back(W[a - 1]);
if (a < W.size()) E.push_back(mk_not(W[a]));
return mk_and(E);
}
SASSERT(m_t == GE_FULL);
return W[a - 1];
}
private: