diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index 49a146489..c77774283 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -100,6 +100,9 @@ namespace sat { else if (p.lookahead_reward() == symbol("unit")) { m_lookahead_reward = unit_literal_reward; } + else if (p.lookahead_reward() == symbol("march_cu")) { + m_lookahead_reward = march_cu_reward; + } else { throw sat_param_exception("invalid reward type supplied: accepted heuristics are 'ternary', 'heuleu', 'unit' or 'heule_schur'"); } diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index a7c8590fb..214a93f5d 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -61,7 +61,8 @@ namespace sat { ternary_reward, unit_literal_reward, heule_schur_reward, - heule_unit_reward + heule_unit_reward, + march_cu_reward }; struct config { diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 7ce2b53b9..0374aca9c 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -217,7 +217,7 @@ namespace sat { bool lookahead::select(unsigned level) { init_pre_selection(level); unsigned level_cand = std::max(m_config.m_level_cand, m_freevars.size() / 50); - unsigned max_num_cand = level == 0 ? m_freevars.size() : level_cand / level; + unsigned max_num_cand = (level > 0 && m_config.m_preselect) ? level_cand / level : m_freevars.size(); max_num_cand = std::max(m_config.m_min_cutoff, max_num_cand); double sum = 0; @@ -251,30 +251,40 @@ namespace sat { } TRACE("sat", display_candidates(tout);); SASSERT(!m_candidates.empty()); - if (m_candidates.size() > max_num_cand) { - unsigned j = m_candidates.size()/2; - while (j > 0) { - --j; - sift_up(j); - } - while (true) { - m_candidates[0] = m_candidates.back(); - m_candidates.pop_back(); - if (m_candidates.size() == max_num_cand) break; - sift_up(0); - } + heap_sort(); + while (m_candidates.size() > max_num_cand) { + m_candidates.pop_back(); } SASSERT(!m_candidates.empty() && m_candidates.size() <= max_num_cand); TRACE("sat", display_candidates(tout);); return true; } - void lookahead::sift_up(unsigned j) { + void lookahead::heap_sort() { + if (m_candidates.size() > 1) { + heapify(); + for (unsigned i = m_candidates.size() - 1; i > 0; --i) { + candidate c = m_candidates[i]; + m_candidates[i] = m_candidates[0]; + m_candidates[0] = c; + sift_down(0, i); + } + } + } + + void lookahead::heapify() { + unsigned i = 1 + (m_candidates.size() - 2) / 2; + while(i > 0) { + sift_down(--i, m_candidates.size()); + } + } + + void lookahead::sift_down(unsigned j, unsigned sz) { unsigned i = j; candidate c = m_candidates[j]; - for (unsigned k = 2*j + 1; k < m_candidates.size(); i = k, k = 2*k + 1) { - // pick largest parent - if (k + 1 < m_candidates.size() && m_candidates[k].m_rating < m_candidates[k+1].m_rating) { + for (unsigned k = 2 * j + 1; k < sz; i = k, k = 2 * k + 1) { + // pick smallest child + if (k + 1 < sz && m_candidates[k].m_rating > m_candidates[k + 1].m_rating) { ++k; } if (c.m_rating <= m_candidates[k].m_rating) break; @@ -452,6 +462,9 @@ namespace sat { case heule_unit_reward: heule_unit_scores(); break; + case march_cu_reward: + march_cu_scores(); + break; case unit_literal_reward: heule_schur_scores(); break; @@ -502,7 +515,7 @@ namespace sat { for (bool_var x : m_freevars) { literal l(x, false); m_rating[l.var()] = heule_unit_score(l) * heule_unit_score(~l); - } + } } double lookahead::heule_unit_score(literal l) { @@ -524,7 +537,23 @@ namespace sat { } #endif return sum; - } + } + + void lookahead::march_cu_scores() { + for (bool_var x : m_freevars) { + literal l(x, false); + double pos = march_cu_score(l), neg = march_cu_score(~l); + m_rating[l.var()] = 1024 * pos * neg + pos + neg + 1; + } + } + + double lookahead::march_cu_score(literal l) { + double sum = 1.0 + literal_big_occs(~l); + for (literal lit : m_binary[l.index()]) { + if (is_undef(lit)) sum += literal_big_occs(lit); + } + return sum; + } void lookahead::ensure_H(unsigned level) { while (m_H.size() <= level) { @@ -920,7 +949,8 @@ namespace sat { } void lookahead::init() { - m_delta_trigger = m_num_vars/10; + m_delta_trigger = 0.0; + m_delta_decrease = 0.0; m_config.m_dl_success = 0.8; m_inconsistent = false; m_qhead = 0; @@ -1104,6 +1134,7 @@ namespace sat { m_lookahead_reward += num_units; break; case heule_unit_reward: + case march_cu_reward: case heule_schur_reward: break; default: @@ -1496,6 +1527,9 @@ namespace sat { case heule_unit_reward: m_lookahead_reward += pow(0.5, nonfixed); break; + case march_cu_reward: + m_lookahead_reward += 3.3 * pow(0.5, nonfixed - 2); + break; case ternary_reward: if (nonfixed == 2) { m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()]; @@ -1697,6 +1731,9 @@ namespace sat { case heule_unit_reward: m_lookahead_reward += 0.25; break; + case march_cu_reward: + m_lookahead_reward += 3.3; + break; case unit_literal_reward: break; } @@ -1727,6 +1764,9 @@ namespace sat { case heule_unit_reward: m_lookahead_reward += pow(0.5, sz); break; + case march_cu_reward: + m_lookahead_reward += 3.3 * pow(0.5, sz - 2); + break; case ternary_reward: m_lookahead_reward = (double)0.001; break; @@ -1736,10 +1776,17 @@ namespace sat { } // Sum_{ clause C that contains ~l } 1 + // FIXME: counts occurences of ~l; misleading double lookahead::literal_occs(literal l) { double result = m_binary[l.index()].size(); - // unsigned_vector const& nclauses = m_nary[(~l).index()]; - result += m_nary_count[(~l).index()]; + result += literal_big_occs(l); + return result; + } + + // Sum_{ clause C that contains ~l such that |C| > 2} 1 + // FIXME: counts occurences of ~l; misleading + double lookahead::literal_big_occs(literal l) { + double result = m_nary_count[(~l).index()]; result += m_ternary_count[(~l).index()]; return result; } @@ -1753,36 +1800,37 @@ namespace sat { } } - void lookahead::propagate() { while (!inconsistent() && m_qhead < m_trail.size()) { unsigned i = m_qhead; - unsigned sz = m_trail.size(); - for (; i < sz && !inconsistent(); ++i) { + for (; i < m_trail.size() && !inconsistent(); ++i) { literal l = m_trail[i]; TRACE("sat", tout << "propagate " << l << " @ " << m_level << "\n";); propagate_binary(l); } - while (m_qhead < sz && !inconsistent()) { + while (m_qhead < m_trail.size() && !inconsistent()) { propagate_clauses(m_trail[m_qhead++]); } - SASSERT(m_qhead == sz || (inconsistent() && m_qhead < sz)); + SASSERT(m_qhead == m_trail.size() || (inconsistent() && m_qhead < m_trail.size())); } - TRACE("sat_verbose", display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n");); } void lookahead::compute_lookahead_reward() { init_lookahead_reward(); TRACE("sat", display_lookahead(tout); ); + m_delta_decrease = pow(m_config.m_delta_rho, 1.0 / (double)m_lookahead.size()); unsigned base = 2; bool change = true; - bool first = true; + literal last_changed = null_literal; while (change && !inconsistent()) { change = false; for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) { checkpoint(); literal lit = m_lookahead[i].m_lit; + if (lit == last_changed) { + break; + } if (is_fixed_at(lit, c_fixed_truth)) continue; unsigned level = base + m_lookahead[i].m_offset; if (m_stamp[lit.var()] >= level) { @@ -1795,7 +1843,7 @@ namespace sat { unsigned old_trail_sz = m_trail.size(); reset_lookahead_reward(lit); push_lookahead1(lit, level); - if (!first) do_double(lit, base); + do_double(lit, base); bool unsat = inconsistent(); unsigned num_units = m_trail.size() - old_trail_sz; pop_lookahead1(lit, num_units); @@ -1806,6 +1854,7 @@ namespace sat { propagate(); init_lookahead_reward(); change = true; + last_changed = lit; } else { update_lookahead_reward(lit, level); @@ -1815,13 +1864,9 @@ namespace sat { if (c_fixed_truth - 2 * m_lookahead.size() < base) { break; } - if (first && !change) { - first = false; - change = true; - } reset_lookahead_reward(); init_lookahead_reward(); - // base += 2 * m_lookahead.size(); + base += 2 * m_lookahead.size(); } reset_lookahead_reward(); TRACE("sat", display_lookahead(tout); ); @@ -1877,6 +1922,7 @@ namespace sat { case ternary_reward: return l + r + (1 << 10) * l * r; case heule_schur_reward: return l * r; case heule_unit_reward: return l * r; + case march_cu_reward: return 1024 * (1024 * l * r + l + r); case unit_literal_reward: return l * r; default: UNREACHABLE(); return l * r; } @@ -1966,18 +2012,21 @@ namespace sat { } } - void lookahead::do_double(literal l, unsigned& base) { - if (!inconsistent() && scope_lvl() > 1 && dl_enabled(l)) { + void lookahead::do_double(literal l, unsigned& base) { + if (!inconsistent() && dl_enabled(l)) { if (get_lookahead_reward(l) > m_delta_trigger) { if (dl_no_overflow(base)) { ++m_stats.m_double_lookahead_rounds; double_look(l, base); - m_delta_trigger = get_lookahead_reward(l); - dl_disable(l); + if (!inconsistent()) { + m_delta_trigger = get_lookahead_reward(l); + dl_disable(l); + } } } else { - m_delta_trigger *= m_config.m_delta_rho; + SASSERT(m_delta_decrease > 0.0); + m_delta_trigger *= m_delta_decrease; } } } @@ -1985,21 +2034,30 @@ namespace sat { void lookahead::double_look(literal l, unsigned& base) { SASSERT(!inconsistent()); SASSERT(dl_no_overflow(base)); - unsigned dl_truth = base + 2 * m_lookahead.size() * (m_config.m_dl_max_iterations + 1); + base += m_lookahead.size(); + unsigned dl_truth = base + m_lookahead.size() * m_config.m_dl_max_iterations; scoped_level _sl(*this, dl_truth); IF_VERBOSE(2, verbose_stream() << "double: " << l << " depth: " << m_trail_lim.size() << "\n";); init_lookahead_reward(); assign(l); propagate(); bool change = true; + literal last_changed = null_literal; unsigned num_iterations = 0; while (change && num_iterations < m_config.m_dl_max_iterations && !inconsistent()) { change = false; num_iterations++; - base += 2*m_lookahead.size(); for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) { literal lit = m_lookahead[i].m_lit; - if (is_fixed_at(lit, dl_truth)) continue; + if (lit == last_changed) { + SASSERT(change == false); + break; + } + if (is_fixed_at(lit, dl_truth)) continue; + if (base + m_lookahead.size() + m_lookahead[i].m_offset >= dl_truth) { + change = false; + break; + } if (push_lookahead2(lit, base + m_lookahead[i].m_offset)) { TRACE("sat", tout << "unit: " << ~lit << "\n";); ++m_stats.m_double_lookahead_propagations; @@ -2008,10 +2066,12 @@ namespace sat { assign(~lit); propagate(); change = true; + last_changed = lit; init_lookahead_reward(); } } - SASSERT(dl_truth - 2 * m_lookahead.size() > base); + base += 2 * m_lookahead.size(); + SASSERT(dl_truth >= base); } reset_lookahead_reward(); SASSERT(m_level == dl_truth); @@ -2051,14 +2111,13 @@ namespace sat { void lookahead::propagated(literal l) { assign(l); - switch (m_search_mode) { - case lookahead_mode::searching: - break; - case lookahead_mode::lookahead1: + for (unsigned i = m_trail.size() - 1; i < m_trail.size() && !inconsistent(); ++i) { + literal l = m_trail[i]; + TRACE("sat", tout << "propagate " << l << " @ " << m_level << "\n";); + propagate_binary(l); + } + if (m_search_mode == lookahead_mode::lookahead1) { m_wstack.push_back(l); - break; - case lookahead_mode::lookahead2: - break; } } @@ -2166,10 +2225,11 @@ namespace sat { backtrack(m_cube_state.m_cube, m_cube_state.m_is_decision); return l_undef; } + int prev_nfreevars = m_freevars.size(); literal lit = choose(); if (inconsistent()) { TRACE("sat", tout << "inconsistent: " << m_cube_state.m_cube << "\n";); - m_cube_state.m_freevars_threshold = m_freevars.size(); + m_cube_state.m_freevars_threshold = prev_nfreevars; if (!backtrack(m_cube_state.m_cube, m_cube_state.m_is_decision)) return l_false; continue; } diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 2972bc167..38f9b0481 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -74,6 +74,7 @@ namespace sat { double m_max_score; unsigned m_max_hlevel; unsigned m_min_cutoff; + bool m_preselect; unsigned m_level_cand; double m_delta_rho; unsigned m_dl_max_iterations; @@ -87,9 +88,10 @@ namespace sat { m_alpha = 3.5; m_max_score = 20.0; m_min_cutoff = 30; + m_preselect = false; m_level_cand = 600; - m_delta_rho = (double)0.9995; - m_dl_max_iterations = 32; + m_delta_rho = (double)0.25; + m_dl_max_iterations = 2; m_tc1_limit = 10000000; m_reward_type = ternary_reward; m_cube_cutoff = 0; @@ -175,7 +177,8 @@ namespace sat { }; config m_config; - double m_delta_trigger; + double m_delta_trigger; + double m_delta_decrease; drat m_drat; literal_vector m_assumptions; @@ -327,7 +330,10 @@ namespace sat { double get_rating(bool_var v) const { return m_rating[v]; } double get_rating(literal l) const { return get_rating(l.var()); } bool select(unsigned level); - void sift_up(unsigned j); + //void sift_up(unsigned j); + void heap_sort(); + void heapify(); + void sift_down(unsigned j, unsigned sz); double init_candidates(unsigned level, bool newbies); std::ostream& display_candidates(std::ostream& out) const; bool is_unsat() const; @@ -339,6 +345,8 @@ namespace sat { double heule_schur_score(literal l); void heule_unit_scores(); double heule_unit_score(literal l); + void march_cu_scores(); + double march_cu_score(literal l); double l_score(literal l, svector const& h, double factor, double sqfactor, double afactor); // ------------------------------------ @@ -503,6 +511,7 @@ namespace sat { void do_double(literal l, unsigned& base); void double_look(literal l, unsigned& base); void set_conflict() { TRACE("sat", tout << "conflict\n";); m_inconsistent = true; } + //void set_conflict() { TRACE("sat", tout << "conflict\n";); printf("CONFLICT\n"); m_inconsistent = true; } bool inconsistent() { return m_inconsistent; } unsigned scope_lvl() const { return m_trail_lim.size(); } @@ -593,7 +602,8 @@ namespace sat { void collect_statistics(statistics& st) const; - double literal_occs(literal l); + double literal_occs(literal l); + double literal_big_occs(literal l); }; } diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 01708b775..296c75180 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -31,7 +31,7 @@ def_module_params('sat', ('drat.check', BOOL, False, 'build up internal proof and check'), ('cardinality.solver', BOOL, False, 'use cardinality solver'), ('pb.solver', SYMBOL, 'circuit', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use SMT solver)'), - ('xor.solver', BOOL, False, 'use xor solver'), + ('xor.solver', BOOL, False, 'use xor solver'), ('atmost1_encoding', SYMBOL, 'grouped', 'encoding used for at-most-1 constraints grouped, bimander, ordered'), ('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'), ('local_search', BOOL, False, 'use local search instead of CDCL'), @@ -40,7 +40,7 @@ def_module_params('sat', ('lookahead.cube.cutoff', UINT, 0, 'cut-off depth to create cubes. Only enabled when non-zero. Used when lookahead_cube is true.'), ('lookahead_search', BOOL, False, 'use lookahead solver'), ('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'), - ('lookahead.reward', SYMBOL, 'heuleu', 'select lookahead heuristic: ternary, heule_schur (Heule Schur), heuleu (Heule Unit), or unit'), + ('lookahead.reward', SYMBOL, 'march_cu', 'select lookahead heuristic: ternary, heule_schur (Heule Schur), heuleu (Heule Unit), unit, or march_cu'), ('dimacs.display', BOOL, False, 'display SAT instance in DIMACS format and return unknown instead of solving'), ('dimacs.inprocess.display', BOOL, False, 'display SAT instance in DIMACS format if unsolved after inprocess.max inprocessing passes')))