3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-13 20:38:43 +00:00

Cuber fixes. Added March_CU heuristics

This commit is contained in:
Miguel Neves 2017-10-06 16:10:05 -07:00
parent 133f376172
commit 4d91169118
6 changed files with 130 additions and 59 deletions

View file

@ -120,7 +120,7 @@ def _get_args(args):
try:
if len(args) == 1 and (isinstance(args[0], tuple) or isinstance(args[0], list)):
return args[0]
elif len(args) == 1 and isinstance(args[0], set):
elif len(args) == 1 and (isinstance(args[0], set) or isinstance(args[0], AstVector)):
return [arg for arg in args[0]]
else:
return args

View file

@ -100,6 +100,9 @@ namespace sat {
else if (p.lookahead_reward() == symbol("unit")) {
m_lookahead_reward = unit_literal_reward;
}
else if (p.lookahead_reward() == symbol("march_cu")) {
m_lookahead_reward = march_cu_reward;
}
else {
throw sat_param_exception("invalid reward type supplied: accepted heuristics are 'ternary', 'heuleu', 'unit' or 'heule_schur'");
}

View file

@ -61,7 +61,8 @@ namespace sat {
ternary_reward,
unit_literal_reward,
heule_schur_reward,
heule_unit_reward
heule_unit_reward,
march_cu_reward
};
struct config {

View file

@ -217,7 +217,7 @@ namespace sat {
bool lookahead::select(unsigned level) {
init_pre_selection(level);
unsigned level_cand = std::max(m_config.m_level_cand, m_freevars.size() / 50);
unsigned max_num_cand = level == 0 ? m_freevars.size() : level_cand / level;
unsigned max_num_cand = (level > 0 && m_config.m_preselect) ? level_cand / level : m_freevars.size();
max_num_cand = std::max(m_config.m_min_cutoff, max_num_cand);
double sum = 0;
@ -251,30 +251,40 @@ namespace sat {
}
TRACE("sat", display_candidates(tout););
SASSERT(!m_candidates.empty());
if (m_candidates.size() > max_num_cand) {
unsigned j = m_candidates.size()/2;
while (j > 0) {
--j;
sift_up(j);
}
while (true) {
m_candidates[0] = m_candidates.back();
m_candidates.pop_back();
if (m_candidates.size() == max_num_cand) break;
sift_up(0);
}
heap_sort();
while (m_candidates.size() > max_num_cand) {
m_candidates.pop_back();
}
SASSERT(!m_candidates.empty() && m_candidates.size() <= max_num_cand);
TRACE("sat", display_candidates(tout););
return true;
}
void lookahead::sift_up(unsigned j) {
void lookahead::heap_sort() {
if (m_candidates.size() > 1) {
heapify();
for (unsigned i = m_candidates.size() - 1; i > 0; --i) {
candidate c = m_candidates[i];
m_candidates[i] = m_candidates[0];
m_candidates[0] = c;
sift_down(0, i);
}
}
}
void lookahead::heapify() {
unsigned i = 1 + (m_candidates.size() - 2) / 2;
while(i > 0) {
sift_down(--i, m_candidates.size());
}
}
void lookahead::sift_down(unsigned j, unsigned sz) {
unsigned i = j;
candidate c = m_candidates[j];
for (unsigned k = 2*j + 1; k < m_candidates.size(); i = k, k = 2*k + 1) {
// pick largest parent
if (k + 1 < m_candidates.size() && m_candidates[k].m_rating < m_candidates[k+1].m_rating) {
for (unsigned k = 2 * j + 1; k < sz; i = k, k = 2 * k + 1) {
// pick smallest child
if (k + 1 < sz && m_candidates[k].m_rating > m_candidates[k + 1].m_rating) {
++k;
}
if (c.m_rating <= m_candidates[k].m_rating) break;
@ -432,6 +442,9 @@ namespace sat {
case heule_unit_reward:
heule_unit_scores();
break;
case march_cu_reward:
march_cu_scores();
break;
case unit_literal_reward:
heule_schur_scores();
break;
@ -478,7 +491,7 @@ namespace sat {
for (bool_var x : m_freevars) {
literal l(x, false);
m_rating[l.var()] = heule_unit_score(l) * heule_unit_score(~l);
}
}
}
double lookahead::heule_unit_score(literal l) {
@ -493,7 +506,23 @@ namespace sat {
sum += pow(0.5, m_nary_literals[cls_idx]);
}
return sum;
}
}
void lookahead::march_cu_scores() {
for (bool_var x : m_freevars) {
literal l(x, false);
double pos = march_cu_score(l), neg = march_cu_score(~l);
m_rating[l.var()] = 1024 * pos * neg + pos + neg + 1;
}
}
double lookahead::march_cu_score(literal l) {
double sum = 1.0 + literal_big_occs(~l);
for (literal lit : m_binary[l.index()]) {
if (is_undef(lit)) sum += literal_big_occs(lit);
}
return sum;
}
void lookahead::ensure_H(unsigned level) {
while (m_H.size() <= level) {
@ -884,7 +913,7 @@ namespace sat {
}
void lookahead::init() {
m_delta_trigger = m_num_vars/10;
m_delta_trigger = 0.0;
m_config.m_dl_success = 0.8;
m_inconsistent = false;
m_qhead = 0;
@ -1068,6 +1097,7 @@ namespace sat {
m_lookahead_reward += num_units;
break;
case heule_unit_reward:
case march_cu_reward:
case heule_schur_reward:
break;
default:
@ -1391,6 +1421,9 @@ namespace sat {
case heule_unit_reward:
m_lookahead_reward += pow(0.5, nonfixed);
break;
case march_cu_reward:
m_lookahead_reward += 3.3 * pow(0.5, nonfixed - 2);
break;
case ternary_reward:
if (nonfixed == 2) {
m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()];
@ -1486,6 +1519,9 @@ namespace sat {
case heule_unit_reward:
m_lookahead_reward += 0.25;
break;
case march_cu_reward:
m_lookahead_reward += 3.3;
break;
case unit_literal_reward:
break;
}
@ -1516,6 +1552,9 @@ namespace sat {
case heule_unit_reward:
m_lookahead_reward += pow(0.5, sz);
break;
case march_cu_reward:
m_lookahead_reward += 3.3 * pow(0.5, sz - 2);
break;
case ternary_reward:
m_lookahead_reward = (double)0.001;
break;
@ -1525,10 +1564,18 @@ namespace sat {
}
// Sum_{ clause C that contains ~l } 1
// FIXME: counts occurences of ~l; misleading
double lookahead::literal_occs(literal l) {
double result = m_binary[l.index()].size();
unsigned_vector const& nclauses = m_nary[(~l).index()];
result += m_nary_count[(~l).index()];
//unsigned_vector const& nclauses = m_nary[(~l).index()];
result += literal_big_occs(l);
return result;
}
// Sum_{ clause C that contains ~l such that |C| > 2} 1
// FIXME: counts occurences of ~l; misleading
double lookahead::literal_big_occs(literal l) {
double result = m_nary_count[(~l).index()];
result += m_ternary_count[(~l).index()];
return result;
}
@ -1542,22 +1589,19 @@ namespace sat {
}
}
void lookahead::propagate() {
while (!inconsistent() && m_qhead < m_trail.size()) {
unsigned i = m_qhead;
unsigned sz = m_trail.size();
for (; i < sz && !inconsistent(); ++i) {
for (; i < m_trail.size() && !inconsistent(); ++i) {
literal l = m_trail[i];
TRACE("sat", tout << "propagate " << l << " @ " << m_level << "\n";);
propagate_binary(l);
}
while (m_qhead < sz && !inconsistent()) {
while (m_qhead < m_trail.size() && !inconsistent()) {
propagate_clauses(m_trail[m_qhead++]);
}
SASSERT(m_qhead == sz || (inconsistent() && m_qhead < sz));
SASSERT(m_qhead == m_trail.size() || (inconsistent() && m_qhead < m_trail.size()));
}
TRACE("sat_verbose", display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n"););
}
@ -1566,12 +1610,15 @@ namespace sat {
TRACE("sat", display_lookahead(tout); );
unsigned base = 2;
bool change = true;
bool first = true;
literal last_changed = null_literal;
while (change && !inconsistent()) {
change = false;
for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) {
checkpoint();
literal lit = m_lookahead[i].m_lit;
if (lit == last_changed) {
break;
}
if (is_fixed_at(lit, c_fixed_truth)) continue;
unsigned level = base + m_lookahead[i].m_offset;
if (m_stamp[lit.var()] >= level) {
@ -1584,7 +1631,7 @@ namespace sat {
unsigned old_trail_sz = m_trail.size();
reset_lookahead_reward(lit);
push_lookahead1(lit, level);
if (!first) do_double(lit, base);
do_double(lit, base);
bool unsat = inconsistent();
unsigned num_units = m_trail.size() - old_trail_sz;
pop_lookahead1(lit, num_units);
@ -1595,6 +1642,7 @@ namespace sat {
propagate();
init_lookahead_reward();
change = true;
last_changed = lit;
}
else {
update_lookahead_reward(lit, level);
@ -1604,13 +1652,9 @@ namespace sat {
if (c_fixed_truth - 2 * m_lookahead.size() < base) {
break;
}
if (first && !change) {
first = false;
change = true;
}
reset_lookahead_reward();
init_lookahead_reward();
// base += 2 * m_lookahead.size();
base += 2 * m_lookahead.size();
}
reset_lookahead_reward();
TRACE("sat", display_lookahead(tout); );
@ -1664,6 +1708,7 @@ namespace sat {
case ternary_reward: return l + r + (1 << 10) * l * r;
case heule_schur_reward: return l * r;
case heule_unit_reward: return l * r;
case march_cu_reward: return 1024 * (1024 * l * r + l + r);
case unit_literal_reward: return l * r;
default: UNREACHABLE(); return l * r;
}
@ -1753,14 +1798,16 @@ namespace sat {
}
}
void lookahead::do_double(literal l, unsigned& base) {
if (!inconsistent() && scope_lvl() > 1 && dl_enabled(l)) {
void lookahead::do_double(literal l, unsigned& base) {
if (!inconsistent() && dl_enabled(l)) {
if (get_lookahead_reward(l) > m_delta_trigger) {
if (dl_no_overflow(base)) {
++m_stats.m_double_lookahead_rounds;
double_look(l, base);
m_delta_trigger = get_lookahead_reward(l);
dl_disable(l);
if (!inconsistent()) {
m_delta_trigger = get_lookahead_reward(l);
dl_disable(l);
}
}
}
else {
@ -1772,21 +1819,30 @@ namespace sat {
void lookahead::double_look(literal l, unsigned& base) {
SASSERT(!inconsistent());
SASSERT(dl_no_overflow(base));
unsigned dl_truth = base + 2 * m_lookahead.size() * (m_config.m_dl_max_iterations + 1);
base += m_lookahead.size();
unsigned dl_truth = base + m_lookahead.size() * m_config.m_dl_max_iterations;
scoped_level _sl(*this, dl_truth);
IF_VERBOSE(2, verbose_stream() << "double: " << l << " depth: " << m_trail_lim.size() << "\n";);
init_lookahead_reward();
assign(l);
propagate();
bool change = true;
literal last_changed = null_literal;
unsigned num_iterations = 0;
while (change && num_iterations < m_config.m_dl_max_iterations && !inconsistent()) {
change = false;
num_iterations++;
base += 2*m_lookahead.size();
for (unsigned i = 0; !inconsistent() && i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
if (is_fixed_at(lit, dl_truth)) continue;
if (lit == last_changed) {
SASSERT(change == false);
break;
}
if (is_fixed_at(lit, dl_truth)) continue;
if (base + m_lookahead.size() + m_lookahead[i].m_offset >= dl_truth) {
change = false;
break;
}
if (push_lookahead2(lit, base + m_lookahead[i].m_offset)) {
TRACE("sat", tout << "unit: " << ~lit << "\n";);
++m_stats.m_double_lookahead_propagations;
@ -1795,10 +1851,12 @@ namespace sat {
assign(~lit);
propagate();
change = true;
last_changed = lit;
init_lookahead_reward();
}
}
SASSERT(dl_truth - 2 * m_lookahead.size() > base);
base += 2 * m_lookahead.size();
SASSERT(dl_truth >= base);
}
reset_lookahead_reward();
SASSERT(m_level == dl_truth);
@ -1838,14 +1896,13 @@ namespace sat {
void lookahead::propagated(literal l) {
assign(l);
switch (m_search_mode) {
case lookahead_mode::searching:
break;
case lookahead_mode::lookahead1:
for (unsigned i = m_trail.size() - 1; i < m_trail.size() && !inconsistent(); ++i) {
literal l = m_trail[i];
TRACE("sat", tout << "propagate " << l << " @ " << m_level << "\n";);
propagate_binary(l);
}
if (m_search_mode == lookahead_mode::lookahead1) {
m_wstack.push_back(l);
break;
case lookahead_mode::lookahead2:
break;
}
}
@ -1953,10 +2010,11 @@ namespace sat {
backtrack(m_cube_state.m_cube, m_cube_state.m_is_decision);
return l_undef;
}
int prev_nfreevars = m_freevars.size();
literal lit = choose();
if (inconsistent()) {
TRACE("sat", tout << "inconsistent: " << m_cube_state.m_cube << "\n";);
m_cube_state.m_freevars_threshold = m_freevars.size();
m_cube_state.m_freevars_threshold = prev_nfreevars;
if (!backtrack(m_cube_state.m_cube, m_cube_state.m_is_decision)) return l_false;
continue;
}

View file

@ -73,6 +73,7 @@ namespace sat {
double m_max_score;
unsigned m_max_hlevel;
unsigned m_min_cutoff;
bool m_preselect;
unsigned m_level_cand;
double m_delta_rho;
unsigned m_dl_max_iterations;
@ -86,9 +87,10 @@ namespace sat {
m_alpha = 3.5;
m_max_score = 20.0;
m_min_cutoff = 30;
m_preselect = false;
m_level_cand = 600;
m_delta_rho = (double)0.9995;
m_dl_max_iterations = 32;
m_delta_rho = (double)0.99995;
m_dl_max_iterations = 2;
m_tc1_limit = 10000000;
m_reward_type = ternary_reward;
m_cube_cutoff = 0;
@ -289,7 +291,10 @@ namespace sat {
double get_rating(bool_var v) const { return m_rating[v]; }
double get_rating(literal l) const { return get_rating(l.var()); }
bool select(unsigned level);
void sift_up(unsigned j);
//void sift_up(unsigned j);
void heap_sort();
void heapify();
void sift_down(unsigned j, unsigned sz);
double init_candidates(unsigned level, bool newbies);
std::ostream& display_candidates(std::ostream& out) const;
bool is_unsat() const;
@ -301,6 +306,8 @@ namespace sat {
double heule_schur_score(literal l);
void heule_unit_scores();
double heule_unit_score(literal l);
void march_cu_scores();
double march_cu_score(literal l);
double l_score(literal l, svector<double> const& h, double factor, double sqfactor, double afactor);
// ------------------------------------
@ -460,6 +467,7 @@ namespace sat {
void do_double(literal l, unsigned& base);
void double_look(literal l, unsigned& base);
void set_conflict() { TRACE("sat", tout << "conflict\n";); m_inconsistent = true; }
//void set_conflict() { TRACE("sat", tout << "conflict\n";); printf("CONFLICT\n"); m_inconsistent = true; }
bool inconsistent() { return m_inconsistent; }
unsigned scope_lvl() const { return m_trail_lim.size(); }
@ -544,7 +552,8 @@ namespace sat {
void collect_statistics(statistics& st) const;
double literal_occs(literal l);
double literal_occs(literal l);
double literal_big_occs(literal l);
};
}

View file

@ -31,7 +31,7 @@ def_module_params('sat',
('drat.check', BOOL, False, 'build up internal proof and check'),
('cardinality.solver', BOOL, False, 'use cardinality solver'),
('pb.solver', SYMBOL, 'circuit', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use SMT solver)'),
('xor.solver', BOOL, False, 'use xor solver'),
('xor.solver', BOOL, False, 'use xor solver'),
('atmost1_encoding', SYMBOL, 'grouped', 'encoding used for at-most-1 constraints grouped, bimander, ordered'),
('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'),
('local_search', BOOL, False, 'use local search instead of CDCL'),
@ -40,7 +40,7 @@ def_module_params('sat',
('lookahead.cube.cutoff', UINT, 0, 'cut-off depth to create cubes. Only enabled when non-zero. Used when lookahead_cube is true.'),
('lookahead_search', BOOL, False, 'use lookahead solver'),
('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'),
('lookahead.reward', SYMBOL, 'heuleu', 'select lookahead heuristic: ternary, heule_schur (Heule Schur), heuleu (Heule Unit), or unit'),
('lookahead.reward', SYMBOL, 'march_cu', 'select lookahead heuristic: ternary, heule_schur (Heule Schur), heuleu (Heule Unit), unit, or march_cu'),
('dimacs.display', BOOL, False, 'display SAT instance in DIMACS format and return unknown instead of solving'),
('dimacs.inprocess.display', BOOL, False, 'display SAT instance in DIMACS format if unsolved after inprocess.max inprocessing passes')))