diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index 4a4c27159..415368e2c 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -329,7 +329,7 @@ namespace sat { void ddfw::init_clause_data() { for (unsigned v = 0; v < num_vars(); ++v) { make_count(v) = 0; - reward(v) = 0; + m_vars[v].m_reward = 0; } m_unsat_vars.reset(); m_unsat.reset(); @@ -590,6 +590,44 @@ namespace sat { m_use_list[(~unit).index()].reset(); } + bool ddfw::try_rotate(bool_var v, bool_var_set& rotated, unsigned& budget) { + if (m_rotate_tabu.contains(v)) + return false; + if (budget == 0) + return false; + --budget; + rotated.insert(v); + m_rotate_tabu.insert(v); + flip(v); + switch (m_unsat.size()) { + case 0: + m_rotate_tabu.reset(); + m_new_tabu_vars.reset(); + return true; + case 1: + for (unsigned cl : m_unsat) { + unsigned sz = m_new_tabu_vars.size(); + for (literal lit : get_clause(cl)) { + if (m_rotate_tabu.contains(lit.var())) + continue; + if (try_rotate(lit.var(), rotated, budget)) + return true; + m_rotate_tabu.insert(lit.var()); + m_new_tabu_vars.push_back(lit.var()); + } + while (m_new_tabu_vars.size() > sz) + m_rotate_tabu.remove(m_new_tabu_vars.back()), m_new_tabu_vars.pop_back(); + } + break; + default: + break; + } + rotated.remove(v); + m_rotate_tabu.remove(v); + flip(v); + return false; + } + std::ostream& ddfw::display(std::ostream& out) const { unsigned num_cls = m_clauses.size(); for (unsigned i = 0; i < num_cls; ++i) { @@ -598,7 +636,7 @@ namespace sat { out << ci.m_num_trues << " w: " << ci.m_weight << "\n"; } for (unsigned v = 0; v < num_vars(); ++v) - out << (is_true(literal(v, false)) ? "" : "-") << v << " rw: " << get_reward(v) << "\n"; + out << (is_true(literal(v, false)) ? "" : "-") << v << " rw: " << reward(v) << "\n"; out << "unsat vars: "; for (bool_var v : m_unsat_vars) out << v << " "; diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 6f3386a05..5c027d759 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -123,7 +123,7 @@ namespace sat { inline bool value(bool_var v) const { return m_vars[v].m_value; } - inline double& reward(bool_var v) { return m_vars[v].m_reward; } + // inline double reward(bool_var v) { return m_vars[v].m_reward; } unsigned value_hash() const; @@ -150,9 +150,9 @@ namespace sat { if (--make_count(v) == 0) m_unsat_vars.remove(v); } - inline void inc_reward(literal lit, double w) { reward(lit.var()) += w; } + inline void inc_reward(literal lit, double w) { m_vars[lit.var()].m_reward += w; } - inline void dec_reward(literal lit, double w) { reward(lit.var()) -= w; } + inline void dec_reward(literal lit, double w) { m_vars[lit.var()].m_reward -= w; } void check_with_plugin(); void check_without_plugin(); @@ -201,6 +201,9 @@ namespace sat { inline bool disregard_neighbor(); + bool_var_set m_rotate_tabu; + bool_var_vector m_new_tabu_vars; + public: ddfw() {} @@ -248,7 +251,9 @@ namespace sat { void flip(bool_var v); - inline double get_reward(bool_var v) const { return m_vars[v].m_reward; } + inline double reward(bool_var v) const { return m_vars[v].m_reward; } + + void set_reward(bool_var v, double r) { m_vars[v].m_reward = r; } double get_reward_avg(bool_var v) const { return m_vars[v].m_reward_avg; } @@ -268,6 +273,7 @@ namespace sat { void simplify(); + bool try_rotate(bool_var v, bool_var_set& rotated, unsigned& budget); ptr_iterator use_list(literal lit) { flatten_use_list(); diff --git a/src/ast/sls/sls_bv_lookahead.cpp b/src/ast/sls/sls_bv_lookahead.cpp index 4198b949b..f9f0c7ffb 100644 --- a/src/ast/sls/sls_bv_lookahead.cpp +++ b/src/ast/sls/sls_bv_lookahead.cpp @@ -95,7 +95,7 @@ namespace sls { for (unsigned i = 0; i < sz; ++i) add_updates(vars[(start + i) % sz]); CTRACE("bv", !m_best_expr, tout << "no guided move\n";); - return apply_update(m_best_expr, m_best_value, "increasing move"); + return apply_update(m_last_atom, m_best_expr, m_best_value, "increasing move"); } /** @@ -117,7 +117,7 @@ namespace sls { auto& v = wval(e); m_v_updated.set_bw(v.bw); v.get_variant(m_v_updated, m_ev.m_rand); - return apply_update(e, m_v_updated, "random update"); + return apply_update(nullptr, e, m_v_updated, "random update"); } /** @@ -153,7 +153,7 @@ namespace sls { v.sub1(m_v_updated); break; } - return apply_update(e, m_v_updated, "random move"); + return apply_update(nullptr, e, m_v_updated, "random move"); } /** @@ -243,7 +243,7 @@ namespace sls { auto& v = wval(e); m_v_updated.set_bw(v.bw); m_v_updated.set_zero(); - apply_update(e, m_v_updated, "reset"); + apply_update(nullptr, e, m_v_updated, "reset"); } } } @@ -517,20 +517,20 @@ namespace sls { * The update is committed. */ - bool bv_lookahead::apply_update(expr* e, bvect const& new_value, char const* reason) { - if (!e || m.is_bool(e) || !wval(e).can_set(new_value)) + bool bv_lookahead::apply_update(expr* p, expr* t, bvect const& new_value, char const* reason) { + if (!t || m.is_bool(t) || !wval(t).can_set(new_value)) return false; - SASSERT(is_uninterp(e)); + SASSERT(is_uninterp(t)); SASSERT(m_restore.empty()); - if (bv.is_bv(e)) { - wval(e).eval = new_value; - VERIFY(wval(e).commit_eval_check_tabu()); + if (bv.is_bv(t)) { + wval(t).eval = new_value; + VERIFY(wval(t).commit_eval_check_tabu()); } - insert_update_stack(e); - unsigned max_depth = get_depth(e); + insert_update_stack(t); + unsigned max_depth = get_depth(t); for (unsigned depth = max_depth; depth <= max_depth; ++depth) { for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { auto e = m_update_stack[depth][i]; @@ -553,11 +553,27 @@ namespace sls { continue; if (ctx.is_true(v) == v1) continue; + if (!p || e == p) + continue; + TRACE("bv", tout << "updated truth value " << v << ": " << mk_bounded_pp(e, m) << "\n";); +#if 0 unsigned num_unsat = ctx.unsat().size(); TRACE("bv", tout << "update flip " << mk_bounded_pp(e, m) << "\n";); + auto r = ctx.reward(v); + auto lit = sat::literal(v, !ctx.is_true(v)); + bool is_bv_lit = is_bv_literal(lit); + verbose_stream() << "flip " << is_bv_literal(lit) << " " << mk_bounded_pp(e, m) << " " << lit << " " << r << " num unsat " << ctx.unsat().size() << "\n"; + + ctx.flip(v); - if (num_unsat < ctx.unsat().size()) + + verbose_stream() << "new unsat " << ctx.unsat().size() << "\n"; + + if (num_unsat < ctx.unsat().size()) { + verbose_stream() << "flip back\n"; ctx.flip(v); + } +#endif } m_ev.set_bool_value(to_app(e), v1); } @@ -573,7 +589,7 @@ namespace sls { } m_in_update_stack.reset(); m_ev.clear_bool_values(); - TRACE("bv", tout << reason << " " << mk_bounded_pp(e, m) + TRACE("bv", tout << reason << " " << mk_bounded_pp(t, m) << " := " << new_value << " score " << m_top_score << "\n";); return true; diff --git a/src/ast/sls/sls_bv_lookahead.h b/src/ast/sls/sls_bv_lookahead.h index 5cc951ee3..e91ed5455 100644 --- a/src/ast/sls/sls_bv_lookahead.h +++ b/src/ast/sls/sls_bv_lookahead.h @@ -112,7 +112,7 @@ namespace sls { void try_set(expr* u, bvect const& new_value); void try_flip(expr* u); void add_updates(expr* u); - bool apply_update(expr* e, bvect const& new_value, char const* reason); + bool apply_update(expr* p, expr* t, bvect const& new_value, char const* reason); bool apply_random_move(ptr_vector const& vars); bool apply_guided_move(ptr_vector const& vars); bool apply_random_update(ptr_vector const& vars); diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h index 4fcc2f730..cd044d2b2 100644 --- a/src/ast/sls/sls_smt_plugin.h +++ b/src/ast/sls/sls_smt_plugin.h @@ -148,7 +148,7 @@ namespace sls { void flip(sat::bool_var v) override { m_ddfw->flip(v); } - double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } + double reward(sat::bool_var v) override { return m_ddfw->reward(v); } double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 87115bdd6..621092bdf 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -85,7 +85,7 @@ namespace sls { sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); } ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); } void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); } - double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); } + double reward(sat::bool_var v) override { return m_ddfw.reward(v); } double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; } bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } unsigned num_vars() const override { return m_ddfw.num_vars(); } diff --git a/src/sat/sat_ddfw_wrapper.cpp b/src/sat/sat_ddfw_wrapper.cpp index 2fba213de..e5c6a399d 100644 --- a/src/sat/sat_ddfw_wrapper.cpp +++ b/src/sat/sat_ddfw_wrapper.cpp @@ -48,7 +48,7 @@ namespace sat { m_ddfw.add_assumptions(); for (unsigned v = 0; v < phase.size(); ++v) { m_ddfw.value(v) = phase[v]; - m_ddfw.reward(v) = 0; + m_ddfw.set_reward(v, 0); m_ddfw.make_count(v) = 0; } m_ddfw.init_clause_data(); diff --git a/src/sat/sat_ddfw_wrapper.h b/src/sat/sat_ddfw_wrapper.h index 6c87c72bd..720b71c03 100644 --- a/src/sat/sat_ddfw_wrapper.h +++ b/src/sat/sat_ddfw_wrapper.h @@ -77,7 +77,7 @@ namespace sat { void flip(bool_var v) { m_ddfw.flip(v); } - inline double get_reward(bool_var v) const { return m_ddfw.get_reward(v); } + inline double get_reward(bool_var v) const { return m_ddfw.reward(v); } void add(unsigned sz, literal const* c) { m_ddfw.add(sz, c); }