mirror of
https://github.com/Z3Prover/z3
synced 2025-04-22 16:45:31 +00:00
Track shared variables using a unit set
This commit is contained in:
parent
59b0e46d99
commit
185ddd6488
2 changed files with 25 additions and 14 deletions
|
@ -69,6 +69,8 @@ namespace sls {
|
|||
register_plugin(alloc(array_plugin, *this));
|
||||
else if (fid == datatype_util(m).get_family_id())
|
||||
register_plugin(alloc(datatype_plugin, *this));
|
||||
else if (fid == null_family_id)
|
||||
;
|
||||
else
|
||||
verbose_stream() << "did not find plugin for " << fid << "\n";
|
||||
}
|
||||
|
@ -242,7 +244,7 @@ namespace sls {
|
|||
fid = to_app(e)->get_arg(0)->get_sort()->get_family_id();
|
||||
if (m.is_distinct(e))
|
||||
fid = to_app(e)->get_arg(0)->get_sort()->get_family_id();
|
||||
if (fid == null_family_id || fid == model_value_family_id)
|
||||
if ((fid == null_family_id && to_app(e)->get_num_args() > 0) || fid == model_value_family_id)
|
||||
fid = user_sort_family_id;
|
||||
return fid;
|
||||
}
|
||||
|
@ -489,9 +491,9 @@ namespace sls {
|
|||
m_unit_indices.insert(lit.index());
|
||||
|
||||
verbose_stream() << "UNITS " << m_unit_literals << "\n";
|
||||
for (auto a : m_atoms)
|
||||
if (a)
|
||||
register_terms(a);
|
||||
for (unsigned i = 0; i < m_atoms.size(); ++i)
|
||||
if (m_atoms.get(i))
|
||||
register_terms(m_atoms.get(i));
|
||||
for (auto p : m_plugins)
|
||||
if (p)
|
||||
p->initialize();
|
||||
|
|
|
@ -44,8 +44,7 @@ namespace sls {
|
|||
solver& s;
|
||||
sat::ddfw* m_ddfw;
|
||||
sls::context m_context;
|
||||
bool m_new_clause_added = false;
|
||||
unsigned m_num_shared_vars = 0;
|
||||
bool m_new_clause_added = false;
|
||||
unsigned m_min_unsat_size = UINT_MAX;
|
||||
ast_manager m_sync_manager;
|
||||
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
|
||||
|
@ -65,7 +64,7 @@ namespace sls {
|
|||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
IF_VERBOSE(1, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n");
|
||||
for (auto lit : s.m_units)
|
||||
if (lit.var() < m_num_shared_vars)
|
||||
if (m_shared_vars.contains(lit.var()))
|
||||
m_ddfw->add(1, &lit);
|
||||
s.m_has_units = false;
|
||||
s.m_units.reset();
|
||||
|
@ -77,7 +76,7 @@ namespace sls {
|
|||
return false;
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
IF_VERBOSE(1, verbose_stream() << "SMT -> SLS phase\n");
|
||||
for (unsigned i = 0; i < m_sat_phase.size(); ++i) {
|
||||
for (auto i : m_shared_vars) {
|
||||
if (m_sat_phase[i] != is_true(sat::literal(i, false)))
|
||||
flip(i);
|
||||
m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1;
|
||||
|
@ -88,7 +87,7 @@ namespace sls {
|
|||
|
||||
bool export_to_sls() {
|
||||
bool updated = false;
|
||||
if (export_units_to_sls())
|
||||
if (false && export_units_to_sls())
|
||||
updated = true;
|
||||
if (export_phase_to_sls())
|
||||
updated = true;
|
||||
|
@ -103,8 +102,8 @@ namespace sls {
|
|||
if (unsat().size() > m_min_unsat_size)
|
||||
return;
|
||||
m_min_unsat_size = unsat().size();
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (unsigned v = 0; v < m_num_shared_vars; ++v) {
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (auto v : m_shared_vars) {
|
||||
m_rewards[v] = m_ddfw->get_reward_avg(v);
|
||||
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
|
||||
m_has_new_sls_phase = true;
|
||||
|
@ -128,8 +127,10 @@ namespace sls {
|
|||
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
|
||||
m(m), s(s), m_ddfw(d), m_context(m, *this),
|
||||
m_sync_uninterp(m_sync_manager),
|
||||
m_sync_values(m_sync_manager) {}
|
||||
m_sync_values(m_sync_manager) {
|
||||
}
|
||||
|
||||
uint_set m_shared_vars;
|
||||
svector<bool> m_sat_phase;
|
||||
std::atomic<bool> m_has_new_sat_phase = false;
|
||||
|
||||
|
@ -148,6 +149,11 @@ namespace sls {
|
|||
m_sls2sync_uninterp.insert(sls_t, sync_t);
|
||||
}
|
||||
|
||||
void add_shared_var(sat::bool_var v) {
|
||||
m_sls_phase.reserve(v + 1);
|
||||
m_rewards.reserve(v + 1);
|
||||
m_shared_vars.insert(v);
|
||||
}
|
||||
|
||||
void init_search() override {}
|
||||
|
||||
|
@ -301,9 +307,12 @@ namespace sls {
|
|||
m_smt_plugin = alloc(smt_plugin, *m_slsm, *this, m_ddfw.get());
|
||||
m_ddfw->set_plugin(m_smt_plugin);
|
||||
m_ddfw->updt_params(s().params());
|
||||
for (auto const& clause : ctx.top_level_clauses())
|
||||
for (auto const& clause : ctx.top_level_clauses()) {
|
||||
m_ddfw->add(clause.size(), clause.data());
|
||||
for (sat::bool_var v = 0; v < s().num_vars(); ++v) {
|
||||
for (auto lit : clause)
|
||||
m_smt_plugin->add_shared_var(lit.var());
|
||||
}
|
||||
for (auto v : m_smt_plugin->m_shared_vars) {
|
||||
expr* e = ctx.bool_var2expr(v);
|
||||
if (!e)
|
||||
continue;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue