3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-01 11:51:20 +00:00

add N-ary clause reward heuristic based on discussions with Heule

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-06-21 10:30:30 -07:00
parent c6fbe38f78
commit 5f93b9a081
2 changed files with 108 additions and 53 deletions

View file

@ -367,6 +367,7 @@ namespace sat {
} }
void lookahead::init_pre_selection(unsigned level) { void lookahead::init_pre_selection(unsigned level) {
if (!m_config.m_use_ternary_reward) return;
unsigned max_level = m_config.m_max_hlevel; unsigned max_level = m_config.m_max_hlevel;
if (level <= 1) { if (level <= 1) {
ensure_H(2); ensure_H(2);
@ -888,6 +889,12 @@ namespace sat {
for (; it != end; ++it) { for (; it != end; ++it) {
clause& c = *(*it); clause& c = *(*it);
if (c.was_removed()) continue; if (c.was_removed()) continue;
#if 0
// enable when there is a non-ternary reward system.
if (c.size() > 3) {
m_config.m_use_ternary_reward = false;
}
#endif
clause* c1 = m_cls_allocator.mk_clause(c.size(), c.begin(), false); clause* c1 = m_cls_allocator.mk_clause(c.size(), c.begin(), false);
m_clauses.push_back(c1); m_clauses.push_back(c1);
attach_clause(*c1); attach_clause(*c1);
@ -1004,7 +1011,7 @@ namespace sat {
TRACE("sat", tout << "windfall: " << nlit << " " << l2 << "\n";); TRACE("sat", tout << "windfall: " << nlit << " " << l2 << "\n";);
// if we use try_add_binary, then this may produce new assignments // if we use try_add_binary, then this may produce new assignments
// these assignments get put on m_trail, and they are cleared by // these assignments get put on m_trail, and they are cleared by
// reset_wnb. We would need to distinguish the trail that comes // reset_lookahead_reward. We would need to distinguish the trail that comes
// from lookahead levels and the main search level for this to work. // from lookahead levels and the main search level for this to work.
add_binary(nlit, l2); add_binary(nlit, l2);
} }
@ -1079,7 +1086,7 @@ namespace sat {
skip = true; skip = true;
break; break;
case lookahead_mode::lookahead1: case lookahead_mode::lookahead1:
m_weighted_new_binaries += (*m_heur)[l1.index()] * (*m_heur)[l2.index()]; update_binary_clause_reward(l1, l2);
break; break;
case lookahead2: case lookahead2:
break; break;
@ -1133,20 +1140,14 @@ namespace sat {
try_add_binary(c[0], c[1]); try_add_binary(c[0], c[1]);
break; break;
case lookahead_mode::lookahead1: case lookahead_mode::lookahead1:
m_weighted_new_binaries += (*m_heur)[c[0].index()]* (*m_heur)[c[1].index()]; update_binary_clause_reward(c[0], c[1]);
break; break;
case lookahead_mode::lookahead2: case lookahead_mode::lookahead2:
break; break;
} }
} }
else if (found && m_search_mode == lookahead_mode::lookahead1 && m_weighted_new_binaries == 0) { else if (found && m_search_mode == lookahead_mode::lookahead1) {
// leave a trail that some clause was reduced but potentially not an autarky update_nary_clause_reward(c);
l_it = c.begin() + 2;
found = false;
for (; l_it != l_end && !found; found = is_true(*l_it), ++l_it) ;
if (!found) {
m_weighted_new_binaries = (double)0.001;
}
} }
break; break;
} }
@ -1192,6 +1193,53 @@ namespace sat {
wlist.set_end(it2); wlist.set_end(it2);
} }
void lookahead::update_binary_clause_reward(literal l1, literal l2) {
SASSERT(!is_false(l1));
SASSERT(!is_false(l2));
if (m_config.m_use_ternary_reward) {
m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()];
}
else {
m_lookahead_reward += 0.5 * (literal_occs(l1) + literal_occs(l2));
}
}
void lookahead::update_nary_clause_reward(clause const& c) {
if (m_config.m_use_ternary_reward && m_lookahead_reward != 0) {
return;
}
literal const * l_it = c.begin() + 2, *l_end = c.end();
unsigned sz = 0;
for (; l_it != l_end; ++l_it) {
if (is_true(*l_it)) return;
if (!is_false(*l_it)) ++sz;
}
if (!m_config.m_use_ternary_reward) {
SASSERT(sz > 0);
double to_add = 0;
for (literal l : c) {
if (!is_false(l)) {
to_add += literal_occs(l);
}
}
m_lookahead_reward += pow(sz, -2) * to_add;
}
else {
m_lookahead_reward = (double)0.001;
}
}
// Sum_{ clause C that contains ~l } 1 / |C|
double lookahead::literal_occs(literal l) {
double result = m_binary[l.index()].size();
for (clause const* c : m_full_watches[l.index()]) {
if (!is_true((*c)[0]) && !is_true((*c)[1])) {
result += 1.0 / c->size();
}
}
return result;
}
void lookahead::propagate_binary(literal l) { void lookahead::propagate_binary(literal l) {
literal_vector const& lits = m_binary[l.index()]; literal_vector const& lits = m_binary[l.index()];
TRACE("sat", tout << l << " => " << lits << "\n";); TRACE("sat", tout << l << " => " << lits << "\n";);
@ -1201,6 +1249,7 @@ namespace sat {
} }
} }
void lookahead::propagate() { void lookahead::propagate() {
while (!inconsistent() && m_qhead < m_trail.size()) { while (!inconsistent() && m_qhead < m_trail.size()) {
unsigned i = m_qhead; unsigned i = m_qhead;
@ -1220,8 +1269,8 @@ namespace sat {
TRACE("sat_verbose", display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n");); TRACE("sat_verbose", display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n"););
} }
void lookahead::compute_wnb() { void lookahead::compute_lookahead_reward() {
init_wnb(); init_lookahead_reward();
TRACE("sat", display_lookahead(tout); ); TRACE("sat", display_lookahead(tout); );
unsigned base = 2; unsigned base = 2;
bool change = true; bool change = true;
@ -1240,21 +1289,21 @@ namespace sat {
IF_VERBOSE(30, verbose_stream() << scope_lvl() << " " << lit << " binary: " << m_binary_trail.size() << " trail: " << m_trail_lim.back() << "\n";); IF_VERBOSE(30, verbose_stream() << scope_lvl() << " " << lit << " binary: " << m_binary_trail.size() << " trail: " << m_trail_lim.back() << "\n";);
} }
TRACE("sat", tout << "lookahead: " << lit << " @ " << m_lookahead[i].m_offset << "\n";); TRACE("sat", tout << "lookahead: " << lit << " @ " << m_lookahead[i].m_offset << "\n";);
reset_wnb(lit); reset_lookahead_reward(lit);
push_lookahead1(lit, level); push_lookahead1(lit, level);
if (!first) do_double(lit, base); if (!first) do_double(lit, base);
bool unsat = inconsistent(); bool unsat = inconsistent();
pop_lookahead1(lit); pop_lookahead1(lit);
if (unsat) { if (unsat) {
TRACE("sat", tout << "backtracking and settting " << ~lit << "\n";); TRACE("sat", tout << "backtracking and settting " << ~lit << "\n";);
reset_wnb(); reset_lookahead_reward();
assign(~lit); assign(~lit);
propagate(); propagate();
init_wnb(); init_lookahead_reward();
change = true; change = true;
} }
else { else {
update_wnb(lit, level); update_lookahead_reward(lit, level);
} }
SASSERT(inconsistent() || !is_unsat()); SASSERT(inconsistent() || !is_unsat());
} }
@ -1265,23 +1314,23 @@ namespace sat {
first = false; first = false;
change = true; change = true;
} }
reset_wnb(); reset_lookahead_reward();
init_wnb(); init_lookahead_reward();
// base += 2 * m_lookahead.size(); // base += 2 * m_lookahead.size();
} }
reset_wnb(); reset_lookahead_reward();
TRACE("sat", display_lookahead(tout); ); TRACE("sat", display_lookahead(tout); );
} }
void lookahead::init_wnb() { void lookahead::init_lookahead_reward() {
TRACE("sat", tout << "init_wnb: " << m_qhead << "\n";); TRACE("sat", tout << "init_lookahead_reward: " << m_qhead << "\n";);
m_qhead_lim.push_back(m_qhead); m_qhead_lim.push_back(m_qhead);
m_trail_lim.push_back(m_trail.size()); m_trail_lim.push_back(m_trail.size());
} }
void lookahead::reset_wnb() { void lookahead::reset_lookahead_reward() {
m_qhead = m_qhead_lim.back(); m_qhead = m_qhead_lim.back();
TRACE("sat", tout << "reset_wnb: " << m_qhead << "\n";); TRACE("sat", tout << "reset_lookahead_reward: " << m_qhead << "\n";);
unsigned old_sz = m_trail_lim.back(); unsigned old_sz = m_trail_lim.back();
for (unsigned i = old_sz; i < m_trail.size(); ++i) { for (unsigned i = old_sz; i < m_trail.size(); ++i) {
set_undef(m_trail[i]); set_undef(m_trail[i]);
@ -1300,7 +1349,7 @@ namespace sat {
if (lit.sign() || !is_undef(lit)) { if (lit.sign() || !is_undef(lit)) {
continue; continue;
} }
double diff1 = get_wnb(lit), diff2 = get_wnb(~lit); double diff1 = get_lookahead_reward(lit), diff2 = get_lookahead_reward(~lit);
double mixd = mix_diff(diff1, diff2); double mixd = mix_diff(diff1, diff2);
if (mixd == h) ++count; if (mixd == h) ++count;
@ -1317,12 +1366,13 @@ namespace sat {
} }
void lookahead::reset_wnb(literal l) { void lookahead::reset_lookahead_reward(literal l) {
m_weighted_new_binaries = 0;
m_lookahead_reward = 0;
// inherit propagation effect from parent. // inherit propagation effect from parent.
literal p = get_parent(l); literal p = get_parent(l);
set_wnb(l, p == null_literal ? 0 : get_wnb(p)); set_lookahead_reward(l, p == null_literal ? 0 : get_lookahead_reward(p));
} }
bool lookahead::check_autarky(literal l, unsigned level) { bool lookahead::check_autarky(literal l, unsigned level) {
@ -1361,22 +1411,22 @@ namespace sat {
} }
void lookahead::update_wnb(literal l, unsigned level) { void lookahead::update_lookahead_reward(literal l, unsigned level) {
if (m_weighted_new_binaries == 0) { if (m_lookahead_reward == 0) {
if (!check_autarky(l, level)) { if (!check_autarky(l, level)) {
// skip // skip
} }
else if (get_wnb(l) == 0) { else if (get_lookahead_reward(l) == 0) {
++m_stats.m_autarky_propagations; ++m_stats.m_autarky_propagations;
IF_VERBOSE(1, verbose_stream() << "(sat.lookahead autarky " << l << ")\n";); IF_VERBOSE(1, verbose_stream() << "(sat.lookahead autarky " << l << ")\n";);
TRACE("sat", tout << "autarky: " << l << " @ " << m_stamp[l.var()] TRACE("sat", tout << "autarky: " << l << " @ " << m_stamp[l.var()]
<< " " << " "
<< (!m_binary[l.index()].empty() || !m_full_watches[l.index()].empty()) << "\n";); << (!m_binary[l.index()].empty() || !m_full_watches[l.index()].empty()) << "\n";);
reset_wnb(); reset_lookahead_reward();
assign(l); assign(l);
propagate(); propagate();
init_wnb(); init_lookahead_reward();
} }
else { else {
++m_stats.m_autarky_equivalences; ++m_stats.m_autarky_equivalences;
@ -1396,17 +1446,17 @@ namespace sat {
} }
} }
else { else {
inc_wnb(l, m_weighted_new_binaries); inc_lookahead_reward(l, m_lookahead_reward);
} }
} }
void lookahead::do_double(literal l, unsigned& base) { void lookahead::do_double(literal l, unsigned& base) {
if (!inconsistent() && scope_lvl() > 1 && dl_enabled(l)) { if (!inconsistent() && scope_lvl() > 1 && dl_enabled(l)) {
if (get_wnb(l) > m_delta_trigger) { if (get_lookahead_reward(l) > m_delta_trigger) {
if (dl_no_overflow(base)) { if (dl_no_overflow(base)) {
++m_stats.m_double_lookahead_rounds; ++m_stats.m_double_lookahead_rounds;
double_look(l, base); double_look(l, base);
m_delta_trigger = get_wnb(l); m_delta_trigger = get_lookahead_reward(l);
dl_disable(l); dl_disable(l);
} }
} }
@ -1422,7 +1472,7 @@ namespace sat {
unsigned dl_truth = base + 2 * m_lookahead.size() * (m_config.m_dl_max_iterations + 1); unsigned dl_truth = base + 2 * m_lookahead.size() * (m_config.m_dl_max_iterations + 1);
scoped_level _sl(*this, dl_truth); scoped_level _sl(*this, dl_truth);
IF_VERBOSE(2, verbose_stream() << "double: " << l << "\n";); IF_VERBOSE(2, verbose_stream() << "double: " << l << "\n";);
init_wnb(); init_lookahead_reward();
assign(l); assign(l);
propagate(); propagate();
bool change = true; bool change = true;
@ -1438,16 +1488,16 @@ namespace sat {
TRACE("sat", tout << "unit: " << ~lit << "\n";); TRACE("sat", tout << "unit: " << ~lit << "\n";);
++m_stats.m_double_lookahead_propagations; ++m_stats.m_double_lookahead_propagations;
SASSERT(m_level == dl_truth); SASSERT(m_level == dl_truth);
reset_wnb(); reset_lookahead_reward();
assign(~lit); assign(~lit);
propagate(); propagate();
change = true; change = true;
init_wnb(); init_lookahead_reward();
} }
} }
SASSERT(dl_truth - 2 * m_lookahead.size() > base); SASSERT(dl_truth - 2 * m_lookahead.size() > base);
} }
reset_wnb(); reset_lookahead_reward();
SASSERT(m_level == dl_truth); SASSERT(m_level == dl_truth);
base = dl_truth; base = dl_truth;
} }
@ -1585,7 +1635,7 @@ namespace sat {
unsigned offset = m_lookahead[i].m_offset; unsigned offset = m_lookahead[i].m_offset;
out << lit << "\toffset: " << offset; out << lit << "\toffset: " << offset;
out << (is_undef(lit)?" undef": (is_true(lit) ? " true": " false")); out << (is_undef(lit)?" undef": (is_true(lit) ? " true": " false"));
out << " wnb: " << get_wnb(lit); out << " lookahead_reward: " << get_lookahead_reward(lit);
out << "\n"; out << "\n";
} }
return out; return out;
@ -1613,7 +1663,7 @@ namespace sat {
if (m_lookahead.empty()) { if (m_lookahead.empty()) {
break; break;
} }
compute_wnb(); compute_lookahead_reward();
if (inconsistent()) { if (inconsistent()) {
break; break;
} }

View file

@ -76,6 +76,7 @@ namespace sat {
double m_delta_rho; double m_delta_rho;
unsigned m_dl_max_iterations; unsigned m_dl_max_iterations;
unsigned m_tc1_limit; unsigned m_tc1_limit;
bool m_use_ternary_reward;
config() { config() {
m_max_hlevel = 50; m_max_hlevel = 50;
@ -86,6 +87,7 @@ namespace sat {
m_delta_rho = (double)0.9995; m_delta_rho = (double)0.9995;
m_dl_max_iterations = 32; m_dl_max_iterations = 32;
m_tc1_limit = 10000000; m_tc1_limit = 10000000;
m_use_ternary_reward = true;
} }
}; };
@ -96,9 +98,9 @@ namespace sat {
}; };
struct lit_info { struct lit_info {
double m_wnb; double m_lookahead_reward;
unsigned m_double_lookahead; unsigned m_double_lookahead;
lit_info(): m_wnb(0), m_double_lookahead(0) {} lit_info(): m_lookahead_reward(0), m_double_lookahead(0) {}
}; };
struct stats { struct stats {
@ -156,7 +158,7 @@ namespace sat {
vector<watch_list> m_watches; // literal: watch structure vector<watch_list> m_watches; // literal: watch structure
svector<lit_info> m_lits; // literal: attributes. svector<lit_info> m_lits; // literal: attributes.
vector<clause_vector> m_full_watches; // literal: full watch list, used to ensure that autarky reduction is sound vector<clause_vector> m_full_watches; // literal: full watch list, used to ensure that autarky reduction is sound
double m_weighted_new_binaries; // metric associated with current lookahead1 literal. double m_lookahead_reward; // metric associated with current lookahead1 literal.
literal_vector m_wstack; // windofall stack that is populated in lookahead1 mode literal_vector m_wstack; // windofall stack that is populated in lookahead1 mode
uint64 m_prefix; // where we are in search tree uint64 m_prefix; // where we are in search tree
svector<prefix> m_vprefix; // var: prefix where variable participates in propagation svector<prefix> m_vprefix; // var: prefix where variable participates in propagation
@ -393,18 +395,21 @@ namespace sat {
void propagate_binary(literal l); void propagate_binary(literal l);
void propagate(); void propagate();
literal choose(); literal choose();
void compute_wnb(); void compute_lookahead_reward();
void init_wnb(); void init_lookahead_reward();
void reset_wnb(); void reset_lookahead_reward();
literal select_literal(); literal select_literal();
void update_binary_clause_reward(literal l1, literal l2);
void update_nary_clause_reward(clause const& c);
double literal_occs(literal l);
void set_wnb(literal l, double f) { m_lits[l.index()].m_wnb = f; } void set_lookahead_reward(literal l, double f) { m_lits[l.index()].m_lookahead_reward = f; }
void inc_wnb(literal l, double f) { m_lits[l.index()].m_wnb += f; } void inc_lookahead_reward(literal l, double f) { m_lits[l.index()].m_lookahead_reward += f; }
double get_wnb(literal l) const { return m_lits[l.index()].m_wnb; } double get_lookahead_reward(literal l) const { return m_lits[l.index()].m_lookahead_reward; }
void reset_wnb(literal l); void reset_lookahead_reward(literal l);
bool check_autarky(literal l, unsigned level); bool check_autarky(literal l, unsigned level);
void update_wnb(literal l, unsigned level); void update_lookahead_reward(literal l, unsigned level);
bool dl_enabled(literal l) const { return m_lits[l.index()].m_double_lookahead != m_istamp_id; } bool dl_enabled(literal l) const { return m_lits[l.index()].m_double_lookahead != m_istamp_id; }
void dl_disable(literal l) { m_lits[l.index()].m_double_lookahead = m_istamp_id; } void dl_disable(literal l) { m_lits[l.index()].m_double_lookahead = m_istamp_id; }
bool dl_no_overflow(unsigned base) const { return base + 2 * m_lookahead.size() * static_cast<uint64>(m_config.m_dl_max_iterations + 1) < c_fixed_truth; } bool dl_no_overflow(unsigned base) const { return base + 2 * m_lookahead.size() * static_cast<uint64>(m_config.m_dl_max_iterations + 1) < c_fixed_truth; }