diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index 3862aecae..d1def83a1 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -865,8 +865,8 @@ struct pb2bv_rewriter::imp { // definitions used for sorting network pliteral mk_false() { return m.mk_false(); } pliteral mk_true() { return m.mk_true(); } - pliteral mk_max(pliteral a, pliteral b) { return trail(m.mk_or(a, b)); } - pliteral mk_min(pliteral a, pliteral b) { return trail(m.mk_and(a, b)); } + pliteral mk_max(unsigned n, pliteral const* lits) { return trail(m.mk_or(n, lits)); } + pliteral mk_min(unsigned n, pliteral const* lits) { return trail(m.mk_and(n, lits)); } pliteral mk_not(pliteral a) { if (m.is_not(a,a)) return a; return trail(m.mk_not(a)); } std::ostream& pp(std::ostream& out, pliteral lit) { return out << mk_ismt2_pp(lit, m); } @@ -889,7 +889,7 @@ struct pb2bv_rewriter::imp { m_keep_cardinality_constraints = f; } - void set_at_most1(sorting_network_encoding enc) { m_sort.cfg().m_encoding = enc; } + void set_cardinality_encoding(sorting_network_encoding enc) { m_sort.cfg().m_encoding = enc; } }; @@ -904,7 +904,7 @@ struct pb2bv_rewriter::imp { card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {} void keep_cardinality_constraints(bool f) { m_r.keep_cardinality_constraints(f); } void set_pb_solver(symbol const& s) { m_r.set_pb_solver(s); } - void set_at_most1(sorting_network_encoding enc) { m_r.set_at_most1(enc); } + void set_cardinality_encoding(sorting_network_encoding enc) { m_r.set_cardinality_encoding(enc); } }; @@ -916,7 +916,7 @@ struct pb2bv_rewriter::imp { m_cfg(i, m) {} void keep_cardinality_constraints(bool f) { m_cfg.keep_cardinality_constraints(f); } void set_pb_solver(symbol const& s) { m_cfg.set_pb_solver(s); } - void set_at_most1(sorting_network_encoding e) { m_cfg.set_at_most1(e); } + void set_cardinality_encoding(sorting_network_encoding e) { m_cfg.set_cardinality_encoding(e); } void rewrite(bool full, expr* e, expr_ref& r, proof_ref& p) { expr_ref ee(e, m()); if (m_cfg.m_r.mk_app(full, e, r)) { @@ -947,15 +947,17 @@ struct pb2bv_rewriter::imp { return gparams::get_module("sat").get_sym("pb.solver", symbol("solver")); } - sorting_network_encoding atmost1_encoding() const { - symbol enc = m_params.get_sym("atmost1_encoding", symbol()); + sorting_network_encoding cardinality_encoding() const { + symbol enc = m_params.get_sym("cardinality.encoding", symbol()); if (enc == symbol()) { - enc = gparams::get_module("sat").get_sym("atmost1_encoding", symbol()); + enc = gparams::get_module("sat").get_sym("cardinality.encoding", symbol()); } - if (enc == symbol("grouped")) return sorting_network_encoding::grouped_at_most_1; - if (enc == symbol("bimander")) return sorting_network_encoding::bimander_at_most_1; - if (enc == symbol("ordered")) return sorting_network_encoding::ordered_at_most_1; - return grouped_at_most_1; + if (enc == symbol("grouped")) return sorting_network_encoding::grouped_at_most; + if (enc == symbol("bimander")) return sorting_network_encoding::bimander_at_most; + if (enc == symbol("ordered")) return sorting_network_encoding::ordered_at_most; + if (enc == symbol("unate")) return sorting_network_encoding::unate_at_most; + if (enc == symbol("circuit")) return sorting_network_encoding::circuit_at_most; + return grouped_at_most; } @@ -973,10 +975,11 @@ struct pb2bv_rewriter::imp { m_params.append(p); m_rw.keep_cardinality_constraints(keep_cardinality()); m_rw.set_pb_solver(pb_solver()); - m_rw.set_at_most1(atmost1_encoding()); + m_rw.set_cardinality_encoding(cardinality_encoding()); } + void collect_param_descrs(param_descrs& r) const { - r.insert("keep_cardinality_constraints", CPK_BOOL, "(default: true) retain cardinality constraints (don't bit-blast them) and use built-in cardinality solver"); + r.insert("keep_cardinality_constraints", CPK_BOOL, "(default: false) retain cardinality constraints (don't bit-blast them) and use built-in cardinality solver"); r.insert("pb.solver", CPK_SYMBOL, "(default: solver) retain pb constraints (don't bit-blast them) and use built-in pb solver"); } diff --git a/src/opt/sortmax.cpp b/src/opt/sortmax.cpp index 4313cfbec..1fafa12bd 100644 --- a/src/opt/sortmax.cpp +++ b/src/opt/sortmax.cpp @@ -124,8 +124,8 @@ namespace opt { // definitions used for sorting network pliteral mk_false() { return m.mk_false(); } pliteral mk_true() { return m.mk_true(); } - pliteral mk_max(pliteral a, pliteral b) { return trail(m.mk_or(a, b)); } - pliteral mk_min(pliteral a, pliteral b) { return trail(m.mk_and(a, b)); } + pliteral mk_max(unsigned n, pliteral const* as) { return trail(m.mk_or(n, as)); } + pliteral mk_min(unsigned n, pliteral const* as) { return trail(m.mk_and(n, as)); } pliteral mk_not(pliteral a) { if (m.is_not(a,a)) return a; return trail(m.mk_not(a)); } std::ostream& pp(std::ostream& out, pliteral lit) { return out << mk_pp(lit, m); } diff --git a/src/qe/qe_arith_plugin.cpp b/src/qe/qe_arith_plugin.cpp index 21e50182f..f8c519285 100644 --- a/src/qe/qe_arith_plugin.cpp +++ b/src/qe/qe_arith_plugin.cpp @@ -1301,6 +1301,7 @@ namespace qe { ptr_vector todo; todo.push_back(a); rational k1, k2; + expr* e1 = nullptr, *e2 = nullptr; expr_ref rest(m); while (!todo.empty()) { expr* e = todo.back(); @@ -1319,9 +1320,9 @@ namespace qe { return false; } a = to_app(e); - if (m_util.m_arith.is_mod(e) && - m_util.m_arith.is_numeral(to_app(e)->get_arg(1), k1) && - m_util.get_coeff(contains_x, to_app(e)->get_arg(0), k2, rest)) { + if (m_util.m_arith.is_mod(e, e1, e2) && + m_util.m_arith.is_numeral(e2, k1) && + m_util.get_coeff(contains_x, e1, k2, rest)) { app_ref z(m), z_bv(m); m_util.mk_bounded_var(k1, z_bv, z); m_nested_div_terms.push_back(rest); @@ -1331,10 +1332,9 @@ namespace qe { m_nested_div_z.push_back(z); continue; } - unsigned num_args = a->get_num_args(); - for (unsigned i = 0; i < num_args; ++i) { - todo.push_back(a->get_arg(i)); - } + for (expr* arg : *a) { + todo.push_back(arg); + } } return true; } diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index 930617301..38cd07fa7 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -2591,22 +2591,54 @@ namespace sat { return literal(v, false); } - literal ba_solver::ba_sort::mk_max(literal l1, literal l2) { - VERIFY(l1 != null_literal); - VERIFY(l2 != null_literal); - if (l1 == m_true) return l1; - if (l2 == m_true) return l2; - if (l1 == ~m_true) return l2; - if (l2 == ~m_true) return l1; - literal max = fresh("max"); - s.s().mk_clause(~l1, max); - s.s().mk_clause(~l2, max); - s.s().mk_clause(~max, l1, l2); - return max; + + literal ba_solver::ba_sort::mk_max(unsigned n, literal const* lits) { + m_lits.reset(); + for (unsigned i = 0; i < n; ++i) { + if (lits[i] == m_true) return m_true; + if (lits[i] == ~m_true) continue; + m_lits.push_back(lits[i]); + } + switch (m_lits.size()) { + case 0: + return ~m_true; + case 1: + return m_lits[0]; + default: { + literal max = fresh("max"); + for (unsigned i = 0; i < n; ++i) { + s.s().mk_clause(~m_lits[i], max); + } + m_lits.push_back(~max); + s.s().mk_clause(m_lits.size(), m_lits.c_ptr()); + return max; + } + } } - literal ba_solver::ba_sort::mk_min(literal l1, literal l2) { - return ~mk_max(~l1, ~l2); + literal ba_solver::ba_sort::mk_min(unsigned n, literal const* lits) { + m_lits.reset(); + for (unsigned i = 0; i < n; ++i) { + if (lits[i] == ~m_true) return ~m_true; + if (lits[i] == m_true) continue; + m_lits.push_back(lits[i]); + } + switch (m_lits.size()) { + case 0: + return m_true; + case 1: + return m_lits[0]; + default: { + literal min = fresh("min"); + for (unsigned i = 0; i < n; ++i) { + s.s().mk_clause(~min, m_lits[i]); + m_lits[i] = ~m_lits[i]; + } + m_lits.push_back(min); + s.s().mk_clause(m_lits.size(), m_lits.c_ptr()); + return min; + } + } } void ba_solver::ba_sort::mk_clause(unsigned n, literal const* lits) { diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index e947cee96..bae59f45a 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -255,8 +255,8 @@ namespace sat { pliteral mk_true(); pliteral mk_not(pliteral l); pliteral fresh(char const*); - pliteral mk_max(pliteral l1, pliteral l2); - pliteral mk_min(pliteral l1, pliteral l2); + pliteral mk_min(unsigned, pliteral const* lits); + pliteral mk_max(unsigned, pliteral const* lits); void mk_clause(unsigned n, literal const* lits); std::ostream& pp(std::ostream& out, pliteral l) const; }; diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index dd840468e..89776c479 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -43,7 +43,7 @@ def_module_params('sat', ('cardinality.solver', BOOL, True, 'use cardinality solver'), ('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use native solver)'), ('xor.solver', BOOL, False, 'use xor solver'), - ('atmost1_encoding', SYMBOL, 'grouped', 'encoding used for at-most-1 constraints grouped, bimander, ordered'), + ('cardinality.encoding', SYMBOL, 'grouped', 'encoding used for at-most-k constraints: grouped, bimander, ordered, unate, circuit'), ('local_search', BOOL, False, 'use local search instead of CDCL'), ('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'), ('local_search_mode', SYMBOL, 'wsat', 'local search algorithm, either default wsat or qsat'), diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index ed0481938..e389c819e 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -1429,23 +1429,27 @@ namespace smt { return literal(ctx.mk_bool_var(y)); } - literal mk_max(literal a, literal b) { - if (a == b) return a; - expr_ref t1(m), t2(m), t3(m); - ctx.literal2expr(a, t1); - ctx.literal2expr(b, t2); - t3 = m.mk_or(t1, t2); - bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); + literal mk_max(unsigned n, literal const* lits) { + expr_ref_vector es(m); + expr_ref tmp(m); + for (unsigned i = 0; i < n; ++i) { + ctx.literal2expr(lits[i], tmp); + es.push_back(tmp); + } + tmp = m.mk_or(es.size(), es.c_ptr()); + bool_var v = ctx.b_internalized(tmp)?ctx.get_bool_var(tmp):ctx.mk_bool_var(tmp); return literal(v); } - - literal mk_min(literal a, literal b) { - if (a == b) return a; - expr_ref t1(m), t2(m), t3(m); - ctx.literal2expr(a, t1); - ctx.literal2expr(b, t2); - t3 = m.mk_and(t1, t2); - bool_var v = ctx.b_internalized(t3)?ctx.get_bool_var(t3):ctx.mk_bool_var(t3); + + literal mk_min(unsigned n, literal const* lits) { + expr_ref_vector es(m); + expr_ref tmp(m); + for (unsigned i = 0; i < n; ++i) { + ctx.literal2expr(lits[i], tmp); + es.push_back(tmp); + } + tmp = m.mk_and(es.size(), es.c_ptr()); + bool_var v = ctx.b_internalized(tmp)?ctx.get_bool_var(tmp):ctx.mk_bool_var(tmp); return literal(v); } diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp index f5c415c04..2470df528 100644 --- a/src/test/sorting_network.cpp +++ b/src/test/sorting_network.cpp @@ -134,16 +134,12 @@ void test_sorting3() { for (unsigned i = 0; i < 7; ++i) { in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); } - for (unsigned i = 0; i < in.size(); ++i) { - std::cout << mk_pp(in[i].get(), m) << "\n"; - } + for (expr* e : in) std::cout << mk_pp(e, m) << "\n"; ast_ext aext(m); sorting_network sn(aext); sn(in, out); std::cout << "size: " << out.size() << "\n"; - for (unsigned i = 0; i < out.size(); ++i) { - std::cout << mk_pp(out[i].get(), m) << "\n"; - } + for (expr* e : out) std::cout << mk_pp(e, m) << "\n"; } @@ -162,10 +158,12 @@ struct ast_ext2 { pliteral mk_false() { return m.mk_false(); } pliteral mk_true() { return m.mk_true(); } - pliteral mk_max(pliteral a, pliteral b) { - return trail(m.mk_or(a, b)); + pliteral mk_max(unsigned n, pliteral const* lits) { + return trail(m.mk_or(n, lits)); + } + pliteral mk_min(unsigned n, pliteral const* lits) { + return trail(m.mk_and(n, lits)); } - pliteral mk_min(pliteral a, pliteral b) { return trail(m.mk_and(a, b)); } pliteral mk_not(pliteral a) { if (m.is_not(a,a)) return a; return trail(m.mk_not(a)); } @@ -199,8 +197,8 @@ static void test_eq1(unsigned n, sorting_network_encoding enc) { // equality: solver.push(); result1 = sn.eq(true, 1, in.size(), in.c_ptr()); - for (expr* cl : ext.m_clauses) { - solver.assert_expr(cl); + for (expr* cls : ext.m_clauses) { + solver.assert_expr(cls); } expr_ref_vector ors(m); for (unsigned i = 0; i < n; ++i) { @@ -245,12 +243,15 @@ static void test_sorting_eq(unsigned n, unsigned k, sorting_network_encoding enc std::cout << "eq " << k << " out of " << n << " for encoding " << enc << "\n"; solver.push(); result = sn.eq(false, k, in.size(), in.c_ptr()); - std::cout << result << "\n" << ext.m_clauses << "\n"; solver.assert_expr(result); for (expr* cl : ext.m_clauses) { solver.assert_expr(cl); } lbool res = solver.check(); + if (res != l_true) { + std::cout << res << "\n"; + solver.display(std::cout); + } ENSURE(res == l_true); solver.push(); @@ -258,6 +259,9 @@ static void test_sorting_eq(unsigned n, unsigned k, sorting_network_encoding enc solver.assert_expr(in[i].get()); } res = solver.check(); + if (res != l_true) { + std::cout << result << "\n" << ext.m_clauses << "\n"; + } ENSURE(res == l_true); solver.assert_expr(in[k].get()); res = solver.check(); @@ -295,16 +299,26 @@ static void test_sorting_le(unsigned n, unsigned k, sorting_network_encoding enc solver.push(); result = sn.le(false, k, in.size(), in.c_ptr()); solver.assert_expr(result); - for (unsigned i = 0; i < ext.m_clauses.size(); ++i) { - solver.assert_expr(ext.m_clauses[i].get()); + for (expr* cls : ext.m_clauses) { + solver.assert_expr(cls); } lbool res = solver.check(); + if (res != l_true) { + std::cout << res << "\n"; + solver.display(std::cout); + std::cout << "clauses: " << ext.m_clauses << "\n"; + std::cout << "result: " << result << "\n"; + } ENSURE(res == l_true); for (unsigned i = 0; i < k; ++i) { solver.assert_expr(in[i].get()); } res = solver.check(); + if (res != l_true) { + std::cout << res << "\n"; + solver.display(std::cout); + } ENSURE(res == l_true); solver.assert_expr(in[k].get()); res = solver.check(); @@ -343,8 +357,8 @@ void test_sorting_ge(unsigned n, unsigned k, sorting_network_encoding enc) { solver.push(); result = sn.ge(false, k, in.size(), in.c_ptr()); solver.assert_expr(result); - for (unsigned i = 0; i < ext.m_clauses.size(); ++i) { - solver.assert_expr(ext.m_clauses[i].get()); + for (expr* cls : ext.m_clauses) { + solver.assert_expr(cls); } lbool res = solver.check(); ENSURE(res == l_true); @@ -407,13 +421,13 @@ void test_at_most_1(unsigned n, bool full, sorting_network_encoding enc) { std::cout << "clauses: " << ext.m_clauses << "\n-----\n"; - std::cout << "encoded: " << result1 << "\n"; - std::cout << "naive: " << result2 << "\n"; + //std::cout << "encoded: " << result1 << "\n"; + //std::cout << "naive: " << result2 << "\n"; smt_params fp; smt::kernel solver(m, fp); - for (unsigned i = 0; i < ext.m_clauses.size(); ++i) { - solver.assert_expr(ext.m_clauses[i].get()); + for (expr* cls : ext.m_clauses) { + solver.assert_expr(cls); } if (full) { solver.push(); @@ -481,8 +495,8 @@ static void test_at_most1(sorting_network_encoding enc) { sn.cfg().m_encoding = enc; expr_ref result(m); result = sn.le(true, 1, in.size(), in.c_ptr()); - std::cout << result << "\n"; - std::cout << ext.m_clauses << "\n"; + //std::cout << result << "\n"; + //std::cout << ext.m_clauses << "\n"; } static void test_sorting5(sorting_network_encoding enc) { @@ -509,9 +523,11 @@ static void tst_sorting_network(sorting_network_encoding enc) { } void tst_sorting_network() { - tst_sorting_network(sorting_network_encoding::ordered_at_most_1); - tst_sorting_network(sorting_network_encoding::grouped_at_most_1); - tst_sorting_network(sorting_network_encoding::bimander_at_most_1); + tst_sorting_network(sorting_network_encoding::unate_at_most); + tst_sorting_network(sorting_network_encoding::circuit_at_most); + tst_sorting_network(sorting_network_encoding::ordered_at_most); + tst_sorting_network(sorting_network_encoding::grouped_at_most); + tst_sorting_network(sorting_network_encoding::bimander_at_most); test_sorting1(); test_sorting2(); test_sorting3(); diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index 2a0d929dd..d7d8e8bbe 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -25,16 +25,22 @@ Notes: #define SORTING_NETWORK_H_ enum sorting_network_encoding { - grouped_at_most_1, - bimander_at_most_1, - ordered_at_most_1 + sorted_at_most, + grouped_at_most, + bimander_at_most, + ordered_at_most, + unate_at_most, + circuit_at_most }; inline std::ostream& operator<<(std::ostream& out, sorting_network_encoding enc) { switch (enc) { - case grouped_at_most_1: return out << "grouped"; - case bimander_at_most_1: return out << "bimander"; - case ordered_at_most_1: return out << "ordered"; + case grouped_at_most: return out << "grouped"; + case bimander_at_most: return out << "bimander"; + case ordered_at_most: return out << "ordered"; + case sorted_at_most: return out << "sorted"; + case unate_at_most: return out << "unate"; + case circuit_at_most: return out << "circuit"; } return out << "???"; } @@ -42,7 +48,7 @@ Notes: struct sorting_network_config { sorting_network_encoding m_encoding; sorting_network_config() { - m_encoding = grouped_at_most_1; + m_encoding = sorted_at_most; } }; @@ -158,12 +164,19 @@ Notes: vc operator+(vc const& other) const { return vc(v + other.v, c + other.c); } + vc operator-(vc const& other) const { + return vc(v - other.v, c - other.c); + } unsigned to_int() const { return lambda*v + c; } vc operator*(unsigned n) const { return vc(n*v, n*c); } + + std::ostream& pp(std::ostream& out) const { + return out << "v: " << v << " c: " << c; + } }; static vc mk_min(vc const& v1, vc const& v2) { @@ -176,13 +189,17 @@ Notes: cmp_t m_t; // for testing - static const bool m_disable_dcard = false; - static const bool m_disable_dsorting = true; // false; - static const bool m_disable_dsmerge = true; // false; + static const bool m_disable_dcard = false; + static const bool m_disable_dsorting = false; + static const bool m_disable_dsmerge = false; static const bool m_force_dcard = false; static const bool m_force_dsorting = false; static const bool m_force_dsmerge = false; + bool is_power_of2(unsigned n) const { + return n != 0 && ((n-1) & n) == 0; + } + public: struct stats { unsigned m_num_compiled_vars; @@ -221,15 +238,31 @@ Notes: } SASSERT(0 < k && k <= n); literal_vector in, out; + if (k == 1) { + return mk_or(n, xs); + } if (dualize(k, n, xs, in)) { return le(full, k, in.size(), in.c_ptr()); } else { - SASSERT(2*k <= n); - m_t = full?GE_FULL:GE; - // scoped_stats _ss(m_stats, k, n); - psort_nw::card(k, n, xs, out); - return out[k-1]; + switch (m_cfg.m_encoding) { + case sorted_at_most: + case bimander_at_most: + case ordered_at_most: + case grouped_at_most: + SASSERT(2*k <= n); + m_t = full?GE_FULL:GE; + // scoped_stats _ss(m_stats, k, n); + psort_nw::card(k, n, xs, out); + return out[k-1]; + case unate_at_most: + return unate_ge(full, k, n, xs); + case circuit_at_most: + return circuit_ge(full, k, n, xs); + default: + UNREACHABLE(); + return xs[0]; + } } } @@ -246,23 +279,40 @@ Notes: literal_vector ors; // scoped_stats _ss(m_stats, k, n); switch (m_cfg.m_encoding) { - case grouped_at_most_1: + case grouped_at_most: + case sorted_at_most: + case unate_at_most: + case circuit_at_most: return mk_at_most_1(full, n, xs, ors, false); - case bimander_at_most_1: + case bimander_at_most: return mk_at_most_1_bimander(full, n, xs, ors); - case ordered_at_most_1: + case ordered_at_most: return mk_ordered_atmost_1(full, n, xs); + default: UNREACHABLE(); return xs[0]; } } else { - SASSERT(2*k <= n); - m_t = full?LE_FULL:LE; - // scoped_stats _ss(m_stats, k, n); - card(k + 1, n, xs, out); - return ctx.mk_not(out[k]); + switch (m_cfg.m_encoding) { + case sorted_at_most: + case bimander_at_most: + case ordered_at_most: + case grouped_at_most: + SASSERT(2*k <= n); + m_t = full?LE_FULL:LE; + // scoped_stats _ss(m_stats, k, n); + card(k + 1, n, xs, out); + return mk_not(out[k]); + case unate_at_most: + return unate_le(full, k, n, xs); + case circuit_at_most: + return circuit_le(full, k, n, xs); + default: + UNREACHABLE(); + return xs[0]; + } } } @@ -280,16 +330,29 @@ Notes: return mk_exactly_1(full, n, xs); } else { - // scoped_stats _ss(m_stats, k, n); - SASSERT(2*k <= n); - m_t = EQ; - card(k+1, n, xs, out); - SASSERT(out.size() >= k+1); - if (k == 0) { - return ctx.mk_not(out[k]); - } - else { - return ctx.mk_min(out[k-1], ctx.mk_not(out[k])); + switch (m_cfg.m_encoding) { + case sorted_at_most: + case bimander_at_most: + case grouped_at_most: + case ordered_at_most: + // scoped_stats _ss(m_stats, k, n); + SASSERT(2*k <= n); + m_t = EQ; + card(k+1, n, xs, out); + SASSERT(out.size() >= k+1); + if (k == 0) { + return mk_not(out[k]); + } + else { + return mk_min(out[k-1], mk_not(out[k])); + } + case unate_at_most: + return unate_eq(k, n, xs); + case circuit_at_most: + return circuit_eq(k, n, xs); + default: + UNREACHABLE(); + return xs[0]; } } } @@ -297,49 +360,193 @@ Notes: private: + // perform unate addition up to k. + literal unate_cmp(cmp_t cmp, unsigned k, unsigned n, literal const* xs) { + unsigned last = k; + if (cmp == LE || cmp == EQ || cmp == LE_FULL) { + last = k + 1; + } + bool full = cmp == LE_FULL || cmp == GE_FULL; + + literal_vector carry; + for (unsigned i = 0; i < last; ++i) { + carry.push_back(ctx.mk_false()); + } + for (unsigned i = 0; i < n; ++i) { + for (unsigned j = last; j-- > 0; ) { + // c'[j] <-> (xs[i] & c[j-1]) | c[j] + literal c0 = j > 0 ? carry[j-1] : ctx.mk_true(); + carry[j] = mk_or(mk_and(xs[i], c0), carry[j]); + } + } + switch (cmp) { + case LE: + case LE_FULL: + return mk_not(carry[k]); + case GE: + case GE_FULL: + return carry[k-1]; + case EQ: + return mk_and(mk_not(carry[k]), carry[k-1]); + default: + UNREACHABLE(); + return xs[0]; + } + } + + literal unate_ge(bool full, unsigned k, unsigned n, literal const* xs) { + return unate_cmp(full ? GE_FULL : GE, k, n, xs); + } + + literal unate_le(bool full, unsigned k, unsigned n, literal const* xs) { + return unate_cmp(full ? LE_FULL : LE, k, n, xs); + } + + literal unate_eq(unsigned k, unsigned n, literal const* xs) { + return unate_cmp(EQ, k, n, xs); + } + + // circuit encoding + void mk_unit_circuit(unsigned k, literal x, literal_vector& out) { + out.push_back(x); + for (unsigned i = 1; i < k; ++i) out.push_back(ctx.mk_false()); + } + + literal mk_add_circuit(literal_vector const& x, literal_vector const& y, literal_vector& out) { + literal c = ctx.mk_false(); + SASSERT(x.size() == y.size()); + for (unsigned i = 0; i < x.size(); ++i) { + // out[i] = c + x[i] + y[i] + // c' = c&x[i] | c&y[i] | x[i]&y[i]; + literal_vector ors; + ors.push_back(mk_and(c, mk_not(x[i]), mk_not(y[i]))); + ors.push_back(mk_and(x[i], mk_not(c), mk_not(y[i]))); + ors.push_back(mk_and(y[i], mk_not(c), mk_not(x[i]))); + ors.push_back(mk_and(c, x[i], y[i])); + literal o = mk_or(4, ors.c_ptr()); + out.push_back(o); + ors[0] = mk_and(c, x[i]); + ors[1] = mk_and(c, y[i]); + ors[2] = mk_and(x[i], y[i]); + c = mk_or(3, ors.c_ptr()); + } + return c; + } + + literal circuit_add(unsigned k, unsigned n, literal const* xs, literal_vector& out) { + switch (n) { + case 0: + for (unsigned i = 0; i < k; ++i) { + out.push_back(ctx.mk_false()); + } + return ctx.mk_false(); + case 1: + mk_unit_circuit(k, xs[0], out); + return ctx.mk_false(); + default: { + literal_vector o1, o2; + unsigned half = n / 2; + literal ovfl1 = circuit_add(k, half, xs, o1); + literal ovfl2 = circuit_add(k, n - half, xs + half, o2); + literal ovfl3 = mk_add_circuit(o1, o2, out); + return mk_or(ovfl1, ovfl2, ovfl3); + } + } + } + + literal circuit_cmp(cmp_t cmp, unsigned k, unsigned n, literal const* xs) { + literal_vector out, kvec; + unsigned num_bits = 0; + unsigned k1 = (cmp == LE || cmp == LE_FULL) ? k + 1 : k; + unsigned k0 = k1; + while (k0 > 0) { ++num_bits; k0 >>= 1; } + for (unsigned i = 0; i < num_bits; ++i) { + kvec.push_back((0 != (k1 & (1 << i))) ? ctx.mk_true() : ctx.mk_false()); + } + literal ovfl = circuit_add(num_bits, n, xs, out); + switch (cmp) { + case LE: + case LE_FULL: + return mk_not(mk_or(ovfl, mk_ge(out, kvec))); + case GE: + case GE_FULL: + return mk_or(ovfl, mk_ge(out, kvec)); + case EQ: { + literal_vector eqs; + SASSERT(kvec.size() == out.size()); + for (unsigned i = 0; i < num_bits; ++i) { + eqs.push_back(mk_or(mk_not(kvec[i]), out[i])); + eqs.push_back(mk_or(kvec[i], mk_not(out[i]))); + } + eqs.push_back(mk_not(ovfl)); + return mk_and(eqs); + } + default: + UNREACHABLE(); + return xs[0]; + } + } + + literal mk_ge(literal_vector const& x, literal_vector const& y) { + literal r = ctx.mk_true(); + literal g = ctx.mk_false(); + for (unsigned j = x.size(); j-- > 0; ) { + g = mk_or(g, mk_and(r, mk_and(x[j], mk_not(y[j])))); + r = mk_or(g, mk_and(r, mk_or( x[j], mk_not(y[j])))); + } + return r; + } + + literal circuit_ge(bool full, unsigned k, unsigned n, literal const* xs) { + return circuit_cmp(full ? GE_FULL : GE, k, n, xs); + } + + literal circuit_le(bool full, unsigned k, unsigned n, literal const* xs) { + return circuit_cmp(full ? LE_FULL : LE, k, n, xs); + } + + literal circuit_eq(unsigned k, unsigned n, literal const* xs) { + return circuit_cmp(EQ, k, n, xs); + } + void add_implies_or(literal l, unsigned n, literal const* xs) { literal_vector lits(n, xs); - lits.push_back(ctx.mk_not(l)); + lits.push_back(mk_not(l)); add_clause(lits); } - void add_or_implies(literal l, unsigned n, literal const* xs) { - for (unsigned j = 0; j < n; ++j) { - add_clause(ctx.mk_not(xs[j]), l); + literal mk_or(unsigned n, literal const* _ors) { + literal_vector ors(n, _ors); + unsigned j = 0; + for (literal lit : ors) { + if (is_true(lit)) return lit; + if (!is_false(lit)) ors[j++] = lit; } - } - - literal mk_or(unsigned n, literal const* ors) { - if (n == 1) { - return ors[0]; + ors.shrink(j); + switch (j) { + case 0: return ctx.mk_false(); + case 1: return ors[0]; + default: return ctx.mk_max(ors.size(), ors.c_ptr()); } - literal result = fresh("or"); - add_implies_or(result, n, ors); - add_or_implies(result, n, ors); - return result; } literal mk_or(literal l1, literal l2) { literal ors[2] = { l1, l2 }; return mk_or(2, ors); } + literal mk_or(literal l1, literal l2, literal l3) { + literal ors[3] = { l1, l2, l3 }; + return mk_or(3, ors); + } + literal mk_or(literal_vector const& ors) { return mk_or(ors.size(), ors.c_ptr()); } - void add_implies_and(literal l, literal_vector const& xs) { - for (literal const& x : xs) { - add_clause(ctx.mk_not(l), x); - } - } - - void add_and_implies(literal l, literal_vector const& xs) { - literal_vector lits; - for (literal const& x : xs) { - lits.push_back(ctx.mk_not(x)); - } - lits.push_back(l); - add_clause(lits); + literal mk_not(literal lit) { + if (is_true(lit)) return ctx.mk_false(); + if (is_false(lit)) return ctx.mk_true(); + return ctx.mk_not(lit); } literal mk_and(literal l1, literal l2) { @@ -348,14 +555,39 @@ Notes: return mk_and(xs); } - literal mk_and(literal_vector const& ands) { - if (ands.size() == 1) { - return ands[0]; + literal mk_and(literal l1, literal l2, literal l3) { + literal_vector xs; + xs.push_back(l1); xs.push_back(l2); xs.push_back(l3); + return mk_and(xs); + } + + bool is_true(literal l) { + return l == ctx.mk_true(); + } + + bool is_false(literal l) { + return l == ctx.mk_false(); + } + + literal mk_and(literal_vector const& _ands) { + literal_vector ands(_ands); + unsigned j = 0; + for (literal lit : ands) { + if (is_false(lit)) return lit; + if (!is_true(lit)) ands[j++] = lit; + } + ands.shrink(j); + switch (j) { + case 0: + return ctx.mk_true(); + case 1: + return ands[0]; + case 2: + return mk_min(ands[0], ands[1]); + default: { + return ctx.mk_min(ands.size(), ands.c_ptr()); + } } - literal result = fresh("and"); - add_implies_and(result, ands); - add_and_implies(result, ands); - return result; } literal mk_exactly_1(bool full, unsigned n, literal const* xs) { @@ -363,13 +595,16 @@ Notes: literal_vector ors; literal r1; switch (m_cfg.m_encoding) { - case grouped_at_most_1: + case grouped_at_most: + case sorted_at_most: + case unate_at_most: + case circuit_at_most: r1 = mk_at_most_1(full, n, xs, ors, true); break; - case bimander_at_most_1: + case bimander_at_most: r1 = mk_at_most_1_bimander(full, n, xs, ors); break; - case ordered_at_most_1: + case ordered_at_most: return mk_ordered_exactly_1(full, n, xs); default: UNREACHABLE(); @@ -426,7 +661,7 @@ Notes: // result => xs[0] + ... + xs[n-1] <= 1 for (unsigned i = 0; i < n; ++i) { for (unsigned j = i + 1; j < n; ++j) { - add_clause(ctx.mk_not(result), ctx.mk_not(xs[i]), ctx.mk_not(xs[j])); + add_clause(mk_not(result), mk_not(xs[i]), mk_not(xs[j])); } } @@ -441,7 +676,7 @@ Notes: } add_clause(lits); } - ands.push_back(ctx.mk_not(and_i)); + ands.push_back(mk_not(and_i)); } } @@ -457,7 +692,7 @@ Notes: literal_vector ands; for (unsigned i = 0; i < n; ++i) { for (unsigned j = i + 1; j < n; ++j) { - ands.push_back(mk_or(ctx.mk_not(xs[i]), ctx.mk_not(xs[j]))); + ands.push_back(mk_or(mk_not(xs[i]), mk_not(xs[j]))); } } return mk_and(ands); @@ -513,36 +748,36 @@ Notes: ys.push_back(fresh("y")); } for (unsigned i = 0; i + 2 < n; ++i) { - add_clause(ctx.mk_not(ys[i]), ys[i + 1]); + add_clause(mk_not(ys[i]), ys[i + 1]); } for (unsigned i = 0; i + 1 < n; ++i) { - add_clause(ctx.mk_not(xs[i]), ys[i]); - add_clause(ctx.mk_not(r), ctx.mk_not(ys[i]), ctx.mk_not(xs[i + 1])); + add_clause(mk_not(xs[i]), ys[i]); + add_clause(mk_not(r), mk_not(ys[i]), mk_not(xs[i + 1])); } if (is_eq) { - add_clause(ctx.mk_not(r), ys[n-2], xs[n-1]); + add_clause(mk_not(r), ys[n-2], xs[n-1]); } for (unsigned i = 1; i < n - 1; ++i) { - add_clause(ctx.mk_not(ys[i]), xs[i], ys[i-1]); + add_clause(mk_not(ys[i]), xs[i], ys[i-1]); } - add_clause(ctx.mk_not(ys[0]), xs[0]); + add_clause(mk_not(ys[0]), xs[0]); if (full) { literal_vector twos; for (unsigned i = 0; i < n - 1; ++i) { twos.push_back(fresh("two")); } - add_clause(ctx.mk_not(twos[0]), ys[0]); - add_clause(ctx.mk_not(twos[0]), xs[1]); + add_clause(mk_not(twos[0]), ys[0]); + add_clause(mk_not(twos[0]), xs[1]); for (unsigned i = 1; i < n - 1; ++i) { - add_clause(ctx.mk_not(twos[i]), ys[i], twos[i-1]); - add_clause(ctx.mk_not(twos[i]), xs[i + 1], twos[i-1]); + add_clause(mk_not(twos[i]), ys[i], twos[i-1]); + add_clause(mk_not(twos[i]), xs[i + 1], twos[i-1]); } if (is_eq) { literal zero = fresh("zero"); - add_clause(ctx.mk_not(zero), ctx.mk_not(xs[n-1])); - add_clause(ctx.mk_not(zero), ctx.mk_not(ys[n-2])); + add_clause(mk_not(zero), mk_not(xs[n-1])); + add_clause(mk_not(zero), mk_not(ys[n-2])); add_clause(r, zero, twos.back()); } else { @@ -579,7 +814,7 @@ Notes: for (unsigned i = 0; i < ors.size(); ++i) { for (unsigned k = 0; k < nbits; ++k) { bool bit_set = (i & (static_cast(1 << k))) != 0; - add_clause(ctx.mk_not(result), ctx.mk_not(ors[i]), bit_set ? bits[k] : ctx.mk_not(bits[k])); + add_clause(mk_not(result), mk_not(ors[i]), bit_set ? bits[k] : mk_not(bits[k])); } } return result; @@ -608,7 +843,7 @@ Notes: } k = N - k; for (unsigned i = 0; i < N; ++i) { - in.push_back(ctx.mk_not(xs[i])); + in.push_back(mk_not(xs[i])); } TRACE("pb_verbose", //pp(tout << N << ": ", in); @@ -627,13 +862,15 @@ Notes: literal mk_max(literal a, literal b) { if (a == b) return a; m_stats.m_num_compiled_vars++; - return ctx.mk_max(a, b); + literal lits[2] = { a, b}; + return ctx.mk_max(2, lits); } literal mk_min(literal a, literal b) { if (a == b) return a; m_stats.m_num_compiled_vars++; - return ctx.mk_min(a, b); + literal lits[2] = { a, b}; + return ctx.mk_min(2, lits); } literal fresh(char const* n) { @@ -652,6 +889,9 @@ Notes: add_clause(lits.size(), lits.c_ptr()); } void add_clause(unsigned n, literal const* ls) { + for (unsigned i = 0; i < n; ++i) { + if (is_true(ls[i])) return; + } m_stats.m_num_compiled_clauses++; m_stats.m_num_clause_vars += n; literal_vector tmp(n, ls); @@ -661,17 +901,17 @@ Notes: // y1 <= mk_max(x1,x2) // y2 <= mk_min(x1,x2) void cmp_ge(literal x1, literal x2, literal y1, literal y2) { - add_clause(ctx.mk_not(y2), x1); - add_clause(ctx.mk_not(y2), x2); - add_clause(ctx.mk_not(y1), x1, x2); + add_clause(mk_not(y2), x1); + add_clause(mk_not(y2), x2); + add_clause(mk_not(y1), x1, x2); } // mk_max(x1,x2) <= y1 // mk_min(x1,x2) <= y2 void cmp_le(literal x1, literal x2, literal y1, literal y2) { - add_clause(ctx.mk_not(x1), y1); - add_clause(ctx.mk_not(x2), y1); - add_clause(ctx.mk_not(x1), ctx.mk_not(x2), y2); + add_clause(mk_not(x1), y1); + add_clause(mk_not(x2), y1); + add_clause(mk_not(x1), mk_not(x2), y2); } void cmp_eq(literal x1, literal x2, literal y1, literal y2) { @@ -773,6 +1013,8 @@ Notes: } TRACE("pb_verbose", tout << "merge a: " << a << " b: " << b << " "; tout << "num clauses " << m_stats.m_num_compiled_clauses - nc << "\n"; + vc_dsmerge(a, b, a + b).pp(tout << "vc_dsmerge ") << "\n"; + vc_smerge_rec(a, b, a + b).pp(tout << "vc_smerge_rec ") << "\n"; //pp(tout << "a:", a, as) << "\n"; //pp(tout << "b:", b, bs) << "\n"; //pp(tout << "out:", out) << "\n"; @@ -796,7 +1038,8 @@ Notes: return vc_merge(ceil2(a), ceil2(b)) + vc_merge(floor2(a), floor2(b)) + - vc_interleave(ceil2(a) + ceil2(b), floor2(a) + floor2(b)); + vc_interleave(ceil2(a) + ceil2(b), floor2(a) + floor2(b)) - + vc(0, 2); } void split(unsigned n, literal const* ls, literal_vector& even, literal_vector& odd) { for (unsigned i = 0; i < n; i += 2) { @@ -914,12 +1157,12 @@ Notes: if (m_t != GE) { // x1 <= mk_max(x1,x2) // x2 <= mk_max(x1,x2) - add_clause(ctx.mk_not(as[0]), y); - add_clause(ctx.mk_not(bs[0]), y); + add_clause(mk_not(as[0]), y); + add_clause(mk_not(bs[0]), y); } if (m_t != LE) { // mk_max(x1,x2) <= x1, x2 - add_clause(ctx.mk_not(y), as[0], bs[0]); + add_clause(mk_not(y), as[0], bs[0]); } out.push_back(y); } @@ -970,11 +1213,11 @@ Notes: out2.pop_back(); y = mk_max(z1, z2); if (m_t != GE) { - add_clause(ctx.mk_not(z1), y); - add_clause(ctx.mk_not(z2), y); + add_clause(mk_not(z1), y); + add_clause(mk_not(z2), y); } if (m_t != LE) { - add_clause(ctx.mk_not(y), z1, z2); + add_clause(mk_not(y), z1, z2); } } interleave(out1, out2, out); @@ -1018,7 +1261,7 @@ Notes: return m_force_dsmerge || (!m_disable_dsmerge && - a < (1 << 7) && b < (1 << 7) && + a < 10 && b < 10 && vc_dsmerge(a, b, a + b) < vc_smerge_rec(a, b, c)); } @@ -1027,7 +1270,7 @@ Notes: unsigned a, literal const* as, unsigned b, literal const* bs, literal_vector& out) { - TRACE("pb_verbose", tout << "dsmerge: c:" << c << " a:" << a << " b:" << b << "\n";); + unsigned nc = m_stats.m_num_compiled_clauses; SASSERT(a <= c); SASSERT(b <= c); SASSERT(a + b >= c); @@ -1036,14 +1279,14 @@ Notes: } if (m_t != GE) { for (unsigned i = 0; i < a; ++i) { - add_clause(ctx.mk_not(as[i]), out[i]); + add_clause(mk_not(as[i]), out[i]); } for (unsigned i = 0; i < b; ++i) { - add_clause(ctx.mk_not(bs[i]), out[i]); + add_clause(mk_not(bs[i]), out[i]); } for (unsigned i = 1; i <= a; ++i) { for (unsigned j = 1; j <= b && i + j <= c; ++j) { - add_clause(ctx.mk_not(as[i-1]),ctx.mk_not(bs[j-1]),out[i+j-1]); + add_clause(mk_not(as[i-1]),mk_not(bs[j-1]),out[i+j-1]); } } } @@ -1051,12 +1294,12 @@ Notes: literal_vector ls; for (unsigned k = 0; k < c; ++k) { ls.reset(); - ls.push_back(ctx.mk_not(out[k])); + ls.push_back(mk_not(out[k])); if (a <= k) { - add_clause(ctx.mk_not(out[k]), bs[k-a]); + add_clause(mk_not(out[k]), bs[k-a]); } if (b <= k) { - add_clause(ctx.mk_not(out[k]), as[k-b]); + add_clause(mk_not(out[k]), as[k-b]); } for (unsigned i = 0; i < std::min(a,k + 1); ++i) { unsigned j = k - i; @@ -1071,7 +1314,13 @@ Notes: } } } + TRACE("pb_verbose", tout << "dsmerge: c:" << c << " a:" << a << " b:" << b << " "; + tout << "num clauses: " << m_stats.m_num_compiled_clauses - nc << "\n"; + vc_dsmerge(a, b, c).pp(tout << "vc_dsmerge ") << "\n"; + vc_smerge_rec(a, b, c).pp(tout << "vc_smerge_rec ") << "\n"; + ); } + vc vc_dsmerge(unsigned a, unsigned b, unsigned c) { vc v(c, 0); if (m_t != GE) { @@ -1101,7 +1350,7 @@ Notes: } if (m_t != LE) { for (unsigned k = 1; k <= m; ++k) { - lits.push_back(ctx.mk_not(out[k-1])); + lits.push_back(mk_not(out[k-1])); add_subset(false, n-k+1, 0, lits, n, xs); lits.pop_back(); } @@ -1134,7 +1383,7 @@ Notes: return; } for (unsigned i = offset; i < n - k + 1; ++i) { - lits.push_back(polarity?ctx.mk_not(xs[i]):xs[i]); + lits.push_back(polarity?mk_not(xs[i]):xs[i]); add_subset(polarity, k-1, i+1, lits, n, xs); lits.pop_back(); }