From 5262248823e11ffca0fa7971ad9bb61ccb9e1b79 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 4 Jul 2017 11:13:05 -0700 Subject: [PATCH] n/a Signed-off-by: Nikolaj Bjorner --- src/sat/ba_solver.cpp | 665 +++++++++++++++++++++++++------------- src/sat/ba_solver.h | 45 ++- src/sat/sat_extension.h | 5 +- src/sat/sat_lookahead.cpp | 4 +- src/sat/sat_solver.cpp | 22 +- 5 files changed, 501 insertions(+), 240 deletions(-) diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index ebf6079ea..0d4f5b1fe 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -106,8 +106,8 @@ namespace sat { // ---------------------- // card - ba_solver::card::card(literal lit, literal_vector const& lits, unsigned k): - pb_base(card_t, lit, lits.size(), get_obj_size(lits.size()), k) { + ba_solver::card::card(unsigned id, literal lit, literal_vector const& lits, unsigned k): + pb_base(card_t, id, lit, lits.size(), get_obj_size(lits.size()), k) { for (unsigned i = 0; i < size(); ++i) { m_lits[i] = lits[i]; } @@ -133,8 +133,8 @@ namespace sat { // ----------------------------------- // pb - ba_solver::pb::pb(literal lit, svector const& wlits, unsigned k): - pb_base(pb_t, lit, wlits.size(), get_obj_size(wlits.size()), k), + ba_solver::pb::pb(unsigned id, literal lit, svector const& wlits, unsigned k): + pb_base(pb_t, id, lit, wlits.size(), get_obj_size(wlits.size()), k), m_slack(0), m_num_watch(0), m_max_sum(0) { @@ -185,8 +185,8 @@ namespace sat { // ----------------------------------- // xor - ba_solver::xor::xor(literal lit, literal_vector const& lits): - constraint(xor_t, lit, lits.size(), get_obj_size(lits.size())) { + ba_solver::xor::xor(unsigned id, literal lit, literal_vector const& lits): + constraint(xor_t, id, lit, lits.size(), get_obj_size(lits.size())) { for (unsigned i = 0; i < size(); ++i) { m_lits[i] = lits[i]; } @@ -211,7 +211,7 @@ namespace sat { // ---------------------------- // card - void ba_solver::init_watch(card& c, bool is_true) { + bool ba_solver::init_watch(card& c, bool is_true) { clear_watch(c); if (c.lit() != null_literal && c.lit().sign() == is_true) { c.negate(); @@ -223,7 +223,7 @@ namespace sat { if (bound == sz) { for (literal l : c) assign(c, l); - return; + return false; } for (unsigned i = 0; i < sz; ++i) { @@ -242,6 +242,7 @@ namespace sat { }); // j is the number of non-false, sz - j the number of false. + if (j < bound) { SASSERT(0 < bound && bound < sz); literal alit = c[j]; @@ -259,16 +260,19 @@ namespace sat { } } set_conflict(c, alit); + return false; } else if (j == bound) { for (unsigned i = 0; i < bound; ++i) { assign(c, c[i]); } + return false; } else { for (unsigned i = 0; i <= bound; ++i) { watch_literal(c[i], c); } + return true; } } @@ -323,7 +327,7 @@ namespace sat { // watch a prefix of literals, such that the slack of these is >= k - void ba_solver::init_watch(pb& p, bool is_true) { + bool ba_solver::init_watch(pb& p, bool is_true) { clear_watch(p); if (p.lit() != null_literal && p.lit().sign() == is_true) { p.negate(); @@ -366,6 +370,7 @@ namespace sat { } } set_conflict(p, lit); + return false; } else { for (unsigned i = 0; i < num_watch; ++i) { @@ -374,6 +379,8 @@ namespace sat { p.set_slack(slack); p.set_num_watch(num_watch); + TRACE("sat", display(tout << "init watch: ", p, true);); + // slack is tight: if (slack + slack1 == bound) { SASSERT(slack1 == 0); @@ -382,8 +389,8 @@ namespace sat { assign(p, p[i].second); } } + return true; } - TRACE("sat", display(tout << "init watch: ", p, true);); } /* @@ -400,12 +407,9 @@ namespace sat { Lw = Lw u {l_s} Lu = Lu \ {l_s} } - if (Sw < k) conflict - while (Sw < k + a_max) { - assign (l_max) - a_max = max { ai | li in Lw, li = undef } - } - ASSERT(Sw >= bound) + if (Sw < k) return conflict + for (li in Lw | Sw < k + ai) + assign li return no-conflict a_max index: index of non-false literal with maximal weight. @@ -420,8 +424,22 @@ namespace sat { } } + static unsigned _bad_id = 62390000; + +#define BADLOG(_cmd_) if (p.id() == _bad_id) { _cmd_; } + + /* + \brief propagate assignment to alit in constraint p. + + TBD: + - consider reordering literals in watch list so that the search for watched literal takes average shorter time. + - combine with caching literals that are assigned to 'true' to a cold store where they are not being revisited. + Since 'true' literals may be unassigned (unless they are assigned at level 0) the cache has to be backtrack + friendly (and the overhead of backtracking has to be taken into account). + */ lbool ba_solver::add_assign(pb& p, literal alit) { + BADLOG(display(std::cout << "assign: " << alit << " watch: " << p.num_watch() << " size: " << p.size(), p, true)); TRACE("sat", display(tout << "assign: " << alit << "\n", p, true);); SASSERT(!inconsistent()); unsigned sz = p.size(); @@ -442,12 +460,16 @@ namespace sat { } add_index(p, index, lit); } - SASSERT(index < num_watch); - if (index >= num_watch) { - std::cout << "BAD assign. " << alit << " not found within " << num_watch << "\n"; - std::cout << p << "\n"; + if (index == num_watch) { + _bad_id = p.id(); + std::cout << p.id() << "\n"; + display(std::cout, p, true); + std::cout << "alit: " << alit << "\n"; + std::cout << "num watch " << num_watch << "\n"; + return l_undef; } - + + SASSERT(index < num_watch); unsigned index1 = index + 1; for (; m_a_max == 0 && index1 < num_watch; ++index1) { add_index(p, index1, p[index1].second); @@ -457,6 +479,7 @@ namespace sat { SASSERT(value(p[index].second) == l_false); SASSERT(val <= slack); slack -= val; + // find literals to swap with: for (unsigned j = num_watch; j < sz && slack < bound + m_a_max; ++j) { literal lit = p[j].second; @@ -465,6 +488,7 @@ namespace sat { watch_literal(p[j], p); p.swap(num_watch, j); add_index(p, num_watch, lit); + BADLOG(std::cout << "add watch: " << lit << " num watch: " << num_watch << "\n"); ++num_watch; } } @@ -477,16 +501,18 @@ namespace sat { slack += val; p.set_slack(slack); p.set_num_watch(num_watch); + BADLOG(display(std::cout << "conflict: " << alit << " watch: " << p.num_watch() << " size: " << p.size(), p, true)); SASSERT(bound <= slack); TRACE("sat", tout << "conflict " << alit << "\n";); set_conflict(p, alit); return l_false; } - if (index > p.size() || num_watch > p.size() || num_watch == 0) { + if (index > p.size() || num_watch > p.size() || num_watch == 0 || p.id() == _bad_id) { display(std::cout, p, true); - std::cout << "size: " << p.size() << "index: " << index << " num watch: " << num_watch << "\n"; + std::cout << "size: " << p.size() << " index: " << index << " num watch: " << num_watch << "\n"; } + // swap out the watched literal. --num_watch; SASSERT(num_watch > 0); @@ -494,32 +520,28 @@ namespace sat { p.set_num_watch(num_watch); p.swap(num_watch, index); - if (slack < bound + m_a_max) { - TRACE("sat", tout << p; for(auto j : m_pb_undef) tout << j << "\n";); - literal_vector to_assign; - while (!m_pb_undef.empty()) { - index1 = m_pb_undef.back(); - if (index1 == num_watch) index1 = index; // it was swapped with index above. - if (index1 >= num_watch) { - std::cout << "BAD assignment at position " << index1 << " with " << num_watch << "\n"; - std::cout << p << "\n"; - } - literal lit = p[index1].second; - SASSERT(value(lit) == l_undef); - TRACE("sat", tout << index1 << " " << lit << "\n";); - if (slack >= bound + p[index1].first) { - break; - } - m_pb_undef.pop_back(); - to_assign.push_back(lit); - } + BADLOG(std::cout << "swap watched: " << alit << " watch: " << p.num_watch() << " size: " << p.size() << " slack: " << p.slack() << "\n"); - for (literal lit : to_assign) { - assign(p, lit); + // + // slack >= bound, but slack - w(l) < bound + // l must be true. + // + if (slack < bound + m_a_max) { + TRACE("sat", tout << p; for(auto j : m_pb_undef) tout << j << "\n";); + for (unsigned index1 : m_pb_undef) { + if (index1 == num_watch) { + index1 = index; + } + wliteral wl = p[index1]; + literal lit = wl.second; + SASSERT(value(lit) == l_undef); + BADLOG(std::cout << "Assign " << lit << "\n"); + if (slack < bound + wl.first) { + assign(p, lit); + } } } - TRACE("sat", display(tout << "assign: " << alit << "\n", p, true);); return l_undef; @@ -564,7 +586,7 @@ namespace sat { if (p.is_cardinality()) { literal_vector lits(p.literals()); unsigned k = (p.k() + p[0].first - 1) / p[0].first; - add_at_least(p.lit(), lits, k); + add_at_least(p.lit(), lits, k, p.learned()); remove_constraint(p); } else if (p.lit() == null_literal) { @@ -711,7 +733,7 @@ namespace sat { return odd; } - void ba_solver::init_watch(xor& x, bool is_true) { + bool ba_solver::init_watch(xor& x, bool is_true) { clear_watch(x); if (x.lit() != null_literal && x.lit().sign() == is_true) { x.negate(); @@ -739,18 +761,18 @@ namespace sat { SASSERT(x.lit() == null_literal || value(x.lit()) == l_true); set_conflict(x, x[j]); } - break; + return false; case 1: SASSERT(x.lit() == null_literal || value(x.lit()) == l_true); assign(x, parity(x, 1) ? ~x[0] : x[0]); - break; + return false; default: SASSERT(j == 2); watch_literal(x[0], x); watch_literal(x[1], x); watch_literal(~x[0], x); watch_literal(~x[1], x); - break; + return true; } } @@ -801,7 +823,7 @@ namespace sat { // conflict resolution void ba_solver::normalize_active_coeffs() { - while (!m_active_var_set.empty()) m_active_var_set.erase(); + reset_active_var_set(); unsigned i = 0, j = 0, sz = m_active_vars.size(); for (; i < sz; ++i) { bool_var v = m_active_vars[i]; @@ -864,9 +886,9 @@ namespace sat { static bool _debug_conflict = false; - bool ba_solver::resolve_conflict() { + lbool ba_solver::resolve_conflict() { if (0 == m_num_propagations_since_pop) { - return false; + return l_undef; } reset_coeffs(); m_num_marks = 0; @@ -1044,10 +1066,62 @@ namespace sat { SASSERT(validate_lemma()); normalize_active_coeffs(); + + if (!create_asserting_lemma()) { + goto bail_out; + } + active2card(); + + SASSERT(validate_conflict(m_lemma, m_A)); + + TRACE("sat", tout << m_lemma << "\n";); + + if (get_config().m_drat) { + svector ps; // TBD fill in + drat_add(m_lemma, ps); + } + + s().m_lemma.reset(); + s().m_lemma.append(m_lemma); + for (unsigned i = 1; i < m_lemma.size(); ++i) { + CTRACE("sat", s().is_marked(m_lemma[i].var()), tout << "marked: " << m_lemma[i] << "\n";); + s().mark(m_lemma[i].var()); + } + + return l_true; + + bail_out: + while (m_num_marks > 0 && idx >= 0) { + bool_var v = lits[idx].var(); + if (s().is_marked(v)) { + s().reset_mark(v); + --m_num_marks; + } + if (idx == 0 && !_debug_conflict) { + _debug_conflict = true; + // s().display(std::cout); + std::cout << s().m_not_l << "\n"; + for (literal l : lits) { + if (s().is_marked(l.var())) { + std::cout << "missing mark: " << l << "\n"; + s().reset_mark(l.var()); + } + } + m_num_marks = 0; + resolve_conflict(); + } + --idx; + } + return l_undef; + } + + bool ba_solver::create_asserting_lemma() { + + adjust_conflict_level: + int slack = -m_bound; - for (unsigned i = 0; i < m_active_vars.size(); ++i) { - bool_var v = m_active_vars[i]; + for (bool_var v : m_active_vars) { slack += get_abs_coeff(v); } @@ -1085,69 +1159,39 @@ namespace sat { } } - if (slack >= 0) { IF_VERBOSE(2, verbose_stream() << "(sat.card slack: " << slack << " skipped: " << num_skipped << ")\n";); - goto bail_out; + return false; } - if (m_lemma[0] == null_literal) { - m_lemma[0] = m_lemma.back(); - m_lemma.pop_back(); - unsigned level = m_lemma.empty() ? 0 : lvl(m_lemma[0]); - for (unsigned i = 1; i < m_lemma.size(); ++i) { - if (lvl(m_lemma[i]) > level) { - level = lvl(m_lemma[i]); - std::swap(m_lemma[0], m_lemma[i]); - } - } - IF_VERBOSE(2, verbose_stream() << "(sat.card set level to " << level << " < " << m_conflict_lvl << ")\n";); - } - - if (slack < -1) std::cout << "lemma: " << m_lemma << " >= " << -slack << "\n"; - SASSERT(slack < 0); - - SASSERT(validate_conflict(m_lemma, m_A)); - TRACE("sat", tout << m_lemma << "\n";); - - if (get_config().m_drat) { - svector ps; // TBD fill in - drat_add(m_lemma, ps); + if (m_lemma[0] == null_literal) { + if (m_lemma.size() == 1) { + s().set_conflict(justification()); + return false; + } + unsigned old_level = m_conflict_lvl; + m_conflict_lvl = 0; + for (unsigned i = 1; i < m_lemma.size(); ++i) { + m_conflict_lvl = std::max(m_conflict_lvl, lvl(m_lemma[i])); + } + IF_VERBOSE(1, verbose_stream() << "(sat-backjump :new-level " << m_conflict_lvl << " :old-level " << old_level << ")\n";); + goto adjust_conflict_level; } - s().m_lemma.reset(); - s().m_lemma.append(m_lemma); - for (unsigned i = 1; i < m_lemma.size(); ++i) { - CTRACE("sat", s().is_marked(m_lemma[i].var()), tout << "marked: " << m_lemma[i] << "\n";); - s().mark(m_lemma[i].var()); + // slack is related to coefficients of m_lemma + // so does not apply to unit coefficients. + // std::cout << "lemma: " << m_lemma << " >= " << 1 << "\n"; + // active2pb(m_A); + // display(std::cout, m_A, true); +#if 0 + constraint* c = active2constraint(); + if (c) { + display(std::cout, *c, true); + std::cout << "Eval: " << eval(*c) << "\n"; } - +#endif return true; - - bail_out: - while (m_num_marks > 0 && idx >= 0) { - bool_var v = lits[idx].var(); - if (s().is_marked(v)) { - s().reset_mark(v); - --m_num_marks; - } - if (idx == 0 && !_debug_conflict) { - _debug_conflict = true; - // s().display(std::cout); - std::cout << s().m_not_l << "\n"; - for (literal l : lits) { - if (s().is_marked(l.var())) { - std::cout << "missing mark: " << l << "\n"; - s().reset_mark(l.var()); - } - } - m_num_marks = 0; - resolve_conflict(); - } - --idx; - } - return false; } void ba_solver::cut() { @@ -1184,12 +1228,17 @@ namespace sat { } } if (g >= 2) { + active2pb(m_A); + display(std::cout, m_A, true); normalize_active_coeffs(); + int ig = static_cast(g); for (unsigned i = 0; i < m_active_vars.size(); ++i) { - m_coeffs[m_active_vars[i]] /= g; + m_coeffs[m_active_vars[i]] /= ig; } m_bound = (m_bound + g - 1) / g; std::cout << "CUT " << g << "\n"; + active2pb(m_A); + display(std::cout, m_A, true); } } @@ -1238,32 +1287,51 @@ namespace sat { return p; } - ba_solver::ba_solver(): m_solver(0), m_lookahead(0) { + ba_solver::ba_solver(): m_solver(0), m_lookahead(0), m_constraint_id(0) { TRACE("sat", tout << this << "\n";); } ba_solver::~ba_solver() { m_stats.reset(); - while (!m_constraints.empty()) { - pop_constraint(); + for (constraint* c : m_constraints) { + m_allocator.deallocate(c->obj_size(), c); + } + for (constraint* c : m_learned) { + m_allocator.deallocate(c->obj_size(), c); } } void ba_solver::add_at_least(bool_var v, literal_vector const& lits, unsigned k) { literal lit = v == null_bool_var ? null_literal : literal(v, false); - add_at_least(lit, lits, k); + add_at_least(lit, lits, k, false); } - void ba_solver::add_at_least(literal lit, literal_vector const& lits, unsigned k) { + ba_solver::card& ba_solver::add_at_least(literal lit, literal_vector const& lits, unsigned k, bool learned) { void * mem = m_allocator.allocate(card::get_obj_size(lits.size())); - card* c = new (mem) card(lit, lits, k); + card* c = new (mem) card(next_id(), lit, lits, k); + c->set_learned(learned); add_constraint(c); + return *c; } void ba_solver::add_constraint(constraint* c) { - m_constraints.push_back(c); + if (c->id() == _bad_id) { + display(std::cout, *c, true); + } + if (c->learned()) { + m_learned.push_back(c); + } + else { + SASSERT(s().at_base_lvl()); + m_constraints.push_back(c); + } literal lit = c->lit(); - if (lit == null_literal) { + if (c->learned()) { + SASSERT(lit == null_literal); + // gets initialized after backjump. + m_constraint_to_reinit.push_back(c); + } + else if (lit == null_literal) { init_watch(*c, true); } else { @@ -1275,12 +1343,15 @@ namespace sat { } - void ba_solver::init_watch(constraint& c, bool is_true) { + bool ba_solver::init_watch(constraint& c, bool is_true) { + if (inconsistent()) return false; switch (c.tag()) { - case card_t: init_watch(c.to_card(), is_true); break; - case pb_t: init_watch(c.to_pb(), is_true); break; - case xor_t: init_watch(c.to_xor(), is_true); break; + case card_t: return init_watch(c.to_card(), is_true); + case pb_t: return init_watch(c.to_pb(), is_true); + case xor_t: return init_watch(c.to_xor(), is_true); } + UNREACHABLE(); + return false; } lbool ba_solver::add_assign(constraint& c, literal l) { @@ -1293,42 +1364,48 @@ namespace sat { return l_undef; } - ba_solver::pb const& ba_solver::add_pb_ge(literal lit, svector const& wlits, unsigned k) { + ba_solver::pb& ba_solver::add_pb_ge(literal lit, svector const& wlits, unsigned k, bool learned) { void * mem = m_allocator.allocate(pb::get_obj_size(wlits.size())); - pb* p = new (mem) pb(lit, wlits, k); + pb* p = new (mem) pb(next_id(), lit, wlits, k); + p->set_learned(learned); add_constraint(p); return *p; } void ba_solver::add_pb_ge(bool_var v, svector const& wlits, unsigned k) { literal lit = v == null_bool_var ? null_literal : literal(v, false); - add_pb_ge(lit, wlits, k); + add_pb_ge(lit, wlits, k, false); } void ba_solver::add_xor(bool_var v, literal_vector const& lits) { - add_xor(literal(v, false), lits); + add_xor(literal(v, false), lits, false); } - void ba_solver::add_xor(literal lit, literal_vector const& lits) { + ba_solver::xor& ba_solver::add_xor(literal lit, literal_vector const& lits, bool learned) { void * mem = m_allocator.allocate(xor::get_obj_size(lits.size())); - xor* x = new (mem) xor(lit, lits); + xor* x = new (mem) xor(next_id(), lit, lits); + x->set_learned(learned); add_constraint(x); for (literal l : lits) s().set_external(l.var()); // TBD: determine if goal2sat does this. + return *x; } - void ba_solver::propagate(literal l, ext_constraint_idx idx, bool & keep) { + /* + \brief return true to keep watching literal. + */ + bool ba_solver::propagate(literal l, ext_constraint_idx idx) { SASSERT(value(l) == l_true); TRACE("sat", tout << l << " " << idx << "\n";); constraint& c = index2constraint(idx); if (c.lit() != null_literal && l.var() == c.lit().var()) { init_watch(c, !l.sign()); - keep = true; + return true; } else if (c.lit() != null_literal && value(c.lit()) != l_true) { - keep = false; + return false; } else { - keep = l_undef != add_assign(c, ~l); + return l_undef != add_assign(c, ~l); } } @@ -1436,88 +1513,78 @@ namespace sat { TRACE("sat", tout << r << "\n";); } + /** + \brief retrieve a sufficient set of literals from p that imply l. + + Find partition: + + - Ax + coeff*l + B*y >= k + - all literals in x are false. + - B < k + + Then x is an explanation for l + + */ void ba_solver::get_antecedents(literal l, pb const& p, literal_vector& r) { - if (p.lit() != null_literal) r.push_back(p.lit()); - SASSERT(p.lit() == null_literal || value(p.lit()) == l_true); TRACE("sat", display(tout, p, true);); + SASSERT(p.lit() == null_literal || value(p.lit()) == l_true); + + if (p.lit() != null_literal) { + r.push_back(p.lit()); + } + + unsigned k = p.k(); if (value(l) == l_false) { // The literal comes from a conflict. // it is forced true, but assigned to false. unsigned slack = 0; - unsigned miss = 0; - unsigned worth = 0; - unsigned k = p.k(); for (wliteral wl : p) { - literal lit = wl.second; - if (lit == l) { - worth = wl.first; - } - else if (value(lit) == l_false) { - miss += wl.first; - } - else { + if (value(wl.second) != l_false) { slack += wl.first; } } SASSERT(slack < k); - SASSERT(0 < worth); - - slack += worth; for (wliteral wl : p) { literal lit = wl.second; if (lit != l && value(lit) == l_false) { unsigned w = wl.first; - if (slack + w >= k) { - r.push_back(~lit); + if (slack + w < k) { + slack += w; } else { - slack += w; - std::cout << "increase slack by " << w << " to " << slack << " worth: " << worth << "\n"; + r.push_back(~lit); } } } -#if 0 - std::cout << p << "\n"; - std::cout << r << "\n"; - std::cout << "slack:" << slack << " miss: " << miss << "\n"; -#endif - return; } - - unsigned coeff = 0; - for (unsigned j = 0; j < p.num_watch(); ++j) { - if (p[j].second == l) { - coeff = p[j].first; - break; + else { + unsigned coeff = 0; + for (unsigned j = 0; j < p.num_watch(); ++j) { + if (p[j].second == l) { + coeff = p[j].first; + break; + } + } + + CTRACE("sat", coeff == 0, display(tout << l << " coeff: " << coeff << "\n", p, true);); + + SASSERT(coeff > 0); + unsigned slack = p.slack() - coeff; + + for (unsigned i = p.num_watch(); i < p.size(); ++i) { + literal lit = p[i].second; + unsigned w = p[i].first; + SASSERT(l_false == value(lit)); + if (slack + w < k) { + slack += w; + } + else { + r.push_back(~lit); + } } } - - if (_debug_conflict) { - std::cout << p << "\n"; - std::cout << l << " " << coeff << " num_watch: " << p.num_watch() << "\n"; - } - - CTRACE("sat", coeff == 0, display(tout << l << " coeff: " << coeff << "\n", p, true);); - - SASSERT(coeff > 0); - unsigned slack = p.slack() - coeff; - unsigned i = p.num_watch(); - - // skip entries that are not required for unit propagation. - // slack - coeff + w_head < k - unsigned h = 0; - for (; i < p.size() && p[i].first + h + slack < p.k(); ++i) { - h += p[i].first; - } - for (; i < p.size(); ++i) { - literal lit = p[i].second; - CTRACE("sat", l_false != value(lit), - tout << l << " index: " << i << " slack: " << slack << " h: " << h << " coeff: " << coeff << "\n"; - display(tout, p, true);); - SASSERT(l_false == value(lit)); - r.push_back(~lit); - } + SASSERT(validate_unit_propagation(p, r, l)); } void ba_solver::simplify(xor& x) { @@ -1741,7 +1808,7 @@ namespace sat { if (lvl(l) == 0) continue; bool found = is_watching(l, c); if (found != c.is_watching(l)) { - std::cout << "Discrepancy of watched literal: " << l << ": " << c.index() << " " << c << (found?" is watched, but shouldn't be":" not watched, but should be") << "\n"; + std::cout << "Discrepancy of watched literal: " << l << ": " << c.id() << " " << c << (found?" is watched, but shouldn't be":" not watched, but should be") << "\n"; display_watch_list(std::cout << l << ": ", s().m_cls_allocator, get_wlist(l)) << "\n"; display_watch_list(std::cout << ~l << ": ", s().m_cls_allocator, get_wlist(~l)) << "\n"; std::cout << "value: " << value(l) << " level: " << lvl(l) << "\n"; @@ -1858,28 +1925,30 @@ namespace sat { check_result ba_solver::check() { return CR_DONE; } void ba_solver::push() { - m_constraint_lim.push_back(m_constraints.size()); + m_constraint_to_reinit_lim.push_back(m_constraint_to_reinit.size()); } - void ba_solver::pop_constraint() { - constraint* c = m_constraints.back(); - m_constraints.pop_back(); - remove_constraint(*c); - m_allocator.deallocate(c->obj_size(), c); - } - - void ba_solver::pop(unsigned n) { TRACE("sat_verbose", tout << "pop:" << n << "\n";); - unsigned new_lim = m_constraint_lim.size() - n; - unsigned sz = m_constraint_lim[new_lim]; - while (m_constraints.size() > sz) { - pop_constraint(); - } - m_constraint_lim.resize(new_lim); + unsigned new_lim = m_constraint_to_reinit_lim.size() - n; + m_constraint_to_reinit_last_sz = m_constraint_to_reinit_lim[new_lim]; + m_constraint_to_reinit_lim.shrink(new_lim); m_num_propagations_since_pop = 0; } + void ba_solver::pop_reinit() { + // TBD: need a stack to follow backtracking order. + unsigned sz = m_constraint_to_reinit_last_sz; + // if (sz < m_constraint_to_reinit.size()) std::cout << "REINIT " << s().scope_lvl() << " " << m_constraint_to_reinit.size() - sz << "\n"; + for (unsigned i = sz; i < m_constraint_to_reinit.size(); ++i) { + constraint* c = m_constraint_to_reinit[i]; + if (!init_watch(*c, true)) { + m_constraint_to_reinit[sz++] = c; + } + } + m_constraint_to_reinit.shrink(sz); + } + void ba_solver::simplify(constraint& c) { SASSERT(s().at_base_lvl()); switch (c.tag()) { @@ -1902,11 +1971,12 @@ namespace sat { unsigned trail_sz; do { trail_sz = s().init_trail_size(); - IF_VERBOSE(1, verbose_stream() << "(bool-algebra-solver simplify-begin :trail " << trail_sz << ")\n";); + IF_VERBOSE(1, verbose_stream() << "(bool-algebra-solver simplify-begin :trail " << trail_sz << " :learned " << m_learned.size() << ")\n";); m_simplify_change = false; m_clause_removed = false; m_constraint_removed = false; - // for (constraint* c : m_constraints) simplify(*c); + for (constraint* c : m_constraints) simplify(*c); + for (constraint* c : m_learned) simplify(*c); init_use_lists(); remove_unused_defs(); set_non_external(); @@ -1934,7 +2004,7 @@ namespace sat { if (mux.size() > 2) { IF_VERBOSE(1, verbose_stream() << "mux: " << mux << "\n";); for (unsigned i = 0; i < mux.size(); ++i) mux[i].neg(); - add_at_least(null_literal, mux, mux.size() - 1); + add_at_least(null_literal, mux, mux.size() - 1, false); } } } @@ -2081,6 +2151,9 @@ namespace sat { } void ba_solver::recompile(constraint& c) { + if (c.id() == _bad_id) { + display(std::cout << "recompile\n", c, true); + } switch (c.tag()) { case card_t: recompile(c.to_card()); @@ -2158,7 +2231,7 @@ namespace sat { } literal root = c.lit(); remove_constraint(c); - pb const& p = add_pb_ge(root, wlits, k); + pb const& p = add_pb_ge(root, wlits, k, c.learned()); IF_VERBOSE(1, verbose_stream() << p << "\n";); } else { @@ -2665,7 +2738,7 @@ namespace sat { card const& c = cp->to_card(); lits.reset(); for (literal l : c) lits.push_back(l); - result->add_at_least(c.lit(), lits, c.k()); + result->add_at_least(c.lit(), lits, c.k(), c.learned()); break; } case pb_t: { @@ -2674,14 +2747,14 @@ namespace sat { for (wliteral w : p) { wlits.push_back(w); } - result->add_pb_ge(p.lit(), wlits, p.k()); + result->add_pb_ge(p.lit(), wlits, p.k(), p.learned()); break; } case xor_t: { xor const& x = cp->to_xor(); lits.reset(); for (literal l : x) lits.push_back(l); - result->add_xor(x.lit(), lits); + result->add_xor(x.lit(), lits, x.learned()); break; } default: @@ -2832,6 +2905,21 @@ namespace sat { return sum < p.k(); } + bool ba_solver::validate_unit_propagation(pb const& p, literal_vector const& r, literal alit) const { + unsigned sum = 0; + // all elements of r are true, + for (literal l : r) { + if (value(l) != l_true) return false; + } + // the sum of elements not in r or alit add up to less than k. + for (wliteral wl : p) { + if (wl.second != alit && !r.contains(~wl.second)) { + sum += wl.first; + } + } + return sum < p.k(); + } + bool ba_solver::validate_unit_propagation(xor const& x, literal alit) const { if (value(x.lit()) != l_true) return false; for (unsigned i = 1; i < x.size(); ++i) { @@ -2842,7 +2930,7 @@ namespace sat { bool ba_solver::validate_lemma() { int val = -m_bound; - while (!m_active_var_set.empty()) m_active_var_set.erase(); + reset_active_var_set(); for (bool_var v : m_active_vars) { if (m_active_var_set.contains(v)) continue; int coeff = get_coeff(v); @@ -2860,20 +2948,163 @@ namespace sat { return val < 0; } - void ba_solver::active2pb(ineq& p) { + void ba_solver::reset_active_var_set() { while (!m_active_var_set.empty()) m_active_var_set.erase(); + } + + void ba_solver::active2pb(ineq& p) { + reset_active_var_set(); p.reset(m_bound); for (bool_var v : m_active_vars) { if (m_active_var_set.contains(v)) continue; int coeff = get_coeff(v); if (coeff == 0) continue; m_active_var_set.insert(v); - literal lit(v, get_coeff(v) < 0); + literal lit(v, coeff < 0); p.m_lits.push_back(lit); - p.m_coeffs.push_back(get_abs_coeff(v)); + p.m_coeffs.push_back(abs(coeff)); } } + ba_solver::constraint* ba_solver::active2constraint() { + reset_active_var_set(); + literal_vector lits; + unsigned_vector coeffs; + bool all_one = true; + uint64_t sum = 0; + if (m_bound == 1) return 0; + for (bool_var v : m_active_vars) { + int coeff = get_coeff(v); + if (m_active_var_set.contains(v) || coeff == 0) continue; + m_active_var_set.insert(v); + literal lit(v, coeff < 0); + lits.push_back(lit); + coeffs.push_back(abs(coeff)); + all_one &= abs(coeff) == 1; + sum += abs(coeff); + } + if (sum >= UINT_MAX/2) return 0; + if (all_one) { + card& c = add_at_least(null_literal, lits, m_bound, true); + return &c; + } + else { + svector wlits; + for (unsigned i = 0; i < lits.size(); ++i) { + wlits.push_back(wliteral(coeffs[i], lits[i])); + } + pb& p = add_pb_ge(null_literal, wlits, m_bound, true); + return &p; + } + } + + /* + Chai Kuhlmann: + + a1*l1 + ... + a_n*l_n >= k + s.t. + a1 >= a2 >= .. >= a_n + + let m be such that + + sum_{i = 1}^{m-1} a_i < k <= sum_{i = 1}^{m} + + then + + l1 + ... + l_n >= m + + furthermore, for the largest n' <= n, such that + + sum_{i = n'+1}^n a_i + sum_{i = 1}^{m-1} a_i < k + + then + + l1 + ... + l_n' >= m + + */ + struct compare_wlit { + bool operator()(ba_solver::wliteral l1, ba_solver::wliteral l2) const { + return l1.first > l2.first; + } + }; + + + ba_solver::card* ba_solver::active2card() { + normalize_active_coeffs(); + svector wlits; + for (bool_var v : m_active_vars) { + int coeff = get_coeff(v); + wlits.push_back(std::make_pair(abs(coeff), literal(v, coeff < 0))); + } + std::sort(wlits.begin(), wlits.end(), compare_wlit()); + unsigned k = 0; + int sum = 0, sum0 = 0; + for (wliteral wl : wlits) { + if (sum >= m_bound) break; + sum0 = sum; + sum += wl.first; + ++k; + } + if (k == 1) { + return 0; + } + while (!wlits.empty()) { + wliteral wl = wlits.back(); + if (wl.first + sum0 >= static_cast(m_bound)) break; + wlits.pop_back(); + sum0 += wl.first; + } + + unsigned slack = 0; + unsigned max_level = 0; + unsigned num_max_level = 0; + for (wliteral wl : wlits) { + if (value(wl.second) != l_false) ++slack; + unsigned level = lvl(wl.second); + if (level > max_level) { + max_level = level; + num_max_level = 1; + } + else if (max_level == level) { + ++num_max_level; + } + } + + + if (slack >= k) { + return 0; + } + + +#if 0 + std::cout << "card: "; + for (wliteral wl : wlits) std::cout << wl.second << " "; + std::cout << ">= " << k << "\n"; + + if (num_max_level > 1) { + std::cout << "max level " << num_max_level << "\n"; + } + + + if (wlits.size() < m_active_vars.size()) std::cout << "REMOVED " << m_active_vars.size() - wlits.size() << "\n"; +#endif + + // produce asserting cardinality constraint + literal_vector lits; + for (wliteral wl : wlits) lits.push_back(wl.second); + card& c = add_at_least(null_literal, lits, k, true); + + lits.reset(); + for (wliteral wl : wlits) { + if (value(wl.second) == l_false) lits.push_back(wl.second); + } + unsigned glue = s().num_diff_levels(lits.size(), lits.c_ptr()); + + c.set_glue(glue); + return &c; + } + + void ba_solver::justification2pb(justification const& js, literal lit, unsigned offset, ineq& ineq) { switch (js.get_kind()) { case justification::NONE: diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index 0651dd495..b8978b254 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -62,9 +62,12 @@ namespace sat { unsigned m_glue; unsigned m_size; size_t m_obj_size; + bool m_learned; + unsigned m_id; public: - constraint(tag_t t, literal l, unsigned sz, size_t osz): m_tag(t), m_removed(false), m_lit(l), m_glue(0), m_size(sz), m_obj_size(osz) {} + constraint(tag_t t, unsigned id, literal l, unsigned sz, size_t osz): m_tag(t), m_removed(false), m_lit(l), m_glue(0), m_size(sz), m_obj_size(osz), m_learned(false), m_id(id) {} ext_constraint_idx index() const { return reinterpret_cast(this); } + unsigned id() const { return m_id; } tag_t tag() const { return m_tag; } literal lit() const { return m_lit; } unsigned size() const { return m_size; } @@ -75,6 +78,8 @@ namespace sat { void nullify_literal() { m_lit = null_literal; } unsigned glue() const { return m_glue; } void set_glue(unsigned g) { m_glue = g; } + void set_learned(bool f) { m_learned = f; } + bool learned() const { return m_learned; } size_t obj_size() const { return m_obj_size; } card& to_card(); @@ -102,7 +107,7 @@ namespace sat { protected: unsigned m_k; public: - pb_base(tag_t t, literal l, unsigned sz, size_t osz, unsigned k): constraint(t, l, sz, osz), m_k(k) {} + pb_base(tag_t t, unsigned id, literal l, unsigned sz, size_t osz, unsigned k): constraint(t, id, l, sz, osz), m_k(k) {} virtual void set_k(unsigned k) { m_k = k; } virtual unsigned get_coeff(unsigned i) const { UNREACHABLE(); return 0; } unsigned k() const { return m_k; } @@ -113,7 +118,7 @@ namespace sat { literal m_lits[0]; public: static size_t get_obj_size(unsigned num_lits) { return sizeof(card) + num_lits * sizeof(literal); } - card(literal lit, literal_vector const& lits, unsigned k); + card(unsigned id, literal lit, literal_vector const& lits, unsigned k); literal operator[](unsigned i) const { return m_lits[i]; } literal& operator[](unsigned i) { return m_lits[i]; } literal const* begin() const { return m_lits; } @@ -138,7 +143,7 @@ namespace sat { void update_max_sum(); public: static size_t get_obj_size(unsigned num_lits) { return sizeof(pb) + num_lits * sizeof(wliteral); } - pb(literal lit, svector const& wlits, unsigned k); + pb(unsigned id, literal lit, svector const& wlits, unsigned k); literal lit() const { return m_lit; } wliteral operator[](unsigned i) const { return m_wlits[i]; } wliteral& operator[](unsigned i) { return m_wlits[i]; } @@ -165,7 +170,7 @@ namespace sat { literal m_lits[0]; public: static size_t get_obj_size(unsigned num_lits) { return sizeof(xor) + num_lits * sizeof(literal); } - xor(literal lit, literal_vector const& lits); + xor(unsigned id, literal lit, literal_vector const& lits); literal operator[](unsigned i) const { return m_lits[i]; } literal const* begin() const { return m_lits; } literal const* end() const { return begin() + m_size; } @@ -197,7 +202,10 @@ namespace sat { ptr_vector m_constraints; ptr_vector m_learned; - unsigned_vector m_constraint_lim; + ptr_vector m_constraint_to_reinit; + unsigned_vector m_constraint_to_reinit_lim; + unsigned m_constraint_to_reinit_last_sz; + unsigned m_constraint_id; // conflict resolution unsigned m_num_marks; @@ -272,7 +280,7 @@ namespace sat { void watch_literal(literal w, constraint& c); void watch_literal(wliteral w, pb& p); void add_constraint(constraint* c); - void init_watch(constraint& c, bool is_true); + bool init_watch(constraint& c, bool is_true); void init_watch(bool_var v); void clear_watch(constraint& c); lbool add_assign(constraint& c, literal l); @@ -290,10 +298,11 @@ namespace sat { void assert_unconstrained(literal lit, literal_vector const& lits); void flush_roots(constraint& c); void recompile(constraint& c); + unsigned next_id() { return m_constraint_id++; } // cardinality - void init_watch(card& c, bool is_true); + bool init_watch(card& c, bool is_true); lbool add_assign(card& c, literal lit); void clear_watch(card& c); void reset_coeffs(); @@ -305,7 +314,7 @@ namespace sat { // xor specific functionality void clear_watch(xor& x); - void init_watch(xor& x, bool is_true); + bool init_watch(xor& x, bool is_true); bool parity(xor const& x, unsigned offset) const; lbool add_assign(xor& x, literal alit); void get_xor_antecedents(literal l, unsigned index, justification js, literal_vector& r); @@ -316,7 +325,7 @@ namespace sat { // pb functionality unsigned m_a_max; - void init_watch(pb& p, bool is_true); + bool init_watch(pb& p, bool is_true); lbool add_assign(pb& p, literal alit); void add_index(pb& p, unsigned index, literal lit); void clear_watch(pb& p); @@ -342,6 +351,7 @@ namespace sat { inline void drat_add(literal_vector const& c, svector const& premises) { m_solver->m_drat.add(c, premises); } + void reset_active_var_set(); void normalize_active_coeffs(); void inc_coeff(literal l, int offset); int get_coeff(bool_var v) const; @@ -351,6 +361,7 @@ namespace sat { void process_antecedent(literal l, int offset); void process_card(card& c, int offset); void cut(); + bool create_asserting_lemma(); // validation utilities bool validate_conflict(card const& c) const; @@ -360,6 +371,7 @@ namespace sat { bool validate_lemma(); bool validate_unit_propagation(card const& c, literal alit) const; bool validate_unit_propagation(pb const& p, literal alit) const; + bool validate_unit_propagation(pb const& p, literal_vector const& r, literal alit) const; bool validate_unit_propagation(xor const& x, literal alit) const; bool validate_conflict(literal_vector const& lits, ineq& p); bool validate_watch_literals() const; @@ -369,6 +381,8 @@ namespace sat { ineq m_A, m_B, m_C; void active2pb(ineq& p); + constraint* active2constraint(); + card* active2card(); void justification2pb(justification const& j, literal lit, unsigned offset, ineq& p); bool validate_resolvent(); @@ -378,9 +392,9 @@ namespace sat { void display(std::ostream& out, pb const& p, bool values) const; void display(std::ostream& out, xor const& c, bool values) const; - void add_at_least(literal l, literal_vector const& lits, unsigned k); - pb const& add_pb_ge(literal l, svector const& wlits, unsigned k); - void add_xor(literal l, literal_vector const& lits); + card& add_at_least(literal l, literal_vector const& lits, unsigned k, bool learned); + pb& add_pb_ge(literal l, svector const& wlits, unsigned k, bool learned); + xor& add_xor(literal l, literal_vector const& lits, bool learned); public: ba_solver(); @@ -391,8 +405,8 @@ namespace sat { void add_pb_ge(bool_var v, svector const& wlits, unsigned k); void add_xor(bool_var v, literal_vector const& lits); - virtual void propagate(literal l, ext_constraint_idx idx, bool & keep); - virtual bool resolve_conflict(); + virtual bool propagate(literal l, ext_constraint_idx idx); + virtual lbool resolve_conflict(); virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r); virtual void asserted(literal l); virtual check_result check(); @@ -408,6 +422,7 @@ namespace sat { virtual void collect_statistics(statistics& st) const; virtual extension* copy(solver* s); virtual void find_mutexes(literal_vector& lits, vector & mutexes); + virtual void pop_reinit(); virtual void gc(); ptr_vector const & constraints() const { return m_constraints; } diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index f984091e3..4c052494a 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -34,11 +34,11 @@ namespace sat { virtual ~extension() {} virtual void set_solver(solver* s) = 0; virtual void set_lookahead(lookahead* s) = 0; - virtual void propagate(literal l, ext_constraint_idx idx, bool & keep) = 0; + virtual bool propagate(literal l, ext_constraint_idx idx) = 0; virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) = 0; virtual void asserted(literal l) = 0; virtual check_result check() = 0; - virtual bool resolve_conflict() { return false; } // stores result in sat::solver::m_lemma + virtual lbool resolve_conflict() { return l_undef; } // stores result in sat::solver::m_lemma virtual void push() = 0; virtual void pop(unsigned n) = 0; virtual void simplify() = 0; @@ -53,6 +53,7 @@ namespace sat { virtual extension* copy(solver* s) = 0; virtual void find_mutexes(literal_vector& lits, vector & mutexes) = 0; virtual void gc() = 0; + virtual void pop_reinit() = 0; virtual void validate() = 0; }; diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 2a1837e20..6c40c93fd 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -1170,10 +1170,10 @@ namespace sat { break; } case watched::EXT_CONSTRAINT: { - bool keep = true; SASSERT(m_s.m_ext); - m_s.m_ext->propagate(l, it->get_ext_constraint_idx(), keep); + bool keep = m_s.m_ext->propagate(l, it->get_ext_constraint_idx()); if (m_inconsistent) { + if (!keep) ++it; set_conflict(); } else if (keep) { diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 29c15c2d7..15737d65b 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -795,8 +795,12 @@ namespace sat { } case watched::EXT_CONSTRAINT: SASSERT(m_ext); - m_ext->propagate(l, it->get_ext_constraint_idx(), keep); + keep = m_ext->propagate(l, it->get_ext_constraint_idx()); if (m_inconsistent) { + if (!keep) { + std::cout << "CONFLICT - but throw away current watch literal\n"; + ++it; + } CONFLICT_CLEANUP(); return false; } @@ -1955,9 +1959,17 @@ namespace sat { forget_phase_of_vars(m_conflict_lvl); - if (m_ext && m_ext->resolve_conflict()) { - learn_lemma_and_backjump(); - return true; + if (m_ext) { + switch (m_ext->resolve_conflict()) { + case l_true: + learn_lemma_and_backjump(); + return true; + case l_undef: + break; + case l_false: + // backjumping was taken care of internally. + return true; + } } m_lemma.reset(); @@ -2770,6 +2782,8 @@ namespace sat { m_scope_lvl -= num_scopes; m_scopes.shrink(new_lvl); reinit_clauses(s.m_clauses_to_reinit_lim); + if (m_ext) + m_ext->pop_reinit(); } void solver::unassign_vars(unsigned old_sz) {