From 4c76d43670f5307287506680de6c8eb24e263923 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 28 Feb 2019 08:35:22 -0800 Subject: [PATCH] add binary_merge encoding option Signed-off-by: Nikolaj Bjorner --- src/ast/rewriter/pb2bv_rewriter.cpp | 37 +++++++- src/sat/sat_params.pyg | 2 +- src/smt/smt_kernel.cpp | 3 +- src/smt/smt_kernel.h | 2 +- src/test/sorting_network.cpp | 137 ++++++++++++++++++++++++++++ src/util/sorting_network.h | 106 +++++++++++++++++++++ 6 files changed, 283 insertions(+), 4 deletions(-) diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index 0faf08515..015372933 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -161,7 +161,6 @@ struct pb2bv_rewriter::imp { } if (m_pb_solver == "segmented") { - expr_ref result(m); switch (is_le) { case l_true: return mk_seg_le(k); case l_false: return mk_seg_ge(k); @@ -169,6 +168,11 @@ struct pb2bv_rewriter::imp { } } + if (m_pb_solver == "binary_merge") { + expr_ref result = binary_merge(is_le, k); + if (result) return result; + } + // fall back to divide and conquer encoding. SASSERT(k.is_pos()); expr_ref zero(m), bound(m); @@ -494,6 +498,37 @@ struct pb2bv_rewriter::imp { return true; } + /** + \brief binary merge encoding. + */ + expr_ref binary_merge(lbool is_le, rational const& k) { + expr_ref result(m); + unsigned_vector coeffs; + for (rational const& c : m_coeffs) { + if (c.is_unsigned()) { + coeffs.push_back(c.get_unsigned()); + } + else { + return result; + } + } + if (!k.is_unsigned()) { + return result; + } + switch (is_le) { + case l_true: + result = m_sort.le(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr()); + break; + case l_false: + result = m_sort.ge(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr()); + break; + case l_undef: + result = m_sort.eq(k.get_unsigned(), coeffs.size(), coeffs.c_ptr(), m_args.c_ptr()); + break; + } + return result; + } + /** \brief Segment based encoding. The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient. diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 88b196d04..178132f63 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -43,7 +43,7 @@ def_module_params('sat', ('drat.check_unsat', BOOL, False, 'build up internal proof and check'), ('drat.check_sat', BOOL, False, 'build up internal trace, check satisfying model'), ('cardinality.solver', BOOL, True, 'use cardinality solver'), - ('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use native solver)'), + ('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), binary_merge, segmented, solver (use native solver)'), ('xor.solver', BOOL, False, 'use xor solver'), ('cardinality.encoding', SYMBOL, 'grouped', 'encoding used for at-most-k constraints: grouped, bimander, ordered, unate, circuit'), ('pb.resolve', SYMBOL, 'cardinality', 'resolution strategy for boolean algebra solver: cardinality, rounding'), diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index adcda3979..6c5bad479 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -383,8 +383,9 @@ namespace smt { return m_imp->next_decision(); } - void kernel::display(std::ostream & out) const { + std::ostream& kernel::display(std::ostream & out) const { m_imp->display(out); + return out; } void kernel::collect_statistics(::statistics & st) const { diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index a46195e02..e21a49dc4 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -237,7 +237,7 @@ namespace smt { /** \brief (For debubbing purposes) Prints the state of the kernel */ - void display(std::ostream & out) const; + std::ostream& display(std::ostream & out) const; /** \brief Collect runtime statistics. diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp index 2470df528..9a143c012 100644 --- a/src/test/sorting_network.cpp +++ b/src/test/sorting_network.cpp @@ -522,7 +522,144 @@ static void tst_sorting_network(sorting_network_encoding enc) { test_sorting5(enc); } +static void test_pb(unsigned max_w, unsigned sz, unsigned_vector& ws) { + if (ws.empty()) { + for (unsigned w = 1; w <= max_w; ++w) { + ws.push_back(w); + test_pb(max_w, sz, ws); + ws.pop_back(); + } + } + else if (ws.size() < sz) { + for (unsigned w = ws.back(); w <= max_w; ++w) { + ws.push_back(w); + test_pb(max_w, sz, ws); + ws.pop_back(); + } + } + else { + SASSERT(ws.size() == sz); + ast_manager m; + reg_decl_plugins(m); + expr_ref_vector xs(m), nxs(m); + expr_ref ge(m), eq(m); + smt_params fp; + smt::kernel solver(m, fp); + for (unsigned i = 0; i < sz; ++i) { + xs.push_back(m.mk_const(symbol(i), m.mk_bool_sort())); + nxs.push_back(m.mk_not(xs.back())); + } + std::cout << ws << " " << "\n"; + for (unsigned k = max_w + 1; k < ws.size()*max_w; ++k) { + + ast_ext2 ext(m); + psort_nw sn(ext); + solver.push(); + //std::cout << "bound: " << k << "\n"; + //std::cout << ws << " " << xs << "\n"; + ge = sn.ge(k, sz, ws.c_ptr(), xs.c_ptr()); + //std::cout << "ge: " << ge << "\n"; + for (expr* cls : ext.m_clauses) { + solver.assert_expr(cls); + } + // solver.display(std::cout); + // for each truth assignment to xs, validate + // that circuit computes the right value for ge + for (unsigned i = 0; i < (1ul << sz); ++i) { + solver.push(); + unsigned sum = 0; + for (unsigned j = 0; j < sz; ++j) { + if (0 == ((1 << j) & i)) { + solver.assert_expr(xs.get(j)); + sum += ws[j]; + } + else { + solver.assert_expr(nxs.get(j)); + } + } + // std::cout << "bound: " << k << "\n"; + // std::cout << ws << " " << xs << "\n"; + // std::cout << sum << " >= " << k << " : " << (sum >= k) << " "; + solver.push(); + if (sum < k) { + solver.assert_expr(m.mk_not(ge)); + } + else { + solver.assert_expr(ge); + } + // solver.display(std::cout) << "\n"; + VERIFY(solver.check() == l_true); + solver.pop(1); + + solver.push(); + if (sum >= k) { + solver.assert_expr(m.mk_not(ge)); + } + else { + solver.assert_expr(ge); + } + // solver.display(std::cout) << "\n"; + VERIFY(l_false == solver.check()); + solver.pop(1); + solver.pop(1); + } + solver.pop(1); + + solver.push(); + eq = sn.eq(k, sz, ws.c_ptr(), xs.c_ptr()); + + for (expr* cls : ext.m_clauses) { + solver.assert_expr(cls); + } + // for each truth assignment to xs, validate + // that circuit computes the right value for ge + for (unsigned i = 0; i < (1ul << sz); ++i) { + solver.push(); + unsigned sum = 0; + for (unsigned j = 0; j < sz; ++j) { + if (0 == ((1 << j) & i)) { + solver.assert_expr(xs.get(j)); + sum += ws[j]; + } + else { + solver.assert_expr(nxs.get(j)); + } + } + solver.push(); + if (sum != k) { + solver.assert_expr(m.mk_not(eq)); + } + else { + solver.assert_expr(eq); + } + // solver.display(std::cout) << "\n"; + VERIFY(solver.check() == l_true); + solver.pop(1); + + solver.push(); + if (sum == k) { + solver.assert_expr(m.mk_not(eq)); + } + else { + solver.assert_expr(eq); + } + VERIFY(l_false == solver.check()); + solver.pop(1); + solver.pop(1); + } + + solver.pop(1); + } + } +} + +static void tst_pb() { + unsigned_vector ws; + test_pb(3, 3, ws); +} + void tst_sorting_network() { + tst_pb(); tst_sorting_network(sorting_network_encoding::unate_at_most); tst_sorting_network(sorting_network_encoding::circuit_at_most); tst_sorting_network(sorting_network_encoding::ordered_at_most); diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index b094a5b66..b7358b58c 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -357,6 +357,112 @@ Notes: } } + /** + \brief encode clauses for ws*xs >= k + + - normalize inequality to ws'*xs' >= a*2^(bits-1) + - for each binary digit, sort contributions + - merge with even digits from lower layer - creating 2*n vector + - for last layer return that index a is on. + */ + + literal le(unsigned k, unsigned n, unsigned const* ws, literal const* xs) { + unsigned sum = 0; + literal_vector Xs; + for (unsigned i = 0; i < n; ++i) { + sum += ws[i]; + Xs.push_back(mk_not(xs[i])); + } + if (k >= sum) { + return ctx.mk_true(); + } + return ge(sum - k, n, ws, Xs.begin()); + } + + literal ge(unsigned k, unsigned n, unsigned const* ws, literal const* xs) { + m_t = GE_FULL; + return cmp(k, n, ws, xs); + } + + literal eq(unsigned k, unsigned n, unsigned const* ws, literal const* xs) { + return mk_and(ge(k, n, ws, xs), le(k, n, ws, xs)); +#if 0 + m_t = EQ; + return cmp(k, n, ws, xs); +#endif + } + + literal cmp(unsigned k, unsigned n, unsigned const* ws, literal const* xs) { + unsigned w_max = 0, sum = 0; + literal_vector Xs; + unsigned_vector Ws; + for (unsigned i = 0; i < n; ++i) { + sum += ws[i]; + w_max = std::max(ws[i], w_max); + Xs.push_back(xs[i]); + Ws.push_back(ws[i]); + } + if (sum < k) { + return ctx.mk_false(); + } + + // Normalize to form Ws*Xs ~ a*2^{q-1} + SASSERT(w_max > 0); + unsigned bits = 0; + while (w_max > 0) { + bits++; + w_max >>= 1; + } + unsigned pow = (1ul << (bits-1)); + unsigned a = (k + pow - 1) / pow; // a*pow >= k + SASSERT(a*pow >= k); + SASSERT((a-1)*pow < k); + if (a*pow > k) { + Ws.push_back(a*pow - k); + Xs.push_back(ctx.mk_true()); + ++n; + k = a*pow; + } + literal_vector W, We, B, S, E; + for (unsigned i = 0; i < bits; ++i) { + + // B is digits from Xs that are set at bit position i + B.reset(); + for (unsigned j = 0; j < n; ++j) { + if (0 != ((1 << i) & Ws[j])) { + B.push_back(Xs[j]); + } + } + + // We is every second position of W + We.reset(); + for (unsigned j = 0; j + 2 <= W.size(); j += 2) { + We.push_back(W[j+1]); + } + // if we test for equality, then what is not included has to be false. + if (m_t == EQ && W.size() % 2 == 1) { + E.push_back(mk_not(W.back())); + } + + // B is the sorted (from largest to smallest bit) version of S + S.reset(); + sorting(B.size(), B.begin(), S); + + // W is the merge of S and We + W.reset(); + merge(S.size(), S.begin(), We.size(), We.begin(), W); + } + + if (m_t == EQ) { + E.push_back(W[a - 1]); + if (a < W.size()) E.push_back(mk_not(W[a])); + return mk_and(E); + } + SASSERT(m_t == GE_FULL); + return W[a - 1]; + } + + private: