mirror of
https://github.com/Z3Prover/z3
synced 2025-04-08 10:25:18 +00:00
add optimized sorting network
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
5965515385
commit
4027de42f6
|
@ -28,71 +28,228 @@ Notes:
|
|||
|
||||
namespace smt {
|
||||
|
||||
#if 0
|
||||
// parametric sorting network
|
||||
// Described in Abio et.al. CP 2013.
|
||||
class psort_nw {
|
||||
enum cmp_t { LE, GE, EQ };
|
||||
ast_manager& m;
|
||||
class vc {
|
||||
unsigned v; // number of vertices
|
||||
unsigned c; // number of clauses
|
||||
static const unsigned lambda = 5;
|
||||
public:
|
||||
vc(unsigned v, unsigned c):v(v), c(c) {}
|
||||
|
||||
bool operator<(vc const& other) const {
|
||||
return to_int() < other.to_int();
|
||||
}
|
||||
vc operator+(vc const& other) const {
|
||||
return vc(v + other.v, c + other.c);
|
||||
}
|
||||
unsigned to_int() const {
|
||||
return lambda*v + c;
|
||||
}
|
||||
vc operator*(unsigned n) const {
|
||||
return vc(n*v, n*c);
|
||||
}
|
||||
};
|
||||
|
||||
static vc min(vc const& v1, vc const& v2) {
|
||||
return (v1.to_int() < v2.to_int())?v1:v2;
|
||||
}
|
||||
|
||||
|
||||
enum cmp_t { LE, GE, EQ, GE_FULL, LE_FULL };
|
||||
context& ctx;
|
||||
cmp_t m_t;
|
||||
|
||||
// for testing
|
||||
static const bool m_disable_dcard = false;
|
||||
static const bool m_disable_dsorting = false;
|
||||
static const bool m_disable_dsmerge = false;
|
||||
static const bool m_force_dcard = false;
|
||||
static const bool m_force_dsorting = false;
|
||||
static const bool m_force_dsmerge = false;
|
||||
|
||||
public:
|
||||
psort_nw(ast_manager& m, context& c):
|
||||
m(m), ctx(c) {}
|
||||
struct stats {
|
||||
unsigned m_num_compiled_vars;
|
||||
unsigned m_num_compiled_clauses;
|
||||
void reset() { memset(this, 0, sizeof(*this)); }
|
||||
stats() { reset(); }
|
||||
};
|
||||
stats m_stats;
|
||||
|
||||
literal ge(unsigned m, unsigned n, literal const* xs) {
|
||||
SASSERT(0 < m && m <= n);
|
||||
literal_vector out;
|
||||
m_t = GE;
|
||||
card(m, n, xs, out);
|
||||
return out[m-1]; // check
|
||||
psort_nw(context& c): ctx(c) {}
|
||||
|
||||
literal ge(bool full, unsigned k, unsigned n, literal const* xs) {
|
||||
if (k > n) {
|
||||
return false_literal;
|
||||
}
|
||||
if (k == 0) {
|
||||
return true_literal;
|
||||
}
|
||||
SASSERT(0 < k && k <= n);
|
||||
literal_vector in, out;
|
||||
if (dualize(k, n, xs, in)) {
|
||||
return le(full, k, in.size(), in.c_ptr());
|
||||
}
|
||||
else {
|
||||
SASSERT(2*k <= n);
|
||||
m_t = full?GE_FULL:GE;
|
||||
card(k, n, xs, out);
|
||||
return out[k-1];
|
||||
}
|
||||
}
|
||||
|
||||
literal le(bool full, unsigned k, unsigned n, literal const* xs) {
|
||||
if (k >= n) {
|
||||
return true_literal;
|
||||
}
|
||||
SASSERT(k < n);
|
||||
literal_vector in, out;
|
||||
if (dualize(k, n, xs, in)) {
|
||||
return ge(full, k, n, in.c_ptr());
|
||||
}
|
||||
else {
|
||||
SASSERT(2*k <= n);
|
||||
m_t = full?LE_FULL:LE;
|
||||
card(k + 1, n, xs, out);
|
||||
return ~out[k];
|
||||
}
|
||||
}
|
||||
|
||||
literal le(unsigned m, unsigned n, literal const* xs) {
|
||||
SASSERT(0 <= m && m < n);
|
||||
literal_vector out;
|
||||
m_t = LE;
|
||||
card(m, n, xs, out);
|
||||
return out[m-1]; // check
|
||||
}
|
||||
|
||||
literal eq(unsigned m, unsigned n, literal const* xs) {
|
||||
SASSERT(0 <= m && m <= n);
|
||||
literal_vector out;
|
||||
m_t = EQ;
|
||||
card(m, n, xs, out);
|
||||
return null_literal; // TBD
|
||||
literal eq(unsigned k, unsigned n, literal const* xs) {
|
||||
if (k > n) {
|
||||
return false_literal;
|
||||
}
|
||||
SASSERT(k <= n);
|
||||
literal_vector in, out;
|
||||
if (dualize(k, n, xs, in)) {
|
||||
return eq(k, n, in.c_ptr());
|
||||
}
|
||||
else {
|
||||
SASSERT(2*k < n);
|
||||
m_t = EQ;
|
||||
card(k+1, n, xs, out);
|
||||
SASSERT(out.size() >= k+1);
|
||||
return out[k-1]; // & ~out[m] TBD
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
|
||||
void card(unsigned m, unsigned n, literal const* xs, literal_vector& out) {
|
||||
if (n <= m) {
|
||||
sorting(n, xs, out);
|
||||
std::ostream& pp(std::ostream& out, unsigned n, literal const* lits) {
|
||||
for (unsigned i = 0; i < n; ++i) out << lits[i] << " ";
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& pp(std::ostream& out, literal_vector const& lits) {
|
||||
for (unsigned i = 0; i < lits.size(); ++i) out << lits[i] << " ";
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& ppv(std::ostream& out, unsigned n, literal const* lits) {
|
||||
for (unsigned i = 0; i < n; ++i) {
|
||||
expr_ref tmp(ctx.get_manager());
|
||||
ctx.literal2expr(lits[i], tmp);
|
||||
out << tmp << " ";
|
||||
}
|
||||
if (use_dcard(n, m)) {
|
||||
dsorting(m, n, xs, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& ppv(std::ostream& out, literal_vector const& lits) {
|
||||
for (unsigned i = 0; i < lits.size(); ++i) {
|
||||
expr_ref tmp(ctx.get_manager());
|
||||
ctx.literal2expr(lits[i], tmp);
|
||||
out << tmp << " ";
|
||||
}
|
||||
else {
|
||||
literal_vector out1, out2;
|
||||
unsigned l = n/2; // TBD
|
||||
card(m, l, xs, out1);
|
||||
card(m, n-l, xs + l, out2);
|
||||
smerge(m, out1.size(), out1.c_ptr(), out2.size(), out2.c_ptr(), out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// 0 <= k <= N
|
||||
// SUM x_i >= k
|
||||
// <=>
|
||||
// SUM ~x_i <= N - k
|
||||
// suppose k > N/2, then it is better to solve dual.
|
||||
|
||||
bool dualize(unsigned& k, unsigned N, literal const* xs, literal_vector& in) {
|
||||
SASSERT(0 <= k && k <= N);
|
||||
if (2*k <= N) {
|
||||
return false;
|
||||
}
|
||||
k = N - k;
|
||||
for (unsigned i = 0; i < N; ++i) {
|
||||
in.push_back(~xs[i]);
|
||||
}
|
||||
TRACE("pb",
|
||||
pp(tout << N << ": ", in);
|
||||
tout << " ~ " << k << "\n";);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool even(unsigned n) const { return (0 == (n & 0x1)); }
|
||||
bool odd(unsigned n) const { return !even(n); }
|
||||
|
||||
unsigned ceil2(unsigned n) const { return n/2 + odd(n); }
|
||||
unsigned floor2(unsigned n) const { return n/2; }
|
||||
unsigned power2(unsigned n) const { SASSERT(n < 10); return 1 << n; }
|
||||
|
||||
literal max(literal a, literal b) {
|
||||
if (a == b) return a;
|
||||
m_stats.m_num_compiled_vars++;
|
||||
ast_manager& m = ctx.get_manager();
|
||||
expr_ref t1(m), t2(m), t3(m);
|
||||
ctx.literal2expr(a, t1);
|
||||
ctx.literal2expr(b, t2);
|
||||
t3 = m.mk_or(t1, t2);
|
||||
bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3);
|
||||
return literal(v);
|
||||
}
|
||||
|
||||
literal min(literal a, literal b) {
|
||||
if (a == b) return a;
|
||||
m_stats.m_num_compiled_vars++;
|
||||
ast_manager& m = ctx.get_manager();
|
||||
expr_ref t1(m), t2(m), t3(m);
|
||||
ctx.literal2expr(a, t1);
|
||||
ctx.literal2expr(b, t2);
|
||||
t3 = m.mk_and(t1, t2);
|
||||
bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3);
|
||||
return literal(v);
|
||||
}
|
||||
|
||||
literal fresh() {
|
||||
m_stats.m_num_compiled_vars++;
|
||||
ast_manager& m = ctx.get_manager();
|
||||
app_ref y(m);
|
||||
y = m.mk_fresh_const("y", m.mk_bool_sort());
|
||||
return literal(ctx.mk_bool_var(y));
|
||||
}
|
||||
void add_clause(literal l1, literal l2, literal l3) {
|
||||
literal lits[3] = { l1, l2, l3 };
|
||||
add_clause(3, lits);
|
||||
}
|
||||
void add_clause(literal l1, literal l2) {
|
||||
literal lits[2] = { l1, l2 };
|
||||
add_clause(2, lits);
|
||||
}
|
||||
void add_clause(unsigned n, literal const* ls) {
|
||||
m_stats.m_num_compiled_clauses++;
|
||||
literal_vector tmp(n, ls);
|
||||
TRACE("pb", pp(tout, n, ls) << "\n";);
|
||||
ctx.mk_clause(n, tmp.c_ptr(), 0, CLS_AUX, 0);
|
||||
}
|
||||
|
||||
// y1 <= max(x1,x2)
|
||||
// y2 <= min(x1,x2)
|
||||
void cmp_ge(literal x1, literal x2, literal y1, literal y2) {
|
||||
add_clause(~y2, x1);
|
||||
add_clause(~y2, x2);
|
||||
add_clause(x1, x2, ~y1);
|
||||
add_clause(~y1, x1, x2);
|
||||
}
|
||||
|
||||
// max(x1,x2) <= y1
|
||||
// min(x1,x2) <= y2
|
||||
void cmp_le(literal x1, literal x2, literal y1, literal y2) {
|
||||
add_clause(~x1, y1);
|
||||
add_clause(~x2, y1);
|
||||
|
@ -111,21 +268,57 @@ namespace smt {
|
|||
case EQ: cmp_eq(x1, x2, y1, y2); break;
|
||||
}
|
||||
}
|
||||
|
||||
void size_merge(unsigned a, unsigned b,
|
||||
unsigned& vD, unsigned& cD,
|
||||
unsigned& vR, unsigned& cR) {
|
||||
vD = a + b;
|
||||
cD = a*b + a + b;
|
||||
|
||||
vc vc_cmp() {
|
||||
return vc(2, (m_t==EQ)?6:3);
|
||||
}
|
||||
|
||||
void card(unsigned k, unsigned n, literal const* xs, literal_vector& out) {
|
||||
TRACE("pb", tout << "card k:" << k << " n: " << n << "\n";);
|
||||
if (n <= k) {
|
||||
sorting(n, xs, out);
|
||||
}
|
||||
else if (use_dcard(k, n)) {
|
||||
dsorting(k, n, xs, out);
|
||||
}
|
||||
else {
|
||||
literal_vector out1, out2;
|
||||
unsigned l = n/2; // TBD
|
||||
card(k, l, xs, out1);
|
||||
card(k, n-l, xs + l, out2);
|
||||
smerge(k, out1.size(), out1.c_ptr(), out2.size(), out2.c_ptr(), out);
|
||||
}
|
||||
TRACE("pb", tout << "card k:" << k << " n: " << n << "\n";
|
||||
pp(tout << "in:", n, xs) << "\n";
|
||||
pp(tout << "out:", out) << "\n";);
|
||||
|
||||
}
|
||||
vc vc_card(unsigned k, unsigned n) {
|
||||
if (n <= k) {
|
||||
return vc_sorting(n);
|
||||
}
|
||||
else if (use_dcard(k, n)) {
|
||||
return vc_dsorting(k, n);
|
||||
}
|
||||
else {
|
||||
return vc_card_rec(k, n);
|
||||
}
|
||||
}
|
||||
vc vc_card_rec(unsigned k, unsigned n) {
|
||||
unsigned l = n/2;
|
||||
return vc_card(k, l) + vc_card(k, n-l) + vc_smerge(k, l, n-l);
|
||||
}
|
||||
bool use_dcard(unsigned k, unsigned n) {
|
||||
return m_force_dcard || (!m_disable_dcard && n < 10 && vc_dsorting(k, n) < vc_card_rec(k, n));
|
||||
}
|
||||
|
||||
|
||||
void merge(unsigned a, literal const* as,
|
||||
unsigned b, literal const* bs,
|
||||
literal_vector& out) {
|
||||
TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n";);
|
||||
if (a == 1 && b == 1) {
|
||||
y1 = fresh();
|
||||
y2 = fresh();
|
||||
literal y1 = max(as[0], bs[0]);
|
||||
literal y2 = min(as[0], bs[0]);
|
||||
out.push_back(y1);
|
||||
out.push_back(y2);
|
||||
cmp(as[0], bs[0], y1, y2);
|
||||
|
@ -134,27 +327,98 @@ namespace smt {
|
|||
out.append(b, bs);
|
||||
}
|
||||
else if (b == 0) {
|
||||
merge(b, bs, a, as,out);
|
||||
out.append(a, as);
|
||||
}
|
||||
else if (use_dsmerge(a, b, a + b)) {
|
||||
dsmerge(a, as, b, bs, a + b, out);
|
||||
dsmerge(a + b, a, as, b, bs, out);
|
||||
}
|
||||
else if (even(a) && even(b) && b > 0) {
|
||||
|
||||
}
|
||||
else if (even(a) && odd(b) && (a > 1 || b > 1)) {
|
||||
|
||||
}
|
||||
else if (odd(a) && odd(b) && (a > 1 || b > 1)) {
|
||||
|
||||
else if (even(a) && odd(b)) {
|
||||
merge(b, bs, a, as, out);
|
||||
}
|
||||
else {
|
||||
merge(b, bs, a, as, out);
|
||||
literal_vector even_a, odd_a;
|
||||
literal_vector even_b, odd_b;
|
||||
literal_vector out1, out2;
|
||||
SASSERT(a > 1 || b > 1);
|
||||
split(a, as, even_a, odd_a);
|
||||
split(b, bs, even_b, odd_b);
|
||||
SASSERT(!even_a.empty());
|
||||
SASSERT(!even_b.empty());
|
||||
merge(even_a.size(), even_a.c_ptr(),
|
||||
even_b.size(), even_b.c_ptr(), out1);
|
||||
merge(odd_a.size(), odd_a.c_ptr(),
|
||||
odd_b.size(), odd_b.c_ptr(), out2);
|
||||
interleave(out1, out2, out);
|
||||
}
|
||||
TRACE("pb", tout << "merge a: " << a << " b: " << b << "\n";
|
||||
pp(tout << "a:", a, as) << "\n";
|
||||
pp(tout << "b:", b, bs) << "\n";
|
||||
pp(tout << "out:", out) << "\n";);
|
||||
}
|
||||
vc vc_merge(unsigned a, unsigned b) {
|
||||
if (a == 1 && b == 1) {
|
||||
return vc_cmp();
|
||||
}
|
||||
else if (a == 0 || b == 0) {
|
||||
return vc(0, 0);
|
||||
}
|
||||
else if (use_dsmerge(a, b, a + b)) {
|
||||
return vc_dsmerge(a, b, a + b);
|
||||
}
|
||||
else {
|
||||
return vc_merge_rec(a, b);
|
||||
}
|
||||
}
|
||||
vc vc_merge_rec(unsigned a, unsigned b) {
|
||||
return
|
||||
vc_merge(ceil2(a), ceil2(b)) +
|
||||
vc_merge(floor2(a), floor2(b)) +
|
||||
vc_interleave(ceil2(a) + ceil2(b), floor2(a) + floor2(b));
|
||||
}
|
||||
void split(unsigned n, literal const* ls, literal_vector& even, literal_vector& odd) {
|
||||
for (unsigned i = 0; i < n; i += 2) {
|
||||
even.push_back(ls[i]);
|
||||
}
|
||||
for (unsigned i = 1; i < n; i += 2) {
|
||||
odd.push_back(ls[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void interleave(literal_vector const& as,
|
||||
literal_vector const& bs,
|
||||
literal_vector& out) {
|
||||
TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n";);
|
||||
SASSERT(as.size() >= bs.size());
|
||||
SASSERT(as.size() <= bs.size() + 2);
|
||||
SASSERT(!as.empty());
|
||||
out.push_back(as[0]);
|
||||
unsigned sz = std::min(as.size()-1, bs.size());
|
||||
for (unsigned i = 0; i < sz; ++i) {
|
||||
literal y1 = max(as[i+1],bs[i]);
|
||||
literal y2 = min(as[i+1],bs[i]);
|
||||
cmp(as[i+1], bs[i], y1, y2);
|
||||
out.push_back(y1);
|
||||
out.push_back(y2);
|
||||
}
|
||||
if (as.size() == bs.size()) {
|
||||
out.push_back(bs[sz]);
|
||||
}
|
||||
else if (as.size() == bs.size() + 2) {
|
||||
out.push_back(as[sz+1]);
|
||||
}
|
||||
SASSERT(out.size() == as.size() + bs.size());
|
||||
TRACE("pb", tout << "interleave: " << as.size() << " " << bs.size() << "\n";
|
||||
pp(tout << "a: ", as) << "\n";
|
||||
pp(tout << "b: ", bs) << "\n";
|
||||
pp(tout << "out: ", out) << "\n";);
|
||||
|
||||
}
|
||||
vc vc_interleave(unsigned a, unsigned b) {
|
||||
return vc_cmp()*std::min(a-1,b);
|
||||
}
|
||||
|
||||
void sorting(cmp_t t, unsigned n, literal const* xs,
|
||||
literal_vector& out) {
|
||||
void sorting(unsigned n, literal const* xs, literal_vector& out) {
|
||||
TRACE("pb", tout << "sorting: " << n << "\n";);
|
||||
switch(n) {
|
||||
case 0:
|
||||
break;
|
||||
|
@ -179,94 +443,270 @@ namespace smt {
|
|||
}
|
||||
break;
|
||||
}
|
||||
TRACE("pb", tout << "sorting: " << n << "\n";
|
||||
pp(tout << "in:", n, xs) << "\n";
|
||||
pp(tout << "out:", out) << "\n";);
|
||||
|
||||
}
|
||||
vc vc_sorting(unsigned n) {
|
||||
switch(n) {
|
||||
case 0: return vc(0,0);
|
||||
case 1: return vc(0,0);
|
||||
case 2: return vc_merge(1,1);
|
||||
default:
|
||||
if (use_dsorting(n)) {
|
||||
return vc_dsorting(n, n);
|
||||
}
|
||||
else {
|
||||
return vc_sorting_rec(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
vc vc_sorting_rec(unsigned n) {
|
||||
SASSERT(n > 2);
|
||||
unsigned l = n/2;
|
||||
return vc_sorting(l) + vc_sorting(n-l) + vc_merge(l, n-l);
|
||||
}
|
||||
|
||||
bool use_dsmerge(unsigned a, unsigned b, unsigned c) const {
|
||||
return false;
|
||||
bool use_dsorting(unsigned n) {
|
||||
SASSERT(n > 2);
|
||||
return m_force_dsorting ||
|
||||
(!m_disable_dsorting && n < 10 && vc_dsorting(n, n) < vc_sorting_rec(n));
|
||||
}
|
||||
|
||||
void smerge(unsigned a, literal const* as,
|
||||
void smerge(unsigned c,
|
||||
unsigned a, literal const* as,
|
||||
unsigned b, literal const* bs,
|
||||
unsigned c,
|
||||
literal_vector& out) {
|
||||
TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n";);
|
||||
if (a == 1 && b == 1 && c == 1) {
|
||||
literal y = fresh();
|
||||
add_clause(~as[0], y);
|
||||
add_clause(~bs[0], y);
|
||||
literal y = max(as[0], bs[0]);
|
||||
if (m_t != GE) {
|
||||
// x1 <= max(x1,x2)
|
||||
// x2 <= max(x1,x2)
|
||||
add_clause(~as[0], y);
|
||||
add_clause(~bs[0], y);
|
||||
}
|
||||
if (m_t != LE) {
|
||||
// max(x1,x2) <= x1, x2
|
||||
add_clause(~y, as[0], bs[0]);
|
||||
}
|
||||
out.push_back(y);
|
||||
}
|
||||
else if (a == 0) {
|
||||
out.append(std::min(c, b), bs);
|
||||
}
|
||||
else if (b == 0) {
|
||||
out.append(std::min(c, a), as);
|
||||
}
|
||||
else if (a > c) {
|
||||
smerge(a - c, as, b, bs, c, out);
|
||||
smerge(c, c, as, b, bs, out);
|
||||
}
|
||||
else if (b > c) {
|
||||
smerge(a, as, b - c, bs, c, out);
|
||||
smerge(c, a, as, c, bs, out);
|
||||
}
|
||||
else if (a + b <= c) {
|
||||
merge(a, as, b, bs, out);
|
||||
}
|
||||
else if (use_dsmerge(a, b, c)) {
|
||||
dsmerge(a, as, b, bs, c, out);
|
||||
dsmerge(c, a, as, b, bs, out);
|
||||
}
|
||||
else if (even(c)) {
|
||||
|
||||
}
|
||||
else if (odd(c)) {
|
||||
SASSERT(c > 1);
|
||||
else {
|
||||
literal_vector even_a, odd_a;
|
||||
literal_vector even_b, odd_b;
|
||||
literal_vector out1, out2;
|
||||
split(a, as, even_a, odd_a);
|
||||
split(b, bs, even_b, odd_b);
|
||||
SASSERT(!even_a.empty());
|
||||
SASSERT(!even_b.empty());
|
||||
unsigned c1, c2;
|
||||
if (even(c)) {
|
||||
c1 = 1 + c/2; c2 = c/2;
|
||||
}
|
||||
else {
|
||||
c1 = (c + 1)/2; c2 = (c - 1)/2;
|
||||
}
|
||||
smerge(c1, even_a.size(), even_a.c_ptr(),
|
||||
even_b.size(), even_b.c_ptr(), out1);
|
||||
smerge(c2, odd_a.size(), odd_a.c_ptr(),
|
||||
odd_b.size(), odd_b.c_ptr(), out2);
|
||||
SASSERT(out1.size() == std::min(even_a.size()+even_b.size(), c1));
|
||||
SASSERT(out2.size() == std::min(odd_a.size()+odd_b.size(), c2));
|
||||
literal y;
|
||||
if (even(c)) {
|
||||
literal z1 = out1.back();
|
||||
literal z2 = out2.back();
|
||||
out1.pop_back();
|
||||
out2.pop_back();
|
||||
y = max(z1, z2);
|
||||
if (m_t != GE) {
|
||||
add_clause(~z1, y);
|
||||
add_clause(~z2, y);
|
||||
}
|
||||
if (m_t != LE) {
|
||||
add_clause(~y, z1, z2);
|
||||
}
|
||||
}
|
||||
interleave(out1, out2, out);
|
||||
if (even(c)) {
|
||||
out.push_back(y);
|
||||
}
|
||||
}
|
||||
TRACE("pb", tout << "smerge: c:" << c << " a:" << a << " b:" << b << "\n";
|
||||
pp(tout << "a:", a, as) << "\n";
|
||||
pp(tout << "b:", b, bs) << "\n";
|
||||
pp(tout << "out:", out) << "\n";
|
||||
);
|
||||
SASSERT(out.size() == std::min(a + b, c));
|
||||
}
|
||||
|
||||
vc vc_smerge(unsigned a, unsigned b, unsigned c) {
|
||||
if (a == 1 && b == 1 && c == 1) {
|
||||
vc v(1,0);
|
||||
if (m_t != GE) v = v + vc(0, 2);
|
||||
if (m_t != LE) v = v + vc(0, 1);
|
||||
return v;
|
||||
}
|
||||
if (a == 0 || b == 0) return vc(0, 0);
|
||||
if (a > c) return vc_smerge(c, b, c);
|
||||
if (b > c) return vc_smerge(a, c, c);
|
||||
if (a + b <= c) return vc_merge(a, b);
|
||||
if (use_dsmerge(a, b, c)) return vc_dsmerge(a, b, c);
|
||||
return vc_smerge_rec(a, b, c);
|
||||
}
|
||||
vc vc_smerge_rec(unsigned a, unsigned b, unsigned c) {
|
||||
return
|
||||
vc_smerge(ceil2(a), ceil2(b), even(c)?(1+c/2):((c+1)/2)) +
|
||||
vc_smerge(floor2(a), floor2(b), even(c)?(c/2):((c-1)/2)) +
|
||||
vc_interleave(ceil2(a)+ceil2(b),floor2(a)+floor2(b)) +
|
||||
vc(1, 0) +
|
||||
((m_t != GE)?vc(0, 2):vc(0, 0)) +
|
||||
((m_t != LE)?vc(0, 1):vc(0, 0));
|
||||
}
|
||||
bool use_dsmerge(unsigned a, unsigned b, unsigned c) {
|
||||
return
|
||||
m_force_dsmerge ||
|
||||
(!m_disable_dsmerge &&
|
||||
a < (1 << 15) && b < (1 << 15) &&
|
||||
vc_dsmerge(a, b, a + b) < vc_smerge_rec(a, b, c));
|
||||
}
|
||||
|
||||
void dsmerge(
|
||||
unsigned c,
|
||||
unsigned a, literal const* as,
|
||||
unsigned b, literal const* bs,
|
||||
unsigned c,
|
||||
literal_vector& out) {
|
||||
TRACE("pb", tout << "dsmerge: c:" << c << " a:" << a << " b:" << b << "\n";);
|
||||
SASSERT(a <= c);
|
||||
SASSERT(b <= c);
|
||||
SASSERT(a + b > c);
|
||||
for (unsigned i = 0; i < c; ++i) {
|
||||
out.push_back(fresh());
|
||||
}
|
||||
for (unsigned i = 0; i < a; ++i) {
|
||||
add_clause(~as[i],out[i]);
|
||||
if (m_t != GE) {
|
||||
for (unsigned i = 0; i < a; ++i) {
|
||||
add_clause(~as[i], out[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < b; ++i) {
|
||||
add_clause(~bs[i], out[i]);
|
||||
}
|
||||
for (unsigned i = 1; i <= a; ++i) {
|
||||
for (unsigned j = 1; j <= b && i + j <= c; ++j) {
|
||||
add_clause(~as[i-1],~bs[j-1],out[i+j-1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < b; ++i) {
|
||||
add_clause(~bs[i],out[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < a; ++i) {
|
||||
for (unsigned j = 0; j < b && i + j < c; ++j) {
|
||||
add_clause(~as[i],~bs[j],out[i+j]);
|
||||
if (m_t != LE) {
|
||||
for (unsigned k = 1; k <= c; ++k) {
|
||||
literal_vector ls;
|
||||
ls.push_back(~out[k-1]);
|
||||
if (k <= a) {
|
||||
ls.push_back(as[k-1]);
|
||||
}
|
||||
if (k <= b) {
|
||||
ls.push_back(bs[k-1]);
|
||||
}
|
||||
for (unsigned i = 1; i <= std::min(a,k-1); ++i) {
|
||||
if (k + 1 - i <= b) {
|
||||
ls.push_back(as[i-1]);
|
||||
ls.push_back(bs[k-i]);
|
||||
add_clause(ls.size(), ls.c_ptr());
|
||||
ls.pop_back();
|
||||
ls.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
vc vc_dsmerge(unsigned a, unsigned b, unsigned c) {
|
||||
vc v(c, 0);
|
||||
if (m_t != GE) {
|
||||
v = v + vc(0, a + b + std::min(a, c)*std::min(b, c)/2);
|
||||
}
|
||||
if (m_t != LE) {
|
||||
v = v + vc(0, std::min(a, c)*std::min(b, c)/2);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
|
||||
void dsorting(unsigned m, unsigned n, literal const* xs,
|
||||
literal_vector& out) {
|
||||
TRACE("pb", tout << "dsorting m: " << m << " n: " << n << "\n";);
|
||||
SASSERT(m <= n);
|
||||
literal_vector lits;
|
||||
for (unsigned i = 0; i < m; ++i) {
|
||||
out.push_back(fresh());
|
||||
}
|
||||
for (unsigned k = 0; k < m; ++k) {
|
||||
literal_vector lits;
|
||||
lits.push_back(out[k]);
|
||||
add_subset(k+1, 0, lits, n, xs);
|
||||
if (m_t != GE) {
|
||||
for (unsigned k = 1; k <= m; ++k) {
|
||||
lits.push_back(out[k-1]);
|
||||
add_subset(true, k, 0, lits, n, xs);
|
||||
lits.pop_back();
|
||||
}
|
||||
}
|
||||
if (m_t != LE) {
|
||||
for (unsigned k = 1; k <= m; ++k) {
|
||||
lits.push_back(~out[k-1]);
|
||||
add_subset(false, n-k+1, 0, lits, n, xs);
|
||||
lits.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
vc vc_dsorting(unsigned m, unsigned n) {
|
||||
SASSERT(m <= n && n < 10);
|
||||
vc v(m, 0);
|
||||
if (m_t != GE) {
|
||||
v = v + vc(0, power2(n-1));
|
||||
}
|
||||
if (m_t != LE) {
|
||||
v = v + vc(0, power2(n-1));
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
void add_subset(unsigned k, unsigned offset, literal_vector& lits,
|
||||
void add_subset(bool polarity, unsigned k, unsigned offset, literal_vector& lits,
|
||||
unsigned n, literal const* xs) {
|
||||
SASSERT(k + offset < n);
|
||||
TRACE("pb", tout << "k:" << k << " offset: " << offset << " n: " << n << " ";
|
||||
pp(tout, lits) << "\n";);
|
||||
SASSERT(k + offset <= n);
|
||||
if (k == 0) {
|
||||
ctx.add_clause(lits.size(), lits.c_ptr());
|
||||
add_clause(lits.size(), lits.c_ptr());
|
||||
return;
|
||||
}
|
||||
for (unsigned i = offset; i < n-offset-k; ++i) {
|
||||
lits.push_back(xs[i]);
|
||||
add_subset(k-1, i+1, lits, n, xs);
|
||||
for (unsigned i = offset; i < n - k + 1; ++i) {
|
||||
lits.push_back(polarity?~xs[i]:xs[i]);
|
||||
add_subset(polarity, k-1, i+1, lits, n, xs);
|
||||
lits.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif
|
||||
// for testing
|
||||
literal theory_pb::assert_ge(context& ctx, unsigned k, unsigned n, literal const* xs) {
|
||||
psort_nw sort(ctx);
|
||||
return sort.ge(false, k, n, xs);
|
||||
}
|
||||
|
||||
class pb_lit_rewriter_util {
|
||||
public:
|
||||
|
@ -448,9 +888,6 @@ namespace smt {
|
|||
case l_undef:
|
||||
break;
|
||||
}
|
||||
|
||||
#if 1
|
||||
// TBD: special cases: k == 1, or args.size() == 1
|
||||
|
||||
if (c->k().is_one()) {
|
||||
literal_vector& lits = get_lits();
|
||||
|
@ -463,7 +900,6 @@ namespace smt {
|
|||
ctx.mk_th_axiom(get_id(), lits.size(), lits.c_ptr());
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// maximal coefficient:
|
||||
numeral& max_watch = c->m_max_watch;
|
||||
|
@ -516,18 +952,6 @@ namespace smt {
|
|||
ctx.set_var_theory(bv, get_id());
|
||||
}
|
||||
has_bv = (ctx.get_var_theory(bv) == get_id());
|
||||
#if 0
|
||||
TBD:
|
||||
if (!has_bv) {
|
||||
if (!ctx.e_internalized(arg)) {
|
||||
ctx.internalize(arg, false);
|
||||
SASSERT(ctx.e_internalized(arg));
|
||||
}
|
||||
enode* n = ctx.get_enode(arg);
|
||||
theory_var v = mk_var(n);
|
||||
ctx.attach_th_var(n, this, v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if (m.is_true(arg)) {
|
||||
bv = true_bool_var;
|
||||
|
@ -997,8 +1421,38 @@ namespace smt {
|
|||
unsigned k = c.k().get_unsigned();
|
||||
unsigned num_args = c.size();
|
||||
|
||||
sort_expr se(*this);
|
||||
sorting_network<sort_expr> sn(se);
|
||||
|
||||
literal thl = c.lit();
|
||||
literal at_least_k;
|
||||
|
||||
#if 1
|
||||
literal_vector in;
|
||||
for (unsigned i = 0; i < num_args; ++i) {
|
||||
rational n = c.coeff(i);
|
||||
while (n.is_pos()) {
|
||||
in.push_back(c.lit(i));
|
||||
n -= rational::one();
|
||||
}
|
||||
}
|
||||
if (ctx.get_assignment(thl) == l_true &&
|
||||
ctx.get_assign_level(thl) == ctx.get_base_level()) {
|
||||
psort_nw sortnw(ctx);
|
||||
sortnw.m_stats.reset();
|
||||
at_least_k = sortnw.ge(false, k, in.size(), in.c_ptr());
|
||||
ctx.mk_clause(~thl, at_least_k, 0);
|
||||
m_stats.m_num_compiled_vars += sortnw.m_stats.m_num_compiled_vars;
|
||||
m_stats.m_num_compiled_clauses += sortnw.m_stats.m_num_compiled_clauses;
|
||||
}
|
||||
else {
|
||||
psort_nw sortnw(ctx);
|
||||
sortnw.m_stats.reset();
|
||||
literal at_least_k = sortnw.ge(true, k, in.size(), in.c_ptr());
|
||||
ctx.mk_clause(~thl, at_least_k, 0);
|
||||
ctx.mk_clause(~at_least_k, thl, 0);
|
||||
m_stats.m_num_compiled_vars += sortnw.m_stats.m_num_compiled_vars;
|
||||
m_stats.m_num_compiled_clauses += sortnw.m_stats.m_num_compiled_clauses;
|
||||
}
|
||||
#else
|
||||
expr_ref_vector in(m), out(m);
|
||||
expr_ref tmp(m);
|
||||
for (unsigned i = 0; i < num_args; ++i) {
|
||||
|
@ -1009,15 +1463,20 @@ namespace smt {
|
|||
n -= rational::one();
|
||||
}
|
||||
}
|
||||
IF_VERBOSE(1, verbose_stream() << "(compile " << k << ")\n";);
|
||||
sort_expr se(*this);
|
||||
sorting_network<sort_expr> sn(se);
|
||||
sn(in, out);
|
||||
literal at_least_k = se.internalize(c, out[k-1].get()); // first k outputs are 1.
|
||||
at_least_k = se.internalize(c, out[k-1].get()); // first k outputs are 1.
|
||||
TRACE("pb", tout << "at_least: " << mk_pp(out[k-1].get(), m) << "\n";);
|
||||
|
||||
literal thl = c.lit();
|
||||
se.add_clause(~thl, at_least_k);
|
||||
se.add_clause(thl, ~at_least_k);
|
||||
TRACE("pb", tout << c.lit() << "\n";);
|
||||
#endif
|
||||
IF_VERBOSE(1, verbose_stream()
|
||||
<< "(smt.pb compile sorting network bound: "
|
||||
<< k << " literals: " << in.size() << ")\n";);
|
||||
|
||||
TRACE("pb", tout << thl << "\n";);
|
||||
// auxiliary clauses get removed when popping scopes.
|
||||
// we have to recompile the circuit after back-tracking.
|
||||
c.m_compiled = l_false;
|
||||
|
@ -1125,11 +1584,6 @@ namespace smt {
|
|||
}
|
||||
|
||||
ctx.mk_clause(lits.size(), lits.c_ptr(), js, CLS_AUX_LEMMA, 0);
|
||||
|
||||
// if (true || (c.m_num_propagations & 0xF) == 0) {
|
||||
// resolve_conflict(c);
|
||||
//}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -196,5 +196,7 @@ namespace smt {
|
|||
|
||||
void set_conflict_frequency(unsigned f) { m_conflict_frequency = f; }
|
||||
void set_learn_complements(bool l) { m_learn_complements = l; }
|
||||
|
||||
static literal assert_ge(context& ctx, unsigned k, unsigned n, literal const* xs);
|
||||
};
|
||||
};
|
||||
|
|
|
@ -217,6 +217,7 @@ int main(int argc, char ** argv) {
|
|||
TST(qe_arith);
|
||||
TST(expr_substitution);
|
||||
TST(sorting_network);
|
||||
TST(theory_pb);
|
||||
}
|
||||
|
||||
void initialize_mam() {}
|
||||
|
|
65
src/test/theory_pb.cpp
Normal file
65
src/test/theory_pb.cpp
Normal file
|
@ -0,0 +1,65 @@
|
|||
#include "smt_context.h"
|
||||
#include "ast_pp.h"
|
||||
#include "model_v2_pp.h"
|
||||
#include "reg_decl_plugins.h"
|
||||
#include "theory_pb.h"
|
||||
|
||||
unsigned populate_literals(unsigned k, smt::literal_vector& lits) {
|
||||
SASSERT(k < (1u << lits.size()));
|
||||
unsigned t = 0;
|
||||
for (unsigned i = 0; i < lits.size(); ++i) {
|
||||
if (k & (1 << i)) {
|
||||
lits[i] = smt::true_literal;
|
||||
t++;
|
||||
}
|
||||
else {
|
||||
lits[i] = smt::false_literal;
|
||||
}
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
void tst_theory_pb() {
|
||||
ast_manager m;
|
||||
smt_params params;
|
||||
params.m_model = true;
|
||||
reg_decl_plugins(m);
|
||||
expr_ref tmp(m);
|
||||
|
||||
enable_trace("pb");
|
||||
for (unsigned N = 4; N < 11; ++N) {
|
||||
for (unsigned i = 0; i < (1u << N); ++i) {
|
||||
smt::literal_vector lits(N, smt::false_literal);
|
||||
unsigned k = populate_literals(i, lits);
|
||||
std::cout << "k:" << k << " " << N << "\n";
|
||||
std::cout.flush();
|
||||
TRACE("pb", tout << "k " << k << ": ";
|
||||
for (unsigned j = 0; j < lits.size(); ++j) {
|
||||
tout << lits[j] << " ";
|
||||
}
|
||||
tout << "\n";);
|
||||
{
|
||||
smt::context ctx(m, params);
|
||||
ctx.push();
|
||||
smt::literal l = smt::theory_pb::assert_ge(ctx, k+1, lits.size(), lits.c_ptr());
|
||||
if (l != smt::false_literal) {
|
||||
ctx.assign(l, 0, false);
|
||||
TRACE("pb", tout << "assign: " << l << "\n";
|
||||
ctx.display(tout););
|
||||
VERIFY(l_false == ctx.check());
|
||||
}
|
||||
ctx.pop(1);
|
||||
}
|
||||
{
|
||||
smt::context ctx(m, params);
|
||||
ctx.push();
|
||||
smt::literal l = smt::theory_pb::assert_ge(ctx, k, lits.size(), lits.c_ptr());
|
||||
SASSERT(l != smt::false_literal);
|
||||
ctx.assign(l, 0, false);
|
||||
TRACE("pb", ctx.display(tout););
|
||||
VERIFY(l_true == ctx.check());
|
||||
ctx.pop(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue