diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index eab579ff1..7f2ae5e89 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -873,38 +873,40 @@ namespace sat { m_active_vars.shrink(j); } - void ba_solver::inc_coeff(literal l, int64 offset) { + void ba_solver::inc_coeff(literal l, unsigned offset) { SASSERT(offset > 0); bool_var v = l.var(); SASSERT(v != null_bool_var); - if (static_cast(m_coeffs.size()) <= v) { - m_coeffs.resize(v + 1, 0); - } + m_coeffs.reserve(v + 1, 0); + int64 coeff0 = m_coeffs[v]; if (coeff0 == 0) { m_active_vars.push_back(v); } - int64 inc = l.sign() ? -offset : offset; + int64 loffset = static_cast(offset); + int64 inc = l.sign() ? -loffset : loffset; int64 coeff1 = inc + coeff0; m_coeffs[v] = coeff1; if (coeff1 > INT_MAX || coeff1 < INT_MIN) { - std::cout << "overflow update coefficient " << coeff1 << "\n"; + std::cout << "overflow update coefficient " << coeff1 << " offset: " << offset << " coeff0: " << coeff0 << "\n"; + UNREACHABLE(); m_overflow = true; return; } if (coeff0 > 0 && inc < 0) { - m_bound -= coeff0 - std::max(0LL, coeff1); + inc_bound(std::max(0LL, coeff1) - coeff0); } else if (coeff0 < 0 && inc > 0) { - m_bound -= std::min(0LL, coeff1) - coeff0; + inc_bound(coeff0 - std::min(0LL, coeff1)); } + // reduce coefficient to be no larger than bound. - if (coeff1 > m_bound) { + if (coeff1 > static_cast(m_bound)) { m_coeffs[v] = m_bound; } - else if (coeff1 < 0 && -coeff1 > m_bound) { + else if (coeff1 < 0 && -coeff1 > static_cast(m_bound)) { m_coeffs[v] = m_bound; } } @@ -913,8 +915,13 @@ namespace sat { return m_coeffs.get(v, 0); } - int64 ba_solver::get_abs_coeff(bool_var v) const { - return abs(get_coeff(v)); + unsigned ba_solver::get_abs_coeff(bool_var v) const { + int64 c = get_coeff(v); + if (c < INT_MIN+1 || c > UINT_MAX) { + m_overflow = true; + return UINT_MAX; + } + return static_cast(abs(c)); } int ba_solver::get_int_coeff(bool_var v) const { @@ -927,13 +934,22 @@ namespace sat { return static_cast(c); } - unsigned ba_solver::get_bound() const { - if (m_bound < 0 || m_bound > UINT_MAX) { - std::cout << "overflow bound " << m_bound << "\n"; + void ba_solver::inc_bound(int64 i) { + if (i < INT_MIN || i > INT_MAX) { m_overflow = true; - return 1; + return; + } + int64 new_bound = m_bound; + new_bound += i; + if (new_bound < 0) { + m_bound = 0; + } + else if (new_bound > UINT_MAX) { + m_overflow = true; + } + else { + m_bound = static_cast(new_bound); } - return static_cast(m_bound); } void ba_solver::reset_coeffs() { @@ -965,14 +981,14 @@ namespace sat { } literal_vector const& lits = s().m_trail; unsigned idx = lits.size() - 1; - int64 offset = 1; + unsigned offset = 1; DEBUG_CODE(active2pb(m_A);); unsigned init_marks = m_num_marks; do { - if (m_overflow || offset > (1 << 12)) { + if (m_overflow || offset > (1 << 12) || m_bound == 0) { IF_VERBOSE(20, verbose_stream() << "offset: " << offset << "\n"; active2pb(m_A); display(verbose_stream(), m_A); @@ -987,9 +1003,8 @@ namespace sat { TRACE("sat_verbose", display(tout, m_A);); TRACE("sat", tout << "process consequent: " << consequent << ":\n"; s().display_justification(tout, js) << "\n";); SASSERT(offset > 0); - SASSERT(m_bound >= 0); - DEBUG_CODE(justification2pb(js, consequent, offset, m_B);); + // DEBUG_CODE(justification2pb(js, consequent, offset, m_B);); if (_debug_conflict) { std::cout << consequent << "\n"; @@ -1000,23 +1015,23 @@ namespace sat { switch(js.get_kind()) { case justification::NONE: SASSERT (consequent != null_literal); - m_bound += offset; + inc_bound(offset); break; case justification::BINARY: - m_bound += offset; + inc_bound(offset); SASSERT (consequent != null_literal); inc_coeff(consequent, offset); process_antecedent(js.get_literal(), offset); break; case justification::TERNARY: - m_bound += offset; + inc_bound(offset); SASSERT (consequent != null_literal); inc_coeff(consequent, offset); process_antecedent(js.get_literal1(), offset); process_antecedent(js.get_literal2(), offset); break; case justification::CLAUSE: { - m_bound += offset; + inc_bound(offset); clause & c = *(s().m_cls_allocator.get_clause(js.get_clause_offset())); unsigned i = 0; if (consequent != null_literal) { @@ -1041,14 +1056,14 @@ namespace sat { switch (cnstr.tag()) { case card_t: { card& c = cnstr.to_card(); - m_bound += offset * c.k(); + inc_bound(static_cast(offset) * c.k()); process_card(c, offset); break; } case pb_t: { pb& p = cnstr.to_pb(); m_lemma.reset(); - m_bound += offset; + inc_bound(offset); inc_coeff(consequent, offset); get_antecedents(consequent, p, m_lemma); TRACE("sat", display(tout, p, true); tout << m_lemma << "\n";); @@ -1062,7 +1077,7 @@ namespace sat { case xor_t: { // jus.push_back(js); m_lemma.reset(); - m_bound += offset; + inc_bound(offset); inc_coeff(consequent, offset); get_xor_antecedents(consequent, idx, js, m_lemma); for (literal l : m_lemma) process_antecedent(~l, offset); @@ -1115,7 +1130,8 @@ namespace sat { js = s().m_justification[v]; offset = get_abs_coeff(v); if (offset > m_bound) { - m_coeffs[v] = (get_coeff(v) < 0) ? -m_bound : m_bound; + int64 bound64 = static_cast(m_bound); + m_coeffs[v] = (get_coeff(v) < 0) ? -bound64 : bound64; offset = m_bound; DEBUG_CODE(active2pb(m_A);); } @@ -1196,8 +1212,8 @@ namespace sat { bool ba_solver::create_asserting_lemma() { adjust_conflict_level: - - int64 slack = -m_bound; + int64 bound64 = m_bound; + int64 slack = -bound64; for (bool_var v : m_active_vars) { slack += get_abs_coeff(v); } @@ -1214,6 +1230,7 @@ namespace sat { bool append = coeff != 0 && val != l_undef && (coeff < 0 == is_true); if (append) { literal lit(v, !is_true); + unsigned acoeff = get_abs_coeff(v); if (lvl(lit) == m_conflict_lvl) { if (m_lemma[0] == null_literal) { asserting_coeff = abs(coeff); @@ -1269,31 +1286,30 @@ namespace sat { if (1 == get_abs_coeff(v)) return; } - SASSERT(0 <= m_bound && m_bound <= UINT_MAX); - unsigned g = 0; for (unsigned i = 0; g != 1 && i < m_active_vars.size(); ++i) { bool_var v = m_active_vars[i]; - int64 coeff = get_abs_coeff(v); + unsigned coeff = get_abs_coeff(v); if (coeff == 0) { continue; } - if (m_bound < coeff) { + if (m_bound < coeff) { + int64 bound64 = m_bound; if (get_coeff(v) > 0) { - m_coeffs[v] = m_bound; + m_coeffs[v] = bound64; } else { - m_coeffs[v] = -m_bound; + m_coeffs[v] = -bound64; } coeff = m_bound; } SASSERT(0 < coeff && coeff <= m_bound); if (g == 0) { - g = static_cast(coeff); + g = coeff; } else { - g = u_gcd(g, static_cast(coeff)); + g = u_gcd(g, coeff); } } @@ -1307,7 +1323,7 @@ namespace sat { } } - void ba_solver::process_card(card& c, int64 offset) { + void ba_solver::process_card(card& c, unsigned offset) { literal lit = c.lit(); SASSERT(c.k() <= c.size()); SASSERT(lit == null_literal || value(lit) == l_true); @@ -1319,11 +1335,18 @@ namespace sat { inc_coeff(c[i], offset); } if (lit != null_literal) { - process_antecedent(~lit, c.k() * offset); + uint64 offset1 = static_cast(offset) * c.k(); + if (offset1 > UINT_MAX) { + m_overflow = true; + std::cout << "cardinality offset overflow\n"; + } + else { + process_antecedent(~lit, static_cast(offset1)); + } } } - void ba_solver::process_antecedent(literal l, int64 offset) { + void ba_solver::process_antecedent(literal l, unsigned offset) { SASSERT(value(l) == l_false); bool_var v = l.var(); unsigned level = lvl(v); @@ -3286,7 +3309,8 @@ namespace sat { } bool ba_solver::validate_lemma() { - int64 val = -m_bound; + int64 bound64 = m_bound; + int64 val = -bound64; reset_active_var_set(); for (bool_var v : m_active_vars) { if (m_active_var_set.contains(v)) continue; @@ -3326,7 +3350,7 @@ namespace sat { ba_solver::constraint* ba_solver::active2constraint() { reset_active_var_set(); svector wlits; - uint64_t sum = 0; + uint64 sum = 0; if (m_bound == 1) return 0; if (m_overflow) return 0; @@ -3335,16 +3359,15 @@ namespace sat { if (m_active_var_set.contains(v) || coeff == 0) continue; m_active_var_set.insert(v); literal lit(v, coeff < 0); - wlits.push_back(wliteral(abs(coeff), lit)); - sum += abs(coeff); + wlits.push_back(wliteral(get_abs_coeff(v), lit)); + sum += get_abs_coeff(v); } - unsigned k = get_bound(); if (m_overflow || sum >= UINT_MAX/2) { return 0; } else { - return add_pb_ge(null_literal, wlits, k, true); + return add_pb_ge(null_literal, wlits, m_bound, true); } } @@ -3384,11 +3407,11 @@ namespace sat { svector wlits; for (bool_var v : m_active_vars) { int coeff = get_int_coeff(v); - wlits.push_back(std::make_pair(abs(coeff), literal(v, coeff < 0))); + wlits.push_back(std::make_pair(get_abs_coeff(v), literal(v, coeff < 0))); } std::sort(wlits.begin(), wlits.end(), compare_wlit()); unsigned k = 0; - int sum = 0, sum0 = 0; + uint64 sum = 0, sum0 = 0; for (wliteral wl : wlits) { if (sum >= m_bound) break; sum0 = sum; @@ -3400,7 +3423,7 @@ namespace sat { } while (!wlits.empty()) { wliteral wl = wlits.back(); - if (wl.first + sum0 >= get_bound()) break; + if (wl.first + sum0 >= m_bound) break; wlits.pop_back(); sum0 += wl.first; } diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index c30fd25b4..93c6f14e1 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -215,7 +215,7 @@ namespace sat { unsigned m_conflict_lvl; svector m_coeffs; svector m_active_vars; - int64 m_bound; + unsigned m_bound; tracked_uint_set m_active_var_set; literal_vector m_lemma; literal_vector m_skipped; @@ -366,15 +366,16 @@ namespace sat { mutable bool m_overflow; void reset_active_var_set(); void normalize_active_coeffs(); - void inc_coeff(literal l, int64 offset); + void inc_coeff(literal l, unsigned offset); int64 get_coeff(bool_var v) const; - int64 get_abs_coeff(bool_var v) const; + unsigned get_abs_coeff(bool_var v) const; int get_int_coeff(bool_var v) const; unsigned get_bound() const; + void inc_bound(int64 i); literal get_asserting_literal(literal conseq); - void process_antecedent(literal l, int64 offset); - void process_card(card& c, int64 offset); + void process_antecedent(literal l, unsigned offset); + void process_card(card& c, unsigned offset); void cut(); bool create_asserting_lemma();