3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-27 08:28:44 +00:00

lookahead

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-06-09 09:17:20 -07:00
parent c870b77366
commit a55416351f
2 changed files with 50 additions and 41 deletions

View file

@ -211,7 +211,7 @@ namespace sat {
unsigned max_num_cand = level == 0 ? m_freevars.size() : level_cand / level;
max_num_cand = std::max(m_config.m_min_cutoff, max_num_cand);
float sum = 0;
double sum = 0;
for (bool newbies = false; ; newbies = true) {
sum = init_candidates(level, newbies);
if (!m_candidates.empty()) break;
@ -231,7 +231,7 @@ namespace sat {
bool progress = true;
while (progress && m_candidates.size() >= max_num_cand * 2) {
progress = false;
float mean = sum / (float)(m_candidates.size() + 0.0001);
double mean = sum / (double)(m_candidates.size() + 0.0001);
sum = 0;
for (unsigned i = 0; i < m_candidates.size() && m_candidates.size() >= max_num_cand * 2; ++i) {
if (m_candidates[i].m_rating >= mean) {
@ -279,14 +279,15 @@ namespace sat {
if (i > j) m_candidates[i] = c;
}
float lookahead::init_candidates(unsigned level, bool newbies) {
double lookahead::init_candidates(unsigned level, bool newbies) {
m_candidates.reset();
float sum = 0;
double sum = 0;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
SASSERT(is_undef(*it));
bool_var x = *it;
if (!m_select_lookahead_vars.empty()) {
if (m_select_lookahead_vars.contains(x)) {
// IF_VERBOSE(1, verbose_stream() << x << " " << m_rating[x] << "\n";);
m_candidates.push_back(candidate(x, m_rating[x]));
sum += m_rating[x];
}
@ -296,6 +297,7 @@ namespace sat {
sum += m_rating[x];
}
}
IF_VERBOSE(1, verbose_stream() << " " << sum << " " << m_candidates.size() << "\n";);
TRACE("sat", display_candidates(tout << "sum: " << sum << "\n"););
return sum;
}
@ -378,32 +380,34 @@ namespace sat {
void lookahead::ensure_H(unsigned level) {
while (m_H.size() <= level) {
m_H.push_back(svector<float>());
m_H.push_back(svector<double>());
m_H.back().resize(m_num_vars * 2, 0);
}
}
void lookahead::h_scores(svector<float>& h, svector<float>& hp) {
float sum = 0;
void lookahead::h_scores(svector<double>& h, svector<double>& hp) {
double sum = 0;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
literal l(*it, false);
sum += h[l.index()] + h[(~l).index()];
}
float factor = 2 * m_freevars.size() / sum;
float sqfactor = factor * factor;
float afactor = factor * m_config.m_alpha;
if (sum == 0) sum = 0.0001;
double factor = 2 * m_freevars.size() / sum;
double sqfactor = factor * factor;
double afactor = factor * m_config.m_alpha;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
literal l(*it, false);
float pos = l_score(l, h, factor, sqfactor, afactor);
float neg = l_score(~l, h, factor, sqfactor, afactor);
double pos = l_score(l, h, factor, sqfactor, afactor);
double neg = l_score(~l, h, factor, sqfactor, afactor);
hp[l.index()] = pos;
hp[(~l).index()] = neg;
// std::cout << "h_scores: " << pos << " " << neg << "\n";
m_rating[l.var()] = pos * neg;
}
}
float lookahead::l_score(literal l, svector<float> const& h, float factor, float sqfactor, float afactor) {
float sum = 0, tsum = 0;
double lookahead::l_score(literal l, svector<double> const& h, double factor, double sqfactor, double afactor) {
double sum = 0, tsum = 0;
literal_vector::iterator it = m_binary[l.index()].begin(), end = m_binary[l.index()].end();
for (; it != end; ++it) {
bool_var v = it->var();
@ -412,6 +416,7 @@ namespace sat {
if (is_undef(*it)) sum += h[it->index()];
// if (m_freevars.contains(it->var())) sum += h[it->index()];
}
// std::cout << "sum: " << sum << "\n";
watch_list& wlist = m_watches[l.index()];
watch_list::iterator wit = wlist.begin(), wend = wlist.end();
for (; wit != wend; ++wit) {
@ -440,9 +445,13 @@ namespace sat {
}
break;
}
// case watched::EXTERNAL:
}
// std::cout << "tsum: " << tsum << "\n";
}
sum = (float)(0.1 + afactor*sum + sqfactor*tsum);
// std::cout << "sum: " << sum << " afactor " << afactor << " sqfactor " << sqfactor << " tsum " << tsum << "\n";
sum = (double)(0.1 + afactor*sum + sqfactor*tsum);
// std::cout << "sum: " << sum << " max score " << m_config.m_max_score << "\n";
return std::min(m_config.m_max_score, sum);
}
@ -545,7 +554,7 @@ namespace sat {
literal t = m_active;
m_active = get_link(v);
literal best = v;
float best_rating = get_rating(v);
double best_rating = get_rating(v);
set_rank(v, UINT_MAX);
set_link(v, m_settled); m_settled = t;
while (t != v) {
@ -556,7 +565,7 @@ namespace sat {
}
set_rank(t, UINT_MAX);
set_parent(t, v);
float t_rating = get_rating(t);
double t_rating = get_rating(t);
if (t_rating > best_rating) {
best = t;
best_rating = t_rating;
@ -1124,7 +1133,7 @@ namespace sat {
found = false;
for (; l_it != l_end && !found; found = is_true(*l_it), ++l_it) ;
if (!found) {
m_weighted_new_binaries = (float)0.001;
m_weighted_new_binaries = (double)0.001;
}
}
break;
@ -1272,15 +1281,15 @@ namespace sat {
literal lookahead::select_literal() {
literal l = null_literal;
float h = 0;
double h = 0;
unsigned count = 1;
for (unsigned i = 0; i < m_lookahead.size(); ++i) {
literal lit = m_lookahead[i].m_lit;
if (lit.sign() || !is_undef(lit)) {
continue;
}
float diff1 = get_wnb(lit), diff2 = get_wnb(~lit);
float mixd = mix_diff(diff1, diff2);
double diff1 = get_wnb(lit), diff2 = get_wnb(~lit);
double mixd = mix_diff(diff1, diff2);
if (mixd == h) ++count;
if (mixd > h || (mixd == h && m_s.m_rand(count) == 0)) {