From b0eee161099b4baff1048774f099db74bbe49025 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 27 Dec 2024 12:26:11 -0800 Subject: [PATCH] fix double override bug in bv_lookahead, integrate with bv_eval --- src/ast/sls/sls_arith_base.cpp | 4 +- src/ast/sls/sls_bv_eval.cpp | 6 +- src/ast/sls/sls_bv_eval.h | 1 + src/ast/sls/sls_bv_lookahead.cpp | 127 +++++++++++++++++++++---------- src/ast/sls/sls_bv_lookahead.h | 34 +++++++-- src/ast/sls/sls_bv_plugin.cpp | 4 + src/ast/sls/sls_bv_plugin.h | 2 +- src/ast/sls/sls_bv_valuation.cpp | 2 +- src/ast/sls/sls_context.cpp | 4 +- 9 files changed, 128 insertions(+), 56 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index a31304831..7515c58de 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -1656,9 +1656,9 @@ namespace sls { } if (result < 0) - return 0.1; + return 0.0000001; else if (result == 0) - return 0.2; + return 0.000002; for (int i = m_prob_break.size(); i <= breaks; ++i) m_prob_break.push_back(std::pow(m_config.cb, -i)); return m_prob_break[breaks]; diff --git a/src/ast/sls/sls_bv_eval.cpp b/src/ast/sls/sls_bv_eval.cpp index 997b3c4b3..d2332deee 100644 --- a/src/ast/sls/sls_bv_eval.cpp +++ b/src/ast/sls/sls_bv_eval.cpp @@ -679,7 +679,7 @@ namespace sls { expr* arg = e->get_arg(i); if (m.is_value(arg)) return false; - if (m.is_bool(e) && false && m_rand(10) == 0 && m_lookahead.try_repair_down(e)) + if (false && m.is_bool(e) && ctx.rand(10) == 0 && m_lookahead.try_repair_down(e)) return true; if (e->get_family_id() == bv.get_family_id() && try_repair_bv(e, i)) { commit_eval(e, to_app(arg)); @@ -2024,6 +2024,10 @@ namespace sls { return expr_ref(m); } + void bv_eval::collect_statistics(statistics& st) const { + m_lookahead.collect_statistics(st); + } + std::ostream& bv_eval::display(std::ostream& out) const { auto& terms = ctx.subterms(); for (expr* e : terms) { diff --git a/src/ast/sls/sls_bv_eval.h b/src/ast/sls/sls_bv_eval.h index 6dc2d35e4..2314c2895 100644 --- a/src/ast/sls/sls_bv_eval.h +++ b/src/ast/sls/sls_bv_eval.h @@ -190,6 +190,7 @@ namespace sls { */ bool repair_up(expr* e); + void collect_statistics(statistics& st) const; std::ostream& display(std::ostream& out) const; diff --git a/src/ast/sls/sls_bv_lookahead.cpp b/src/ast/sls/sls_bv_lookahead.cpp index 6341f1e0e..2ad3cc535 100644 --- a/src/ast/sls/sls_bv_lookahead.cpp +++ b/src/ast/sls/sls_bv_lookahead.cpp @@ -36,25 +36,45 @@ namespace sls { auto const& uninterp = m_ev.terms.uninterp_occurs(e); if (uninterp.empty()) return false; + + if (false && ctx.rand(10) == 0 && apply_random_update(uninterp)) + return true; + reset_updates(); - IF_VERBOSE(4, - verbose_stream() << mk_bounded_pp(e, m) << "\n"; - for (auto e : uninterp) - verbose_stream() << mk_bounded_pp(e, m) << " "; - verbose_stream() << "\n"); + TRACE("sls", tout << mk_bounded_pp(e, m) << " contains "; + for (auto e : uninterp) + tout << mk_bounded_pp(e, m) << " "; + tout << "\n";); - for (auto e : uninterp) + for (auto e : uninterp) add_updates(e); -#if 0 + m_stats.m_num_lookahead += 1; + m_stats.m_num_updates += m_num_updates; + + TRACE("sls", display_updates(tout)); + + if (apply_update()) + return true; + + return apply_random_update(uninterp); + } + + void bv_lookahead::display_updates(std::ostream& out) { for (unsigned i = 0; i < m_num_updates; ++i) { auto const& [e, score, new_value] = m_updates[i]; - verbose_stream() << mk_bounded_pp(e, m) << " " << new_value << " score: " << score << "\n"; + out << mk_bounded_pp(e, m) << " " << new_value << " score: " << score << "\n"; } -#endif - - return apply_update(); + } + + bool bv_lookahead::apply_random_update(ptr_vector const& vars) { + expr* e = vars[ctx.rand(vars.size())]; + auto& v = wval(e); + m_v_updated.set_bw(v.bw); + v.get_variant(m_v_updated, m_ev.m_rand); + apply_update(e, m_v_updated); + return true; } double bv_lookahead::lookahead(expr* e, bvect const& new_value) { @@ -63,22 +83,23 @@ namespace sls { SASSERT(m_restore.empty()); bool has_tabu = false; - double break_count = 0, make_count = 0; + int result = 0; + int breaks = 0; wval(e).eval = new_value; if (!insert_update(e)) { restore_lookahead(); + m_in_update_stack.reset(); return -1000000; } insert_update_stack(e); unsigned max_depth = get_depth(e); for (unsigned depth = max_depth; depth <= max_depth; ++depth) { for (unsigned i = 0; !has_tabu && i < m_update_stack[depth].size(); ++i) { - auto e = m_update_stack[depth][i]; - if (bv.is_bv(e)) { - auto& v = m_ev.eval(to_app(e)); - if (insert_update(e)) { - for (auto p : ctx.parents(e)) { + auto a = m_update_stack[depth][i]; + if (bv.is_bv(a)) { + if (a == e || (m_ev.eval(a), insert_update(a))) { // do not insert e twice + for (auto p : ctx.parents(a)) { insert_update_stack(p); max_depth = std::max(max_depth, get_depth(p)); } @@ -86,32 +107,43 @@ namespace sls { else has_tabu = true; } - else if (m.is_bool(e) && m_ev.can_eval1(to_app(e))) { - if (!ctx.is_relevant(e)) + else if (m.is_bool(a) && m_ev.can_eval1(a)) { + if (!ctx.is_relevant(a)) continue; - bool is_true = ctx.is_true(e); - bool is_true_new = m_ev.bval1(to_app(e)); - bool is_true_old = m_ev.bval1_tmp(to_app(e)); + bool is_true = ctx.is_true(a); + bool is_true_new = m_ev.bval1(a); + bool is_true_old = m_ev.bval1_tmp(a); + TRACE("sls_verbose", tout << mk_bounded_pp(a, m) << " " << is_true << " " << is_true_new << " " << is_true_old << "\n"); if (is_true_new == is_true_old) continue; if (is_true == is_true_new) - ++make_count; - if (is_true == is_true_old) - ++break_count; + ++result; + if (is_true == is_true_old) { + --result; + ++breaks; + } } else { - IF_VERBOSE(1, verbose_stream() << "skipping " << mk_bounded_pp(e, m) << "\n"); + IF_VERBOSE(1, verbose_stream() << "skipping " << mk_bounded_pp(a, m) << "\n"); has_tabu = true; } } m_update_stack[depth].reset(); } + m_in_update_stack.reset(); restore_lookahead(); - // verbose_stream() << has_tabu << " " << new_value << " " << make_count << " " << break_count << "\n"; + + TRACE("sls_verbose", tout << mk_bounded_pp(e, m) << " " << new_value << " " << result << " " << breaks << "\n"); if (has_tabu) return -10000; - return make_count - break_count; + if (result < 0) + return 0.0000001; + else if (result == 0) + return 0.000002; + for (int i = m_prob_break.size(); i <= breaks; ++i) + m_prob_break.push_back(std::pow(m_config.cb, -i)); + return m_prob_break[breaks]; } void bv_lookahead::try_set(expr* e, bvect const& new_value) { @@ -125,7 +157,6 @@ namespace sls { void bv_lookahead::add_updates(expr* e) { SASSERT(bv.is_bv(e)); auto& v = wval(e); - double d = 0; while (m_v_saved.size() < v.bits().size()) { m_v_saved.push_back(0); m_v_updated.push_back(0); @@ -161,9 +192,9 @@ namespace sls { v.sub1(m_v_updated); try_set(e, m_v_updated); - // random - v.get_variant(m_v_updated, m_ev.m_rand); - try_set(e, m_v_updated); + // random, deffered to failure path + // v.get_variant(m_v_updated, m_ev.m_rand); + // try_set(e, m_v_updated); } bool bv_lookahead::apply_update() { @@ -174,12 +205,13 @@ namespace sls { for (unsigned i = 0; i < m_num_updates; ++i) { auto const& [e, score, new_value] = m_updates[i]; pos -= score; - if (pos <= 0) { - //verbose_stream() << "apply " << mk_bounded_pp(e, m) << " new value " << new_value << " " << score << "\n"; + if (pos <= 0.00000000001) { + TRACE("sls", tout << "apply " << mk_bounded_pp(e, m) << " new value " << new_value << " " << score << "\n"); apply_update(e, new_value); return true; } } + TRACE("sls", tout << "no update " << m_num_updates << "\n"); return false; } @@ -195,14 +227,18 @@ namespace sls { for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { auto e = m_update_stack[depth][i]; if (bv.is_bv(e)) { - m_ev.eval(to_app(e)); // updates wval(e).eval - VERIFY(wval(e).commit_eval()); + m_ev.eval(e); // updates wval(e).eval + if (!wval(e).commit_eval()) { + TRACE("sls", tout << "failed to commit " << mk_bounded_pp(e, m) << " " << wval(e) << "\n"); + // bv_plugin::is_sat picks up discrepancies + continue; + } for (auto p : ctx.parents(e)) { insert_update_stack(p); max_depth = std::max(max_depth, get_depth(p)); } } - else if (m.is_bool(e) && m_ev.can_eval1(to_app(e))) { + else if (m.is_bool(e) && m_ev.can_eval1(e)) { VERIFY(m_ev.repair_up(e)); } else { @@ -215,9 +251,10 @@ namespace sls { } bool bv_lookahead::insert_update(expr* e) { + auto& v = wval(e); m_restore.push_back(e); m_on_restore.mark(e); - auto& v = wval(e); + TRACE("sls_verbose", tout << "insert update " << mk_bounded_pp(e, m) << " " << v << "\n"); v.save_value(); return v.commit_eval(); } @@ -225,18 +262,19 @@ namespace sls { void bv_lookahead::insert_update_stack(expr* e) { unsigned depth = get_depth(e); m_update_stack.reserve(depth + 1); - if (!m_in_update_stack.is_marked(e)) { + if (!m_in_update_stack.is_marked(e) && is_app(e)) { m_in_update_stack.mark(e); - m_update_stack[depth].push_back(e); + m_update_stack[depth].push_back(to_app(e)); } } void bv_lookahead::restore_lookahead() { - for (auto e : m_restore) + for (auto e : m_restore) { wval(e).restore_value(); + TRACE("sls_verbose", tout << "restore value " << mk_bounded_pp(e, m) << " " << wval(e) << "\n"); + } m_restore.reset(); m_on_restore.reset(); - m_in_update_stack.reset(); } sls::bv_valuation& bv_lookahead::wval(expr* e) const { @@ -246,4 +284,9 @@ namespace sls { bool bv_lookahead::on_restore(expr* e) const { return m_on_restore.is_marked(e); } + + void bv_lookahead::collect_statistics(statistics& st) const { + st.update("sls-bv-lookahead", m_stats.m_num_lookahead); + st.update("sls-bv-updates", m_stats.m_num_updates); + } } \ No newline at end of file diff --git a/src/ast/sls/sls_bv_lookahead.h b/src/ast/sls/sls_bv_lookahead.h index 15b810ec0..c60ad1362 100644 --- a/src/ast/sls/sls_bv_lookahead.h +++ b/src/ast/sls/sls_bv_lookahead.h @@ -24,23 +24,38 @@ namespace sls { class bv_eval; class bv_lookahead { + struct config { + double cb = 2.85; + }; + + struct update { + expr* e; + double score; + bvect value; + }; + + struct stats { + unsigned m_num_lookahead = 0; + unsigned m_num_updates = 0; + }; + + bv_util bv; bv_eval& m_ev; context& ctx; ast_manager& m; + config m_config; + stats m_stats; bvect m_v_saved, m_v_updated; - + svector m_prob_break; ptr_vector m_restore; - vector> m_update_stack; + vector> m_update_stack; expr_mark m_on_restore, m_in_update_stack; - struct update { - expr* e; - double score; - bvect value; - }; vector m_updates; unsigned m_num_updates = 0; + void reset_updates() { m_num_updates = 0; } + void add_update(double score, expr* e, bvect const& value) { if (m_num_updates == m_updates.size()) m_updates.push_back({ e, score, value }); @@ -65,13 +80,18 @@ namespace sls { void add_updates(expr* e); void apply_update(expr* e, bvect const& new_value); bool apply_update(); + bool apply_random_update(ptr_vector const& vars); + void display_updates(std::ostream& out); public: bv_lookahead(bv_eval& ev); + bool on_restore(expr* e) const; bool try_repair_down(app* e); + void collect_statistics(statistics& st) const; + }; } \ No newline at end of file diff --git a/src/ast/sls/sls_bv_plugin.cpp b/src/ast/sls/sls_bv_plugin.cpp index 4c2877235..78d367770 100644 --- a/src/ast/sls/sls_bv_plugin.cpp +++ b/src/ast/sls/sls_bv_plugin.cpp @@ -167,6 +167,10 @@ namespace sls { ctx.flip(lit.var()); } + void bv_plugin::collect_statistics(statistics& st) const { + m_eval.collect_statistics(st); + } + std::ostream& bv_plugin::trace_repair(bool down, expr* e) { verbose_stream() << (down ? "d #" : "u #") << e->get_id() << ": " diff --git a/src/ast/sls/sls_bv_plugin.h b/src/ast/sls/sls_bv_plugin.h index ea750163e..7bcc4c329 100644 --- a/src/ast/sls/sls_bv_plugin.h +++ b/src/ast/sls/sls_bv_plugin.h @@ -53,7 +53,7 @@ namespace sls { void on_restart() override {} std::ostream& display(std::ostream& out) const override; bool set_value(expr* e, expr* v) override; - void collect_statistics(statistics& st) const override {} + void collect_statistics(statistics& st) const override; void reset_statistics() override {} }; diff --git a/src/ast/sls/sls_bv_valuation.cpp b/src/ast/sls/sls_bv_valuation.cpp index 29fc6b517..d23471a2c 100644 --- a/src/ast/sls/sls_bv_valuation.cpp +++ b/src/ast/sls/sls_bv_valuation.cpp @@ -33,7 +33,7 @@ namespace sls { bool operator==(bvect const& a, bvect const& b) { SASSERT(a.nw > 0); - return 0 == mpn_manager().compare(a.data(), a.nw, b.data(), a.nw); + return 0 == memcmp(a.data(), b.data(), a.nw * sizeof(digit_t)); } bool operator<(bvect const& a, bvect const& b) { diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index b177bfb0c..cab031d52 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -565,8 +565,8 @@ namespace sls { SASSERT(m.is_true(get_value(e)) == is_true(v)); } } - ); - + ); + m_repair_down.reserve(e->get_id() + 1); m_repair_up.reserve(e->get_id() + 1); if (!term(e->get_id()))