mirror of
https://github.com/Z3Prover/z3
synced 2025-04-08 10:25:18 +00:00
add binary_merge encoding option
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
c4ee4ffae4
commit
4c76d43670
|
@ -161,7 +161,6 @@ struct pb2bv_rewriter::imp {
|
|||
}
|
||||
|
||||
if (m_pb_solver == "segmented") {
|
||||
expr_ref result(m);
|
||||
switch (is_le) {
|
||||
case l_true: return mk_seg_le(k);
|
||||
case l_false: return mk_seg_ge(k);
|
||||
|
@ -169,6 +168,11 @@ struct pb2bv_rewriter::imp {
|
|||
}
|
||||
}
|
||||
|
||||
if (m_pb_solver == "binary_merge") {
|
||||
expr_ref result = binary_merge(is_le, k);
|
||||
if (result) return result;
|
||||
}
|
||||
|
||||
// fall back to divide and conquer encoding.
|
||||
SASSERT(k.is_pos());
|
||||
expr_ref zero(m), bound(m);
|
||||
|
@ -494,6 +498,37 @@ struct pb2bv_rewriter::imp {
|
|||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief binary merge encoding.
|
||||
*/
|
||||
expr_ref binary_merge(lbool is_le, rational const& k) {
|
||||
expr_ref result(m);
|
||||
unsigned_vector coeffs;
|
||||
for (rational const& c : m_coeffs) {
|
||||
if (c.is_unsigned()) {
|
||||
coeffs.push_back(c.get_unsigned());
|
||||
}
|
||||
else {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (!k.is_unsigned()) {
|
||||
return result;
|
||||
}
|
||||
switch (is_le) {
|
||||
case l_true:
|
||||
result = m_sort.le(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
|
||||
break;
|
||||
case l_false:
|
||||
result = m_sort.ge(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
|
||||
break;
|
||||
case l_undef:
|
||||
result = m_sort.eq(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr());
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Segment based encoding.
|
||||
The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient.
|
||||
|
|
|
@ -43,7 +43,7 @@ def_module_params('sat',
|
|||
('drat.check_unsat', BOOL, False, 'build up internal proof and check'),
|
||||
('drat.check_sat', BOOL, False, 'build up internal trace, check satisfying model'),
|
||||
('cardinality.solver', BOOL, True, 'use cardinality solver'),
|
||||
('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use native solver)'),
|
||||
('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), binary_merge, segmented, solver (use native solver)'),
|
||||
('xor.solver', BOOL, False, 'use xor solver'),
|
||||
('cardinality.encoding', SYMBOL, 'grouped', 'encoding used for at-most-k constraints: grouped, bimander, ordered, unate, circuit'),
|
||||
('pb.resolve', SYMBOL, 'cardinality', 'resolution strategy for boolean algebra solver: cardinality, rounding'),
|
||||
|
|
|
@ -383,8 +383,9 @@ namespace smt {
|
|||
return m_imp->next_decision();
|
||||
}
|
||||
|
||||
void kernel::display(std::ostream & out) const {
|
||||
std::ostream& kernel::display(std::ostream & out) const {
|
||||
m_imp->display(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
void kernel::collect_statistics(::statistics & st) const {
|
||||
|
|
|
@ -237,7 +237,7 @@ namespace smt {
|
|||
/**
|
||||
\brief (For debubbing purposes) Prints the state of the kernel
|
||||
*/
|
||||
void display(std::ostream & out) const;
|
||||
std::ostream& display(std::ostream & out) const;
|
||||
|
||||
/**
|
||||
\brief Collect runtime statistics.
|
||||
|
|
|
@ -522,7 +522,144 @@ static void tst_sorting_network(sorting_network_encoding enc) {
|
|||
test_sorting5(enc);
|
||||
}
|
||||
|
||||
static void test_pb(unsigned max_w, unsigned sz, unsigned_vector& ws) {
|
||||
if (ws.empty()) {
|
||||
for (unsigned w = 1; w <= max_w; ++w) {
|
||||
ws.push_back(w);
|
||||
test_pb(max_w, sz, ws);
|
||||
ws.pop_back();
|
||||
}
|
||||
}
|
||||
else if (ws.size() < sz) {
|
||||
for (unsigned w = ws.back(); w <= max_w; ++w) {
|
||||
ws.push_back(w);
|
||||
test_pb(max_w, sz, ws);
|
||||
ws.pop_back();
|
||||
}
|
||||
}
|
||||
else {
|
||||
SASSERT(ws.size() == sz);
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
expr_ref_vector xs(m), nxs(m);
|
||||
expr_ref ge(m), eq(m);
|
||||
smt_params fp;
|
||||
smt::kernel solver(m, fp);
|
||||
for (unsigned i = 0; i < sz; ++i) {
|
||||
xs.push_back(m.mk_const(symbol(i), m.mk_bool_sort()));
|
||||
nxs.push_back(m.mk_not(xs.back()));
|
||||
}
|
||||
std::cout << ws << " " << "\n";
|
||||
for (unsigned k = max_w + 1; k < ws.size()*max_w; ++k) {
|
||||
|
||||
ast_ext2 ext(m);
|
||||
psort_nw<ast_ext2> sn(ext);
|
||||
solver.push();
|
||||
//std::cout << "bound: " << k << "\n";
|
||||
//std::cout << ws << " " << xs << "\n";
|
||||
ge = sn.ge(k, sz, ws.c_ptr(), xs.c_ptr());
|
||||
//std::cout << "ge: " << ge << "\n";
|
||||
for (expr* cls : ext.m_clauses) {
|
||||
solver.assert_expr(cls);
|
||||
}
|
||||
// solver.display(std::cout);
|
||||
// for each truth assignment to xs, validate
|
||||
// that circuit computes the right value for ge
|
||||
for (unsigned i = 0; i < (1ul << sz); ++i) {
|
||||
solver.push();
|
||||
unsigned sum = 0;
|
||||
for (unsigned j = 0; j < sz; ++j) {
|
||||
if (0 == ((1 << j) & i)) {
|
||||
solver.assert_expr(xs.get(j));
|
||||
sum += ws[j];
|
||||
}
|
||||
else {
|
||||
solver.assert_expr(nxs.get(j));
|
||||
}
|
||||
}
|
||||
// std::cout << "bound: " << k << "\n";
|
||||
// std::cout << ws << " " << xs << "\n";
|
||||
// std::cout << sum << " >= " << k << " : " << (sum >= k) << " ";
|
||||
solver.push();
|
||||
if (sum < k) {
|
||||
solver.assert_expr(m.mk_not(ge));
|
||||
}
|
||||
else {
|
||||
solver.assert_expr(ge);
|
||||
}
|
||||
// solver.display(std::cout) << "\n";
|
||||
VERIFY(solver.check() == l_true);
|
||||
solver.pop(1);
|
||||
|
||||
solver.push();
|
||||
if (sum >= k) {
|
||||
solver.assert_expr(m.mk_not(ge));
|
||||
}
|
||||
else {
|
||||
solver.assert_expr(ge);
|
||||
}
|
||||
// solver.display(std::cout) << "\n";
|
||||
VERIFY(l_false == solver.check());
|
||||
solver.pop(1);
|
||||
solver.pop(1);
|
||||
}
|
||||
solver.pop(1);
|
||||
|
||||
solver.push();
|
||||
eq = sn.eq(k, sz, ws.c_ptr(), xs.c_ptr());
|
||||
|
||||
for (expr* cls : ext.m_clauses) {
|
||||
solver.assert_expr(cls);
|
||||
}
|
||||
// for each truth assignment to xs, validate
|
||||
// that circuit computes the right value for ge
|
||||
for (unsigned i = 0; i < (1ul << sz); ++i) {
|
||||
solver.push();
|
||||
unsigned sum = 0;
|
||||
for (unsigned j = 0; j < sz; ++j) {
|
||||
if (0 == ((1 << j) & i)) {
|
||||
solver.assert_expr(xs.get(j));
|
||||
sum += ws[j];
|
||||
}
|
||||
else {
|
||||
solver.assert_expr(nxs.get(j));
|
||||
}
|
||||
}
|
||||
solver.push();
|
||||
if (sum != k) {
|
||||
solver.assert_expr(m.mk_not(eq));
|
||||
}
|
||||
else {
|
||||
solver.assert_expr(eq);
|
||||
}
|
||||
// solver.display(std::cout) << "\n";
|
||||
VERIFY(solver.check() == l_true);
|
||||
solver.pop(1);
|
||||
|
||||
solver.push();
|
||||
if (sum == k) {
|
||||
solver.assert_expr(m.mk_not(eq));
|
||||
}
|
||||
else {
|
||||
solver.assert_expr(eq);
|
||||
}
|
||||
VERIFY(l_false == solver.check());
|
||||
solver.pop(1);
|
||||
solver.pop(1);
|
||||
}
|
||||
|
||||
solver.pop(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void tst_pb() {
|
||||
unsigned_vector ws;
|
||||
test_pb(3, 3, ws);
|
||||
}
|
||||
|
||||
void tst_sorting_network() {
|
||||
tst_pb();
|
||||
tst_sorting_network(sorting_network_encoding::unate_at_most);
|
||||
tst_sorting_network(sorting_network_encoding::circuit_at_most);
|
||||
tst_sorting_network(sorting_network_encoding::ordered_at_most);
|
||||
|
|
|
@ -357,6 +357,112 @@ Notes:
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
\brief encode clauses for ws*xs >= k
|
||||
|
||||
- normalize inequality to ws'*xs' >= a*2^(bits-1)
|
||||
- for each binary digit, sort contributions
|
||||
- merge with even digits from lower layer - creating 2*n vector
|
||||
- for last layer return that index a is on.
|
||||
*/
|
||||
|
||||
literal le(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
|
||||
unsigned sum = 0;
|
||||
literal_vector Xs;
|
||||
for (unsigned i = 0; i < n; ++i) {
|
||||
sum += ws[i];
|
||||
Xs.push_back(mk_not(xs[i]));
|
||||
}
|
||||
if (k >= sum) {
|
||||
return ctx.mk_true();
|
||||
}
|
||||
return ge(sum - k, n, ws, Xs.begin());
|
||||
}
|
||||
|
||||
literal ge(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
|
||||
m_t = GE_FULL;
|
||||
return cmp(k, n, ws, xs);
|
||||
}
|
||||
|
||||
literal eq(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
|
||||
return mk_and(ge(k, n, ws, xs), le(k, n, ws, xs));
|
||||
#if 0
|
||||
m_t = EQ;
|
||||
return cmp(k, n, ws, xs);
|
||||
#endif
|
||||
}
|
||||
|
||||
literal cmp(unsigned k, unsigned n, unsigned const* ws, literal const* xs) {
|
||||
unsigned w_max = 0, sum = 0;
|
||||
literal_vector Xs;
|
||||
unsigned_vector Ws;
|
||||
for (unsigned i = 0; i < n; ++i) {
|
||||
sum += ws[i];
|
||||
w_max = std::max(ws[i], w_max);
|
||||
Xs.push_back(xs[i]);
|
||||
Ws.push_back(ws[i]);
|
||||
}
|
||||
if (sum < k) {
|
||||
return ctx.mk_false();
|
||||
}
|
||||
|
||||
// Normalize to form Ws*Xs ~ a*2^{q-1}
|
||||
SASSERT(w_max > 0);
|
||||
unsigned bits = 0;
|
||||
while (w_max > 0) {
|
||||
bits++;
|
||||
w_max >>= 1;
|
||||
}
|
||||
unsigned pow = (1ul << (bits-1));
|
||||
unsigned a = (k + pow - 1) / pow; // a*pow >= k
|
||||
SASSERT(a*pow >= k);
|
||||
SASSERT((a-1)*pow < k);
|
||||
if (a*pow > k) {
|
||||
Ws.push_back(a*pow - k);
|
||||
Xs.push_back(ctx.mk_true());
|
||||
++n;
|
||||
k = a*pow;
|
||||
}
|
||||
literal_vector W, We, B, S, E;
|
||||
for (unsigned i = 0; i < bits; ++i) {
|
||||
|
||||
// B is digits from Xs that are set at bit position i
|
||||
B.reset();
|
||||
for (unsigned j = 0; j < n; ++j) {
|
||||
if (0 != ((1 << i) & Ws[j])) {
|
||||
B.push_back(Xs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// We is every second position of W
|
||||
We.reset();
|
||||
for (unsigned j = 0; j + 2 <= W.size(); j += 2) {
|
||||
We.push_back(W[j+1]);
|
||||
}
|
||||
// if we test for equality, then what is not included has to be false.
|
||||
if (m_t == EQ && W.size() % 2 == 1) {
|
||||
E.push_back(mk_not(W.back()));
|
||||
}
|
||||
|
||||
// B is the sorted (from largest to smallest bit) version of S
|
||||
S.reset();
|
||||
sorting(B.size(), B.begin(), S);
|
||||
|
||||
// W is the merge of S and We
|
||||
W.reset();
|
||||
merge(S.size(), S.begin(), We.size(), We.begin(), W);
|
||||
}
|
||||
|
||||
if (m_t == EQ) {
|
||||
E.push_back(W[a - 1]);
|
||||
if (a < W.size()) E.push_back(mk_not(W[a]));
|
||||
return mk_and(E);
|
||||
}
|
||||
SASSERT(m_t == GE_FULL);
|
||||
return W[a - 1];
|
||||
}
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
|
|
Loading…
Reference in a new issue