3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 10:25:18 +00:00

fix double override bug in bv_lookahead, integrate with bv_eval

This commit is contained in:
Nikolaj Bjorner 2024-12-27 12:26:11 -08:00
parent 8de0005ab3
commit b0eee16109
9 changed files with 128 additions and 56 deletions

View file

@ -1656,9 +1656,9 @@ namespace sls {
}
if (result < 0)
return 0.1;
return 0.0000001;
else if (result == 0)
return 0.2;
return 0.000002;
for (int i = m_prob_break.size(); i <= breaks; ++i)
m_prob_break.push_back(std::pow(m_config.cb, -i));
return m_prob_break[breaks];

View file

@ -679,7 +679,7 @@ namespace sls {
expr* arg = e->get_arg(i);
if (m.is_value(arg))
return false;
if (m.is_bool(e) && false && m_rand(10) == 0 && m_lookahead.try_repair_down(e))
if (false && m.is_bool(e) && ctx.rand(10) == 0 && m_lookahead.try_repair_down(e))
return true;
if (e->get_family_id() == bv.get_family_id() && try_repair_bv(e, i)) {
commit_eval(e, to_app(arg));
@ -2024,6 +2024,10 @@ namespace sls {
return expr_ref(m);
}
void bv_eval::collect_statistics(statistics& st) const {
m_lookahead.collect_statistics(st);
}
std::ostream& bv_eval::display(std::ostream& out) const {
auto& terms = ctx.subterms();
for (expr* e : terms) {

View file

@ -190,6 +190,7 @@ namespace sls {
*/
bool repair_up(expr* e);
void collect_statistics(statistics& st) const;
std::ostream& display(std::ostream& out) const;

View file

@ -36,25 +36,45 @@ namespace sls {
auto const& uninterp = m_ev.terms.uninterp_occurs(e);
if (uninterp.empty())
return false;
if (false && ctx.rand(10) == 0 && apply_random_update(uninterp))
return true;
reset_updates();
IF_VERBOSE(4,
verbose_stream() << mk_bounded_pp(e, m) << "\n";
for (auto e : uninterp)
verbose_stream() << mk_bounded_pp(e, m) << " ";
verbose_stream() << "\n");
TRACE("sls", tout << mk_bounded_pp(e, m) << " contains ";
for (auto e : uninterp)
tout << mk_bounded_pp(e, m) << " ";
tout << "\n";);
for (auto e : uninterp)
for (auto e : uninterp)
add_updates(e);
#if 0
m_stats.m_num_lookahead += 1;
m_stats.m_num_updates += m_num_updates;
TRACE("sls", display_updates(tout));
if (apply_update())
return true;
return apply_random_update(uninterp);
}
void bv_lookahead::display_updates(std::ostream& out) {
for (unsigned i = 0; i < m_num_updates; ++i) {
auto const& [e, score, new_value] = m_updates[i];
verbose_stream() << mk_bounded_pp(e, m) << " " << new_value << " score: " << score << "\n";
out << mk_bounded_pp(e, m) << " " << new_value << " score: " << score << "\n";
}
#endif
return apply_update();
}
bool bv_lookahead::apply_random_update(ptr_vector<expr> const& vars) {
expr* e = vars[ctx.rand(vars.size())];
auto& v = wval(e);
m_v_updated.set_bw(v.bw);
v.get_variant(m_v_updated, m_ev.m_rand);
apply_update(e, m_v_updated);
return true;
}
double bv_lookahead::lookahead(expr* e, bvect const& new_value) {
@ -63,22 +83,23 @@ namespace sls {
SASSERT(m_restore.empty());
bool has_tabu = false;
double break_count = 0, make_count = 0;
int result = 0;
int breaks = 0;
wval(e).eval = new_value;
if (!insert_update(e)) {
restore_lookahead();
m_in_update_stack.reset();
return -1000000;
}
insert_update_stack(e);
unsigned max_depth = get_depth(e);
for (unsigned depth = max_depth; depth <= max_depth; ++depth) {
for (unsigned i = 0; !has_tabu && i < m_update_stack[depth].size(); ++i) {
auto e = m_update_stack[depth][i];
if (bv.is_bv(e)) {
auto& v = m_ev.eval(to_app(e));
if (insert_update(e)) {
for (auto p : ctx.parents(e)) {
auto a = m_update_stack[depth][i];
if (bv.is_bv(a)) {
if (a == e || (m_ev.eval(a), insert_update(a))) { // do not insert e twice
for (auto p : ctx.parents(a)) {
insert_update_stack(p);
max_depth = std::max(max_depth, get_depth(p));
}
@ -86,32 +107,43 @@ namespace sls {
else
has_tabu = true;
}
else if (m.is_bool(e) && m_ev.can_eval1(to_app(e))) {
if (!ctx.is_relevant(e))
else if (m.is_bool(a) && m_ev.can_eval1(a)) {
if (!ctx.is_relevant(a))
continue;
bool is_true = ctx.is_true(e);
bool is_true_new = m_ev.bval1(to_app(e));
bool is_true_old = m_ev.bval1_tmp(to_app(e));
bool is_true = ctx.is_true(a);
bool is_true_new = m_ev.bval1(a);
bool is_true_old = m_ev.bval1_tmp(a);
TRACE("sls_verbose", tout << mk_bounded_pp(a, m) << " " << is_true << " " << is_true_new << " " << is_true_old << "\n");
if (is_true_new == is_true_old)
continue;
if (is_true == is_true_new)
++make_count;
if (is_true == is_true_old)
++break_count;
++result;
if (is_true == is_true_old) {
--result;
++breaks;
}
}
else {
IF_VERBOSE(1, verbose_stream() << "skipping " << mk_bounded_pp(e, m) << "\n");
IF_VERBOSE(1, verbose_stream() << "skipping " << mk_bounded_pp(a, m) << "\n");
has_tabu = true;
}
}
m_update_stack[depth].reset();
}
m_in_update_stack.reset();
restore_lookahead();
// verbose_stream() << has_tabu << " " << new_value << " " << make_count << " " << break_count << "\n";
TRACE("sls_verbose", tout << mk_bounded_pp(e, m) << " " << new_value << " " << result << " " << breaks << "\n");
if (has_tabu)
return -10000;
return make_count - break_count;
if (result < 0)
return 0.0000001;
else if (result == 0)
return 0.000002;
for (int i = m_prob_break.size(); i <= breaks; ++i)
m_prob_break.push_back(std::pow(m_config.cb, -i));
return m_prob_break[breaks];
}
void bv_lookahead::try_set(expr* e, bvect const& new_value) {
@ -125,7 +157,6 @@ namespace sls {
void bv_lookahead::add_updates(expr* e) {
SASSERT(bv.is_bv(e));
auto& v = wval(e);
double d = 0;
while (m_v_saved.size() < v.bits().size()) {
m_v_saved.push_back(0);
m_v_updated.push_back(0);
@ -161,9 +192,9 @@ namespace sls {
v.sub1(m_v_updated);
try_set(e, m_v_updated);
// random
v.get_variant(m_v_updated, m_ev.m_rand);
try_set(e, m_v_updated);
// random, deffered to failure path
// v.get_variant(m_v_updated, m_ev.m_rand);
// try_set(e, m_v_updated);
}
bool bv_lookahead::apply_update() {
@ -174,12 +205,13 @@ namespace sls {
for (unsigned i = 0; i < m_num_updates; ++i) {
auto const& [e, score, new_value] = m_updates[i];
pos -= score;
if (pos <= 0) {
//verbose_stream() << "apply " << mk_bounded_pp(e, m) << " new value " << new_value << " " << score << "\n";
if (pos <= 0.00000000001) {
TRACE("sls", tout << "apply " << mk_bounded_pp(e, m) << " new value " << new_value << " " << score << "\n");
apply_update(e, new_value);
return true;
}
}
TRACE("sls", tout << "no update " << m_num_updates << "\n");
return false;
}
@ -195,14 +227,18 @@ namespace sls {
for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) {
auto e = m_update_stack[depth][i];
if (bv.is_bv(e)) {
m_ev.eval(to_app(e)); // updates wval(e).eval
VERIFY(wval(e).commit_eval());
m_ev.eval(e); // updates wval(e).eval
if (!wval(e).commit_eval()) {
TRACE("sls", tout << "failed to commit " << mk_bounded_pp(e, m) << " " << wval(e) << "\n");
// bv_plugin::is_sat picks up discrepancies
continue;
}
for (auto p : ctx.parents(e)) {
insert_update_stack(p);
max_depth = std::max(max_depth, get_depth(p));
}
}
else if (m.is_bool(e) && m_ev.can_eval1(to_app(e))) {
else if (m.is_bool(e) && m_ev.can_eval1(e)) {
VERIFY(m_ev.repair_up(e));
}
else {
@ -215,9 +251,10 @@ namespace sls {
}
bool bv_lookahead::insert_update(expr* e) {
auto& v = wval(e);
m_restore.push_back(e);
m_on_restore.mark(e);
auto& v = wval(e);
TRACE("sls_verbose", tout << "insert update " << mk_bounded_pp(e, m) << " " << v << "\n");
v.save_value();
return v.commit_eval();
}
@ -225,18 +262,19 @@ namespace sls {
void bv_lookahead::insert_update_stack(expr* e) {
unsigned depth = get_depth(e);
m_update_stack.reserve(depth + 1);
if (!m_in_update_stack.is_marked(e)) {
if (!m_in_update_stack.is_marked(e) && is_app(e)) {
m_in_update_stack.mark(e);
m_update_stack[depth].push_back(e);
m_update_stack[depth].push_back(to_app(e));
}
}
void bv_lookahead::restore_lookahead() {
for (auto e : m_restore)
for (auto e : m_restore) {
wval(e).restore_value();
TRACE("sls_verbose", tout << "restore value " << mk_bounded_pp(e, m) << " " << wval(e) << "\n");
}
m_restore.reset();
m_on_restore.reset();
m_in_update_stack.reset();
}
sls::bv_valuation& bv_lookahead::wval(expr* e) const {
@ -246,4 +284,9 @@ namespace sls {
bool bv_lookahead::on_restore(expr* e) const {
return m_on_restore.is_marked(e);
}
void bv_lookahead::collect_statistics(statistics& st) const {
st.update("sls-bv-lookahead", m_stats.m_num_lookahead);
st.update("sls-bv-updates", m_stats.m_num_updates);
}
}

View file

@ -24,23 +24,38 @@ namespace sls {
class bv_eval;
class bv_lookahead {
struct config {
double cb = 2.85;
};
struct update {
expr* e;
double score;
bvect value;
};
struct stats {
unsigned m_num_lookahead = 0;
unsigned m_num_updates = 0;
};
bv_util bv;
bv_eval& m_ev;
context& ctx;
ast_manager& m;
config m_config;
stats m_stats;
bvect m_v_saved, m_v_updated;
svector<double> m_prob_break;
ptr_vector<expr> m_restore;
vector<ptr_vector<expr>> m_update_stack;
vector<ptr_vector<app>> m_update_stack;
expr_mark m_on_restore, m_in_update_stack;
struct update {
expr* e;
double score;
bvect value;
};
vector<update> m_updates;
unsigned m_num_updates = 0;
void reset_updates() { m_num_updates = 0; }
void add_update(double score, expr* e, bvect const& value) {
if (m_num_updates == m_updates.size())
m_updates.push_back({ e, score, value });
@ -65,13 +80,18 @@ namespace sls {
void add_updates(expr* e);
void apply_update(expr* e, bvect const& new_value);
bool apply_update();
bool apply_random_update(ptr_vector<expr> const& vars);
void display_updates(std::ostream& out);
public:
bv_lookahead(bv_eval& ev);
bool on_restore(expr* e) const;
bool try_repair_down(app* e);
void collect_statistics(statistics& st) const;
};
}

View file

@ -167,6 +167,10 @@ namespace sls {
ctx.flip(lit.var());
}
void bv_plugin::collect_statistics(statistics& st) const {
m_eval.collect_statistics(st);
}
std::ostream& bv_plugin::trace_repair(bool down, expr* e) {
verbose_stream() << (down ? "d #" : "u #")
<< e->get_id() << ": "

View file

@ -53,7 +53,7 @@ namespace sls {
void on_restart() override {}
std::ostream& display(std::ostream& out) const override;
bool set_value(expr* e, expr* v) override;
void collect_statistics(statistics& st) const override {}
void collect_statistics(statistics& st) const override;
void reset_statistics() override {}
};

View file

@ -33,7 +33,7 @@ namespace sls {
bool operator==(bvect const& a, bvect const& b) {
SASSERT(a.nw > 0);
return 0 == mpn_manager().compare(a.data(), a.nw, b.data(), a.nw);
return 0 == memcmp(a.data(), b.data(), a.nw * sizeof(digit_t));
}
bool operator<(bvect const& a, bvect const& b) {

View file

@ -565,8 +565,8 @@ namespace sls {
SASSERT(m.is_true(get_value(e)) == is_true(v));
}
}
);
);
m_repair_down.reserve(e->get_id() + 1);
m_repair_up.reserve(e->get_id() + 1);
if (!term(e->get_id()))