3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

working on lookahead

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-02-27 10:59:59 -08:00
parent 388b025d9e
commit 88e7c240b7
4 changed files with 280 additions and 87 deletions

View file

@ -219,6 +219,10 @@ namespace sat {
lbool operator()();
lbool check(unsigned sz, literal const* assumptions) { return l_undef; } // TBD
void cancel() {} // TBD
local_search_config& config() { return m_config; }
local_search_config m_config;

View file

@ -31,6 +31,7 @@ namespace sat {
unsigned m_max_hlevel;
unsigned m_min_cutoff;
unsigned m_level_cand;
float m_delta_rho;
config() {
m_max_hlevel = 50;
@ -38,15 +39,39 @@ namespace sat {
m_max_score = 20.0;
m_min_cutoff = 30;
m_level_cand = 600;
m_delta_rho = (float)0.9995;
}
};
struct prefix {
unsigned m_prefix;
unsigned m_length;
prefix(): m_prefix(0), m_length(0) {}
};
struct lit_info {
float m_wnb;
unsigned m_double_lookahead;
lit_info(): m_wnb(0), m_double_lookahead(0) {}
};
struct statistics {
unsigned m_propagations;
statistics() { reset(); }
void reset() { memset(this, 0, sizeof(*this)); }
};
enum search_mode {
searching, // normal search
lookahead1, // lookahead mode
lookahead2 // double lookahead
};
struct ternary {
ternary(literal u, literal v, literal w): m_u(u), m_v(v), m_w(w) {}
literal m_u, m_v, m_w;
};
config m_config;
double m_delta_trigger;
@ -60,15 +85,25 @@ namespace sat {
unsigned m_qhead; // propagation queue head
unsigned_vector m_qhead_lim;
clause_vector m_clauses; // non-binary clauses
clause_vector m_retired_clauses; // clauses that were removed during search
svector<ternary> m_retired_ternary; //
unsigned_vector m_retired_clause_lim;
clause_allocator m_cls_allocator;
bool m_inconsistent;
unsigned_vector m_bstamp; // literal: timestamp for binary implication
vector<svector<float> > m_H; // literal: fitness score
svector<float>* m_heur; // current fitness
svector<float> m_rating; // var: pre-selection rating
unsigned m_bstamp_id; // unique id for binary implication.
unsigned m_istamp_id; // unique id for managing double lookaheads
char_vector m_assignment; // literal: assignment
vector<watch_list> m_watches; // literal: watch structure
svector<lit_info> m_lits; // literal: attributes.
float m_weighted_new_binaries; // metric associated with current lookahead1 literal.
svector<prefix> m_prefix; // var: prefix where variable participates in propagation
indexed_uint_set m_freevars;
svector<search_mode> m_search_modes; // stack of modes
search_mode m_search_mode; // mode of search
statistics m_stats;
void add_binary(literal l1, literal l2) {
@ -97,6 +132,15 @@ namespace sat {
m_bstamp.fill(0);
}
}
void inc_istamp() {
++m_istamp_id;
if (m_istamp_id == 0) {
++m_istamp_id;
for (unsigned i = 0; i < m_lits.size(); ++i) {
m_lits[i].m_double_lookahead = 0;
}
}
}
void set_bstamp(literal l) {
m_bstamp[l.index()] = m_bstamp_id;
}
@ -159,6 +203,15 @@ namespace sat {
// pre-selection
// see also 91 - 102 sat11.w
void pre_select() {
m_lookahead.reset();
if (select(scope_lvl())) {
get_scc();
find_heights();
construct_lookahead_table();
}
}
struct candidate {
bool_var m_var;
float m_rating;
@ -169,7 +222,7 @@ namespace sat {
float get_rating(bool_var v) const { return m_rating[v]; }
float get_rating(literal l) const { return get_rating(l.var()); }
bool_var select(unsigned level) {
bool select(unsigned level) {
init_pre_selection(level);
unsigned max_num_cand = level == 0 ? m_freevars.size() : m_config.m_level_cand / level;
max_num_cand = std::max(m_config.m_min_cutoff, max_num_cand);
@ -179,7 +232,7 @@ namespace sat {
sum = init_candidates(level, newbies);
if (!m_candidates.empty()) break;
if (is_sat()) {
return null_bool_var;
return false;
}
}
SASSERT(!m_candidates.empty());
@ -218,6 +271,7 @@ namespace sat {
}
}
SASSERT(!m_candidates.empty() && m_candidates.size() <= max_num_cand);
return true;
}
void sift_up(unsigned j) {
@ -238,12 +292,13 @@ namespace sat {
m_candidates.reset();
float sum = 0;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
SASSERT(l_undef == value(*it));
bool_var x = *it;
if (!newbies) {
// TBD filter out candidates based on prefix strings or similar method
}
m_candidates.push_back(candidate(x, m_rating[x]));
sum += m_rating[x];
sum += m_rating[x];
}
return sum;
}
@ -277,17 +332,17 @@ namespace sat {
h_scores(m_H[i + 1], m_H[(i + 2) % 3]);
}
}
// heur = m_H[1];
m_heur = &m_H[1];
}
else if (level < max_level) {
ensure_H(level);
h_scores(m_H[level-1], m_H[level]);
// heur = m_H[level];
m_heur = &m_H[level];
}
else {
ensure_H(max_level);
h_scores(m_H[max_level-1], m_H[max_level]);
// heur = m_H[max_level];
m_heur = &m_H[max_level];
}
}
@ -415,9 +470,7 @@ namespace sat {
}
}
void add_arc(literal u, literal v) { m_dfs[u.index()].m_next.push_back(v); }
bool has_arc(literal v) const {
return m_dfs[v.index()].m_next.size() > m_dfs[v.index()].m_nextp;
}
bool has_arc(literal v) const { return m_dfs[v.index()].m_next.size() > m_dfs[v.index()].m_nextp; }
literal pop_arc(literal u) { return m_dfs[u.index()].m_next[m_dfs[u.index()].m_nextp++]; }
unsigned num_next(literal u) const { return m_dfs[u.index()].m_next.size(); }
literal get_next(literal u, unsigned i) const { return m_dfs[u.index()].m_next[i]; }
@ -512,10 +565,6 @@ namespace sat {
else m_dfs[v.index()].m_min = u;
}
void construct_forest() {
find_heights();
construct_lookahead_table();
}
void find_heights() {
literal pp = null_literal;
set_child(pp, null_literal);
@ -562,7 +611,7 @@ namespace sat {
void construct_lookahead_table() {
literal u = get_child(null_literal), v = null_literal;
unsigned offset = 0;
m_lookahead.reset();
SASSERT(m_lookahead.empty());
while (u != null_literal) {
set_rank(u, m_lookahead.size());
set_lookahead(get_vcomp(u));
@ -591,7 +640,41 @@ namespace sat {
TRACE("sat", for (unsigned i = 0; i < m_lookahead.size(); ++i)
tout << m_lookahead[i].m_lit << " : " << m_lookahead[i].m_offset << "\n";);
}
// ------------------------------------
// clause management
void attach_clause(clause& c) {
if (false && c.size() == 3) { // disable ternary clauses
m_watches[(~c[0]).index()].push_back(watched(c[1], c[2]));
m_watches[(~c[1]).index()].push_back(watched(c[0], c[2]));
m_watches[(~c[2]).index()].push_back(watched(c[0], c[1]));
}
else {
literal block_lit = c[c.size() >> 2];
clause_offset cls_off = m_cls_allocator.get_offset(&c);
m_watches[(~c[0]).index()].push_back(watched(block_lit, cls_off));
m_watches[(~c[1]).index()].push_back(watched(block_lit, cls_off));
SASSERT(value(c[0]) == l_undef);
SASSERT(value(c[1]) == l_undef);
}
}
void detach_clause(clause& c) {
clause_offset cls_off = m_cls_allocator.get_offset(&c);
m_retired_clauses.push_back(&c);
erase_clause_watch(get_wlist(~c[0]), cls_off);
erase_clause_watch(get_wlist(~c[1]), cls_off);
}
void detach_ternary(literal l1, literal l2, literal l3) {
m_retired_ternary.push_back(ternary(l1, l2, l3));
erase_ternary_watch(get_wlist(~l1), l2, l3);
erase_ternary_watch(get_wlist(~l2), l1, l3);
erase_ternary_watch(get_wlist(~l3), l1, l2);
}
watch_list& get_wlist(literal l) { return m_watches[l.index()]; }
// ------------------------------------
// initialization
@ -608,6 +691,10 @@ namespace sat {
m_rating.push_back(0);
m_dfs.push_back(dfs_info());
m_dfs.push_back(dfs_info());
m_lits.push_back(lit_info());
m_lits.push_back(lit_info());
m_prefix.push_back(prefix());
m_freevars.insert(v);
}
void init() {
@ -642,8 +729,9 @@ namespace sat {
clause_vector::const_iterator end = s.m_clauses.end();
for (; it != end; ++it) {
clause& c = *(*it);
m_clauses.push_back(m_cls_allocator.mk_clause(c.size(), c.begin(), false));
// TBD: add watch
clause* c1 = m_cls_allocator.mk_clause(c.size(), c.begin(), false);
m_clauses.push_back(c1);
attach_clause(c);
}
// copy units
@ -654,18 +742,38 @@ namespace sat {
assign(l);
}
}
// ------------------------------------
// search
void push(literal lit) {
void push(literal lit, search_mode mode) {
m_binary_trail_lim.push_back(m_binary_trail.size());
m_units_lim.push_back(m_units.size());
m_trail_lim.push_back(m_trail.size());
m_retired_clause_lim.push_back(m_retired_clauses.size());
m_qhead_lim.push_back(m_qhead);
m_trail.push_back(lit);
m_search_modes.push_back(m_search_mode);
m_search_mode = mode;
assign(lit);
propagate();
}
void pop() {
m_inconsistent = false;
// search mode
m_search_mode = m_search_modes.back();
m_search_modes.pop_back();
// unretire clauses
unsigned rsz = m_retired_clause_lim.back();
for (unsigned i = rsz; i < m_retired_clauses.size(); ++i) {
attach_clause(*m_retired_clauses[i]);
}
m_retired_clauses.resize(rsz);
m_retired_clause_lim.pop_back();
// remove local binary clauses
unsigned old_sz = m_binary_trail_lim.back();
m_binary_trail_lim.pop_back();
@ -673,24 +781,31 @@ namespace sat {
del_binary(m_binary_trail[i]);
}
// undo assignments
for (unsigned i = m_trail.size(); i > m_trail_lim.size(); ) {
--i;
literal l = m_trail[i];
m_freevars.insert(l.var());
m_assignment[l.index()] = l_undef;
m_assignment[(~l).index()] = l_undef;
}
m_trail.shrink(m_trail_lim.size()); // reset assignment.
m_trail_lim.pop_back();
// add implied binary clauses
unsigned new_unit_sz = m_units_lim.back();
for (unsigned i = new_unit_sz; i < m_units.size(); ++i) {
add_binary(~m_trail.back(), m_units[i]);
try_add_binary(~m_trail.back(), m_units[i]);
}
m_units.shrink(new_unit_sz);
m_units_lim.pop_back();
m_trail.shrink(m_trail_lim.size()); // reset assignment.
m_trail_lim.pop_back();
// reset propagation queue
m_qhead_lim.pop_back();
m_qhead = m_qhead_lim.back();
}
m_inconsistent = false;
}
unsigned diff() const { return m_units.size() - m_units_lim.back(); }
unsigned mix_diff(unsigned l, unsigned r) const { return l + r + (1 << 10) * l * r; }
float mix_diff(float l, float r) const { return l + r + (1 << 10) * l * r; }
clause const& get_clause(watch_list::iterator it) const {
clause_offset cls_off = it->get_clause_offset();
@ -715,6 +830,7 @@ namespace sat {
UNREACHABLE();
break;
case watched::TERNARY: {
UNREACHABLE(); // we avoid adding ternary clauses for now.
literal l1 = it->get_literal1();
literal l2 = it->get_literal2();
lbool val1 = value(l1);
@ -731,7 +847,17 @@ namespace sat {
set_conflict();
}
else if (val1 == l_undef && val2 == l_undef) {
// TBD: the clause has become binary.
switch (m_search_mode) {
case searching:
detach_ternary(l, l1, l2);
try_add_binary(l1, l2);
break;
case lookahead1:
m_weighted_new_binaries += (*m_heur)[l1.index()] * (*m_heur)[l2.index()];
break;
case lookahead2:
break;
}
}
*it2 = *it;
it2++;
@ -750,29 +876,37 @@ namespace sat {
}
literal * l_it = c.begin() + 2;
literal * l_end = c.end();
unsigned found = 0;
for (; l_it != l_end && found < 2; ++l_it) {
bool found = false;
for (; l_it != l_end && !found; ++l_it) {
if (value(*l_it) != l_false) {
++found;
if (found == 2) {
break;
}
else {
c[1] = *l_it;
*l_it = ~l;
m_watches[(~c[1]).index()].push_back(watched(c[0], cls_off));
}
found = true;
c[1] = *l_it;
*l_it = ~l;
m_watches[(~c[1]).index()].push_back(watched(c[0], cls_off));
}
}
if (found == 1) {
// TBD: clause has become binary
if (found) {
found = false;
for (; l_it != l_end && !found; ++l_it) {
found = value(*l_it) != l_false;
}
// normal clause was converted to a binary clause.
if (!found && value(c[1]) == l_undef && value(c[0]) == l_undef) {
switch (m_search_mode) {
case searching:
detach_clause(c);
try_add_binary(c[0], c[1]);
break;
case lookahead1:
m_weighted_new_binaries += (*m_heur)[c[0].index()]* (*m_heur)[c[1].index()];
break;
case lookahead2:
break;
}
}
break;
}
if (found > 1) {
// not a binary clause
break;
}
else if (value(c[0]) == l_false) {
if (value(c[0]) == l_false) {
set_conflict();
}
else {
@ -829,41 +963,57 @@ namespace sat {
return l;
}
// TBD:
// Handle scope properly for nested implications.
// Suppose u -> v, and u -> w and we process v first, then the
// consequences of v should remain when processing u.
// March and sat11.w solve this by introducing timestamps on truth values.
// regular push/pop doesn't really work here: we basically need a traversal of the
// lookahead tree and push/pop according to that (or adapt timestamps)
//
bool choose1(literal& l) {
pre_select();
l = null_literal;
if (m_lookahead.empty()) {
return true;
}
unsigned h = 0, count = 1;
for (unsigned i = 0; i < m_lookahead.size(); ++i) {
float h = 0;
unsigned count = 1;
for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
if (value(lit) != l_undef) {
continue;
}
SASSERT(value(lit) == l_undef);
SASSERT(!inconsistent());
push(lit);
if (do_double()) double_look();
reset_wnb(lit);
push(lit, lookahead1);
do_double(lit);
if (inconsistent()) {
pop();
assign(~lit);
if (do_double()) double_look();
if (inconsistent()) return true;
propagate();
continue;
}
unsigned diff1 = diff();
update_wnb(lit);
float diff1 = m_weighted_new_binaries;
pop();
push(~lit);
if (do_double()) double_look();
bool unsat2 = inconsistent();
unsigned diff2 = diff();
pop();
if (unsat2) {
reset_wnb(~lit);
push(~lit, lookahead1);
do_double(~lit);
if (inconsistent()) {
pop();
assign(lit);
propagate();
continue;
}
update_wnb(~lit);
float diff2 = m_weighted_new_binaries;
pop();
unsigned mixd = mix_diff(diff1, diff2);
float mixd = mix_diff(diff1, diff2);
if (mixd > h || (mixd == h && s.m_rand(count) == 0)) {
CTRACE("sat", l != null_literal, tout << lit << " diff1: " << diff1 << " diff2: " << diff2 << "\n";);
@ -872,7 +1022,30 @@ namespace sat {
l = diff1 < diff2 ? lit : ~lit;
}
}
return l != null_literal;
return l != null_literal || inconsistent();
}
void set_wnb(literal l, float f) { m_lits[l.index()].m_wnb = f; }
void inc_wnb(literal l, float f) { m_lits[l.index()].m_wnb += f; }
float get_wnb(literal l) const { return m_lits[l.index()].m_wnb; }
bool dl_enabled(literal l) const { return m_lits[l.index()].m_double_lookahead != m_istamp_id; }
void dl_disable(literal l) { m_lits[l.index()].m_double_lookahead = m_istamp_id; }
void reset_wnb(literal l) {
m_weighted_new_binaries = 0;
// inherit propagation effect from parent.
literal p = get_parent(l);
set_wnb(l, p == null_literal ? 0 : get_wnb(p));
}
void update_wnb(literal l) {
if (m_weighted_new_binaries == 0) {
// TBD autarky
}
else {
inc_wnb(l, m_weighted_new_binaries);
}
}
void double_look() {
@ -881,7 +1054,7 @@ namespace sat {
literal lit = m_lookahead[i].m_lit;
if (value(lit) != l_undef) continue;
push(lit);
push(lit, lookahead2);
unsat = inconsistent();
pop();
if (unsat) {
@ -890,7 +1063,7 @@ namespace sat {
continue;
}
push(~lit);
push(~lit, lookahead2);
unsat = inconsistent();
pop();
if (unsat) {
@ -898,7 +1071,6 @@ namespace sat {
assign(lit);
}
}
update_delta_trigger();
}
bool is_fixed(literal l) const { return value(l) != l_undef; }
@ -906,6 +1078,7 @@ namespace sat {
bool is_true(literal l) const { return value(l) == l_true; }
void set_conflict() { m_inconsistent = true; }
lbool value(literal l) const { return static_cast<lbool>(m_assignment[l.index()]); }
lbool value(bool_var v) const { return value(literal(v, false)); }
unsigned scope_lvl() const { return m_trail_lim.size(); }
void assign(literal l) {
@ -919,14 +1092,13 @@ namespace sat {
m_assignment[l.index()] = l.sign() ? l_false : l_true;
m_assignment[(~l).index()] = l.sign() ? l_false : l_true;
m_trail.push_back(l);
m_freevars.remove(l.var());
break;
}
}
void set_inconsistent() { m_inconsistent = true; }
bool inconsistent() { return m_inconsistent; }
void select_variables(literal_vector& P) {
for (unsigned i = 0; i < s.num_vars(); ++i) {
if (value(literal(i,false)) == l_undef) {
@ -935,19 +1107,16 @@ namespace sat {
}
}
bool do_double() {
return !inconsistent() && diff() > m_delta_trigger;
}
void update_delta_trigger() {
if (inconsistent()) {
m_delta_trigger -= (1 - m_config.m_dl_success) / m_config.m_dl_success;
}
else {
m_delta_trigger += 1;
}
if (m_delta_trigger >= s.num_vars()) {
// reset it.
void do_double(literal l) {
if (!inconsistent() && scope_lvl() > 0 && dl_enabled(l)) {
if (get_wnb(l) > m_delta_trigger) {
double_look();
m_delta_trigger = get_wnb(l);
dl_disable(l);
}
else {
m_delta_trigger *= m_config.m_delta_rho;
}
}
}
@ -961,8 +1130,9 @@ namespace sat {
lbool search() {
literal_vector trail;
m_search_mode = searching;
while (true) {
inc_istamp();
s.checkpoint();
literal l = choose();
if (inconsistent()) {
@ -973,7 +1143,7 @@ namespace sat {
return l_true;
}
TRACE("sat", tout << "choose: " << l << " " << trail << "\n";);
push(l);
push(l, searching);
trail.push_back(l);
}
}

View file

@ -782,7 +782,7 @@ namespace sat {
pop_to_base_level();
IF_VERBOSE(2, verbose_stream() << "(sat.sat-solver)\n";);
SASSERT(at_base_lvl());
if (m_config.m_num_threads > 1 && !m_par) {
if ((m_config.m_num_threads > 1 || m_local_search) && !m_par) {
return check_par(num_lits, lits);
}
flet<bool> _searching(m_searching, true);
@ -854,9 +854,17 @@ namespace sat {
ERROR_EX
};
local_search& solver::init_local_search() {
if (!m_local_search) {
m_local_search = alloc(local_search, *this);
}
return *m_local_search.get();
}
lbool solver::check_par(unsigned num_lits, literal const* lits) {
int num_threads = static_cast<int>(m_config.m_num_threads);
int num_extra_solvers = num_threads - 1;
int num_extra_solvers = num_threads - 1 + (m_local_search ? 1 : 0);
sat::parallel par(*this);
par.reserve(num_threads, 1 << 12);
par.init_solvers(*this, num_extra_solvers);
@ -870,7 +878,10 @@ namespace sat {
for (int i = 0; i < num_threads; ++i) {
try {
lbool r = l_undef;
if (i < num_extra_solvers) {
if (m_local_search && i + 1 == num_extra_solvers) {
r = m_local_search->check(num_lits, lits);
}
else if (i < num_extra_solvers) {
r = par.get_solver(i).check(num_lits, lits);
}
else {
@ -886,6 +897,9 @@ namespace sat {
}
}
if (first) {
if (m_local_search) {
m_local_search->cancel();
}
for (int j = 0; j < num_extra_solvers; ++j) {
if (i != j) {
par.cancel_solver(j);

View file

@ -35,6 +35,7 @@ Revision History:
#include"sat_mus.h"
#include"sat_drat.h"
#include"sat_parallel.h"
#include"sat_local_search.h"
#include"params.h"
#include"statistics.h"
#include"stopwatch.h"
@ -89,6 +90,7 @@ namespace sat {
probing m_probing;
mus m_mus; // MUS for minimal core extraction
drat m_drat; // DRAT for generating proofs
scoped_ptr<local_search> m_local_search;
bool m_inconsistent;
bool m_searching;
// A conflict is usually a single justification. That is, a justification
@ -460,6 +462,9 @@ namespace sat {
lbool get_consequences(literal_vector const& assms, bool_var_vector const& vars, vector<literal_vector>& conseq);
// initialize and retrieve local search.
local_search& init_local_search();
private:
typedef hashtable<unsigned, u_hash, u_eq> index_set;