diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index ad5ed75b3..4dcfa8bbb 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -69,6 +69,8 @@ namespace sls { register_plugin(alloc(array_plugin, *this)); else if (fid == datatype_util(m).get_family_id()) register_plugin(alloc(datatype_plugin, *this)); + else if (fid == null_family_id) + ; else verbose_stream() << "did not find plugin for " << fid << "\n"; } @@ -242,7 +244,7 @@ namespace sls { fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); if (m.is_distinct(e)) fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); - if (fid == null_family_id || fid == model_value_family_id) + if ((fid == null_family_id && to_app(e)->get_num_args() > 0) || fid == model_value_family_id) fid = user_sort_family_id; return fid; } @@ -489,9 +491,9 @@ namespace sls { m_unit_indices.insert(lit.index()); verbose_stream() << "UNITS " << m_unit_literals << "\n"; - for (auto a : m_atoms) - if (a) - register_terms(a); + for (unsigned i = 0; i < m_atoms.size(); ++i) + if (m_atoms.get(i)) + register_terms(m_atoms.get(i)); for (auto p : m_plugins) if (p) p->initialize(); diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index bd83552ac..1c5dbd79b 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -44,8 +44,7 @@ namespace sls { solver& s; sat::ddfw* m_ddfw; sls::context m_context; - bool m_new_clause_added = false; - unsigned m_num_shared_vars = 0; + bool m_new_clause_added = false; unsigned m_min_unsat_size = UINT_MAX; ast_manager m_sync_manager; obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp @@ -65,7 +64,7 @@ namespace sls { std::lock_guard lock(s.m_mutex); IF_VERBOSE(1, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n"); for (auto lit : s.m_units) - if (lit.var() < m_num_shared_vars) + if (m_shared_vars.contains(lit.var())) m_ddfw->add(1, &lit); s.m_has_units = false; s.m_units.reset(); @@ -77,7 +76,7 @@ namespace sls { return false; std::lock_guard lock(s.m_mutex); IF_VERBOSE(1, verbose_stream() << "SMT -> SLS phase\n"); - for (unsigned i = 0; i < m_sat_phase.size(); ++i) { + for (auto i : m_shared_vars) { if (m_sat_phase[i] != is_true(sat::literal(i, false))) flip(i); m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1; @@ -88,7 +87,7 @@ namespace sls { bool export_to_sls() { bool updated = false; - if (export_units_to_sls()) + if (false && export_units_to_sls()) updated = true; if (export_phase_to_sls()) updated = true; @@ -103,8 +102,8 @@ namespace sls { if (unsat().size() > m_min_unsat_size) return; m_min_unsat_size = unsat().size(); - std::lock_guard lock(s.m_mutex); - for (unsigned v = 0; v < m_num_shared_vars; ++v) { + std::lock_guard lock(s.m_mutex); + for (auto v : m_shared_vars) { m_rewards[v] = m_ddfw->get_reward_avg(v); m_sls_phase[v] = l_true == m_ddfw->get_model()[v]; m_has_new_sls_phase = true; @@ -128,8 +127,10 @@ namespace sls { smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : m(m), s(s), m_ddfw(d), m_context(m, *this), m_sync_uninterp(m_sync_manager), - m_sync_values(m_sync_manager) {} + m_sync_values(m_sync_manager) { + } + uint_set m_shared_vars; svector m_sat_phase; std::atomic m_has_new_sat_phase = false; @@ -148,6 +149,11 @@ namespace sls { m_sls2sync_uninterp.insert(sls_t, sync_t); } + void add_shared_var(sat::bool_var v) { + m_sls_phase.reserve(v + 1); + m_rewards.reserve(v + 1); + m_shared_vars.insert(v); + } void init_search() override {} @@ -301,9 +307,12 @@ namespace sls { m_smt_plugin = alloc(smt_plugin, *m_slsm, *this, m_ddfw.get()); m_ddfw->set_plugin(m_smt_plugin); m_ddfw->updt_params(s().params()); - for (auto const& clause : ctx.top_level_clauses()) + for (auto const& clause : ctx.top_level_clauses()) { m_ddfw->add(clause.size(), clause.data()); - for (sat::bool_var v = 0; v < s().num_vars(); ++v) { + for (auto lit : clause) + m_smt_plugin->add_shared_var(lit.var()); + } + for (auto v : m_smt_plugin->m_shared_vars) { expr* e = ctx.bool_var2expr(v); if (!e) continue;