3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-13 12:28:44 +00:00

flip tabu on predicate being repaired, add model rotation code

This commit is contained in:
Nikolaj Bjorner 2025-01-02 14:39:36 -08:00
parent f67e1b8b8b
commit 70f7feabc8
8 changed files with 85 additions and 25 deletions

View file

@ -329,7 +329,7 @@ namespace sat {
void ddfw::init_clause_data() {
for (unsigned v = 0; v < num_vars(); ++v) {
make_count(v) = 0;
reward(v) = 0;
m_vars[v].m_reward = 0;
}
m_unsat_vars.reset();
m_unsat.reset();
@ -590,6 +590,44 @@ namespace sat {
m_use_list[(~unit).index()].reset();
}
bool ddfw::try_rotate(bool_var v, bool_var_set& rotated, unsigned& budget) {
if (m_rotate_tabu.contains(v))
return false;
if (budget == 0)
return false;
--budget;
rotated.insert(v);
m_rotate_tabu.insert(v);
flip(v);
switch (m_unsat.size()) {
case 0:
m_rotate_tabu.reset();
m_new_tabu_vars.reset();
return true;
case 1:
for (unsigned cl : m_unsat) {
unsigned sz = m_new_tabu_vars.size();
for (literal lit : get_clause(cl)) {
if (m_rotate_tabu.contains(lit.var()))
continue;
if (try_rotate(lit.var(), rotated, budget))
return true;
m_rotate_tabu.insert(lit.var());
m_new_tabu_vars.push_back(lit.var());
}
while (m_new_tabu_vars.size() > sz)
m_rotate_tabu.remove(m_new_tabu_vars.back()), m_new_tabu_vars.pop_back();
}
break;
default:
break;
}
rotated.remove(v);
m_rotate_tabu.remove(v);
flip(v);
return false;
}
std::ostream& ddfw::display(std::ostream& out) const {
unsigned num_cls = m_clauses.size();
for (unsigned i = 0; i < num_cls; ++i) {
@ -598,7 +636,7 @@ namespace sat {
out << ci.m_num_trues << " w: " << ci.m_weight << "\n";
}
for (unsigned v = 0; v < num_vars(); ++v)
out << (is_true(literal(v, false)) ? "" : "-") << v << " rw: " << get_reward(v) << "\n";
out << (is_true(literal(v, false)) ? "" : "-") << v << " rw: " << reward(v) << "\n";
out << "unsat vars: ";
for (bool_var v : m_unsat_vars)
out << v << " ";

View file

@ -123,7 +123,7 @@ namespace sat {
inline bool value(bool_var v) const { return m_vars[v].m_value; }
inline double& reward(bool_var v) { return m_vars[v].m_reward; }
// inline double reward(bool_var v) { return m_vars[v].m_reward; }
unsigned value_hash() const;
@ -150,9 +150,9 @@ namespace sat {
if (--make_count(v) == 0) m_unsat_vars.remove(v);
}
inline void inc_reward(literal lit, double w) { reward(lit.var()) += w; }
inline void inc_reward(literal lit, double w) { m_vars[lit.var()].m_reward += w; }
inline void dec_reward(literal lit, double w) { reward(lit.var()) -= w; }
inline void dec_reward(literal lit, double w) { m_vars[lit.var()].m_reward -= w; }
void check_with_plugin();
void check_without_plugin();
@ -201,6 +201,9 @@ namespace sat {
inline bool disregard_neighbor();
bool_var_set m_rotate_tabu;
bool_var_vector m_new_tabu_vars;
public:
ddfw() {}
@ -248,7 +251,9 @@ namespace sat {
void flip(bool_var v);
inline double get_reward(bool_var v) const { return m_vars[v].m_reward; }
inline double reward(bool_var v) const { return m_vars[v].m_reward; }
void set_reward(bool_var v, double r) { m_vars[v].m_reward = r; }
double get_reward_avg(bool_var v) const { return m_vars[v].m_reward_avg; }
@ -268,6 +273,7 @@ namespace sat {
void simplify();
bool try_rotate(bool_var v, bool_var_set& rotated, unsigned& budget);
ptr_iterator<unsigned> use_list(literal lit) {
flatten_use_list();

View file

@ -95,7 +95,7 @@ namespace sls {
for (unsigned i = 0; i < sz; ++i)
add_updates(vars[(start + i) % sz]);
CTRACE("bv", !m_best_expr, tout << "no guided move\n";);
return apply_update(m_best_expr, m_best_value, "increasing move");
return apply_update(m_last_atom, m_best_expr, m_best_value, "increasing move");
}
/**
@ -117,7 +117,7 @@ namespace sls {
auto& v = wval(e);
m_v_updated.set_bw(v.bw);
v.get_variant(m_v_updated, m_ev.m_rand);
return apply_update(e, m_v_updated, "random update");
return apply_update(nullptr, e, m_v_updated, "random update");
}
/**
@ -153,7 +153,7 @@ namespace sls {
v.sub1(m_v_updated);
break;
}
return apply_update(e, m_v_updated, "random move");
return apply_update(nullptr, e, m_v_updated, "random move");
}
/**
@ -243,7 +243,7 @@ namespace sls {
auto& v = wval(e);
m_v_updated.set_bw(v.bw);
m_v_updated.set_zero();
apply_update(e, m_v_updated, "reset");
apply_update(nullptr, e, m_v_updated, "reset");
}
}
}
@ -517,20 +517,20 @@ namespace sls {
* The update is committed.
*/
bool bv_lookahead::apply_update(expr* e, bvect const& new_value, char const* reason) {
if (!e || m.is_bool(e) || !wval(e).can_set(new_value))
bool bv_lookahead::apply_update(expr* p, expr* t, bvect const& new_value, char const* reason) {
if (!t || m.is_bool(t) || !wval(t).can_set(new_value))
return false;
SASSERT(is_uninterp(e));
SASSERT(is_uninterp(t));
SASSERT(m_restore.empty());
if (bv.is_bv(e)) {
wval(e).eval = new_value;
VERIFY(wval(e).commit_eval_check_tabu());
if (bv.is_bv(t)) {
wval(t).eval = new_value;
VERIFY(wval(t).commit_eval_check_tabu());
}
insert_update_stack(e);
unsigned max_depth = get_depth(e);
insert_update_stack(t);
unsigned max_depth = get_depth(t);
for (unsigned depth = max_depth; depth <= max_depth; ++depth) {
for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) {
auto e = m_update_stack[depth][i];
@ -553,11 +553,27 @@ namespace sls {
continue;
if (ctx.is_true(v) == v1)
continue;
if (!p || e == p)
continue;
TRACE("bv", tout << "updated truth value " << v << ": " << mk_bounded_pp(e, m) << "\n";);
#if 0
unsigned num_unsat = ctx.unsat().size();
TRACE("bv", tout << "update flip " << mk_bounded_pp(e, m) << "\n";);
auto r = ctx.reward(v);
auto lit = sat::literal(v, !ctx.is_true(v));
bool is_bv_lit = is_bv_literal(lit);
verbose_stream() << "flip " << is_bv_literal(lit) << " " << mk_bounded_pp(e, m) << " " << lit << " " << r << " num unsat " << ctx.unsat().size() << "\n";
ctx.flip(v);
if (num_unsat < ctx.unsat().size())
verbose_stream() << "new unsat " << ctx.unsat().size() << "\n";
if (num_unsat < ctx.unsat().size()) {
verbose_stream() << "flip back\n";
ctx.flip(v);
}
#endif
}
m_ev.set_bool_value(to_app(e), v1);
}
@ -573,7 +589,7 @@ namespace sls {
}
m_in_update_stack.reset();
m_ev.clear_bool_values();
TRACE("bv", tout << reason << " " << mk_bounded_pp(e, m)
TRACE("bv", tout << reason << " " << mk_bounded_pp(t, m)
<< " := " << new_value
<< " score " << m_top_score << "\n";);
return true;

View file

@ -112,7 +112,7 @@ namespace sls {
void try_set(expr* u, bvect const& new_value);
void try_flip(expr* u);
void add_updates(expr* u);
bool apply_update(expr* e, bvect const& new_value, char const* reason);
bool apply_update(expr* p, expr* t, bvect const& new_value, char const* reason);
bool apply_random_move(ptr_vector<expr> const& vars);
bool apply_guided_move(ptr_vector<expr> const& vars);
bool apply_random_update(ptr_vector<expr> const& vars);

View file

@ -148,7 +148,7 @@ namespace sls {
void flip(sat::bool_var v) override {
m_ddfw->flip(v);
}
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
double reward(sat::bool_var v) override { return m_ddfw->reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override {
return m_ddfw->get_value(lit.var()) != lit.sign();

View file

@ -85,7 +85,7 @@ namespace sls {
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); }
void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); }
double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); }
double reward(sat::bool_var v) override { return m_ddfw.reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); }
unsigned num_vars() const override { return m_ddfw.num_vars(); }

View file

@ -48,7 +48,7 @@ namespace sat {
m_ddfw.add_assumptions();
for (unsigned v = 0; v < phase.size(); ++v) {
m_ddfw.value(v) = phase[v];
m_ddfw.reward(v) = 0;
m_ddfw.set_reward(v, 0);
m_ddfw.make_count(v) = 0;
}
m_ddfw.init_clause_data();

View file

@ -77,7 +77,7 @@ namespace sat {
void flip(bool_var v) { m_ddfw.flip(v); }
inline double get_reward(bool_var v) const { return m_ddfw.get_reward(v); }
inline double get_reward(bool_var v) const { return m_ddfw.reward(v); }
void add(unsigned sz, literal const* c) { m_ddfw.add(sz, c); }