diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index 0aeeea81a..48d566f11 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -24,6 +24,7 @@ Notes: #include"sorting_network.h" #include"ast_util.h" #include"ast_pp.h" +#include"lbool.h" struct pb2bv_rewriter::imp { @@ -79,150 +80,118 @@ struct pb2bv_rewriter::imp { bv_util bv; expr_ref_vector m_trail; - unsigned get_num_bits(func_decl* f) { - rational r(0); - unsigned sz = f->get_arity(); - for (unsigned i = 0; i < sz; ++i) { - r += pb.get_coeff(f, i); + template + expr_ref mk_le_ge(expr_ref_vector& fmls, expr* a, expr* b, expr* bound) { + expr_ref x(m), y(m), result(m); + unsigned nb = bv.get_bv_size(a); + x = bv.mk_zero_extend(1, a); + y = bv.mk_zero_extend(1, b); + result = bv.mk_bv_add(x, y); + x = bv.mk_extract(nb, nb, result); + result = bv.mk_extract(nb-1, 0, result); + if (is_le != l_false) { + fmls.push_back(m.mk_eq(x, bv.mk_numeral(rational::zero(), 1))); + fmls.push_back(bv.mk_ule(result, bound)); } - r = r > pb.get_k(f)? r : pb.get_k(f); - return r.get_num_bits(); + else { + fmls.push_back(m.mk_eq(x, bv.mk_numeral(rational::one(), 1))); + fmls.push_back(bv.mk_ule(bound, result)); + } + return result; + } - void mk_bv(func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { - - expr_ref zero(m), a(m), b(m); - expr_ref_vector es(m); - unsigned bw = get_num_bits(f); - zero = bv.mk_numeral(rational(0), bw); + // + // create a circuit of size sz*log(k) + // by forming a binary tree adding pairs of values that are assumed <= k, + // and in each step we check that the result is <= k by checking the overflow + // bit and that the non-overflow bits are <= k. + // The procedure for checking >= k is symmetric and checking for = k is + // achieved by checking <= k on intermediary addends and the resulting sum is = k. + // + template + expr_ref mk_le_ge(func_decl *f, unsigned sz, expr * const* args, rational const & k) { + if (k.is_zero()) { + if (is_le != l_false) { + return expr_ref(m.mk_not(mk_or(m, sz, args)), m); + } + else { + return expr_ref(m.mk_true(), m); + } + } + SASSERT(k.is_pos()); + expr_ref zero(m), bound(m); + expr_ref_vector es(m), fmls(m); + unsigned nb = k.get_num_bits(); + zero = bv.mk_numeral(rational(0), nb); + bound = bv.mk_numeral(k, nb); for (unsigned i = 0; i < sz; ++i) { - es.push_back(mk_ite(args[i], bv.mk_numeral(pb.get_coeff(f, i), bw), zero)); - } - switch (es.size()) { - case 0: a = zero; break; - case 1: a = es[0].get(); break; - default: - a = es[0].get(); - for (unsigned i = 1; i < es.size(); ++i) { - a = bv.mk_bv_add(a, es[i].get()); - } - break; - } - b = bv.mk_numeral(pb.get_k(f), bw); - - switch (f->get_decl_kind()) { - case OP_AT_MOST_K: - case OP_PB_LE: - result = bv.mk_ule(a, b); - break; - case OP_AT_LEAST_K: - case OP_PB_GE: - result = bv.mk_ule(b, a); - break; - case OP_PB_EQ: - result = m.mk_eq(a, b); - break; - default: - UNREACHABLE(); - } - TRACE("pb", tout << result << "\n";); - } - - bool mk_shannon(func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { - decl_kind kind = f->get_decl_kind(); - if (kind != OP_PB_GE && kind != OP_AT_LEAST_K) { - return false; - } - unsigned max_clauses = sz*10; - vector argcs; - for (unsigned i = 0; i < sz; ++i) { - argcs.push_back(argc_t(args[i], pb.get_coeff(f, i))); - } - std::sort(argcs.begin(), argcs.end(), argc_gt()); - DEBUG_CODE( - for (unsigned i = 0; i + 1 < sz; ++i) { - SASSERT(argcs[i].m_coeff >= argcs[i+1].m_coeff); - }); - result = m.mk_app(f, sz, args); - TRACE("pb", tout << result << "\n";); - argc_cache cache; - expr_ref_vector trail(m); - vector todo_k; - unsigned_vector todo_i; - todo_k.push_back(pb.get_k(f)); - todo_i.push_back(0); - argc_entry entry1; - while (!todo_i.empty()) { - SASSERT(todo_i.size() == todo_k.size()); - if (cache.size() > max_clauses) { - return false; - } - unsigned i = todo_i.back(); - rational k = todo_k.back(); - argc_entry entry(i, k); - if (cache.contains(entry)) { - todo_i.pop_back(); - todo_k.pop_back(); - continue; - } - SASSERT(i < sz); - SASSERT(!k.is_neg()); - rational const& coeff = argcs[i].m_coeff; - expr* arg = argcs[i].m_arg; - if (i + 1 == sz) { - if (k.is_zero()) { - entry.m_value = m.mk_true(); - } - else if (coeff < k) { - entry.m_value = m.mk_false(); - } - else if (coeff.is_zero()) { - entry.m_value = m.mk_true(); + if (pb.get_coeff(f, i) > k) { + if (is_le != l_false) { + fmls.push_back(m.mk_not(args[i])); } else { - SASSERT(coeff >= k && k.is_pos()); - entry.m_value = arg; + fmls.push_back(args[i]); } - todo_i.pop_back(); - todo_k.pop_back(); - cache.insert(entry); - continue; - } - entry.m_index++; - expr* lo = 0, *hi = 0; - if (cache.find(entry, entry1)) { - lo = entry1.m_value; } else { - todo_i.push_back(i+1); - todo_k.push_back(k); - } - entry.m_k -= coeff; - if (!entry.m_k.is_pos()) { - hi = m.mk_true(); + es.push_back(mk_ite(args[i], bv.mk_numeral(pb.get_coeff(f, i), nb), zero)); } - else if (cache.find(entry, entry1)) { - hi = entry1.m_value; + } + while (es.size() > 1) { + for (unsigned i = 0; i + 1 < es.size(); i += 2) { + es[i/2] = mk_le_ge(fmls, es[i].get(), es[i+1].get(), bound); } - else { - todo_i.push_back(i+1); - todo_k.push_back(entry.m_k); + if ((es.size() % 2) == 1) { + es[es.size()/2] = es.back(); } - if (hi && lo) { - todo_i.pop_back(); - todo_k.pop_back(); - entry.m_index = i; - entry.m_k = k; - entry.m_value = mk_ite(arg, hi, lo); - trail.push_back(entry.m_value); - cache.insert(entry); - } - } - argc_entry entry(0, pb.get_k(f)); - VERIFY(cache.find(entry, entry)); - result = entry.m_value; - TRACE("pb", tout << result << "\n";); - return true; + es.shrink((1 + es.size())/2); + } + switch (is_le) { + case l_true: + return mk_and(fmls); + case l_false: + fmls.push_back(bv.mk_ule(bound, es.back())); + return mk_or(fmls); + case l_undef: + fmls.push_back(m.mk_eq(bound, es.back())); + return mk_and(fmls); + default: + UNREACHABLE(); + return expr_ref(m.mk_true(), m); + } + } + + expr_ref mk_bv(func_decl * f, unsigned sz, expr * const* args) { + decl_kind kind = f->get_decl_kind(); + rational k = pb.get_k(f); + SASSERT(!k.is_neg()); + switch (kind) { + case OP_PB_GE: + case OP_AT_LEAST_K: { + expr_ref_vector nargs(m); + nargs.append(sz, args); + dualize(f, nargs, k); + SASSERT(!k.is_neg()); + return mk_le_ge(f, sz, nargs.c_ptr(), k); + } + case OP_PB_LE: + case OP_AT_MOST_K: + return mk_le_ge(f, sz, args, k); + case OP_PB_EQ: + return mk_le_ge(f, sz, args, k); + default: + UNREACHABLE(); + return expr_ref(m.mk_true(), m); + } + } + + void dualize(func_decl* f, expr_ref_vector & args, rational & k) { + k.neg(); + for (unsigned i = 0; i < args.size(); ++i) { + k += pb.get_coeff(f, i); + args[i] = ::mk_not(m, args[i].get()); + } } expr* negate(expr* e) { @@ -346,8 +315,8 @@ struct pb2bv_rewriter::imp { else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { result = m_sort.ge(true, pb.get_k(f).get_unsigned(), sz, args); } - else if (!mk_shannon(f, sz, args, result)) { - mk_bv(f, sz, args, result); + else { + result = mk_bv(f, sz, args); } }