diff --git a/src/ast/sls/sls_smt_plugin.cpp b/src/ast/sls/sls_smt_plugin.cpp index a160f8ef4..67049d781 100644 --- a/src/ast/sls/sls_smt_plugin.cpp +++ b/src/ast/sls/sls_smt_plugin.cpp @@ -33,6 +33,7 @@ namespace sls { m_smt2sls_tr(m, m_sls), m_sls2sync_tr(m_sls, m_sync), m_sls2smt_tr(m_sls, m), + m_sync2sls_tr(m_sync, m_sls), m_sync_uninterp(m_sync), m_sls_uninterp(m_sls), m_sync_values(m_sync), @@ -47,7 +48,7 @@ namespace sls { void smt_plugin::check(expr_ref_vector const& fmls, vector const& clauses) { SASSERT(!m_ddfw); // set up state for local search theory_sls here - + m_result = l_undef; m_completed = false; m_units.reset(); @@ -225,6 +226,12 @@ namespace sls { m_has_new_sat_phase = false; updated = true; } + if (m_has_new_smt_values) { + std::lock_guard lock(m_mutex); + export_values_to_sls(); + m_has_new_smt_values = false; + updated = true; + } return updated; } @@ -240,6 +247,18 @@ namespace sls { #endif } + void smt_plugin::export_values_to_sls() { + IF_VERBOSE(2, verbose_stream() << "SMT -> SLS values\n"); + for (auto [var, value] : m_sync_var_values) { + expr_ref var1(m_sls), value1(m_sls); + var1 = m_sync2sls_tr(var.get()); + value1 = m_sync2sls_tr(value.get()); + if (!var1 || !value1) + continue; + m_context.set_value(var1, value1); + } + } + void smt_plugin::smt_phase_to_sls() { #if 0 IF_VERBOSE(2, verbose_stream() << "SMT -> SLS phase\n"); @@ -254,6 +273,21 @@ namespace sls { } void smt_plugin::smt_values_to_sls() { + + if (ctx.parallel_mode()) { + std::scoped_lock lock(m_mutex); + m_sync_var_values.reset(); + for (auto const& [t, t_sync] : m_smt2sync_uninterp) { + expr_ref val_t(m); + if (!ctx.get_smt_value(t, val_t)) + continue; + auto t_sls = expr_ref(m_smt2sls_tr(t), m_sls); + auto val_sls = expr_ref(m_smt2sls_tr(val_t.get()), m_sls); + m_sync_var_values.push_back({ t_sls, val_sls }); + } + m_has_new_smt_values = true; + return; + } #if 0 if (m_value_smt2sls_delay < m_value_smt2sls_delay_threshold) { m_value_smt2sls_delay++; @@ -267,7 +301,7 @@ namespace sls { expr_ref val_t(m); if (!ctx.get_smt_value(t, val_t)) continue; - expr* t_sls = m_smt2sls_tr(t); + auto t_sls = expr_ref(m_smt2sls_tr(t), m_sls); auto val_sls = expr_ref(m_smt2sls_tr(val_t.get()), m_sls); m_context.set_value(t_sls, val_sls); } diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h index d3213c3a5..c91b5c90f 100644 --- a/src/ast/sls/sls_smt_plugin.h +++ b/src/ast/sls/sls_smt_plugin.h @@ -53,7 +53,7 @@ namespace sls { ast_manager& m; ast_manager m_sls; ast_manager m_sync; - ast_translation m_smt2sync_tr, m_smt2sls_tr, m_sls2sync_tr, m_sls2smt_tr; + ast_translation m_smt2sync_tr, m_smt2sls_tr, m_sls2sync_tr, m_sls2smt_tr, m_sync2sls_tr; expr_ref_vector m_sync_uninterp, m_sls_uninterp; expr_ref_vector m_sync_values; sat::ddfw* m_ddfw = nullptr; @@ -73,11 +73,13 @@ namespace sls { unsigned m_min_unsat_size = UINT_MAX; obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp obj_map m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp + vector> m_sync_var_values; std::atomic m_has_new_sls_values = false; uint_set m_shared_bool_vars, m_shared_terms; svector m_sat_phase; std::atomic m_has_new_sat_phase = false; std::atomic m_has_new_sls_phase = false; + std::atomic m_has_new_smt_values = false; svector m_sls_phase; svector m_rewards; svector m_smt_bool_var2sls_bool_var, m_sls_bool_var2smt_bool_var; @@ -91,6 +93,7 @@ namespace sls { void import_phase_from_smt(); void import_values_from_sls(); + void export_values_to_sls(); void export_values_from_sls(); void export_phase_from_sls(); void import_activity_from_sls(); diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp index 8a9e1a581..da2207e56 100644 --- a/src/smt/theory_sls.cpp +++ b/src/smt/theory_sls.cpp @@ -169,10 +169,14 @@ namespace smt { } void theory_sls::run_guided_sls() { + m_smt_plugin->smt_values_to_sls(); + if (m_parallel_mode) + return; + ++m_stats.m_num_guided_sls; m_smt_plugin->smt_phase_to_sls(); m_smt_plugin->smt_units_to_sls(); - m_smt_plugin->smt_values_to_sls(); + bounded_run(m_final_check_ls_steps); dec_final_check_ls_steps(); if (m_smt_plugin) { @@ -225,7 +229,7 @@ namespace smt { } final_check_status theory_sls::final_check_eh() { - if (m_parallel_mode || !m_smt_plugin) + if (!m_smt_plugin) return FC_DONE; ++m_after_resolve_decide_count; if (m_after_resolve_decide_gap > m_after_resolve_decide_count)