/*++ Copyright (c) 2015 Microsoft Corporation --*/ #include "trace.h" #include "vector.h" #include "ast.h" #include "ast_pp.h" #include "reg_decl_plugins.h" #include "sorting_network.h" #include "smt_kernel.h" #include "model_smt2_pp.h" #include "smt_params.h" #include "ast_util.h" struct ast_ext { ast_manager& m; ast_ext(ast_manager& m):m(m) {} typedef expr* T; typedef expr_ref_vector vector; T mk_ite(T a, T b, T c) { return m.mk_ite(a, b, c); } T mk_le(T a, T b) { if (m.is_bool(a)) { return m.mk_implies(a, b); } UNREACHABLE(); return 0; } T mk_default() { return m.mk_false(); } }; struct unsigned_ext { unsigned_ext() {} typedef unsigned T; typedef svector vector; T mk_ite(T a, T b, T c) { return (a==1)?b:c; } T mk_le(T a, T b) { return (a <= b)?1:0; } T mk_default() { return 0; } }; static void is_sorted(svector const& v) { for (unsigned i = 0; i + 1 < v.size(); ++i) { SASSERT(v[i] <= v[i+1]); } } static void test_sorting1() { svector in, out; unsigned_ext uext; sorting_network sn(uext); in.push_back(0); in.push_back(1); in.push_back(0); in.push_back(1); in.push_back(1); in.push_back(0); sn(in, out); is_sorted(out); for (unsigned i = 0; i < out.size(); ++i) { std::cout << out[i]; } std::cout << "\n"; } static void test_sorting2() { svector in, out; unsigned_ext uext; sorting_network sn(uext); in.push_back(0); in.push_back(1); in.push_back(2); in.push_back(1); in.push_back(1); in.push_back(3); sn(in, out); is_sorted(out); for (unsigned i = 0; i < out.size(); ++i) { std::cout << out[i]; } std::cout << "\n"; } static void test_sorting4_r(unsigned i, svector& in) { if (i == in.size()) { svector out; unsigned_ext uext; sorting_network sn(uext); sn(in, out); is_sorted(out); std::cout << "sorted\n"; } else { in[i] = 0; test_sorting4_r(i+1, in); in[i] = 1; test_sorting4_r(i+1, in); } } static void test_sorting4() { svector in; in.resize(5); test_sorting4_r(0, in); in.resize(8); test_sorting4_r(0, in); } void test_sorting3() { ast_manager m; reg_decl_plugins(m); expr_ref_vector in(m), out(m); for (unsigned i = 0; i < 7; ++i) { in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); } for (unsigned i = 0; i < in.size(); ++i) { std::cout << mk_pp(in[i].get(), m) << "\n"; } ast_ext aext(m); sorting_network sn(aext); sn(in, out); std::cout << "size: " << out.size() << "\n"; for (unsigned i = 0; i < out.size(); ++i) { std::cout << mk_pp(out[i].get(), m) << "\n"; } } struct ast_ext2 { ast_manager& m; expr_ref_vector m_clauses; expr_ref_vector m_trail; ast_ext2(ast_manager& m):m(m), m_clauses(m), m_trail(m) {} typedef expr* literal; typedef ptr_vector literal_vector; expr* trail(expr* e) { m_trail.push_back(e); return e; } literal mk_false() { return m.mk_false(); } literal mk_true() { return m.mk_true(); } literal mk_max(literal a, literal b) { return trail(m.mk_or(a, b)); } literal mk_min(literal a, literal b) { return trail(m.mk_and(a, b)); } literal mk_not(literal a) { if (m.is_not(a,a)) return a; return trail(m.mk_not(a)); } std::ostream& pp(std::ostream& out, literal lit) { return out << mk_pp(lit, m); } literal fresh() { return trail(m.mk_fresh_const("x", m.mk_bool_sort())); } void mk_clause(unsigned n, literal const* lits) { m_clauses.push_back(mk_or(m, n, lits)); } }; static void test_sorting_eq(unsigned n, unsigned k) { SASSERT(k < n); ast_manager m; reg_decl_plugins(m); ast_ext2 ext(m); expr_ref_vector in(m), out(m); for (unsigned i = 0; i < n; ++i) { in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); } smt_params fp; smt::kernel solver(m, fp); psort_nw sn(ext); expr_ref result(m); // equality: std::cout << "eq " << k << "\n"; solver.push(); result = sn.eq(k, in.size(), in.c_ptr()); solver.assert_expr(result); for (unsigned i = 0; i < ext.m_clauses.size(); ++i) { solver.assert_expr(ext.m_clauses[i].get()); } lbool res = solver.check(); SASSERT(res == l_true); solver.push(); for (unsigned i = 0; i < k; ++i) { solver.assert_expr(in[i].get()); } res = solver.check(); SASSERT(res == l_true); solver.assert_expr(in[k].get()); res = solver.check(); if (res == l_true) { TRACE("pb", unsigned sz = solver.size(); for (unsigned i = 0; i < sz; ++i) { tout << mk_pp(solver.get_formulas()[i], m) << "\n"; }); model_ref model; solver.get_model(model); model_smt2_pp(std::cout, m, *model, 0); TRACE("pb", model_smt2_pp(tout, m, *model, 0);); } SASSERT(res == l_false); solver.pop(1); ext.m_clauses.reset(); } static void test_sorting_le(unsigned n, unsigned k) { ast_manager m; reg_decl_plugins(m); ast_ext2 ext(m); expr_ref_vector in(m), out(m); for (unsigned i = 0; i < n; ++i) { in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); } smt_params fp; smt::kernel solver(m, fp); psort_nw sn(ext); expr_ref result(m); // B <= k std::cout << "le " << k << "\n"; solver.push(); result = sn.le(false, k, in.size(), in.c_ptr()); solver.assert_expr(result); for (unsigned i = 0; i < ext.m_clauses.size(); ++i) { solver.assert_expr(ext.m_clauses[i].get()); } lbool res = solver.check(); SASSERT(res == l_true); for (unsigned i = 0; i < k; ++i) { solver.assert_expr(in[i].get()); } res = solver.check(); SASSERT(res == l_true); solver.assert_expr(in[k].get()); res = solver.check(); if (res == l_true) { TRACE("pb", unsigned sz = solver.size(); for (unsigned i = 0; i < sz; ++i) { tout << mk_pp(solver.get_formulas()[i], m) << "\n"; }); model_ref model; solver.get_model(model); model_smt2_pp(std::cout, m, *model, 0); TRACE("pb", model_smt2_pp(tout, m, *model, 0);); } SASSERT(res == l_false); solver.pop(1); ext.m_clauses.reset(); } void test_sorting_ge(unsigned n, unsigned k) { ast_manager m; reg_decl_plugins(m); ast_ext2 ext(m); expr_ref_vector in(m), out(m); for (unsigned i = 0; i < n; ++i) { in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); } smt_params fp; smt::kernel solver(m, fp); psort_nw sn(ext); expr_ref result(m); // k <= B std::cout << "ge " << k << "\n"; solver.push(); result = sn.ge(false, k, in.size(), in.c_ptr()); solver.assert_expr(result); for (unsigned i = 0; i < ext.m_clauses.size(); ++i) { solver.assert_expr(ext.m_clauses[i].get()); } lbool res = solver.check(); SASSERT(res == l_true); solver.push(); for (unsigned i = 0; i < n - k; ++i) { solver.assert_expr(m.mk_not(in[i].get())); } res = solver.check(); SASSERT(res == l_true); solver.assert_expr(m.mk_not(in[n - k].get())); res = solver.check(); if (res == l_true) { TRACE("pb", unsigned sz = solver.size(); for (unsigned i = 0; i < sz; ++i) { tout << mk_pp(solver.get_formulas()[i], m) << "\n"; }); model_ref model; solver.get_model(model); model_smt2_pp(std::cout, m, *model, 0); TRACE("pb", model_smt2_pp(tout, m, *model, 0);); } SASSERT(res == l_false); solver.pop(1); } void test_sorting5(unsigned n, unsigned k) { std::cout << "n: " << n << " k: " << k << "\n"; test_sorting_le(n, k); test_sorting_eq(n, k); test_sorting_ge(n, k); } void tst_sorting_network() { test_sorting_eq(11,7); for (unsigned n = 3; n < 20; n += 2) { for (unsigned k = 1; k < n; ++k) { test_sorting5(n, k); } } test_sorting1(); test_sorting2(); test_sorting3(); test_sorting4(); }