From 90490cb22f2600dd45826ba7070c97e091077298 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 3 Nov 2022 03:54:39 -0700 Subject: [PATCH] make visited_helper independent of literals re-introduce shorthands in sat::solver for visited and have them convert literals to unsigned. --- src/sat/sat_gc.cpp | 4 ++-- src/sat/sat_lut_finder.cpp | 8 +++---- src/sat/sat_solver.cpp | 14 ++++++------ src/sat/sat_solver.h | 6 +++++- src/sat/sat_xor_finder.cpp | 8 +++---- src/sat/smt/pb_solver.cpp | 8 +++---- src/util/visit_helper.h | 44 +++++++++++++++++++++++--------------- 7 files changed, 53 insertions(+), 39 deletions(-) diff --git a/src/sat/sat_gc.cpp b/src/sat/sat_gc.cpp index 69e91c745..a655956db 100644 --- a/src/sat/sat_gc.cpp +++ b/src/sat/sat_gc.cpp @@ -406,9 +406,9 @@ namespace sat { auto gc_watch = [&](literal lit) { auto& wl1 = get_wlist(lit); for (auto w : get_wlist(lit)) { - if (w.is_binary_clause() && w.get_literal().var() < max_var && !m_visited.is_visited(w.get_literal())) { + if (w.is_binary_clause() && w.get_literal().var() < max_var && !is_visited(w.get_literal())) { m_aux_literals.push_back(w.get_literal()); - m_visited.mark_visited(w.get_literal()); + mark_visited(w.get_literal()); } } wl1.reset(); diff --git a/src/sat/sat_lut_finder.cpp b/src/sat/sat_lut_finder.cpp index 26ec80143..60143f91c 100644 --- a/src/sat/sat_lut_finder.cpp +++ b/src/sat/sat_lut_finder.cpp @@ -70,7 +70,7 @@ namespace sat { for (literal l : m_clause) { m_vars.push_back(l.var()); m_var_position[l.var()] = i; - s.m_visited.mark_visited(l.var()); + s.mark_visited(l.var()); mask |= (l.sign() << (i++)); } m_clauses_to_remove.reset(); @@ -91,7 +91,7 @@ namespace sat { // TBD: replace by BIG // loop over binary clauses in watch list for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_lut(~l, w.get_literal())) { add_lut(); return; @@ -100,7 +100,7 @@ namespace sat { } l.neg(); for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_lut(~l, w.get_literal())) { add_lut(); return; @@ -144,7 +144,7 @@ namespace sat { bool lut_finder::extract_lut(clause& c2) { for (literal l : c2) { - if (!s.m_visited.is_visited(l.var())) + if (!s.is_visited(l.var())) return false; } if (c2.size() == m_vars.size()) { diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index f97a08001..14e7e9775 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -3441,10 +3441,10 @@ namespace sat { for (unsigned i = m_clauses_to_reinit.size(); i-- > old_sz; ) { clause_wrapper const& cw = m_clauses_to_reinit[i]; for (unsigned j = cw.size(); j-- > 0; ) - m_visited.mark_visited(cw[j].var()); + mark_visited(cw[j].var()); } for (literal lit : m_lemma) - m_visited.mark_visited(lit.var()); + mark_visited(lit.var()); auto is_active = [&](bool_var v) { return value(v) != l_undef && lvl(v) <= new_lvl; @@ -3452,7 +3452,7 @@ namespace sat { for (unsigned i = old_num_vars; i < sz; ++i) { bool_var v = m_active_vars[i]; - if (is_external(v) || m_visited.is_visited(v) || is_active(v)) { + if (is_external(v) || is_visited(v) || is_active(v)) { m_vars_to_reinit.push_back(v); m_active_vars[j++] = v; m_var_scope[v] = new_lvl; @@ -4697,10 +4697,10 @@ namespace sat { bool solver::all_distinct(literal_vector const& lits) { init_visited(); for (literal l : lits) { - if (m_visited.is_visited(l.var())) { + if (is_visited(l.var())) { return false; } - m_visited.mark_visited(l.var()); + mark_visited(l.var()); } return true; } @@ -4708,10 +4708,10 @@ namespace sat { bool solver::all_distinct(clause const& c) { init_visited(); for (literal l : c) { - if (m_visited.is_visited(l.var())) { + if (is_visited(l.var())) { return false; } - m_visited.mark_visited(l.var()); + mark_visited(l.var()); } return true; } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index b75950f88..982a84307 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -343,7 +343,11 @@ namespace sat { void push_reinit_stack(clause & c); void push_reinit_stack(literal l1, literal l2); - void init_visited(unsigned lim = 1) { m_visited.init_visited(num_vars(), lim); } + void init_visited(unsigned lim = 1) { m_visited.init_visited(2 * num_vars(), lim); } + bool is_visited(sat::bool_var v) const { return is_visited(literal(v, false)); } + bool is_visited(literal lit) const { return m_visited.is_visited(lit.index()); } + void mark_visited(literal lit) { m_visited.mark_visited(lit.index()); } + void mark_visited(bool_var v) { mark_visited(literal(v, false)); } bool all_distinct(literal_vector const& lits); bool all_distinct(clause const& cl); diff --git a/src/sat/sat_xor_finder.cpp b/src/sat/sat_xor_finder.cpp index 0a20f4782..a34d1b7ad 100644 --- a/src/sat/sat_xor_finder.cpp +++ b/src/sat/sat_xor_finder.cpp @@ -62,7 +62,7 @@ namespace sat { unsigned mask = 0, i = 0; for (literal l : c) { m_var_position[l.var()] = i; - s.m_visited.mark_visited(l.var()); + s.mark_visited(l.var()); parity ^= !l.sign(); mask |= (!l.sign() << (i++)); } @@ -84,7 +84,7 @@ namespace sat { } // loop over binary clauses in watch list for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_xor(parity, c, ~l, w.get_literal())) { add_xor(parity, c); return; @@ -93,7 +93,7 @@ namespace sat { } l.neg(); for (watched const & w : s.get_wlist(l)) { - if (w.is_binary_clause() && s.m_visited.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { if (extract_xor(parity, c, ~l, w.get_literal())) { add_xor(parity, c); return; @@ -144,7 +144,7 @@ namespace sat { bool xor_finder::extract_xor(bool parity, clause& c, clause& c2) { bool parity2 = false; for (literal l : c2) { - if (!s.m_visited.is_visited(l.var())) return false; + if (!s.is_visited(l.var())) return false; parity2 ^= !l.sign(); } if (c2.size() == c.size() && parity2 != parity) { diff --git a/src/sat/smt/pb_solver.cpp b/src/sat/smt/pb_solver.cpp index 5b2d851d3..424b20d4e 100644 --- a/src/sat/smt/pb_solver.cpp +++ b/src/sat/smt/pb_solver.cpp @@ -2709,10 +2709,10 @@ namespace pb { } void solver::init_visited() { s().init_visited(); } - void solver::mark_visited(literal l) { s().m_visited.mark_visited(l); } - void solver::mark_visited(bool_var v) { s().m_visited.mark_visited(v); } - bool solver::is_visited(bool_var v) const { return s().m_visited.is_visited(v); } - bool solver::is_visited(literal l) const { return s().m_visited.is_visited(l); } + void solver::mark_visited(literal l) { s().mark_visited(l); } + void solver::mark_visited(bool_var v) { s().mark_visited(v); } + bool solver::is_visited(bool_var v) const { return s().is_visited(v); } + bool solver::is_visited(literal l) const { return s().is_visited(l); } void solver::cleanup_clauses() { if (m_clause_removed) { diff --git a/src/util/visit_helper.h b/src/util/visit_helper.h index 1a0d4f5b9..a11d7bdc6 100644 --- a/src/util/visit_helper.h +++ b/src/util/visit_helper.h @@ -1,5 +1,21 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + visit_helper.h + +Abstract: + + Routine for marking and counting visited occurrences + +Author: + + Clemens Eisenhofer 2022-11-03 + +--*/ #pragma once -#include "sat_literal.h" + class visit_helper { @@ -7,7 +23,9 @@ class visit_helper { unsigned m_visited_begin = 0; unsigned m_visited_end = 0; - void init_ts(unsigned n, unsigned lim = 1) { +public: + + void init_visited(unsigned n, unsigned lim = 1) { SASSERT(lim > 0); if (m_visited_end >= m_visited_end + lim) { // overflow m_visited_begin = 0; @@ -18,22 +36,14 @@ class visit_helper { m_visited_begin = m_visited_end; m_visited_end = m_visited_end + lim; } - while (m_visited.size() < n) - m_visited.push_back(0); + while (m_visited.size() < n) + m_visited.push_back(0); } -public: - - void init_visited(unsigned num_vars, unsigned lim = 1) { - init_ts(2 * num_vars, lim); + void mark_visited(unsigned v) { m_visited[v] = m_visited_begin + 1; } + void inc_visited(unsigned v) { + m_visited[v] = std::min(m_visited_end, std::max(m_visited_begin, m_visited[v]) + 1); } - void mark_visited(sat::literal l) { m_visited[l.index()] = m_visited_begin + 1; } - void mark_visited(sat::bool_var v) { mark_visited(sat::literal(v, false)); } - void inc_visited(sat::literal l) { - m_visited[l.index()] = std::min(m_visited_end, std::max(m_visited_begin, m_visited[l.index()]) + 1); - } - void inc_visited(sat::bool_var v) { inc_visited(sat::literal(v, false)); } - bool is_visited(sat::bool_var v) const { return is_visited(sat::literal(v, false)); } - bool is_visited(sat::literal l) const { return m_visited[l.index()] > m_visited_begin; } - unsigned num_visited(unsigned i) { return std::max(m_visited_begin, m_visited[i]) - m_visited_begin; } + bool is_visited(unsigned v) const { return m_visited[v] > m_visited_begin; } + unsigned num_visited(unsigned v) { return std::max(m_visited_begin, m_visited[v]) - m_visited_begin; } }; \ No newline at end of file