diff --git a/src/sat/sat_clause.h b/src/sat/sat_clause.h index 7bacf0777..1477cc6e3 100644 --- a/src/sat/sat_clause.h +++ b/src/sat/sat_clause.h @@ -65,8 +65,7 @@ namespace sat { literal & operator[](unsigned idx) { SASSERT(idx < m_size); return m_lits[idx]; } literal const & operator[](unsigned idx) const { SASSERT(idx < m_size); return m_lits[idx]; } bool is_learned() const { return m_learned; } - void set_learned() { SASSERT(!is_learned()); m_learned = true; } - void unset_learned() { SASSERT(is_learned()); m_learned = false; } + void set_learned(bool l) { SASSERT(is_learned() != l); m_learned = l; } void shrink(unsigned num_lits) { SASSERT(num_lits <= m_size); if (num_lits < m_size) { m_size = num_lits; mark_strengthened(); } } bool strengthened() const { return m_strengthened; } void mark_strengthened() { m_strengthened = true; update_approx(); } diff --git a/src/sat/sat_simplifier.cpp b/src/sat/sat_simplifier.cpp index 106cc791e..577badbc8 100644 --- a/src/sat/sat_simplifier.cpp +++ b/src/sat/sat_simplifier.cpp @@ -151,7 +151,7 @@ namespace sat { inline void simplifier::block_clause(clause & c) { if (m_retain_blocked_clauses) { m_need_cleanup = true; - c.set_learned(); + s.set_learned(c, true); m_use_list.block(c); } else { @@ -161,7 +161,7 @@ namespace sat { } inline void simplifier::unblock_clause(clause & c) { - c.unset_learned(); + s.set_learned(c, false); m_use_list.unblock(c); } @@ -499,7 +499,7 @@ namespace sat { if (!c2.was_removed() && *l_it == null_literal) { // c2 was subsumed if (c1.is_learned() && !c2.is_learned()) - c1.unset_learned(); + s.set_learned(c1, false); TRACE("subsumption", tout << c1 << " subsumed " << c2 << "\n";); remove_clause(c2); m_num_subsumed++; @@ -599,7 +599,7 @@ namespace sat { clause & c2 = *cp; // c2 was subsumed if (c1.is_learned() && !c2.is_learned()) - c1.unset_learned(); + s.set_learned(c1, false); TRACE("subsumption", tout << c1 << " subsumed " << c2 << "\n";); remove_clause(c2); m_num_subsumed++; @@ -759,7 +759,7 @@ namespace sat { SASSERT(wlist[j] == w); TRACE("set_not_learned_bug", tout << "marking as not learned: " << l2 << " " << wlist[j].is_learned() << "\n";); - wlist[j].set_not_learned(); + wlist[j].set_learned(false); mark_as_not_learned_core(get_wlist(~l2), l); } if (s.inconsistent()) @@ -776,7 +776,7 @@ namespace sat { void simplifier::mark_as_not_learned_core(watch_list & wlist, literal l2) { for (watched & w : wlist) { if (w.is_binary_clause() && w.get_literal() == l2 && w.is_learned()) { - w.set_not_learned(); + w.set_learned(false); return; } } @@ -1793,7 +1793,7 @@ namespace sat { w = find_binary_watch(wlist1, l2); if (w) { if (w->is_learned()) - w->set_not_learned(); + w->set_learned(false); } else { wlist1.push_back(watched(l2, false)); @@ -1802,7 +1802,7 @@ namespace sat { w = find_binary_watch(wlist2, l1); if (w) { if (w->is_learned()) - w->set_not_learned(); + w->set_learned(false); } else { wlist2.push_back(watched(l1, false)); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 0c2d34dac..d21e773f3 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -336,13 +336,13 @@ namespace sat { watched* w0 = find_binary_watch(get_wlist(~l1), l2); if (w0) { if (w0->is_learned() && !learned) { - w0->set_not_learned(); + w0->set_learned(false); } w0 = find_binary_watch(get_wlist(~l2), l1); } if (w0) { if (w0->is_learned() && !learned) { - w0->set_not_learned(); + w0->set_learned(false); } return; } @@ -481,6 +481,19 @@ namespace sat { reinit = attach_nary_clause(c); } + void solver::set_learned(clause& c, bool learned) { + if (c.is_learned() == learned) + return; + + if (c.size() == 3) { + set_ternary_learned(get_wlist(~c[0]), c[1], c[2], learned); + set_ternary_learned(get_wlist(~c[1]), c[0], c[2], learned); + set_ternary_learned(get_wlist(~c[2]), c[0], c[1], learned); + } + c.set_learned(learned); + } + + /** \brief Select a watch literal starting the search at the given position. This method is only used for clauses created during the search. diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 8b97dc7ea..0370d634d 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -233,6 +233,7 @@ namespace sat { bool attach_nary_clause(clause & c); void attach_clause(clause & c, bool & reinit); void attach_clause(clause & c) { bool reinit; attach_clause(c, reinit); } + void set_learned(clause& c, bool learned); class scoped_disable_checkpoint { solver& s; public: @@ -611,7 +612,7 @@ namespace sat { clause * const * end_learned() const { return m_learned.end(); } clause_vector const& learned() const { return m_learned; } clause_vector const& clauses() const { return m_clauses; } - void collect_bin_clauses(svector & r, bool learned, bool learned_only = false) const; + void collect_bin_clauses(svector & r, bool learned, bool learned_only = false) const; // ----------------------- // diff --git a/src/sat/sat_watched.cpp b/src/sat/sat_watched.cpp index 1a0880282..d7c6520bd 100644 --- a/src/sat/sat_watched.cpp +++ b/src/sat/sat_watched.cpp @@ -52,7 +52,7 @@ namespace sat { } return nullptr; } - + void erase_binary_watch(watch_list& wlist, literal l) { watch_list::iterator it = wlist.begin(), end = wlist.end(); watch_list::iterator it2 = it; @@ -70,6 +70,30 @@ namespace sat { VERIFY(found); } + void erase_ternary_watch(watch_list& wlist, literal l1, literal l2) { + watch_list::iterator it = wlist.begin(), end = wlist.end(); + watch_list::iterator it2 = it; + bool found = false; + for (; it != end; ++it) { + if (it->is_ternary_clause() && it->get_literal1() == l1 && it->get_literal2() == l2) { + found = true; + continue; + } + *it2 = *it; + ++it2; + } + wlist.set_end(it2); + VERIFY(found); + } + + void set_ternary_learned(watch_list& wlist, literal l1, literal l2, bool learned) { + for (watched& w : wlist) { + if (w.is_ternary_clause() && w.get_literal1() == l1 && w.get_literal2() == l2) { + w.set_learned(learned); + } + } + } + void conflict_cleanup(watch_list::iterator it, watch_list::iterator it2, watch_list& wlist) { watch_list::iterator end = wlist.end(); for (; it != end; ++it, ++it2) diff --git a/src/sat/sat_watched.h b/src/sat/sat_watched.h index 88b6f3e51..8979405ae 100644 --- a/src/sat/sat_watched.h +++ b/src/sat/sat_watched.h @@ -33,7 +33,7 @@ namespace sat { For binary clauses: we use a bit to store whether the binary clause was learned or not. - Remark: there is not Clause object for binary clauses. + Remark: there are no clause objects for binary clauses. */ class watched { public: @@ -87,13 +87,12 @@ namespace sat { bool is_binary_clause() const { return get_kind() == BINARY; } literal get_literal() const { SASSERT(is_binary_clause()); return to_literal(static_cast(m_val1)); } void set_literal(literal l) { SASSERT(is_binary_clause()); m_val1 = l.to_uint(); } - bool is_learned() const { SASSERT(is_binary_clause()); return (m_val2 >> 2) == 1; } + bool is_learned() const { SASSERT(is_binary_clause() || is_ternary_clause()); return ((m_val2 >> 2) & 1) == 1; } bool is_binary_learned_clause() const { return is_binary_clause() && is_learned(); } bool is_binary_non_learned_clause() const { return is_binary_clause() && !is_learned(); } - void set_not_learned() { SASSERT(is_learned()); m_val2 &= 0x3; SASSERT(!is_learned()); } - void set_learned() { SASSERT(!is_learned()); m_val2 |= 0x4; SASSERT(is_learned()); } + void set_learned(bool l) { SASSERT(is_learned() != l); if (l) m_val2 |= 4; else m_val2 &= 3; SASSERT(is_learned() == l); } bool is_ternary_clause() const { return get_kind() == TERNARY; } literal get_literal1() const { SASSERT(is_ternary_clause()); return to_literal(static_cast(m_val1)); } @@ -136,7 +135,8 @@ namespace sat { watched* find_binary_watch(watch_list & wlist, literal l); watched const* find_binary_watch(watch_list const & wlist, literal l); bool erase_clause_watch(watch_list & wlist, clause_offset c); - inline void erase_ternary_watch(watch_list & wlist, literal l1, literal l2) { wlist.erase(watched(l1, l2, true)); wlist.erase(watched(l1, l2, false)); } + void erase_ternary_watch(watch_list & wlist, literal l1, literal l2); + void set_ternary_learned(watch_list& wlist, literal l1, literal l2, bool learned); class clause_allocator; std::ostream& display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist);