From 2b3d2ea055fe37a30173fc1e423e2b9763e6c168 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Wed, 2 Jul 2025 18:14:29 -0700 Subject: [PATCH] another try to avoid the race with atomic swap Signed-off-by: Lev Nachmanson --- src/smt/tactic/smt_tactic_core.cpp | 123 ++++++++++++++++------------- 1 file changed, 67 insertions(+), 56 deletions(-) diff --git a/src/smt/tactic/smt_tactic_core.cpp b/src/smt/tactic/smt_tactic_core.cpp index 89a6963b5..2dd82106d 100644 --- a/src/smt/tactic/smt_tactic_core.cpp +++ b/src/smt/tactic/smt_tactic_core.cpp @@ -44,8 +44,7 @@ class smt_tactic : public tactic { expr_ref_vector m_vars; vector> m_values; statistics m_stats; - smt::kernel* m_ctx = nullptr; - mutable std::atomic m_ctx_destroying; + mutable std::atomic m_ctx; symbol m_logic; progress_callback* m_callback = nullptr; bool m_candidate_models = false; @@ -56,7 +55,7 @@ public: m(m), m_params_ref(p), m_vars(m), - m_ctx_destroying(false) { + m_ctx(nullptr) { updt_params_core(p); TRACE(smt_tactic, tout << "p: " << p << "\n";); } @@ -66,7 +65,7 @@ public: } ~smt_tactic() override { - SASSERT(m_ctx == nullptr); + SASSERT(m_ctx.load() == nullptr); } char const* name() const override { return "smt"; } @@ -91,8 +90,8 @@ public: fparams().updt_params(p); m_params_ref.copy(p); m_logic = p.get_sym(symbol("logic"), m_logic); - if (m_logic != symbol::null && m_ctx) { - m_ctx->set_logic(m_logic); + if (m_logic != symbol::null && m_ctx.load()) { + m_ctx.load()->set_logic(m_logic); } SASSERT(p.get_bool("auto_config", fparams().m_auto_config) == fparams().m_auto_config); } @@ -104,15 +103,25 @@ public: void collect_statistics(statistics & st) const override { - // Check if context is being destroyed to avoid race condition - if (!m_ctx_destroying.load() && m_ctx) { - try { - m_ctx->collect_statistics(st); + // Use compare_exchange to safely extract the context pointer + // If we win the race, we're responsible for deallocation + smt::kernel* ctx = m_ctx.load(); + if (ctx) { + // Try to atomically swap the context with nullptr + // If successful, we own the context and must deallocate it + if (m_ctx.compare_exchange_strong(ctx, nullptr)) { + try { + ctx->collect_statistics(st); + } catch (...) { + // If exception occurs, fall back to cached stats + st.copy(m_stats); + } + // We won the race, so we're responsible for cleanup + dealloc(ctx); return; - } catch (...) { - // If exception occurs, fall back to cached stats } } + // Either no context or lost the race - use cached stats st.copy(m_stats); } @@ -131,6 +140,12 @@ public: m_callback = callback; } +private: + // Helper to safely get context pointer (only use within scoped_init_ctx lifetime) + smt::kernel* get_ctx() const { return m_ctx.load(); } + +public: + struct scoped_init_ctx { smt_tactic & m_owner; smt_params m_params; // smt-setup overwrites parameters depending on the current assertions. @@ -146,22 +161,18 @@ public: if (o.m_callback) { new_ctx->set_progress_callback(o.m_callback); } - o.m_ctx = new_ctx; + o.m_ctx.store(new_ctx); } ~scoped_init_ctx() { - // Signal that we're destroying the context - m_owner.m_ctx_destroying.store(true); - - smt::kernel * d = m_owner.m_ctx; - m_owner.m_ctx = nullptr; - m_owner.m_user_ctx = nullptr; - - if (d) - dealloc(d); - - // Reset the flag after cleanup - m_owner.m_ctx_destroying.store(false); + // Try to atomically swap the context with nullptr + smt::kernel* ctx = m_owner.m_ctx.load(); + if (ctx && m_owner.m_ctx.compare_exchange_strong(ctx, nullptr)) { + // We won the race, so we're responsible for cleanup + m_owner.m_user_ctx = nullptr; + dealloc(ctx); + } + // If we lost the race, collect_statistics() is handling the cleanup } }; @@ -184,7 +195,7 @@ public: TRACE(smt_tactic_detail, in->display(tout);); TRACE(smt_tactic_memory, tout << "wasted_size: " << m.get_allocator().get_wasted_size() << "\n";); scoped_init_ctx init(*this, m); - SASSERT(m_ctx); + SASSERT(m_ctx.load()); expr_ref_vector clauses(m); expr2expr_map bool2dep; @@ -197,22 +208,22 @@ public: if (in->proofs_enabled() && !assumptions.empty()) throw tactic_exception("smt tactic does not support simultaneous generation of proofs and unsat cores"); for (unsigned i = 0; i < clauses.size(); ++i) { - m_ctx->assert_expr(clauses[i].get()); + get_ctx()->assert_expr(clauses[i].get()); } } else if (in->proofs_enabled()) { unsigned sz = in->size(); for (unsigned i = 0; i < sz; i++) { - m_ctx->assert_expr(in->form(i), in->pr(i)); + get_ctx()->assert_expr(in->form(i), in->pr(i)); } } else { unsigned sz = in->size(); for (unsigned i = 0; i < sz; i++) { - m_ctx->assert_expr(in->form(i)); + get_ctx()->assert_expr(in->form(i)); } } - if (m_ctx->canceled()) { + if (get_ctx()->canceled()) { throw tactic_exception(Z3_CANCELED_MSG); } user_propagate_delay_init(); @@ -220,18 +231,18 @@ public: lbool r; try { if (assumptions.empty() && !m_user_ctx) - r = m_ctx->setup_and_check(); + r = get_ctx()->setup_and_check(); else - r = m_ctx->check(assumptions.size(), assumptions.data()); + r = get_ctx()->check(assumptions.size(), assumptions.data()); } catch(...) { TRACE(smt_tactic, tout << "exception\n";); - m_ctx->collect_statistics(m_stats); + get_ctx()->collect_statistics(m_stats); throw; } SASSERT(m_ctx); - m_ctx->collect_statistics(m_stats); - proof_ref pr(m_ctx->get_proof(), m); + get_ctx()->collect_statistics(m_stats); + proof_ref pr(get_ctx()->get_proof(), m); TRACE(smt_tactic, tout << r << " " << pr << "\n";); switch (r) { case l_true: { @@ -243,9 +254,9 @@ public: // store the model in a no-op model converter, and filter fresh Booleans if (in->models_enabled()) { model_ref md; - m_ctx->get_model(md); + get_ctx()->get_model(md); buffer r; - m_ctx->get_relevant_labels(nullptr, r); + get_ctx()->get_relevant_labels(nullptr, r); labels_vec rv; rv.append(r.size(), r.data()); model_converter_ref mc; @@ -253,7 +264,7 @@ public: mc = concat(fmc.get(), mc.get()); in->add(mc.get()); } - if (m_ctx->canceled()) + if (get_ctx()->canceled()) throw tactic_exception(Z3_CANCELED_MSG); return; } @@ -266,9 +277,9 @@ public: in->reset(); expr_dependency * lcore = nullptr; if (in->unsat_core_enabled()) { - unsigned sz = m_ctx->get_unsat_core_size(); + unsigned sz = get_ctx()->get_unsat_core_size(); for (unsigned i = 0; i < sz; i++) { - expr * b = m_ctx->get_unsat_core_expr(i); + expr * b = get_ctx()->get_unsat_core_expr(i); SASSERT(is_uninterp_const(b) && m.is_bool(b)); expr * d = bool2dep.find(b); lcore = m.mk_join(lcore, m.mk_leaf(d)); @@ -284,13 +295,13 @@ public: } case l_undef: - if (m_ctx->canceled() && !pr) { + if (get_ctx()->canceled() && !pr) { throw tactic_exception(Z3_CANCELED_MSG); } if (m_fail_if_inconclusive && !m_candidate_models && !pr) { std::stringstream strm; - strm << "smt tactic failed to show goal to be sat/unsat " << m_ctx->last_failure_as_string(); + strm << "smt tactic failed to show goal to be sat/unsat " << get_ctx()->last_failure_as_string(); throw tactic_exception(strm.str()); } result.push_back(in.get()); @@ -300,15 +311,15 @@ public: in->updt_prec(goal::UNDER_OVER); } if (m_candidate_models) { - switch (m_ctx->last_failure()) { + switch (get_ctx()->last_failure()) { case smt::NUM_CONFLICTS: case smt::THEORY: case smt::QUANTIFIERS: if (in->models_enabled()) { model_ref md; - m_ctx->get_model(md); + get_ctx()->get_model(md); buffer r; - m_ctx->get_relevant_labels(nullptr, r); + get_ctx()->get_relevant_labels(nullptr, r); labels_vec rv; rv.append(r.size(), r.data()); in->add(model_and_labels2model_converter(md.get(), rv)); @@ -321,7 +332,7 @@ public: if (pr) { return; } - throw tactic_exception(m_ctx->last_failure_as_string()); + throw tactic_exception(get_ctx()->last_failure_as_string()); } } catch (rewriter_exception & ex) { @@ -344,24 +355,24 @@ public: void on_clause_delay_init() { if (m_on_clause_eh) - m_ctx->register_on_clause(m_on_clause_ctx, m_on_clause_eh); + get_ctx()->register_on_clause(m_on_clause_ctx, m_on_clause_eh); } void user_propagate_delay_init() { if (!m_user_ctx) return; - m_ctx->user_propagate_init(m_user_ctx, m_push_eh, m_pop_eh, m_fresh_eh); - if (m_fixed_eh) m_ctx->user_propagate_register_fixed(m_fixed_eh); - if (m_final_eh) m_ctx->user_propagate_register_final(m_final_eh); - if (m_eq_eh) m_ctx->user_propagate_register_eq(m_eq_eh); - if (m_diseq_eh) m_ctx->user_propagate_register_diseq(m_diseq_eh); - if (m_created_eh) m_ctx->user_propagate_register_created(m_created_eh); - if (m_decide_eh) m_ctx->user_propagate_register_decide(m_decide_eh); + get_ctx()->user_propagate_init(m_user_ctx, m_push_eh, m_pop_eh, m_fresh_eh); + if (m_fixed_eh) get_ctx()->user_propagate_register_fixed(m_fixed_eh); + if (m_final_eh) get_ctx()->user_propagate_register_final(m_final_eh); + if (m_eq_eh) get_ctx()->user_propagate_register_eq(m_eq_eh); + if (m_diseq_eh) get_ctx()->user_propagate_register_diseq(m_diseq_eh); + if (m_created_eh) get_ctx()->user_propagate_register_created(m_created_eh); + if (m_decide_eh) get_ctx()->user_propagate_register_decide(m_decide_eh); for (expr* v : m_vars) - m_ctx->user_propagate_register_expr(v); + get_ctx()->user_propagate_register_expr(v); for (auto& [var, value] : m_values) - m_ctx->user_propagate_initialize_value(var, value); + get_ctx()->user_propagate_initialize_value(var, value); } void user_propagate_clear() override {