diff --git a/src/smt/tactic/smt_tactic_core.cpp b/src/smt/tactic/smt_tactic_core.cpp index 7d01ce2ce..bbaf99e4d 100644 --- a/src/smt/tactic/smt_tactic_core.cpp +++ b/src/smt/tactic/smt_tactic_core.cpp @@ -32,6 +32,7 @@ Notes: #include "solver/mus.h" #include "solver/parallel_tactical.h" #include "solver/parallel_params.hpp" +#include typedef obj_map expr2expr_map; @@ -43,11 +44,12 @@ class smt_tactic : public tactic { expr_ref_vector m_vars; vector> m_values; statistics m_stats; - smt::kernel* m_ctx = nullptr; + std::atomic m_ctx = nullptr; symbol m_logic; progress_callback* m_callback = nullptr; bool m_candidate_models = false; bool m_fail_if_inconclusive = false; + mutable std::mutex m_mutex; public: smt_tactic(ast_manager& m, params_ref const & p): @@ -63,7 +65,7 @@ public: } ~smt_tactic() override { - SASSERT(m_ctx == nullptr); + SASSERT(m_ctx.load() == nullptr); } char const* name() const override { return "smt"; } @@ -88,8 +90,8 @@ public: fparams().updt_params(p); m_params_ref.copy(p); m_logic = p.get_sym(symbol("logic"), m_logic); - if (m_logic != symbol::null && m_ctx) { - m_ctx->set_logic(m_logic); + if (m_logic != symbol::null && m_ctx.load()) { + m_ctx.load()->set_logic(m_logic); } SASSERT(p.get_bool("auto_config", fparams().m_auto_config) == fparams().m_auto_config); } @@ -99,12 +101,15 @@ public: smt_params_helper::collect_param_descrs(r); } - void collect_statistics(statistics & st) const override { - if (m_ctx) - m_ctx->collect_statistics(st); // ctx is still running... - else - st.copy(m_stats); + if (m_ctx.load()) { + std::scoped_lock lock(m_mutex); + if (m_ctx.load()) { + m_ctx.load()->collect_statistics(st); // ctx is still running... + return; + } + } + st.copy(m_stats); } void cleanup() override { @@ -141,10 +146,13 @@ public: } ~scoped_init_ctx() { - smt::kernel * d = m_owner.m_ctx; - m_owner.m_ctx = nullptr; - m_owner.m_user_ctx = nullptr; - + smt::kernel* d = nullptr; + { + std::scoped_lock lock(m_owner.m_mutex); + d = m_owner.m_ctx.load(); + m_owner.m_ctx = nullptr; + m_owner.m_user_ctx = nullptr; + } if (d) dealloc(d); } @@ -169,7 +177,7 @@ public: TRACE(smt_tactic_detail, in->display(tout);); TRACE(smt_tactic_memory, tout << "wasted_size: " << m.get_allocator().get_wasted_size() << "\n";); scoped_init_ctx init(*this, m); - SASSERT(m_ctx); + SASSERT(m_ctx.load()); expr_ref_vector clauses(m); expr2expr_map bool2dep; @@ -182,22 +190,22 @@ public: if (in->proofs_enabled() && !assumptions.empty()) throw tactic_exception("smt tactic does not support simultaneous generation of proofs and unsat cores"); for (unsigned i = 0; i < clauses.size(); ++i) { - m_ctx->assert_expr(clauses[i].get()); + m_ctx.load()->assert_expr(clauses[i].get()); } } else if (in->proofs_enabled()) { unsigned sz = in->size(); for (unsigned i = 0; i < sz; i++) { - m_ctx->assert_expr(in->form(i), in->pr(i)); + m_ctx.load()->assert_expr(in->form(i), in->pr(i)); } } else { unsigned sz = in->size(); for (unsigned i = 0; i < sz; i++) { - m_ctx->assert_expr(in->form(i)); + m_ctx.load()->assert_expr(in->form(i)); } } - if (m_ctx->canceled()) { + if (m_ctx.load()->canceled()) { throw tactic_exception(Z3_CANCELED_MSG); } user_propagate_delay_init(); @@ -205,18 +213,17 @@ public: lbool r; try { if (assumptions.empty() && !m_user_ctx) - r = m_ctx->setup_and_check(); + r = m_ctx.load()->setup_and_check(); else - r = m_ctx->check(assumptions.size(), assumptions.data()); + r = m_ctx.load()->check(assumptions.size(), assumptions.data()); } catch(...) { TRACE(smt_tactic, tout << "exception\n";); - m_ctx->collect_statistics(m_stats); + m_ctx.load()->collect_statistics(m_stats); throw; } - SASSERT(m_ctx); - m_ctx->collect_statistics(m_stats); - proof_ref pr(m_ctx->get_proof(), m); + m_ctx.load()->collect_statistics(m_stats); + proof_ref pr(m_ctx.load()->get_proof(), m); TRACE(smt_tactic, tout << r << " " << pr << "\n";); switch (r) { case l_true: { @@ -228,9 +235,9 @@ public: // store the model in a no-op model converter, and filter fresh Booleans if (in->models_enabled()) { model_ref md; - m_ctx->get_model(md); + m_ctx.load()->get_model(md); buffer r; - m_ctx->get_relevant_labels(nullptr, r); + m_ctx.load()->get_relevant_labels(nullptr, r); labels_vec rv; rv.append(r.size(), r.data()); model_converter_ref mc; @@ -238,7 +245,7 @@ public: mc = concat(fmc.get(), mc.get()); in->add(mc.get()); } - if (m_ctx->canceled()) + if (m_ctx.load()->canceled()) throw tactic_exception(Z3_CANCELED_MSG); return; } @@ -251,9 +258,9 @@ public: in->reset(); expr_dependency * lcore = nullptr; if (in->unsat_core_enabled()) { - unsigned sz = m_ctx->get_unsat_core_size(); + unsigned sz = m_ctx.load()->get_unsat_core_size(); for (unsigned i = 0; i < sz; i++) { - expr * b = m_ctx->get_unsat_core_expr(i); + expr * b = m_ctx.load()->get_unsat_core_expr(i); SASSERT(is_uninterp_const(b) && m.is_bool(b)); expr * d = bool2dep.find(b); lcore = m.mk_join(lcore, m.mk_leaf(d)); @@ -269,13 +276,13 @@ public: } case l_undef: - if (m_ctx->canceled() && !pr) { + if (m_ctx.load()->canceled() && !pr) { throw tactic_exception(Z3_CANCELED_MSG); } if (m_fail_if_inconclusive && !m_candidate_models && !pr) { std::stringstream strm; - strm << "smt tactic failed to show goal to be sat/unsat " << m_ctx->last_failure_as_string(); + strm << "smt tactic failed to show goal to be sat/unsat " << m_ctx.load()->last_failure_as_string(); throw tactic_exception(strm.str()); } result.push_back(in.get()); @@ -285,15 +292,15 @@ public: in->updt_prec(goal::UNDER_OVER); } if (m_candidate_models) { - switch (m_ctx->last_failure()) { + switch (m_ctx.load()->last_failure()) { case smt::NUM_CONFLICTS: case smt::THEORY: case smt::QUANTIFIERS: if (in->models_enabled()) { model_ref md; - m_ctx->get_model(md); + m_ctx.load()->get_model(md); buffer r; - m_ctx->get_relevant_labels(nullptr, r); + m_ctx.load()->get_relevant_labels(nullptr, r); labels_vec rv; rv.append(r.size(), r.data()); in->add(model_and_labels2model_converter(md.get(), rv)); @@ -306,7 +313,7 @@ public: if (pr) { return; } - throw tactic_exception(m_ctx->last_failure_as_string()); + throw tactic_exception(m_ctx.load()->last_failure_as_string()); } } catch (rewriter_exception & ex) { @@ -329,24 +336,24 @@ public: void on_clause_delay_init() { if (m_on_clause_eh) - m_ctx->register_on_clause(m_on_clause_ctx, m_on_clause_eh); + m_ctx.load()->register_on_clause(m_on_clause_ctx, m_on_clause_eh); } void user_propagate_delay_init() { if (!m_user_ctx) return; - m_ctx->user_propagate_init(m_user_ctx, m_push_eh, m_pop_eh, m_fresh_eh); - if (m_fixed_eh) m_ctx->user_propagate_register_fixed(m_fixed_eh); - if (m_final_eh) m_ctx->user_propagate_register_final(m_final_eh); - if (m_eq_eh) m_ctx->user_propagate_register_eq(m_eq_eh); - if (m_diseq_eh) m_ctx->user_propagate_register_diseq(m_diseq_eh); - if (m_created_eh) m_ctx->user_propagate_register_created(m_created_eh); - if (m_decide_eh) m_ctx->user_propagate_register_decide(m_decide_eh); + m_ctx.load()->user_propagate_init(m_user_ctx, m_push_eh, m_pop_eh, m_fresh_eh); + if (m_fixed_eh) m_ctx.load()->user_propagate_register_fixed(m_fixed_eh); + if (m_final_eh) m_ctx.load()->user_propagate_register_final(m_final_eh); + if (m_eq_eh) m_ctx.load()->user_propagate_register_eq(m_eq_eh); + if (m_diseq_eh) m_ctx.load()->user_propagate_register_diseq(m_diseq_eh); + if (m_created_eh) m_ctx.load()->user_propagate_register_created(m_created_eh); + if (m_decide_eh) m_ctx.load()->user_propagate_register_decide(m_decide_eh); for (expr* v : m_vars) - m_ctx->user_propagate_register_expr(v); + m_ctx.load()->user_propagate_register_expr(v); for (auto& [var, value] : m_values) - m_ctx->user_propagate_initialize_value(var, value); + m_ctx.load()->user_propagate_initialize_value(var, value); } void user_propagate_clear() override {