From 770d0d58fe1580c8e1a6956157ffb120b1a8ef21 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Thu, 11 Sep 2014 21:53:12 -0700
Subject: [PATCH] bug fixes to sorting network

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 src/test/sorting_network.cpp | 20 ++++++++++----------
 src/util/sorting_network.h   |  6 ++++++
 2 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp
index 2a6ba3666..d54c575ca 100644
--- a/src/test/sorting_network.cpp
+++ b/src/test/sorting_network.cpp
@@ -237,7 +237,7 @@ static void test_sorting_le(unsigned n, unsigned k) {
     smt::kernel solver(m, fp);
     psort_nw<ast_ext2> sn(ext);
     expr_ref result(m);
-    // k <= B
+    // B <= k
     std::cout << "le " << k << "\n";
     solver.push();
     result = sn.le(false, k, in.size(), in.c_ptr());
@@ -248,12 +248,12 @@ static void test_sorting_le(unsigned n, unsigned k) {
     lbool res = solver.check();
     SASSERT(res == l_true);
 
-    for (unsigned i = 0; i < n - k; ++i) {
-        solver.assert_expr(m.mk_not(in[i].get()));        
+    for (unsigned i = 0; i < k; ++i) {
+        solver.assert_expr(in[i].get());        
     }
     res = solver.check();
     SASSERT(res == l_true);
-    solver.assert_expr(m.mk_not(in[n - k].get()));
+    solver.assert_expr(in[k].get());
     res = solver.check();
     if (res == l_true) {
         TRACE("pb",
@@ -284,7 +284,7 @@ void test_sorting_ge(unsigned n, unsigned k) {
     smt::kernel solver(m, fp);
     psort_nw<ast_ext2> sn(ext);
     expr_ref result(m);
-    // k >= B
+    // k <= B
     std::cout << "ge " << k << "\n";
     solver.push();
     result = sn.ge(false, k, in.size(), in.c_ptr());
@@ -326,14 +326,14 @@ void test_sorting5(unsigned n, unsigned k) {
 }
 
 void tst_sorting_network() {
-    test_sorting1();
-    test_sorting2();
-    test_sorting3();
-    test_sorting4();
-    test_sorting5(11,4);
+    test_sorting_eq(11,7);
     for (unsigned n = 3; n < 20; n += 2) {
         for (unsigned k = 1; k < n; ++k) {
             test_sorting5(n, k);
         }
     }
+    test_sorting1();
+    test_sorting2();
+    test_sorting3();
+    test_sorting4();
 }
diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h
index c2bdc600e..ee01d2cd2 100644
--- a/src/util/sorting_network.h
+++ b/src/util/sorting_network.h
@@ -680,6 +680,12 @@ Notes:
                 for (unsigned k = 0; k < c; ++k) {
                     ls.reset();
                     ls.push_back(ctx.mk_not(out[k]));
+                    if (a <= k) {
+                        add_clause(ctx.mk_not(out[k]), bs[k-a]);
+                    }
+                    if (b <= k) {
+                        add_clause(ctx.mk_not(out[k]), as[k-b]);
+                    }
                     for (unsigned i = 0; i < std::min(a,k + 1); ++i) {
                         unsigned j = k - i;
                         SASSERT(i + j == k);