3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 13:28:47 +00:00

fix local search

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-03-14 13:47:01 -07:00
parent 51951a3683
commit 5c6cef4735
3 changed files with 81 additions and 37 deletions

View file

@ -67,6 +67,16 @@ namespace sat {
lookahead2 // double lookahead
};
std::ostream& display(std::ostream& out, search_mode m) const {
switch (m) {
case search_mode::searching: return out << "searching";
case search_mode::lookahead1: return out << "lookahead1";
case search_mode::lookahead2: return out << "lookahead2";
default: break;
}
return out;
}
struct ternary {
ternary(literal u, literal v, literal w): m_u(u), m_v(v), m_w(w) {}
literal m_u, m_v, m_w;
@ -117,6 +127,7 @@ namespace sat {
inline bool is_true(literal l) const { return is_fixed(l) && !(bool)((m_stamp[l.var()] & 0x1) ^ l.sign()); }
inline void set_true(literal l) { m_stamp[l.var()] = m_level + l.sign(); }
inline void set_undef(literal l) { m_stamp[l.var()] = 0; }
lbool value(literal l) const { return is_undef(l) ? l_undef : is_true(l) ? l_true : l_false; }
// set the level within a scope of the search.
class scoped_level {
@ -172,6 +183,7 @@ namespace sat {
// ----------------------------------------
void add_binary(literal l1, literal l2) {
TRACE("sat", tout << "binary: " << l1 << " " << l2 << "\n";);
SASSERT(l1 != l2);
SASSERT(~l1 != l2);
m_binary[(~l1).index()].push_back(l2);
@ -180,6 +192,7 @@ namespace sat {
}
void del_binary(unsigned idx) {
// TRACE("sat", display(tout << "Delete " << to_literal(idx) << "\n"););
literal_vector & lits = m_binary[idx];
literal l = lits.back();
lits.pop_back();
@ -520,7 +533,7 @@ namespace sat {
literal m_active;
unsigned m_rank;
literal m_settled;
literal_vector m_settled;
vector<dfs_info> m_dfs;
void get_scc() {
@ -547,6 +560,7 @@ namespace sat {
// set nextp = 0?
m_rank = 0;
m_active = null_literal;
m_settled.reset();
TRACE("sat", display_dfs(tout););
}
void init_dfs_info(literal l) {
@ -631,6 +645,7 @@ namespace sat {
literal best = v;
float best_rating = get_rating(v);
set_rank(v, UINT_MAX);
m_settled.push_back(t);
while (t != v) {
SASSERT(t != ~v);
set_rank(t, UINT_MAX);
@ -671,6 +686,7 @@ namespace sat {
out << l << " := " << get_parent(l) << "\n";
out << ~l << " := " << get_parent(~l) << "\n";
}
return out;
}
// ------------------------------------
@ -693,7 +709,8 @@ namespace sat {
set_child(pp, null_literal);
unsigned h = 0;
literal w;
for (literal u = m_settled; u != null_literal; u = get_link(u)) {
for (unsigned i = 0; i < m_settled.size(); ++i) {
literal u = m_settled[i];
literal p = get_parent(u);
if (p != pp) {
h = 0;
@ -790,6 +807,15 @@ namespace sat {
erase_clause_watch(get_wlist(~c[1]), cls_off);
}
void del_clauses() {
clause * const* end = m_clauses.end();
clause * const * it = m_clauses.begin();
for (; it != end; ++it) {
m_cls_allocator.del_clause(*it);
}
}
void detach_ternary(literal l1, literal l2, literal l3) {
NOT_IMPLEMENTED_YET();
// there is a clause corresponding to a ternary watch group.
@ -812,6 +838,7 @@ namespace sat {
m_watches.push_back(watch_list());
m_bstamp.push_back(0);
m_bstamp.push_back(0);
m_stamp.push_back(0);
m_dfs.push_back(dfs_info());
m_dfs.push_back(dfs_info());
m_lits.push_back(lit_info());
@ -864,6 +891,7 @@ namespace sat {
literal l = s.m_trail[i];
assign(l);
}
TRACE("sat", s.display(tout); display(tout););
}
// ------------------------------------
@ -874,7 +902,6 @@ namespace sat {
m_trail_lim.push_back(m_trail.size());
m_retired_clause_lim.push_back(m_retired_clauses.size());
m_qhead_lim.push_back(m_qhead);
m_trail.push_back(lit);
m_search_modes.push_back(m_search_mode);
m_search_mode = searching;
scoped_level _sl(*this, level);
@ -889,38 +916,42 @@ namespace sat {
m_search_mode = m_search_modes.back();
m_search_modes.pop_back();
// not for lookahead
// unretire clauses
unsigned rsz = m_retired_clause_lim.back();
for (unsigned i = rsz; i < m_retired_clauses.size(); ++i) {
attach_clause(*m_retired_clauses[i]);
}
m_retired_clauses.resize(rsz);
m_retired_clause_lim.pop_back();
// m_search_mode == searching
// remove local binary clauses
unsigned old_sz = m_binary_trail_lim.back();
m_binary_trail_lim.pop_back();
for (unsigned i = old_sz; i < m_binary_trail.size(); ++i) {
del_binary(m_binary_trail[i]);
}
// not for lookahead.
// m_freevars only for main search
// undo assignments
for (unsigned i = m_trail.size(); i > m_trail_lim.size(); ) {
unsigned old_sz = m_trail_lim.back();
for (unsigned i = m_trail.size(); i > old_sz; ) {
--i;
literal l = m_trail[i];
set_undef(l);
TRACE("sat", tout << "inserting free var v" << l.var() << "\n";);
m_freevars.insert(l.var());
}
m_trail.shrink(m_trail_lim.size()); // reset assignment.
m_trail.shrink(old_sz); // reset assignment.
m_trail_lim.pop_back();
// not for lookahead
// unretire clauses
old_sz = m_retired_clause_lim.back();
for (unsigned i = old_sz; i < m_retired_clauses.size(); ++i) {
attach_clause(*m_retired_clauses[i]);
}
m_retired_clauses.resize(old_sz);
m_retired_clause_lim.pop_back();
// m_search_mode == searching
// remove local binary clauses
old_sz = m_binary_trail_lim.back();
for (unsigned i = m_binary_trail.size(); i > old_sz; ) {
del_binary(m_binary_trail[--i]);
}
m_binary_trail.shrink(old_sz);
m_binary_trail_lim.pop_back();
// reset propagation queue
m_qhead_lim.pop_back();
m_qhead = m_qhead_lim.back();
m_qhead_lim.pop_back();
}
void push_lookahead2(literal lit) {
@ -995,7 +1026,6 @@ namespace sat {
case watched::CLAUSE: {
clause_offset cls_off = it->get_clause_offset();
clause & c = *(s.m_cls_allocator.get_clause(cls_off));
TRACE("sat", tout << "propagating " << c << "\n";);
if (c[0] == ~l)
std::swap(c[0], c[1]);
if (is_true(c[0])) {
@ -1021,6 +1051,7 @@ namespace sat {
}
// normal clause was converted to a binary clause.
if (!found && is_undef(c[1]) && is_undef(c[0])) {
TRACE("sat", tout << "got binary " << l << ": " << c << "\n";);
switch (m_search_mode) {
case searching:
detach_clause(c);
@ -1036,9 +1067,11 @@ namespace sat {
break;
}
if (is_false(c[0])) {
TRACE("sat", tout << "conflict " << l << ": " << c << "\n";);
set_conflict();
}
else {
TRACE("sat", tout << "propagating " << l << ": " << c << "\n";);
SASSERT(is_undef(c[0]));
*it2 = *it;
it2++;
@ -1076,7 +1109,7 @@ namespace sat {
propagate_binary(l);
propagate_clauses(l);
}
TRACE("sat", s.display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n"););
TRACE("sat", display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n"););
}
literal choose() {
@ -1100,9 +1133,7 @@ namespace sat {
TRACE("sat", display_lookahead(tout); );
for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
if (!is_undef(lit)) {
continue;
}
TRACE("sat", tout << "lookahead " << lit << "\n";);
reset_wnb(lit);
push_lookahead1(lit, 2 + m_lookahead[i].m_offset);
bool unsat = inconsistent();
@ -1110,6 +1141,7 @@ namespace sat {
pop_lookahead1();
update_wnb(lit);
if (unsat) {
TRACE("sat", tout << "backtracking and settting " << ~lit << "\n";);
reset_wnb();
assign(~lit);
propagate();
@ -1150,13 +1182,13 @@ namespace sat {
if (mixd == h) ++count;
if (mixd > h || (mixd == h && s.m_rand(count) == 0)) {
CTRACE("sat", l != null_literal, tout << lit << " " << mixd << "\n";);
CTRACE("sat", l != null_literal, tout << lit << " mix diff: " << mixd << "\n";);
if (mixd > h) count = 1;
h = mixd;
l = diff1 < diff2 ? lit : ~lit;
}
}
TRACE("sat", tout << l << "\n";);
TRACE("sat", tout << "selected: " << l << "\n";);
return l;
}
@ -1169,7 +1201,7 @@ namespace sat {
}
void pop_lookahead1() {
SASSERT(!inconsistent());
m_inconsistent = false;
m_search_mode = m_search_modes.back();
m_search_modes.pop_back();
}
@ -1230,16 +1262,18 @@ namespace sat {
unsigned scope_lvl() const { return m_trail_lim.size(); }
void assign(literal l) {
TRACE("sat", tout << "assign: " << l << "\n";);
TRACE("sat", tout << "assign: " << l << " := " << value(l) << " @ " << m_level << " "; display(tout, m_search_mode) << "\n";);
SASSERT(m_level > 0);
if (is_undef(l)) {
set_true(l);
m_trail.push_back(l);
if (m_search_mode == searching) {
TRACE("sat", tout << "removing free var v" << l.var() << "\n";);
m_freevars.remove(l.var());
}
}
else if (is_false(l)) {
SASSERT(!is_true(l));
set_conflict();
}
}
@ -1269,6 +1303,7 @@ namespace sat {
}
lbool search() {
scoped_level _sl(*this, c_fixed_truth);
literal_vector trail;
m_search_mode = searching;
while (true) {
@ -1318,7 +1353,7 @@ namespace sat {
for (unsigned i = 0; i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
unsigned offset = m_lookahead[i].m_offset;
out << lit << " offset: " << offset;
out << lit << "\toffset: " << offset;
out << (is_undef(lit)?" undef": (is_true(lit) ? " true": " false"));
out << " wnb: " << get_wnb(lit);
out << "\n";
@ -1330,10 +1365,15 @@ namespace sat {
public:
lookahead(solver& s) :
s(s),
m_level(0) {
m_level(2),
m_prefix(0) {
scoped_level _sl(*this, c_fixed_truth);
init();
}
~lookahead() {
del_clauses();
}
lbool check() {
return search();
@ -1341,6 +1381,7 @@ namespace sat {
std::ostream& display(std::ostream& out) const {
out << std::hex << "Prefix: " << m_prefix << std::dec << "\n";
out << "Level: " << m_level << "\n";
display_values(out);
display_binary(out);
display_clauses(out);

View file

@ -11,7 +11,6 @@ void tst_sat_lookahead(char ** argv, int argc, int& i) {
reslimit limit;
params_ref params;
sat::solver solver(params, limit);
sat::lookahead lh(solver);
char const* file_name = argv[i + 1];
++i;
@ -24,6 +23,8 @@ void tst_sat_lookahead(char ** argv, int argc, int& i) {
parse_dimacs(in, solver);
}
sat::lookahead lh(solver);
IF_VERBOSE(20, solver.display_status(verbose_stream()););
std::cout << lh.check() << "\n";

View file

@ -344,11 +344,12 @@ public:
void insert(unsigned x) {
SASSERT(!contains(x));
m_index.resize(x + 1, UINT_MAX);
m_elems.resize(m_size + 1);
m_index.reserve(x + 1, UINT_MAX);
m_elems.reserve(m_size + 1);
m_index[x] = m_size;
m_elems[m_size] = x;
m_size++;
SASSERT(contains(x));
}
void remove(unsigned x) {
@ -361,6 +362,7 @@ public:
m_index[x] = m_size;
m_elems[m_size] = x;
}
SASSERT(!contains(x));
}
bool contains(unsigned x) const { return x < m_index.size() && m_index[x] < m_size && m_elems[m_index[x]] == x; }