3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-05-05 09:55:15 +00:00

Refactor pb_solver to use structured bindings for wliteral patterns (#8391)

* Initial plan

* Refactor pb_solver.cpp to use C++17 structured bindings for wliteral patterns

Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>

* Fix active2wlits to avoid unnecessary wliteral reconstruction

Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>
This commit is contained in:
Copilot 2026-01-27 13:58:14 -08:00 committed by GitHub
parent 8cb403384e
commit ef5ee85bfd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -263,8 +263,9 @@ namespace pb {
void solver::add_index(pbc& p, unsigned index, literal lit) { void solver::add_index(pbc& p, unsigned index, literal lit) {
if (value(lit) == l_undef) { if (value(lit) == l_undef) {
m_pb_undef.push_back(index); m_pb_undef.push_back(index);
if (p[index].first > m_a_max) { auto [w, l] = p[index];
m_a_max = p[index].first; if (w > m_a_max) {
m_a_max = w;
} }
} }
} }
@ -295,7 +296,7 @@ namespace pb {
m_a_max = 0; m_a_max = 0;
m_pb_undef.reset(); m_pb_undef.reset();
for (; index < num_watch; ++index) { for (; index < num_watch; ++index) {
literal lit = p[index].second; auto [w, lit] = p[index];
if (lit == alit) { if (lit == alit) {
break; break;
} }
@ -316,21 +317,22 @@ namespace pb {
SASSERT(index < num_watch); SASSERT(index < num_watch);
unsigned index1 = index + 1; unsigned index1 = index + 1;
for (; m_a_max == 0 && index1 < num_watch; ++index1) { for (; m_a_max == 0 && index1 < num_watch; ++index1) {
add_index(p, index1, p[index1].second); auto [w, lit] = p[index1];
add_index(p, index1, lit);
} }
unsigned val = p[index].first; auto [val, alit_lit] = p[index];
SASSERT(value(p[index].second) == l_false); SASSERT(value(alit_lit) == l_false);
SASSERT(val <= slack); SASSERT(val <= slack);
slack -= val; slack -= val;
// find literals to swap with: // find literals to swap with:
for (unsigned j = num_watch; j < sz && slack < bound + m_a_max; ++j) { for (unsigned j = num_watch; j < sz && slack < bound + m_a_max; ++j) {
literal lit = p[j].second; auto [w, lit] = p[j];
if (value(lit) != l_false) { if (value(lit) != l_false) {
slack += p[j].first; slack += w;
SASSERT(!p.is_watched(*this, p[j].second)); SASSERT(!p.is_watched(*this, lit));
p.watch_literal(*this, p[j].second); p.watch_literal(*this, lit);
p.swap(num_watch, j); p.swap(num_watch, j);
add_index(p, num_watch, lit); add_index(p, num_watch, lit);
++num_watch; ++num_watch;
@ -377,11 +379,10 @@ namespace pb {
if (index1 == num_watch) { if (index1 == num_watch) {
index1 = index; index1 = index;
} }
wliteral wl = p[index1]; auto [w, lit] = p[index1];
literal lit = wl.second;
SASSERT(value(lit) == l_undef); SASSERT(value(lit) == l_undef);
if (slack < bound + wl.first) { if (slack < bound + w) {
BADLOG(verbose_stream() << "Assign " << lit << " " << wl.first << "\n"); BADLOG(verbose_stream() << "Assign " << lit << " " << w << "\n");
assign(p, lit); assign(p, lit);
} }
} }
@ -400,15 +401,15 @@ namespace pb {
// IF_VERBOSE(2, verbose_stream() << "re: " << p << "\n";); // IF_VERBOSE(2, verbose_stream() << "re: " << p << "\n";);
SASSERT(p.num_watch() == 0); SASSERT(p.num_watch() == 0);
m_weights.resize(2*s().num_vars(), 0); m_weights.resize(2*s().num_vars(), 0);
for (wliteral wl : p) { for (auto [w, lit] : p) {
m_weights[wl.second.index()] += wl.first; m_weights[lit.index()] += w;
} }
unsigned k = p.k(); unsigned k = p.k();
unsigned sz = p.size(); unsigned sz = p.size();
bool all_units = true; bool all_units = true;
unsigned j = 0; unsigned j = 0;
for (unsigned i = 0; i < sz && 0 < k; ++i) { for (unsigned i = 0; i < sz && 0 < k; ++i) {
literal l = p[i].second; auto [w, l] = p[i];
unsigned w1 = m_weights[l.index()]; unsigned w1 = m_weights[l.index()];
unsigned w2 = m_weights[(~l).index()]; unsigned w2 = m_weights[(~l).index()];
if (w1 == 0 || w1 < w2) { if (w1 == 0 || w1 < w2) {
@ -436,9 +437,9 @@ namespace pb {
} }
sz = j; sz = j;
// clear weights // clear weights
for (wliteral wl : p) { for (auto [w, lit] : p) {
m_weights[wl.second.index()] = 0; m_weights[lit.index()] = 0;
m_weights[(~wl.second).index()] = 0; m_weights[(~lit).index()] = 0;
} }
if (k == 0) { if (k == 0) {
@ -606,7 +607,7 @@ namespace pb {
IF_VERBOSE(0, IF_VERBOSE(0,
active2pb(m_A); active2pb(m_A);
uint64_t c = 0; uint64_t c = 0;
for (wliteral l : m_A.m_wlits) c += l.first; for (auto [w, l] : m_A.m_wlits) c += w;
verbose_stream() << "sum of coefficients: " << c << "\n"; verbose_stream() << "sum of coefficients: " << c << "\n";
display(verbose_stream(), m_A, true); display(verbose_stream(), m_A, true);
verbose_stream() << "conflicting literal: " << s().m_not_l << "\n";); verbose_stream() << "conflicting literal: " << s().m_not_l << "\n";);
@ -882,10 +883,10 @@ namespace pb {
unsigned c = get_abs_coeff(w); unsigned c = get_abs_coeff(w);
if (c == 1 || c == 0) return; if (c == 1 || c == 0) return;
for (bool_var v : m_active_vars) { for (bool_var v : m_active_vars) {
wliteral wl = get_wliteral(v); auto [coeff, l] = get_wliteral(v);
unsigned q = wl.first % c; unsigned q = coeff % c;
if (q != 0 && !is_false(wl.second)) { if (q != 0 && !is_false(l)) {
m_coeffs[v] = wl.first - q; m_coeffs[v] = coeff - q;
m_bound -= q; m_bound -= q;
SASSERT(m_bound > 0); SASSERT(m_bound > 0);
} }
@ -951,8 +952,7 @@ namespace pb {
* below the current processing level. * below the current processing level.
*/ */
void solver::mark_variables(ineq const& ineq) { void solver::mark_variables(ineq const& ineq) {
for (wliteral wl : ineq.m_wlits) { for (auto [w, l] : ineq.m_wlits) {
literal l = wl.second;
if (!is_false(l)) continue; if (!is_false(l)) continue;
bool_var v = l.var(); bool_var v = l.var();
unsigned level = lvl(v); unsigned level = lvl(v);
@ -1061,8 +1061,8 @@ namespace pb {
if (consequent == sat::null_literal) { if (consequent == sat::null_literal) {
SASSERT(validate_ineq(m_A)); SASSERT(validate_ineq(m_A));
m_bound = static_cast<unsigned>(m_A.m_k); m_bound = static_cast<unsigned>(m_A.m_k);
for (wliteral wl : m_A.m_wlits) { for (auto [w, lit] : m_A.m_wlits) {
process_antecedent(wl.second, wl.first); process_antecedent(lit, w);
} }
} }
else { else {
@ -1432,8 +1432,8 @@ namespace pb {
constraint* solver::add_pb_ge(literal lit, svector<wliteral> const& wlits, unsigned k, bool learned) { constraint* solver::add_pb_ge(literal lit, svector<wliteral> const& wlits, unsigned k, bool learned) {
bool units = true; bool units = true;
for (wliteral wl : wlits) for (auto [w, l] : wlits)
units &= wl.first == 1; units &= w == 1;
if (k == 0) { if (k == 0) {
if (lit != sat::null_literal) if (lit != sat::null_literal)
@ -1451,12 +1451,12 @@ namespace pb {
return nullptr; return nullptr;
} }
if (!learned) { if (!learned) {
for (wliteral wl : wlits) for (auto [w, l] : wlits)
s().set_external(wl.second.var()); s().set_external(l.var());
} }
if (units || k == 1) { if (units || k == 1) {
literal_vector lits; literal_vector lits;
for (wliteral wl : wlits) lits.push_back(wl.second); for (auto [w, l] : wlits) lits.push_back(l);
return add_at_least(lit, lits, k, learned); return add_at_least(lit, lits, k, learned);
} }
void * mem = m_allocator.allocate(pbc::get_obj_size(wlits.size())); void * mem = m_allocator.allocate(pbc::get_obj_size(wlits.size()));
@ -1568,16 +1568,14 @@ namespace pb {
// The literal comes from a conflict. // The literal comes from a conflict.
// it is forced true, but assigned to false. // it is forced true, but assigned to false.
unsigned slack = 0; unsigned slack = 0;
for (wliteral wl : p) { for (auto [w, lit] : p) {
if (value(wl.second) != l_false) { if (value(lit) != l_false) {
slack += wl.first; slack += w;
} }
} }
SASSERT(slack < k); SASSERT(slack < k);
for (wliteral wl : p) { for (auto [w, lit] : p) {
literal lit = wl.second;
if (lit != l && value(lit) == l_false) { if (lit != l && value(lit) == l_false) {
unsigned w = wl.first;
if (slack + w < k) { if (slack + w < k) {
slack += w; slack += w;
} }
@ -1592,8 +1590,9 @@ namespace pb {
SASSERT(value(l) == l_true); SASSERT(value(l) == l_true);
unsigned coeff = 0, j = 0; unsigned coeff = 0, j = 0;
for (; j < p.size(); ++j) { for (; j < p.size(); ++j) {
if (p[j].second == l) { auto [w, lit] = p[j];
coeff = p[j].first; if (lit == l) {
coeff = w;
break; break;
} }
} }
@ -1613,8 +1612,7 @@ namespace pb {
// we need antecedents to be deeper than alit. // we need antecedents to be deeper than alit.
for (; j < p.size(); ++j) { for (; j < p.size(); ++j) {
literal lit = p[j].second; auto [w, lit] = p[j];
unsigned w = p[j].first;
if (l_false != value(lit)) { if (l_false != value(lit)) {
// skip // skip
} }
@ -1805,7 +1803,7 @@ namespace pb {
if (p.lit() == sat::null_literal || value(p.lit()) != l_true) if (p.lit() == sat::null_literal || value(p.lit()) != l_true)
return true; return true;
for (unsigned i = 0; i < p.size(); ++i) { for (unsigned i = 0; i < p.size(); ++i) {
literal l = p[i].second; auto [w, l] = p[i];
if (l != alit && lvl(l) != 0 && p.is_watched(*this, l) != (i < p.num_watch())) { if (l != alit && lvl(l) != 0 && p.is_watched(*this, l) != (i < p.num_watch())) {
IF_VERBOSE(0, display(verbose_stream(), p, true); IF_VERBOSE(0, display(verbose_stream(), p, true);
verbose_stream() << "literal " << l << " at position " << i << " " << p.is_watched(*this, l) << "\n";); verbose_stream() << "literal " << l << " at position " << i << " " << p.is_watched(*this, l) << "\n";);
@ -1814,8 +1812,10 @@ namespace pb {
} }
} }
unsigned slack = 0; unsigned slack = 0;
for (unsigned i = 0; i < p.num_watch(); ++i) for (unsigned i = 0; i < p.num_watch(); ++i) {
slack += p[i].first; auto [w, l] = p[i];
slack += w;
}
if (slack != p.slack()) { if (slack != p.slack()) {
IF_VERBOSE(0, display(verbose_stream(), p, true);); IF_VERBOSE(0, display(verbose_stream(), p, true););
UNREACHABLE(); UNREACHABLE();
@ -1847,8 +1847,8 @@ namespace pb {
} }
break; break;
case pb::tag_t::pb_t: case pb::tag_t::pb_t:
for (wliteral l : c.to_pb()) { for (auto [w, l] : c.to_pb()) {
if (s().m_phase[l.second.var()] == !l.second.sign()) ++r; if (s().m_phase[l.var()] == !l.sign()) ++r;
} }
break; break;
default: default:
@ -2391,8 +2391,8 @@ namespace pb {
bool ok = !p.learned(); bool ok = !p.learned();
bool is_def = p.lit() != sat::null_literal; bool is_def = p.lit() != sat::null_literal;
for (wliteral wl : p) { for (auto [w, lit] : p) {
ok &= !s().was_eliminated(wl.second); ok &= !s().was_eliminated(lit);
} }
ok &= !is_def || !s().was_eliminated(p.lit()); ok &= !is_def || !s().was_eliminated(p.lit());
if (!ok) { if (!ok) {
@ -2415,12 +2415,12 @@ namespace pb {
bool solver::is_cardinality(pbc const& p, literal_vector& lits) { bool solver::is_cardinality(pbc const& p, literal_vector& lits) {
lits.reset(); lits.reset();
p.size(); p.size();
for (wliteral wl : p) { for (auto [w, lit] : p) {
if (lits.size() > 2*p.size() + wl.first) { if (lits.size() > 2*p.size() + w) {
return false; return false;
} }
for (unsigned i = 0; i < wl.first; ++i) { for (unsigned i = 0; i < w; ++i) {
lits.push_back(wl.second); lits.push_back(lit);
} }
} }
return true; return true;
@ -3012,17 +3012,18 @@ namespace pb {
return; return;
} }
init_visited(); init_visited();
for (wliteral l : p1) { for (auto [w, l] : p1) {
SASSERT(m_weights.size() <= l.second.index() || m_weights[l.second.index()] == 0); SASSERT(m_weights.size() <= l.index() || m_weights[l.index()] == 0);
m_weights.setx(l.second.index(), l.first, 0); m_weights.setx(l.index(), w, 0);
mark_visited(l.second); mark_visited(l);
} }
for (unsigned i = 0; i < std::min(10u, p1.num_watch()); ++i) { for (unsigned i = 0; i < std::min(10u, p1.num_watch()); ++i) {
unsigned j = s().m_rand() % p1.num_watch(); unsigned j = s().m_rand() % p1.num_watch();
subsumes(p1, p1[j].second); auto [w, lit] = p1[j];
subsumes(p1, lit);
} }
for (wliteral l : p1) { for (auto [w, l] : p1) {
m_weights[l.second.index()] = 0; m_weights[l.index()] = 0;
} }
} }
@ -3228,10 +3229,9 @@ namespace pb {
// sum a < k // sum a < k
// val(r) = false // val(r) = false
// hence alit has to be true. // hence alit has to be true.
for (wliteral wl : p) { for (auto [w, lit] : p) {
literal lit = wl.second;
if (lit != alit && !r.contains(~lit)) { if (lit != alit && !r.contains(~lit)) {
sum += wl.first; sum += w;
} }
} }
if (sum >= p.k()) { if (sum >= p.k()) {
@ -3240,14 +3240,14 @@ namespace pb {
display(verbose_stream(), p, true); display(verbose_stream(), p, true);
verbose_stream() << "id: " << p.id() << "\n"; verbose_stream() << "id: " << p.id() << "\n";
sum = 0; sum = 0;
for (wliteral wl : p) sum += wl.first; for (auto [w, lit] : p) sum += w;
verbose_stream() << "overall sum " << sum << "\n"; verbose_stream() << "overall sum " << sum << "\n";
verbose_stream() << "asserting literal: " << alit << "\n"; verbose_stream() << "asserting literal: " << alit << "\n";
verbose_stream() << "reason: " << r << "\n";); verbose_stream() << "reason: " << r << "\n";);
return false; return false;
} }
for (wliteral wl : p) { for (auto [w, lit] : p) {
if (alit == wl.second) { if (alit == lit) {
return true; return true;
} }
} }
@ -3262,10 +3262,10 @@ namespace pb {
reset_active_var_set(); reset_active_var_set();
for (bool_var v : m_active_vars) { for (bool_var v : m_active_vars) {
if (!test_and_set_active(v)) continue; if (!test_and_set_active(v)) continue;
wliteral wl = get_wliteral(v); auto [w, l] = get_wliteral(v);
if (wl.first == 0) continue; if (w == 0) continue;
if (!is_false(wl.second)) { if (!is_false(l)) {
val += wl.first; val += w;
} }
} }
CTRACE(pb, val >= 0, active2pb(m_A); display(tout, m_A, true);); CTRACE(pb, val >= 0, active2pb(m_A); display(tout, m_A, true););
@ -3277,9 +3277,9 @@ namespace pb {
*/ */
bool solver::validate_ineq(ineq const& ineq) const { bool solver::validate_ineq(ineq const& ineq) const {
int64_t k = -static_cast<int64_t>(ineq.m_k); int64_t k = -static_cast<int64_t>(ineq.m_k);
for (wliteral wl : ineq.m_wlits) { for (auto [w, lit] : ineq.m_wlits) {
if (!is_false(wl.second)) if (!is_false(lit))
k += wl.first; k += w;
} }
CTRACE(pb, k > 0, display(tout, ineq, true);); CTRACE(pb, k > 0, display(tout, ineq, true););
return k <= 0; return k <= 0;
@ -3315,9 +3315,10 @@ namespace pb {
for (bool_var v : m_active_vars) { for (bool_var v : m_active_vars) {
if (!test_and_set_active(v)) continue; if (!test_and_set_active(v)) continue;
wliteral wl = get_wliteral(v); wliteral wl = get_wliteral(v);
if (wl.first == 0) continue; auto [w, l] = wl;
if (w == 0) continue;
wlits.push_back(wl); wlits.push_back(wl);
sum += wl.first; sum += w;
} }
m_overflow |= sum >= UINT_MAX/2; m_overflow |= sum >= UINT_MAX/2;
} }
@ -3384,27 +3385,27 @@ namespace pb {
std::sort(m_wlits.begin(), m_wlits.end(), compare_wlit()); std::sort(m_wlits.begin(), m_wlits.end(), compare_wlit());
unsigned k = 0; unsigned k = 0;
uint64_t sum = 0, sum0 = 0; uint64_t sum = 0, sum0 = 0;
for (wliteral wl : m_wlits) { for (auto [w, lit] : m_wlits) {
if (sum >= m_bound) break; if (sum >= m_bound) break;
sum0 = sum; sum0 = sum;
sum += wl.first; sum += w;
++k; ++k;
} }
if (k == 1) { if (k == 1) {
return nullptr; return nullptr;
} }
while (!m_wlits.empty()) { while (!m_wlits.empty()) {
wliteral wl = m_wlits.back(); auto [w, lit] = m_wlits.back();
if (wl.first + sum0 >= m_bound) break; if (w + sum0 >= m_bound) break;
m_wlits.pop_back(); m_wlits.pop_back();
sum0 += wl.first; sum0 += w;
} }
unsigned slack = 0; unsigned slack = 0;
unsigned max_level = 0; unsigned max_level = 0;
for (wliteral wl : m_wlits) { for (auto [w, lit] : m_wlits) {
if (value(wl.second) != l_false) ++slack; if (value(lit) != l_false) ++slack;
unsigned level = lvl(wl.second); unsigned level = lvl(lit);
if (level > max_level) { if (level > max_level) {
max_level = level; max_level = level;
} }
@ -3419,8 +3420,8 @@ namespace pb {
// produce asserting cardinality constraint // produce asserting cardinality constraint
literal_vector lits; literal_vector lits;
for (wliteral wl : m_wlits) { for (auto [w, lit] : m_wlits) {
lits.push_back(wl.second); lits.push_back(lit);
} }
constraint* c = add_at_least(sat::null_literal, lits, k, true); constraint* c = add_at_least(sat::null_literal, lits, k, true);
@ -3428,8 +3429,8 @@ namespace pb {
if (c) { if (c) {
lits.reset(); lits.reset();
for (wliteral wl : m_wlits) { for (auto [w, lit] : m_wlits) {
if (value(wl.second) == l_false) lits.push_back(wl.second); if (value(lit) == l_false) lits.push_back(lit);
} }
unsigned glue = s().num_diff_levels(lits.size(), lits.data()); unsigned glue = s().num_diff_levels(lits.size(), lits.data());
c->set_glue(glue); c->set_glue(glue);
@ -3449,7 +3450,7 @@ namespace pb {
case pb::tag_t::pb_t: { case pb::tag_t::pb_t: {
pbc& p = cnstr.to_pb(); pbc& p = cnstr.to_pb();
ineq.reset(static_cast<uint64_t>(offset) * p.k()); ineq.reset(static_cast<uint64_t>(offset) * p.k());
for (wliteral wl : p) ineq.push(wl.second, offset * wl.first); for (auto [w, lit] : p) ineq.push(lit, offset * w);
if (p.lit() != sat::null_literal) ineq.push(~p.lit(), offset * p.k()); if (p.lit() != sat::null_literal) ineq.push(~p.lit(), offset * p.k());
break; break;
} }
@ -3734,13 +3735,13 @@ namespace pb {
lits.reset(); lits.reset();
coeffs.reset(); coeffs.reset();
unsigned sum = 0; unsigned sum = 0;
for (wliteral wl : p) sum += wl.first; for (auto [w, lit] : p) sum += w;
if (p.lit() == sat::null_literal) { if (p.lit() == sat::null_literal) {
// w1 + .. + w_n >= k // w1 + .. + w_n >= k
// <=> // <=>
// ~wl + ... + ~w_n <= sum_of_weights - k // ~wl + ... + ~w_n <= sum_of_weights - k
for (wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); for (auto [w, lit] : p) lits.push_back(~lit), coeffs.push_back(w);
add_pb(lits.size(), lits.data(), coeffs.data(), sum - p.k()); add_pb(lits.size(), lits.data(), coeffs.data(), sum - p.k());
} }
else { else {
@ -3752,13 +3753,13 @@ namespace pb {
// (sum - k + 1)*~lit + w1 + .. + w_n <= sum // (sum - k + 1)*~lit + w1 + .. + w_n <= sum
// k*lit + ~wl + ... + ~w_n <= sum // k*lit + ~wl + ... + ~w_n <= sum
lits.push_back(p.lit()), coeffs.push_back(p.k()); lits.push_back(p.lit()), coeffs.push_back(p.k());
for (wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); for (auto [w, lit] : p) lits.push_back(~lit), coeffs.push_back(w);
add_pb(lits.size(), lits.data(), coeffs.data(), sum); add_pb(lits.size(), lits.data(), coeffs.data(), sum);
lits.reset(); lits.reset();
coeffs.reset(); coeffs.reset();
lits.push_back(~p.lit()), coeffs.push_back(sum + 1 - p.k()); lits.push_back(~p.lit()), coeffs.push_back(sum + 1 - p.k());
for (wliteral wl : p) lits.push_back(wl.second), coeffs.push_back(wl.first); for (auto [w, lit] : p) lits.push_back(lit), coeffs.push_back(w);
add_pb(lits.size(), lits.data(), coeffs.data(), sum); add_pb(lits.size(), lits.data(), coeffs.data(), sum);
} }
break; break;