diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index f546ef70b..d85bdec41 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -5,7 +5,7 @@ Module Name: pb2bv_rewriter.cpp -Abstract: +Abstralct: Conversion from pseudo-booleans to bit-vectors. @@ -25,6 +25,7 @@ Notes: #include"ast_util.h" #include"ast_pp.h" #include"lbool.h" +#include"uint_set.h" const unsigned g_primes[7] = { 2, 3, 5, 7, 11, 13, 17}; @@ -111,7 +112,13 @@ struct pb2bv_rewriter::imp { if (k.is_neg()) { return expr_ref((is_le == l_false)?m.mk_true():m.mk_false(), m); } - + + expr_ref result(m); + switch (is_le) { + case l_true: if (mk_le_tot(sz, args, k, result)) return result; else break; + case l_false: if (mk_ge_tot(sz, args, k, result)) return result; else break; + case l_undef: break; + } #if 0 expr_ref result(m); switch (is_le) { @@ -172,6 +179,141 @@ struct pb2bv_rewriter::imp { } } + /** + \brief Totalizer encoding. Based on a version by Miguel. + */ + + bool mk_le_tot(unsigned sz, expr * const * args, rational const& _k, expr_ref& result) { + SASSERT(sz == m_coeffs.size()); + if (!_k.is_unsigned() || sz == 0) return false; + unsigned k = _k.get_unsigned(); + expr_ref_vector args1(m); + rational bound; + flip(sz, args, args1, _k, bound); + if (bound.get_unsigned() < k) { + return mk_ge_tot(sz, args1.c_ptr(), bound, result); + } + if (k > 20) { + return false; + } + result = m.mk_not(bounded_addition(sz, args, k + 1)); + TRACE("pb", tout << result << "\n";); + return true; + } + + bool mk_ge_tot(unsigned sz, expr * const * args, rational const& _k, expr_ref& result) { + SASSERT(sz == m_coeffs.size()); + if (!_k.is_unsigned() || sz == 0) return false; + unsigned k = _k.get_unsigned(); + expr_ref_vector args1(m); + rational bound; + flip(sz, args, args1, _k, bound); + if (bound.get_unsigned() < k) { + return mk_le_tot(sz, args1.c_ptr(), bound, result); + } + if (k > 20) { + return false; + } + result = bounded_addition(sz, args, k); + TRACE("pb", tout << result << "\n";); + return true; + } + + void flip(unsigned sz, expr* const* args, expr_ref_vector& args1, rational const& k, rational& bound) { + bound = -k; + for (unsigned i = 0; i < sz; ++i) { + args1.push_back(mk_not(args[i])); + bound += m_coeffs[i]; + } + } + + expr_ref bounded_addition(unsigned sz, expr * const * args, unsigned k) { + SASSERT(sz > 0); + expr_ref result(m); + vector es; + vector coeffs; + for (unsigned i = 0; i < m_coeffs.size(); ++i) { + unsigned_vector v; + expr_ref_vector e(m); + unsigned c = m_coeffs[i].get_unsigned(); + v.push_back(c >= k ? k : c); + e.push_back(args[i]); + es.push_back(e); + coeffs.push_back(v); + } + while (es.size() > 1) { + for (unsigned i = 0; i + 1 < es.size(); i += 2) { + expr_ref_vector o(m); + unsigned_vector oc; + tot_adder(es[i], coeffs[i], es[i + 1], coeffs[i + 1], k, o, oc); + es[i / 2].set(o); + coeffs[i / 2] = oc; + } + if ((es.size() % 2) == 1) { + es[es.size() / 2].set(es.back()); + coeffs[es.size() / 2] = coeffs.back(); + } + es.shrink((1 + es.size())/2); + coeffs.shrink((1 + coeffs.size())/2); + } + SASSERT(coeffs.size() == 1); + SASSERT(coeffs[0].back() <= k); + if (coeffs[0].back() == k) { + result = es[0].back(); + } + else { + result = m.mk_false(); + } + return result; + } + + void tot_adder(expr_ref_vector const& l, unsigned_vector const& lc, + expr_ref_vector const& r, unsigned_vector const& rc, + unsigned k, + expr_ref_vector& o, unsigned_vector & oc) { + SASSERT(l.size() == lc.size()); + SASSERT(r.size() == rc.size()); + uint_set sums; + vector trail; + u_map sum2def; + for (unsigned i = 0; i <= l.size(); ++i) { + for (unsigned j = (i == 0) ? 1 : 0; j <= r.size(); ++j) { + unsigned sum = std::min(k, ((i == 0) ? 0 : lc[i - 1]) + ((j == 0) ? 0 : rc[j - 1])); + sums.insert(sum); + } + } + uint_set::iterator it = sums.begin(), end = sums.end(); + for (; it != end; ++it) { + oc.push_back(*it); + } + std::sort(oc.begin(), oc.end()); + DEBUG_CODE( + for (unsigned i = 0; i + 1 < oc.size(); ++i) { + SASSERT(oc[i] < oc[i+1]); + }); + for (unsigned i = 0; i < oc.size(); ++i) { + sum2def.insert(oc[i], i); + trail.push_back(expr_ref_vector(m)); + } + for (unsigned i = 0; i <= l.size(); ++i) { + for (unsigned j = (i == 0) ? 1 : 0; j <= r.size(); ++j) { + if (i != 0 && j != 0 && (lc[i - 1] >= k || rc[j - 1] >= k)) continue; + unsigned sum = std::min(k, ((i == 0) ? 0 : lc[i - 1]) + ((j == 0) ? 0 : rc[j - 1])); + expr_ref_vector ands(m); + if (i != 0) { + ands.push_back(l[i - 1]); + } + if (j != 0) { + ands.push_back(r[j - 1]); + } + trail[sum2def.find(sum)].push_back(::mk_and(ands)); + } + } + for (unsigned i = 0; i < oc.size(); ++i) { + o.push_back(::mk_or(trail[sum2def.find(oc[i])])); + } + } + /** \brief MiniSat+ based encoding of PB constraints. The procedure is described in "Translating Pseudo-Boolean Constraints into SAT " diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 3344e82e1..c4f6a4bba 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -282,8 +282,10 @@ namespace sat { inc_bstamp(); set_bstamp(l); literal_vector const& conseq = m_binary[l.index()]; - for (unsigned i = 0; i < conseq.size(); ++i) { - set_bstamp(conseq[i]); + literal_vector::const_iterator it = conseq.begin(); + literal_vector::const_iterator end = conseq.end(); + for (; it != end; ++it) { + set_bstamp(*it); } } bool is_stamped(literal l) const { return m_bstamp[l.index()] == m_bstamp_id; } @@ -557,6 +559,7 @@ namespace sat { literal_vector::iterator it = m_binary[l.index()].begin(), end = m_binary[l.index()].end(); for (; it != end; ++it) { if (is_undef(*it)) sum += h[it->index()]; + // if (m_freevars.contains(it->var())) sum += h[it->index()]; } watch_list& wlist = m_watches[l.index()]; watch_list::iterator wit = wlist.begin(), wend = wlist.end(); @@ -568,9 +571,8 @@ namespace sat { case watched::TERNARY: { literal l1 = wit->get_literal1(); literal l2 = wit->get_literal2(); - if (is_undef(l1) && is_undef(l2)) { - tsum += h[l1.index()] * h[l2.index()]; - } + // if (is_undef(l1) && is_undef(l2)) + tsum += h[l1.index()] * h[l2.index()]; break; } case watched::CLAUSE: { @@ -1155,14 +1157,19 @@ namespace sat { // convert windfalls to binary clauses. if (!unsat) { literal nlit = ~lit; + for (unsigned i = 0; i < m_wstack.size(); ++i) { - ++m_stats.m_windfall_binaries; literal l2 = m_wstack[i]; //update_prefix(~lit); //update_prefix(m_wstack[i]); TRACE("sat", tout << "windfall: " << nlit << " " << l2 << "\n";); + // if we use try_add_binary, then this may produce new assignments + // these assignments get put on m_trail, and they are cleared by + // reset_wnb. We would need to distinguish the trail that comes + // from lookahead levels and the main search level for this to work. add_binary(nlit, l2); } + m_stats.m_windfall_binaries += m_wstack.size(); } m_wstack.reset(); } @@ -1180,6 +1187,15 @@ namespace sat { return r; } + // + // The current version is modeled after CDCL SAT solving data-structures. + // It borrows from the watch list data-structure. The cost tradeoffs are somewhat + // biased towards CDCL search overheads. + // If we walk over the positive occurrences of l, then those clauses can be retired so + // that they don't interfere with calculation of H. Instead of removing clauses from the watch + // list one can swap them to the "back" and adjust a size indicator of the watch list + // Only the size indicator needs to be updated on backtracking. + // void propagate_clauses(literal l) { SASSERT(is_true(l)); if (inconsistent()) return; @@ -1237,17 +1253,17 @@ namespace sat { break; } case watched::CLAUSE: { - clause_offset cls_off = it->get_clause_offset(); - clause & c = *(s.m_cls_allocator.get_clause(cls_off)); if (is_true(it->get_blocked_literal())) { *it2 = *it; ++it2; break; } - + clause_offset cls_off = it->get_clause_offset(); + clause & c = *(s.m_cls_allocator.get_clause(cls_off)); if (c[0] == ~l) std::swap(c[0], c[1]); if (is_true(c[0])) { + it->set_blocked_literal(c[0]); *it2 = *it; it2++; break; @@ -1337,13 +1353,21 @@ namespace sat { } void propagate() { - for (; m_qhead < m_trail.size(); ++m_qhead) { - if (inconsistent()) break; - literal l = m_trail[m_qhead]; - TRACE("sat", tout << "propagate " << l << " @ " << m_level << "\n";); - propagate_binary(l); - propagate_clauses(l); + while (!inconsistent() && m_qhead < m_trail.size()) { + unsigned i = m_qhead; + unsigned sz = m_trail.size(); + for (; i < sz && !inconsistent(); ++i) { + literal l = m_trail[i]; + TRACE("sat", tout << "propagate " << l << " @ " << m_level << "\n";); + propagate_binary(l); + } + i = m_qhead; + for (; i < sz && !inconsistent(); ++i) { + propagate_clauses(m_trail[i]); + } + m_qhead = sz; } + TRACE("sat_verbose", display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n");); }