diff --git a/src/sat/card_extension.cpp b/src/sat/card_extension.cpp index bd32a7c40..df5bb1eed 100644 --- a/src/sat/card_extension.cpp +++ b/src/sat/card_extension.cpp @@ -482,7 +482,7 @@ namespace sat { } } - void card_extension::display(std::ostream& out, pb& p, bool values) const { + void card_extension::display(std::ostream& out, pb const& p, bool values) const { out << p.lit() << "[" << p.size() << "]"; if (p.lit() != null_literal && values) { out << "@(" << value(p.lit()); @@ -835,8 +835,9 @@ namespace sat { bool card_extension::resolve_conflict() { - if (0 == m_num_propagations_since_pop) + if (0 == m_num_propagations_since_pop) { return false; + } reset_coeffs(); m_num_marks = 0; m_bound = 0; @@ -941,6 +942,7 @@ namespace sat { m_bound += offset; inc_coeff(consequent, offset); get_pb_antecedents(consequent, p, m_lemma); + TRACE("sat", tout << m_lemma << "\n";); for (unsigned i = 0; i < m_lemma.size(); ++i) { process_antecedent(~m_lemma[i], offset); } @@ -1187,6 +1189,7 @@ namespace sat { void card_extension::add_at_least(bool_var v, literal_vector const& lits, unsigned k) { unsigned index = 4*m_cards.size(); + SASSERT(is_card_index(index)); literal lit = v == null_bool_var ? null_literal : literal(v, false); card* c = new (memory::allocate(card::get_obj_size(lits.size()))) card(index, lit, lits, k); m_cards.push_back(c); @@ -1203,7 +1206,8 @@ namespace sat { } void card_extension::add_pb_ge(bool_var v, svector const& wlits, unsigned k) { - unsigned index = 4*m_pbs.size() + 0x11; + unsigned index = 4*m_pbs.size() + 0x3; + SASSERT(is_pb_index(index)); literal lit = v == null_bool_var ? null_literal : literal(v, false); pb* p = new (memory::allocate(pb::get_obj_size(wlits.size()))) pb(index, lit, wlits, k); m_pbs.push_back(p); @@ -1221,7 +1225,8 @@ namespace sat { void card_extension::add_xor(bool_var v, literal_vector const& lits) { m_has_xor = true; - unsigned index = 4*m_xors.size() + 0x01; + unsigned index = 4*m_xors.size() + 0x1; + SASSERT(is_xor_index(index)); xor* x = new (memory::allocate(xor::get_obj_size(lits.size()))) xor(index, literal(v, false), lits); m_xors.push_back(x); init_watch(v); @@ -1358,46 +1363,50 @@ namespace sat { SASSERT(max_sum < k); } + void card_extension::get_card_antecedents(literal l, card const& c, literal_vector& r) { + DEBUG_CODE( + bool found = false; + for (unsigned i = 0; !found && i < c.k(); ++i) { + found = c[i] == l; + } + SASSERT(found);); + + if (c.lit() != null_literal) r.push_back(c.lit()); + SASSERT(c.lit() == null_literal || value(c.lit()) == l_true); + for (unsigned i = c.k(); i < c.size(); ++i) { + SASSERT(value(c[i]) == l_false); + r.push_back(~c[i]); + } + } + + void card_extension::get_xor_antecedents(literal l, xor const& x, literal_vector& r) { + if (x.lit() != null_literal) r.push_back(x.lit()); + // TRACE("sat", display(tout << l << " ", x, true);); + SASSERT(x.lit() == null_literal || value(x.lit()) == l_true); + SASSERT(x[0].var() == l.var() || x[1].var() == l.var()); + if (x[0].var() == l.var()) { + SASSERT(value(x[1]) != l_undef); + r.push_back(value(x[1]) == l_true ? x[1] : ~x[1]); + } + else { + SASSERT(value(x[0]) != l_undef); + r.push_back(value(x[0]) == l_true ? x[0] : ~x[0]); + } + for (unsigned i = 2; i < x.size(); ++i) { + SASSERT(value(x[i]) != l_undef); + r.push_back(value(x[i]) == l_true ? x[i] : ~x[i]); + } + } + void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) { if (is_card_index(idx)) { - card& c = index2card(idx); - - DEBUG_CODE( - bool found = false; - for (unsigned i = 0; !found && i < c.k(); ++i) { - found = c[i] == l; - } - SASSERT(found);); - - if (c.lit() != null_literal) r.push_back(c.lit()); - SASSERT(c.lit() == null_literal || value(c.lit()) == l_true); - for (unsigned i = c.k(); i < c.size(); ++i) { - SASSERT(value(c[i]) == l_false); - r.push_back(~c[i]); - } + get_card_antecedents(l, index2card(idx), r); } else if (is_xor_index(idx)) { - xor& x = index2xor(idx); - if (x.lit() != null_literal) r.push_back(x.lit()); - TRACE("sat", display(tout << l << " ", x, true);); - SASSERT(x.lit() == null_literal || value(x.lit()) == l_true); - SASSERT(x[0].var() == l.var() || x[1].var() == l.var()); - if (x[0].var() == l.var()) { - SASSERT(value(x[1]) != l_undef); - r.push_back(value(x[1]) == l_true ? x[1] : ~x[1]); - } - else { - SASSERT(value(x[0]) != l_undef); - r.push_back(value(x[0]) == l_true ? x[0] : ~x[0]); - } - for (unsigned i = 2; i < x.size(); ++i) { - SASSERT(value(x[i]) != l_undef); - r.push_back(value(x[i]) == l_true ? x[i] : ~x[i]); - } + get_xor_antecedents(l, index2xor(idx), r); } else if (is_pb_index(idx)) { - pb const& p = index2pb(idx); - get_pb_antecedents(l, p, r); + get_pb_antecedents(l, index2pb(idx), r); } else { UNREACHABLE(); @@ -1635,7 +1644,7 @@ namespace sat { out << ">= " << ineq.m_k << "\n"; } - void card_extension::display(std::ostream& out, xor& x, bool values) const { + void card_extension::display(std::ostream& out, xor const& x, bool values) const { out << "xor " << x.lit(); if (x.lit() != null_literal && values) { out << "@(" << value(x.lit()); @@ -1664,7 +1673,7 @@ namespace sat { out << "\n"; } - void card_extension::display(std::ostream& out, card& c, bool values) const { + void card_extension::display(std::ostream& out, card const& c, bool values) const { out << c.lit() << "[" << c.size() << "]"; if (c.lit() != null_literal && values) { out << "@(" << value(c.lit()); @@ -1728,10 +1737,7 @@ namespace sat { pb& p = index2pb(idx); out << "pb " << p.lit() << ": "; for (unsigned i = 0; i < p.size(); ++i) { - if (p[i].first != 1) { - out << p[i].first << " "; - } - out << p[i].second << " "; + out << p[i].first << "*" << p[i].second << " "; } out << ">= " << p.k(); } diff --git a/src/sat/card_extension.h b/src/sat/card_extension.h index 0e229a056..f7e54b843 100644 --- a/src/sat/card_extension.h +++ b/src/sat/card_extension.h @@ -74,7 +74,7 @@ namespace sat { unsigned m_max_sum; wliteral m_wlits[0]; public: - static size_t get_obj_size(unsigned num_lits) { return sizeof(card) + num_lits * sizeof(wliteral); } + static size_t get_obj_size(unsigned num_lits) { return sizeof(pb) + num_lits * sizeof(wliteral); } pb(unsigned index, literal lit, svector const& wlits, unsigned k); unsigned index() const { return m_index; } literal lit() const { return m_lit; } @@ -205,6 +205,8 @@ namespace sat { void reset_coeffs(); void reset_marked_literals(); void unwatch_literal(literal w, card* c); + void get_card_antecedents(literal l, card const& c, literal_vector & r); + // xor specific functionality void copy_xor(card_extension& result); @@ -218,13 +220,14 @@ namespace sat { lbool add_assign(xor& x, literal alit); void asserted_xor(literal l, ptr_vector* xors, xor* x); void get_xor_antecedents(literal l, unsigned index, justification js, literal_vector& r); + void get_xor_antecedents(literal l, xor const& x, literal_vector & r); - bool is_card_index(unsigned idx) const { return 0x00 == (idx & 0x11); } - bool is_xor_index(unsigned idx) const { return 0x01 == (idx & 0x11); } - bool is_pb_index(unsigned idx) const { return 0x11 == (idx & 0x11); } + bool is_card_index(unsigned idx) const { return 0x0 == (idx & 0x3); } + bool is_xor_index(unsigned idx) const { return 0x1 == (idx & 0x3); } + bool is_pb_index(unsigned idx) const { return 0x3 == (idx & 0x3); } card& index2card(unsigned idx) const { SASSERT(is_card_index(idx)); return *m_cards[idx >> 2]; } - xor& index2xor(unsigned idx) const { SASSERT(!is_card_index(idx)); return *m_xors[idx >> 2]; } + xor& index2xor(unsigned idx) const { SASSERT(is_xor_index(idx)); return *m_xors[idx >> 2]; } pb& index2pb(unsigned idx) const { SASSERT(is_pb_index(idx)); return *m_pbs[idx >> 2]; } @@ -285,9 +288,9 @@ namespace sat { bool validate_resolvent(); void display(std::ostream& out, ineq& p) const; - void display(std::ostream& out, card& c, bool values) const; - void display(std::ostream& out, pb& p, bool values) const; - void display(std::ostream& out, xor& c, bool values) const; + void display(std::ostream& out, card const& c, bool values) const; + void display(std::ostream& out, pb const& p, bool values) const; + void display(std::ostream& out, xor const& c, bool values) const; void display_watch(std::ostream& out, bool_var v) const; void display_watch(std::ostream& out, bool_var v, bool sign) const; diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index 45ce213b0..9d81ce886 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -38,6 +38,7 @@ namespace sat { m_num_threads = 1; m_local_search = 0; m_lookahead_search = false; + m_lookahead_simplify = false; m_ccc = false; updt_params(p); } @@ -83,7 +84,8 @@ namespace sat { m_max_conflicts = p.max_conflicts(); m_num_threads = p.threads(); m_local_search = p.local_search(); - m_local_search_threads = p.local_search_threads(); + m_local_search_threads = p.local_search_threads(); + m_lookahead_simplify = p.lookahead_simplify(); m_lookahead_search = p.lookahead_search(); m_ccc = p.ccc(); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index 2e3d4ec86..a34384e87 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -61,6 +61,7 @@ namespace sat { unsigned m_local_search_threads; bool m_local_search; bool m_lookahead_search; + bool m_lookahead_simplify; bool m_ccc; unsigned m_simplify_mult1; diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index f1d0833c4..d5abfd747 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -251,19 +251,16 @@ namespace sat { void del_binary(unsigned idx) { // TRACE("sat", display(tout << "Delete " << to_literal(idx) << "\n");); literal_vector & lits = m_binary[idx]; - if (lits.empty()) IF_VERBOSE(0, verbose_stream() << "empty literals\n";); + SASSERT(!lits.empty()); literal l = lits.back(); - lits.pop_back(); - if (m_binary[(~l).index()].back() != ~to_literal(idx)) { - IF_VERBOSE(0, verbose_stream() << "pop bad literal: " << idx << " " << (~l).index() << "\n";); - } - if (m_binary[(~l).index()].empty()) - IF_VERBOSE(0, verbose_stream() << "empty binary\n";); + lits.pop_back(); + SASSERT(!m_binary[(~l).index()].empty()); + IF_VERBOSE(0, if (m_binary[(~l).index()].back() != ~to_literal(idx)) verbose_stream() << "pop bad literal: " << idx << " " << (~l).index() << "\n";); + SASSERT(m_binary[(~l).index()].back() == ~to_literal(idx)); m_binary[(~l).index()].pop_back(); ++m_stats.m_del_binary; } - void validate_binary(literal l1, literal l2) { if (m_search_mode == lookahead_mode::searching) { m_assumptions.push_back(l1); @@ -1860,6 +1857,9 @@ namespace sat { return search(); } + /** + \brief simplify set of clauses by extracting units from a lookahead at base level. + */ void simplify() { SASSERT(m_prefix == 0); SASSERT(m_watches.empty()); @@ -1879,12 +1879,38 @@ namespace sat { ++num_units; } } - IF_VERBOSE(1, verbose_stream() << "units found: " << num_units << "\n";); + IF_VERBOSE(1, verbose_stream() << "(sat-lookahead :units " << num_units << ")\n";); m_s.m_simplifier.subsume(); m_lookahead.reset(); } + // + // there can be two sets of equivalence classes. + // example: + // a -> !b + // b -> !a + // c -> !a + // we pick as root the Boolean variable with the largest value. + // + literal get_root(bool_var v) { + literal lit(v, false); + literal r1 = get_parent(lit); + literal r2 = get_parent(literal(r1.var(), false)); + CTRACE("sat", r1 != get_parent(literal(r2.var(), false)), + tout << r1 << " " << r2 << "\n";); + SASSERT(r1.var() == get_parent(literal(r2.var(), false)).var()); + if (r1.var() >= r2.var()) { + return r1; + } + else { + return r1.sign() ? ~r2 : r2; + } + } + + /** + \brief extract equivalence classes of variables and simplify clauses using these. + */ void scc() { SASSERT(m_prefix == 0); SASSERT(m_watches.empty()); @@ -1905,8 +1931,7 @@ namespace sat { } for (unsigned i = 0; i < m_candidates.size(); ++i) { bool_var v = m_candidates[i].m_var; - literal lit = literal(v, false); - literal p = get_parent(lit); + literal p = get_root(v); if (p != null_literal && p.var() != v && !m_s.is_external(v) && !m_s.was_eliminated(v) && !m_s.was_eliminated(p.var())) { to_elim.push_back(v); roots[v] = p; @@ -1918,7 +1943,7 @@ namespace sat { } } } - IF_VERBOSE(1, verbose_stream() << "eliminate " << to_elim.size() << " variables\n";); + IF_VERBOSE(1, verbose_stream() << "(sat-lookahead :equivalences " << to_elim.size() << ")\n";); elim_eqs elim(m_s); elim(roots, to_elim); } diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 226c79642..2c09ecdc1 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -32,5 +32,6 @@ def_module_params('sat', ('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'), ('local_search', BOOL, False, 'use local search instead of CDCL'), ('lookahead_search', BOOL, False, 'use lookahead solver'), + ('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'), ('ccc', BOOL, False, 'use Concurrent Cube and Conquer solver') )) diff --git a/src/sat/sat_simplifier.cpp b/src/sat/sat_simplifier.cpp index efb9f5498..2e3fbdd77 100644 --- a/src/sat/sat_simplifier.cpp +++ b/src/sat/sat_simplifier.cpp @@ -223,6 +223,13 @@ namespace sat { } } + if (!learned && s.m_config.m_lookahead_simplify) { + // perform lookahead simplification + lookahead lh(s); + lh.simplify(); + lh.collect_statistics(s.m_aux_stats); + } + CASSERT("sat_solver", s.check_invariant()); TRACE("after_simplifier", s.display(tout); tout << "model_converter:\n"; s.m_mc.display(tout);); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index d42f4d41a..1374bf7c8 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1342,7 +1342,11 @@ namespace sat { CASSERT("sat_simplify_bug", check_invariant()); } - lookahead(*this).scc(); + if (m_config.m_lookahead_simplify) { + lookahead lh(*this); + lh.scc(); + lh.collect_statistics(m_aux_stats); + } sort_watch_lits(); CASSERT("sat_simplify_bug", check_invariant());