diff --git a/src/sat/smt/pb_pb.h b/src/sat/smt/pb_pb.h index 30d1f0c1c..5386051af 100644 --- a/src/sat/smt/pb_pb.h +++ b/src/sat/smt/pb_pb.h @@ -34,6 +34,7 @@ namespace pb { literal lit() const { return m_lit; } wliteral operator[](unsigned i) const { return m_wlits[i]; } wliteral& operator[](unsigned i) { return m_wlits[i]; } + wliteral *data() { return m_wlits; } wliteral const* begin() const { return m_wlits; } wliteral const* end() const { return begin() + m_size; } diff --git a/src/sat/smt/pb_solver.cpp b/src/sat/smt/pb_solver.cpp index 7d72c15b2..6944dcfa0 100644 --- a/src/sat/smt/pb_solver.cpp +++ b/src/sat/smt/pb_solver.cpp @@ -397,19 +397,17 @@ namespace pb { return l_undef; } - void solver::recompile(pbc& p) { - // IF_VERBOSE(2, verbose_stream() << "re: " << p << "\n";); + std::pair solver::normalize(wliteral* begin, wliteral* end, unsigned k) { SASSERT(p.num_watch() == 0); - m_weights.resize(2*s().num_vars(), 0); - for (auto [w, lit] : p) { + m_weights.resize(2 * s().num_vars(), 0); + for (auto it = begin; it != end; ++it) { + auto [w, lit] = *it; m_weights[lit.index()] += w; - } - unsigned k = p.k(); - unsigned sz = p.size(); - bool all_units = true; - unsigned j = 0; - for (unsigned i = 0; i < sz && 0 < k; ++i) { - auto [w, l] = p[i]; + } + auto j = begin; + unsigned sz = 0; + for (auto it = begin; it != end; ++it) { + auto [w, l] = *it; unsigned w1 = m_weights[l.index()]; unsigned w2 = m_weights[(~l).index()]; if (w1 == 0 || w1 < w2) { @@ -424,23 +422,33 @@ namespace pb { k -= w2; w1 -= w2; m_weights[l.index()] = 0; - m_weights[(~l).index()] = 0; + m_weights[(~l).index()] = 0; if (w1 == 0) { continue; - } + } else { - p[j] = wliteral(w1, l); - all_units &= w1 == 1; + *j = wliteral(w1, l); ++j; + ++sz; } } } - sz = j; // clear weights - for (auto [w, lit] : p) { - m_weights[lit.index()] = 0; - m_weights[(~lit).index()] = 0; + while (begin != end) { + auto [w, l] = *begin; + m_weights[l.index()] = 0; + m_weights[(~l).index()] = 0; + ++begin; } + return {sz, k}; + } + + void solver::recompile(pbc& p) { + // IF_VERBOSE(2, verbose_stream() << "re: " << p << "\n";); + + auto [sz, k] = normalize(p.data(), p.data() + p.size(), p.k()); + p.set_size(sz); + auto all_units = all_of(p, [](wliteral const& wl) { return wl.first == 1; }); if (k == 0) { if (p.lit() != sat::null_literal) { @@ -463,8 +471,7 @@ namespace pb { remove_constraint(p, "recompiled to cardinality"); return; } - else { - p.set_size(sz); + else { p.update_max_sum(); if (p.max_sum() < k) { if (p.lit() == sat::null_literal) { @@ -1463,7 +1470,10 @@ namespace pb { for (auto const&[w, l] : wlits) { auto v = l.var(); if (is_visited(v)) { - throw default_exception("malformed constraint: variable appears more than once - is pre-processing disabled?"); + svector wlits2(wlits); + auto [sz, k2] = normalize(wlits2.data(), wlits2.data() + wlits2.size(), k); + wlits2.shrink(sz); + return add_pb_ge(lit, wlits2, k2, learned); } mark_visited(v); } diff --git a/src/sat/smt/pb_solver.h b/src/sat/smt/pb_solver.h index 5a09742da..67c55c9d5 100644 --- a/src/sat/smt/pb_solver.h +++ b/src/sat/smt/pb_solver.h @@ -235,6 +235,7 @@ namespace pb { void simplify(constraint& p); void simplify2(pbc& p); bool is_cardinality(pbc const& p); + std::pair normalize(wliteral *begin, wliteral *end, unsigned k); void flush_roots(pbc& p); void recompile(pbc& p); bool clausify(pbc& p);