From 53c38f02d515661155851640df7fb1a9d0d013c0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 6 Jul 2017 22:12:18 -0700 Subject: [PATCH] n/a Signed-off-by: Nikolaj Bjorner --- src/sat/ba_solver.cpp | 448 +++++++++++++++++++++++++------------- src/sat/ba_solver.h | 31 ++- src/sat/sat_extension.h | 7 + src/sat/sat_lookahead.cpp | 18 +- src/sat/sat_lookahead.h | 5 +- 5 files changed, 340 insertions(+), 169 deletions(-) diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index 243fe9b10..eab579ff1 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -94,6 +94,7 @@ namespace sat { bool ba_solver::pb_base::well_formed() const { uint_set vars; + if (lit() != null_literal) vars.insert(lit().var()); for (unsigned i = 0; i < size(); ++i) { bool_var v = get_lit(i).var(); if (vars.contains(v)) return false; @@ -149,6 +150,7 @@ namespace sat { for (unsigned i = 0; i < size(); ++i) { m_wlits[i].first = std::min(k(), m_wlits[i].first); if (m_max_sum + m_wlits[i].first < m_max_sum) { + std::cout << "update-max-sum overflows\n"; throw default_exception("addition of pb coefficients overflows"); } m_max_sum += m_wlits[i].first; @@ -200,10 +202,11 @@ namespace sat { bool ba_solver::xor::well_formed() const { uint_set vars; + if (lit() != null_literal) vars.insert(lit().var()); for (literal l : *this) { bool_var v = l.var(); if (vars.contains(v)) return false; - vars.insert(v); + vars.insert(v); } return true; } @@ -326,6 +329,8 @@ namespace sat { // pb static unsigned _bad_id = 11111111; // 2759; // +#define BADLOG(_cmd_) if (p.id() == _bad_id) { _cmd_; } + // watch a prefix of literals, such that the slack of these is >= k bool ba_solver::init_watch(pb& p, bool is_true) { @@ -354,9 +359,8 @@ namespace sat { ++j; } } - if (p.id() == _bad_id) { - std::cout << "watch " << num_watch << " out of " << sz << "\n"; - } + BADLOG(std::cout << "watch " << num_watch << " out of " << sz << "\n"); + DEBUG_CODE( bool is_false = false; for (unsigned k = 0; k < sz; ++k) { @@ -383,14 +387,14 @@ namespace sat { p.set_slack(slack); p.set_num_watch(num_watch); - validate_watch(p); + SASSERT(validate_watch(p)); TRACE("sat", display(tout << "init watch: ", p, true);); // slack is tight: if (slack + slack1 == bound) { SASSERT(slack1 == 0); - SASSERT(j == num_watch); + SASSERT(j == num_watch); for (unsigned i = 0; i < j; ++i) { assign(p, p[i].second); } @@ -412,11 +416,11 @@ namespace sat { Sw += a_s Lw = Lw u {l_s} Lu = Lu \ {l_s} - } - if (Sw < k) return conflict - for (li in Lw | Sw < k + ai) + } + if (Sw < k) return conflict + for (li in Lw | Sw < k + ai) assign li - return no-conflict + return no-conflict a_max index: index of non-false literal with maximal weight. */ @@ -428,10 +432,7 @@ namespace sat { m_a_max = p[index].first; } } - } - - -#define BADLOG(_cmd_) if (p.id() == _bad_id) { _cmd_; } + } /* \brief propagate assignment to alit in constraint p. @@ -465,7 +466,7 @@ namespace sat { } add_index(p, index, lit); } - if (index == num_watch) { + if (index == num_watch || num_watch == 0) { _bad_id = p.id(); std::cout << "BAD: " << p.id() << "\n"; display(std::cout, p, true); @@ -493,9 +494,7 @@ namespace sat { literal lit = p[j].second; if (value(lit) != l_false) { slack += p[j].first; - if (is_watched(p[j].second, p)) { - std::cout << "Swap literal already watched: " << p[j].second << "\n"; - } + SASSERT(!is_watched(p[j].second, p)); watch_literal(p[j], p); p.swap(num_watch, j); add_index(p, num_watch, lit); @@ -519,11 +518,10 @@ namespace sat { set_conflict(p, alit); return l_false; } - - if (index > p.size() || num_watch > p.size() || num_watch == 0 || p.id() == _bad_id) { - display(std::cout, p, true); - std::cout << "size: " << p.size() << " index: " << index << " num watch: " << num_watch << "\n"; - } + + if (num_watch == 1) { _bad_id = p.id(); } + + BADLOG(std::cout << "size: " << p.size() << " index: " << index << " num watch: " << num_watch << "\n"); // swap out the watched literal. --num_watch; @@ -532,6 +530,7 @@ namespace sat { p.set_num_watch(num_watch); p.swap(num_watch, index); + // // slack >= bound, but slack - w(l) < bound // l must be true. @@ -564,12 +563,10 @@ namespace sat { } void ba_solver::clear_watch(pb& p) { - validate_watch(p); for (unsigned i = 0; i < p.num_watch(); ++i) { unwatch_literal(p[i].second, p); } p.set_num_watch(0); - validate_watch(p); } /* @@ -687,10 +684,10 @@ namespace sat { --i; } } - if (p.id() == _bad_id) display(std::cout << "simplify ", p, true); + BADLOG(display(std::cout << "simplify ", p, true)); p.set_size(sz); p.set_k(p.k() - true_val); - if (p.id() == _bad_id) display(std::cout << "simplified ", p, true); + BADLOG(display(std::cout << "simplified ", p, true)); // display(verbose_stream(), c, true); if (p.k() == 1 && p.lit() == null_literal) { @@ -876,45 +873,69 @@ namespace sat { m_active_vars.shrink(j); } - void ba_solver::inc_coeff(literal l, int offset) { + void ba_solver::inc_coeff(literal l, int64 offset) { SASSERT(offset > 0); bool_var v = l.var(); SASSERT(v != null_bool_var); if (static_cast(m_coeffs.size()) <= v) { m_coeffs.resize(v + 1, 0); } - int coeff0 = m_coeffs[v]; + int64 coeff0 = m_coeffs[v]; if (coeff0 == 0) { m_active_vars.push_back(v); } - int inc = l.sign() ? -offset : offset; - int coeff1 = inc + coeff0; + int64 inc = l.sign() ? -offset : offset; + int64 coeff1 = inc + coeff0; m_coeffs[v] = coeff1; + if (coeff1 > INT_MAX || coeff1 < INT_MIN) { + std::cout << "overflow update coefficient " << coeff1 << "\n"; + m_overflow = true; + return; + } if (coeff0 > 0 && inc < 0) { - m_bound -= coeff0 - std::max(0, coeff1); + m_bound -= coeff0 - std::max(0LL, coeff1); } else if (coeff0 < 0 && inc > 0) { - m_bound -= std::min(0, coeff1) - coeff0; + m_bound -= std::min(0LL, coeff1) - coeff0; } // reduce coefficient to be no larger than bound. if (coeff1 > m_bound) { m_coeffs[v] = m_bound; } else if (coeff1 < 0 && -coeff1 > m_bound) { - m_coeffs[v] = -m_bound; + m_coeffs[v] = m_bound; } } - int ba_solver::get_coeff(bool_var v) const { + int64 ba_solver::get_coeff(bool_var v) const { return m_coeffs.get(v, 0); } - int ba_solver::get_abs_coeff(bool_var v) const { + int64 ba_solver::get_abs_coeff(bool_var v) const { return abs(get_coeff(v)); } + int ba_solver::get_int_coeff(bool_var v) const { + int64 c = m_coeffs.get(v, 0); + if (c < INT_MIN || c > INT_MAX) { + std::cout << "overflow " << c << "\n"; + m_overflow = true; + return 0; + } + return static_cast(c); + } + + unsigned ba_solver::get_bound() const { + if (m_bound < 0 || m_bound > UINT_MAX) { + std::cout << "overflow bound " << m_bound << "\n"; + m_overflow = true; + return 1; + } + return static_cast(m_bound); + } + void ba_solver::reset_coeffs() { for (unsigned i = 0; i < m_active_vars.size(); ++i) { m_coeffs[m_active_vars[i]] = 0; @@ -923,11 +944,14 @@ namespace sat { } static bool _debug_conflict = false; + static literal _debug_consequent = null_literal; + static unsigned_vector _debug_var2position; lbool ba_solver::resolve_conflict() { if (0 == m_num_propagations_since_pop) { return l_undef; } + m_overflow = false; reset_coeffs(); m_num_marks = 0; m_bound = 0; @@ -941,25 +965,25 @@ namespace sat { } literal_vector const& lits = s().m_trail; unsigned idx = lits.size() - 1; - int offset = 1; + int64 offset = 1; DEBUG_CODE(active2pb(m_A);); unsigned init_marks = m_num_marks; do { - if (offset == 0) { - goto process_next_resolvent; - } - // TBD: need proper check for overflow. - if (offset > (1 << 12)) { - IF_VERBOSE(12, verbose_stream() << "offset: " << offset << "\n"; + if (m_overflow || offset > (1 << 12)) { + IF_VERBOSE(20, verbose_stream() << "offset: " << offset << "\n"; active2pb(m_A); display(verbose_stream(), m_A); ); goto bail_out; } + if (offset == 0) { + goto process_next_resolvent; + } + TRACE("sat_verbose", display(tout, m_A);); TRACE("sat", tout << "process consequent: " << consequent << ":\n"; s().display_justification(tout, js) << "\n";); SASSERT(offset > 0); @@ -971,6 +995,7 @@ namespace sat { std::cout << consequent << "\n"; s().display_justification(std::cout, js); std::cout << "\n"; + _debug_consequent = consequent; } switch(js.get_kind()) { case justification::NONE: @@ -1056,7 +1081,6 @@ namespace sat { SASSERT(validate_lemma()); - DEBUG_CODE( active2pb(m_C); //SASSERT(validate_resolvent()); @@ -1096,7 +1120,6 @@ namespace sat { DEBUG_CODE(active2pb(m_A);); } SASSERT(value(consequent) == l_true); - } while (m_num_marks > 0); @@ -1111,6 +1134,10 @@ namespace sat { active2card(); + if (m_overflow) { + goto bail_out; + } + SASSERT(validate_conflict(m_lemma, m_A)); TRACE("sat", tout << m_lemma << "\n";); @@ -1130,6 +1157,9 @@ namespace sat { return l_true; bail_out: + + m_overflow = false; + while (m_num_marks > 0 && idx >= 0) { bool_var v = lits[idx].var(); if (s().is_marked(v)) { @@ -1138,11 +1168,20 @@ namespace sat { } if (idx == 0 && !_debug_conflict) { _debug_conflict = true; + _debug_var2position.reserve(s().num_vars()); + for (unsigned i = 0; i < lits.size(); ++i) { + _debug_var2position[lits[i].var()] = i; + } // s().display(std::cout); - std::cout << s().m_not_l << "\n"; + active2pb(m_A); + uint64 c = 0; + for (uint64 c1 : m_A.m_coeffs) c += c1; + std::cout << "sum of coefficients: " << c << "\n"; + display(std::cout, m_A, true); + std::cout << "conflicting literal: " << s().m_not_l << "\n"; for (literal l : lits) { if (s().is_marked(l.var())) { - std::cout << "missing mark: " << l << "\n"; + IF_VERBOSE(0, verbose_stream() << "missing mark: " << l << "\n";); s().reset_mark(l.var()); } } @@ -1158,7 +1197,7 @@ namespace sat { adjust_conflict_level: - int slack = -m_bound; + int64 slack = -m_bound; for (bool_var v : m_active_vars) { slack += get_abs_coeff(v); } @@ -1166,10 +1205,10 @@ namespace sat { m_lemma.reset(); m_lemma.push_back(null_literal); unsigned num_skipped = 0; - int asserting_coeff = 0; - for (unsigned i = 0; /* 0 <= slack && */ i < m_active_vars.size(); ++i) { + int64 asserting_coeff = 0; + for (unsigned i = 0; 0 <= slack && i < m_active_vars.size(); ++i) { bool_var v = m_active_vars[i]; - int coeff = get_coeff(v); + int64 coeff = get_coeff(v); lbool val = value(v); bool is_true = val == l_true; bool append = coeff != 0 && val != l_undef && (coeff < 0 == is_true); @@ -1216,38 +1255,30 @@ namespace sat { IF_VERBOSE(10, verbose_stream() << "(sat.backjump :new-level " << m_conflict_lvl << " :old-level " << old_level << ")\n";); goto adjust_conflict_level; } - - // slack is related to coefficients of m_lemma - // so does not apply to unit coefficients. - // std::cout << "lemma: " << m_lemma << " >= " << 1 << "\n"; - // active2pb(m_A); - // display(std::cout, m_A, true); -#if 0 - constraint* c = active2constraint(); - if (c) { - display(std::cout, *c, true); - std::cout << "Eval: " << eval(*c) << "\n"; - } -#endif return true; } + /* + \brief compute a cut for current resolvent. + */ + void ba_solver::cut() { - unsigned g = 0; - int sum_of_units = 0; + + // bypass cut if there is a unit coefficient for (bool_var v : m_active_vars) { - if (1 == get_abs_coeff(v) && ++sum_of_units >= m_bound) return; + if (1 == get_abs_coeff(v)) return; } - //active2pb(m_A); - //display(std::cout << "units can be removed\n", m_A, true); + + SASSERT(0 <= m_bound && m_bound <= UINT_MAX); + + unsigned g = 0; for (unsigned i = 0; g != 1 && i < m_active_vars.size(); ++i) { bool_var v = m_active_vars[i]; - int coeff = get_abs_coeff(v); + int64 coeff = get_abs_coeff(v); if (coeff == 0) { continue; } - if (coeff == 1) return; if (m_bound < coeff) { if (get_coeff(v) > 0) { m_coeffs[v] = m_bound; @@ -1265,23 +1296,18 @@ namespace sat { g = u_gcd(g, static_cast(coeff)); } } + if (g >= 2) { - active2pb(m_A); - //display(std::cout, m_A, true); normalize_active_coeffs(); - int ig = static_cast(g); - for (unsigned i = 0; i < m_active_vars.size(); ++i) { - m_coeffs[m_active_vars[i]] /= ig; + for (bool_var v : m_active_vars) { + m_coeffs[v] /= static_cast(g); } m_bound = (m_bound + g - 1) / g; ++m_stats.m_num_cut; - //std::cout << "CUT " << g << "\n"; - //active2pb(m_A); - //display(std::cout, m_A, true); } } - void ba_solver::process_card(card& c, int offset) { + void ba_solver::process_card(card& c, int64 offset) { literal lit = c.lit(); SASSERT(c.k() <= c.size()); SASSERT(lit == null_literal || value(lit) == l_true); @@ -1297,7 +1323,7 @@ namespace sat { } } - void ba_solver::process_antecedent(literal l, int offset) { + void ba_solver::process_antecedent(literal l, int64 offset) { SASSERT(value(l) == l_false); bool_var v = l.var(); unsigned level = lvl(v); @@ -1306,6 +1332,9 @@ namespace sat { s().mark(v); TRACE("sat", tout << "Mark: v" << v << "\n";); ++m_num_marks; + if (_debug_conflict && _debug_consequent != null_literal && _debug_var2position[_debug_consequent.var()] < _debug_var2position[l.var()]) { + std::cout << "antecedent " << l << " is above consequent in stack\n"; + } } inc_coeff(l, offset); } @@ -1359,9 +1388,6 @@ namespace sat { } void ba_solver::add_constraint(constraint* c) { - if (c->id() == _bad_id) { - display(std::cout, *c, true); - } if (c->learned()) { m_learned.push_back(c); } @@ -1463,6 +1489,50 @@ namespace sat { } } + double ba_solver::get_reward(card const& c, literal_occs_fun& literal_occs) const { + unsigned k = c.k(), slack = 0; + double to_add = 0; + for (literal l : c) { + switch (value(l)) { + case l_true: --k; if (k == 0) return 0; break; + case l_undef: to_add += literal_occs(l); ++slack; break; + case l_false: break; + } + } + if (k >= slack) return 1; + return pow(0.5, slack - k + 1) * to_add; + } + + double ba_solver::get_reward(pb const& c, literal_occs_fun& occs) const { + unsigned k = c.k(), slack = 0; + double to_add = 0; + double undefs = 0; + for (wliteral wl : c) { + literal l = wl.second; + unsigned w = wl.first; + switch (value(l)) { + case l_true: if (k <= w) return 0; k -= w; break; + case l_undef: to_add += occs(l); ++undefs; slack += w; break; // TBD multiplier factor on this + case l_false: break; + } + } + if (k >= slack || 0 == undefs) return 0; + double avg = slack / undefs; + return pow(0.5, (slack - k + 1)/avg) * to_add; + } + + double ba_solver::get_reward(literal l, ext_justification_idx idx, literal_occs_fun& occs) const { + constraint const& c = index2constraint(idx); + unsigned sz = c.size(); + switch (c.tag()) { + case card_t: return get_reward(c.to_card(), occs); + case pb_t: return get_reward(c.to_pb(), occs); + case xor_t: return 0; + default: UNREACHABLE(); return 0; + } + } + + void ba_solver::ensure_parity_size(bool_var v) { if (m_parity_marks.size() <= static_cast(v)) { @@ -1585,6 +1655,11 @@ namespace sat { unsigned k = p.k(); + if (_debug_conflict) { + display(std::cout, p, true); + std::cout << "literal: " << l << " value: " << value(l) << " num-watch: " << p.num_watch() << " slack: " << p.slack() << "\n"; + } + if (value(l) == l_false) { // The literal comes from a conflict. // it is forced true, but assigned to false. @@ -1619,6 +1694,10 @@ namespace sat { CTRACE("sat", coeff == 0, display(tout << l << " coeff: " << coeff << "\n", p, true);); + if (_debug_conflict) { + std::cout << "coeff " << coeff << "\n"; + } + SASSERT(coeff > 0); unsigned slack = p.slack() - coeff; @@ -1688,20 +1767,10 @@ namespace sat { } void ba_solver::unwatch_literal(literal lit, constraint& c) { - if (c.id() == _bad_id) { std::cout << "unwatch " << lit << "\n"; } get_wlist(~lit).erase(watched(c.index())); - if (is_watched(lit, c)) { - std::cout << "Not deleted " << lit << "\n"; - } } void ba_solver::watch_literal(literal lit, constraint& c) { - if (is_watched(lit, c)) { - std::cout << "Already watched " << lit << "\n"; - UNREACHABLE(); - exit(0); - } - if (c.id() == _bad_id) { std::cout << "watch " << lit << "\n"; } get_wlist(~lit).push_back(watched(c.index())); } @@ -1867,9 +1936,7 @@ namespace sat { if (c.lit() != null_literal && value(c.lit()) != l_true) return true; if (c.lit() != null_literal && lvl(c.lit()) != 0) { if (!is_watched(c.lit(), c) || !is_watched(~c.lit(), c)) { - std::cout << "Definition literal is not watched " << c.lit() << " " << c << "\n"; - display_watch_list(std::cout, s().m_cls_allocator, get_wlist(c.lit())) << "\n"; - display_watch_list(std::cout, s().m_cls_allocator, get_wlist(~c.lit())) << "\n"; + UNREACHABLE(); return false; } } @@ -1882,12 +1949,15 @@ namespace sat { bool found = is_watched(l, c); if (found != c.is_watching(l)) { - std::cout << "Discrepancy of watched literal: " << l << " id: " << c.id() << " clause: " << c << (found?" is watched, but shouldn't be":" not watched, but should be") << "\n"; - display_watch_list(std::cout << l << ": ", s().m_cls_allocator, get_wlist(l)) << "\n"; - display_watch_list(std::cout << ~l << ": ", s().m_cls_allocator, get_wlist(~l)) << "\n"; - std::cout << "value: " << value(l) << " level: " << lvl(l) << "\n"; - display(std::cout, c, true); - if (c.lit() != null_literal) std::cout << value(c.lit()) << "\n"; + IF_VERBOSE(0, + verbose_stream() << "Discrepancy of watched literal: " << l << " id: " << c.id() + << " clause: " << c << (found?" is watched, but shouldn't be":" not watched, but should be") << "\n"; + display_watch_list(verbose_stream() << l << ": ", s().m_cls_allocator, get_wlist(l)) << "\n"; + display_watch_list(verbose_stream() << ~l << ": ", s().m_cls_allocator, get_wlist(~l)) << "\n"; + verbose_stream() << "value: " << value(l) << " level: " << lvl(l) << "\n"; + display(verbose_stream(), c, true); + if (c.lit() != null_literal) verbose_stream() << value(c.lit()) << "\n";); + UNREACHABLE(); exit(1); return false; @@ -1900,8 +1970,6 @@ namespace sat { for (unsigned i = 0; i < p.size(); ++i) { literal l = p[i].second; if (lvl(l) != 0 && is_watched(l, p) != i < p.num_watch()) { - std::cout << "DISCREPANCY: " << l << " at " << i << " for " << p.num_watch() << " index: " << p.id() << "\n"; - display(std::cout, p, true); UNREACHABLE(); return false; } @@ -2067,13 +2135,14 @@ namespace sat { m_simplify_change = false; m_clause_removed = false; m_constraint_removed = false; - for (constraint* c : m_constraints) simplify(*c); - for (constraint* c : m_learned) simplify(*c); + for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) simplify(*m_constraints[i]); + for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) simplify(*m_learned[i]); init_use_lists(); remove_unused_defs(); set_non_external(); elim_pure(); - subsumption(); + for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) subsumption(*m_constraints[i]); + for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) subsumption(*m_learned[i]); cleanup_clauses(); cleanup_constraints(); } @@ -2242,12 +2311,10 @@ namespace sat { m_visited.resize(s().num_vars()*2, false); m_constraint_removed = false; - for (constraint* c : m_constraints) { - flush_roots(*c); - } - for (constraint* c : m_learned) { - flush_roots(*c); - } + for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) + flush_roots(*m_constraints[i]); + for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) + flush_roots(*m_learned[i]); cleanup_constraints(); // validate(); @@ -2342,7 +2409,7 @@ namespace sat { } literal root = c.lit(); remove_constraint(c); - constraint* p = add_pb_ge(root, wlits, k, c.learned()); + add_pb_ge(root, wlits, k, c.learned()); } else { if (c.lit() == null_literal || value(c.lit()) == l_true) { @@ -2644,11 +2711,29 @@ namespace sat { unsigned ext = 0; for (unsigned v = 0; v < s().num_vars(); ++v) { literal lit(v, false); - if (s().is_external(v) && m_cnstr_use_list[lit.index()].size() == 0 && m_cnstr_use_list[(~lit).index()].size() == 0 && !s().is_assumption(v)) { + if (s().is_external(v) && + m_cnstr_use_list[lit.index()].size() == 0 && + m_cnstr_use_list[(~lit).index()].size() == 0 && !s().is_assumption(v)) { s().set_non_external(v); ++ext; } } + // ensure that lemmas use only external variables. + for (constraint* cp : m_learned) { + constraint& c = *cp; + if (c.was_removed()) continue; + SASSERT(c.lit() == null_literal); + for (unsigned i = 0; i < c.size(); ++i) { + bool_var v = c.get_lit(i).var(); + if (s().was_eliminated(v)) { + remove_constraint(c); + break; + } + if (!s().is_external(v)) { + s().set_external(v); + } + } + } IF_VERBOSE(10, verbose_stream() << "non-external variables converted: " << ext << "\n";); return ext; } @@ -2680,11 +2765,6 @@ namespace sat { return pure_literals; } - void ba_solver::subsumption() { - for (constraint* c : m_constraints) subsumption(*c); - for (constraint* c : m_learned) subsumption(*c); - } - void ba_solver::subsumption(constraint& cnstr) { if (cnstr.was_removed()) return; switch (cnstr.tag()) { @@ -2693,6 +2773,11 @@ namespace sat { if (c.k() > 1) subsumption(c); break; } + case pb_t: { + pb& p = cnstr.to_pb(); + if (p.k() > 1) subsumption(p); + break; + } default: break; } @@ -2805,6 +2890,45 @@ namespace sat { return c1_exclusive + 1 <= c1.k(); } + /* + \brief Ax >= k subsumes By >= k' if + all coefficients in A are <= B and k >= k' + */ + bool ba_solver::subsumes(pb const& p1, pb_base const& p2) { + if (p1.k() < p2.k()) return false; + unsigned num_marked = 0; + for (unsigned i = 0; i < p2.size(); ++i) { + literal l = p2.get_lit(i); + if (is_marked(l)) { + ++num_marked; + if (m_weights[l.index()] > p2.get_coeff(i)) return false; + } + } + return num_marked == p1.size(); + } + + void ba_solver::subsumes(pb& p1, literal lit) { + for (constraint* c : m_cnstr_use_list[lit.index()]) { + if (c == &p1 || c->was_removed()) continue; + bool s = false; + switch (c->tag()) { + case card_t: + s = subsumes(p1, c->to_card()); + break; + case pb_t: + s = subsumes(p1, c->to_pb()); + break; + default: + break; + } + if (s) { + ++m_stats.m_num_card_subsumes; + p1.set_learned(false); + remove_constraint(*c); + } + } + } + literal ba_solver::get_min_occurrence_literal(card const& c) { unsigned occ_count = UINT_MAX; literal lit = null_literal; @@ -2821,7 +2945,7 @@ namespace sat { void ba_solver::card_subsumption(card& c1, literal lit) { literal_vector slit; for (constraint* c : m_cnstr_use_list[lit.index()]) { - if (!c || c->tag() != card_t || c == &c1 || c->was_removed()) { + if (!c->is_card() || c == &c1 || c->was_removed()) { continue; } card& c2 = c->to_card(); @@ -2910,8 +3034,7 @@ namespace sat { } void ba_solver::subsumption(card& c1) { - SASSERT(!c1.was_removed()); - if (c1.lit() != null_literal) { + if (c1.was_removed() || c1.lit() != null_literal) { return; } clause_vector removed_clauses; @@ -2930,6 +3053,24 @@ namespace sat { } } + void ba_solver::subsumption(pb& p1) { + if (p1.was_removed() || p1.lit() != null_literal) { + return; + } + for (wliteral l : p1) { + mark_visited(l.second); + SASSERT(m_weights[l.second.index()] == 0); + m_weights[l.second.index()] = l.first; + } + for (unsigned i = 0; i < p1.num_watch(); ++i) { + subsumes(p1, p1[i].second); + } + for (wliteral l : p1) { + unmark_visited(l.second); + m_weights[l.second.index()] = 0; + } + } + void ba_solver::clauses_modifed() {} lbool ba_solver::get_phase(bool_var v) { return l_undef; } @@ -3145,11 +3286,11 @@ namespace sat { } bool ba_solver::validate_lemma() { - int val = -m_bound; + int64 val = -m_bound; reset_active_var_set(); for (bool_var v : m_active_vars) { if (m_active_var_set.contains(v)) continue; - int coeff = get_coeff(v); + int64 coeff = get_coeff(v); if (coeff == 0) continue; m_active_var_set.insert(v); literal lit(v, false); @@ -3173,7 +3314,7 @@ namespace sat { p.reset(m_bound); for (bool_var v : m_active_vars) { if (m_active_var_set.contains(v)) continue; - int coeff = get_coeff(v); + int64 coeff = get_coeff(v); if (coeff == 0) continue; m_active_var_set.insert(v); literal lit(v, coeff < 0); @@ -3184,32 +3325,27 @@ namespace sat { ba_solver::constraint* ba_solver::active2constraint() { reset_active_var_set(); - literal_vector lits; - unsigned_vector coeffs; - bool all_one = true; + svector wlits; uint64_t sum = 0; if (m_bound == 1) return 0; + if (m_overflow) return 0; + for (bool_var v : m_active_vars) { - int coeff = get_coeff(v); + int coeff = get_int_coeff(v); if (m_active_var_set.contains(v) || coeff == 0) continue; m_active_var_set.insert(v); literal lit(v, coeff < 0); - lits.push_back(lit); - coeffs.push_back(abs(coeff)); - all_one &= abs(coeff) == 1; + wlits.push_back(wliteral(abs(coeff), lit)); sum += abs(coeff); } - if (sum >= UINT_MAX/2) return 0; - if (all_one) { - return add_at_least(null_literal, lits, m_bound, true); + unsigned k = get_bound(); + + if (m_overflow || sum >= UINT_MAX/2) { + return 0; } else { - svector wlits; - for (unsigned i = 0; i < lits.size(); ++i) { - wlits.push_back(wliteral(coeffs[i], lits[i])); - } - return add_pb_ge(null_literal, wlits, m_bound, true); - } + return add_pb_ge(null_literal, wlits, k, true); + } } /* @@ -3247,7 +3383,7 @@ namespace sat { normalize_active_coeffs(); svector wlits; for (bool_var v : m_active_vars) { - int coeff = get_coeff(v); + int coeff = get_int_coeff(v); wlits.push_back(std::make_pair(abs(coeff), literal(v, coeff < 0))); } std::sort(wlits.begin(), wlits.end(), compare_wlit()); @@ -3264,7 +3400,7 @@ namespace sat { } while (!wlits.empty()) { wliteral wl = wlits.back(); - if (wl.first + sum0 >= static_cast(m_bound)) break; + if (wl.first + sum0 >= get_bound()) break; wlits.pop_back(); sum0 += wl.first; } @@ -3283,9 +3419,15 @@ namespace sat { ++num_max_level; } } - + if (m_overflow) return 0; if (slack >= k) { +#if 0 + return active2constraint(); + active2pb(m_A); + std::cout << "not asserting\n"; + display(std::cout, m_A, true); +#endif return 0; } @@ -3374,15 +3516,15 @@ namespace sat { // validate that m_A & m_B implies m_C bool ba_solver::validate_resolvent() { - u_map coeffs; - unsigned k = m_A.m_k + m_B.m_k; + u_map coeffs; + uint64 k = m_A.m_k + m_B.m_k; for (unsigned i = 0; i < m_A.m_lits.size(); ++i) { - unsigned coeff = m_A.m_coeffs[i]; + uint64 coeff = m_A.m_coeffs[i]; SASSERT(!coeffs.contains(m_A.m_lits[i].index())); coeffs.insert(m_A.m_lits[i].index(), coeff); } for (unsigned i = 0; i < m_B.m_lits.size(); ++i) { - unsigned coeff1 = m_B.m_coeffs[i], coeff2; + uint64 coeff1 = m_B.m_coeffs[i], coeff2; literal lit = m_B.m_lits[i]; if (coeffs.find((~lit).index(), coeff2)) { if (coeff1 == coeff2) { @@ -3410,7 +3552,7 @@ namespace sat { // C is above the sum of A and B for (unsigned i = 0; i < m_C.m_lits.size(); ++i) { literal lit = m_C.m_lits[i]; - unsigned coeff; + uint64 coeff; if (coeffs.find(lit.index(), coeff)) { if (coeff > m_C.m_coeffs[i] && m_C.m_coeffs[i] < m_C.m_k) { IF_VERBOSE(0, verbose_stream() << i << ": " << m_C.m_coeffs[i] << " " << m_C.m_k << "\n";); @@ -3444,9 +3586,9 @@ namespace sat { return false; } } - unsigned value = 0; + uint64 value = 0; for (unsigned i = 0; i < p.m_lits.size(); ++i) { - unsigned coeff = p.m_coeffs[i]; + uint64 coeff = p.m_coeffs[i]; if (!lits.contains(p.m_lits[i])) { value += coeff; } diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index 115b57fed..c30fd25b4 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -191,9 +191,9 @@ namespace sat { struct ineq { literal_vector m_lits; - unsigned_vector m_coeffs; - unsigned m_k; - void reset(unsigned k) { m_lits.reset(); m_coeffs.reset(); m_k = k; } + svector m_coeffs; + uint64 m_k; + void reset(uint64 k) { m_lits.reset(); m_coeffs.reset(); m_k = k; } void push(literal l, unsigned c) { m_lits.push_back(l); m_coeffs.push_back(c); } }; @@ -213,9 +213,9 @@ namespace sat { // conflict resolution unsigned m_num_marks; unsigned m_conflict_lvl; - svector m_coeffs; + svector m_coeffs; svector m_active_vars; - int m_bound; + int64 m_bound; tracked_uint_set m_active_var_set; literal_vector m_lemma; literal_vector m_skipped; @@ -246,6 +246,9 @@ namespace sat { bool subsumes(card& c1, card& c2, literal_vector& comp); bool subsumes(card& c1, clause& c2, literal_vector& comp); bool subsumed(card& c1, literal l1, literal l2); + bool subsumes(pb const& p1, pb_base const& p2); + void subsumes(pb& p1, literal lit); + void subsumption(pb& p1); void binary_subsumption(card& c1, literal lit); void clause_subsumption(card& c1, literal lit, clause_vector& removed_clauses); void card_subsumption(card& c1, literal lit); @@ -259,7 +262,6 @@ namespace sat { unsigned set_non_external(); unsigned elim_pure(); bool elim_pure(literal lit); - void subsumption(); void subsumption(constraint& c1); void subsumption(card& c1); void gc_half(char const* _method); @@ -317,6 +319,8 @@ namespace sat { void flush_roots(card& c); void recompile(card& c); lbool eval(card const& c) const; + double get_reward(card const& c, literal_occs_fun& occs) const; + // xor specific functionality void clear_watch(xor& x); @@ -343,6 +347,7 @@ namespace sat { void flush_roots(pb& p); void recompile(pb& p); lbool eval(pb const& p) const; + double get_reward(pb const& p, literal_occs_fun& occs) const; // access solver inline lbool value(bool_var v) const { return value(literal(v, false)); } @@ -358,15 +363,18 @@ namespace sat { inline void drat_add(literal_vector const& c, svector const& premises) { m_solver->m_drat.add(c, premises); } + mutable bool m_overflow; void reset_active_var_set(); void normalize_active_coeffs(); - void inc_coeff(literal l, int offset); - int get_coeff(bool_var v) const; - int get_abs_coeff(bool_var v) const; + void inc_coeff(literal l, int64 offset); + int64 get_coeff(bool_var v) const; + int64 get_abs_coeff(bool_var v) const; + int get_int_coeff(bool_var v) const; + unsigned get_bound() const; literal get_asserting_literal(literal conseq); - void process_antecedent(literal l, int offset); - void process_card(card& c, int offset); + void process_antecedent(literal l, int64 offset); + void process_card(card& c, int64 offset); void cut(); bool create_asserting_lemma(); @@ -432,6 +440,7 @@ namespace sat { virtual void find_mutexes(literal_vector& lits, vector & mutexes); virtual void pop_reinit(); virtual void gc(); + virtual double get_reward(literal l, ext_justification_idx idx, literal_occs_fun& occs) const; ptr_vector const & constraints() const { return m_constraints; } diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index a61c330c7..87f1904ed 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -29,12 +29,19 @@ namespace sat { CR_DONE, CR_CONTINUE, CR_GIVEUP }; + class literal_occs_fun { + public: + virtual double operator()(literal l) = 0; + + }; + class extension { public: virtual ~extension() {} virtual void set_solver(solver* s) = 0; virtual void set_lookahead(lookahead* s) = 0; virtual bool propagate(literal l, ext_constraint_idx idx) = 0; + virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const = 0; virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) = 0; virtual void asserted(literal l) = 0; virtual check_result check() = 0; diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 5dadd0b73..b5229dc54 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -852,6 +852,8 @@ namespace sat { copy_clauses(m_s.m_clauses); copy_clauses(m_s.m_learned); + m_config.m_use_ternary_reward &= !m_s.m_ext; + // copy units unsigned trail_sz = m_s.init_trail_size(); for (unsigned i = 0; i < trail_sz; ++i) { @@ -883,12 +885,10 @@ namespace sat { for (; it != end; ++it) { clause& c = *(*it); if (c.was_removed()) continue; -#if 0 // enable when there is a non-ternary reward system. if (c.size() > 3) { m_config.m_use_ternary_reward = false; } -#endif bool was_eliminated = false; for (unsigned i = 0; !was_eliminated && i < c.size(); ++i) { was_eliminated = m_s.was_eliminated(c[i].var()); @@ -1042,6 +1042,14 @@ namespace sat { // Only the size indicator needs to be updated on backtracking. // + class lookahead_literal_occs_fun : public literal_occs_fun { + lookahead& lh; + public: + lookahead_literal_occs_fun(lookahead& lh): lh(lh) {} + double operator()(literal l) { return lh.literal_occs(l); } + }; + + void lookahead::propagate_clauses(literal l) { SASSERT(is_true(l)); if (inconsistent()) return; @@ -1172,6 +1180,10 @@ namespace sat { case watched::EXT_CONSTRAINT: { SASSERT(m_s.m_ext); bool keep = m_s.m_ext->propagate(l, it->get_ext_constraint_idx()); + if (m_search_mode == lookahead_mode::lookahead1) { + lookahead_literal_occs_fun literal_occs_fn(*this); + m_lookahead_reward += m_s.m_ext->get_reward(l, it->get_ext_constraint_idx(), literal_occs_fn); + } if (m_inconsistent) { if (!keep) ++it; set_conflict(); @@ -1222,7 +1234,7 @@ namespace sat { to_add += literal_occs(l); } } - m_lookahead_reward += pow(sz, -2) * to_add; + m_lookahead_reward += pow(0.5, sz) * to_add; } else { m_lookahead_reward = (double)0.001; diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 067a95c55..38adc4505 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -402,8 +402,7 @@ namespace sat { literal select_literal(); void update_binary_clause_reward(literal l1, literal l2); void update_nary_clause_reward(clause const& c); - double literal_occs(literal l); - + void set_lookahead_reward(literal l, double f) { m_lits[l.index()].m_lookahead_reward = f; } void inc_lookahead_reward(literal l, double f) { m_lits[l.index()].m_lookahead_reward += f; } double get_lookahead_reward(literal l) const { return m_lits[l.index()].m_lookahead_reward; } @@ -486,6 +485,8 @@ namespace sat { model const& get_model(); void collect_statistics(statistics& st) const; + + double literal_occs(literal l); }; }