3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-07-24 05:08:55 +00:00

local search updates

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-03-05 14:40:58 -08:00
parent a7db118ebc
commit fda5809c89
7 changed files with 364 additions and 249 deletions

View file

@ -77,8 +77,6 @@ namespace sat {
literal_vector m_trail; // trail of units
unsigned_vector m_trail_lim;
literal_vector m_units; // units learned during lookahead
unsigned_vector m_units_lim;
vector<literal_vector> m_binary; // literal: binary clauses
unsigned_vector m_binary_trail; // trail of added binary clauses
unsigned_vector m_binary_trail_lim;
@ -96,7 +94,9 @@ namespace sat {
svector<float> m_rating; // var: pre-selection rating
unsigned m_bstamp_id; // unique id for binary implication.
unsigned m_istamp_id; // unique id for managing double lookaheads
char_vector m_assignment; // literal: assignment
unsigned_vector m_stamp; // var: timestamp with truth value
unsigned m_level; // current level, = 2 * m_trail_lim.size()
const unsigned c_fixed_truth = UINT_MAX - 1;
vector<watch_list> m_watches; // literal: watch structure
svector<lit_info> m_lits; // literal: attributes.
float m_weighted_new_binaries; // metric associated with current lookahead1 literal.
@ -106,6 +106,33 @@ namespace sat {
search_mode m_search_mode; // mode of search
statistics m_stats;
// ---------------------------------------
// truth values
inline bool is_fixed(literal l) const { return m_stamp[l.var()] >= m_level; }
inline bool is_undef(literal l) const { return !is_fixed(l); }
inline bool is_undef(bool_var v) const { return m_stamp[v] < m_level; }
inline bool is_false(literal l) const { return is_fixed(l) && (bool)((m_stamp[l.var()] & 0x1) ^ l.sign()); } // even iff l.sign()
inline bool is_true(literal l) const { return is_fixed(l) && !(bool)((m_stamp[l.var()] & 0x1) ^ l.sign()); }
inline void set_true(literal l) { m_stamp[l.var()] = m_level + l.sign(); }
inline void set_undef(literal l) { m_stamp[l.var()] = 0; }
// set the level within a scope of the search.
class scoped_level {
lookahead& m_parent;
unsigned m_save;
public:
scoped_level(lookahead& p, unsigned l):
m_parent(p), m_save(p.m_level) {
p.m_level = l;
}
~scoped_level() {
m_parent.m_level = m_save;
}
};
// ----------------------------------------
void add_binary(literal l1, literal l2) {
SASSERT(l1 != l2);
SASSERT(~l1 != l2);
@ -182,6 +209,7 @@ namespace sat {
\brief main routine for adding a new binary clause dynamically.
*/
void try_add_binary(literal u, literal v) {
SASSERT(m_search_mode == searching);
SASSERT(u.var() != v.var());
set_bstamps(~u);
if (is_stamped(~v)) {
@ -292,7 +320,7 @@ namespace sat {
m_candidates.reset();
float sum = 0;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
SASSERT(l_undef == value(*it));
SASSERT(is_undef(*it));
bool_var x = *it;
if (!newbies) {
// TBD filter out candidates based on prefix strings or similar method
@ -376,20 +404,13 @@ namespace sat {
float sum = 0, tsum = 0;
literal_vector::iterator it = m_binary[l.index()].begin(), end = m_binary[l.index()].end();
for (; it != end; ++it) {
if (is_free(*it)) sum += h[it->index()];
if (is_undef(*it)) sum += h[it->index()];
}
// TBD walk ternary clauses.
sum = (float)(0.1 + afactor*sum + sqfactor*tsum);
return std::min(m_config.m_max_score, sum);
}
bool is_free(literal l) const {
return !is_unit(l);
}
bool is_unit(literal l) const {
return false; // TBD track variables that are units
}
// ------------------------------------
// Implication graph
// Compute implication ordering and strongly connected components.
@ -638,7 +659,7 @@ namespace sat {
}
SASSERT(2*m_lookahead.size() == offset);
TRACE("sat", for (unsigned i = 0; i < m_lookahead.size(); ++i)
tout << m_lookahead[i].m_lit << " : " << m_lookahead[i].m_offset << "\n";);
tout << m_lookahead[i].m_lit << " : " << m_lookahead[i].m_offset << "\n";);
}
// ------------------------------------
@ -655,8 +676,8 @@ namespace sat {
clause_offset cls_off = m_cls_allocator.get_offset(&c);
m_watches[(~c[0]).index()].push_back(watched(block_lit, cls_off));
m_watches[(~c[1]).index()].push_back(watched(block_lit, cls_off));
SASSERT(value(c[0]) == l_undef);
SASSERT(value(c[1]) == l_undef);
SASSERT(is_undef(c[0]));
SASSERT(is_undef(c[1]));
}
}
@ -668,6 +689,9 @@ namespace sat {
}
void detach_ternary(literal l1, literal l2, literal l3) {
NOT_IMPLEMENTED_YET();
// there is a clause corresponding to a ternary watch group.
// the clause could be retired / detached.
m_retired_ternary.push_back(ternary(l1, l2, l3));
erase_ternary_watch(get_wlist(~l1), l2, l3);
erase_ternary_watch(get_wlist(~l2), l1, l3);
@ -680,8 +704,6 @@ namespace sat {
// initialization
void init_var(bool_var v) {
m_assignment.push_back(l_undef);
m_assignment.push_back(l_undef);
m_binary.push_back(literal_vector());
m_binary.push_back(literal_vector());
m_watches.push_back(watch_list());
@ -738,7 +760,6 @@ namespace sat {
unsigned trail_sz = s.init_trail_size();
for (unsigned i = 0; i < trail_sz; ++i) {
literal l = s.m_trail[i];
m_units.push_back(l);
assign(l);
}
}
@ -746,15 +767,15 @@ namespace sat {
// ------------------------------------
// search
void push(literal lit, search_mode mode) {
void push(literal lit, unsigned level) {
m_binary_trail_lim.push_back(m_binary_trail.size());
m_units_lim.push_back(m_units.size());
m_trail_lim.push_back(m_trail.size());
m_retired_clause_lim.push_back(m_retired_clauses.size());
m_qhead_lim.push_back(m_qhead);
m_trail.push_back(lit);
m_search_modes.push_back(m_search_mode);
m_search_mode = mode;
m_search_mode = searching;
scoped_level _sl(*this, level);
assign(lit);
propagate();
}
@ -766,6 +787,7 @@ namespace sat {
m_search_mode = m_search_modes.back();
m_search_modes.pop_back();
// not for lookahead
// unretire clauses
unsigned rsz = m_retired_clause_lim.back();
for (unsigned i = rsz; i < m_retired_clauses.size(); ++i) {
@ -774,6 +796,7 @@ namespace sat {
m_retired_clauses.resize(rsz);
m_retired_clause_lim.pop_back();
// m_search_mode == searching
// remove local binary clauses
unsigned old_sz = m_binary_trail_lim.back();
m_binary_trail_lim.pop_back();
@ -781,30 +804,31 @@ namespace sat {
del_binary(m_binary_trail[i]);
}
// not for lookahead.
// m_freevars only for main search
// undo assignments
for (unsigned i = m_trail.size(); i > m_trail_lim.size(); ) {
--i;
literal l = m_trail[i];
set_undef(l);
m_freevars.insert(l.var());
m_assignment[l.index()] = l_undef;
m_assignment[(~l).index()] = l_undef;
}
m_trail.shrink(m_trail_lim.size()); // reset assignment.
m_trail_lim.pop_back();
// add implied binary clauses
unsigned new_unit_sz = m_units_lim.back();
for (unsigned i = new_unit_sz; i < m_units.size(); ++i) {
try_add_binary(~m_trail.back(), m_units[i]);
}
m_units.shrink(new_unit_sz);
m_units_lim.pop_back();
// reset propagation queue
m_qhead_lim.pop_back();
m_qhead = m_qhead_lim.back();
}
void push_lookahead2(literal lit) {
}
void pop_lookahead2() {
}
float mix_diff(float l, float r) const { return l + r + (1 << 10) * l * r; }
clause const& get_clause(watch_list::iterator it) const {
@ -813,14 +837,13 @@ namespace sat {
}
bool is_nary_propagation(clause const& c, literal l) const {
bool r = c.size() > 2 && ((c[0] == l && value(c[1]) == l_false) || (c[1] == l && value(c[0]) == l_false));
DEBUG_CODE(if (r) for (unsigned j = 2; j < c.size(); ++j) SASSERT(value(c[j]) == l_false););
bool r = c.size() > 2 && ((c[0] == l && is_false(c[1])) || (c[1] == l && is_false(c[0])));
DEBUG_CODE(if (r) for (unsigned j = 2; j < c.size(); ++j) SASSERT(is_false(c[j])););
return r;
}
void propagate_clauses(literal l) {
SASSERT(value(l) == l_true);
SASSERT(value(~l) == l_false);
SASSERT(is_true(l));
if (inconsistent()) return;
watch_list& wlist = m_watches[l.index()];
watch_list::iterator it = wlist.begin(), it2 = it, end = wlist.end();
@ -833,20 +856,24 @@ namespace sat {
UNREACHABLE(); // we avoid adding ternary clauses for now.
literal l1 = it->get_literal1();
literal l2 = it->get_literal2();
lbool val1 = value(l1);
lbool val2 = value(l2);
if (val1 == l_false && val2 == l_undef) {
m_stats.m_propagations++;
assign(l2);
if (is_fixed(l1)) {
if (is_false(l1)) {
if (is_undef(l2)) {
m_stats.m_propagations++;
assign(l2);
}
else if (is_false(l2)) {
set_conflict();
}
}
}
else if (val1 == l_undef && val2 == l_false) {
m_stats.m_propagations++;
assign(l1);
else if (is_fixed(l2)) {
if (is_false(l2)) {
m_stats.m_propagations++;
assign(l1);
}
}
else if (val1 == l_false && val2 == l_false) {
set_conflict();
}
else if (val1 == l_undef && val2 == l_undef) {
else {
switch (m_search_mode) {
case searching:
detach_ternary(l, l1, l2);
@ -866,10 +893,10 @@ namespace sat {
case watched::CLAUSE: {
clause_offset cls_off = it->get_clause_offset();
clause & c = *(s.m_cls_allocator.get_clause(cls_off));
TRACE("propagate_clause_bug", tout << "processing... " << c << "\nwas_removed: " << c.was_removed() << "\n";);
TRACE("sat", tout << "propagating " << c << "\n";);
if (c[0] == ~l)
std::swap(c[0], c[1]);
if (value(c[0]) == l_true) {
if (is_true(c[0])) {
it2->set_clause(c[0], cls_off);
it2++;
break;
@ -878,7 +905,7 @@ namespace sat {
literal * l_end = c.end();
bool found = false;
for (; l_it != l_end && !found; ++l_it) {
if (value(*l_it) != l_false) {
if (!is_false(*l_it)) {
found = true;
c[1] = *l_it;
*l_it = ~l;
@ -888,10 +915,10 @@ namespace sat {
if (found) {
found = false;
for (; l_it != l_end && !found; ++l_it) {
found = value(*l_it) != l_false;
found = !is_false(*l_it);
}
// normal clause was converted to a binary clause.
if (!found && value(c[1]) == l_undef && value(c[0]) == l_undef) {
if (!found && is_undef(c[1]) && is_undef(c[0])) {
switch (m_search_mode) {
case searching:
detach_clause(c);
@ -906,11 +933,11 @@ namespace sat {
}
break;
}
if (value(c[0]) == l_false) {
if (is_false(c[0])) {
set_conflict();
}
else {
SASSERT(value(c[0]) == l_undef);
SASSERT(is_undef(c[0]));
*it2 = *it;
it2++;
m_stats.m_propagations++;
@ -929,14 +956,7 @@ namespace sat {
for (; it != end; ++it, ++it2) {
*it2 = *it;
}
wlist.set_end(it2);
//
// TBD: count binary clauses created by propagation.
// They used to be in the watch list of l.index(),
// both new literals in watch list should be unassigned.
//
wlist.set_end(it2);
}
void propagate_binary(literal l) {
@ -958,78 +978,100 @@ namespace sat {
}
literal choose() {
literal l;
while (!choose1(l)) {};
literal l = null_literal;
while (l == null_literal) {
pre_select();
if (m_lookahead.empty()) {
break;
}
compute_wnb();
if (inconsistent()) {
break;
}
l = select_literal();
}
return l;
}
// TBD:
// Handle scope properly for nested implications.
// Suppose u -> v, and u -> w and we process v first, then the
// consequences of v should remain when processing u.
// March and sat11.w solve this by introducing timestamps on truth values.
// regular push/pop doesn't really work here: we basically need a traversal of the
// lookahead tree and push/pop according to that (or adapt timestamps)
//
bool choose1(literal& l) {
pre_select();
l = null_literal;
if (m_lookahead.empty()) {
return true;
}
float h = 0;
unsigned count = 1;
void compute_wnb() {
init_wnb();
for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
if (value(lit) != l_undef) {
if (!is_undef(lit)) {
continue;
}
SASSERT(value(lit) == l_undef);
SASSERT(!inconsistent());
reset_wnb(lit);
push(lit, lookahead1);
do_double(lit);
if (inconsistent()) {
pop();
push_lookahead1(lit, 2 + m_lookahead[i].m_offset);
bool unsat = inconsistent();
// TBD do_double(lit);
pop_lookahead1();
update_wnb(lit);
if (unsat) {
reset_wnb();
assign(~lit);
propagate();
continue;
init_wnb();
}
update_wnb(lit);
float diff1 = m_weighted_new_binaries;
pop();
reset_wnb(~lit);
push(~lit, lookahead1);
do_double(~lit);
if (inconsistent()) {
pop();
assign(lit);
propagate();
continue;
}
update_wnb(~lit);
float diff2 = m_weighted_new_binaries;
pop();
}
reset_wnb();
}
void init_wnb() {
m_qhead_lim.push_back(m_qhead);
m_trail_lim.push_back(m_trail.size());
}
void reset_wnb() {
m_qhead = m_qhead_lim.back();
unsigned old_sz = m_trail_lim.back();
for (unsigned i = old_sz; i < m_trail.size(); ++i) {
set_undef(m_trail[i]);
}
m_trail.shrink(old_sz);
m_trail_lim.pop_back();
m_qhead_lim.pop_back();
}
literal select_literal() {
literal l = null_literal;
float h = 0;
unsigned count = 1;
for (unsigned i = 0; i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
if (lit.sign() || !is_undef(lit)) {
continue;
}
float diff1 = get_wnb(lit), diff2 = get_wnb(~lit);
float mixd = mix_diff(diff1, diff2);
if (mixd == h) ++count;
if (mixd > h || (mixd == h && s.m_rand(count) == 0)) {
CTRACE("sat", l != null_literal, tout << lit << " diff1: " << diff1 << " diff2: " << diff2 << "\n";);
if (mixd > h) count = 1; else ++count;
CTRACE("sat", l != null_literal, tout << lit << " " << mixd << "\n";);
if (mixd > h) count = 1;
h = mixd;
l = diff1 < diff2 ? lit : ~lit;
}
}
return l != null_literal || inconsistent();
return l;
}
void push_lookahead1(literal lit, unsigned level) {
m_search_modes.push_back(m_search_mode);
m_search_mode = lookahead1;
scoped_level _sl(*this, level);
assign(lit);
propagate();
}
void pop_lookahead1() {
SASSERT(!inconsistent());
m_search_mode = m_search_modes.back();
m_search_modes.pop_back();
}
void set_wnb(literal l, float f) { m_lits[l.index()].m_wnb = f; }
void inc_wnb(literal l, float f) { m_lits[l.index()].m_wnb += f; }
float get_wnb(literal l) const { return m_lits[l.index()].m_wnb; }
bool dl_enabled(literal l) const { return m_lits[l.index()].m_double_lookahead != m_istamp_id; }
void dl_disable(literal l) { m_lits[l.index()].m_double_lookahead = m_istamp_id; }
void reset_wnb(literal l) {
m_weighted_new_binaries = 0;
@ -1048,24 +1090,27 @@ namespace sat {
}
}
bool dl_enabled(literal l) const { return m_lits[l.index()].m_double_lookahead != m_istamp_id; }
void dl_disable(literal l) { m_lits[l.index()].m_double_lookahead = m_istamp_id; }
void double_look() {
bool unsat;
for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
if (value(lit) != l_undef) continue;
if (!is_undef(lit)) continue;
push(lit, lookahead2);
push_lookahead2(lit);
unsat = inconsistent();
pop();
pop_lookahead2();
if (unsat) {
TRACE("sat", tout << "unit: " << ~lit << "\n";);
assign(~lit);
continue;
}
push(~lit, lookahead2);
push_lookahead2(~lit);
unsat = inconsistent();
pop();
pop_lookahead2();
if (unsat) {
TRACE("sat", tout << "unit: " << lit << "\n";);
assign(lit);
@ -1073,31 +1118,25 @@ namespace sat {
}
}
bool is_fixed(literal l) const { return value(l) != l_undef; }
bool is_contrary(literal l) const { return value(l) == l_false; }
bool is_true(literal l) const { return value(l) == l_true; }
void set_conflict() { m_inconsistent = true; }
lbool value(literal l) const { return static_cast<lbool>(m_assignment[l.index()]); }
lbool value(bool_var v) const { return value(literal(v, false)); }
bool inconsistent() { return m_inconsistent; }
unsigned scope_lvl() const { return m_trail_lim.size(); }
void assign(literal l) {
switch (value(l)) {
case l_true:
break;
case l_false:
set_conflict();
break;
default:
m_assignment[l.index()] = l.sign() ? l_false : l_true;
m_assignment[(~l).index()] = l.sign() ? l_false : l_true;
if (is_undef(l)) {
set_true(l);
m_trail.push_back(l);
m_freevars.remove(l.var());
break;
if (m_search_mode == searching) {
m_freevars.remove(l.var());
}
}
else if (is_false(l)) {
set_conflict();
}
}
bool inconsistent() { return m_inconsistent; }
void do_double(literal l) {
if (!inconsistent() && scope_lvl() > 0 && dl_enabled(l)) {
@ -1115,7 +1154,8 @@ namespace sat {
bool backtrack(literal_vector& trail) {
if (trail.empty()) return false;
pop();
assign(~trail.back());
assign(~trail.back());
propagate();
trail.pop_back();
return true;
}
@ -1124,6 +1164,7 @@ namespace sat {
literal_vector trail;
m_search_mode = searching;
while (true) {
TRACE("sat", display(tout););
inc_istamp();
s.checkpoint();
literal l = choose();
@ -1135,19 +1176,55 @@ namespace sat {
return l_true;
}
TRACE("sat", tout << "choose: " << l << " " << trail << "\n";);
push(l, searching);
push(l, c_fixed_truth);
trail.push_back(l);
}
}
std::ostream& display_binary(std::ostream& out) const {
for (unsigned i = 0; i < m_binary.size(); ++i) {
literal_vector const& lits = m_binary[i];
if (!lits.empty()) {
out << to_literal(i) << " -> " << lits << "\n";
}
}
return out;
}
std::ostream& display_clauses(std::ostream& out) const {
for (unsigned i = 0; i < m_clauses.size(); ++i) {
out << *m_clauses[i] << "\n";
}
return out;
}
std::ostream& display_values(std::ostream& out) const {
for (unsigned i = 0; i < m_trail.size(); ++i) {
literal l = m_trail[i];
out << l << " " << m_stamp[l.var()] << "\n";
}
return out;
}
public:
lookahead(solver& s) : s(s) {
lookahead(solver& s) :
s(s),
m_level(0) {
scoped_level _sl(*this, c_fixed_truth);
init();
}
lbool check() {
return search();
}
std::ostream& display(std::ostream& out) const {
display_values(out);
display_binary(out);
display_clauses(out);
return out;
}
};
}