3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-05 17:14:07 +00:00

use clause structure for nary

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-10-10 11:49:31 -07:00
parent a0cd6e0fca
commit 8b32c15ac9
4 changed files with 455 additions and 18 deletions

View file

@ -98,7 +98,7 @@ namespace sat {
// apply substitution
for (i = 0; i < sz; i++) {
c[i] = norm(roots, c[i]);
SASSERT(!m_solver.was_eliminated(c[i].var()));
VERIFY(!m_solver.was_eliminated(c[i].var()));
}
std::sort(c.begin(), c.end());
for (literal l : c) VERIFY(l == norm(roots, l));

View file

@ -312,10 +312,11 @@ namespace sat {
}
bool lookahead::is_unsat() const {
bool all_false = true;
bool first = true;
// check if there is a clause whose literals are false.
// every clause is terminated by a null-literal.
#if OLD_NARY
bool all_false = true;
bool first = true;
for (unsigned l_idx : m_nary_literals) {
literal l = to_literal(l_idx);
if (first) {
@ -332,6 +333,15 @@ namespace sat {
all_false &= is_false(l);
}
}
#else
for (nary* n : m_nary_clauses) {
bool all_false = true;
for (literal l : *n) {
all_false &= is_false(l);
}
if (all_false) return true;
}
#endif
// check if there is a ternary whose literals are false.
for (unsigned idx = 0; idx < m_ternary.size(); ++idx) {
literal lit = to_literal(idx);
@ -366,10 +376,11 @@ namespace sat {
}
}
}
bool no_true = true;
bool first = true;
// check if there is a clause whose literals are false.
// every clause is terminated by a null-literal.
#if OLD_NARY
bool no_true = true;
bool first = true;
for (unsigned l_idx : m_nary_literals) {
literal l = to_literal(l_idx);
if (first) {
@ -385,6 +396,15 @@ namespace sat {
no_true &= !is_true(l);
}
}
#else
for (nary * n : m_nary_clauses) {
bool no_true = true;
for (literal l : *n) {
no_true &= !is_true(l);
}
if (no_true) return false;
}
#endif
// check if there is a ternary whose literals are false.
for (unsigned idx = 0; idx < m_ternary.size(); ++idx) {
literal lit = to_literal(idx);
@ -457,6 +477,7 @@ namespace sat {
sum += (literal_occs(b.m_u) + literal_occs(b.m_v)) / 8.0;
}
sz = m_nary_count[(~l).index()];
#if OLD_NARY
for (unsigned idx : m_nary[(~l).index()]) {
if (sz-- == 0) break;
literal lit;
@ -470,6 +491,9 @@ namespace sat {
unsigned len = m_nary_literals[idx];
sum += pow(0.5, len) * to_add / len;
}
#else
#endif
return sum;
}
@ -488,10 +512,17 @@ namespace sat {
}
sum += 0.25 * m_ternary_count[(~l).index()];
unsigned sz = m_nary_count[(~l).index()];
#if OLD_NARY
for (unsigned cls_idx : m_nary[(~l).index()]) {
if (sz-- == 0) break;
sum += pow(0.5, m_nary_literals[cls_idx]);
}
#else
for (nary * n : m_nary[(~l).index()]) {
if (sz-- == 0) break;
sum += pow(0.5, n->size());
}
#endif
return sum;
}
@ -866,8 +897,13 @@ namespace sat {
m_ternary.push_back(svector<binary>());
m_ternary_count.push_back(0);
m_ternary_count.push_back(0);
#if OLD_NARY
m_nary.push_back(unsigned_vector());
m_nary.push_back(unsigned_vector());
#else
m_nary.push_back(ptr_vector<nary>());
m_nary.push_back(ptr_vector<nary>());
#endif
m_nary_count.push_back(0);
m_nary_count.push_back(0);
m_bstamp.push_back(0);
@ -1254,8 +1290,10 @@ namespace sat {
// new n-ary clause managment
void lookahead::add_clause(clause const& c) {
SASSERT(c.size() > 3);
#if OLD_NARY
unsigned sz = c.size();
SASSERT(sz > 3);
unsigned idx = m_nary_literals.size();
m_nary_literals.push_back(sz);
for (literal l : c) {
@ -1264,7 +1302,15 @@ namespace sat {
m_nary[l.index()].push_back(idx);
SASSERT(m_nary_count[l.index()] == m_nary[l.index()].size());
}
m_nary_literals.push_back(null_literal.index());
m_nary_literals.push_back(null_literal.index());
#else
void * mem = m_allocator.allocate(nary::get_obj_size(c.size()));
nary * n = new (mem) nary(c.size(), c.begin());
m_nary_clauses.push_back(n);
for (literal l : c) {
m_nary[l.index()].push_back(n);
}
#endif
}
@ -1274,6 +1320,7 @@ namespace sat {
literal lit;
SASSERT(m_search_mode == lookahead_mode::searching);
#if OLD_NARY
for (unsigned idx : m_nary[(~l).index()]) {
if (sz-- == 0) break;
unsigned len = --m_nary_literals[idx];
@ -1323,12 +1370,69 @@ namespace sat {
}
}
}
#else
for (nary * n : m_nary[(~l).index()]) {
if (sz-- == 0) break;
unsigned len = n->dec_size();
if (m_inconsistent) continue;
if (len <= 1) continue; // already processed
// find the two unassigned literals, if any
if (len == 2) {
literal l1 = null_literal;
literal l2 = null_literal;
bool found_true = false;
for (literal lit : *n) {
if (!is_fixed(lit)) {
if (l1 == null_literal) {
l1 = lit;
}
else {
SASSERT(l2 == null_literal);
l2 = lit;
break;
}
}
else if (is_true(lit)) {
n->set_head(lit);
found_true = true;
break;
}
}
if (found_true) {
// skip, the clause will be removed when propagating on 'lit'
}
else if (l1 == null_literal) {
set_conflict();
}
else if (l2 == null_literal) {
// clause may get revisited during propagation, when l2 is true in this clause.
// m_removed_clauses.push_back(std::make_pair(~l, idx));
// remove_clause_at(~l, idx);
propagated(l1);
}
else {
// extract binary clause. A unary or empty clause may get revisited,
// but we skip it then because it is already handled as a binary clause.
// m_removed_clauses.push_back(std::make_pair(~l, idx)); // need to restore this clause.
// remove_clause_at(~l, idx);
try_add_binary(l1, l2);
}
}
}
#endif
// clauses where l is positive:
sz = m_nary_count[l.index()];
#if OLD_NARY
for (unsigned idx : m_nary[l.index()]) {
if (sz-- == 0) break;
remove_clause_at(l, idx);
}
#else
for (nary* n : m_nary[l.index()]) {
if (sz-- == 0) break;
remove_clause_at(l, *n);
}
#endif
}
void lookahead::propagate_clauses_lookahead(literal l) {
@ -1338,6 +1442,7 @@ namespace sat {
SASSERT(m_search_mode == lookahead_mode::lookahead1 ||
m_search_mode == lookahead_mode::lookahead2);
#if OLD_NARY
for (unsigned idx : m_nary[(~l).index()]) {
if (sz-- == 0) break;
literal l1 = null_literal;
@ -1404,9 +1509,75 @@ namespace sat {
}
}
}
#else
for (nary* n : m_nary[(~l).index()]) {
if (sz-- == 0) break;
literal l1 = null_literal;
literal l2 = null_literal;
bool found_true = false;
unsigned nonfixed = 0;
for (literal lit : *n) {
if (!is_fixed(lit)) {
++nonfixed;
if (l1 == null_literal) {
l1 = lit;
}
else if (l2 == null_literal) {
l2 = lit;
}
}
else if (is_true(lit)) {
found_true = true;
break;
}
}
if (found_true) {
// skip, the clause will be removed when propagating on 'lit'
}
else if (l1 == null_literal) {
set_conflict();
return;
}
else if (l2 == null_literal) {
propagated(l1);
}
else if (m_search_mode == lookahead_mode::lookahead2) {
continue;
}
else {
SASSERT(nonfixed >= 2);
SASSERT(m_search_mode == lookahead_mode::lookahead1);
switch (m_config.m_reward_type) {
case heule_schur_reward: {
double to_add = 0;
for (literal lit : *n) {
if (!is_fixed(lit)) {
to_add += literal_occs(lit);
}
}
m_lookahead_reward += pow(0.5, nonfixed) * to_add / nonfixed;
break;
}
case heule_unit_reward:
m_lookahead_reward += pow(0.5, nonfixed);
break;
case ternary_reward:
if (nonfixed == 2) {
m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()];
}
else {
m_lookahead_reward += (double)0.001;
}
break;
case unit_literal_reward:
break;
}
}
}
#endif
}
#if OLD_NARY
void lookahead::remove_clause_at(literal l, unsigned clause_idx) {
unsigned j = clause_idx;
literal lit;
@ -1429,21 +1600,50 @@ namespace sat {
}
UNREACHABLE();
}
#else
void lookahead::remove_clause_at(literal l, nary& n) {
for (literal lit : n) {
if (lit != l) {
remove_clause(lit, n);
}
}
}
void lookahead::remove_clause(literal l, nary& n) {
ptr_vector<nary>& pclauses = m_nary[l.index()];
unsigned sz = m_nary_count[l.index()]--;
for (unsigned i = sz; i > 0; ) {
--i;
if (&n == pclauses[i]) {
std::swap(pclauses[i], pclauses[sz-1]);
return;
}
}
UNREACHABLE();
}
#endif
void lookahead::restore_clauses(literal l) {
SASSERT(m_search_mode == lookahead_mode::searching);
// increase the length of clauses where l is negative
unsigned sz = m_nary_count[(~l).index()];
#if OLD_NARY
for (unsigned idx : m_nary[(~l).index()]) {
if (sz-- == 0) break;
++m_nary_literals[idx];
}
#else
for (nary* n : m_nary[(~l).index()]) {
if (sz-- == 0) break;
n->inc_size();
}
#endif
// add idx back to clause list where l is positive
// add them back in the same order as they were inserted
// in this way we can check that the clauses are the same.
sz = m_nary_count[l.index()];
#if OLD_NARY
unsigned_vector const& pclauses = m_nary[l.index()];
for (unsigned i = sz; i > 0; ) {
--i;
@ -1456,6 +1656,17 @@ namespace sat {
}
}
}
#else
ptr_vector<nary>& pclauses = m_nary[l.index()];
for (unsigned i = sz; i-- > 0; ) {
for (literal lit : *pclauses[i]) {
if (lit != l) {
// SASSERT(m_nary[lit.index()] == pclauses[i]);
m_nary_count[lit.index()]++;
}
}
}
#endif
}
void lookahead::propagate_clauses(literal l) {
@ -1527,7 +1738,7 @@ namespace sat {
// Sum_{ clause C that contains ~l } 1
double lookahead::literal_occs(literal l) {
double result = m_binary[l.index()].size();
unsigned_vector const& nclauses = m_nary[(~l).index()];
// unsigned_vector const& nclauses = m_nary[(~l).index()];
result += m_nary_count[(~l).index()];
result += m_ternary_count[(~l).index()];
return result;
@ -1684,7 +1895,7 @@ namespace sat {
return false;
#if 0
// no propagations are allowed to reduce clauses.
for (clause * cp : m_full_watches[l.index()]) {
for (nary * cp : m_nary[(~l).index()]) {
clause& c = *cp;
unsigned sz = c.size();
bool found = false;
@ -2026,6 +2237,7 @@ namespace sat {
}
}
#if OLD_NARY
for (unsigned l_idx : m_nary_literals) {
literal l = to_literal(l_idx);
if (first) {
@ -2041,6 +2253,12 @@ namespace sat {
out << l << " ";
}
}
#else
for (nary * n : m_nary_clauses) {
for (literal l : *n) out << l << " ";
out << "\n";
}
#endif
return out;
}

View file

@ -20,6 +20,7 @@ Notes:
#ifndef _SAT_LOOKAHEAD_H_
#define _SAT_LOOKAHEAD_H_
#define OLD_NARY 0
#include "sat_elim_eqs.h"
@ -129,6 +130,36 @@ namespace sat {
literal m_u, m_v;
};
class nary {
unsigned m_size; // number of non-false literals
size_t m_obj_size; // object size (counting all literals)
literal m_head; // head literal
literal m_literals[0]; // list of literals, put any true literal in head.
size_t num_lits() const {
return (m_obj_size - sizeof(nary)) / sizeof(literal);
}
public:
static size_t get_obj_size(unsigned sz) { return sizeof(nary) + sz * sizeof(literal); }
size_t obj_size() const { return m_obj_size; }
nary(unsigned sz, literal const* lits):
m_size(sz),
m_obj_size(get_obj_size(sz)) {
for (unsigned i = 0; i < sz; ++i) m_literals[i] = lits[i];
m_head = lits[0];
}
unsigned size() const { return m_size; }
unsigned dec_size() { SASSERT(m_size > 0); return --m_size; }
void inc_size() { SASSERT(m_size < num_lits()); ++m_size; }
literal get_head() const { return m_head; }
void set_head(literal l) { m_head = l; }
literal operator[](unsigned i) { SASSERT(i < num_lits()); return m_literals[i]; }
literal const* begin() const { return m_literals; }
literal const* end() const { return m_literals + num_lits(); }
// swap the true literal to the head.
// void swap(unsigned i, unsigned j) { SASSERT(i < num_lits() && j < num_lits()); std::swap(m_literals[i], m_literals[j]); }
};
struct cube_state {
bool m_first;
svector<bool> m_is_decision;
@ -160,11 +191,18 @@ namespace sat {
vector<svector<binary>> m_ternary; // lit |-> vector of ternary clauses
unsigned_vector m_ternary_count; // lit |-> current number of active ternary clauses for lit
#if OLD_NARY
vector<unsigned_vector> m_nary; // lit |-> vector of clause_id
unsigned_vector m_nary_count; // lit |-> number of valid clause_id in m_clauses2[lit]
unsigned_vector m_nary_literals; // the actual literals, clauses start at offset clause_id,
// the first entry is the current length, clauses are separated by a null_literal
#else
small_object_allocator m_allocator;
vector<ptr_vector<nary>> m_nary; // lit |-> vector of nary clauses
ptr_vector<nary> m_nary_clauses; // vector of all nary clauses
#endif
unsigned_vector m_nary_count; // lit |-> number of valid clause_id in m_nary[lit]
unsigned m_num_tc1;
unsigned_vector m_num_tc1_lim;
unsigned m_qhead; // propagation queue head
@ -410,15 +448,20 @@ namespace sat {
void propagate_clauses_searching(literal l);
void propagate_clauses_lookahead(literal l);
void restore_clauses(literal l);
#if OLD_NARY
void remove_clause(literal l, unsigned clause_idx);
void remove_clause_at(literal l, unsigned clause_idx);
#else
void remove_clause(literal l, nary& n);
void remove_clause_at(literal l, nary& n);
#endif
// ------------------------------------
// initialization
void init_var(bool_var v);
void init();
void copy_clauses(clause_vector const& clauses, bool learned);
nary * copy_clause(clause const& c);
// ------------------------------------
// search
@ -499,6 +542,12 @@ namespace sat {
~lookahead() {
m_s.rlimit().pop_child();
#if OLD_NARY
#else
for (nary* n : m_nary_clauses) {
m_allocator.deallocate(n->obj_size(), n);
}
#endif
}

View file

@ -3,11 +3,11 @@ Copyright (c) 2017 Microsoft Corporation
Module Name:
parallel_solver.cpp
parallel_tactic.cpp
Abstract:
Parallel solver in the style of Treengeling.
Parallel tactic in the style of Treengeling.
It assumes a solver that supports good lookaheads.
@ -20,13 +20,183 @@ Notes:
--*/
#include "util/scoped_ptr_vector.h"
#include "solver/solver.h"
#include "tactic/tactic.h"
class parallel_tactic : public tactic {
ref<solver> m_solver;
// parameters
unsigned m_conflicts_lower_bound;
unsigned m_conflicts_upper_bound;
unsigned m_conflicts_growth_rate;
unsigned m_conflicts_decay_rate;
unsigned m_num_threads;
unsigned m_max_conflicts;
sref_vector<solver> m_solvers;
scoped_ptr_vector<ast_manager> m_managers;
void init() {
m_conflicts_lower_bound = 1000;
m_conflicts_upper_bound = 10000;
m_conflicts_growth_rate = 150;
m_conflicts_decay_rate = 75;
m_max_conflicts = m_conflicts_lower_bound;
m_num_threads = omp_get_num_threads();
}
unsigned get_max_conflicts() {
return m_max_conflicts;
}
void set_max_conflicts(unsigned c) {
m_max_conflicts = c;
}
bool should_increase_conflicts() {
NOT_IMPLEMENTED_YET();
return false;
}
int pick_solvers() {
NOT_IMPLEMENTED_YET();
return 1;
}
void update_max_conflicts() {
if (should_increase_conflicts()) {
set_max_conflicts(std::min(m_conflicts_upper_bound, m_conflicts_growth_rate * get_max_conflicts() / 100));
}
else {
set_max_conflicts(std::max(m_conflicts_lower_bound, m_conflicts_decay_rate * get_max_conflicts() / 100));
}
}
lbool simplify(solver& s) {
params_ref p;
p.set_uint("sat.max_conflicts", 10);
p.set_bool("sat.lookahead_simplify", true);
s.updt_params(p);
lbool is_sat = s.check_sat(0,0);
p.set_uint("sat.max_conflicts", get_max_conflicts());
p.set_bool("sat.lookahead_simplify", false);
s.updt_params(p);
return is_sat;
}
lbool lookahead(solver& s) {
ast_manager& m = s.get_manager();
params_ref p;
p.set_uint("sat.lookahead.cube.cutoff", 1);
expr_ref_vector cubes(m);
while (true) {
expr_ref c = s.cube();
if (m.is_false(c)) {
break;
}
cubes.push_back(c);
}
if (cubes.empty()) {
return l_false;
}
for (unsigned i = 1; i < cubes.size(); ++i) {
ast_manager * new_m = alloc(ast_manager, m, !m.proof_mode());
solver* s1 = s.translate(*new_m, params_ref());
ast_translation translate(m, *new_m);
expr_ref cube(translate(cubes[i].get()), *new_m);
s1->assert_expr(cube);
#pragma omp critical (_solvers)
{
m_managers.push_back(new_m);
m_solvers.push_back(s1);
}
}
s.assert_expr(cubes[0].get());
return l_true;
}
lbool solve(solver& s) {
params_ref p;
p.set_uint("sat.max_conflicts", get_max_conflicts());
s.updt_params(p);
lbool is_sat = s.check_sat(0, 0);
return is_sat;
}
void remove_unsat(svector<int>& unsat) {
std::sort(unsat.begin(), unsat.end());
unsat.reverse();
DEBUG_CODE(for (unsigned i = 0; i + 1 < unsat.size(); ++i) SASSERT(unsat[i] > unsat[i+1]););
for (int i : unsat) {
m_solvers.erase(i);
}
unsat.reset();
}
lbool solve() {
while (true) {
int sz = pick_solvers();
if (sz == 0) {
return l_false;
}
svector<int> unsat;
int sat_index = -1;
// Simplify phase.
#pragma omp parallel for
for (int i = 0; i < sz; ++i) {
lbool is_sat = simplify(*m_solvers[i]);
switch (is_sat) {
case l_false: unsat.push_back(i); break;
case l_true: sat_index = i; break;
case l_undef: break;
}
}
if (sat_index != -1) return l_true; // TBD: extact model
sz -= unsat.size();
remove_unsat(unsat);
if (sz == 0) continue;
// Solve phase.
#pragma omp parallel for
for (int i = 0; i < sz; ++i) {
lbool is_sat = solve(*m_solvers[i]);
switch (is_sat) {
case l_false: unsat.push_back(i); break;
case l_true: sat_index = i; break;
case l_undef: break;
}
}
if (sat_index != -1) return l_true; // TBD: extact model
sz -= unsat.size();
remove_unsat(unsat);
if (sz == 0) continue;
// Split phase.
#pragma omp parallel for
for (int i = 0; i < sz; ++i) {
lbool is_sat = lookahead(*m_solvers[i]);
switch (is_sat) {
case l_false: unsat.push_back(i); break;
case l_true: break;
case l_undef: break;
}
}
remove_unsat(unsat);
update_max_conflicts();
}
return l_undef;
}
public:
parallel_tactic(solver* s) : m_solver(s) {}
parallel_tactic(solver* s) {
m_solvers.push_back(s); // clone it?
}
void operator ()(const goal_ref & g,goal_ref_buffer & result,model_converter_ref & mc,proof_converter_ref & pc,expr_dependency_ref & dep) {
NOT_IMPLEMENTED_YET();