From a3f4d58b000f8ef91c5db61f74adf05329686ba3 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 18 Apr 2017 16:58:56 -0700 Subject: [PATCH] use lookahead for simplification Signed-off-by: Nikolaj Bjorner --- src/sat/sat_config.cpp | 4 +- src/sat/sat_config.h | 1 + src/sat/sat_local_search.cpp | 42 ++++++----- src/sat/sat_local_search.h | 1 + src/sat/sat_lookahead.h | 135 +++++++++++++++++++++++++++++------ src/sat/sat_params.pyg | 1 + src/sat/sat_scc.cpp | 4 +- src/sat/sat_simplifier.cpp | 8 +++ src/sat/sat_simplifier.h | 6 +- src/sat/sat_solver.cpp | 23 ++++++ src/sat/sat_solver.h | 3 + 11 files changed, 183 insertions(+), 45 deletions(-) diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index 4b019c2b7..42c185ee1 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -37,6 +37,7 @@ namespace sat { m_psm_glue("psm_glue") { m_num_threads = 1; m_local_search = 0; + m_lookahead_search = false; updt_params(p); } @@ -81,7 +82,8 @@ namespace sat { m_max_conflicts = p.max_conflicts(); m_num_threads = p.threads(); m_local_search = p.local_search(); - + m_lookahead_search = p.lookahead_search(); + // These parameters are not exposed m_simplify_mult1 = _p.get_uint("simplify_mult1", 300); m_simplify_mult2 = _p.get_double("simplify_mult2", 1.5); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index 313b4ec49..8c10983d2 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -59,6 +59,7 @@ namespace sat { unsigned m_max_conflicts; unsigned m_num_threads; unsigned m_local_search; + bool m_lookahead_search; unsigned m_simplify_mult1; double m_simplify_mult2; diff --git a/src/sat/sat_local_search.cpp b/src/sat/sat_local_search.cpp index c5e7fcf54..5372011ed 100644 --- a/src/sat/sat_local_search.cpp +++ b/src/sat/sat_local_search.cpp @@ -71,9 +71,8 @@ namespace sat { void local_search::init_cur_solution() { for (unsigned v = 0; v < num_vars(); ++v) { // use bias with a small probability - if (m_rand() % 100 < 3) { - //m_vars[v].m_value = ((unsigned)(m_rand() % 100) < m_vars[v].m_bias); - m_vars[v].m_value = (50 < m_vars[v].m_bias); + if (m_rand() % 100 < 2) { + m_vars[v].m_value = ((unsigned)(m_rand() % 100) < m_vars[v].m_bias); } } } @@ -138,20 +137,24 @@ namespace sat { } void local_search::reinit() { - // the following methods does NOT converge for pseudo-boolean - // can try other way to define "worse" and "better" - // the current best noise is below 1000 -#if 0 - if (m_best_unsat_rate > m_last_best_unsat_rate) { - // worse - m_noise -= m_noise * 2 * m_noise_delta; - m_best_unsat_rate *= 1000.0; + + if (!m_is_pb) { + // + // the following methods does NOT converge for pseudo-boolean + // can try other way to define "worse" and "better" + // the current best noise is below 1000 + // + if (m_best_unsat_rate > m_last_best_unsat_rate) { + // worse + m_noise -= m_noise * 2 * m_noise_delta; + m_best_unsat_rate *= 1000.0; + } + else { + // better + m_noise += (10000 - m_noise) * m_noise_delta; + } } - else { - // better - m_noise += (10000 - m_noise) * m_noise_delta; - } -#endif + for (unsigned i = 0; i < m_constraints.size(); ++i) { constraint& c = m_constraints[i]; c.m_slack = c.m_k; @@ -264,7 +267,7 @@ namespace sat { unsigned id = m_constraints.size(); m_constraints.push_back(constraint(k)); for (unsigned i = 0; i < sz; ++i) { - m_vars.reserve(c[i].var() + 1); + m_vars.reserve(c[i].var() + 1); literal t(~c[i]); m_vars[t.var()].m_watch[is_pos(t)].push_back(pbcoeff(id, coeffs[i])); m_constraints.back().push(t); // add coefficient to constraint? @@ -279,6 +282,7 @@ namespace sat { } void local_search::import(solver& s, bool _init) { + m_is_pb = false; m_vars.reset(); m_constraints.reset(); @@ -349,6 +353,7 @@ namespace sat { // = ~c.lit() or (~c.lits() <= n - k) // = k*c.lit() + ~c.lits() <= n // + m_is_pb = true; lits.reset(); coeffs.reset(); for (unsigned j = 0; j < n; ++j) lits.push_back(c[j]), coeffs.push_back(1); @@ -616,8 +621,7 @@ namespace sat { // verify_unsat_stack(); } - void local_search::flip_gsat(bool_var flipvar) - { + void local_search::flip_gsat(bool_var flipvar) { // already changed truth value!!!! m_vars[flipvar].m_value = !cur_solution(flipvar); diff --git a/src/sat/sat_local_search.h b/src/sat/sat_local_search.h index 918f5328d..edce6cc9c 100644 --- a/src/sat/sat_local_search.h +++ b/src/sat/sat_local_search.h @@ -144,6 +144,7 @@ namespace sat { literal_vector m_assumptions; unsigned m_num_non_binary_clauses; + bool m_is_pb; inline bool is_pos(literal t) const { return !t.sign(); } inline bool is_true(bool_var v) const { return cur_solution(v); } diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index c4f6a4bba..6bc528220 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -20,6 +20,8 @@ Notes: #ifndef _SAT_LOOKAHEAD_H_ #define _SAT_LOOKAHEAD_H_ +#include "sat_elim_eqs.h" + namespace sat { struct pp_prefix { @@ -34,6 +36,9 @@ namespace sat { for (unsigned i = 0; i <= d; ++i) { if (0 != (p.m_prefix & (1ull << i))) out << "1"; else out << "0"; } + if (d < p.m_depth) { + out << " d:" << p.m_depth; + } return out; } @@ -309,6 +314,7 @@ namespace sat { assign(u); return false; } + IF_VERBOSE(3, verbose_stream() << "tc1: " << u << " " << w << "\n";); add_binary(u, w); } } @@ -367,7 +373,8 @@ namespace sat { bool select(unsigned level) { init_pre_selection(level); - unsigned max_num_cand = level == 0 ? m_freevars.size() : m_config.m_level_cand / level; + unsigned level_cand = std::max(m_config.m_level_cand, m_freevars.size() / 50); + unsigned max_num_cand = level == 0 ? m_freevars.size() : level_cand / level; max_num_cand = std::max(m_config.m_min_cutoff, max_num_cand); float sum = 0; @@ -1010,7 +1017,8 @@ namespace sat { m_lits.push_back(lit_info()); m_rating.push_back(0); m_vprefix.push_back(prefix()); - m_freevars.insert(v); + if (!s.was_eliminated(v)) + m_freevars.insert(v); } void init() { @@ -1040,11 +1048,31 @@ namespace sat { } } + copy_clauses(s.m_clauses); + copy_clauses(s.m_learned); + + // copy units + unsigned trail_sz = s.init_trail_size(); + for (unsigned i = 0; i < trail_sz; ++i) { + literal l = s.m_trail[i]; + if (!s.was_eliminated(l.var())) + { + if (s.m_config.m_drat) m_drat.add(l, false); + assign(l); + } + } + propagate(); + m_qhead = m_trail.size(); + TRACE("sat", s.display(tout); display(tout);); + } + + void copy_clauses(clause_vector const& clauses) { // copy clauses - clause_vector::const_iterator it = s.m_clauses.begin(); - clause_vector::const_iterator end = s.m_clauses.end(); - for (; it != end; ++it) { + clause_vector::const_iterator it = clauses.begin(); + clause_vector::const_iterator end = clauses.end(); + for (; it != end; ++it) { clause& c = *(*it); + if (c.was_removed()) continue; clause* c1 = m_cls_allocator.mk_clause(c.size(), c.begin(), false); m_clauses.push_back(c1); attach_clause(*c1); @@ -1053,17 +1081,6 @@ namespace sat { } if (s.m_config.m_drat) m_drat.add(c, false); } - - // copy units - unsigned trail_sz = s.init_trail_size(); - for (unsigned i = 0; i < trail_sz; ++i) { - literal l = s.m_trail[i]; - if (s.m_config.m_drat) m_drat.add(l, false); - assign(l); - } - propagate(); - m_qhead = m_trail.size(); - TRACE("sat", s.display(tout); display(tout);); } // ------------------------------------ @@ -1393,19 +1410,24 @@ namespace sat { TRACE("sat", display_lookahead(tout); ); unsigned base = 2; bool change = true; + bool first = true; while (change && !inconsistent()) { change = false; for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) { + s.checkpoint(); literal lit = m_lookahead[i].m_lit; if (is_fixed_at(lit, c_fixed_truth)) continue; unsigned level = base + m_lookahead[i].m_offset; if (m_stamp[lit.var()] >= level) { continue; } + if (scope_lvl() == 1) { + IF_VERBOSE(3, verbose_stream() << scope_lvl() << " " << lit << " binary: " << m_binary_trail.size() << " trail: " << m_trail_lim.back() << "\n";); + } TRACE("sat", tout << "lookahead: " << lit << " @ " << m_lookahead[i].m_offset << "\n";); reset_wnb(lit); push_lookahead1(lit, level); - do_double(lit, base); + if (!first) do_double(lit, base); bool unsat = inconsistent(); pop_lookahead1(lit); if (unsat) { @@ -1424,7 +1446,13 @@ namespace sat { if (c_fixed_truth - 2 * m_lookahead.size() < base) { break; } - base += 2 * m_lookahead.size(); + if (first && !change) { + first = false; + change = true; + } + reset_wnb(); + init_wnb(); + // base += 2 * m_lookahead.size(); } reset_wnb(); TRACE("sat", display_lookahead(tout); ); @@ -1487,6 +1515,7 @@ namespace sat { } bool check_autarky(literal l, unsigned level) { + return false; // no propagations are allowed to reduce clauses. clause_vector::const_iterator it = m_full_watches[l.index()].begin(); clause_vector::const_iterator end = m_full_watches[l.index()].end(); @@ -1568,7 +1597,7 @@ namespace sat { } void do_double(literal l, unsigned& base) { - if (!inconsistent() && scope_lvl() > 0 && dl_enabled(l)) { + if (!inconsistent() && scope_lvl() > 1 && dl_enabled(l)) { if (get_wnb(l) > m_delta_trigger) { if (dl_no_overflow(base)) { ++m_stats.m_double_lookahead_rounds; @@ -1588,6 +1617,7 @@ namespace sat { SASSERT(dl_no_overflow(base)); unsigned dl_truth = base + 2 * m_lookahead.size() * (m_config.m_dl_max_iterations + 1); scoped_level _sl(*this, dl_truth); + IF_VERBOSE(2, verbose_stream() << "double: " << l << "\n";); init_wnb(); assign(l); propagate(); @@ -1769,9 +1799,6 @@ namespace sat { m_drat(s), m_level(2), m_prefix(0) { - m_search_mode = lookahead_mode::searching; - scoped_level _sl(*this, c_fixed_truth); - init(); } ~lookahead() { @@ -1779,9 +1806,73 @@ namespace sat { } lbool check() { + { + m_search_mode = lookahead_mode::searching; + scoped_level _sl(*this, c_fixed_truth); + init(); + } return search(); } + void simplify() { + SASSERT(m_prefix == 0); + SASSERT(m_watches.empty()); + m_search_mode = lookahead_mode::searching; + scoped_level _sl(*this, c_fixed_truth); + init(); + if (inconsistent()) return; + inc_istamp(); + literal l = choose(); + if (inconsistent()) return; + SASSERT(m_trail_lim.empty()); + unsigned num_units = 0; + for (unsigned i = 0; i < m_trail.size(); ++i) { + literal lit = m_trail[i]; + if (s.value(lit) == l_undef && !s.was_eliminated(lit.var())) { + s.m_simplifier.propagate_unit(lit); + ++num_units; + } + } + IF_VERBOSE(1, verbose_stream() << "units found: " << num_units << "\n";); + + s.m_simplifier.subsume(); + m_lookahead.reset(); + } + + void scc() { + SASSERT(m_prefix == 0); + SASSERT(m_watches.empty()); + m_search_mode = lookahead_mode::searching; + scoped_level _sl(*this, c_fixed_truth); + init(); + if (inconsistent()) return; + inc_istamp(); + m_lookahead.reset(); + if (select(0)) { + // extract equivalences + get_scc(); + if (inconsistent()) return; + literal_vector roots; + bool_var_vector to_elim; + for (unsigned i = 0; i < s.num_vars(); ++i) { + roots.push_back(literal(i, false)); + } + for (unsigned i = 0; i < m_candidates.size(); ++i) { + bool_var v = m_candidates[i].m_var; + literal lit = literal(v, false); + literal p = get_parent(lit); + if (p != null_literal && p.var() != v && !s.is_external(v) && !s.was_eliminated(v) && !s.was_eliminated(p.var())) { + to_elim.push_back(v); + roots[v] = p; + } + } + IF_VERBOSE(1, verbose_stream() << "eliminate " << to_elim.size() << " variables\n";); + elim_eqs elim(s); + elim(roots, to_elim); + } + m_lookahead.reset(); + } + std::ostream& display(std::ostream& out) const { out << "Prefix: " << pp_prefix(m_prefix, m_trail_lim.size()) << "\n"; out << "Level: " << m_level << "\n"; diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 045fd803a..ffc699d02 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -29,4 +29,5 @@ def_module_params('sat', ('cardinality.solver', BOOL, False, 'use cardinality solver'), ('xor.solver', BOOL, False, 'use xor solver'), ('local_search', UINT, 0, 'number of local search threads to find satisfiable solution'), + ('lookahead_search', BOOL, False, 'use lookahead solver') )) diff --git a/src/sat/sat_scc.cpp b/src/sat/sat_scc.cpp index ffbdb31c6..3dfc42f6a 100644 --- a/src/sat/sat_scc.cpp +++ b/src/sat/sat_scc.cpp @@ -76,7 +76,9 @@ namespace sat { lowlink.resize(num_lits, UINT_MAX); in_s.resize(num_lits, false); literal_vector roots; - roots.resize(m_solver.num_vars(), null_literal); + for (unsigned i = 0; i < m_solver.num_vars(); ++i) { + roots.push_back(literal(i, false)); + } unsigned next_index = 0; svector frames; bool_var_vector to_elim; diff --git a/src/sat/sat_simplifier.cpp b/src/sat/sat_simplifier.cpp index fe019427f..8cbedb86b 100644 --- a/src/sat/sat_simplifier.cpp +++ b/src/sat/sat_simplifier.cpp @@ -21,6 +21,7 @@ Revision History: #include"sat_simplifier.h" #include"sat_simplifier_params.hpp" #include"sat_solver.h" +#include"sat_lookahead.h" #include"stopwatch.h" #include"trace.h" @@ -204,6 +205,11 @@ namespace sat { } while (!m_sub_todo.empty()); + if (!learned) { + // perform lookahead simplification + lookahead(s).simplify(); + } + bool vars_eliminated = m_num_elim_vars > old_num_elim_vars; if (m_need_cleanup) { @@ -219,9 +225,11 @@ namespace sat { cleanup_clauses(s.m_learned, true, true, learned_in_use_lists); } } + CASSERT("sat_solver", s.check_invariant()); TRACE("after_simplifier", s.display(tout); tout << "model_converter:\n"; s.m_mc.display(tout);); finalize(); + } /** diff --git a/src/sat/sat_simplifier.h b/src/sat/sat_simplifier.h index 9ee239083..47648cc10 100644 --- a/src/sat/sat_simplifier.h +++ b/src/sat/sat_simplifier.h @@ -130,13 +130,11 @@ namespace sat { bool cleanup_clause(clause & c, bool in_use_list); bool cleanup_clause(literal_vector & c); - void propagate_unit(literal l); void elim_lit(clause & c, literal l); void elim_dup_bins(); bool subsume_with_binaries(); void mark_as_not_learned_core(watch_list & wlist, literal l2); void mark_as_not_learned(literal l1, literal l2); - void subsume(); void cleanup_watches(); void cleanup_clauses(clause_vector & cs, bool learned, bool vars_eliminated, bool in_use_lists); @@ -191,6 +189,10 @@ namespace sat { void collect_statistics(statistics & st) const; void reset_statistics(); + + void propagate_unit(literal l); + void subsume(); + }; }; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index d25c8bdee..5ebd661ee 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -22,6 +22,7 @@ Revision History: #include"trace.h" #include"max_cliques.h" #include"scoped_ptr_vector.h" +#include"sat_lookahead.h" // define to update glue during propagation #define UPDATE_GLUE @@ -783,6 +784,9 @@ namespace sat { pop_to_base_level(); IF_VERBOSE(2, verbose_stream() << "(sat.sat-solver)\n";); SASSERT(at_base_lvl()); + if (m_config.m_lookahead_search && num_lits == 0) { + return lookahead_search(); + } if ((m_config.m_num_threads > 1 || m_config.m_local_search > 0) && !m_par) { return check_par(num_lits, lits); } @@ -855,6 +859,20 @@ namespace sat { ERROR_EX }; + lbool solver::lookahead_search() { + lookahead lh(*this); + lbool r = l_undef; + try { + r = lh.check(); + m_model = lh.get_model(); + } + catch (z3_exception&) { + lh.collect_statistics(m_lookahead_stats); + throw; + } + lh.collect_statistics(m_lookahead_stats); + return r; + } lbool solver::check_par(unsigned num_lits, literal const* lits) { scoped_ptr_vector ls; @@ -1295,6 +1313,8 @@ namespace sat { CASSERT("sat_simplify_bug", check_invariant()); } + lookahead(*this).scc(); + sort_watch_lits(); CASSERT("sat_simplify_bug", check_invariant()); @@ -2762,6 +2782,7 @@ namespace sat { m_asymm_branch.collect_statistics(st); m_probing.collect_statistics(st); if (m_ext) m_ext->collect_statistics(st); + st.copy(m_lookahead_stats); } void solver::reset_statistics() { @@ -2770,6 +2791,7 @@ namespace sat { m_simplifier.reset_statistics(); m_asymm_branch.reset_statistics(); m_probing.reset_statistics(); + m_lookahead_stats.reset(); } // ----------------------- @@ -3605,6 +3627,7 @@ namespace sat { if (m_solver.m_num_frozen > 0) out << " :frozen " << m_solver.m_num_frozen; } + out << " :units " << m_solver.init_trail_size(); out << " :gc-clause " << m_solver.m_stats.m_gc_clause; out << mem_stat(); } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index c35a0296c..1bf393696 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -141,6 +141,8 @@ namespace sat { unsigned m_par_num_vars; bool m_par_syncing_clauses; + statistics m_lookahead_stats; + void del_clauses(clause * const * begin, clause * const * end); friend class integrity_checker; @@ -346,6 +348,7 @@ namespace sat { void sort_watch_lits(); void exchange_par(); lbool check_par(unsigned num_lits, literal const* lits); + lbool lookahead_search(); // ----------------------- //