diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp index 2a6ba3666..d54c575ca 100644 --- a/src/test/sorting_network.cpp +++ b/src/test/sorting_network.cpp @@ -237,7 +237,7 @@ static void test_sorting_le(unsigned n, unsigned k) { smt::kernel solver(m, fp); psort_nw sn(ext); expr_ref result(m); - // k <= B + // B <= k std::cout << "le " << k << "\n"; solver.push(); result = sn.le(false, k, in.size(), in.c_ptr()); @@ -248,12 +248,12 @@ static void test_sorting_le(unsigned n, unsigned k) { lbool res = solver.check(); SASSERT(res == l_true); - for (unsigned i = 0; i < n - k; ++i) { - solver.assert_expr(m.mk_not(in[i].get())); + for (unsigned i = 0; i < k; ++i) { + solver.assert_expr(in[i].get()); } res = solver.check(); SASSERT(res == l_true); - solver.assert_expr(m.mk_not(in[n - k].get())); + solver.assert_expr(in[k].get()); res = solver.check(); if (res == l_true) { TRACE("pb", @@ -284,7 +284,7 @@ void test_sorting_ge(unsigned n, unsigned k) { smt::kernel solver(m, fp); psort_nw sn(ext); expr_ref result(m); - // k >= B + // k <= B std::cout << "ge " << k << "\n"; solver.push(); result = sn.ge(false, k, in.size(), in.c_ptr()); @@ -326,14 +326,14 @@ void test_sorting5(unsigned n, unsigned k) { } void tst_sorting_network() { - test_sorting1(); - test_sorting2(); - test_sorting3(); - test_sorting4(); - test_sorting5(11,4); + test_sorting_eq(11,7); for (unsigned n = 3; n < 20; n += 2) { for (unsigned k = 1; k < n; ++k) { test_sorting5(n, k); } } + test_sorting1(); + test_sorting2(); + test_sorting3(); + test_sorting4(); } diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index c2bdc600e..ee01d2cd2 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -680,6 +680,12 @@ Notes: for (unsigned k = 0; k < c; ++k) { ls.reset(); ls.push_back(ctx.mk_not(out[k])); + if (a <= k) { + add_clause(ctx.mk_not(out[k]), bs[k-a]); + } + if (b <= k) { + add_clause(ctx.mk_not(out[k]), as[k-b]); + } for (unsigned i = 0; i < std::min(a,k + 1); ++i) { unsigned j = k - i; SASSERT(i + j == k);