From bcf0ee77096fd70d3cd60da2097c2562ca6bf249 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 29 Jun 2017 18:53:58 -0700 Subject: [PATCH] n/a Signed-off-by: Nikolaj Bjorner --- src/sat/ba_solver.cpp | 255 +++++++++++++++++++++++++++++++++------- src/sat/ba_solver.h | 29 ++++- src/sat/sat_extension.h | 1 + src/sat/sat_lookahead.h | 1 + src/sat/sat_solver.cpp | 21 ++-- src/sat/sat_watched.cpp | 3 +- src/sat/sat_watched.h | 2 +- 7 files changed, 252 insertions(+), 60 deletions(-) diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index 7d9e16df0..22c77f164 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -69,6 +69,14 @@ namespace sat { SASSERT(m_size >= m_k && m_k > 0); } + bool ba_solver::card::is_watching(literal l) const { + unsigned sz = std::min(k() + 1, size()); + for (unsigned i = 0; i < sz; ++i) { + if ((*this)[i] == l) return true; + } + return false; + } + std::ostream& operator<<(std::ostream& out, ba_solver::constraint const& cnstr) { if (cnstr.lit() != null_literal) out << cnstr.lit() << " == "; switch (cnstr.tag()) { @@ -135,6 +143,14 @@ namespace sat { m_k = w - m_k + 1; SASSERT(w >= m_k && m_k > 0); } + + bool ba_solver::pb::is_watching(literal l) const { + for (unsigned i = 0; i < m_num_watch; ++i) { + if ((*this)[i].second == l) return true; + } + return false; + } + ba_solver::xor::xor(literal lit, literal_vector const& lits): constraint(xor_t, lit, lits.size(), get_obj_size(lits.size())) { @@ -143,6 +159,12 @@ namespace sat { } } + bool ba_solver::xor::is_watching(literal l) const { + return + l == (*this)[0] || l == (*this)[1] || + ~l == (*this)[0] || ~l == (*this)[1]; + } + void ba_solver::init_watch(card& c, bool is_true) { clear_watch(c); if (c.lit() != null_literal && c.lit().sign() == is_true) { @@ -211,14 +233,6 @@ namespace sat { } } - void ba_solver::unwatch_literal(literal lit, constraint& c) { - get_wlist(~lit).erase(watched(c.index())); - } - - void ba_solver::watch_literal(literal lit, constraint& c) { - get_wlist(~lit).push_back(watched(c.index())); - } - void ba_solver::set_conflict(constraint& c, literal lit) { m_stats.m_num_conflicts++; TRACE("sat", display(tout, c, true); ); @@ -510,7 +524,6 @@ namespace sat { } void ba_solver::simplify(pb& p) { - s().pop_to_base_level(); if (p.lit() != null_literal && value(p.lit()) == l_false) { TRACE("sat", tout << "pb: flip sign " << p << "\n";); return; @@ -629,6 +642,8 @@ namespace sat { void ba_solver::clear_watch(xor& x) { unwatch_literal(x[0], x); unwatch_literal(x[1], x); + unwatch_literal(~x[0], x); + unwatch_literal(~x[1], x); } bool ba_solver::parity(xor const& x, unsigned offset) const { @@ -678,6 +693,8 @@ namespace sat { SASSERT(j == 2); watch_literal(x[0], x); watch_literal(x[1], x); + watch_literal(~x[0], x); + watch_literal(~x[1], x); break; } } @@ -1157,6 +1174,14 @@ namespace sat { s().set_external(lit.var()); get_wlist(lit).push_back(c->index()); get_wlist(~lit).push_back(c->index()); + if (!validate_watched_constraint(*c)) { + std::cout << "wrong: " << *c << "\n"; + } + } + if (lit.var() == 102770) { + display(std::cout, *c, true); + display_watch_list(std::cout, s().m_cls_allocator, get_wlist(lit)) << "\n"; + display_watch_list(std::cout, s().m_cls_allocator, get_wlist(~lit)) << "\n"; } } @@ -1208,6 +1233,7 @@ namespace sat { if (c.lit() != null_literal && l.var() == c.lit().var()) { init_watch(c, !l.sign()); keep = true; + if (!inconsistent()) validate_watched_constraint(c); } else if (c.lit() != null_literal && value(c.lit()) != l_true) { keep = false; @@ -1215,6 +1241,7 @@ namespace sat { else { keep = l_undef != add_assign(c, ~l); } + std::cout << c.lit() << " " << l << " " << keep << "\n"; } @@ -1444,10 +1471,21 @@ namespace sat { } } + // ---------------------------- + // constraint generic methods + void ba_solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) { get_antecedents(l, index2constraint(idx), r); } + void ba_solver::unwatch_literal(literal lit, constraint& c) { + get_wlist(~lit).erase(watched(c.index())); + } + + void ba_solver::watch_literal(literal lit, constraint& c) { + get_wlist(~lit).push_back(watched(c.index())); + } + void ba_solver::get_antecedents(literal l, constraint const& c, literal_vector& r) { switch (c.tag()) { case card_t: get_antecedents(l, c.to_card(), r); break; @@ -1457,7 +1495,38 @@ namespace sat { } } + void ba_solver::nullify_tracking_literal(constraint& c) { + if (c.lit() != null_literal) { + get_wlist(c.lit()).erase(watched(c.index())); + get_wlist(~c.lit()).erase(watched(c.index())); + c.nullify_literal(); + } + } + + void ba_solver::remove_constraint(constraint& c) { + nullify_tracking_literal(c); + switch (c.tag()) { + case card_t: + clear_watch(c.to_card()); + break; + case pb_t: + clear_watch(c.to_pb()); + break; + case xor_t: + clear_watch(c.to_xor()); + break; + default: + UNREACHABLE(); + } + c.remove(); + m_constraint_removed = true; + } + + // -------------------------------- + // validation + bool ba_solver::validate_unit_propagation(constraint const& c, literal l) const { + return true; switch (c.tag()) { case card_t: return validate_unit_propagation(c.to_card(), l); case pb_t: return validate_unit_propagation(c.to_pb(), l); @@ -1468,6 +1537,7 @@ namespace sat { } bool ba_solver::validate_conflict(constraint const& c) const { + return true; switch (c.tag()) { case card_t: return validate_conflict(c.to_card()); case pb_t: return validate_conflict(c.to_pb()); @@ -1477,6 +1547,133 @@ namespace sat { return false; } + bool ba_solver::is_true(constraint const& c) const { + lbool v1 = c.lit() == null_literal ? l_true : value(c.lit()); + if (v1 == l_undef) return false; + switch (c.tag()) { + case card_t: return v1 == value(c.to_card()); + case pb_t: return v1 == value(c.to_pb()); + case xor_t: return v1 == value(c.to_xor()); + default: UNREACHABLE(); break; + } + return false; + } + + lbool ba_solver::value(card const& c) const { + unsigned trues = 0, undefs = 0; + for (literal l : c) { + switch (value(l)) { + case l_true: trues++; break; + case l_undef: undefs++; break; + default: break; + } + } + if (trues + undefs < c.k()) return l_false; + if (trues >= c.k()) return l_true; + return l_undef; + } + + lbool ba_solver::value(pb const& p) const { + unsigned trues = 0, undefs = 0; + for (wliteral wl : p) { + switch (value(wl.second)) { + case l_true: trues += wl.first; break; + case l_undef: undefs += wl.first; break; + default: break; + } + } + if (trues + undefs < p.k()) return l_false; + if (trues >= p.k()) return l_true; + return l_undef; + } + + lbool ba_solver::value(xor const& x) const { + bool odd = false; + + for (auto l : x) { + switch (value(l)) { + case l_true: odd = !odd; break; + case l_false: break; + default: return l_undef; + } + } + return odd ? l_true : l_false; + } + + void ba_solver::validate() { + if (validate_watch_literals()) { + for (constraint* c : m_constraints) { + if (!validate_watched_constraint(*c)) break; + } + } + } + + bool ba_solver::validate_watch_literals() const { + for (unsigned v = 0; v < s().num_vars(); ++v) { + literal lit(v, false); + if (lvl(lit) == 0) continue; + if (!validate_watch_literal(lit)) return false; + if (!validate_watch_literal(~lit)) return false; + } + return true; + } + + bool ba_solver::validate_watch_literal(literal lit) const { + if (lvl(lit) == 0) return true; + for (auto const & w : get_wlist(lit)) { + if (w.get_kind() == watched::EXT_CONSTRAINT) { + constraint const& c = index2constraint(w.get_ext_constraint_idx()); + if (!c.is_watching(~lit)) { + std::cout << lit << " " << lvl(lit) << " is not watched in " << c << "\n"; + display(std::cout, c, true); + UNREACHABLE(); + return false; + } + } + } + return true; + } + + bool ba_solver::validate_watched_constraint(constraint const& c) const { + if (c.lit() != null_literal && value(c.lit()) != l_true) return true; + if (c.lit() != null_literal && lvl(c.lit()) != 0) { + if (!is_watching(c.lit(), c) || !is_watching(~c.lit(), c)) { + std::cout << "Definition literal is not watched " << c.lit() << " " << c << "\n"; + display_watch_list(std::cout, s().m_cls_allocator, get_wlist(c.lit())) << "\n"; + display_watch_list(std::cout, s().m_cls_allocator, get_wlist(~c.lit())) << "\n"; + return false; + } + } + if (is_true(c)) { + return true; + } + literal_vector lits(c.literals()); + for (literal l : lits) { + if (lvl(l) == 0) continue; + bool found = is_watching(l, c); + if (found != c.is_watching(l)) { + std::cout << "Discrepancy of watched literal: " << l << ": " << c.index() << " " << c << (found?" is watched, but shouldn't be":" not watched, but should be") << "\n"; + display_watch_list(std::cout << l << ": ", s().m_cls_allocator, get_wlist(l)) << "\n"; + display_watch_list(std::cout << ~l << ": ", s().m_cls_allocator, get_wlist(~l)) << "\n"; + std::cout << "value: " << value(l) << " level: " << lvl(l) << "\n"; + display(std::cout, c, true); + if (c.lit() != null_literal) std::cout << value(c.lit()) << "\n"; + UNREACHABLE(); + exit(1); + return false; + } + } + return true; + } + + bool ba_solver::is_watching(literal lit, constraint const& c) const { + for (auto w : get_wlist(~lit)) { + if (w.get_kind() == watched::EXT_CONSTRAINT && w.get_ext_constraint_idx() == c.index()) + return true; + } + return false; + } + /** \brief Lex on (glue, size) */ @@ -1507,33 +1704,6 @@ namespace sat { } - void ba_solver::nullify_tracking_literal(constraint& c) { - if (c.lit() != null_literal) { - get_wlist(c.lit()).erase(watched(c.index())); - get_wlist(~c.lit()).erase(watched(c.index())); - c.nullify_literal(); - } - } - - void ba_solver::remove_constraint(constraint& c) { - nullify_tracking_literal(c); - switch (c.tag()) { - case card_t: - clear_watch(c.to_card()); - break; - case pb_t: - clear_watch(c.to_pb()); - break; - case xor_t: - clear_watch(c.to_xor()); - break; - default: - UNREACHABLE(); - } - c.remove(); - m_constraint_removed = true; - } - void ba_solver::simplify(card& c) { SASSERT(c.lit() == null_literal || value(c.lit()) != l_false); if (c.lit() != null_literal && value(c.lit()) == l_false) { @@ -1722,7 +1892,7 @@ namespace sat { void ba_solver::simplify() { return; - if (!s().at_base_lvl()) s().pop_to_base_level(); + SASSERT(s().at_base_lvl()); unsigned trail_sz; do { m_simplify_change = false; @@ -1760,6 +1930,10 @@ namespace sat { void ba_solver::flush_roots() { if (m_roots.empty()) return; + + std::cout << "pre\n"; + validate(); + m_visited.resize(s().num_vars()*2, false); m_constraint_removed = false; for (constraint* c : m_constraints) { @@ -1778,6 +1952,8 @@ namespace sat { } } cleanup_constraints(); + std::cout << "post\n"; + validate(); // display(std::cout << "flush roots\n"); } @@ -1907,7 +2083,7 @@ namespace sat { } void ba_solver::recompile(pb& p) { - // IF_VERBOSE(0, verbose_stream() << "re: " << p << "\n";); + IF_VERBOSE(0, verbose_stream() << "re: " << p << "\n";); m_weights.resize(2*s().num_vars(), 0); for (wliteral wl : p) { m_weights[wl.second.index()] += wl.first; @@ -1965,11 +2141,10 @@ namespace sat { p.update_k(k); p.update_max_sum(); - literal root = null_literal; if (p.lit() != null_literal) root = m_roots[p.lit().index()]; - // IF_VERBOSE(0, verbose_stream() << "new: " << p << "\n";); + IF_VERBOSE(0, verbose_stream() << "new: " << p << "\n";); // std::cout << "simplified " << p << "\n"; if (p.lit() != root) { diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index ca6b0d093..2c2a5d914 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -74,6 +74,7 @@ namespace sat { void nullify_literal() { m_lit = null_literal; } unsigned glue() const { return m_glue; } void set_glue(unsigned g) { m_glue = g; } + size_t obj_size() const { return m_obj_size; } card& to_card(); @@ -85,6 +86,9 @@ namespace sat { bool is_card() const { return m_tag == card_t; } bool is_pb() const { return m_tag == pb_t; } bool is_xor() const { return m_tag == xor_t; } + + virtual bool is_watching(literal l) const { return false; }; + virtual literal_vector literals() const { return literal_vector(); } }; friend std::ostream& operator<<(std::ostream& out, constraint const& c); @@ -103,7 +107,8 @@ namespace sat { void swap(unsigned i, unsigned j) { std::swap(m_lits[i], m_lits[j]); } void negate(); void update_k(unsigned k) { m_k = k; } - literal_vector literals() const { return literal_vector(m_size, m_lits); } + virtual literal_vector literals() const { return literal_vector(m_size, m_lits); } + virtual bool is_watching(literal l) const; }; @@ -122,7 +127,7 @@ namespace sat { wliteral operator[](unsigned i) const { return m_wlits[i]; } wliteral& operator[](unsigned i) { return m_wlits[i]; } wliteral const* begin() const { return m_wlits; } - wliteral const* end() const { return static_cast(m_wlits) + m_size; } + wliteral const* end() const { return begin() + m_size; } unsigned k() const { return m_k; } unsigned slack() const { return m_slack; } @@ -134,7 +139,8 @@ namespace sat { void negate(); void update_k(unsigned k) { m_k = k; } void update_max_sum(); - literal_vector literals() const { literal_vector lits; for (auto wl : *this) lits.push_back(wl.second); return lits; } + virtual literal_vector literals() const { literal_vector lits; for (auto wl : *this) lits.push_back(wl.second); return lits; } + virtual bool is_watching(literal l) const; }; class xor : public constraint { @@ -144,9 +150,11 @@ namespace sat { xor(literal lit, literal_vector const& lits); literal operator[](unsigned i) const { return m_lits[i]; } literal const* begin() const { return m_lits; } - literal const* end() const { return static_cast(m_lits) + m_size; } + literal const* end() const { return begin() + m_size; } void swap(unsigned i, unsigned j) { std::swap(m_lits[i], m_lits[j]); } void negate() { m_lits[0].neg(); } + virtual bool is_watching(literal l) const; + virtual literal_vector literals() const { return literal_vector(size(), begin()); } }; @@ -244,6 +252,9 @@ namespace sat { void get_antecedents(literal l, constraint const& c, literal_vector & r); bool validate_conflict(constraint const& c) const; bool validate_unit_propagation(constraint const& c, literal alit) const; + void attach_constraint(constraint const& c); + void detach_constraint(constraint const& c); + bool is_true(constraint const& c) const; // cardinality void init_watch(card& c, bool is_true); @@ -256,6 +267,7 @@ namespace sat { void unit_propagation_simplification(literal lit, literal_vector const& lits); void flush_roots(card& c); void recompile(card& c); + lbool value(card const& c) const; // xor specific functionality void clear_watch(xor& x); @@ -266,6 +278,7 @@ namespace sat { void get_antecedents(literal l, xor const& x, literal_vector & r); void simplify(xor& x); void flush_roots(xor& x); + lbool value(xor const& x) const; // pb functionality unsigned m_a_max; @@ -279,6 +292,7 @@ namespace sat { bool is_cardinality(pb const& p); void flush_roots(pb& p); void recompile(pb& p); + lbool value(pb const& p) const; // access solver inline lbool value(literal lit) const { return m_lookahead ? m_lookahead->value(lit) : m_solver->value(lit); } @@ -286,6 +300,7 @@ namespace sat { inline unsigned lvl(bool_var v) const { return m_solver->lvl(v); } inline bool inconsistent() const { return m_lookahead ? m_lookahead->inconsistent() : m_solver->inconsistent(); } inline watch_list& get_wlist(literal l) { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } + inline watch_list const& get_wlist(literal l) const { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } inline void assign(literal l, justification j) { if (m_lookahead) m_lookahead->assign(l); else m_solver->assign(l, j); } inline void set_conflict(justification j, literal l) { if (m_lookahead) m_lookahead->set_conflict(); else m_solver->set_conflict(j, l); } inline config const& get_config() const { return m_solver->get_config(); } @@ -312,6 +327,10 @@ namespace sat { bool validate_unit_propagation(pb const& p, literal alit) const; bool validate_unit_propagation(xor const& x, literal alit) const; bool validate_conflict(literal_vector const& lits, ineq& p); + bool validate_watch_literals() const; + bool validate_watch_literal(literal lit) const; + bool validate_watched_constraint(constraint const& c) const; + bool is_watching(literal lit, constraint const& c) const; ineq m_A, m_B, m_C; void active2pb(ineq& p); @@ -358,6 +377,8 @@ namespace sat { ptr_vector const & constraints() const { return m_constraints; } + virtual void validate(); + }; }; diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 642171610..f984091e3 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -53,6 +53,7 @@ namespace sat { virtual extension* copy(solver* s) = 0; virtual void find_mutexes(literal_vector& lits, vector & mutexes) = 0; virtual void gc() = 0; + virtual void validate() = 0; }; }; diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 43c2612b8..067a95c55 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -372,6 +372,7 @@ namespace sat { void attach_ternary(ternary const& t); void attach_ternary(literal l1, literal l2, literal l3); watch_list& get_wlist(literal l) { return m_watches[l.index()]; } + watch_list const& get_wlist(literal l) const { return m_watches[l.index()]; } // ------------------------------------ // initialization diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 7a6955ef8..dfa85569c 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1391,6 +1391,7 @@ namespace sat { */ void solver::simplify_problem() { + if (m_ext) m_ext->validate(); if (m_conflicts_since_init < m_next_simplify) { return; } @@ -1403,12 +1404,15 @@ namespace sat { SASSERT(at_base_lvl()); m_cleaner(); + if (m_ext) m_ext->validate(); CASSERT("sat_simplify_bug", check_invariant()); m_scc(); + if (m_ext) m_ext->validate(); CASSERT("sat_simplify_bug", check_invariant()); m_simplifier(false); + if (m_ext) m_ext->validate(); CASSERT("sat_simplify_bug", check_invariant()); CASSERT("sat_missed_prop", check_missed_propagation()); @@ -1417,6 +1421,7 @@ namespace sat { CASSERT("sat_missed_prop", check_missed_propagation()); CASSERT("sat_simplify_bug", check_invariant()); } + if (m_ext) m_ext->validate(); if (m_config.m_lookahead_simplify) { { @@ -1435,10 +1440,12 @@ namespace sat { CASSERT("sat_simplify_bug", check_invariant()); m_probing(); + if (m_ext) m_ext->validate(); CASSERT("sat_missed_prop", check_missed_propagation()); CASSERT("sat_simplify_bug", check_invariant()); m_asymm_branch(); + if (m_ext) m_ext->validate(); CASSERT("sat_missed_prop", check_missed_propagation()); CASSERT("sat_simplify_bug", check_invariant()); @@ -1460,21 +1467,7 @@ namespace sat { m_next_simplify = m_conflicts_since_init + m_config.m_simplify_max; } -#if 0 - static unsigned file_no = 0; - #pragma omp critical (print_sat) - { - ++file_no; - std::ostringstream ostrm; - ostrm << "s" << file_no << ".txt"; - std::ofstream ous(ostrm.str()); - display(ous); - } -#endif - if (m_par) m_par->set_phase(*this); - - } bool solver::set_root(literal l, literal r) { diff --git a/src/sat/sat_watched.cpp b/src/sat/sat_watched.cpp index 1b294351f..d1642d50f 100644 --- a/src/sat/sat_watched.cpp +++ b/src/sat/sat_watched.cpp @@ -39,7 +39,7 @@ namespace sat { return false; } - void display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist) { + std::ostream& display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist) { watch_list::const_iterator it = wlist.begin(); watch_list::const_iterator end = wlist.end(); for (bool first = true; it != end; ++it) { @@ -66,6 +66,7 @@ namespace sat { UNREACHABLE(); } } + return out; } }; diff --git a/src/sat/sat_watched.h b/src/sat/sat_watched.h index b9a0962e9..fa7008818 100644 --- a/src/sat/sat_watched.h +++ b/src/sat/sat_watched.h @@ -130,7 +130,7 @@ namespace sat { inline void erase_ternary_watch(watch_list & wlist, literal l1, literal l2) { wlist.erase(watched(l1, l2)); } class clause_allocator; - void display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist); + std::ostream& display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist); }; #endif