From 7aeaf11ee44f03bb7b9ad72f729c92e36a01fba6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 5 Feb 2017 22:24:20 -0800 Subject: [PATCH] adding clause sharing to par mode Signed-off-by: Nikolaj Bjorner --- src/sat/card_extension.cpp | 112 ++++++++++++------------------- src/sat/card_extension.h | 1 - src/sat/sat_clause.cpp | 4 +- src/sat/sat_par.cpp | 133 ++++++++++++++++++++++++++++++++++++- src/sat/sat_par.h | 40 ++++++++++- src/sat/sat_solver.cpp | 43 +++++++----- src/sat/sat_solver.h | 7 +- 7 files changed, 249 insertions(+), 91 deletions(-) diff --git a/src/sat/card_extension.cpp b/src/sat/card_extension.cpp index 1273cfc29..c051ded15 100644 --- a/src/sat/card_extension.cpp +++ b/src/sat/card_extension.cpp @@ -181,35 +181,6 @@ namespace sat { s().set_conflict(justification::mk_ext_justification(c.index()), ~lit); SASSERT(s().inconsistent()); } - - literal card_extension::last_false_literal(card& c) { - while (!m_active_var_set.empty()) m_active_var_set.erase(); - reset_coeffs(); - for (unsigned i = 0; i < c.size(); ++i) { - bool_var v = c[i].var(); - m_active_var_set.insert(v); - m_active_vars.push_back(v); - m_coeffs.setx(v, c[i].sign() ? -1 : 1, 0); - } - literal_vector const& lits = s().m_trail; - for (unsigned i = lits.size(); i > 0; ) { - --i; - literal lit = lits[i]; - bool_var v = lit.var(); - if (m_active_var_set.contains(v) && - (m_coeffs[v] > 0 == lits[i].sign())) { - //std::cout << "last literal: " << lit << "\n"; - for (unsigned j = 0; j < c.size(); ++j) { - if (~lit == c[j] && j != c.k()-1) { - // std::cout << "POSITION " << j << " bound " << c.k() << "\n"; - } - } - return ~lit; - } - } - UNREACHABLE(); - return null_literal; - } void card_extension::normalize_active_coeffs() { while (!m_active_var_set.empty()) m_active_var_set.erase(); @@ -252,11 +223,9 @@ namespace sat { } // reduce coefficient to be no larger than bound. if (coeff1 > m_bound) { - //if (m_bound > 1) std::cout << m_bound << " " << coeff1 << "\n"; m_coeffs[v] = m_bound; } else if (coeff1 < 0 && -coeff1 > m_bound) { - //if (m_bound > 1) std::cout << m_bound << " " << coeff1 << "\n"; m_coeffs[v] = -m_bound; } } @@ -302,19 +271,19 @@ namespace sat { DEBUG_CODE(active2pb(m_A);); do { - // TRACE("sat", display(tout, m_A);); if (offset == 0) { goto process_next_resolvent; } // TBD: need proper check for overflow. if (offset > (1 << 12)) { - // std::cout << "offset: " << offset << "\n"; goto bail_out; } ++num_steps; + // TRACE("sat", display(tout, m_A);); + TRACE("sat", tout << "process consequent: " << consequent << ":\n"; s().display_justification(tout, js) << "\n";); SASSERT(offset > 0); SASSERT(m_bound >= 0); @@ -366,7 +335,7 @@ namespace sat { card& c = *m_constraints[index]; m_bound += offset * c.k(); if (!process_card(c, offset)) { - std::cout << "failed to process card\n"; + TRACE("sat", tout << "failed to process card\n";); goto bail_out; } break; @@ -378,11 +347,13 @@ namespace sat { SASSERT(validate_lemma()); + DEBUG_CODE( active2pb(m_C); - SASSERT(validate_resolvent()); + //SASSERT(validate_resolvent()); m_A = m_C;); + TRACE("sat", display(tout << "conflict:\n", m_A);); // cut(); process_next_resolvent: @@ -404,31 +375,32 @@ namespace sat { --m_num_marks; js = s().m_justification[v]; offset = get_abs_coeff(v); + if (offset > m_bound) { + m_coeffs[v] = (get_coeff(v) < 0) ? -m_bound : m_bound; + offset = m_bound; + // TBD: also adjust coefficient in m_A. + } SASSERT(value(consequent) == l_true); + } while (m_num_marks > 0); - - std::cout << m_num_propagations_since_pop << " " << num_steps << " " << num_card << "\n"; - // std::cout << consequent << "\n"; DEBUG_CODE(for (bool_var i = 0; i < static_cast(s().num_vars()); ++i) SASSERT(!s().is_marked(i));); SASSERT(validate_lemma()); normalize_active_coeffs(); + if (consequent == null_literal) { + return false; + } int slack = -m_bound; for (unsigned i = 0; i < m_active_vars.size(); ++i) { bool_var v = m_active_vars[i]; slack += get_abs_coeff(v); } - - TRACE("sat", display(tout, m_A);); - ++idx; consequent = null_literal; - - // std::cout << c.size() << " >= " << c.k() << "\n"; - // std::cout << m_active_vars.size() << ": " << slack + m_bound << " >= " << m_bound << "\n"; + ++idx; while (0 <= slack) { literal lit = lits[idx]; @@ -450,6 +422,11 @@ namespace sat { } else { m_lemma.push_back(~lit); + if (lvl(lit) == m_conflict_lvl) { + TRACE("sat", tout << "Bail out on no progress " << lit << "\n";); + IF_VERBOSE(1, verbose_stream() << "bail cardinality lemma\n";); + return false; + } } } } @@ -460,7 +437,6 @@ namespace sat { SASSERT(slack < 0); if (consequent == null_literal) { - std::cout << "null literal: " << m_lemma.empty() << "\n"; if (!m_lemma.empty()) return false; } else { @@ -473,8 +449,7 @@ namespace sat { svector ps; // TBD fill in s().m_drat.add(m_lemma, ps); } - - // std::cout << m_lemma << "\n"; + s().m_lemma.reset(); s().m_lemma.append(m_lemma); for (unsigned i = 1; i < m_lemma.size(); ++i) { @@ -575,7 +550,6 @@ namespace sat { } SASSERT(found);); - // std::cout << "antecedents: " << idx << ": " << l << " " << c.size() - c.k() + 1 << "\n"; r.push_back(c.lit()); SASSERT(value(c.lit()) == l_true); for (unsigned i = c.k(); i < c.size(); ++i) { @@ -788,17 +762,12 @@ namespace sat { } std::ostream& card_extension::display_justification(std::ostream& out, ext_justification_idx idx) const { - if (idx == 0) { - out << "conflict: " << m_lemma; - } - else { - card& c = *m_constraints[idx]; - out << "bound " << c.lit() << ": "; - for (unsigned i = c.k(); i < c.size(); ++i) { - out << c[i] << " "; - } - out << ">= " << c.k(); + card& c = *m_constraints[idx]; + out << "bound " << c.lit() << ": "; + for (unsigned i = 0; i < c.size(); ++i) { + out << c[i] << " "; } + out << ">= " << c.k(); return out; } @@ -932,22 +901,29 @@ namespace sat { literal lit = m_C.m_lits[i]; unsigned coeff; if (coeffs.find(lit.index(), coeff)) { - SASSERT(coeff <= m_C.m_coeffs[i] || m_C.m_coeffs[i] == m_C.m_k); + if (coeff > m_C.m_coeffs[i] && m_C.m_coeffs[i] < m_C.m_k) { + std::cout << i << ": " << m_C.m_coeffs[i] << " " << m_C.m_k << "\n"; + goto violated; + } coeffs.remove(lit.index()); } } - if (!coeffs.empty() || m_C.m_k > k) { - display(std::cout, m_A); - display(std::cout, m_B); - display(std::cout, m_C); - u_map::iterator it = coeffs.begin(), end = coeffs.end(); - for (; it != end; ++it) { - std::cout << to_literal(it->m_key) << ": " << it->m_value << "\n"; - } - } + if (!coeffs.empty()) goto violated; + if (m_C.m_k > k) goto violated; SASSERT(coeffs.empty()); SASSERT(m_C.m_k <= k); return true; + + violated: + display(std::cout, m_A); + display(std::cout, m_B); + display(std::cout, m_C); + u_map::iterator it = coeffs.begin(), end = coeffs.end(); + for (; it != end; ++it) { + std::cout << to_literal(it->m_key) << ": " << it->m_value << "\n"; + } + + return false; } bool card_extension::validate_conflict(literal_vector const& lits, ineq& p) { diff --git a/src/sat/card_extension.h b/src/sat/card_extension.h index 3434a280b..dbc8f7b07 100644 --- a/src/sat/card_extension.h +++ b/src/sat/card_extension.h @@ -105,7 +105,6 @@ namespace sat { lbool add_assign(card& c, literal lit); void watch_literal(card& c, literal lit); void set_conflict(card& c, literal lit); - literal last_false_literal(card& c); void clear_watch(card& c); void reset_coeffs(); diff --git a/src/sat/sat_clause.cpp b/src/sat/sat_clause.cpp index 1efbd6758..68af09ec7 100644 --- a/src/sat/sat_clause.cpp +++ b/src/sat/sat_clause.cpp @@ -198,13 +198,13 @@ namespace sat { size_t size = clause::get_obj_size(num_lits); void * mem = m_allocator.allocate(size); clause * cls = new (mem) clause(m_id_gen.mk(), num_lits, lits, learned); - TRACE("sat", tout << "alloc: " << cls->id() << " " << *cls << " " << (learned?"l":"a") << "\n";); + TRACE("sat_clause", tout << "alloc: " << cls->id() << " " << *cls << " " << (learned?"l":"a") << "\n";); SASSERT(!learned || cls->is_learned()); return cls; } void clause_allocator::del_clause(clause * cls) { - TRACE("sat", tout << "delete: " << cls->id() << " " << *cls << "\n";); + TRACE("sat_clause", tout << "delete: " << cls->id() << " " << *cls << "\n";); m_id_gen.recycle(cls->id()); #if defined(_AMD64_) #if defined(Z3DEBUG) diff --git a/src/sat/sat_par.cpp b/src/sat/sat_par.cpp index 7a185a3b5..585bd8bcc 100644 --- a/src/sat/sat_par.cpp +++ b/src/sat/sat_par.cpp @@ -17,13 +17,84 @@ Revision History: --*/ #include "sat_par.h" +#include "sat_clause.h" +#include "sat_solver.h" namespace sat { + void par::vector_pool::next(unsigned& index) { + SASSERT(index < m_size); + unsigned n = index + 2 + get_length(index); + if (n >= m_size) { + index = 0; + } + else { + index = n; + } + } + + void par::vector_pool::reserve(unsigned num_threads, unsigned sz) { + m_vectors.reset(); + m_vectors.resize(sz, 0); + m_heads.reset(); + m_heads.resize(num_threads, 0); + m_tail = 0; + m_size = sz; + } + + void par::vector_pool::begin_add_vector(unsigned owner, unsigned n) { + unsigned capacity = n + 2; + m_vectors.reserve(m_size + capacity, 0); + IF_VERBOSE(3, verbose_stream() << owner << ": begin-add " << n << " tail: " << m_tail << " size: " << m_size << "\n";); + if (m_tail >= m_size) { + // move tail to the front. + for (unsigned i = 0; i < m_heads.size(); ++i) { + while (m_heads[i] < capacity) { + next(m_heads[i]); + } + IF_VERBOSE(3, verbose_stream() << owner << ": head: " << m_heads[i] << "\n";); + } + m_tail = 0; + } + else { + for (unsigned i = 0; i < m_heads.size(); ++i) { + while (m_tail < m_heads[i] && m_heads[i] < m_tail + capacity) { + next(m_heads[i]); + } + IF_VERBOSE(3, verbose_stream() << owner << ": head: " << m_heads[i] << "\n";); + } + } + m_vectors[m_tail++] = owner; + m_vectors[m_tail++] = n; + } + + void par::vector_pool::add_vector_elem(unsigned e) { + m_vectors[m_tail++] = e; + } + + bool par::vector_pool::get_vector(unsigned owner, unsigned& n, unsigned const*& ptr) { + unsigned head = m_heads[owner]; + SASSERT(head < m_size); + while (head != m_tail) { + IF_VERBOSE(3, verbose_stream() << owner << ": head: " << head << " tail: " << m_tail << "\n";); + bool is_self = owner == get_owner(head); + next(m_heads[owner]); + if (!is_self) { + n = get_length(head); + ptr = get_ptr(head); + return true; + } + head = m_heads[owner]; + } + return false; + } + par::par() {} - void par::exchange(literal_vector const& in, unsigned& limit, literal_vector& out) { + void par::exchange(solver& s, literal_vector const& in, unsigned& limit, literal_vector& out) { + if (s.m_par_syncing_clauses) return; + flet _disable_sync_clause(s.m_par_syncing_clauses, true); #pragma omp critical (par_solver) { if (limit < m_units.size()) { @@ -40,6 +111,64 @@ namespace sat { limit = m_units.size(); } } - + + void par::share_clause(solver& s, literal l1, literal l2) { + if (s.m_par_syncing_clauses) return; + flet _disable_sync_clause(s.m_par_syncing_clauses, true); + #pragma omp critical (par_solver) + { + IF_VERBOSE(3, verbose_stream() << s.m_par_id << ": share " << l1 << " " << l2 << "\n";); + m_pool.begin_add_vector(s.m_par_id, 2); + m_pool.add_vector_elem(l1.index()); + m_pool.add_vector_elem(l2.index()); + } + } + + void par::share_clause(solver& s, clause const& c) { + if (s.m_par_syncing_clauses) return; + flet _disable_sync_clause(s.m_par_syncing_clauses, true); + unsigned n = c.size(); + unsigned owner = s.m_par_id; + #pragma omp critical (par_solver) + { + if (enable_add(c)) { + IF_VERBOSE(3, verbose_stream() << owner << ": share " << c << "\n";); + m_pool.begin_add_vector(owner, n); + for (unsigned i = 0; i < n; ++i) { + m_pool.add_vector_elem(c[i].index()); + } + } + } + } + + void par::get_clauses(solver& s) { + if (s.m_par_syncing_clauses) return; + flet _disable_sync_clause(s.m_par_syncing_clauses, true); + #pragma omp critical (par_solver) + { + _get_clauses(s); + } + } + + void par::_get_clauses(solver& s) { + unsigned n; + unsigned const* ptr; + unsigned owner = s.m_par_id; + while (m_pool.get_vector(owner, n, ptr)) { + m_lits.reset(); + for (unsigned i = 0; i < n; ++i) { + m_lits.push_back(to_literal(ptr[i])); + } + IF_VERBOSE(3, verbose_stream() << s.m_par_id << ": retrieve " << m_lits << "\n";); + SASSERT(n >= 2); + s.mk_clause_core(m_lits.size(), m_lits.c_ptr(), true); + } + } + + bool par::enable_add(clause const& c) const { + // plingeling, glucose heuristic: + return (c.size() <= 40 && c.glue() <= 8) || c.glue() <= 2; + } + }; diff --git a/src/sat/sat_par.h b/src/sat/sat_par.h index 2b2592de7..f76b47536 100644 --- a/src/sat/sat_par.h +++ b/src/sat/sat_par.h @@ -26,12 +26,50 @@ Revision History: namespace sat { class par { + + // shared pool of learned clauses. + class vector_pool { + unsigned_vector m_vectors; + unsigned m_size; + unsigned m_tail; + unsigned_vector m_heads; + void next(unsigned& index); + unsigned get_owner(unsigned index) const { return m_vectors[index]; } + unsigned get_length(unsigned index) const { return m_vectors[index+1]; } + unsigned const* get_ptr(unsigned index) const { return m_vectors.c_ptr() + index + 2; } + public: + vector_pool() {} + void reserve(unsigned num_owners, unsigned sz); + void begin_add_vector(unsigned owner, unsigned n); + void add_vector_elem(unsigned e); + bool get_vector(unsigned owner, unsigned& n, unsigned const*& ptr); + }; + + bool enable_add(clause const& c) const; + void _get_clauses(solver& s); + typedef hashtable index_set; literal_vector m_units; index_set m_unit_set; + literal_vector m_lits; + vector_pool m_pool; public: + par(); - void exchange(literal_vector const& in, unsigned& limit, literal_vector& out); + + // reserve space + void reserve(unsigned num_owners, unsigned sz) { m_pool.reserve(num_owners, sz); } + + // exchange unit literals + void exchange(solver& s, literal_vector const& in, unsigned& limit, literal_vector& out); + + // add clause to shared clause pool + void share_clause(solver& s, clause const& c); + + void share_clause(solver& s, literal l1, literal l2); + + // receive clauses from shared clause pool + void get_clauses(solver& s); }; }; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index dc9622a6b..beb53fd93 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -36,6 +36,8 @@ namespace sat { m_config(p), m_ext(ext), m_par(0), + m_par_syncing_clauses(false), + m_par_id(0), m_cleaner(*this), m_simplifier(*this, p), m_scc(*this, p), @@ -234,6 +236,7 @@ namespace sat { return 0; case 2: mk_bin_clause(lits[0], lits[1], learned); + if (learned && m_par) m_par->share_clause(*this, lits[0], lits[1]); return 0; case 3: return mk_ter_clause(lits, learned); @@ -836,6 +839,7 @@ namespace sat { vector rlims(num_extra_solvers); ptr_vector solvers(num_extra_solvers); sat::par par; + par.reserve(num_threads, 1 << 9); symbol saved_phase = m_params.get_sym("phase", symbol("caching")); for (int i = 0; i < num_extra_solvers; ++i) { m_params.set_uint("random_seed", m_rand()); @@ -844,10 +848,10 @@ namespace sat { } solvers[i] = alloc(sat::solver, m_params, rlims[i], 0); solvers[i]->copy(*this); - solvers[i]->set_par(&par); + solvers[i]->set_par(&par, i); scoped_rlimit.push_child(&solvers[i]->rlimit()); } - set_par(&par); + set_par(&par, num_extra_solvers); m_params.set_sym("phase", saved_phase); int finished_id = -1; std::string ex_msg; @@ -901,7 +905,7 @@ namespace sat { } } } - set_par(0); + set_par(0, 0); if (finished_id != -1 && finished_id < num_extra_solvers) { m_stats = solvers[finished_id]->m_stats; } @@ -923,7 +927,9 @@ namespace sat { \brief import lemmas/units from parallel sat solvers. */ void solver::exchange_par() { - if (m_par) { + if (m_par && at_base_lvl()) m_par->get_clauses(*this); + if (m_par && at_base_lvl()) { + // std::cout << scope_lvl() << " " << search_lvl() << "\n"; SASSERT(scope_lvl() == search_lvl()); // TBD: import also dependencies of assumptions. unsigned sz = init_trail_size(); @@ -937,7 +943,7 @@ namespace sat { } } m_par_limit_out = sz; - m_par->exchange(out, m_par_limit_in, in); + m_par->exchange(*this, out, m_par_limit_in, in); for (unsigned i = 0; !inconsistent() && i < in.size(); ++i) { literal lit = in[i]; SASSERT(lit.var() < m_par_num_vars); @@ -952,11 +958,13 @@ namespace sat { } } - void solver::set_par(par* p) { + void solver::set_par(par* p, unsigned id) { m_par = p; m_par_num_vars = num_vars(); m_par_limit_in = 0; m_par_limit_out = 0; + m_par_id = id; + m_par_syncing_clauses = false; } bool_var solver::next_var() { @@ -1855,10 +1863,11 @@ namespace sat { unsigned glue = num_diff_levels(m_lemma.size(), m_lemma.c_ptr()); pop_reinit(m_scope_lvl - new_scope_lvl); - TRACE("sat_conflict_detail", display(tout); tout << "assignment:\n"; display_assignment(tout);); + TRACE("sat_conflict_detail", tout << new_scope_lvl << "\n"; display(tout);); clause * lemma = mk_clause_core(m_lemma.size(), m_lemma.c_ptr(), true); if (lemma) { lemma->set_glue(glue); + if (m_par) m_par->share_clause(*this, *lemma); } decay_activity(); updt_phase_counters(); @@ -1881,8 +1890,7 @@ namespace sat { TRACE("sat", tout << "processing consequent: "; if (consequent == null_literal) tout << "null\n"; else tout << consequent << "\n"; - display_justification(tout << "js kind: ", js); - tout << "\n";); + display_justification(tout << "js kind: ", js) << "\n";); switch (js.get_kind()) { case justification::NONE: break; @@ -1962,8 +1970,7 @@ namespace sat { if (m_not_l != null_literal) { justification js = m_justification[m_not_l.var()]; TRACE("sat", tout << "not_l: " << m_not_l << "\n"; - display_justification(tout, js); - tout << "\n";); + display_justification(tout, js) << "\n";); process_antecedent_for_unsat_core(m_not_l); if (is_assumption(~m_not_l)) { @@ -2774,10 +2781,15 @@ namespace sat { void solver::display_units(std::ostream & out) const { unsigned end = m_trail.size(); // init_trail_size(); + unsigned level = 0; for (unsigned i = 0; i < end; i++) { - out << m_trail[i] << " "; - display_justification(out, m_justification[m_trail[i].var()]); - out << "\n"; + literal lit = m_trail[i]; + if (lvl(lit) > level) { + level = lvl(lit); + out << "level: " << level << " - "; + } + out << lit << " "; + display_justification(out, m_justification[lit.var()]) << "\n"; } //if (end != 0) // out << "\n"; @@ -2794,7 +2806,7 @@ namespace sat { out << ")\n"; } - void solver::display_justification(std::ostream & out, justification const& js) const { + std::ostream& solver::display_justification(std::ostream & out, justification const& js) const { out << js; if (js.is_clause()) { out << *(m_cls_allocator.get_clause(js.get_clause_offset())); @@ -2802,6 +2814,7 @@ namespace sat { else if (js.is_ext_justification() && m_ext) { m_ext->display_justification(out << " ", js.get_ext_justification_idx()); } + return out; } unsigned solver::num_clauses() const { diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index f43418d22..af1b8f213 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -134,9 +134,11 @@ namespace sat { literal_set m_assumption_set; // set of enabled assumptions literal_vector m_core; // unsat core + unsigned m_par_id; unsigned m_par_limit_in; unsigned m_par_limit_out; unsigned m_par_num_vars; + bool m_par_syncing_clauses; void del_clauses(clause * const * begin, clause * const * end); @@ -151,6 +153,7 @@ namespace sat { friend class mus; friend class drat; friend class card_extension; + friend class par; friend struct mk_stat; public: solver(params_ref const & p, reslimit& l, extension * ext); @@ -257,7 +260,7 @@ namespace sat { m_num_checkpoints = 0; if (memory::get_allocation_size() > m_config.m_max_memory) throw solver_exception(Z3_MAX_MEMORY_MSG); } - void set_par(par* p); + void set_par(par* p, unsigned id); bool canceled() { return !m_rlimit.inc(); } config const& get_config() { return m_config; } extension* get_extension() const { return m_ext.get(); } @@ -525,7 +528,7 @@ namespace sat { void display_dimacs(std::ostream & out) const; void display_wcnf(std::ostream & out, unsigned sz, literal const* lits, unsigned const* weights) const; void display_assignment(std::ostream & out) const; - void display_justification(std::ostream & out, justification const& j) const; + std::ostream& display_justification(std::ostream & out, justification const& j) const; protected: void display_binary(std::ostream & out) const;