diff --git a/src/sat/sat_clause.cpp b/src/sat/sat_clause.cpp index c59ce7289..913c729d4 100644 --- a/src/sat/sat_clause.cpp +++ b/src/sat/sat_clause.cpp @@ -34,7 +34,8 @@ namespace sat { m_reinit_stack(false), m_inact_rounds(0), m_glue(255), - m_psm(255) { + m_psm(255), + m_scope_lim(0) { memcpy(m_lits, lits, sizeof(literal) * sz); mark_strengthened(); SASSERT(check_approx()); @@ -192,6 +193,7 @@ namespace sat { cls->m_psm = other.psm(); cls->m_frozen = other.frozen(); cls->m_approx = other.approx(); + cls->m_scope_lim = other.scope_lim(); return cls; } diff --git a/src/sat/sat_clause.h b/src/sat/sat_clause.h index 0129febbf..f986e7e41 100644 --- a/src/sat/sat_clause.h +++ b/src/sat/sat_clause.h @@ -53,6 +53,7 @@ namespace sat { unsigned m_inact_rounds:8; unsigned m_glue:8; unsigned m_psm:8; // transient field used during gc + unsigned m_scope_lim:2; // user scope level when clause was learned, saturated at 3 literal m_lits[0]; static size_t get_obj_size(unsigned num_lits) { return sizeof(clause) + num_lits * sizeof(literal); } @@ -103,6 +104,8 @@ namespace sat { bool on_reinit_stack() const { return m_reinit_stack; } void set_reinit_stack(bool f) { m_reinit_stack = f; } + unsigned scope_lim() const { return m_scope_lim; } + void set_scope_lim(unsigned lim) { m_scope_lim = lim > 3 ? 3 : lim; } }; std::ostream & operator<<(std::ostream & out, clause_vector const & cs); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 09f874e78..33f6fd80c 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -549,8 +549,10 @@ namespace sat { if (reinit || has_variables_to_reinit(*r)) push_reinit_stack(*r); - if (st.is_redundant()) + if (st.is_redundant()) { + r->set_scope_lim(m_user_scope_literals.size()); m_learned.push_back(r); + } else m_clauses.push_back(r); if (m_config.m_drat) @@ -3721,6 +3723,22 @@ namespace sat { m_ext->user_pop(num_scopes); gc_vars(max_var); + + // remove learned clauses that were added during the popped user scopes + // scope_lim is saturated at 3, so clauses at scope > old_sz can be identified when old_sz < 3 + if (old_sz < 3) { + unsigned j = 0; + for (clause* c : m_learned) { + if (c->scope_lim() > old_sz) { + SASSERT(!c->on_reinit_stack()); + detach_clause(*c); + del_clause(*c); + } + else + m_learned[j++] = c; + } + m_learned.shrink(j); + } TRACE(sat, display(tout);); m_qhead = 0;