diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index a289a3814..12018f083 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -317,7 +317,7 @@ namespace sat { bool first = true; // check if there is a clause whose literals are false. // every clause is terminated by a null-literal. - for (unsigned l_idx : m_clause_literals) { + for (unsigned l_idx : m_nary_literals) { literal l = to_literal(l_idx); if (first) { // skip the first entry, the length indicator. @@ -385,7 +385,7 @@ namespace sat { bool first = true; // check if there is a clause whose literals are false. // every clause is terminated by a null-literal. - for (unsigned l_idx : m_clause_literals) { + for (unsigned l_idx : m_nary_literals) { literal l = to_literal(l_idx); if (first) { // skip the first entry, the length indicator. @@ -484,18 +484,18 @@ namespace sat { if (sz-- == 0) break; sum += (literal_occs(b.m_u) + literal_occs(b.m_v)) / 8.0; } - sz = m_clause_count[(~l).index()]; - for (unsigned idx : m_clauses[(~l).index()]) { + sz = m_nary_count[(~l).index()]; + for (unsigned idx : m_nary[(~l).index()]) { if (sz-- == 0) break; literal lit; unsigned j = idx; double to_add = 0; - while ((lit = to_literal(m_clause_literals[++j])) != null_literal) { + while ((lit = to_literal(m_nary_literals[++j])) != null_literal) { if (!is_fixed(lit) && lit != ~l) { to_add += literal_occs(lit); } } - unsigned len = m_clause_literals[idx]; + unsigned len = m_nary_literals[idx]; sum += pow(0.5, len) * to_add / len; } #else @@ -551,10 +551,10 @@ namespace sat { } #ifdef NEW_CLAUSE sum += 0.25 * m_ternary_count[(~l).index()]; - unsigned sz = m_clause_count[(~l).index()]; - for (unsigned cls_idx : m_clauses[(~l).index()]) { + unsigned sz = m_nary_count[(~l).index()]; + for (unsigned cls_idx : m_nary[(~l).index()]) { if (sz-- == 0) break; - sum += pow(0.5, m_clause_literals[cls_idx]); + sum += pow(0.5, m_nary_literals[cls_idx]); } #else watch_list& wlist = m_watches[l.index()]; @@ -973,10 +973,11 @@ namespace sat { } +#ifndef NEW_CLAUSE + // ------------------------------------ // clause management -#ifndef NEW_CLAUSE void lookahead::attach_clause(clause& c) { if (c.size() == 3) { attach_ternary(c[0], c[1], c[2]); @@ -1036,6 +1037,15 @@ namespace sat { #ifndef NEW_CLAUSE m_full_watches.push_back(clause_vector()); m_full_watches.push_back(clause_vector()); +#else + m_ternary.push_back(svector); + m_ternary.push_back(svector); + m_ternary_count.push_back(0); + m_ternary_count.push_back(0); + m_nary.push_back(unsigned_vector()); + m_nary.push_back(unsigned_vector()); + m_nary_count.push_back(0); + m_nary_count.push_back(0); #endif m_bstamp.push_back(0); m_bstamp.push_back(0); @@ -1076,8 +1086,8 @@ namespace sat { } } - copy_clauses(m_s.m_clauses); - copy_clauses(m_s.m_learned); + copy_clauses(m_s.m_clauses, false); + copy_clauses(m_s.m_learned, true); // copy units unsigned trail_sz = m_s.init_trail_size(); @@ -1103,7 +1113,7 @@ namespace sat { TRACE("sat", m_s.display(tout); display(tout);); } - void lookahead::copy_clauses(clause_vector const& clauses) { + void lookahead::copy_clauses(clause_vector const& clauses, bool learned) { // copy clauses #ifdef NEW_CLAUSE for (clause* cp : clauses) { @@ -1121,7 +1131,7 @@ namespace sat { case 1: assign(c[0]); break; case 2: add_binary(c[0],c[1]); break; case 3: add_ternary(c[0],c[1],c[2]); break; - default: add_clause(c); break; + default: if (!learned) add_clause(c); break; } if (m_s.m_config.m_drat) m_drat.add(c, false); } @@ -1184,8 +1194,6 @@ namespace sat { TRACE("sat", tout << "inserting free var v" << l.var() << "\n";); m_freevars.insert(l.var()); } - m_trail.shrink(old_sz); // reset assignment. - m_trail_lim.pop_back(); m_num_tc1 = m_num_tc1_lim.back(); m_num_tc1_lim.pop_back(); @@ -1213,6 +1221,9 @@ namespace sat { } #endif + m_trail.shrink(old_sz); // reset assignment. + m_trail_lim.pop_back(); + // remove local binary clauses old_sz = m_binary_trail_lim.back(); for (unsigned i = m_binary_trail.size(); i > old_sz; ) { @@ -1358,15 +1369,13 @@ namespace sat { void lookahead::propagate_ternary(literal l) { unsigned sz = m_ternary_count[(~l).index()]; - svector const& negs = m_ternary[(~l).index()]; switch (m_search_mode) { case lookahead_mode::searching: { // ternary clauses where l is negative become binary - - for (unsigned i = 0; i < sz; ++i) { - binary const& b = negs[i]; + for (binary const& b : m_ternary[(~l).index()]) { + if (sz-- == 0) break; // this could create a conflict from propagation, but we complete the transaction. literal l1 = b.m_u; literal l2 = b.m_v; @@ -1375,18 +1384,17 @@ namespace sat { try_add_binary(l1, l2); break; default: - // propagated or tautology. + // propagated or tautology or conflict break; } remove_ternary(l1, l2, l); remove_ternary(l2, l, l1); } - sz = m_ternary_count[l.index()]; - svector const& poss = m_ternary[l.index()]; - + + sz = m_ternary_count[l.index()]; // ternary clauses where l is positive are tautologies - for (unsigned i = 0; i < sz; ++i) { - binary const& b = poss[i]; + for (binary const& b : m_ternary[l.index()]) { + if (sz-- == 0) break; remove_ternary(b.m_u, b.m_v, l); remove_ternary(b.m_v, l, b.m_u); } @@ -1394,8 +1402,8 @@ namespace sat { } case lookahead_mode::lookahead1: // this could create a conflict from propagation, but we complete the loop. - for (unsigned i = 0; i < sz; ++i) { - binary const& b = negs[i]; + for (binary const& b : m_ternary[(~l).index()]) { + if (sz-- == 0) break; literal l1 = b.m_u; literal l2 = b.m_v; switch (propagate_ternary(l1, l2)) { @@ -1409,8 +1417,8 @@ namespace sat { break; case lookahead2: // this could create a conflict from propagation, but we complete the loop. - for (unsigned i = 0; i < sz; ++i) { - binary const& b = negs[i]; + for (binary const& b : m_ternary[(~l).index()]) { + if (sz-- == 0) break; propagate_ternary(b.m_u, b.m_v); } break; @@ -1435,7 +1443,7 @@ namespace sat { void lookahead::restore_ternary(literal l) { unsigned sz = m_ternary_count[(~l).index()]; for (binary const& b : m_ternary[(~l).index()]) { - if (sz-- == 0) break; + if (sz-- == 0) break; m_ternary_count[b.m_u.index()]++; m_ternary_count[b.m_v.index()]++; } @@ -1459,7 +1467,7 @@ namespace sat { lookahead_literal_occs_fun literal_occs_fn(*this); m_lookahead_reward += m_s.m_ext->get_reward(l, it->get_ext_constraint_idx(), literal_occs_fn); } - if (m_inconsistent) { + if (inconsistent()) { if (!keep) ++it; } else if (keep) { @@ -1479,26 +1487,59 @@ namespace sat { void lookahead::add_clause(clause const& c) { unsigned sz = c.size(); SASSERT(sz > 3); - unsigned idx = m_clause_literals.size(); - m_clause_literals.push_back(sz); + unsigned idx = m_nary_literals.size(); + m_nary_literals.push_back(sz); for (literal l : c) { - m_clause_literals.push_back(l.index()); - m_clause_count[l.index()]++; - m_clauses[l.index()].push_back(idx); + m_nary_literals.push_back(l.index()); + m_nary_count[l.index()]++; + m_nary[l.index()].push_back(idx); + SASSERT(m_nary_count[l.index()] == m_nary[l.index()].size()); } - m_clause_literals.push_back(null_literal.index()); + m_nary_literals.push_back(null_literal.index()); } +#if 0 + // split large clauses into smaller ones to avoid overhead during propagation. + + void lookahead::add_clause(unsigned sz, literal const * lits) { + if (sz > 6) { + bool_var v = m_s.mk_var(false); + ++m_num_vars; + init_var(v); + literal lit(v, false); + unsigned mid = sz / 2; + literal_vector lits1(mid, lits); + lits1.push_back(lit); + add_clause(lits1.size(), lits1.c_ptr()); + lit.neg(); + literal_vector lits2(sz - mid, lits + mid); + lits2.push_back(lit); + add_clause(lits2.size(), lits2.c_ptr()); + } + else { + unsigned idx = m_nary_literals.size(); + m_nary_literals.push_back(sz); + for (unsigned i = 0; i < sz; ++i) { + literal l = lits[i]; + m_nary_literals.push_back(l.index()); + m_nary_count[l.index()]++; + m_nary[l.index()].push_back(idx); + SASSERT(m_nary_count[l.index()] == m_nary[l.index()].size()); + } + m_nary_literals.push_back(null_literal.index()); + } + } +#endif + void lookahead::propagate_clauses_searching(literal l) { // clauses where l is negative - unsigned_vector const& nclauses = m_clauses[(~l).index()]; - unsigned sz = m_clause_count[(~l).index()]; + unsigned sz = m_nary_count[(~l).index()]; literal lit; SASSERT(m_search_mode == lookahead_mode::searching); - for (unsigned i = 0; i < sz; ++i) { - unsigned idx = nclauses[i]; - unsigned len = --m_clause_literals[idx]; + for (unsigned idx : m_nary[(~l).index()]) { + if (sz-- == 0) break; + unsigned len = --m_nary_literals[idx]; if (len <= 1) continue; // already processed // find the two unassigned literals, if any if (len == 2) { @@ -1506,7 +1547,7 @@ namespace sat { literal l2 = null_literal; unsigned j = idx; bool found_true = false; - while ((lit = to_literal(m_clause_literals[++j])) != null_literal) { + while ((lit = to_literal(m_nary_literals[++j])) != null_literal) { if (!is_fixed(lit)) { if (l1 == null_literal) { l1 = lit; @@ -1518,7 +1559,7 @@ namespace sat { } } else if (is_true(lit)) { - // can't swap with idx. std::swap(m_clause_literals[j], m_clause_literals[idx]); + // can't swap with idx. std::swap(m_nary_literals[j], m_nary_literals[idx]); found_true = true; break; } @@ -1529,7 +1570,7 @@ namespace sat { else if (l1 == null_literal) { set_conflict(); for (++i; i < sz; ++i) { - --m_clause_literals[nclauses[i]]; + --m_nary_literals[nclauses[i]]; } } else if (l2 == null_literal) { @@ -1548,29 +1589,28 @@ namespace sat { } } // clauses where l is positive: - unsigned_vector const& pclauses = m_clauses[l.index()]; - sz = m_clause_count[l.index()]; - for (unsigned i = 0; i < sz; ++i) { - remove_clause_at(l, pclauses[i]); + sz = m_nary_count[l.index()]; + for (unsigned idx : m_nary[l.index())) { + if (sz-- == 0) break; + remove_clause_at(l, idx); } } void lookahead::propagate_clauses_lookahead(literal l) { // clauses where l is negative - unsigned_vector const& nclauses = m_clauses[(~l).index()]; - unsigned sz = m_clause_count[(~l).index()]; + unsigned sz = m_nary_count[(~l).index()]; literal lit; SASSERT(m_search_mode == lookahead_mode::lookahead1 || m_search_mode == lookahead_mode::lookahead2); - - for (unsigned i = 0; i < sz; ++i) { - unsigned idx = nclauses[i]; + + for (unsigned idx : m_nary[(~l).index()]) { + if (sz-- == 0) break; literal l1 = null_literal; literal l2 = null_literal; unsigned j = idx; bool found_true = false; unsigned nonfixed = 0; - while ((lit = to_literal(m_clause_literals[++j])) != null_literal) { + while ((lit = to_literal(m_nary_literals[++j])) != null_literal) { if (!is_fixed(lit)) { ++nonfixed; if (l1 == null_literal) { @@ -1599,12 +1639,13 @@ namespace sat { continue; } else { - SASSERT (m_search_mode == lookahead_mode::lookahead1); + SASSERT(nonfixed >= 2); + SASSERT(m_search_mode == lookahead_mode::lookahead1); switch (m_config.m_reward_type) { case heule_schur_reward: { j = idx; double to_add = 0; - while ((lit = to_literal(m_clause_literals[++j])) != null_literal) { + while ((lit = to_literal(m_nary_literals[++j])) != null_literal) { if (!is_fixed(lit)) { to_add += literal_occs(lit); } @@ -1634,7 +1675,7 @@ namespace sat { void lookahead::remove_clause_at(literal l, unsigned clause_idx) { unsigned j = clause_idx; literal lit; - while ((lit = to_literal(m_clause_literals[++j])) != null_literal) { + while ((lit = to_literal(m_nary_literals[++j])) != null_literal) { if (lit != l) { remove_clause(lit, clause_idx); } @@ -1642,8 +1683,8 @@ namespace sat { } void lookahead::remove_clause(literal l, unsigned clause_idx) { - unsigned_vector& pclauses = m_clauses[l.index()]; - unsigned sz = m_clause_count[l.index()]--; + unsigned_vector& pclauses = m_nary[l.index()]; + unsigned sz = m_nary_count[l.index()]--; for (unsigned i = sz; i > 0; ) { --i; if (clause_idx == pclauses[i]) { @@ -1658,22 +1699,25 @@ namespace sat { SASSERT(m_search_mode == lookahead_mode::searching); // increase the length of clauses where l is negative - unsigned_vector const& nclauses = m_clauses[(~l).index()]; - unsigned sz = m_clause_count[(~l).index()]; - for (unsigned i = 0; i < sz; ++i) { - ++m_clause_literals[nclauses[i]]; + unsigned sz = m_nary_count[(~l).index()]; + for (unsigned idx : m_nary[(~l).index()]) { + if (sz-- == 0) break; + ++m_nary_literals[idx]; } // add idx back to clause list where l is positive - unsigned_vector const& pclauses = m_clauses[l.index()]; - sz = m_clause_count[l.index()]; - for (unsigned i = 0; i < sz; ++i) { - unsigned idx = pclauses[i]; - unsigned j = idx; + // add them back in the same order as they were inserted + // in this way we can check that the clauses are the same. + sz = m_nary_count[l.index()]; + unsigned_vector const& pclauses = m_nary[l.index()]; + for (unsigned i = sz; i > 0; ) { + --i; + unsigned j = pclauses[i]; literal lit; - while ((lit = to_literal(m_clause_literals[++j])) != null_literal) { + while ((lit = to_literal(m_nary_literals[++j])) != null_literal) { if (lit != l) { - m_clause_count[lit.index()]++; + SASSERT(m_nary[lit.index()] == pclauses[i]); + m_nary_count[lit.index()]++; } } } @@ -1905,8 +1949,8 @@ namespace sat { double lookahead::literal_occs(literal l) { double result = m_binary[l.index()].size(); #ifdef NEW_CLAUSE - unsigned_vector const& nclauses = m_clauses[(~l).index()]; - result += m_clause_count[(~l).index()]; + unsigned_vector const& nclauses = m_nary[(~l).index()]; + result += m_nary_count[(~l).index()]; result += m_ternary_count[(~l).index()]; #else for (clause const* c : m_full_watches[l.index()]) { @@ -2117,7 +2161,7 @@ namespace sat { #ifdef NEW_CLAUSE TRACE("sat", tout << "autarky: " << l << " @ " << m_stamp[l.var()] << " " - << (!m_binary[l.index()].empty() || m_clause_count[l.index()] != 0) << "\n";); + << (!m_binary[l.index()].empty() || m_nary_count[l.index()] != 0) << "\n";); #endif reset_lookahead_reward(); assign(l); @@ -2379,20 +2423,6 @@ namespace sat { std::ostream& lookahead::display_clauses(std::ostream& out) const { #ifdef NEW_CLAUSE bool first = true; - for (unsigned l_idx : m_clause_literals) { - literal l = to_literal(l_idx); - if (first) { - // skip the first entry, the length indicator. - first = false; - } - else if (l == null_literal) { - first = true; - out << "\n"; - } - else { - out << l << " "; - } - } for (unsigned idx = 0; idx < m_ternary.size(); ++idx) { literal lit = to_literal(idx); @@ -2405,6 +2435,22 @@ namespace sat { } } + for (unsigned l_idx : m_nary_literals) { + literal l = to_literal(l_idx); + if (first) { + // the first entry is a length indicator of non-false literals. + out << l_idx << ": "; + first = false; + } + else if (l == null_literal) { + first = true; + out << "\n"; + } + else { + out << l << " "; + } + } + #else for (unsigned i = 0; i < m_clauses.size(); ++i) { out << *m_clauses[i] << "\n"; @@ -2414,8 +2460,7 @@ namespace sat { } std::ostream& lookahead::display_values(std::ostream& out) const { - for (unsigned i = 0; i < m_trail.size(); ++i) { - literal l = m_trail[i]; + for (literal l : m_trail) { out << l << "\n"; } return out; diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index efd650e59..a3777ea44 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -153,15 +153,14 @@ namespace sat { #ifdef NEW_CLAUSE // specialized clause managemet uses ternary clauses and dedicated clause data-structure. - // this will replace m_clauses below + // this replaces m_clauses below vector> m_ternary; // lit |-> vector of ternary clauses - unsigned_vector m_ternary_count; // lit |-> current number of active ternary clauses for lit - unsigned_vector m_ternary_trail_lim; // limit for ternary vectors. + unsigned_vector m_ternary_count; // lit |-> current number of active ternary clauses for lit - vector m_clauses; // lit |-> vector of clause_id - unsigned_vector m_clause_count; // lit |-> number of valid clause_id in m_clauses2[lit] - unsigned_vector m_clause_literals; // the actual literals, clauses start at offset clause_id, - // the first entry is the current length, clauses are separated by a null_literal + vector m_nary; // lit |-> vector of clause_id + unsigned_vector m_nary_count; // lit |-> number of valid clause_id in m_clauses2[lit] + unsigned_vector m_nary_literals; // the actual literals, clauses start at offset clause_id, + // the first entry is the current length, clauses are separated by a null_literal #endif @@ -436,7 +435,7 @@ namespace sat { void init_var(bool_var v); void init(); - void copy_clauses(clause_vector const& clauses); + void copy_clauses(clause_vector const& clauses, bool learned); // ------------------------------------ // search