From 6790f18132041040ffaa098b853afa57637c6ddf Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer <56730610+CEisenhofer@users.noreply.github.com> Date: Thu, 3 Nov 2022 11:34:52 +0100 Subject: [PATCH] Added limit to "visit" to allow detecting multiple visits (#6435) * Memory leak in .NET user-propagator The user-propagator object has to be manually disposed (IDisposable), otherwise it stays in memory forever, as it cannot be garbage collected automatically * Throw an exception if variable passed to decide is already assigned instead of running in an assertion violation * Added limit to "visit" to allow detecting multiple visits * Putting visit in a separate class (Reason: We will probably need two of them in the sat::solver) * Bugfix --- src/sat/sat_gc.cpp | 4 ++-- src/sat/sat_lut_finder.cpp | 12 ++++++------ src/sat/sat_solver.cpp | 32 +++++++------------------------ src/sat/sat_solver.h | 13 ++++--------- src/sat/sat_xor_finder.cpp | 12 ++++++------ src/sat/smt/pb_solver.cpp | 8 ++++---- src/util/visit_helper.h | 39 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 68 insertions(+), 52 deletions(-) create mode 100644 src/util/visit_helper.h diff --git a/src/sat/sat_gc.cpp b/src/sat/sat_gc.cpp index a655956db..69e91c745 100644 --- a/src/sat/sat_gc.cpp +++ b/src/sat/sat_gc.cpp @@ -406,9 +406,9 @@ namespace sat { auto gc_watch = [&](literal lit) { auto& wl1 = get_wlist(lit); for (auto w : get_wlist(lit)) { - if (w.is_binary_clause() && w.get_literal().var() < max_var && !is_visited(w.get_literal())) { + if (w.is_binary_clause() && w.get_literal().var() < max_var && !m_visited.is_visited(w.get_literal())) { m_aux_literals.push_back(w.get_literal()); - mark_visited(w.get_literal()); + m_visited.mark_visited(w.get_literal()); } } wl1.reset(); diff --git a/src/sat/sat_lut_finder.cpp b/src/sat/sat_lut_finder.cpp index 5459ab2a4..26ec80143 100644 --- a/src/sat/sat_lut_finder.cpp +++ b/src/sat/sat_lut_finder.cpp @@ -70,7 +70,7 @@ namespace sat { for (literal l : m_clause) { m_vars.push_back(l.var()); m_var_position[l.var()] = i; - s.mark_visited(l.var()); + s.m_visited.mark_visited(l.var()); mask |= (l.sign() << (i++)); } m_clauses_to_remove.reset(); @@ -91,7 +91,7 @@ namespace sat { // TBD: replace by BIG // loop over binary clauses in watch list for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_lut(~l, w.get_literal())) { add_lut(); return; @@ -100,7 +100,7 @@ namespace sat { } l.neg(); for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_lut(~l, w.get_literal())) { add_lut(); return; @@ -124,8 +124,8 @@ namespace sat { } bool lut_finder::extract_lut(literal l1, literal l2) { - SASSERT(s.is_visited(l1.var())); - SASSERT(s.is_visited(l2.var())); + SASSERT(s.m_visited.is_visited(l1.var())); + SASSERT(s.m_visited.is_visited(l2.var())); m_missing.reset(); unsigned mask = 0; for (unsigned i = 0; i < m_vars.size(); ++i) { @@ -144,7 +144,7 @@ namespace sat { bool lut_finder::extract_lut(clause& c2) { for (literal l : c2) { - if (!s.is_visited(l.var())) + if (!s.m_visited.is_visited(l.var())) return false; } if (c2.size() == m_vars.size()) { diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index eea3a2475..f97a08001 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -3441,10 +3441,10 @@ namespace sat { for (unsigned i = m_clauses_to_reinit.size(); i-- > old_sz; ) { clause_wrapper const& cw = m_clauses_to_reinit[i]; for (unsigned j = cw.size(); j-- > 0; ) - mark_visited(cw[j].var()); + m_visited.mark_visited(cw[j].var()); } for (literal lit : m_lemma) - mark_visited(lit.var()); + m_visited.mark_visited(lit.var()); auto is_active = [&](bool_var v) { return value(v) != l_undef && lvl(v) <= new_lvl; @@ -3452,7 +3452,7 @@ namespace sat { for (unsigned i = old_num_vars; i < sz; ++i) { bool_var v = m_active_vars[i]; - if (is_external(v) || is_visited(v) || is_active(v)) { + if (is_external(v) || m_visited.is_visited(v) || is_active(v)) { m_vars_to_reinit.push_back(v); m_active_vars[j++] = v; m_var_scope[v] = new_lvl; @@ -4697,10 +4697,10 @@ namespace sat { bool solver::all_distinct(literal_vector const& lits) { init_visited(); for (literal l : lits) { - if (is_visited(l.var())) { + if (m_visited.is_visited(l.var())) { return false; } - mark_visited(l.var()); + m_visited.mark_visited(l.var()); } return true; } @@ -4708,30 +4708,12 @@ namespace sat { bool solver::all_distinct(clause const& c) { init_visited(); for (literal l : c) { - if (is_visited(l.var())) { + if (m_visited.is_visited(l.var())) { return false; } - mark_visited(l.var()); + m_visited.mark_visited(l.var()); } return true; } - void solver::init_ts(unsigned n, svector& v, unsigned& ts) { - if (v.empty()) - ts = 0; - - ts++; - if (ts == 0) { - ts = 1; - v.reset(); - } - while (v.size() < n) - v.push_back(0); - } - - void solver::init_visited() { - init_ts(2 * num_vars(), m_visited, m_visited_ts); - } - - }; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 2e53e4620..b75950f88 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -28,6 +28,7 @@ Revision History: #include "util/rlimit.h" #include "util/scoped_ptr_vector.h" #include "util/scoped_limit_trail.h" +#include "util/visit_helper.h" #include "sat/sat_types.h" #include "sat/sat_clause.h" #include "sat/sat_watched.h" @@ -176,8 +177,7 @@ namespace sat { std::string m_reason_unknown; bool m_trim = false; - svector m_visited; - unsigned m_visited_ts; + visit_helper m_visited; struct scope { unsigned m_trail_lim; @@ -342,13 +342,8 @@ namespace sat { void detach_nary_clause(clause & c); void push_reinit_stack(clause & c); void push_reinit_stack(literal l1, literal l2); - - void init_ts(unsigned n, svector& v, unsigned& ts); - void init_visited(); - void mark_visited(literal l) { m_visited[l.index()] = m_visited_ts; } - void mark_visited(bool_var v) { mark_visited(literal(v, false)); } - bool is_visited(bool_var v) const { return is_visited(literal(v, false)); } - bool is_visited(literal l) const { return m_visited[l.index()] == m_visited_ts; } + + void init_visited(unsigned lim = 1) { m_visited.init_visited(num_vars(), lim); } bool all_distinct(literal_vector const& lits); bool all_distinct(clause const& cl); diff --git a/src/sat/sat_xor_finder.cpp b/src/sat/sat_xor_finder.cpp index dbe08d96c..0a20f4782 100644 --- a/src/sat/sat_xor_finder.cpp +++ b/src/sat/sat_xor_finder.cpp @@ -62,7 +62,7 @@ namespace sat { unsigned mask = 0, i = 0; for (literal l : c) { m_var_position[l.var()] = i; - s.mark_visited(l.var()); + s.m_visited.mark_visited(l.var()); parity ^= !l.sign(); mask |= (!l.sign() << (i++)); } @@ -84,7 +84,7 @@ namespace sat { } // loop over binary clauses in watch list for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_xor(parity, c, ~l, w.get_literal())) { add_xor(parity, c); return; @@ -93,7 +93,7 @@ namespace sat { } l.neg(); for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_xor(parity, c, ~l, w.get_literal())) { add_xor(parity, c); return; @@ -122,8 +122,8 @@ namespace sat { } bool xor_finder::extract_xor(bool parity, clause& c, literal l1, literal l2) { - SASSERT(s.is_visited(l1.var())); - SASSERT(s.is_visited(l2.var())); + SASSERT(s.m_visited.is_visited(l1.var())); + SASSERT(s.m_visited.is_visited(l2.var())); m_missing.reset(); unsigned mask = 0; for (unsigned i = 0; i < c.size(); ++i) { @@ -144,7 +144,7 @@ namespace sat { bool xor_finder::extract_xor(bool parity, clause& c, clause& c2) { bool parity2 = false; for (literal l : c2) { - if (!s.is_visited(l.var())) return false; + if (!s.m_visited.is_visited(l.var())) return false; parity2 ^= !l.sign(); } if (c2.size() == c.size() && parity2 != parity) { diff --git a/src/sat/smt/pb_solver.cpp b/src/sat/smt/pb_solver.cpp index 424b20d4e..5b2d851d3 100644 --- a/src/sat/smt/pb_solver.cpp +++ b/src/sat/smt/pb_solver.cpp @@ -2709,10 +2709,10 @@ namespace pb { } void solver::init_visited() { s().init_visited(); } - void solver::mark_visited(literal l) { s().mark_visited(l); } - void solver::mark_visited(bool_var v) { s().mark_visited(v); } - bool solver::is_visited(bool_var v) const { return s().is_visited(v); } - bool solver::is_visited(literal l) const { return s().is_visited(l); } + void solver::mark_visited(literal l) { s().m_visited.mark_visited(l); } + void solver::mark_visited(bool_var v) { s().m_visited.mark_visited(v); } + bool solver::is_visited(bool_var v) const { return s().m_visited.is_visited(v); } + bool solver::is_visited(literal l) const { return s().m_visited.is_visited(l); } void solver::cleanup_clauses() { if (m_clause_removed) { diff --git a/src/util/visit_helper.h b/src/util/visit_helper.h new file mode 100644 index 000000000..1a0d4f5b9 --- /dev/null +++ b/src/util/visit_helper.h @@ -0,0 +1,39 @@ +#pragma once +#include "sat_literal.h" + +class visit_helper { + + unsigned_vector m_visited; + unsigned m_visited_begin = 0; + unsigned m_visited_end = 0; + + void init_ts(unsigned n, unsigned lim = 1) { + SASSERT(lim > 0); + if (m_visited_end >= m_visited_end + lim) { // overflow + m_visited_begin = 0; + m_visited_end = lim; + m_visited.reset(); + } + else { + m_visited_begin = m_visited_end; + m_visited_end = m_visited_end + lim; + } + while (m_visited.size() < n) + m_visited.push_back(0); + } + +public: + + void init_visited(unsigned num_vars, unsigned lim = 1) { + init_ts(2 * num_vars, lim); + } + void mark_visited(sat::literal l) { m_visited[l.index()] = m_visited_begin + 1; } + void mark_visited(sat::bool_var v) { mark_visited(sat::literal(v, false)); } + void inc_visited(sat::literal l) { + m_visited[l.index()] = std::min(m_visited_end, std::max(m_visited_begin, m_visited[l.index()]) + 1); + } + void inc_visited(sat::bool_var v) { inc_visited(sat::literal(v, false)); } + bool is_visited(sat::bool_var v) const { return is_visited(sat::literal(v, false)); } + bool is_visited(sat::literal l) const { return m_visited[l.index()] > m_visited_begin; } + unsigned num_visited(unsigned i) { return std::max(m_visited_begin, m_visited[i]) - m_visited_begin; } +}; \ No newline at end of file