From 8d2b70a5e259044c19cc2c24bc133bf088f2517e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 10 Oct 2016 23:46:03 -0700 Subject: [PATCH] better encodings for at-most-1, #755 Signed-off-by: Nikolaj Bjorner --- src/ast/pb_decl_plugin.h | 2 + src/opt/opt_solver.cpp | 5 + src/opt/opt_solver.h | 1 + src/sat/sat_solver.cpp | 12 ++ src/smt/smt_consequences.cpp | 244 +++++++--------------------- src/smt/smt_context.h | 12 +- src/smt/smt_kernel.cpp | 10 +- src/smt/smt_kernel.h | 5 + src/smt/smt_solver.cpp | 5 + src/smt/theory_pb.cpp | 7 +- src/solver/solver.cpp | 57 ------- src/solver/solver_na2as.cpp | 6 + src/solver/solver_na2as.h | 1 + src/tactic/arith/card2bv_tactic.cpp | 40 ++++- src/tactic/arith/card2bv_tactic.h | 1 + src/test/sorting_network.cpp | 25 +++ src/util/sorting_network.h | 52 ++++++ 17 files changed, 232 insertions(+), 253 deletions(-) diff --git a/src/ast/pb_decl_plugin.h b/src/ast/pb_decl_plugin.h index e1b16f0c9..d0649729b 100644 --- a/src/ast/pb_decl_plugin.h +++ b/src/ast/pb_decl_plugin.h @@ -119,6 +119,8 @@ public: app* mk_fresh_bool(); + expr_ref mk_at_most_1(unsigned num_args, expr * const * args); + private: rational to_rational(parameter const& p) const; diff --git a/src/opt/opt_solver.cpp b/src/opt/opt_solver.cpp index 64ab2823f..0d0283232 100644 --- a/src/opt/opt_solver.cpp +++ b/src/opt/opt_solver.cpp @@ -194,6 +194,11 @@ namespace opt { } } + lbool opt_solver::find_mutexes(expr_ref_vector const& vars, vector& mutexes) { + return m_context.find_mutexes(vars, mutexes); + } + + /** \brief maximize the value of objective i in the current state. diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index b2fa18ad8..16c2061c7 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -106,6 +106,7 @@ namespace opt { virtual expr * get_assertion(unsigned idx) const; virtual std::ostream& display(std::ostream & out) const; virtual ast_manager& get_manager() const { return m; } + virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes); void set_logic(symbol const& logic); smt::theory_var add_objective(app* term); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index f553acb29..f2279b3c4 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -3053,6 +3053,12 @@ namespace sat { return r; } + // ----------------------- + // + // Extraction of mutexes + // + // ----------------------- + lbool solver::find_mutexes(literal_vector const& lits, vector & mutexes) { literal_vector ps(lits); m_user_bin_clauses.reset(); @@ -3111,6 +3117,12 @@ namespace sat { } } + // ----------------------- + // + // Consequence generation. + // + // ----------------------- + lbool solver::get_consequences(literal_vector const& asms, bool_var_vector const& vars, vector& conseq) { literal_vector lits; lbool is_sat = check(asms.size(), asms.c_ptr()); diff --git a/src/smt/smt_consequences.cpp b/src/smt/smt_consequences.cpp index 8619ff815..e44782a31 100644 --- a/src/smt/smt_consequences.cpp +++ b/src/smt/smt_consequences.cpp @@ -251,191 +251,6 @@ namespace smt { expr_ref_vector& conseq, expr_ref_vector& unfixed) { - m_antecedents.reset(); - pop_to_base_lvl(); - lbool is_sat = check(assumptions.size(), assumptions.c_ptr()); - if (is_sat != l_true) { - return is_sat; - } - obj_map var2val; - index_set _assumptions; - for (unsigned i = 0; i < assumptions.size(); ++i) { - _assumptions.insert(get_literal(assumptions[i]).var()); - } - model_ref mdl; - get_model(mdl); - ast_manager& m = m_manager; - expr_ref_vector trail(m); - model_evaluator eval(*mdl.get()); - expr_ref val(m); - TRACE("context", model_pp(tout, *mdl);); - for (unsigned i = 0; i < vars.size(); ++i) { - eval(vars[i], val); - if (m.is_value(val)) { - trail.push_back(val); - var2val.insert(vars[i], val); - } - else { - unfixed.push_back(vars[i]); - } - } - unsigned num_units = 0; - extract_fixed_consequences(num_units, var2val, _assumptions, conseq); - app_ref eq(m); - TRACE("context", - tout << "vars: " << vars.size() << "\n"; - tout << "lits: " << num_units << "\n";); - m_case_split_queue->init_search_eh(); - unsigned num_iterations = 0; - unsigned model_threshold = 2; - unsigned num_fixed_eqs = 0; - unsigned num_reiterations = 0; - while (!var2val.empty()) { - obj_map::iterator it = var2val.begin(); - expr* e = it->m_key; - expr* val = it->m_value; - - TRACE("context", tout << "scope level: " << get_scope_level() << "\n";); - SASSERT(!inconsistent()); - - // - // The current variable is checked to be a backbone - // We add the negation of the reference assignment to the variable. - // If the variable is a Boolean, it means adding literal that has - // the opposite value of the current reference model. - // If the variable is a non-Boolean, it means adding a disequality. - // - literal lit = mk_diseq(e, val); - mark_as_relevant(lit); - push_scope(); - assign(lit, b_justification::mk_axiom(), true); - flet l(m_searching, true); - - // - // We check if the current assignment stack can be extended to a - // satisfying assignment. bounded search may decide to restart, - // in which case it returns l_undef and clears search failure. - // - while (true) { - is_sat = bounded_search(); - TRACE("context", tout << "search result: " << is_sat << "\n";); - if (is_sat != l_true && m_last_search_failure != OK) { - return is_sat; - } - if (is_sat == l_undef) { - TRACE("context", tout << "restart\n";); - inc_limits(); - continue; - } - break; - } - // - // If the state is satisfiable with the current variable assigned to - // a different value from the reference model, it is unfixed. - // - // If it is assigned above the search level we can't conclude anything - // about its value. - // extract_fixed_consequences pops the assignment stack to the search level - // so this sets up the state to retry finding fixed values. - // - // Otherwise, the variable is fixed. - // - it is either assigned at the search level to l_false, or - // - the state is l_false, which means that the variable is fixed by - // the background constraints (and does not depend on assumptions). - // - if (is_sat == l_true && get_assignment(lit) == l_true && is_relevant(lit)) { - var2val.erase(e); - unfixed.push_back(e); - SASSERT(!are_equal(e, val)); - TRACE("context", tout << mk_pp(e, m) << " is unfixed\n"; - display_literal_verbose(tout, lit); tout << "\n"; - tout << "relevant: " << is_relevant(lit) << "\n"; - display(tout);); - } - else if (is_sat == l_true && (get_assign_level(lit) > get_search_level() || !is_relevant(lit))) { - TRACE("context", tout << "Retry fixing: " << mk_pp(e, m) << "\n";); - extract_fixed_consequences(num_units, var2val, _assumptions, conseq); - ++num_reiterations; - continue; - } - else { - // - // The state can be labeled as inconsistent when the implied consequence does - // not depend on assumptions, then the conflict level sits at the search level - // which causes the conflict resolver to decide that the state is unsat. - // - if (l_false == is_sat) { - SASSERT(inconsistent()); - m_conflict = null_b_justification; - m_not_l = null_literal; - } - TRACE("context", tout << "Fixed: " << mk_pp(e, m) << " " << is_sat << "\n"; - if (is_sat == l_false) display(tout);); - - } - ++num_iterations; - - // - // Check the slow pass: it retrieves an updated model and checks if the - // values in the updated model differ from the values in the reference - // model. - // - bool apply_slow_pass = model_threshold <= num_iterations || num_iterations <= 2; - if (apply_slow_pass && is_sat == l_true) { - delete_unfixed(var2val, unfixed); - // The next time we check the model is after 1.5 additional iterations. - model_threshold *= 3; - model_threshold /= 2; - } - - // - // Walk the assignment stack at level 1 for learned consequences. - // The current literal should be assigned at the search level unless - // the state is is_sat == l_true and the assignment to lit is l_true. - // This condition is checked above. - // - extract_fixed_consequences(num_units, var2val, _assumptions, conseq); - - // - // Fixed equalities can be extracted by walking all variables and checking - // if the congruence roots are equal at the search level. - // - if (apply_slow_pass) { - num_fixed_eqs += extract_fixed_eqs(var2val, conseq); - IF_VERBOSE(1, display_consequence_progress(verbose_stream(), num_iterations, var2val.size(), conseq.size(), - unfixed.size(), num_fixed_eqs);); - TRACE("context", display_consequence_progress(tout, num_iterations, var2val.size(), conseq.size(), - unfixed.size(), num_fixed_eqs);); - } - TRACE("context", tout << "finishing " << mk_pp(e, m) << "\n";); - SASSERT(!inconsistent()); - - // - // This becomes unnecessary when the fixed consequence are - // completely extracted. - // - if (var2val.contains(e)) { - TRACE("context", tout << "Fixed value to " << mk_pp(e, m) << " was not processed\n";); - expr_ref fml(m); - fml = m.mk_eq(e, var2val.find(e)); - if (!m_antecedents.contains(lit.var())) { - extract_fixed_consequences(lit, var2val, _assumptions, conseq); - } - fml = m.mk_implies(antecedent2fml(m_antecedents[lit.var()]), fml); - conseq.push_back(fml); - var2val.erase(e); - } - } - end_search(); - DEBUG_CODE(validate_consequences(assumptions, vars, conseq, unfixed);); - return l_true; - } - - lbool context::get_consequences2(expr_ref_vector const& assumptions, - expr_ref_vector const& vars, - expr_ref_vector& conseq, - expr_ref_vector& unfixed) { - m_antecedents.reset(); pop_to_base_lvl(); lbool is_sat = check(assumptions.size(), assumptions.c_ptr()); @@ -552,6 +367,65 @@ namespace smt { } + lbool context::find_mutexes(expr_ref_vector const& vars, vector& mutexes) { + index_set lits; + for (unsigned i = 0; i < vars.size(); ++i) { + expr* n = vars[i]; + bool neg = m_manager.is_not(n, n); + if (b_internalized(n)) { + lits.insert(literal(get_bool_var(n), !neg).index()); + } + } + while (!lits.empty()) { + literal_vector mutex; + index_set other(lits); + while (!other.empty()) { + index_set conseq; + literal p = to_literal(*other.begin()); + other.erase(p.index()); + mutex.push_back(p); + if (other.empty()) { + break; + } + get_reachable(p, other, conseq); + other = conseq; + } + if (mutex.size() > 1) { + expr_ref_vector mux(m_manager); + for (unsigned i = 0; i < mutex.size(); ++i) { + expr_ref e(m_manager); + literal2expr(mutex[i], e); + mux.push_back(e); + } + mutexes.push_back(mux); + } + for (unsigned i = 0; i < mutex.size(); ++i) { + lits.remove(mutex[i].index()); + } + } + return l_true; + } + + void context::get_reachable(literal p, index_set& goal, index_set& reachable) { + index_set seen; + literal_vector todo; + todo.push_back(p); + while (!todo.empty()) { + p = todo.back(); + todo.pop_back(); + if (seen.contains(p.index())) { + continue; + } + seen.insert(p.index()); + literal np = ~p; + if (goal.contains(np.index())) { + reachable.insert(np.index()); + } + watch_list & w = m_watches[np.index()]; + todo.append(static_cast(w.end_literals() - w.begin_literals()), w.begin_literals()); + } + } + // // Validate, in a slow pass, that the current consequences are correctly // extracted. diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 98093d6ed..d9f24ddf6 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1344,6 +1344,9 @@ namespace smt { literal lit, context& src_ctx, context& dst_ctx, vector b2v, ast_translation& tr); + /* + \brief Utilities for consequence finding. + */ typedef hashtable index_set; //typedef uint_set index_set; u_map m_antecedents; @@ -1358,11 +1361,17 @@ namespace smt { expr_ref antecedent2fml(index_set const& ante); + literal mk_diseq(expr* v, expr* val); void validate_consequences(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector const& conseq, expr_ref_vector const& unfixed); + /* + \brief Auxiliry function for mutex finding. + */ + + void get_reachable(literal p, index_set& goal, index_set& reached); public: context(ast_manager & m, smt_params & fp, params_ref const & p = params_ref()); @@ -1404,7 +1413,8 @@ namespace smt { lbool check(unsigned num_assumptions = 0, expr * const * assumptions = 0, bool reset_cancel = true); lbool get_consequences(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq, expr_ref_vector& unfixed); - lbool get_consequences2(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq, expr_ref_vector& unfixed); + + lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes); lbool setup_and_check(bool reset_cancel = true); diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index e5a8a639e..3819f05cb 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -112,7 +112,11 @@ namespace smt { } lbool get_consequences(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq, expr_ref_vector& unfixed) { - return m_kernel.get_consequences2(assumptions, vars, conseq, unfixed); + return m_kernel.get_consequences(assumptions, vars, conseq, unfixed); + } + + lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { + return m_kernel.find_mutexes(vars, mutexes); } void get_model(model_ref & m) const { @@ -272,6 +276,10 @@ namespace smt { return m_imp->get_consequences(assumptions, vars, conseq, unfixed); } + lbool kernel::find_mutexes(expr_ref_vector const& vars, vector& mutexes) { + return m_imp->find_mutexes(vars, mutexes); + } + void kernel::get_model(model_ref & m) const { m_imp->get_model(m); } diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index a10961207..0fec4a21b 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -132,6 +132,11 @@ namespace smt { lbool get_consequences(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq, expr_ref_vector& unfixed); + /* + \brief find mutually exclusive variables. + */ + lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes); + /** \brief Return the model associated with the last check command. */ diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index f9da1bf60..2ea4fea20 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -73,6 +73,10 @@ namespace smt { return m_context.get_consequences(assumptions, vars, conseq, unfixed); } + virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { + return m_context.find_mutexes(vars, mutexes); + } + virtual void assert_expr(expr * t) { m_context.assert_expr(t); } @@ -160,6 +164,7 @@ namespace smt { SASSERT(idx < get_num_assertions()); return m_context.get_formulas()[idx]; } + }; }; diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index a85f9aa80..2461a9db2 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -519,6 +519,7 @@ namespace smt { c->m_compilation_threshold = th; IF_VERBOSE(2, verbose_stream() << "(smt.pb setting compilation threhshold to " << th << ")\n";); TRACE("pb", tout << "compilation threshold: " << th << "\n";); + compile_ineq(*c); } else { c->m_compilation_threshold = UINT_MAX; @@ -1216,7 +1217,7 @@ namespace smt { void theory_pb::inc_propagations(ineq& c) { ++c.m_num_propagations; - if (c.m_compiled == l_false && c.m_num_propagations > c.m_compilation_threshold) { + if (c.m_compiled == l_false && c.m_num_propagations >= c.m_compilation_threshold) { c.m_compiled = l_undef; m_to_compile.push_back(&c); } @@ -1263,12 +1264,14 @@ namespace smt { n -= rational::one(); } } + + if (ctx.get_assignment(thl) == l_true && ctx.get_assign_level(thl) == ctx.get_base_level()) { psort_expr ps(ctx, *this); psort_nw sortnw(ps); sortnw.m_stats.reset(); - at_least_k = sortnw.ge(false, k, in.size(), in.c_ptr()); + at_least_k = sortnw.ge(false, k, in.size(), in.c_ptr()); ctx.mk_clause(~thl, at_least_k, justify(~thl, at_least_k)); m_stats.m_num_compiled_vars += sortnw.m_stats.m_num_compiled_vars; m_stats.m_num_compiled_clauses += sortnw.m_stats.m_num_compiled_clauses; diff --git a/src/solver/solver.cpp b/src/solver/solver.cpp index c75088246..8f32019d5 100644 --- a/src/solver/solver.cpp +++ b/src/solver/solver.cpp @@ -138,63 +138,6 @@ lbool solver::get_consequences_core(expr_ref_vector const& asms, expr_ref_vector lbool solver::find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return l_true; -#if 0 - // complete for literals, but inefficient. - // see more efficient (incomplete) version in sat_solver - - mutexes.reset(); - ast_manager& m = vars.get_manager(); - - typedef obj_hashtable expr_set; - - expr_set A, P; - - for (unsigned i = 0; i < vars.size(); ++i) { - A.insert(vars[i]); - } - - while (!A.empty()) { - P = A; - expr_ref_vector mutex(m); - while (!P.empty()) { - expr_ref_vector asms(m); - expr* p = *P.begin(); - P.remove(p); - if (!is_literal(m, p)) { - break; - } - mutex.push_back(p); - asms.push_back(p); - expr_set Q; - expr_set::iterator it = P.begin(), end = P.end(); - for (; it != end; ++it) { - expr* q = *it; - scoped_assumption_push _scoped_push(asms, q); - if (is_literal(m, q)) { - lbool is_sat = check_sat(asms); - switch (is_sat) { - case l_false: - Q.insert(q); - break; - case l_true: - break; - case l_undef: - return l_undef; - } - } - } - P = Q; - } - if (mutex.size() > 1) { - mutexes.push_back(mutex); - } - for (unsigned i = 0; i < mutex.size(); ++i) { - A.remove(mutex[i].get()); - } - } - return l_true; -#endif - } bool solver::is_literal(ast_manager& m, expr* e) { diff --git a/src/solver/solver_na2as.cpp b/src/solver/solver_na2as.cpp index 29ce4864f..31895b8ef 100644 --- a/src/solver/solver_na2as.cpp +++ b/src/solver/solver_na2as.cpp @@ -22,6 +22,7 @@ Notes: #include"solver_na2as.h" #include"ast_smt2_pp.h" + solver_na2as::solver_na2as(ast_manager & m): m(m), m_assumptions(m) { @@ -71,6 +72,11 @@ lbool solver_na2as::get_consequences(expr_ref_vector const& asms, expr_ref_vecto return get_consequences_core(m_assumptions, vars, consequences); } +lbool solver_na2as::find_mutexes(expr_ref_vector const& vars, vector& mutexes) { + return l_true; +} + + void solver_na2as::push() { m_scopes.push_back(m_assumptions.size()); push_core(); diff --git a/src/solver/solver_na2as.h b/src/solver/solver_na2as.h index 45253e950..aaa48efe7 100644 --- a/src/solver/solver_na2as.h +++ b/src/solver/solver_na2as.h @@ -46,6 +46,7 @@ public: virtual unsigned get_num_assumptions() const { return m_assumptions.size(); } virtual expr * get_assumption(unsigned idx) const { return m_assumptions[idx]; } virtual lbool get_consequences(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences); + virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes); protected: virtual lbool check_sat_core(unsigned num_assumptions, expr * const * assumptions) = 0; virtual void push_core() = 0; diff --git a/src/tactic/arith/card2bv_tactic.cpp b/src/tactic/arith/card2bv_tactic.cpp index 2b551229d..5019b6550 100644 --- a/src/tactic/arith/card2bv_tactic.cpp +++ b/src/tactic/arith/card2bv_tactic.cpp @@ -168,13 +168,39 @@ namespace pb { return BR_FAILED; } - expr_ref card2bv_rewriter::mk_atmost1(unsigned sz, expr * const* args) { - expr_ref f1(m), f2(m), f3(m), result(m); - f1 = bv.mk_bv(sz, args); - f2 = bv.mk_bv_sub(f1, bv.mk_numeral(rational(1), sz)); - f3 = m.mk_app(bv.get_fid(), OP_BAND, f1, f2); - result = m.mk_eq(f3, bv.mk_numeral(rational(0), sz)); - return result; + expr_ref card2bv_rewriter::mk_atmost1(unsigned n, expr * const* xs) { + expr_ref_vector result(m), in(m); + in.append(n, xs); + unsigned inc_size = 4; + while (!in.empty()) { + expr_ref_vector ors(m); + unsigned i = 0; + unsigned n = in.size(); + bool last = n <= inc_size; + for (; i + inc_size < n; i += inc_size) { + mk_at_most_1_small(last, inc_size, in.c_ptr() + i, result, ors); + } + if (i < n) { + mk_at_most_1_small(last, n - i, in.c_ptr() + i, result, ors); + } + if (last) { + break; + } + in.reset(); + in.append(ors); + } + return mk_and(result); + } + + void card2bv_rewriter::mk_at_most_1_small(bool last, unsigned n, literal const* xs, expr_ref_vector& result, expr_ref_vector& ors) { + if (!last) { + ors.push_back(m.mk_or(n, xs)); + } + for (unsigned i = 0; i < n; ++i) { + for (unsigned j = i + 1; j < n; ++j) { + result.push_back(m.mk_not(m.mk_and(xs[i], xs[j]))); + } + } } bool card2bv_rewriter::is_atmost1(func_decl* f, unsigned sz, expr * const* args, expr_ref& result) { diff --git a/src/tactic/arith/card2bv_tactic.h b/src/tactic/arith/card2bv_tactic.h index 91ed68969..9bf21d2c3 100644 --- a/src/tactic/arith/card2bv_tactic.h +++ b/src/tactic/arith/card2bv_tactic.h @@ -54,6 +54,7 @@ namespace pb { bool is_and(func_decl* f); bool is_atmost1(func_decl* f, unsigned sz, expr * const* args, expr_ref& result); expr_ref mk_atmost1(unsigned sz, expr * const* args); + void mk_at_most_1_small(bool last, unsigned n, literal const* xs, expr_ref_vector& result, expr_ref_vector& ors); public: card2bv_rewriter(ast_manager& m); diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp index 4802e2bec..57e818542 100644 --- a/src/test/sorting_network.cpp +++ b/src/test/sorting_network.cpp @@ -332,7 +332,32 @@ void test_sorting5(unsigned n, unsigned k) { test_sorting_ge(n, k); } +void test_at_most_1(unsigned n) { + ast_manager m; + reg_decl_plugins(m); + expr_ref_vector in(m), out(m); + for (unsigned i = 0; i < n; ++i) { + in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); + } + + + ast_ext2 ext(m); + psort_nw sn(ext); + expr_ref result(m); + result = sn.le(false, 1, in.size(), in.c_ptr()); + std::cout << result << "\n"; + std::cout << ext.m_clauses << "\n"; +} + void tst_sorting_network() { + test_at_most_1(1); + test_at_most_1(2); + test_at_most_1(3); + test_at_most_1(4); + test_at_most_1(5); + test_at_most_1(10); + return; + test_sorting_eq(11,7); for (unsigned n = 3; n < 20; n += 2) { for (unsigned k = 1; k < n; ++k) { diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index b106117ea..a29428dbc 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -201,7 +201,11 @@ Notes: if (dualize(k, n, xs, in)) { return ge(full, k, n, in.c_ptr()); } + else if (k == 1) { + return mk_at_most_1(full, n, xs); + } else { + std::cout << "sort " << k << "\n"; SASSERT(2*k <= n); m_t = full?LE_FULL:LE; card(k + 1, n, xs, out); @@ -230,6 +234,54 @@ Notes: private: + literal mk_at_most_1(bool full, unsigned n, literal const* xs) { + literal_vector in(n, xs); + literal result = ctx.fresh(); + unsigned inc_size = 4; + while (!in.empty()) { + literal_vector ors; + unsigned i = 0; + unsigned n = in.size(); + bool last = n <= inc_size; + for (; i + inc_size < n; i += inc_size) { + mk_at_most_1_small(full, last, inc_size, in.c_ptr() + i, result, ors); + } + if (i < n) { + mk_at_most_1_small(full, last, n - i, in.c_ptr() + i, result, ors); + } + if (last) { + break; + } + in.reset(); + in.append(ors); + ors.reset(); + } + return result; + } + + void mk_at_most_1_small(bool full, bool last, unsigned n, literal const* xs, literal result, literal_vector& ors) { + if (!last) { + literal ex = ctx.fresh(); + for (unsigned j = 0; j < n; ++j) { + add_clause(ctx.mk_not(xs[j]), ex); + } + if (full) { + literal_vector lits(n, xs); + lits.push_back(ctx.mk_not(ex)); + add_clause(lits.size(), lits.c_ptr()); + } + ors.push_back(ex); + } + for (unsigned i = 0; i < n; ++i) { + for (unsigned j = i + 1; j < n; ++j) { + add_clause(ctx.mk_not(result), ctx.mk_not(xs[i]), ctx.mk_not(xs[j])); + } + if (full) { + add_clause(result, xs[i]); + } + } + } + std::ostream& pp(std::ostream& out, unsigned n, literal const* lits) { for (unsigned i = 0; i < n; ++i) ctx.pp(out, lits[i]) << " "; return out;