diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index bae79ca9c..5d9f42bfd 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -28,71 +28,228 @@ Notes: namespace smt { -#if 0 // parametric sorting network + // Described in Abio et.al. CP 2013. class psort_nw { - enum cmp_t { LE, GE, EQ }; - ast_manager& m; + class vc { + unsigned v; // number of vertices + unsigned c; // number of clauses + static const unsigned lambda = 5; + public: + vc(unsigned v, unsigned c):v(v), c(c) {} + + bool operator<(vc const& other) const { + return to_int() < other.to_int(); + } + vc operator+(vc const& other) const { + return vc(v + other.v, c + other.c); + } + unsigned to_int() const { + return lambda*v + c; + } + vc operator*(unsigned n) const { + return vc(n*v, n*c); + } + }; + + static vc min(vc const& v1, vc const& v2) { + return (v1.to_int() < v2.to_int())?v1:v2; + } + + + enum cmp_t { LE, GE, EQ, GE_FULL, LE_FULL }; context& ctx; cmp_t m_t; + // for testing + static const bool m_disable_dcard = false; + static const bool m_disable_dsorting = false; + static const bool m_disable_dsmerge = false; + static const bool m_force_dcard = false; + static const bool m_force_dsorting = false; + static const bool m_force_dsmerge = false; + public: - psort_nw(ast_manager& m, context& c): - m(m), ctx(c) {} + struct stats { + unsigned m_num_compiled_vars; + unsigned m_num_compiled_clauses; + void reset() { memset(this, 0, sizeof(*this)); } + stats() { reset(); } + }; + stats m_stats; - literal ge(unsigned m, unsigned n, literal const* xs) { - SASSERT(0 < m && m <= n); - literal_vector out; - m_t = GE; - card(m, n, xs, out); - return out[m-1]; // check + psort_nw(context& c): ctx(c) {} + + literal ge(bool full, unsigned k, unsigned n, literal const* xs) { + if (k > n) { + return false_literal; + } + if (k == 0) { + return true_literal; + } + SASSERT(0 < k && k <= n); + literal_vector in, out; + if (dualize(k, n, xs, in)) { + return le(full, k, in.size(), in.c_ptr()); + } + else { + SASSERT(2*k <= n); + m_t = full?GE_FULL:GE; + card(k, n, xs, out); + return out[k-1]; + } + } + + literal le(bool full, unsigned k, unsigned n, literal const* xs) { + if (k >= n) { + return true_literal; + } + SASSERT(k < n); + literal_vector in, out; + if (dualize(k, n, xs, in)) { + return ge(full, k, n, in.c_ptr()); + } + else { + SASSERT(2*k <= n); + m_t = full?LE_FULL:LE; + card(k + 1, n, xs, out); + return ~out[k]; + } } - literal le(unsigned m, unsigned n, literal const* xs) { - SASSERT(0 <= m && m < n); - literal_vector out; - m_t = LE; - card(m, n, xs, out); - return out[m-1]; // check - } - - literal eq(unsigned m, unsigned n, literal const* xs) { - SASSERT(0 <= m && m <= n); - literal_vector out; - m_t = EQ; - card(m, n, xs, out); - return null_literal; // TBD + literal eq(unsigned k, unsigned n, literal const* xs) { + if (k > n) { + return false_literal; + } + SASSERT(k <= n); + literal_vector in, out; + if (dualize(k, n, xs, in)) { + return eq(k, n, in.c_ptr()); + } + else { + SASSERT(2*k < n); + m_t = EQ; + card(k+1, n, xs, out); + SASSERT(out.size() >= k+1); + return out[k-1]; // & ~out[m] TBD + } } private: - void card(unsigned m, unsigned n, literal const* xs, literal_vector& out) { - if (n <= m) { - sorting(n, xs, out); + std::ostream& pp(std::ostream& out, unsigned n, literal const* lits) { + for (unsigned i = 0; i < n; ++i) out << lits[i] << " "; + return out; + } + + std::ostream& pp(std::ostream& out, literal_vector const& lits) { + for (unsigned i = 0; i < lits.size(); ++i) out << lits[i] << " "; + return out; + } + + std::ostream& ppv(std::ostream& out, unsigned n, literal const* lits) { + for (unsigned i = 0; i < n; ++i) { + expr_ref tmp(ctx.get_manager()); + ctx.literal2expr(lits[i], tmp); + out << tmp << " "; } - if (use_dcard(n, m)) { - dsorting(m, n, xs, out); + return out; + } + + std::ostream& ppv(std::ostream& out, literal_vector const& lits) { + for (unsigned i = 0; i < lits.size(); ++i) { + expr_ref tmp(ctx.get_manager()); + ctx.literal2expr(lits[i], tmp); + out << tmp << " "; } - else { - literal_vector out1, out2; - unsigned l = n/2; // TBD - card(m, l, xs, out1); - card(m, n-l, xs + l, out2); - smerge(m, out1.size(), out1.c_ptr(), out2.size(), out2.c_ptr(), out); + return out; + } + + // 0 <= k <= N + // SUM x_i >= k + // <=> + // SUM ~x_i <= N - k + // suppose k > N/2, then it is better to solve dual. + + bool dualize(unsigned& k, unsigned N, literal const* xs, literal_vector& in) { + SASSERT(0 <= k && k <= N); + if (2*k <= N) { + return false; } + k = N - k; + for (unsigned i = 0; i < N; ++i) { + in.push_back(~xs[i]); + } + TRACE("pb", + pp(tout << N << ": ", in); + tout << " ~ " << k << "\n";); + return true; } bool even(unsigned n) const { return (0 == (n & 0x1)); } bool odd(unsigned n) const { return !even(n); } - + unsigned ceil2(unsigned n) const { return n/2 + odd(n); } + unsigned floor2(unsigned n) const { return n/2; } + unsigned power2(unsigned n) const { SASSERT(n < 10); return 1 << n; } + + literal max(literal a, literal b) { + if (a == b) return a; + m_stats.m_num_compiled_vars++; + ast_manager& m = ctx.get_manager(); + expr_ref t1(m), t2(m), t3(m); + ctx.literal2expr(a, t1); + ctx.literal2expr(b, t2); + t3 = m.mk_or(t1, t2); + bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); + return literal(v); + } + + literal min(literal a, literal b) { + if (a == b) return a; + m_stats.m_num_compiled_vars++; + ast_manager& m = ctx.get_manager(); + expr_ref t1(m), t2(m), t3(m); + ctx.literal2expr(a, t1); + ctx.literal2expr(b, t2); + t3 = m.mk_and(t1, t2); + bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); + return literal(v); + } + + literal fresh() { + m_stats.m_num_compiled_vars++; + ast_manager& m = ctx.get_manager(); + app_ref y(m); + y = m.mk_fresh_const("y", m.mk_bool_sort()); + return literal(ctx.mk_bool_var(y)); + } + void add_clause(literal l1, literal l2, literal l3) { + literal lits[3] = { l1, l2, l3 }; + add_clause(3, lits); + } + void add_clause(literal l1, literal l2) { + literal lits[2] = { l1, l2 }; + add_clause(2, lits); + } + void add_clause(unsigned n, literal const* ls) { + m_stats.m_num_compiled_clauses++; + literal_vector tmp(n, ls); + TRACE("pb", pp(tout, n, ls) << "\n";); + ctx.mk_clause(n, tmp.c_ptr(), 0, CLS_AUX, 0); + } + + // y1 <= max(x1,x2) + // y2 <= min(x1,x2) void cmp_ge(literal x1, literal x2, literal y1, literal y2) { add_clause(~y2, x1); add_clause(~y2, x2); - add_clause(x1, x2, ~y1); + add_clause(~y1, x1, x2); } + // max(x1,x2) <= y1 + // min(x1,x2) <= y2 void cmp_le(literal x1, literal x2, literal y1, literal y2) { add_clause(~x1, y1); add_clause(~x2, y1); @@ -111,21 +268,57 @@ namespace smt { case EQ: cmp_eq(x1, x2, y1, y2); break; } } - - void size_merge(unsigned a, unsigned b, - unsigned& vD, unsigned& cD, - unsigned& vR, unsigned& cR) { - vD = a + b; - cD = a*b + a + b; - + vc vc_cmp() { + return vc(2, (m_t==EQ)?6:3); } + void card(unsigned k, unsigned n, literal const* xs, literal_vector& out) { + TRACE("pb", tout << "card k:" << k << " n: " << n << "\n";); + if (n <= k) { + sorting(n, xs, out); + } + else if (use_dcard(k, n)) { + dsorting(k, n, xs, out); + } + else { + literal_vector out1, out2; + unsigned l = n/2; // TBD + card(k, l, xs, out1); + card(k, n-l, xs + l, out2); + smerge(k, out1.size(), out1.c_ptr(), out2.size(), out2.c_ptr(), out); + } + TRACE("pb", tout << "card k:" << k << " n: " << n << "\n"; + pp(tout << "in:", n, xs) << "\n"; + pp(tout << "out:", out) << "\n";); + + } + vc vc_card(unsigned k, unsigned n) { + if (n <= k) { + return vc_sorting(n); + } + else if (use_dcard(k, n)) { + return vc_dsorting(k, n); + } + else { + return vc_card_rec(k, n); + } + } + vc vc_card_rec(unsigned k, unsigned n) { + unsigned l = n/2; + return vc_card(k, l) + vc_card(k, n-l) + vc_smerge(k, l, n-l); + } + bool use_dcard(unsigned k, unsigned n) { + return m_force_dcard || (!m_disable_dcard && n < 10 && vc_dsorting(k, n) < vc_card_rec(k, n)); + } + + void merge(unsigned a, literal const* as, unsigned b, literal const* bs, literal_vector& out) { + TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n";); if (a == 1 && b == 1) { - y1 = fresh(); - y2 = fresh(); + literal y1 = max(as[0], bs[0]); + literal y2 = min(as[0], bs[0]); out.push_back(y1); out.push_back(y2); cmp(as[0], bs[0], y1, y2); @@ -134,27 +327,98 @@ namespace smt { out.append(b, bs); } else if (b == 0) { - merge(b, bs, a, as,out); + out.append(a, as); } else if (use_dsmerge(a, b, a + b)) { - dsmerge(a, as, b, bs, a + b, out); + dsmerge(a + b, a, as, b, bs, out); } - else if (even(a) && even(b) && b > 0) { - - } - else if (even(a) && odd(b) && (a > 1 || b > 1)) { - - } - else if (odd(a) && odd(b) && (a > 1 || b > 1)) { - + else if (even(a) && odd(b)) { + merge(b, bs, a, as, out); } else { - merge(b, bs, a, as, out); + literal_vector even_a, odd_a; + literal_vector even_b, odd_b; + literal_vector out1, out2; + SASSERT(a > 1 || b > 1); + split(a, as, even_a, odd_a); + split(b, bs, even_b, odd_b); + SASSERT(!even_a.empty()); + SASSERT(!even_b.empty()); + merge(even_a.size(), even_a.c_ptr(), + even_b.size(), even_b.c_ptr(), out1); + merge(odd_a.size(), odd_a.c_ptr(), + odd_b.size(), odd_b.c_ptr(), out2); + interleave(out1, out2, out); + } + TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n"; + pp(tout << "a:", a, as) << "\n"; + pp(tout << "b:", b, bs) << "\n"; + pp(tout << "out:", out) << "\n";); + } + vc vc_merge(unsigned a, unsigned b) { + if (a == 1 && b == 1) { + return vc_cmp(); + } + else if (a == 0 || b == 0) { + return vc(0, 0); + } + else if (use_dsmerge(a, b, a + b)) { + return vc_dsmerge(a, b, a + b); + } + else { + return vc_merge_rec(a, b); } } + vc vc_merge_rec(unsigned a, unsigned b) { + return + vc_merge(ceil2(a), ceil2(b)) + + vc_merge(floor2(a), floor2(b)) + + vc_interleave(ceil2(a) + ceil2(b), floor2(a) + floor2(b)); + } + void split(unsigned n, literal const* ls, literal_vector& even, literal_vector& odd) { + for (unsigned i = 0; i < n; i += 2) { + even.push_back(ls[i]); + } + for (unsigned i = 1; i < n; i += 2) { + odd.push_back(ls[i]); + } + } + + void interleave(literal_vector const& as, + literal_vector const& bs, + literal_vector& out) { + TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n";); + SASSERT(as.size() >= bs.size()); + SASSERT(as.size() <= bs.size() + 2); + SASSERT(!as.empty()); + out.push_back(as[0]); + unsigned sz = std::min(as.size()-1, bs.size()); + for (unsigned i = 0; i < sz; ++i) { + literal y1 = max(as[i+1],bs[i]); + literal y2 = min(as[i+1],bs[i]); + cmp(as[i+1], bs[i], y1, y2); + out.push_back(y1); + out.push_back(y2); + } + if (as.size() == bs.size()) { + out.push_back(bs[sz]); + } + else if (as.size() == bs.size() + 2) { + out.push_back(as[sz+1]); + } + SASSERT(out.size() == as.size() + bs.size()); + TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n"; + pp(tout << "a: ", as) << "\n"; + pp(tout << "b: ", bs) << "\n"; + pp(tout << "out: ", out) << "\n";); + + } + vc vc_interleave(unsigned a, unsigned b) { + return vc_cmp()*std::min(a-1,b); + } - void sorting(cmp_t t, unsigned n, literal const* xs, - literal_vector& out) { + void sorting(unsigned n, literal const* xs, literal_vector& out) { + TRACE("pb", tout << "sorting: " << n << "\n";); switch(n) { case 0: break; @@ -179,94 +443,270 @@ namespace smt { } break; } + TRACE("pb", tout << "sorting: " << n << "\n"; + pp(tout << "in:", n, xs) << "\n"; + pp(tout << "out:", out) << "\n";); + + } + vc vc_sorting(unsigned n) { + switch(n) { + case 0: return vc(0,0); + case 1: return vc(0,0); + case 2: return vc_merge(1,1); + default: + if (use_dsorting(n)) { + return vc_dsorting(n, n); + } + else { + return vc_sorting_rec(n); + } + } + } + vc vc_sorting_rec(unsigned n) { + SASSERT(n > 2); + unsigned l = n/2; + return vc_sorting(l) + vc_sorting(n-l) + vc_merge(l, n-l); } - bool use_dsmerge(unsigned a, unsigned b, unsigned c) const { - return false; + bool use_dsorting(unsigned n) { + SASSERT(n > 2); + return m_force_dsorting || + (!m_disable_dsorting && n < 10 && vc_dsorting(n, n) < vc_sorting_rec(n)); } - void smerge(unsigned a, literal const* as, + void smerge(unsigned c, + unsigned a, literal const* as, unsigned b, literal const* bs, - unsigned c, literal_vector& out) { + TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n";); if (a == 1 && b == 1 && c == 1) { - literal y = fresh(); - add_clause(~as[0], y); - add_clause(~bs[0], y); + literal y = max(as[0], bs[0]); + if (m_t != GE) { + // x1 <= max(x1,x2) + // x2 <= max(x1,x2) + add_clause(~as[0], y); + add_clause(~bs[0], y); + } + if (m_t != LE) { + // max(x1,x2) <= x1, x2 + add_clause(~y, as[0], bs[0]); + } out.push_back(y); } + else if (a == 0) { + out.append(std::min(c, b), bs); + } + else if (b == 0) { + out.append(std::min(c, a), as); + } else if (a > c) { - smerge(a - c, as, b, bs, c, out); + smerge(c, c, as, b, bs, out); } else if (b > c) { - smerge(a, as, b - c, bs, c, out); + smerge(c, a, as, c, bs, out); } else if (a + b <= c) { merge(a, as, b, bs, out); } else if (use_dsmerge(a, b, c)) { - dsmerge(a, as, b, bs, c, out); + dsmerge(c, a, as, b, bs, out); } - else if (even(c)) { - - } - else if (odd(c)) { - SASSERT(c > 1); + else { + literal_vector even_a, odd_a; + literal_vector even_b, odd_b; + literal_vector out1, out2; + split(a, as, even_a, odd_a); + split(b, bs, even_b, odd_b); + SASSERT(!even_a.empty()); + SASSERT(!even_b.empty()); + unsigned c1, c2; + if (even(c)) { + c1 = 1 + c/2; c2 = c/2; + } + else { + c1 = (c + 1)/2; c2 = (c - 1)/2; + } + smerge(c1, even_a.size(), even_a.c_ptr(), + even_b.size(), even_b.c_ptr(), out1); + smerge(c2, odd_a.size(), odd_a.c_ptr(), + odd_b.size(), odd_b.c_ptr(), out2); + SASSERT(out1.size() == std::min(even_a.size()+even_b.size(), c1)); + SASSERT(out2.size() == std::min(odd_a.size()+odd_b.size(), c2)); + literal y; + if (even(c)) { + literal z1 = out1.back(); + literal z2 = out2.back(); + out1.pop_back(); + out2.pop_back(); + y = max(z1, z2); + if (m_t != GE) { + add_clause(~z1, y); + add_clause(~z2, y); + } + if (m_t != LE) { + add_clause(~y, z1, z2); + } + } + interleave(out1, out2, out); + if (even(c)) { + out.push_back(y); + } } + TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n"; + pp(tout << "a:", a, as) << "\n"; + pp(tout << "b:", b, bs) << "\n"; + pp(tout << "out:", out) << "\n"; + ); + SASSERT(out.size() == std::min(a + b, c)); } + vc vc_smerge(unsigned a, unsigned b, unsigned c) { + if (a == 1 && b == 1 && c == 1) { + vc v(1,0); + if (m_t != GE) v = v + vc(0, 2); + if (m_t != LE) v = v + vc(0, 1); + return v; + } + if (a == 0 || b == 0) return vc(0, 0); + if (a > c) return vc_smerge(c, b, c); + if (b > c) return vc_smerge(a, c, c); + if (a + b <= c) return vc_merge(a, b); + if (use_dsmerge(a, b, c)) return vc_dsmerge(a, b, c); + return vc_smerge_rec(a, b, c); + } + vc vc_smerge_rec(unsigned a, unsigned b, unsigned c) { + return + vc_smerge(ceil2(a), ceil2(b), even(c)?(1+c/2):((c+1)/2)) + + vc_smerge(floor2(a), floor2(b), even(c)?(c/2):((c-1)/2)) + + vc_interleave(ceil2(a)+ceil2(b),floor2(a)+floor2(b)) + + vc(1, 0) + + ((m_t != GE)?vc(0, 2):vc(0, 0)) + + ((m_t != LE)?vc(0, 1):vc(0, 0)); + } + bool use_dsmerge(unsigned a, unsigned b, unsigned c) { + return + m_force_dsmerge || + (!m_disable_dsmerge && + a < (1 << 15) && b < (1 << 15) && + vc_dsmerge(a, b, a + b) < vc_smerge_rec(a, b, c)); + } void dsmerge( + unsigned c, unsigned a, literal const* as, unsigned b, literal const* bs, - unsigned c, literal_vector& out) { + TRACE("pb", tout << "dsmerge: c:" << c << " a:" << a << " b:" << b << "\n";); + SASSERT(a <= c); + SASSERT(b <= c); + SASSERT(a + b > c); for (unsigned i = 0; i < c; ++i) { out.push_back(fresh()); } - for (unsigned i = 0; i < a; ++i) { - add_clause(~as[i],out[i]); + if (m_t != GE) { + for (unsigned i = 0; i < a; ++i) { + add_clause(~as[i], out[i]); + } + for (unsigned i = 0; i < b; ++i) { + add_clause(~bs[i], out[i]); + } + for (unsigned i = 1; i <= a; ++i) { + for (unsigned j = 1; j <= b && i + j <= c; ++j) { + add_clause(~as[i-1],~bs[j-1],out[i+j-1]); + } + } } - for (unsigned i = 0; i < b; ++i) { - add_clause(~bs[i],out[i]); - } - for (unsigned i = 0; i < a; ++i) { - for (unsigned j = 0; j < b && i + j < c; ++j) { - add_clause(~as[i],~bs[j],out[i+j]); + if (m_t != LE) { + for (unsigned k = 1; k <= c; ++k) { + literal_vector ls; + ls.push_back(~out[k-1]); + if (k <= a) { + ls.push_back(as[k-1]); + } + if (k <= b) { + ls.push_back(bs[k-1]); + } + for (unsigned i = 1; i <= std::min(a,k-1); ++i) { + if (k + 1 - i <= b) { + ls.push_back(as[i-1]); + ls.push_back(bs[k-i]); + add_clause(ls.size(), ls.c_ptr()); + ls.pop_back(); + ls.pop_back(); + } + } } } } + vc vc_dsmerge(unsigned a, unsigned b, unsigned c) { + vc v(c, 0); + if (m_t != GE) { + v = v + vc(0, a + b + std::min(a, c)*std::min(b, c)/2); + } + if (m_t != LE) { + v = v + vc(0, std::min(a, c)*std::min(b, c)/2); + } + return v; + } + void dsorting(unsigned m, unsigned n, literal const* xs, literal_vector& out) { + TRACE("pb", tout << "dsorting m: " << m << " n: " << n << "\n";); SASSERT(m <= n); + literal_vector lits; for (unsigned i = 0; i < m; ++i) { out.push_back(fresh()); } - for (unsigned k = 0; k < m; ++k) { - literal_vector lits; - lits.push_back(out[k]); - add_subset(k+1, 0, lits, n, xs); + if (m_t != GE) { + for (unsigned k = 1; k <= m; ++k) { + lits.push_back(out[k-1]); + add_subset(true, k, 0, lits, n, xs); + lits.pop_back(); + } + } + if (m_t != LE) { + for (unsigned k = 1; k <= m; ++k) { + lits.push_back(~out[k-1]); + add_subset(false, n-k+1, 0, lits, n, xs); + lits.pop_back(); + } } } + vc vc_dsorting(unsigned m, unsigned n) { + SASSERT(m <= n && n < 10); + vc v(m, 0); + if (m_t != GE) { + v = v + vc(0, power2(n-1)); + } + if (m_t != LE) { + v = v + vc(0, power2(n-1)); + } + return v; + } - void add_subset(unsigned k, unsigned offset, literal_vector& lits, + void add_subset(bool polarity, unsigned k, unsigned offset, literal_vector& lits, unsigned n, literal const* xs) { - SASSERT(k + offset < n); + TRACE("pb", tout << "k:" << k << " offset: " << offset << " n: " << n << " "; + pp(tout, lits) << "\n";); + SASSERT(k + offset <= n); if (k == 0) { - ctx.add_clause(lits.size(), lits.c_ptr()); + add_clause(lits.size(), lits.c_ptr()); return; } - for (unsigned i = offset; i < n-offset-k; ++i) { - lits.push_back(xs[i]); - add_subset(k-1, i+1, lits, n, xs); + for (unsigned i = offset; i < n - k + 1; ++i) { + lits.push_back(polarity?~xs[i]:xs[i]); + add_subset(polarity, k-1, i+1, lits, n, xs); lits.pop_back(); } } - }; -#endif + // for testing + literal theory_pb::assert_ge(context& ctx, unsigned k, unsigned n, literal const* xs) { + psort_nw sort(ctx); + return sort.ge(false, k, n, xs); + } class pb_lit_rewriter_util { public: @@ -448,9 +888,6 @@ namespace smt { case l_undef: break; } - -#if 1 - // TBD: special cases: k == 1, or args.size() == 1 if (c->k().is_one()) { literal_vector& lits = get_lits(); @@ -463,7 +900,6 @@ namespace smt { ctx.mk_th_axiom(get_id(), lits.size(), lits.c_ptr()); return true; } -#endif // maximal coefficient: numeral& max_watch = c->m_max_watch; @@ -516,18 +952,6 @@ namespace smt { ctx.set_var_theory(bv, get_id()); } has_bv = (ctx.get_var_theory(bv) == get_id()); -#if 0 - TBD: - if (!has_bv) { - if (!ctx.e_internalized(arg)) { - ctx.internalize(arg, false); - SASSERT(ctx.e_internalized(arg)); - } - enode* n = ctx.get_enode(arg); - theory_var v = mk_var(n); - ctx.attach_th_var(n, this, v); - } -#endif } else if (m.is_true(arg)) { bv = true_bool_var; @@ -997,8 +1421,38 @@ namespace smt { unsigned k = c.k().get_unsigned(); unsigned num_args = c.size(); - sort_expr se(*this); - sorting_network sn(se); + + literal thl = c.lit(); + literal at_least_k; + +#if 1 + literal_vector in; + for (unsigned i = 0; i < num_args; ++i) { + rational n = c.coeff(i); + while (n.is_pos()) { + in.push_back(c.lit(i)); + n -= rational::one(); + } + } + if (ctx.get_assignment(thl) == l_true && + ctx.get_assign_level(thl) == ctx.get_base_level()) { + psort_nw sortnw(ctx); + sortnw.m_stats.reset(); + at_least_k = sortnw.ge(false, k, in.size(), in.c_ptr()); + ctx.mk_clause(~thl, at_least_k, 0); + m_stats.m_num_compiled_vars += sortnw.m_stats.m_num_compiled_vars; + m_stats.m_num_compiled_clauses += sortnw.m_stats.m_num_compiled_clauses; + } + else { + psort_nw sortnw(ctx); + sortnw.m_stats.reset(); + literal at_least_k = sortnw.ge(true, k, in.size(), in.c_ptr()); + ctx.mk_clause(~thl, at_least_k, 0); + ctx.mk_clause(~at_least_k, thl, 0); + m_stats.m_num_compiled_vars += sortnw.m_stats.m_num_compiled_vars; + m_stats.m_num_compiled_clauses += sortnw.m_stats.m_num_compiled_clauses; + } +#else expr_ref_vector in(m), out(m); expr_ref tmp(m); for (unsigned i = 0; i < num_args; ++i) { @@ -1009,15 +1463,20 @@ namespace smt { n -= rational::one(); } } - IF_VERBOSE(1, verbose_stream() << "(compile " << k << ")\n";); + sort_expr se(*this); + sorting_network sn(se); sn(in, out); - literal at_least_k = se.internalize(c, out[k-1].get()); // first k outputs are 1. + at_least_k = se.internalize(c, out[k-1].get()); // first k outputs are 1. TRACE("pb", tout << "at_least: " << mk_pp(out[k-1].get(), m) << "\n";); - literal thl = c.lit(); se.add_clause(~thl, at_least_k); se.add_clause(thl, ~at_least_k); - TRACE("pb", tout << c.lit() << "\n";); +#endif + IF_VERBOSE(1, verbose_stream() + << "(smt.pb compile sorting network bound: " + << k << " literals: " << in.size() << ")\n";); + + TRACE("pb", tout << thl << "\n";); // auxiliary clauses get removed when popping scopes. // we have to recompile the circuit after back-tracking. c.m_compiled = l_false; @@ -1125,11 +1584,6 @@ namespace smt { } ctx.mk_clause(lits.size(), lits.c_ptr(), js, CLS_AUX_LEMMA, 0); - - // if (true || (c.m_num_propagations & 0xF) == 0) { - // resolve_conflict(c); - //} - } diff --git a/src/smt/theory_pb.h b/src/smt/theory_pb.h index 8f79f0494..b893594d7 100644 --- a/src/smt/theory_pb.h +++ b/src/smt/theory_pb.h @@ -196,5 +196,7 @@ namespace smt { void set_conflict_frequency(unsigned f) { m_conflict_frequency = f; } void set_learn_complements(bool l) { m_learn_complements = l; } + + static literal assert_ge(context& ctx, unsigned k, unsigned n, literal const* xs); }; }; diff --git a/src/test/main.cpp b/src/test/main.cpp index 94c4feb65..f6a4eab29 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -217,6 +217,7 @@ int main(int argc, char ** argv) { TST(qe_arith); TST(expr_substitution); TST(sorting_network); + TST(theory_pb); } void initialize_mam() {} diff --git a/src/test/theory_pb.cpp b/src/test/theory_pb.cpp new file mode 100644 index 000000000..8c9ef405b --- /dev/null +++ b/src/test/theory_pb.cpp @@ -0,0 +1,65 @@ +#include "smt_context.h" +#include "ast_pp.h" +#include "model_v2_pp.h" +#include "reg_decl_plugins.h" +#include "theory_pb.h" + +unsigned populate_literals(unsigned k, smt::literal_vector& lits) { + SASSERT(k < (1u << lits.size())); + unsigned t = 0; + for (unsigned i = 0; i < lits.size(); ++i) { + if (k & (1 << i)) { + lits[i] = smt::true_literal; + t++; + } + else { + lits[i] = smt::false_literal; + } + } + return t; +} + +void tst_theory_pb() { + ast_manager m; + smt_params params; + params.m_model = true; + reg_decl_plugins(m); + expr_ref tmp(m); + + enable_trace("pb"); + for (unsigned N = 4; N < 11; ++N) { + for (unsigned i = 0; i < (1u << N); ++i) { + smt::literal_vector lits(N, smt::false_literal); + unsigned k = populate_literals(i, lits); + std::cout << "k:" << k << " " << N << "\n"; + std::cout.flush(); + TRACE("pb", tout << "k " << k << ": "; + for (unsigned j = 0; j < lits.size(); ++j) { + tout << lits[j] << " "; + } + tout << "\n";); + { + smt::context ctx(m, params); + ctx.push(); + smt::literal l = smt::theory_pb::assert_ge(ctx, k+1, lits.size(), lits.c_ptr()); + if (l != smt::false_literal) { + ctx.assign(l, 0, false); + TRACE("pb", tout << "assign: " << l << "\n"; + ctx.display(tout);); + VERIFY(l_false == ctx.check()); + } + ctx.pop(1); + } + { + smt::context ctx(m, params); + ctx.push(); + smt::literal l = smt::theory_pb::assert_ge(ctx, k, lits.size(), lits.c_ptr()); + SASSERT(l != smt::false_literal); + ctx.assign(l, 0, false); + TRACE("pb", ctx.display(tout);); + VERIFY(l_true == ctx.check()); + ctx.pop(1); + } + } + } +}