diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index e5c0bcddb..5d4eb3fc5 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -351,7 +351,7 @@ namespace opt { void context::get_model_core(model_ref& mdl) { mdl = m_model; fix_model(mdl); - mdl->set_model_completion(true); + if (mdl) mdl->set_model_completion(true); TRACE("opt", tout << *mdl;); } diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index c86ea3317..2c33bd955 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -49,6 +49,11 @@ namespace sat { return static_cast(*this); } + ba_solver::pb_base const& ba_solver::constraint::to_pb_base() const{ + SASSERT(is_pb() || is_card()); + return static_cast(*this); + } + ba_solver::xr& ba_solver::constraint::to_xr() { SASSERT(is_xr()); return static_cast(*this); @@ -1015,6 +1020,7 @@ namespace sat { bool_var v = l.var(); SASSERT(v != null_bool_var); m_coeffs.reserve(v + 1, 0); + TRACE("ba_verbose", tout << l << " " << offset << "\n";); int64_t coeff0 = m_coeffs[v]; if (coeff0 == 0) { @@ -1064,6 +1070,7 @@ namespace sat { l = literal(v, c1 < 0); c1 = std::abs(c1); c = static_cast(c1); + // TRACE("ba", tout << l << " " << c << "\n";); m_overflow |= c != c1; } @@ -1096,6 +1103,17 @@ namespace sat { m_active_vars.reset(); } + void ba_solver::init_visited() { + m_visited_ts++; + if (m_visited_ts == 0) { + m_visited_ts = 1; + m_visited.reset(); + } + while (m_visited.size() < 2*s().num_vars()) { + m_visited.push_back(0); + } + } + static bool _debug_conflict = false; static literal _debug_consequent = null_literal; static unsigned_vector _debug_var2position; @@ -1120,7 +1138,7 @@ namespace sat { IF_VERBOSE(0, active2pb(m_A); uint64_t c = 0; - for (uint64_t c1 : m_A.m_coeffs) c += c1; + for (wliteral l : m_A.m_wlits) c += l.first; verbose_stream() << "sum of coefficients: " << c << "\n"; display(verbose_stream(), m_A, true); verbose_stream() << "conflicting literal: " << s().m_not_l << "\n";); @@ -1154,6 +1172,9 @@ namespace sat { justification js = s().m_conflict; TRACE("ba", tout << consequent << " " << js << "\n";); m_conflict_lvl = s().get_max_lvl(consequent, js); + if (m_conflict_lvl == 0) { + return l_undef; + } if (consequent != null_literal) { consequent.neg(); process_antecedent(consequent, 1); @@ -1316,27 +1337,14 @@ namespace sat { DEBUG_CODE(for (bool_var i = 0; i < static_cast(s().num_vars()); ++i) SASSERT(!s().is_marked(i));); SASSERT(validate_lemma()); - if (!create_asserting_lemma()) { goto bail_out; } + active2card(); + DEBUG_CODE(VERIFY(validate_conflict(m_lemma, m_A));); - TRACE("ba", tout << m_lemma << "\n";); - - if (get_config().m_drat) { - svector ps; // TBD fill in - drat_add(m_lemma, ps); - } - - s().m_lemma.reset(); - s().m_lemma.append(m_lemma); - for (unsigned i = 1; i < m_lemma.size(); ++i) { - CTRACE("ba", s().is_marked(m_lemma[i].var()), tout << "marked: " << m_lemma[i] << "\n";); - s().mark(m_lemma[i].var()); - } - return l_true; bail_out: @@ -1345,8 +1353,7 @@ namespace sat { } - uint64_t ba_solver::ineq::coeff(literal l) const { - bool_var v = l.var(); + unsigned ba_solver::ineq::bv_coeff(bool_var v) const { for (unsigned i = size(); i-- > 0; ) { if (lit(i).var() == v) return coeff(i); } @@ -1354,10 +1361,10 @@ namespace sat { return 0; } - void ba_solver::ineq::divide(uint64_t c) { + void ba_solver::ineq::divide(unsigned c) { if (c == 1) return; for (unsigned i = size(); i-- > 0; ) { - m_coeffs[i] = (m_coeffs[i] + c - 1) / c; + m_wlits[i].first = (coeff(i) + c - 1) / c; } m_k = (m_k + c - 1) / c; } @@ -1366,48 +1373,57 @@ namespace sat { * Remove literal at position i, subtract coefficient from bound. */ void ba_solver::ineq::weaken(unsigned i) { - uint64_t ci = coeff(i); + unsigned ci = coeff(i); SASSERT(m_k >= ci); m_k -= ci; - m_lits[i] = m_lits.back(); - m_coeffs[i] = m_coeffs.back(); - m_lits.pop_back(); - m_coeffs.pop_back(); + m_wlits[i] = m_wlits.back(); + m_wlits.pop_back(); } /** * Round coefficient of inequality to 1. */ - void ba_solver::round_to_one(ineq& ineq, literal lit) { - uint64_t c = ineq.coeff(lit); + void ba_solver::round_to_one(ineq& ineq, bool_var v) { + unsigned c = ineq.bv_coeff(v); if (c == 1) return; unsigned sz = ineq.size(); for (unsigned i = 0; i < sz; ++i) { - uint64_t ci = ineq.coeff(i); - if (ci % c != 0 && !is_false(ineq.lit(i))) { - ineq.weaken(i); - --i; - --sz; + unsigned ci = ineq.coeff(i); + unsigned q = ci % c; + if (q != 0 && !is_false(ineq.lit(i))) { + if (q == ci) { + ineq.weaken(i); + --i; + --sz; + } + else { + ineq.m_wlits[i].first -= q; + ineq.m_k -= q; + } } } ineq.divide(c); } - void ba_solver::round_to_one(literal lit) { - uint64_t c = get_coeff(lit); - if (c == 1) return; + void ba_solver::round_to_one(bool_var w) { + unsigned c = get_abs_coeff(w); + if (c == 1 || c == 0) return; for (bool_var v : m_active_vars) { literal l; unsigned ci; get_coeff(v, l, ci); - if (ci > 0 && ci % c != 0 && !is_false(l)) { - m_coeffs[v] = 0; + unsigned q = ci % c; + if (q != 0 && !is_false(l)) { + m_coeffs[v] = ci - q; + m_bound -= q; + SASSERT(m_bound > 0); } } divide(c); + SASSERT(validate_lemma()); } - void ba_solver::divide(uint64_t c) { + void ba_solver::divide(unsigned c) { SASSERT(c != 0); if (c == 1) return; reset_active_var_set(); @@ -1426,20 +1442,17 @@ namespace sat { m_active_vars[j++] = v; } m_active_vars.shrink(j); - if (m_bound % c != 0) { - ++m_stats.m_num_cut; - m_bound = static_cast((m_bound + c - 1) / c); - } + m_bound = static_cast((m_bound + c - 1) / c); } void ba_solver::resolve_on(literal consequent) { - round_to_one(consequent); + round_to_one(consequent.var()); m_coeffs[consequent.var()] = 0; } void ba_solver::resolve_with(ineq const& ineq) { TRACE("ba", display(tout, ineq, true);); - inc_bound(1 + ineq.m_k); + inc_bound(ineq.m_k); for (unsigned i = ineq.size(); i-- > 0; ) { literal l = ineq.lit(i); inc_coeff(l, static_cast(ineq.coeff(i))); @@ -1457,6 +1470,23 @@ namespace sat { --idx; } } + + /** + * \brief mark variables that are on the assignment stack but + * below the current processing level. + */ + void ba_solver::mark_variables(ineq const& ineq) { + for (wliteral wl : ineq.m_wlits) { + literal l = wl.second; + if (!is_false(l)) continue; + bool_var v = l.var(); + unsigned level = lvl(v); + if (!s().is_marked(v) && !is_visited(v) && level == m_conflict_lvl) { + s().mark(v); + ++m_num_marks; + } + } + } lbool ba_solver::resolve_conflict_rs() { if (0 == m_num_propagations_since_pop) { @@ -1464,49 +1494,56 @@ namespace sat { } m_overflow = false; reset_coeffs(); + init_visited(); m_num_marks = 0; m_bound = 0; literal consequent = s().m_not_l; justification js = s().m_conflict; - TRACE("ba", tout << consequent << " " << js << "\n";); m_conflict_lvl = s().get_max_lvl(consequent, js); + if (m_conflict_lvl == 0) { + return l_undef; + } if (consequent != null_literal) { consequent.neg(); process_antecedent(consequent, 1); } + TRACE("ba", tout << consequent << " " << js << "\n";); unsigned idx = s().m_trail.size() - 1; do { - // TBD: termination condition - // if UIP is below m_conflict level - TRACE("ba", s().display_justification(tout << "process consequent: " << consequent << " : ", js) << "\n"; - active2pb(m_A); display(tout, m_A, true); + if (consequent != null_literal) { active2pb(m_A); display(tout, m_A, true); } ); + switch (js.get_kind()) { case justification::NONE: SASSERT(consequent != null_literal); - resolve_on(consequent); + inc_bound(1); + round_to_one(consequent.var()); + inc_coeff(consequent, 1); break; case justification::BINARY: SASSERT(consequent != null_literal); - resolve_on(consequent); + inc_bound(1); + round_to_one(consequent.var()); + inc_coeff(consequent, 1); process_antecedent(js.get_literal()); break; case justification::TERNARY: SASSERT(consequent != null_literal); - resolve_on(consequent); + inc_bound(1); + round_to_one(consequent.var()); + inc_coeff(consequent, 1); process_antecedent(js.get_literal1()); process_antecedent(js.get_literal2()); break; case justification::CLAUSE: { + inc_bound(1); clause & c = s().get_clause(js); unsigned i = 0; - if (consequent == null_literal) { - m_bound = 1; - } - else { - resolve_on(consequent); + if (consequent != null_literal) { + round_to_one(consequent.var()); + inc_coeff(consequent, 1); if (c[0] == consequent) { i = 1; } @@ -1525,25 +1562,51 @@ namespace sat { ++m_stats.m_num_resolves; ext_justification_idx index = js.get_ext_justification_idx(); constraint& cnstr = index2constraint(index); - constraint2pb(cnstr, consequent, 1, m_A); + switch (cnstr.tag()) { + case card_t: + case pb_t: { + pb_base const& p = cnstr.to_pb_base(); + unsigned k = p.k(), sz = p.size(); + m_A.reset(0); + for (unsigned i = 0; i < sz; ++i) { + literal l = p.get_lit(i); + unsigned c = p.get_coeff(i); + if (l == consequent || !is_visited(l.var())) { + m_A.push(l, c); + } + else { + SASSERT(k > c); + k -= c; + } + } + SASSERT(k > 0); + if (p.lit() != null_literal) m_A.push(~p.lit(), k); + m_A.m_k = k; + break; + } + default: + constraint2pb(cnstr, consequent, 1, m_A); + break; + } + mark_variables(m_A); if (consequent == null_literal) { m_bound = static_cast(m_A.m_k); - for (unsigned i = m_A.size(); i-- > 0; ) { - inc_coeff(m_A.lit(i), static_cast(m_A.coeff(i))); + for (wliteral wl : m_A.m_wlits) { + process_antecedent(wl.second, wl.first); } } else { - round_to_one(consequent); - round_to_one(m_A, consequent); + round_to_one(consequent.var()); + if (cnstr.tag() == pb_t) round_to_one(m_A, consequent.var()); resolve_with(m_A); } - break; } default: UNREACHABLE(); break; } + SASSERT(validate_lemma()); cut(); // find the next marked variable in the assignment stack @@ -1551,7 +1614,14 @@ namespace sat { while (true) { consequent = s().m_trail[idx]; v = consequent.var(); - if (s().is_marked(v)) break; + mark_visited(v); + if (s().is_marked(v)) { + if (get_coeff(v) != 0) { + break; + } + s().reset_mark(v); + --m_num_marks; + } if (idx == 0) { goto bail_out; } @@ -1567,20 +1637,24 @@ namespace sat { while (m_num_marks > 0 && !m_overflow); TRACE("ba", active2pb(m_A); display(tout, m_A, true);); - active2constraint(); - if (!m_overflow) { +#if 0 + // why this? + if (!m_overflow && consequent != null_literal) { + round_to_one(consequent.var()); + } +#endif + if (!m_overflow && create_asserting_lemma()) { + active2constraint(); return l_true; } bail_out: + IF_VERBOSE(1, verbose_stream() << "bail\n"); m_overflow = false; return l_undef; } bool ba_solver::create_asserting_lemma() { - bool adjusted = false; - - adjust_conflict_level: int64_t bound64 = m_bound; int64_t slack = -bound64; for (bool_var v : m_active_vars) { @@ -1629,20 +1703,22 @@ namespace sat { if (m_lemma[0] == null_literal) { if (m_lemma.size() == 1) { s().set_conflict(justification()); - return false; } return false; - unsigned old_level = m_conflict_lvl; - m_conflict_lvl = 0; - for (unsigned i = 1; i < m_lemma.size(); ++i) { - m_conflict_lvl = std::max(m_conflict_lvl, lvl(m_lemma[i])); - } - IF_VERBOSE(1, verbose_stream() << "(sat.backjump :new-level " << m_conflict_lvl << " :old-level " << old_level << ")\n";); - adjusted = true; - goto adjust_conflict_level; } - if (!adjusted) { - active2card(); + + TRACE("ba", tout << m_lemma << "\n";); + + if (get_config().m_drat) { + svector ps; // TBD fill in + drat_add(m_lemma, ps); + } + + s().m_lemma.reset(); + s().m_lemma.append(m_lemma); + for (unsigned i = 1; i < m_lemma.size(); ++i) { + CTRACE("ba", s().is_marked(m_lemma[i].var()), tout << "marked: " << m_lemma[i] << "\n";); + s().mark(m_lemma[i].var()); } return true; } @@ -1732,7 +1808,7 @@ namespace sat { bool_var v = l.var(); unsigned level = lvl(v); - if (level > 0 && !s().is_marked(v) && level == m_conflict_lvl) { + if (!s().is_marked(v) && level == m_conflict_lvl) { s().mark(v); ++m_num_marks; if (_debug_conflict && _debug_consequent != null_literal && _debug_var2position[_debug_consequent.var()] < _debug_var2position[l.var()]) { @@ -2724,7 +2800,8 @@ namespace sat { set_non_external(); if (get_config().m_elim_vars) elim_pure(); for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) subsumption(*m_constraints[i]); - for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) subsumption(*m_learned[i]); + for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) subsumption(*m_learned[i]); + unit_strengthen(); cleanup_clauses(); cleanup_constraints(); update_pure(); @@ -3170,9 +3247,10 @@ namespace sat { bool found_dup = false; bool found_root = false; + init_visited(); for (unsigned i = 0; i < c.size(); ++i) { literal l = c.get_lit(i); - if (is_marked(l)) { + if (is_visited(l)) { found_dup = true; break; } @@ -3182,10 +3260,7 @@ namespace sat { } } for (unsigned i = 0; i < c.size(); ++i) { - literal l = c.get_lit(i); - unmark_visited(l); - unmark_visited(~l); - found_root |= l.var() == root.var(); + found_root |= c.get_lit(i).var() == root.var(); } if (found_root) { @@ -3342,6 +3417,78 @@ namespace sat { return pure_literals; } + /** + * Strengthen inequalities using binary implication information. + * + * x -> ~y, x -> ~z, y + z + u >= 2 + * ---------------------------------- + * y + z + u + ~x >= 3 + * + * for c : constraints + * for l : c: + * slack <- of c under root(~l) + * if slack < 0: + * add ~root(~l) to c, k <- k + 1 + */ + void ba_solver::unit_strengthen() { + big big(s().m_rand); + big.init(s(), true); + for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) + unit_strengthen(big, *m_constraints[i]); + for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) + unit_strengthen(big, *m_learned[i]); + } + + void ba_solver::unit_strengthen(big& big, constraint& c) { + if (c.was_removed()) return; + switch (c.tag()) { + case card_t: + unit_strengthen(big, c.to_card()); + break; + case pb_t: + unit_strengthen(big, c.to_pb()); + break; + default: + break; + } + } + + void ba_solver::unit_strengthen(big& big, card& c) { + for (literal l : c) { + literal r = big.get_root(~l); + if (r == ~l) continue; + unsigned k = c.k(); + for (literal u : c) { + if (big.reaches(r, ~u)) { + if (k == 0) { + // ~r + C >= c.k() + 1 + IF_VERBOSE(0, verbose_stream() << "TBD add " << ~r << " to " << c << "\n";); + return; + } + --k; + } + } + } + } + + void ba_solver::unit_strengthen(big& big, pb& p) { + for (wliteral wl : p) { + literal r = big.get_root(~wl.second); + if (r == ~wl.second) continue; + unsigned k = p.k(); + for (wliteral u : p) { + if (big.reaches(r, ~u.second)) { + if (k < u.first) { + // ~r + p >= p.k() + 1 + IF_VERBOSE(0, verbose_stream() << "TBD add " << ~r << " to " << p << "\n";); + return; + } + k -= u.first; + } + } + } + } + void ba_solver::subsumption(constraint& cnstr) { if (cnstr.was_removed()) return; switch (cnstr.tag()) { @@ -3430,10 +3577,10 @@ namespace sat { unsigned common = 0; comp.reset(); for (literal l : c2) { - if (is_marked(l)) { + if (is_visited(l)) { ++common; } - else if (is_marked(~l)) { + else if (is_visited(~l)) { comp.push_back(l); } else { @@ -3455,10 +3602,10 @@ namespace sat { self = false; for (literal l : c2) { - if (is_marked(l)) { + if (is_visited(l)) { ++common; } - else if (is_marked(~l)) { + else if (is_visited(~l)) { ++complement; } else { @@ -3482,10 +3629,10 @@ namespace sat { unsigned num_sub = 0; for (unsigned i = 0; i < p2.size(); ++i) { literal l = p2.get_lit(i); - if (is_marked(l) && m_weights[l.index()] <= p2.get_coeff(i)) { + if (is_visited(l) && m_weights[l.index()] <= p2.get_coeff(i)) { ++num_sub; } - if (p1.size() + i > p2.size() + num_sub) return false; + if (p1.size() + i > p2.size() + num_sub) return false; } return num_sub == p1.size(); } @@ -3550,7 +3697,7 @@ namespace sat { clear_watch(c2); unsigned j = 0; for (unsigned i = 0; i < c2.size(); ++i) { - if (!is_marked(~c2[i])) { + if (!is_visited(~c2[i])) { c2[j++] = c2[i]; } } @@ -3586,7 +3733,7 @@ namespace sat { void ba_solver::binary_subsumption(card& c1, literal lit) { if (c1.k() + 1 != c1.size()) return; - SASSERT(is_marked(lit)); + SASSERT(is_visited(lit)); SASSERT(!c1.was_removed()); watch_list & wlist = get_wlist(~lit); watch_list::iterator it = wlist.begin(); @@ -3594,7 +3741,7 @@ namespace sat { watch_list::iterator end = wlist.end(); for (; it != end; ++it) { watched w = *it; - if (w.is_binary_clause() && is_marked(w.get_literal())) { + if (w.is_binary_clause() && is_visited(w.get_literal())) { ++m_stats.m_num_bin_subsumes; IF_VERBOSE(10, verbose_stream() << c1 << " subsumes (" << lit << " " << w.get_literal() << ")\n";); if (!w.is_learned()) { @@ -3616,6 +3763,7 @@ namespace sat { return; } clause_vector removed_clauses; + init_visited(); for (literal l : c1) mark_visited(l); for (unsigned i = 0; i < std::min(c1.size(), c1.k() + 1); ++i) { literal lit = c1[i]; @@ -3623,7 +3771,6 @@ namespace sat { clause_subsumption(c1, lit, removed_clauses); binary_subsumption(c1, lit); } - for (literal l : c1) unmark_visited(l); m_clause_removed |= !removed_clauses.empty(); for (clause *c : removed_clauses) { c->set_removed(true); @@ -3635,6 +3782,7 @@ namespace sat { if (p1.was_removed() || p1.lit() != null_literal) { return; } + init_visited(); for (wliteral l : p1) { SASSERT(m_weights[l.second.index()] == 0); m_weights.setx(l.second.index(), l.first, 0); @@ -3646,7 +3794,6 @@ namespace sat { } for (wliteral l : p1) { m_weights[l.second.index()] = 0; - unmark_visited(l.second); } } @@ -3844,9 +3991,9 @@ namespace sat { } void ba_solver::display(std::ostream& out, ineq const& ineq, bool values) const { - for (unsigned i = 0; i < ineq.m_lits.size(); ++i) { - out << ineq.m_coeffs[i] << "*" << ineq.m_lits[i] << " "; - if (values) out << value(ineq.m_lits[i]) << " "; + for (unsigned i = 0; i < ineq.size(); ++i) { + out << ineq.coeff(i) << "*" << ineq.lit(i) << " "; + if (values) out << value(ineq.lit(i)) << " "; } out << ">= " << ineq.m_k << "\n"; } @@ -4031,14 +4178,11 @@ namespace sat { reset_active_var_set(); for (bool_var v : m_active_vars) { if (m_active_var_set.contains(v)) continue; - int64_t coeff = get_coeff(v); + unsigned coeff; + literal lit; + get_coeff(v, lit, coeff); if (coeff == 0) continue; - m_active_var_set.insert(v); - literal lit(v, false); - if (coeff < 0 && value(lit) != l_true) { - val -= coeff; - } - else if (coeff > 0 && value(lit) != l_false) { + if (!is_false(lit)) { val += coeff; } } @@ -4051,31 +4195,26 @@ namespace sat { } void ba_solver::active2pb(ineq& p) { - reset_active_var_set(); p.reset(m_bound); + active2wlits(p.m_wlits); + } + + void ba_solver::active2wlits() { + m_wlits.reset(); + active2wlits(m_wlits); + } + + void ba_solver::active2wlits(svector& wlits) { + reset_active_var_set(); + uint64_t sum = 0; for (bool_var v : m_active_vars) { - if (m_active_var_set.contains(v)) continue; + if (m_active_var_set.contains(v)) continue; unsigned coeff; literal lit; get_coeff(v, lit, coeff); if (coeff == 0) continue; m_active_var_set.insert(v); - p.m_lits.push_back(lit); - p.m_coeffs.push_back(coeff); - } - } - - void ba_solver::active2wlits() { - reset_active_var_set(); - m_wlits.reset(); - uint64_t sum = 0; - for (bool_var v : m_active_vars) { - unsigned coeff; - literal lit; - get_coeff(v, lit, coeff); - if (m_active_var_set.contains(v) || coeff == 0) continue; - m_active_var_set.insert(v); - m_wlits.push_back(wliteral(static_cast(coeff), lit)); + wlits.push_back(wliteral(coeff, lit)); sum += coeff; } m_overflow |= sum >= UINT_MAX/2; @@ -4086,7 +4225,9 @@ namespace sat { if (m_overflow) { return nullptr; } - return add_pb_ge(null_literal, m_wlits, m_bound, true); + constraint* c = add_pb_ge(null_literal, m_wlits, m_bound, true); + TRACE("ba", if (c) display(tout, *c, true);); + return c; } /* @@ -4269,14 +4410,14 @@ namespace sat { return true; u_map coeffs; uint64_t k = m_A.m_k + m_B.m_k; - for (unsigned i = 0; i < m_A.m_lits.size(); ++i) { - uint64_t coeff = m_A.m_coeffs[i]; - SASSERT(!coeffs.contains(m_A.m_lits[i].index())); - coeffs.insert(m_A.m_lits[i].index(), coeff); + for (unsigned i = 0; i < m_A.size(); ++i) { + uint64_t coeff = m_A.coeff(i); + SASSERT(!coeffs.contains(m_A.lit(i).index())); + coeffs.insert(m_A.lit(i).index(), coeff); } - for (unsigned i = 0; i < m_B.m_lits.size(); ++i) { - uint64_t coeff1 = m_B.m_coeffs[i], coeff2; - literal lit = m_B.m_lits[i]; + for (unsigned i = 0; i < m_B.size(); ++i) { + uint64_t coeff1 = m_B.coeff(i), coeff2; + literal lit = m_B.lit(i); if (coeffs.find((~lit).index(), coeff2)) { if (coeff1 == coeff2) { coeffs.remove((~lit).index()); @@ -4301,11 +4442,11 @@ namespace sat { } } // C is above the sum of A and B - for (unsigned i = 0; i < m_C.m_lits.size(); ++i) { - literal lit = m_C.m_lits[i]; + for (unsigned i = 0; i < m_C.size(); ++i) { + literal lit = m_C.lit(i); uint64_t coeff; if (coeffs.find(lit.index(), coeff)) { - if (coeff > m_C.m_coeffs[i] && m_C.m_coeffs[i] < m_C.m_k) { + if (coeff > m_C.coeff(i) && m_C.coeff(i) < m_C.m_k) { goto violated; } coeffs.remove(lit.index()); @@ -4352,15 +4493,15 @@ namespace sat { */ literal ba_solver::translate_to_sat(solver& s, u_map& translation, ineq const& pb) { SASSERT(pb.m_k > 0); - if (pb.m_lits.size() > 1) { + if (pb.size() > 1) { ineq a, b; a.reset(pb.m_k); b.reset(pb.m_k); - for (unsigned i = 0; i < pb.m_lits.size()/2; ++i) { - a.push(pb.m_lits[i], pb.m_coeffs[i]); + for (unsigned i = 0; i < pb.size()/2; ++i) { + a.push(pb.lit(i), pb.coeff(i)); } - for (unsigned i = pb.m_lits.size()/2; i < pb.m_lits.size(); ++i) { - b.push(pb.m_lits[i], pb.m_coeffs[i]); + for (unsigned i = pb.size()/2; i < pb.size(); ++i) { + b.push(pb.lit(i), pb.coeff(i)); } bool_var v = s.mk_var(); literal lit(v, false); @@ -4372,8 +4513,8 @@ namespace sat { s.mk_clause(lits); return lit; } - if (pb.m_coeffs[0] >= pb.m_k) { - return translate_to_sat(s, translation, pb.m_lits[0]); + if (pb.coeff(0) >= pb.m_k) { + return translate_to_sat(s, translation, pb.lit(0)); } else { return null_literal; @@ -4425,9 +4566,9 @@ namespace sat { ba_solver::ineq ba_solver::negate(ineq const& a) const { ineq result; uint64_t sum = 0; - for (unsigned i = 0; i < a.m_lits.size(); ++i) { - result.push(~a.m_lits[i], a.m_coeffs[i]); - sum += a.m_coeffs[i]; + for (unsigned i = 0; i < a.size(); ++i) { + result.push(~a.lit(i), a.coeff(i)); + sum += a.coeff(i); } SASSERT(sum >= a.m_k + 1); result.m_k = sum + 1 - a.m_k; @@ -4446,15 +4587,15 @@ namespace sat { TRACE("ba", tout << "literal " << l << " is not false\n";); return false; } - if (!p.m_lits.contains(l)) { + if (!p.contains(l)) { TRACE("ba", tout << "lemma contains literal " << l << " not in inequality\n";); return false; } } uint64_t value = 0; - for (unsigned i = 0; i < p.m_lits.size(); ++i) { - uint64_t coeff = p.m_coeffs[i]; - if (!lits.contains(p.m_lits[i])) { + for (unsigned i = 0; i < p.size(); ++i) { + uint64_t coeff = p.coeff(i); + if (!lits.contains(p.lit(i))) { value += coeff; } } diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index 6adde156d..9418cdd99 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -25,6 +25,7 @@ Revision History: #include "sat/sat_solver.h" #include "sat/sat_lookahead.h" #include "sat/sat_unit_walk.h" +#include "sat/sat_big.h" #include "util/scoped_ptr_vector.h" #include "util/sorting_network.h" @@ -57,6 +58,7 @@ namespace sat { class card; class pb; class xr; + class pb_base; class constraint { protected: @@ -104,6 +106,7 @@ namespace sat { card const& to_card() const; pb const& to_pb() const; xr const& to_xr() const; + pb_base const& to_pb_base() const; bool is_card() const { return m_tag == card_t; } bool is_pb() const { return m_tag == pb_t; } bool is_xr() const { return m_tag == xr_t; } @@ -118,7 +121,7 @@ namespace sat { }; friend std::ostream& operator<<(std::ostream& out, constraint const& c); - + // base class for pb and cardinality constraints class pb_base : public constraint { protected: @@ -204,18 +207,18 @@ namespace sat { protected: struct ineq { - literal_vector m_lits; - svector m_coeffs; + svector m_wlits; uint64_t m_k; ineq(): m_k(0) {} - unsigned size() const { return m_lits.size(); } - literal lit(unsigned i) const { return m_lits[i]; } - uint64_t coeff(unsigned i) const { return m_coeffs[i]; } - void reset(uint64_t k) { m_lits.reset(); m_coeffs.reset(); m_k = k; } - void push(literal l, uint64_t c) { m_lits.push_back(l); m_coeffs.push_back(c); } - uint64_t coeff(literal lit) const; - void divide(uint64_t c); + unsigned size() const { return m_wlits.size(); } + literal lit(unsigned i) const { return m_wlits[i].second; } + unsigned coeff(unsigned i) const { return m_wlits[i].first; } + void reset(uint64_t k) { m_wlits.reset(); m_k = k; } + void push(literal l, unsigned c) { m_wlits.push_back(wliteral(c,l)); } + unsigned bv_coeff(bool_var v) const; + void divide(unsigned c); void weaken(unsigned i); + bool contains(literal l) const { for (auto wl : m_wlits) if (wl.second == l) return true; return false; } }; solver* m_solver; @@ -279,7 +282,8 @@ namespace sat { // simplification routines - svector m_visited; + svector m_visited; + unsigned m_visited_ts; vector> m_cnstr_use_list; use_list m_clause_use_list; bool m_simplify_change; @@ -298,9 +302,11 @@ namespace sat { void binary_subsumption(card& c1, literal lit); void clause_subsumption(card& c1, literal lit, clause_vector& removed_clauses); void card_subsumption(card& c1, literal lit); - void mark_visited(literal l) { m_visited[l.index()] = true; } - void unmark_visited(literal l) { m_visited[l.index()] = false; } - bool is_marked(literal l) const { return m_visited[l.index()] != 0; } + void init_visited(); + void mark_visited(literal l) { m_visited[l.index()] = m_visited_ts; } + void mark_visited(bool_var v) { mark_visited(literal(v, false)); } + bool is_visited(bool_var v) const { return is_visited(literal(v, false)); } + bool is_visited(literal l) const { return m_visited[l.index()] == m_visited_ts; } unsigned get_num_unblocked_bin(literal l); literal get_min_occurrence_literal(card const& c); void init_use_lists(); @@ -308,6 +314,10 @@ namespace sat { unsigned set_non_external(); unsigned elim_pure(); bool elim_pure(literal lit); + void unit_strengthen(); + void unit_strengthen(big& big, constraint& cs); + void unit_strengthen(big& big, card& c); + void unit_strengthen(big& big, pb& p); void subsumption(constraint& c1); void subsumption(card& c1); void gc_half(char const* _method); @@ -404,12 +414,13 @@ namespace sat { // RoundingPb conflict resolution lbool resolve_conflict_rs(); - void round_to_one(ineq& ineq, literal lit); - void round_to_one(literal lit); - void divide(uint64_t c); + void round_to_one(ineq& ineq, bool_var v); + void round_to_one(bool_var v); + void divide(unsigned c); void resolve_on(literal lit); void resolve_with(ineq const& ineq); void reset_marks(unsigned idx); + void mark_variables(ineq const& ineq); void bail_resolve_conflict(unsigned idx); @@ -487,6 +498,7 @@ namespace sat { constraint* active2constraint(); constraint* active2card(); void active2wlits(); + void active2wlits(svector& wlits); void justification2pb(justification const& j, literal lit, unsigned offset, ineq& p); void constraint2pb(constraint& cnstr, literal lit, unsigned offset, ineq& p); bool validate_resolvent(); diff --git a/src/sat/sat_big.cpp b/src/sat/sat_big.cpp index 35898a110..c1eeecd27 100644 --- a/src/sat/sat_big.cpp +++ b/src/sat/sat_big.cpp @@ -22,7 +22,8 @@ Revision History: namespace sat { big::big(random_gen& rand): - m_rand(rand) { + m_rand(rand), + m_include_cardinality(false) { } void big::init(solver& s, bool learned) { @@ -42,22 +43,22 @@ namespace sat { m_roots[v.index()] = false; edges.push_back(v); } -#if 0 - if (w.is_ext_constraint() && + if (m_include_cardinality && + w.is_ext_constraint() && s.m_ext && - learned && + learned && // cannot (yet) observe if ext constraints are learned !seen_idx.contains(w.get_ext_constraint_idx()) && s.m_ext->is_extended_binary(w.get_ext_constraint_idx(), r)) { seen_idx.insert(w.get_ext_constraint_idx(), true); - for (unsigned i = 0; i < r.size(); ++i) { - literal u = r[i]; - for (unsigned j = i + 1; j < r.size(); ++j) { - // add ~r[i] -> r[j] - literal v = r[j]; - literal u = ~r[j]; + for (unsigned i = 0; i < std::min(4u, r.size()); ++i) { + shuffle(r.size(), r.c_ptr(), m_rand); + literal u = r[0]; + for (unsigned j = 1; j < r.size(); ++j) { + literal v = ~r[j]; + // add u -> v m_roots[v.index()] = false; m_dag[u.index()].push_back(v); - // add ~r[j] -> r[i] + // add ~v -> ~u v.neg(); u.neg(); m_roots[u.index()] = false; @@ -65,7 +66,6 @@ namespace sat { } } } -#endif } } done_adding_edges(); @@ -268,6 +268,16 @@ namespace sat { return out << v; } + literal big::get_root(literal l) { + literal r = l; + do { + l = r; + r = m_root[l.index()]; + } + while (r != l); + return r; + } + void big::display(std::ostream& out) const { unsigned idx = 0; for (auto& next : m_dag) { diff --git a/src/sat/sat_big.h b/src/sat/sat_big.h index 898ddd1e8..25093fd60 100644 --- a/src/sat/sat_big.h +++ b/src/sat/sat_big.h @@ -34,6 +34,7 @@ namespace sat { svector m_left, m_right; literal_vector m_root, m_parent; bool m_learned; + bool m_include_cardinality; svector> m_del_bin; @@ -54,6 +55,9 @@ namespace sat { // static svector> s_del_bin; big(random_gen& rand); + + void set_include_cardinality(bool f) { m_include_cardinality = f; } + /** \brief initialize a BIG from a solver. */ @@ -77,7 +81,7 @@ namespace sat { int get_left(literal l) const { return m_left[l.index()]; } int get_right(literal l) const { return m_right[l.index()]; } literal get_parent(literal l) const { return m_parent[l.index()]; } - literal get_root(literal l) const { return m_root[l.index()]; } + literal get_root(literal l); bool reaches(literal u, literal v) const { return m_left[u.index()] < m_left[v.index()] && m_right[v.index()] < m_right[u.index()]; } bool connected(literal u, literal v) const { return reaches(u, v) || reaches(~v, ~u); } void display(std::ostream& out) const; diff --git a/src/sat/sat_scc.h b/src/sat/sat_scc.h index 146bd2366..1ba646992 100644 --- a/src/sat/sat_scc.h +++ b/src/sat/sat_scc.h @@ -60,7 +60,6 @@ namespace sat { void ensure_big(bool learned) { m_big.ensure_big(m_solver, learned); } int get_left(literal l) const { return m_big.get_left(l); } int get_right(literal l) const { return m_big.get_right(l); } - literal get_root(literal l) const { return m_big.get_root(l); } bool connected(literal u, literal v) const { return m_big.connected(u, v); } }; };