From 3c4ac9aee5ec26830549c3b1bf750d38ef0fddc0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 19 Sep 2017 12:02:50 -0700 Subject: [PATCH] add HS and unit literal reward schemes Signed-off-by: Nikolaj Bjorner --- src/sat/sat_lookahead.cpp | 120 +++++++++++++++++++++----------------- src/sat/sat_lookahead.h | 12 +++- 2 files changed, 77 insertions(+), 55 deletions(-) diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index bfe5ec686..4e2d77f67 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -135,10 +135,8 @@ namespace sat { inc_bstamp(); set_bstamp(l); literal_vector const& conseq = m_binary[l.index()]; - literal_vector::const_iterator it = conseq.begin(); - literal_vector::const_iterator end = conseq.end(); - for (; it != end; ++it) { - set_bstamp(*it); + for (literal l : conseq) { + set_bstamp(l); } } @@ -365,27 +363,35 @@ namespace sat { } void lookahead::init_pre_selection(unsigned level) { - if (!m_config.m_use_ternary_reward) return; - unsigned max_level = m_config.m_max_hlevel; - if (level <= 1) { - ensure_H(2); - h_scores(m_H[0], m_H[1]); - for (unsigned j = 0; j < 2; ++j) { - for (unsigned i = 0; i < 2; ++i) { - h_scores(m_H[i + 1], m_H[(i + 2) % 3]); + switch (m_config.m_reward_type) { + case ternary_reward: { + unsigned max_level = m_config.m_max_hlevel; + if (level <= 1) { + ensure_H(2); + h_scores(m_H[0], m_H[1]); + for (unsigned j = 0; j < 2; ++j) { + for (unsigned i = 0; i < 2; ++i) { + h_scores(m_H[i + 1], m_H[(i + 2) % 3]); + } } + m_heur = &m_H[1]; } - m_heur = &m_H[1]; + else if (level < max_level) { + ensure_H(level); + h_scores(m_H[level-1], m_H[level]); + m_heur = &m_H[level]; + } + else { + ensure_H(max_level); + h_scores(m_H[max_level-1], m_H[max_level]); + m_heur = &m_H[max_level]; + } + break; } - else if (level < max_level) { - ensure_H(level); - h_scores(m_H[level-1], m_H[level]); - m_heur = &m_H[level]; - } - else { - ensure_H(max_level); - h_scores(m_H[max_level-1], m_H[max_level]); - m_heur = &m_H[max_level]; + case heule_schur_reward: + break; + case unit_literal_reward: + break; } } @@ -782,10 +788,8 @@ namespace sat { } void lookahead::del_clauses() { - clause * const* end = m_clauses.end(); - clause * const * it = m_clauses.begin(); - for (; it != end; ++it) { - m_cls_allocator.del_clause(*it); + for (clause * c : m_clauses) { + m_cls_allocator.del_clause(c); } } @@ -849,12 +853,10 @@ namespace sat { literal l = ~to_literal(l_idx); if (m_s.was_eliminated(l.var())) continue; watch_list const & wlist = m_s.m_watches[l_idx]; - watch_list::const_iterator it = wlist.begin(); - watch_list::const_iterator end = wlist.end(); - for (; it != end; ++it) { - if (!it->is_binary_non_learned_clause()) + for (auto& w : wlist) { + if (!w.is_binary_non_learned_clause()) continue; - literal l2 = it->get_literal(); + literal l2 = w.get_literal(); if (l.index() < l2.index() && !m_s.was_eliminated(l2.var())) add_binary(l, l2); } @@ -891,15 +893,10 @@ namespace sat { void lookahead::copy_clauses(clause_vector const& clauses) { // copy clauses - clause_vector::const_iterator it = clauses.begin(); - clause_vector::const_iterator end = clauses.end(); - for (; it != end; ++it) { - clause& c = *(*it); + for (clause* cp : clauses) { + clause& c = *cp; if (c.was_removed()) continue; // enable when there is a non-ternary reward system. - if (c.size() > 3) { - // m_config.m_use_ternary_reward = false; - } bool was_eliminated = false; for (unsigned i = 0; !was_eliminated && i < c.size(); ++i) { was_eliminated = m_s.was_eliminated(c[i].var()); @@ -1028,6 +1025,9 @@ namespace sat { } m_stats.m_windfall_binaries += m_wstack.size(); } + if (m_config.m_reward_type == unit_literal_reward) { + m_lookahead_reward += m_wstack.size(); + } m_wstack.reset(); } @@ -1219,16 +1219,20 @@ namespace sat { void lookahead::update_binary_clause_reward(literal l1, literal l2) { SASSERT(!is_false(l1)); SASSERT(!is_false(l2)); - if (m_config.m_use_ternary_reward) { + switch (m_config.m_reward_type) { + case ternary_reward: m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()]; - } - else { - m_lookahead_reward += 0.5 * (literal_occs(l1) + literal_occs(l2)); + break; + case heule_schur_reward: + m_lookahead_reward += (literal_occs(l1) + literal_occs(l2)) / 8.0; + break; + case unit_literal_reward: + break; } } void lookahead::update_nary_clause_reward(clause const& c) { - if (m_config.m_use_ternary_reward && m_lookahead_reward != 0) { + if (m_config.m_reward_type == ternary_reward && m_lookahead_reward != 0) { return; } literal const * l_it = c.begin() + 2, *l_end = c.end(); @@ -1237,7 +1241,8 @@ namespace sat { if (is_true(*l_it)) return; if (!is_false(*l_it)) ++sz; } - if (!m_config.m_use_ternary_reward) { + switch (m_config.m_reward_type) { + case heule_schur_reward: { SASSERT(sz > 0); double to_add = 0; for (literal l : c) { @@ -1245,14 +1250,18 @@ namespace sat { to_add += literal_occs(l); } } - m_lookahead_reward += pow(0.5, sz) * to_add; + m_lookahead_reward += pow(0.5, sz) * to_add / sz; + break; } - else { + case ternary_reward: m_lookahead_reward = (double)0.001; + break; + case unit_literal_reward: + break; } } - // Sum_{ clause C that contains ~l } 1 / |C| + // Sum_{ clause C that contains ~l } 1 double lookahead::literal_occs(literal l) { double result = m_binary[l.index()].size(); for (clause const* c : m_full_watches[l.index()]) { @@ -1262,7 +1271,7 @@ namespace sat { if (has_true) break; } if (!has_true) { - result += 1.0 / c->size(); + result += 1.0; } } return result; @@ -1394,6 +1403,15 @@ namespace sat { } + double lookahead::mix_diff(double l, double r) const { + switch (m_config.m_reward_type) { + case ternary_reward: return l + r + (1 << 10) * l * r; + case heule_schur_reward: return l * r; + case unit_literal_reward: return l * r; + default: UNREACHABLE(); return l * r; + } + } + void lookahead::reset_lookahead_reward(literal l) { m_lookahead_reward = 0; @@ -1406,10 +1424,8 @@ namespace sat { bool lookahead::check_autarky(literal l, unsigned level) { return false; // no propagations are allowed to reduce clauses. - clause_vector::const_iterator it = m_full_watches[l.index()].begin(); - clause_vector::const_iterator end = m_full_watches[l.index()].end(); - for (; it != end; ++it) { - clause& c = *(*it); + for (clause * cp : m_full_watches[l.index()]) { + clause& c = *cp; unsigned sz = c.size(); bool found = false; for (unsigned i = 0; !found && i < sz; ++i) { diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 38adc4505..7d4d7a39d 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -66,6 +66,12 @@ namespace sat { friend class ccc; friend class ba_solver; + enum reward_t { + ternary_reward, + unit_literal_reward, + heule_schur_reward + }; + struct config { double m_dl_success; double m_alpha; @@ -76,7 +82,7 @@ namespace sat { double m_delta_rho; unsigned m_dl_max_iterations; unsigned m_tc1_limit; - bool m_use_ternary_reward; + reward_t m_reward_type; config() { m_max_hlevel = 50; @@ -87,7 +93,7 @@ namespace sat { m_delta_rho = (double)0.9995; m_dl_max_iterations = 32; m_tc1_limit = 10000000; - m_use_ternary_reward = true; + m_reward_type = ternary_reward; } }; @@ -389,7 +395,7 @@ namespace sat { bool push_lookahead2(literal lit, unsigned level); void push_lookahead1(literal lit, unsigned level); void pop_lookahead1(literal lit); - double mix_diff(double l, double r) const { return l + r + (1 << 10) * l * r; } + double mix_diff(double l, double r) const; clause const& get_clause(watch_list::iterator it) const; bool is_nary_propagation(clause const& c, literal l) const; void propagate_clauses(literal l);