From 8b32c15ac9483c26d3ae74023c9072c8629b0dd0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 10 Oct 2017 11:49:31 -0700 Subject: [PATCH] use clause structure for nary Signed-off-by: Nikolaj Bjorner --- src/sat/sat_elim_eqs.cpp | 2 +- src/sat/sat_lookahead.cpp | 240 +++++++++++++++++++++-- src/sat/sat_lookahead.h | 53 ++++- src/tactic/portfolio/parallel_tactic.cpp | 178 ++++++++++++++++- 4 files changed, 455 insertions(+), 18 deletions(-) diff --git a/src/sat/sat_elim_eqs.cpp b/src/sat/sat_elim_eqs.cpp index 424de0e7c..7eb307f85 100644 --- a/src/sat/sat_elim_eqs.cpp +++ b/src/sat/sat_elim_eqs.cpp @@ -98,7 +98,7 @@ namespace sat { // apply substitution for (i = 0; i < sz; i++) { c[i] = norm(roots, c[i]); - SASSERT(!m_solver.was_eliminated(c[i].var())); + VERIFY(!m_solver.was_eliminated(c[i].var())); } std::sort(c.begin(), c.end()); for (literal l : c) VERIFY(l == norm(roots, l)); diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 11fb7f008..7ce2b53b9 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -312,10 +312,11 @@ namespace sat { } bool lookahead::is_unsat() const { - bool all_false = true; - bool first = true; // check if there is a clause whose literals are false. // every clause is terminated by a null-literal. +#if OLD_NARY + bool all_false = true; + bool first = true; for (unsigned l_idx : m_nary_literals) { literal l = to_literal(l_idx); if (first) { @@ -332,6 +333,15 @@ namespace sat { all_false &= is_false(l); } } +#else + for (nary* n : m_nary_clauses) { + bool all_false = true; + for (literal l : *n) { + all_false &= is_false(l); + } + if (all_false) return true; + } +#endif // check if there is a ternary whose literals are false. for (unsigned idx = 0; idx < m_ternary.size(); ++idx) { literal lit = to_literal(idx); @@ -366,10 +376,11 @@ namespace sat { } } } - bool no_true = true; - bool first = true; // check if there is a clause whose literals are false. // every clause is terminated by a null-literal. +#if OLD_NARY + bool no_true = true; + bool first = true; for (unsigned l_idx : m_nary_literals) { literal l = to_literal(l_idx); if (first) { @@ -385,6 +396,15 @@ namespace sat { no_true &= !is_true(l); } } +#else + for (nary * n : m_nary_clauses) { + bool no_true = true; + for (literal l : *n) { + no_true &= !is_true(l); + } + if (no_true) return false; + } +#endif // check if there is a ternary whose literals are false. for (unsigned idx = 0; idx < m_ternary.size(); ++idx) { literal lit = to_literal(idx); @@ -457,6 +477,7 @@ namespace sat { sum += (literal_occs(b.m_u) + literal_occs(b.m_v)) / 8.0; } sz = m_nary_count[(~l).index()]; +#if OLD_NARY for (unsigned idx : m_nary[(~l).index()]) { if (sz-- == 0) break; literal lit; @@ -470,6 +491,9 @@ namespace sat { unsigned len = m_nary_literals[idx]; sum += pow(0.5, len) * to_add / len; } +#else + +#endif return sum; } @@ -488,10 +512,17 @@ namespace sat { } sum += 0.25 * m_ternary_count[(~l).index()]; unsigned sz = m_nary_count[(~l).index()]; +#if OLD_NARY for (unsigned cls_idx : m_nary[(~l).index()]) { if (sz-- == 0) break; sum += pow(0.5, m_nary_literals[cls_idx]); } +#else + for (nary * n : m_nary[(~l).index()]) { + if (sz-- == 0) break; + sum += pow(0.5, n->size()); + } +#endif return sum; } @@ -866,8 +897,13 @@ namespace sat { m_ternary.push_back(svector()); m_ternary_count.push_back(0); m_ternary_count.push_back(0); +#if OLD_NARY m_nary.push_back(unsigned_vector()); m_nary.push_back(unsigned_vector()); +#else + m_nary.push_back(ptr_vector()); + m_nary.push_back(ptr_vector()); +#endif m_nary_count.push_back(0); m_nary_count.push_back(0); m_bstamp.push_back(0); @@ -1254,8 +1290,10 @@ namespace sat { // new n-ary clause managment void lookahead::add_clause(clause const& c) { + SASSERT(c.size() > 3); + +#if OLD_NARY unsigned sz = c.size(); - SASSERT(sz > 3); unsigned idx = m_nary_literals.size(); m_nary_literals.push_back(sz); for (literal l : c) { @@ -1264,7 +1302,15 @@ namespace sat { m_nary[l.index()].push_back(idx); SASSERT(m_nary_count[l.index()] == m_nary[l.index()].size()); } - m_nary_literals.push_back(null_literal.index()); + m_nary_literals.push_back(null_literal.index()); +#else + void * mem = m_allocator.allocate(nary::get_obj_size(c.size())); + nary * n = new (mem) nary(c.size(), c.begin()); + m_nary_clauses.push_back(n); + for (literal l : c) { + m_nary[l.index()].push_back(n); + } +#endif } @@ -1274,6 +1320,7 @@ namespace sat { literal lit; SASSERT(m_search_mode == lookahead_mode::searching); +#if OLD_NARY for (unsigned idx : m_nary[(~l).index()]) { if (sz-- == 0) break; unsigned len = --m_nary_literals[idx]; @@ -1323,12 +1370,69 @@ namespace sat { } } } +#else + for (nary * n : m_nary[(~l).index()]) { + if (sz-- == 0) break; + unsigned len = n->dec_size(); + if (m_inconsistent) continue; + if (len <= 1) continue; // already processed + // find the two unassigned literals, if any + if (len == 2) { + literal l1 = null_literal; + literal l2 = null_literal; + bool found_true = false; + for (literal lit : *n) { + if (!is_fixed(lit)) { + if (l1 == null_literal) { + l1 = lit; + } + else { + SASSERT(l2 == null_literal); + l2 = lit; + break; + } + } + else if (is_true(lit)) { + n->set_head(lit); + found_true = true; + break; + } + } + if (found_true) { + // skip, the clause will be removed when propagating on 'lit' + } + else if (l1 == null_literal) { + set_conflict(); + } + else if (l2 == null_literal) { + // clause may get revisited during propagation, when l2 is true in this clause. + // m_removed_clauses.push_back(std::make_pair(~l, idx)); + // remove_clause_at(~l, idx); + propagated(l1); + } + else { + // extract binary clause. A unary or empty clause may get revisited, + // but we skip it then because it is already handled as a binary clause. + // m_removed_clauses.push_back(std::make_pair(~l, idx)); // need to restore this clause. + // remove_clause_at(~l, idx); + try_add_binary(l1, l2); + } + } + } +#endif // clauses where l is positive: sz = m_nary_count[l.index()]; +#if OLD_NARY for (unsigned idx : m_nary[l.index()]) { if (sz-- == 0) break; remove_clause_at(l, idx); } +#else + for (nary* n : m_nary[l.index()]) { + if (sz-- == 0) break; + remove_clause_at(l, *n); + } +#endif } void lookahead::propagate_clauses_lookahead(literal l) { @@ -1338,6 +1442,7 @@ namespace sat { SASSERT(m_search_mode == lookahead_mode::lookahead1 || m_search_mode == lookahead_mode::lookahead2); +#if OLD_NARY for (unsigned idx : m_nary[(~l).index()]) { if (sz-- == 0) break; literal l1 = null_literal; @@ -1404,9 +1509,75 @@ namespace sat { } } } +#else + for (nary* n : m_nary[(~l).index()]) { + if (sz-- == 0) break; + literal l1 = null_literal; + literal l2 = null_literal; + bool found_true = false; + unsigned nonfixed = 0; + for (literal lit : *n) { + if (!is_fixed(lit)) { + ++nonfixed; + if (l1 == null_literal) { + l1 = lit; + } + else if (l2 == null_literal) { + l2 = lit; + } + } + else if (is_true(lit)) { + found_true = true; + break; + } + } + if (found_true) { + // skip, the clause will be removed when propagating on 'lit' + } + else if (l1 == null_literal) { + set_conflict(); + return; + } + else if (l2 == null_literal) { + propagated(l1); + } + else if (m_search_mode == lookahead_mode::lookahead2) { + continue; + } + else { + SASSERT(nonfixed >= 2); + SASSERT(m_search_mode == lookahead_mode::lookahead1); + switch (m_config.m_reward_type) { + case heule_schur_reward: { + double to_add = 0; + for (literal lit : *n) { + if (!is_fixed(lit)) { + to_add += literal_occs(lit); + } + } + m_lookahead_reward += pow(0.5, nonfixed) * to_add / nonfixed; + break; + } + case heule_unit_reward: + m_lookahead_reward += pow(0.5, nonfixed); + break; + case ternary_reward: + if (nonfixed == 2) { + m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()]; + } + else { + m_lookahead_reward += (double)0.001; + } + break; + case unit_literal_reward: + break; + } + } + } +#endif } - +#if OLD_NARY void lookahead::remove_clause_at(literal l, unsigned clause_idx) { unsigned j = clause_idx; literal lit; @@ -1429,21 +1600,50 @@ namespace sat { } UNREACHABLE(); } +#else + + void lookahead::remove_clause_at(literal l, nary& n) { + for (literal lit : n) { + if (lit != l) { + remove_clause(lit, n); + } + } + } + + void lookahead::remove_clause(literal l, nary& n) { + ptr_vector& pclauses = m_nary[l.index()]; + unsigned sz = m_nary_count[l.index()]--; + for (unsigned i = sz; i > 0; ) { + --i; + if (&n == pclauses[i]) { + std::swap(pclauses[i], pclauses[sz-1]); + return; + } + } + UNREACHABLE(); + } +#endif void lookahead::restore_clauses(literal l) { SASSERT(m_search_mode == lookahead_mode::searching); - // increase the length of clauses where l is negative unsigned sz = m_nary_count[(~l).index()]; +#if OLD_NARY for (unsigned idx : m_nary[(~l).index()]) { if (sz-- == 0) break; ++m_nary_literals[idx]; } - +#else + for (nary* n : m_nary[(~l).index()]) { + if (sz-- == 0) break; + n->inc_size(); + } +#endif // add idx back to clause list where l is positive // add them back in the same order as they were inserted // in this way we can check that the clauses are the same. sz = m_nary_count[l.index()]; +#if OLD_NARY unsigned_vector const& pclauses = m_nary[l.index()]; for (unsigned i = sz; i > 0; ) { --i; @@ -1456,6 +1656,17 @@ namespace sat { } } } +#else + ptr_vector& pclauses = m_nary[l.index()]; + for (unsigned i = sz; i-- > 0; ) { + for (literal lit : *pclauses[i]) { + if (lit != l) { + // SASSERT(m_nary[lit.index()] == pclauses[i]); + m_nary_count[lit.index()]++; + } + } + } +#endif } void lookahead::propagate_clauses(literal l) { @@ -1527,7 +1738,7 @@ namespace sat { // Sum_{ clause C that contains ~l } 1 double lookahead::literal_occs(literal l) { double result = m_binary[l.index()].size(); - unsigned_vector const& nclauses = m_nary[(~l).index()]; + // unsigned_vector const& nclauses = m_nary[(~l).index()]; result += m_nary_count[(~l).index()]; result += m_ternary_count[(~l).index()]; return result; @@ -1684,7 +1895,7 @@ namespace sat { return false; #if 0 // no propagations are allowed to reduce clauses. - for (clause * cp : m_full_watches[l.index()]) { + for (nary * cp : m_nary[(~l).index()]) { clause& c = *cp; unsigned sz = c.size(); bool found = false; @@ -2026,6 +2237,7 @@ namespace sat { } } +#if OLD_NARY for (unsigned l_idx : m_nary_literals) { literal l = to_literal(l_idx); if (first) { @@ -2041,6 +2253,12 @@ namespace sat { out << l << " "; } } +#else + for (nary * n : m_nary_clauses) { + for (literal l : *n) out << l << " "; + out << "\n"; + } +#endif return out; } diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 9a50dceed..2972bc167 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -20,6 +20,7 @@ Notes: #ifndef _SAT_LOOKAHEAD_H_ #define _SAT_LOOKAHEAD_H_ +#define OLD_NARY 0 #include "sat_elim_eqs.h" @@ -129,6 +130,36 @@ namespace sat { literal m_u, m_v; }; + class nary { + unsigned m_size; // number of non-false literals + size_t m_obj_size; // object size (counting all literals) + literal m_head; // head literal + literal m_literals[0]; // list of literals, put any true literal in head. + size_t num_lits() const { + return (m_obj_size - sizeof(nary)) / sizeof(literal); + } + public: + static size_t get_obj_size(unsigned sz) { return sizeof(nary) + sz * sizeof(literal); } + size_t obj_size() const { return m_obj_size; } + nary(unsigned sz, literal const* lits): + m_size(sz), + m_obj_size(get_obj_size(sz)) { + for (unsigned i = 0; i < sz; ++i) m_literals[i] = lits[i]; + m_head = lits[0]; + } + unsigned size() const { return m_size; } + unsigned dec_size() { SASSERT(m_size > 0); return --m_size; } + void inc_size() { SASSERT(m_size < num_lits()); ++m_size; } + literal get_head() const { return m_head; } + void set_head(literal l) { m_head = l; } + + literal operator[](unsigned i) { SASSERT(i < num_lits()); return m_literals[i]; } + literal const* begin() const { return m_literals; } + literal const* end() const { return m_literals + num_lits(); } + // swap the true literal to the head. + // void swap(unsigned i, unsigned j) { SASSERT(i < num_lits() && j < num_lits()); std::swap(m_literals[i], m_literals[j]); } + }; + struct cube_state { bool m_first; svector m_is_decision; @@ -160,11 +191,18 @@ namespace sat { vector> m_ternary; // lit |-> vector of ternary clauses unsigned_vector m_ternary_count; // lit |-> current number of active ternary clauses for lit +#if OLD_NARY vector m_nary; // lit |-> vector of clause_id - unsigned_vector m_nary_count; // lit |-> number of valid clause_id in m_clauses2[lit] unsigned_vector m_nary_literals; // the actual literals, clauses start at offset clause_id, // the first entry is the current length, clauses are separated by a null_literal +#else + small_object_allocator m_allocator; + vector> m_nary; // lit |-> vector of nary clauses + ptr_vector m_nary_clauses; // vector of all nary clauses +#endif + unsigned_vector m_nary_count; // lit |-> number of valid clause_id in m_nary[lit] + unsigned m_num_tc1; unsigned_vector m_num_tc1_lim; unsigned m_qhead; // propagation queue head @@ -410,15 +448,20 @@ namespace sat { void propagate_clauses_searching(literal l); void propagate_clauses_lookahead(literal l); void restore_clauses(literal l); +#if OLD_NARY void remove_clause(literal l, unsigned clause_idx); void remove_clause_at(literal l, unsigned clause_idx); - +#else + void remove_clause(literal l, nary& n); + void remove_clause_at(literal l, nary& n); +#endif // ------------------------------------ // initialization void init_var(bool_var v); void init(); void copy_clauses(clause_vector const& clauses, bool learned); + nary * copy_clause(clause const& c); // ------------------------------------ // search @@ -499,6 +542,12 @@ namespace sat { ~lookahead() { m_s.rlimit().pop_child(); +#if OLD_NARY +#else + for (nary* n : m_nary_clauses) { + m_allocator.deallocate(n->obj_size(), n); + } +#endif } diff --git a/src/tactic/portfolio/parallel_tactic.cpp b/src/tactic/portfolio/parallel_tactic.cpp index 497587cd3..c8cb7e2ee 100644 --- a/src/tactic/portfolio/parallel_tactic.cpp +++ b/src/tactic/portfolio/parallel_tactic.cpp @@ -3,11 +3,11 @@ Copyright (c) 2017 Microsoft Corporation Module Name: - parallel_solver.cpp + parallel_tactic.cpp Abstract: - Parallel solver in the style of Treengeling. + Parallel tactic in the style of Treengeling. It assumes a solver that supports good lookaheads. @@ -20,13 +20,183 @@ Notes: --*/ +#include "util/scoped_ptr_vector.h" #include "solver/solver.h" #include "tactic/tactic.h" class parallel_tactic : public tactic { - ref m_solver; + + // parameters + unsigned m_conflicts_lower_bound; + unsigned m_conflicts_upper_bound; + unsigned m_conflicts_growth_rate; + unsigned m_conflicts_decay_rate; + unsigned m_num_threads; + + unsigned m_max_conflicts; + + sref_vector m_solvers; + scoped_ptr_vector m_managers; + + void init() { + m_conflicts_lower_bound = 1000; + m_conflicts_upper_bound = 10000; + m_conflicts_growth_rate = 150; + m_conflicts_decay_rate = 75; + m_max_conflicts = m_conflicts_lower_bound; + m_num_threads = omp_get_num_threads(); + } + + unsigned get_max_conflicts() { + return m_max_conflicts; + } + + void set_max_conflicts(unsigned c) { + m_max_conflicts = c; + } + + bool should_increase_conflicts() { + NOT_IMPLEMENTED_YET(); + return false; + } + + int pick_solvers() { + NOT_IMPLEMENTED_YET(); + return 1; + } + + void update_max_conflicts() { + if (should_increase_conflicts()) { + set_max_conflicts(std::min(m_conflicts_upper_bound, m_conflicts_growth_rate * get_max_conflicts() / 100)); + } + else { + set_max_conflicts(std::max(m_conflicts_lower_bound, m_conflicts_decay_rate * get_max_conflicts() / 100)); + } + } + + lbool simplify(solver& s) { + params_ref p; + p.set_uint("sat.max_conflicts", 10); + p.set_bool("sat.lookahead_simplify", true); + s.updt_params(p); + lbool is_sat = s.check_sat(0,0); + p.set_uint("sat.max_conflicts", get_max_conflicts()); + p.set_bool("sat.lookahead_simplify", false); + s.updt_params(p); + return is_sat; + } + + lbool lookahead(solver& s) { + ast_manager& m = s.get_manager(); + params_ref p; + p.set_uint("sat.lookahead.cube.cutoff", 1); + expr_ref_vector cubes(m); + while (true) { + expr_ref c = s.cube(); + if (m.is_false(c)) { + break; + } + cubes.push_back(c); + } + if (cubes.empty()) { + return l_false; + } + for (unsigned i = 1; i < cubes.size(); ++i) { + ast_manager * new_m = alloc(ast_manager, m, !m.proof_mode()); + solver* s1 = s.translate(*new_m, params_ref()); + ast_translation translate(m, *new_m); + expr_ref cube(translate(cubes[i].get()), *new_m); + s1->assert_expr(cube); + + #pragma omp critical (_solvers) + { + m_managers.push_back(new_m); + m_solvers.push_back(s1); + } + } + s.assert_expr(cubes[0].get()); + return l_true; + } + + lbool solve(solver& s) { + params_ref p; + p.set_uint("sat.max_conflicts", get_max_conflicts()); + s.updt_params(p); + lbool is_sat = s.check_sat(0, 0); + return is_sat; + } + + void remove_unsat(svector& unsat) { + std::sort(unsat.begin(), unsat.end()); + unsat.reverse(); + DEBUG_CODE(for (unsigned i = 0; i + 1 < unsat.size(); ++i) SASSERT(unsat[i] > unsat[i+1]);); + for (int i : unsat) { + m_solvers.erase(i); + } + unsat.reset(); + } + + lbool solve() { + while (true) { + int sz = pick_solvers(); + + if (sz == 0) { + return l_false; + } + svector unsat; + int sat_index = -1; + + // Simplify phase. + #pragma omp parallel for + for (int i = 0; i < sz; ++i) { + lbool is_sat = simplify(*m_solvers[i]); + switch (is_sat) { + case l_false: unsat.push_back(i); break; + case l_true: sat_index = i; break; + case l_undef: break; + } + } + if (sat_index != -1) return l_true; // TBD: extact model + sz -= unsat.size(); + remove_unsat(unsat); + if (sz == 0) continue; + + // Solve phase. + #pragma omp parallel for + for (int i = 0; i < sz; ++i) { + lbool is_sat = solve(*m_solvers[i]); + switch (is_sat) { + case l_false: unsat.push_back(i); break; + case l_true: sat_index = i; break; + case l_undef: break; + } + } + if (sat_index != -1) return l_true; // TBD: extact model + sz -= unsat.size(); + remove_unsat(unsat); + if (sz == 0) continue; + + // Split phase. + #pragma omp parallel for + for (int i = 0; i < sz; ++i) { + lbool is_sat = lookahead(*m_solvers[i]); + switch (is_sat) { + case l_false: unsat.push_back(i); break; + case l_true: break; + case l_undef: break; + } + } + remove_unsat(unsat); + + update_max_conflicts(); + } + return l_undef; + } + public: - parallel_tactic(solver* s) : m_solver(s) {} + parallel_tactic(solver* s) { + m_solvers.push_back(s); // clone it? + } void operator ()(const goal_ref & g,goal_ref_buffer & result,model_converter_ref & mc,proof_converter_ref & pc,expr_dependency_ref & dep) { NOT_IMPLEMENTED_YET();