3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

Track shared variables using a unit set

This commit is contained in:
Nikolaj Bjorner 2024-10-20 17:54:44 -07:00
parent 59b0e46d99
commit 185ddd6488
2 changed files with 25 additions and 14 deletions

View file

@ -69,6 +69,8 @@ namespace sls {
register_plugin(alloc(array_plugin, *this));
else if (fid == datatype_util(m).get_family_id())
register_plugin(alloc(datatype_plugin, *this));
else if (fid == null_family_id)
;
else
verbose_stream() << "did not find plugin for " << fid << "\n";
}
@ -242,7 +244,7 @@ namespace sls {
fid = to_app(e)->get_arg(0)->get_sort()->get_family_id();
if (m.is_distinct(e))
fid = to_app(e)->get_arg(0)->get_sort()->get_family_id();
if (fid == null_family_id || fid == model_value_family_id)
if ((fid == null_family_id && to_app(e)->get_num_args() > 0) || fid == model_value_family_id)
fid = user_sort_family_id;
return fid;
}
@ -489,9 +491,9 @@ namespace sls {
m_unit_indices.insert(lit.index());
verbose_stream() << "UNITS " << m_unit_literals << "\n";
for (auto a : m_atoms)
if (a)
register_terms(a);
for (unsigned i = 0; i < m_atoms.size(); ++i)
if (m_atoms.get(i))
register_terms(m_atoms.get(i));
for (auto p : m_plugins)
if (p)
p->initialize();

View file

@ -44,8 +44,7 @@ namespace sls {
solver& s;
sat::ddfw* m_ddfw;
sls::context m_context;
bool m_new_clause_added = false;
unsigned m_num_shared_vars = 0;
bool m_new_clause_added = false;
unsigned m_min_unsat_size = UINT_MAX;
ast_manager m_sync_manager;
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
@ -65,7 +64,7 @@ namespace sls {
std::lock_guard<std::mutex> lock(s.m_mutex);
IF_VERBOSE(1, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n");
for (auto lit : s.m_units)
if (lit.var() < m_num_shared_vars)
if (m_shared_vars.contains(lit.var()))
m_ddfw->add(1, &lit);
s.m_has_units = false;
s.m_units.reset();
@ -77,7 +76,7 @@ namespace sls {
return false;
std::lock_guard<std::mutex> lock(s.m_mutex);
IF_VERBOSE(1, verbose_stream() << "SMT -> SLS phase\n");
for (unsigned i = 0; i < m_sat_phase.size(); ++i) {
for (auto i : m_shared_vars) {
if (m_sat_phase[i] != is_true(sat::literal(i, false)))
flip(i);
m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1;
@ -88,7 +87,7 @@ namespace sls {
bool export_to_sls() {
bool updated = false;
if (export_units_to_sls())
if (false && export_units_to_sls())
updated = true;
if (export_phase_to_sls())
updated = true;
@ -103,8 +102,8 @@ namespace sls {
if (unsat().size() > m_min_unsat_size)
return;
m_min_unsat_size = unsat().size();
std::lock_guard<std::mutex> lock(s.m_mutex);
for (unsigned v = 0; v < m_num_shared_vars; ++v) {
std::lock_guard<std::mutex> lock(s.m_mutex);
for (auto v : m_shared_vars) {
m_rewards[v] = m_ddfw->get_reward_avg(v);
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
m_has_new_sls_phase = true;
@ -128,8 +127,10 @@ namespace sls {
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
m(m), s(s), m_ddfw(d), m_context(m, *this),
m_sync_uninterp(m_sync_manager),
m_sync_values(m_sync_manager) {}
m_sync_values(m_sync_manager) {
}
uint_set m_shared_vars;
svector<bool> m_sat_phase;
std::atomic<bool> m_has_new_sat_phase = false;
@ -148,6 +149,11 @@ namespace sls {
m_sls2sync_uninterp.insert(sls_t, sync_t);
}
void add_shared_var(sat::bool_var v) {
m_sls_phase.reserve(v + 1);
m_rewards.reserve(v + 1);
m_shared_vars.insert(v);
}
void init_search() override {}
@ -301,9 +307,12 @@ namespace sls {
m_smt_plugin = alloc(smt_plugin, *m_slsm, *this, m_ddfw.get());
m_ddfw->set_plugin(m_smt_plugin);
m_ddfw->updt_params(s().params());
for (auto const& clause : ctx.top_level_clauses())
for (auto const& clause : ctx.top_level_clauses()) {
m_ddfw->add(clause.size(), clause.data());
for (sat::bool_var v = 0; v < s().num_vars(); ++v) {
for (auto lit : clause)
m_smt_plugin->add_shared_var(lit.var());
}
for (auto v : m_smt_plugin->m_shared_vars) {
expr* e = ctx.bool_var2expr(v);
if (!e)
continue;