diff --git a/src/sat/card_extension.cpp b/src/sat/card_extension.cpp index 985579f18..2af1a5b5e 100644 --- a/src/sat/card_extension.cpp +++ b/src/sat/card_extension.cpp @@ -50,7 +50,7 @@ namespace sat { if (c.lit().sign() == is_true) { c.negate(); } - SASSERT(s().value(c.lit()) == l_true); + SASSERT(value(c.lit()) == l_true); unsigned j = 0, sz = c.size(), bound = c.k(); if (bound == sz) { for (unsigned i = 0; i < sz && !s().inconsistent(); ++i) { @@ -60,7 +60,7 @@ namespace sat { } // put the non-false literals into the head. for (unsigned i = 0; i < sz; ++i) { - if (s().value(c[i]) != l_false) { + if (value(c[i]) != l_false) { if (j != i) { c.swap(i, j); } @@ -70,8 +70,8 @@ namespace sat { DEBUG_CODE( bool is_false = false; for (unsigned k = 0; k < sz; ++k) { - SASSERT(!is_false || s().value(c[k]) == l_false); - is_false = s().value(c[k]) == l_false; + SASSERT(!is_false || value(c[k]) == l_false); + is_false = value(c[k]) == l_false; }); // j is the number of non-false, sz - j the number of false. @@ -86,7 +86,7 @@ namespace sat { // for (unsigned i = bound; i < sz; ++i) { - if (s().lvl(alit) < s().lvl(c[i])) { + if (lvl(alit) < lvl(c[i])) { c.swap(i, j); alit = c[j]; } @@ -133,7 +133,7 @@ namespace sat { } void card_extension::assign(card& c, literal lit) { - if (s().value(lit) == l_true) { + if (value(lit) == l_true) { return; } m_stats.m_num_propagations++; @@ -155,19 +155,20 @@ namespace sat { void card_extension::set_conflict(card& c, literal lit) { SASSERT(validate_conflict(c)); - literal_vector& lits = get_literals(); - SASSERT(s().value(lit) == l_false); - SASSERT(s().value(c.lit()) == l_true); - lits.push_back(~c.lit()); - lits.push_back(lit); - unsigned sz = c.size(); - for (unsigned i = c.k(); i < sz; ++i) { - SASSERT(s().value(c[i]) == l_false); - lits.push_back(c[i]); - } m_stats.m_num_conflicts++; - if (!resolve_conflict(c, lits)) { + if (!resolve_conflict(c, lit)) { + + literal_vector& lits = get_literals(); + SASSERT(value(lit) == l_false); + SASSERT(value(c.lit()) == l_true); + lits.push_back(~c.lit()); + lits.push_back(lit); + unsigned sz = c.size(); + for (unsigned i = c.k(); i < sz; ++i) { + SASSERT(value(c[i]) == l_false); + lits.push_back(c[i]); + } s().mk_clause_core(lits.size(), lits.c_ptr(), true); } SASSERT(s().inconsistent()); @@ -231,16 +232,16 @@ namespace sat { m_active_vars.reset(); } - bool card_extension::resolve_conflict(card& c, literal_vector const& conflict_clause) { + bool card_extension::resolve_conflict(card& c, literal alit) { bool_var v; m_conflict_lvl = 0; for (unsigned i = 0; i < c.size(); ++i) { literal lit = c[i]; - SASSERT(s().value(lit) == l_false); - m_conflict_lvl = std::max(m_conflict_lvl, s().lvl(lit)); + SASSERT(value(lit) == l_false); + m_conflict_lvl = std::max(m_conflict_lvl, lvl(lit)); } - if (m_conflict_lvl < s().lvl(c.lit()) || m_conflict_lvl == 0) { + if (m_conflict_lvl < lvl(c.lit()) || m_conflict_lvl == 0) { return false; } @@ -251,14 +252,14 @@ namespace sat { literal_vector const& lits = s().m_trail; unsigned idx = lits.size()-1; justification js; - literal consequent = ~conflict_clause[1]; + literal consequent = ~alit; process_card(c, 1); - literal alit; + + DEBUG_CODE(active2pb(m_A);); while (m_num_marks > 0) { - + SASSERT(value(consequent) == l_true); v = consequent.var(); - int offset = get_abs_coeff(v); if (offset == 0) { @@ -268,9 +269,11 @@ namespace sat { goto bail_out; } + SASSERT(validate_lemma()); SASSERT(offset > 0); js = s().m_justification[v]; + DEBUG_CODE(justification2pb(js, consequent, offset, m_B);); int bound = 1; switch(js.get_kind()) { @@ -289,15 +292,13 @@ namespace sat { inc_coeff(consequent, offset); clause & c = *(s().m_cls_allocator.get_clause(js.get_clause_offset())); unsigned i = 0; - if (consequent != null_literal) { - SASSERT(c[0] == consequent || c[1] == consequent); - if (c[0] == consequent) { - i = 1; - } - else { - process_antecedent(~c[0], offset); - i = 2; - } + SASSERT(c[0] == consequent || c[1] == consequent); + if (c[0] == consequent) { + i = 1; + } + else { + process_antecedent(~c[0], offset); + i = 2; } unsigned sz = c.size(); for (; i < sz; i++) @@ -316,6 +317,11 @@ namespace sat { break; } m_bound += offset * bound; + + DEBUG_CODE( + active2pb(m_C); + SASSERT(validate_resolvent()); + m_A = m_C;); // cut(); @@ -331,7 +337,7 @@ namespace sat { --idx; } - SASSERT(s().lvl(v) == m_conflict_lvl); + SASSERT(lvl(v) == m_conflict_lvl); s().reset_mark(v); --idx; --m_num_marks; @@ -390,24 +396,24 @@ namespace sat { void card_extension::process_card(card& c, int offset) { SASSERT(c.k() <= c.size()); - SASSERT(s().value(c.lit()) == l_true); + SASSERT(value(c.lit()) == l_true); for (unsigned i = c.k(); i < c.size(); ++i) { process_antecedent(c[i], offset); } for (unsigned i = 0; i < c.k(); ++i) { inc_coeff(c[i], offset); } - if (s().lvl(c.lit()) > 0) { + if (lvl(c.lit()) > 0) { m_conflict.push_back(~c.lit()); } } void card_extension::process_antecedent(literal l, int offset) { - SASSERT(s().value(l) == l_false); + SASSERT(value(l) == l_false); bool_var v = l.var(); - unsigned lvl = s().lvl(v); + unsigned level = lvl(v); - if (lvl > 0 && !s().is_marked(v) && lvl == m_conflict_lvl) { + if (level > 0 && !s().is_marked(v) && level == m_conflict_lvl) { s().mark(v); ++m_num_marks; } @@ -418,12 +424,13 @@ namespace sat { if (get_abs_coeff(p.var()) != 0) { return p; } - unsigned lvl = 0; + unsigned level = 0; for (unsigned i = 0; i < m_active_vars.size(); ++i) { bool_var v = m_active_vars[i]; literal lit(v, get_coeff(v) < 0); - if (s().value(lit) == l_false && s().lvl(lit) > lvl) { + if (value(lit) == l_false && lvl(lit) > level) { p = lit; + level = lvl(lit); } } return p; @@ -463,21 +470,24 @@ namespace sat { } SASSERT(found);); + r.push_back(c.lit()); + SASSERT(value(c.lit()) == l_true); for (unsigned i = c.k(); i < c.size(); ++i) { - SASSERT(s().value(c[i]) == l_false); - r.push_back(c[i]); + SASSERT(value(c[i]) == l_false); + r.push_back(~c[i]); } } + lbool card_extension::add_assign(card& c, literal alit) { // literal is assigned to false. unsigned sz = c.size(); unsigned bound = c.k(); - TRACE("pb", tout << "assign: " << c.lit() << " " << ~alit << " " << bound << "\n";); + TRACE("sat", tout << "assign: " << c.lit() << " " << ~alit << " " << bound << "\n";); SASSERT(0 < bound && bound < sz); - SASSERT(s().value(alit) == l_false); - SASSERT(s().value(c.lit()) == l_true); + SASSERT(value(alit) == l_false); + SASSERT(value(c.lit()) == l_true); unsigned index = 0; for (index = 0; index <= bound; ++index) { if (c[index] == alit) { @@ -494,7 +504,7 @@ namespace sat { // find a literal to swap with: for (unsigned i = bound + 1; i < sz; ++i) { literal lit2 = c[i]; - if (s().value(lit2) != l_false) { + if (value(lit2) != l_false) { c.swap(index, i); watch_literal(c, lit2); return l_undef; @@ -502,13 +512,13 @@ namespace sat { } // conflict - if (bound != index && s().value(c[bound]) == l_false) { + if (bound != index && value(c[bound]) == l_false) { TRACE("sat", tout << "conflict " << c[bound] << " " << alit << "\n";); set_conflict(c, alit); return l_false; } - TRACE("pb", tout << "no swap " << index << " " << alit << "\n";); + TRACE("sat", tout << "no swap " << index << " " << alit << "\n";); // there are no literals to swap with, // prepare for unit propagation by swapping the false literal into // position bound. Then literals in positions 0..bound-1 have to be @@ -532,7 +542,7 @@ namespace sat { ptr_vector::iterator it = cards->begin(), it2 = it, end = cards->end(); for (; it != end; ++it) { card& c = *(*it); - if (s().value(c.lit()) != l_true) { + if (value(c.lit()) != l_true) { continue; } switch (add_assign(c, l)) { @@ -604,9 +614,9 @@ namespace sat { out << c.lit(); if (c.lit() != null_literal) { if (values) { - out << "@(" << s().value(c.lit()); - if (s().value(c.lit()) != l_undef) { - out << ":" << s().lvl(c.lit()); + out << "@(" << value(c.lit()); + if (value(c.lit()) != l_undef) { + out << ":" << lvl(c.lit()); } out << ")"; } @@ -619,9 +629,9 @@ namespace sat { literal l = c[i]; out << l; if (values) { - out << "@(" << s().value(l); - if (s().value(l) != l_undef) { - out << ":" << s().lvl(l); + out << "@(" << value(l); + if (value(l) != l_undef) { + out << ":" << lvl(l); } out << ") "; } @@ -651,35 +661,142 @@ namespace sat { bool card_extension::validate_conflict(card& c) { if (!validate_unit_propagation(c)) return false; for (unsigned i = 0; i < c.k(); ++i) { - if (s().value(c[i]) == l_false) return true; + if (value(c[i]) == l_false) return true; } return false; } bool card_extension::validate_unit_propagation(card const& c) { + if (value(c.lit()) != l_true) return false; for (unsigned i = c.k(); i < c.size(); ++i) { - if (s().value(c[i]) != l_false) return false; + if (value(c[i]) != l_false) return false; } return true; } bool card_extension::validate_lemma() { - int value = -m_bound; + int val = -m_bound; normalize_active_coeffs(); for (unsigned i = 0; i < m_active_vars.size(); ++i) { bool_var v = m_active_vars[i]; int coeff = get_coeff(v); + literal lit(v, false); SASSERT(coeff != 0); - if (coeff < 0 && s().value(v) != l_true) { - value -= coeff; + if (coeff < 0 && value(lit) != l_true) { + val -= coeff; } - else if (coeff > 0 && s().value(v) != l_false) { - value += coeff; + else if (coeff > 0 && value(lit) != l_false) { + val += coeff; } } - return value < 0; + return val < 0; } - bool card_extension::validate_assign(literal_vector const& lits, literal lit) { return true; } - bool card_extension::validate_conflict(literal_vector const& lits) { return true; } + void card_extension::active2pb(ineq& p) { + normalize_active_coeffs(); + p.reset(m_bound); + for (unsigned i = 0; i < m_active_vars.size(); ++i) { + bool_var v = m_active_vars[i]; + literal lit(v, get_coeff(v) < 0); + p.m_lits.push_back(lit); + p.m_coeffs.push_back(get_abs_coeff(v)); + } + } + + void card_extension::justification2pb(justification const& js, literal lit, unsigned offset, ineq& p) { + switch (js.get_kind()) { + case justification::NONE: + p.reset(0); + break; + case justification::BINARY: + p.reset(offset); + p.push(lit, offset); + p.push(~js.get_literal(), offset); + break; + case justification::TERNARY: + p.reset(offset); + p.push(lit, offset); + p.push(~(js.get_literal1()), offset); + p.push(~(js.get_literal2()), offset); + break; + case justification::CLAUSE: { + p.reset(offset); + clause & c = *(s().m_cls_allocator.get_clause(js.get_clause_offset())); + unsigned sz = c.size(); + for (unsigned i = 0; i < sz; i++) + p.push(~c[i], offset); + break; + } + case justification::EXT_JUSTIFICATION: { + unsigned index = js.get_ext_justification_idx(); + card& c = *m_constraints[index]; + p.reset(offset*c.k()); + for (unsigned i = 0; i < c.size(); ++i) { + p.push(c[i], offset); + } + break; + } + default: + UNREACHABLE(); + break; + } + } + + + // validate that m_A & m_B implies m_C + + bool card_extension::validate_resolvent() { + u_map coeffs; + unsigned k = m_A.m_k + m_B.m_k; + for (unsigned i = 0; i < m_A.m_lits.size(); ++i) { + unsigned coeff = m_A.m_coeffs[i]; + SASSERT(!coeffs.contains(m_A.m_lits[i].index())); + coeffs.insert(m_A.m_lits[i].index(), coeff); + } + for (unsigned i = 0; i < m_B.m_lits.size(); ++i) { + unsigned coeff1 = m_B.m_coeffs[i], coeff2; + literal lit = m_B.m_lits[i]; + if (coeffs.find((~lit).index(), coeff2)) { + if (coeff1 == coeff2) { + coeffs.remove((~lit).index()); + k += coeff1; + } + else if (coeff1 < coeff2) { + coeffs.insert((~lit).index(), coeff2 - coeff1); + k += coeff1; + } + else { + SASSERT(coeff2 < coeff1); + coeffs.remove((~lit).index()); + coeffs.insert(lit.index(), coeff1 - coeff2); + k += coeff2; + } + } + else if (coeffs.find(lit.index(), coeff2)) { + coeffs.insert(lit.index(), coeff1 + coeff2); + } + else { + coeffs.insert(lit.index(), coeff1); + } + } + // C is above the sum of A and B + for (unsigned i = 0; i < m_C.m_lits.size(); ++i) { + literal lit = m_C.m_lits[i]; + unsigned coeff; + if (coeffs.find(lit.index(), coeff)) { + SASSERT(coeff <= m_C.m_coeffs[i]); + coeffs.remove(lit.index()); + } + } + SASSERT(coeffs.empty()); + SASSERT(m_C.m_k <= k); + return true; + } + + bool card_extension::validate_conflict(literal_vector const& lits) { + for (unsigned i = 0; i < lits.size(); ++i) { + if (value(lits[i]) != l_false) return false; + } + return true; + } }; diff --git a/src/sat/card_extension.h b/src/sat/card_extension.h index db8e041c2..1593ef26f 100644 --- a/src/sat/card_extension.h +++ b/src/sat/card_extension.h @@ -23,7 +23,7 @@ Revision History: #include"sat_solver.h" namespace sat { - + class card_extension : public extension { struct stats { unsigned m_num_propagations; @@ -31,7 +31,7 @@ namespace sat { stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); } }; - + class card { unsigned m_index; literal m_lit; @@ -49,6 +49,14 @@ namespace sat { void negate(); }; + struct ineq { + literal_vector m_lits; + unsigned_vector m_coeffs; + unsigned m_k; + void reset(unsigned k) { m_lits.reset(); m_coeffs.reset(); m_k = k; } + void push(literal l, unsigned c) { m_lits.push_back(l); m_coeffs.push_back(c); } + }; + typedef ptr_vector watch; struct var_info { watch* m_lit_watch[2]; @@ -95,6 +103,10 @@ namespace sat { void clear_watch(card& c); void reset_coeffs(); + inline lbool value(literal lit) const { return m_solver->value(lit); } + inline unsigned lvl(literal lit) const { return m_solver->lvl(lit); } + inline unsigned lvl(bool_var v) const { return m_solver->lvl(v); } + void unwatch_literal(literal w, card* c); void remove(ptr_vector& cards, card* c); @@ -105,7 +117,7 @@ namespace sat { literal_vector& get_literals() { m_literals.reset(); return m_literals; } literal get_asserting_literal(literal conseq); - bool resolve_conflict(card& c, literal_vector const& conflict_clause); + bool resolve_conflict(card& c, literal alit); void process_antecedent(literal l, int offset); void process_card(card& c, int offset); void cut(); @@ -117,6 +129,11 @@ namespace sat { bool validate_unit_propagation(card const& c); bool validate_conflict(literal_vector const& lits); + ineq m_A, m_B, m_C; + void active2pb(ineq& p); + void justification2pb(justification const& j, literal lit, unsigned offset, ineq& p); + bool validate_resolvent(); + void display(std::ostream& out, card& c, bool values) const; void display_watch(std::ostream& out, bool_var v, bool sign) const; public: