From 3ee8c3efb5923795afa672ae3225ebbe89eb8c42 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Nov 2013 00:53:08 -0800 Subject: [PATCH] pb/car constraints Signed-off-by: Nikolaj Bjorner --- src/smt/theory_card.cpp | 553 ++++++++++++++------------- src/smt/theory_card.h | 29 +- src/tactic/arith/lia2card_tactic.cpp | 50 ++- src/test/main.cpp | 1 + src/test/sorting_network.cpp | 60 +++ src/util/sorting_network.h | 99 +++++ 6 files changed, 494 insertions(+), 298 deletions(-) create mode 100644 src/test/sorting_network.cpp create mode 100644 src/util/sorting_network.h diff --git a/src/smt/theory_card.cpp b/src/smt/theory_card.cpp index edae63108..080cbf55e 100644 --- a/src/smt/theory_card.cpp +++ b/src/smt/theory_card.cpp @@ -97,25 +97,9 @@ namespace smt { ctx.mk_th_axiom(get_id(), 1, &lit); ctx.mark_as_relevant(tmp); } - c->m_args.push_back(bv); - if (0 < k) { - add_watch(bv, c); - } - } - if (0 < k) { - add_card(c); - } - else { - // bv <=> (and (not bv1) ... (not bv_n)) - literal_vector& lits = get_lits(); - lits.push_back(literal(abv)); - for (unsigned i = 0; i < c->m_args.size(); ++i) { - ctx.mk_th_axiom(get_id(), ~literal(abv), ~literal(c->m_args[i])); - lits.push_back(literal(c->m_args[i])); - } - ctx.mk_th_axiom(get_id(), lits.size(), lits.c_ptr()); - dealloc(c); + c->m_args.push_back(std::make_pair(bv,1)); } + add_card(c); return true; } @@ -128,6 +112,49 @@ namespace smt { cards->push_back(c); m_watch_trail.push_back(bv); } + + void theory_card::add_card(card* c) { + bool_var abv = c->m_bv; + arg_t& args = c->m_args; + + // sort and coalesce arguments: + std::sort(args.begin(), args.end()); + for (unsigned i = 0; i + 1 < args.size(); ++i) { + if (args[i].first == args[i+1].first) { + args[i].second += args[i+1].second; + for (unsigned j = i+1; j + 1 < args.size(); ++j) { + args[j] = args[j+1]; + } + args.resize(args.size()-1); + } + if (args[i].second == 0) { + for (unsigned j = i; j + 1 < args.size(); ++j) { + args[j] = args[j+1]; + } + args.resize(args.size()-1); + } + } + + int min = 0, max = 0; + for (unsigned i = 0; i < args.size(); ++i) { + // update min and max: + int inc = args[i].second; + if (inc > 0) { + max += inc; + } + else { + SASSERT(inc < 0); + min += inc; + } + // add watch literals: + add_watch(args[i].first, c); + } + c->m_current_min = c->m_abs_min = min; + c->m_current_max = c->m_abs_max = max; + m_cards.insert(abv, c); + m_cards_trail.push_back(abv); + } + void theory_card::reset_eh() { @@ -149,6 +176,172 @@ namespace smt { m_watch_lim.reset(); } + void theory_card::update_min_max(bool_var v, bool is_true, card* c) { + context& ctx = get_context(); + ast_manager& m = get_manager(); + arg_t const& args = c->m_args; + int inc = find_inc(v, args); + int& min = c->m_current_min; + int& max = c->m_current_max; + int k = c->m_k; + // inc > 0 & is_true -> min += inc + // inc < 0 & is_true -> max += inc + // inc > 0 & !is_true -> max -= inc + // inc < 0 & !is_true -> min -= inc + + if (inc > 0 && is_true) { + ctx.push_trail(value_trail(min)); + min += inc; + } + else if (inc < 0 && is_true) { + ctx.push_trail(value_trail(max)); + max += inc; + } + else if (inc > 0 && !is_true) { + ctx.push_trail(value_trail(max)); + max -= inc; + } + else { + ctx.push_trail(value_trail(min)); + min -= inc; + } + // invariant min <= max + SASSERT(min <= max); + } + + void theory_card::assign_use(bool_var v, bool is_true, card* c) { + update_min_max(v, is_true, c); + propagate_assignment(c); + } + + lbool theory_card::inc_min(int inc, lbool val) { + if (inc > 0) { + return val; + } + else if (inc < 0) { + return ~val; + } + else { + return l_undef; + } + } + + lbool theory_card::dec_max(int inc, lbool val) { + if (inc > 0) { + return ~val; + } + else if (inc < 0) { + return val; + } + else { + return l_undef; + } + } + + int theory_card::accumulate_min(literal_vector& lits, card* c) { + context& ctx = get_context(); + int k = c->m_k; + arg_t const& args = c->m_args; + int curr_min = c->m_abs_min; + for (unsigned i = 0; i < args.size() && curr_min <= k; ++i) { + bool_var bv = args[i].first; + int inc = args[i].second; + lbool val = ctx.get_assignment(bv); + if (inc_min(inc, val) == l_true) { + curr_min += abs(inc); + lits.push_back(literal(bv, val != l_true)); + } + } + return curr_min; + } + + int theory_card::accumulate_max(literal_vector& lits, card* c) { + context& ctx = get_context(); + arg_t const& args = c->m_args; + int k = c->m_k; + int curr_max = c->m_abs_max; + for (unsigned i = 0; i < args.size() && k < curr_max; ++i) { + bool_var bv = args[i].first; + int inc = args[i].second; + lbool val = ctx.get_assignment(bv); + if (dec_max(inc, val) == l_true) { + curr_max -= abs(inc); + lits.push_back(literal(bv, val == l_true)); + } + } + return curr_max; + } + + void theory_card::propagate_assignment(card* c) { + context& ctx = get_context(); + arg_t const& args = c->m_args; + bool_var abv = c->m_bv; + int min = c->m_current_min; + int max = c->m_current_max; + int k = c->m_k; + + // + // if min > k && abv != l_false -> force abv false + // if max <= k && abv != l_true -> force abv true + // if min == k && abv == l_true -> force positive unassigned literals false + // if max == k + 1 && abv == l_false -> force negative unassigned literals false + // + lbool aval = ctx.get_assignment(abv); + if (min > k && aval != l_false) { + literal_vector& lits = get_lits(); + lits.push_back(~literal(abv)); + int curr_min = accumulate_min(lits, c); + SASSERT(curr_min > k); + add_clause(lits); + } + else if (max <= k && aval != l_true) { + literal_vector& lits = get_lits(); + lits.push_back(literal(abv)); + int curr_max = accumulate_max(lits, c); + SASSERT(curr_max <= k); + add_clause(lits); + } + else if (min == k && aval == l_true) { + literal_vector& lits = get_lits(); + lits.push_back(~literal(abv)); + int curr_min = accumulate_min(lits, c); + if (curr_min > k) { + add_clause(lits); + } + else { + SASSERT(curr_min == k); + for (unsigned i = 0; i < args.size(); ++i) { + bool_var bv = args[i].first; + int inc = args[i].second; + if (inc_min(inc, ctx.get_assignment(bv)) == l_undef) { + lits.push_back(literal(bv, inc > 0)); // avoid incrementing min. + add_clause(lits); + lits.pop_back(); + } + } + } + } + else if (max == k + 1 && aval == l_false) { + literal_vector& lits = get_lits(); + lits.push_back(literal(abv)); + int curr_max = accumulate_max(lits, c); + if (curr_max <= k) { + add_clause(lits); + } + else if (curr_max == k + 1) { + for (unsigned i = 0; i < args.size(); ++i) { + bool_var bv = args[i].first; + int inc = args[i].second; + if (dec_max(inc, ctx.get_assignment(bv)) == l_undef) { + lits.push_back(literal(bv, inc < 0)); // avoid decrementing max. + add_clause(lits); + lits.pop_back(); + } + } + } + } + } + void theory_card::assign_eh(bool_var v, bool is_true) { context& ctx = get_context(); ast_manager& m = get_manager(); @@ -158,125 +351,33 @@ namespace smt { if (m_watch.find(v, cards)) { for (unsigned i = 0; i < cards->size(); ++i) { - c = (*cards)[i]; - svector const& args = c->m_args; - // - // is_true && m_t + 1 > k -> force false - // !is_true && m_f + 1 >= arity - k -> force true - // - if (is_true && c->m_t >= c->m_k) { - unsigned k = c->m_k; - // force false - switch (ctx.get_assignment(c->m_bv)) { - case l_true: - case l_undef: { - literal_vector& lits = get_lits(); - lits.push_back(~literal(c->m_bv)); - for (unsigned i = 0; i < args.size() && lits.size() < k + 1; ++i) { - if (ctx.get_assignment(args[i]) == l_true) { - lits.push_back(~literal(args[i])); - } - } - SASSERT(lits.size() == k + 1); - add_clause(lits); - break; - } - default: - break; - } - } - else if (!is_true && c->m_k >= args.size() - c->m_f - 1) { - // forced true - switch (ctx.get_assignment(c->m_bv)) { - case l_false: - case l_undef: { - unsigned deficit = args.size() - c->m_k; - literal_vector& lits = get_lits(); - lits.push_back(literal(c->m_bv)); - for (unsigned i = 0; i < args.size() && lits.size() <= deficit; ++i) { - if (ctx.get_assignment(args[i]) == l_false) { - lits.push_back(literal(args[i])); - } - } - add_clause(lits); - break; - } - default: - break; - } - } - else if (is_true) { - ctx.push_trail(value_trail(c->m_t)); - c->m_t++; - } - else { - ctx.push_trail(value_trail(c->m_f)); - c->m_f++; - } + assign_use(v, is_true, (*cards)[i]); } } if (m_cards.find(v, c)) { - svector const& args = c->m_args; - SASSERT(args.size() >= c->m_f + c->m_t); - bool_var bv; + propagate_assignment(c); + } + } - TRACE("card", tout << " t:" << is_true << " k:" << c->m_k << " t:" << c->m_t << " f:" << c->m_f << "\n";); - - // at most k - // propagate false to children that are not yet assigned. - // v & t1 & ... & tk => ~l_j - if (is_true && c->m_k <= c->m_t) { - - literal_vector& lits = get_lits(); - lits.push_back(literal(v)); - bool done = false; - for (unsigned i = 0; !done && i < args.size(); ++i) { - bv = args[i]; - if (ctx.get_assignment(bv) == l_true) { - lits.push_back(literal(bv)); - } - if (lits.size() > c->m_k + 1) { - add_clause(lits); - done = true; - } - } - SASSERT(done || lits.size() == c->m_k + 1); - for (unsigned i = 0; !done && i < args.size(); ++i) { - bv = args[i]; - if (ctx.get_assignment(bv) == l_undef) { - lits.push_back(literal(bv)); - add_clause(lits); - lits.pop_back(); - } - } + int theory_card::find_inc(bool_var bv, svector >const& vars) { + unsigned mid = vars.size()/2; + unsigned lo = 0; + unsigned hi = vars.size()-1; + while (lo < hi) { + if (vars[mid].first == bv) { + return vars[mid].second; } - // at least k+1: - // !v & !f1 & .. & !f_m => l_j - // for m + k + 1 = arity() - if (!is_true && args.size() <= 1 + c->m_f + c->m_k) { - literal_vector& lits = get_lits(); - lits.push_back(literal(v)); - bool done = false; - for (unsigned i = 0; !done && i < args.size(); ++i) { - bv = args[i]; - if (ctx.get_assignment(bv) == l_false) { - lits.push_back(literal(bv)); - } - if (lits.size() > c->m_k + 1) { - add_clause(lits); - done = true; - } - } - for (unsigned i = 0; !done && i < args.size(); ++i) { - bv = args[i]; - if (ctx.get_assignment(bv) != l_false) { - lits.push_back(~literal(bv)); - add_clause(lits); - lits.pop_back(); - } - } + else if (vars[mid].first < bv) { + lo = mid; + mid += (hi-mid)/2; + } + else { + hi = mid; + mid = (mid-lo)/2 + lo; } } + SASSERT(vars[mid].first == bv); + return vars[mid].second; } void theory_card::init_search_eh() { @@ -319,151 +420,79 @@ namespace smt { ctx.mk_th_axiom(get_id(), lits.size(), lits.c_ptr()); } + +#if 1 + + +#endif + } + + #if 0 -class sorting_network { - ast_manager& m; - expr_ref_vector m_es; - expr_ref_vector* m_current; - expr_ref_vector* m_next; - - void exchange(unsigned i, unsigned j, expr_ref_vector& es) { - SASSERT(i <= j); - if (i == j) { - return; - } - expr* ei = es[i].get(); - expr* ej = es[j].get(); - es[i] = m.mk_ite(mk_le(ei,ej), ei, ej); - es[j] = m.mk_ite(mk_le(ej,ei), ei, ej); - } - - void sort(unsigned k) { - if (k == 2) { - for (unsigned i = 0; i < m_es.size()/2; ++i) { - exchange(current(2*i), current(2*i+1), m_es); - next(2*i) = current(2*i); - next(2*i+1) = current(2*i+1); + expr_ref_vector merge(expr_ref_vector const& l1, expr_ref_vector const& l2) { + if (l1.empty()) { + return l2; } - std::swap(m_current, m_next); - } - else { - - for (unsigned i = 0; i < m_es.size()/k; ++i) { - for (unsigned j = 0; j < k / 2; ++j) { - next((k * i) + j) = current((k * i) + (2 * j)); - next((k * i) + (k / 2) + j) = current((k * i) + (2 * j) + 1); - } + if (l2.empty()) { + return l1; } - - std::swap(m_current, m_next); - sort(k / 2); - for (unsigned i = 0; i < m_es.size() / k; ++i) { - for (unsigned j = 0; j < k / 2; ++j) { - next((k * i) + (2 * j)) = current((k * i) + j); - next((k * i) + (2 * j) + 1) = current((k * i) + (k / 2) + j); - } - - for (unsigned j = 0; j < (k / 2) - 1; ++j) { - exchange(next((k * i) + (2 * j) + 1), next((k * i) + (2 * (j + 1)))); - } + expr_ref_vector result(m); + if (l1.size() == 1 && l2.size() == 1) { + result.push_back(l1[0]); + result.push_back(l2[0]); + exchange(0, 1, result); + return result; } - std::swap(m_current, m_next); - } - } - - expr_ref_vector merge(expr_ref_vector const& l1, expr_ref_vector& l2) { - if (l1.empty()) { - return l2; - } - if (l2.empty()) { - return l1; - } - expr_ref_vector result(m); - if (l1.size() == 1 && l2.size() == 1) { - result.push_back(l1[0]); - result.push_back(l2[0]); - exchange(0, 1, result); - return result; - } - unsigned l1o = l1.size()/2; - unsigned l2o = l2.size()/2; - unsigned l1e = (l1.size() % 2 == 1) ? l1o + 1 : l1o; - unsigned l2e = (l2.size() % 2 == 1) ? l2o + 1 : l2o; - expr_ref_vector evenl1(m, l1e); - expr_ref_vector oddl1(m, l1o); - expr_ref_vector evenl2(m, l2e); - expr_ref_vector oddl2(m, l2o); - for (unsigned i = 0; i < l1.size(); ++i) { - if (i % 2 == 0) { - evenl1[i/2] = l1[i]; - } - else { - oddl1[i/2] = l1[i]; - } - } - for (unsigned i = 0; i < l2.size(); ++i) { - if (i % 2 == 0) { - evenl2[i/2] = l2[i]; - } - else { - oddl2[i/2] = l2[i]; - } - } - expr_ref_vector even = merge(evenl1, evenl2); - expr_ref_vector odd = merge(oddl1, oddl2); - - result.resize(l1.size() + l2.size()); - for (unsigned i = 0; i < result.size(); ++i) { - if (i % 2 == 0) { - result[i] = even[i/2].get(); - if (i > 0) { - exchange(i - 1, i, result); - } - } - else { - if (i /2 < odd.size()) { - result[i] = odd[i/2].get(); + unsigned l1o = l1.size()/2; + unsigned l2o = l2.size()/2; + unsigned l1e = (l1.size() % 2 == 1) ? l1o + 1 : l1o; + unsigned l2e = (l2.size() % 2 == 1) ? l2o + 1 : l2o; + expr_ref_vector evenl1(m), oddl1(m), evenl2(m), oddl2(m); + evenl1.resize(l1e); + oddl1.resize(l1o); + evenl2.resize(l2e); + oddl2.resize(l2o); + for (unsigned i = 0; i < l1.size(); ++i) { + if (i % 2 == 0) { + evenl1[i/2] = l1[i]; } else { - result[i] = even[(i/2)+1].get(); + oddl1[i/2] = l1[i]; } } - } - return result; - } + for (unsigned i = 0; i < l2.size(); ++i) { + if (i % 2 == 0) { + evenl2[i/2] = l2[i]; + } + else { + oddl2[i/2] = l2[i]; + } + } + expr_ref_vector even = merge(evenl1, evenl2); + expr_ref_vector odd = merge(oddl1, oddl2); -public: - sorting_network(ast_manager& m): - m(m), - m_es(m), - m_current(0), - m_next(0) - {} - - expr_ref_vector operator()(expr_ref_vector const& inputs) { - if (inputs.size() <= 1) { - return inputs; + result.resize(l1.size() + l2.size()); + for (unsigned i = 0; i < result.size(); ++i) { + if (i % 2 == 0) { + result[i] = even[i/2].get(); + if (i > 0) { + exchange(i - 1, i, result); + } + } + else { + if (i /2 < odd.size()) { + result[i] = odd[i/2].get(); + } + else { + result[i] = even[(i/2)+1].get(); + } + } + } + return result; } - m_es.reset(); - m_es.append(inputs); - while (!is_power_of2(m_es.size())) { - m_es.push_back(m.mk_false()); - } - m_es.reverse(); - for (unsigned i = 0; i < m_es.size(); ++i) { - current(i) = i; - } - unsigned k = 2; - while (k <= m_es.size()) { - sort(k); - // TBD - k *= 2; - } - } -}; Sorting networks used in Formula: diff --git a/src/smt/theory_card.h b/src/smt/theory_card.h index ca1d5a061..c53e38e37 100644 --- a/src/smt/theory_card.h +++ b/src/smt/theory_card.h @@ -25,14 +25,19 @@ Notes: namespace smt { class theory_card : public theory { + + typedef svector > arg_t; + struct card { - unsigned m_k; + int m_k; bool_var m_bv; - unsigned m_t; - unsigned m_f; - svector m_args; + int m_current_min; + int m_current_max; + int m_abs_min; + int m_abs_max; + arg_t m_args; card(bool_var bv, unsigned k): - m_k(k), m_bv(bv), m_t(0), m_f(0) + m_k(k), m_bv(bv) {} }; @@ -46,13 +51,19 @@ namespace smt { card_util m_util; void add_watch(bool_var bv, card* c); + void add_card(card* c); - void add_card(card* c) { - m_cards.insert(c->m_bv, c); - m_cards_trail.push_back(c->m_bv); - } void add_clause(literal_vector const& lits); literal_vector& get_lits(); + + int find_inc(bool_var bv, svector >const& vars); + void theory_card::propagate_assignment(card* c); + int theory_card::accumulate_max(literal_vector& lits, card* c); + int theory_card::accumulate_min(literal_vector& lits, card* c); + lbool theory_card::dec_max(int inc, lbool val); + lbool theory_card::inc_min(int inc, lbool val); + void theory_card::assign_use(bool_var v, bool is_true, card* c); + void theory_card::update_min_max(bool_var v, bool is_true, card* c); public: theory_card(ast_manager& m); diff --git a/src/tactic/arith/lia2card_tactic.cpp b/src/tactic/arith/lia2card_tactic.cpp index 7d257d86e..d2be433c6 100644 --- a/src/tactic/arith/lia2card_tactic.cpp +++ b/src/tactic/arith/lia2card_tactic.cpp @@ -175,54 +175,48 @@ class lia2card_tactic : public tactic { if (a.is_le(fml, x, y) || a.is_ge(fml, y, x)) { if (is_01var(x) && a.is_numeral(y, n)) { sub.insert(fml, mk_le(x, n)); - return; } - if (is_01var(y) && a.is_numeral(x, n)) { + else if (is_01var(y) && a.is_numeral(x, n)) { sub.insert(fml, mk_ge(y, n)); - return; } - if (is_add(x, args) && is_unsigned(y, k)) { // x <= k + else if (is_add(x, args) && is_unsigned(y, k)) { // x <= k sub.insert(fml, m_card.mk_at_most_k(args.size(), args.c_ptr(), k)); - return; } - if (is_add(y, args) && is_unsigned(x, k)) { // k <= y <=> not (y <= k-1) + else if (is_add(y, args) && is_unsigned(x, k)) { // k <= y <=> not (y <= k-1) if (k == 0) sub.insert(fml, m.mk_true()); else sub.insert(fml, m.mk_not(m_card.mk_at_most_k(args.size(), args.c_ptr(), k-1))); - return; } - UNREACHABLE(); + else { + UNREACHABLE(); + } } - - if (a.is_lt(fml, x, y) || a.is_gt(fml, y, x)) { + else if (a.is_lt(fml, x, y) || a.is_gt(fml, y, x)) { if (is_01var(x) && a.is_numeral(y, n)) { sub.insert(fml, mk_le(x, n-rational(1))); - return; } - if (is_01var(y) && a.is_numeral(x, n)) { + else if (is_01var(y) && a.is_numeral(x, n)) { sub.insert(fml, mk_ge(y, n+rational(1))); - return; } - if (is_add(x, args) && is_unsigned(y, k)) { // x < k + else if (is_add(x, args) && is_unsigned(y, k)) { // x < k if (k == 0) sub.insert(fml, m.mk_false()); else sub.insert(fml, m_card.mk_at_most_k(args.size(), args.c_ptr(), k-1)); - return; - } - - if (is_add(y, args) && is_unsigned(x, k)) { // k < y <=> not (y <= k) + } + else if (is_add(y, args) && is_unsigned(x, k)) { // k < y <=> not (y <= k) sub.insert(fml, m.mk_not(m_card.mk_at_most_k(args.size(), args.c_ptr(), k))); - return; } - UNREACHABLE(); + else { + UNREACHABLE(); + } } - if (m.is_eq(fml, x, y)) { + else if (m.is_eq(fml, x, y)) { if (!is_01var(x)) { std::swap(x, y); } - if (is_01var(x) && a.is_numeral(y, n)) { + else if (is_01var(x) && a.is_numeral(y, n)) { if (n.is_one()) { sub.insert(fml, mk_01(x)); } @@ -232,19 +226,21 @@ class lia2card_tactic : public tactic { else { sub.insert(fml, m.mk_false()); } - return; } - UNREACHABLE(); + else { + UNREACHABLE(); + } } - if (is_sum(fml)) { + else if (is_sum(fml)) { SASSERT(m_uses.contains(fml)); ptr_vector const& u = m_uses.find(fml); for (unsigned i = 0; i < u.size(); ++i) { convert_01(sub, u[i]); } - return; } - UNREACHABLE(); + else { + UNREACHABLE(); + } } expr_ref mk_01(expr* x) { diff --git a/src/test/main.cpp b/src/test/main.cpp index bc7e04124..94c4feb65 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -216,6 +216,7 @@ int main(int argc, char ** argv) { TST(polynorm); TST(qe_arith); TST(expr_substitution); + TST(sorting_network); } void initialize_mam() {} diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp new file mode 100644 index 000000000..904bbb970 --- /dev/null +++ b/src/test/sorting_network.cpp @@ -0,0 +1,60 @@ + +#include "sorting_network.h" +#include "vector.h" +#include "ast.h" + +struct ast_ext { + ast_manager& m; + ast_ext(ast_manager& m):m(m) {} + typedef expr* T; + typedef expr_ref_vector vector; + T mk_ite(T a, T b, T c) { + return m.mk_ite(a, b, c); + } + T mk_le(T a, T b) { + if (m.is_bool(a)) { + return m.mk_implies(a, b); + } + UNREACHABLE(); + return 0; + } + T mk_default() { + return m.mk_false(); + } +}; + +struct unsigned_ext { + unsigned_ext() {} + typedef unsigned T; + typedef svector vector; + T mk_ite(T a, T b, T c) { + return (a==1)?b:c; + } + T mk_le(T a, T b) { + return (a <= b)?1:0; + } + T mk_default() { + return 0; + } +}; + +void tst_sorting_network() { + svector vec; + unsigned_ext uext; + sorting_network sn(uext, vec); + + svector in1; + in1.push_back(0); + in1.push_back(1); + in1.push_back(0); + in1.push_back(1); + in1.push_back(1); + in1.push_back(0); + + sn(in1); + + for (unsigned i = 0; i < vec.size(); ++i) { + std::cout << vec[i]; + } + std::cout << "\n"; +} diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h new file mode 100644 index 000000000..d823cf44c --- /dev/null +++ b/src/util/sorting_network.h @@ -0,0 +1,99 @@ + +#include "vector.h" + +#ifndef _SORTING_NETWORK_H_ +#define _SORTING_NETWORK_H_ + + + template + class sorting_network { + typename Ext::vector& m_es; + Ext& m_ext; + svector m_currentv; + svector m_nextv; + svector* m_current; + svector* m_next; + + unsigned& current(unsigned i) { return (*m_current)[i]; } + unsigned& next(unsigned i) { return (*m_next)[i]; } + + void exchange(unsigned i, unsigned j) { + SASSERT(i <= j); + if (i < j) { + Ext::T ei = m_es.get(i); + Ext::T ej = m_es.get(j); + m_es.set(i, m_ext.mk_ite(m_ext.mk_le(ei, ej), ei, ej)); + m_es.set(j, m_ext.mk_ite(m_ext.mk_le(ej, ei), ei, ej)); + } + } + + void sort(unsigned k) { + SASSERT(is_power_of2(k) && k > 0); + if (k == 2) { + for (unsigned i = 0; i < m_es.size()/2; ++i) { + exchange(current(2*i), current(2*i+1)); + next(2*i) = current(2*i); + next(2*i+1) = current(2*i+1); + } + std::swap(m_current, m_next); + } + else { + + for (unsigned i = 0; i < m_es.size()/k; ++i) { + unsigned ki = k * i; + for (unsigned j = 0; j < k / 2; ++j) { + next(ki + j) = current(ki + (2 * j)); + next(ki + (k / 2) + j) = current(ki + (2 * j) + 1); + } + } + + std::swap(m_current, m_next); + sort(k / 2); + for (unsigned i = 0; i < m_es.size() / k; ++i) { + unsigned ki = k * i; + for (unsigned j = 0; j < k / 2; ++j) { + next(ki + (2 * j)) = current(ki + j); + next(ki + (2 * j) + 1) = current(ki + (k / 2) + j); + } + + for (unsigned j = 0; j < (k / 2) - 1; ++j) { + exchange(next(ki + (2 * j) + 1), next(ki + (2 * (j + 1)))); + } + } + std::swap(m_current, m_next); + } + } + + bool is_power_of2(unsigned n) const { + return n != 0 && ((n-1) & n) == 0; + } + + public: + sorting_network(Ext& ext, typename Ext::vector& es): + m_ext(ext), + m_es(es), + m_current(&m_currentv), + m_next(&m_nextv) + {} + + void operator()(typename Ext::vector const& inputs) { + if (inputs.size() <= 1) { + return; + } + m_es.reset(); + m_es.append(inputs); + while (!is_power_of2(m_es.size())) { + m_es.push_back(m_ext.mk_default()); + } + for (unsigned i = 0; i < m_es.size(); ++i) { + current(i) = i; + } + unsigned k = 2; + while (k <= m_es.size()) { + sort(k); + k *= 2; + } + } + }; + +#endif