diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index c0a2f3a4d..87878ab6e 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -3617,15 +3617,15 @@ namespace smt { \remark A logical context can only be configured at scope level 0, and before internalizing any formulas. */ - lbool context::setup_and_check(bool reset_cancel) { + lbool context::setup_and_check(bool reset_cancel, bool enable_parallel_param_tuning) { if (!check_preamble(reset_cancel)) return l_undef; SASSERT(m_scope_lvl == 0); SASSERT(!m_setup.already_configured()); setup_context(m_fparams.m_auto_config); if (m_fparams.m_threads > 1 && !m.has_trace_stream()) { - parallel p(*this); expr_ref_vector asms(m); + parallel p(*this, enable_parallel_param_tuning); return p(asms); } @@ -3685,14 +3685,15 @@ namespace smt { } } - lbool context::check(unsigned num_assumptions, expr * const * assumptions, bool reset_cancel) { + lbool context::check(unsigned num_assumptions, expr * const * assumptions, bool reset_cancel, bool enable_parallel_param_tuning) { if (!check_preamble(reset_cancel)) return l_undef; SASSERT(at_base_level()); setup_context(false); search_completion sc(*this); if (m_fparams.m_threads > 1 && !m.has_trace_stream()) { expr_ref_vector asms(m, num_assumptions, assumptions); - parallel p(*this); + IF_VERBOSE(1, verbose_stream() << "Starting parallel check with " << asms.size() << " assumptions and param tuning enabled: " << enable_parallel_param_tuning << "\n"); + parallel p(*this, enable_parallel_param_tuning); return p(asms); } lbool r = l_undef; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 0e4cdde44..a363548e4 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -132,6 +132,7 @@ namespace smt { unsigned m_par_index = 0; bool m_internalizing_assertions = false; lbool m_internal_completed = l_undef; + bool m_in_parallel = false; scoped_ptr m_simplifier; scoped_ptr m_fmls; @@ -1689,7 +1690,7 @@ namespace smt { void pop(unsigned num_scopes); - lbool check(unsigned num_assumptions = 0, expr * const * assumptions = nullptr, bool reset_cancel = true); + lbool check(unsigned num_assumptions = 0, expr * const * assumptions = nullptr, bool reset_cancel = true, bool enable_parallel_param_tuning = true); lbool check(expr_ref_vector const& cube, vector const& clauses); @@ -1699,7 +1700,7 @@ namespace smt { lbool preferred_sat(expr_ref_vector const& asms, vector& cores); - lbool setup_and_check(bool reset_cancel = true); + lbool setup_and_check(bool reset_cancel = true, bool enable_parallel_param_tuning = true); void reduce_assertions(); diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index f6049405d..b88df8b8e 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -66,13 +66,16 @@ namespace smt { namespace smt { lbool parallel::param_generator::run_prefix_step() { - IF_VERBOSE(1, verbose_stream() << " Param generator running prefix step\n"); + if (m.limit().is_canceled()) + return l_undef; + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER running prefix step\n"); ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts; m_recorded_cubes.reset(); ctx->m_recorded_cubes = &m_recorded_cubes; lbool r = l_undef; try { - r = ctx->check(); + r = ctx->check(0, nullptr, true, false); + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: prefix step result " << r << "\n"); } catch (z3_error &err) { b.set_exception(err.error_code()); @@ -93,8 +96,11 @@ namespace smt { bool found_better_params = false; for (unsigned i = 0; i <= N; ++i) { - IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n"); + if (m.limit().is_canceled()) + return; + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n"); + // copy prefix solver context to a new probe_ctx for next replay with candidate mutation smt_params smtp(m_p); scoped_ptr probe_ctx = alloc(context, m, smtp, m_p); @@ -112,8 +118,12 @@ namespace smt { probe_ctx->get_fparams().m_max_conflicts = conflict_budget; // replay the cube (negation of the clause) + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: begin replay of " << m_recorded_cubes.size() << " cubes\n"); for (expr_ref_vector const& cube : m_recorded_cubes) { - lbool r = probe_ctx->check(cube.size(), cube.data()); + if (m.limit().is_canceled()) + return; + // the conflicts and decisions are cumulative over all cube replays inside the probe_ctx + lbool r = probe_ctx->check(cube.size(), cube.data(), true, false); IF_VERBOSE(1, verbose_stream() << " PARAM TUNER " << i << ": cube replay result " << r << "\n"); } unsigned conflicts = probe_ctx->m_stats.m_num_conflicts; @@ -189,7 +199,8 @@ namespace smt { } std::get(param.second).value = new_value; } - IF_VERBOSE(0, + IF_VERBOSE(1, + verbose_stream() << "Mutating param: "; for (auto const &[name, val] : new_param_values) { if (std::holds_alternative(val)) { verbose_stream() << name << " = " << std::get(val) << "\n"; @@ -208,6 +219,9 @@ namespace smt { ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts; lbool r = run_prefix_step(); + if (m.limit().is_canceled()) + return; + switch (r) { case l_undef: { replay_proof_prefixes(); @@ -231,6 +245,11 @@ namespace smt { } } + void parallel::param_generator::cancel() { + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER cancelling\n"); + m.limit().cancel(); + } + void parallel::worker::run() { search_tree::node *node = nullptr; expr_ref_vector cube(m); @@ -252,7 +271,7 @@ namespace smt { lbool r = check_cube(cube); - if (!m.inc()) { + if (m.limit().is_canceled()) { b.set_exception("context cancelled"); return; } @@ -451,7 +470,7 @@ namespace smt { IF_VERBOSE(1, m_search_tree.display(verbose_stream() << bounded_pp_exprs(core) << "\n");); if (m_search_tree.is_closed()) { m_state = state::is_unsat; - cancel_workers(); + cancel_background_threads(); } } @@ -517,7 +536,7 @@ namespace smt { << bounded_pp_exprs(cube) << "with max_conflicts: " << ctx->get_fparams().m_max_conflicts << "\n";); try { - r = ctx->check(asms.size(), asms.data()); + r = ctx->check(asms.size(), asms.data(), true, false); } catch (z3_error &err) { b.set_exception(err.error_code()); } catch (z3_exception &ex) { @@ -562,7 +581,7 @@ namespace smt { return; m_state = state::is_sat; p.ctx.set_model(m.translate(l2g)); - cancel_workers(); + cancel_background_threads(); } void parallel::batch_manager::set_unsat(ast_translation &l2g, expr_ref_vector const &unsat_core) { @@ -576,7 +595,7 @@ namespace smt { SASSERT(p.ctx.m_unsat_core.empty()); for (expr *e : unsat_core) p.ctx.m_unsat_core.push_back(l2g(e)); - cancel_workers(); + cancel_background_threads(); } void parallel::batch_manager::set_exception(unsigned error_code) { @@ -586,7 +605,7 @@ namespace smt { return; m_state = state::is_exception_code; m_exception_code = error_code; - cancel_workers(); + cancel_background_threads(); } void parallel::batch_manager::set_exception(std::string const &msg) { @@ -596,7 +615,7 @@ namespace smt { return; m_state = state::is_exception_msg; m_exception_msg = msg; - cancel_workers(); + cancel_background_threads(); } lbool parallel::batch_manager::get_result() const { @@ -676,6 +695,7 @@ namespace smt { m_batch_manager.initialize(); m_workers.reset(); + scoped_limits sl(m.limit()); flet _nt(ctx.m_fparams.m_threads, 1); SASSERT(num_threads > 1); @@ -685,15 +705,16 @@ namespace smt { for (auto w : m_workers) sl.push_child(&(w->limit())); - sl.push_child(&(m_param_generator.limit())); + sl.push_child(&(m_param_generator->limit())); // Launch threads - vector threads(num_threads + 1); // +1 for parameter generator - for (unsigned i = 0; i < num_threads - 1; ++i) { + vector threads(m_enable_param_tuner ? num_threads + 1 : num_threads); // +1 for param generator + for (unsigned i = 0; i < num_threads; ++i) { threads[i] = std::thread([&, i]() { m_workers[i]->run(); }); } // the final thread runs the parameter generator - threads[num_threads - 1] = std::thread([&]() { m_param_generator.protocol_iteration(); }); + if (m_enable_param_tuner) + threads[num_threads] = std::thread([&]() { m_param_generator->protocol_iteration(); }); // Wait for all threads to finish for (auto &th : threads) diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index d5980e8c3..ea75fc498 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -36,6 +36,7 @@ namespace smt { class parallel { context& ctx; unsigned num_threads; + bool m_enable_param_tuner; struct shared_clause { unsigned source_worker_id; @@ -57,7 +58,6 @@ namespace smt { unsigned m_num_cubes = 0; }; - ast_manager& m; parallel& p; std::mutex mux; @@ -72,6 +72,11 @@ namespace smt { vector shared_clause_trail; // store all shared clauses with worker IDs obj_hashtable shared_clause_set; // for duplicate filtering on per-thread clause expressions + void cancel_background_threads() { + cancel_workers(); + cancel_param_generator(); + } + // called from batch manager to cancel other workers if we've reached a verdict void cancel_workers() { IF_VERBOSE(1, verbose_stream() << "Canceling workers\n"); @@ -79,6 +84,11 @@ namespace smt { w->cancel(); } + void cancel_param_generator() { + IF_VERBOSE(1, verbose_stream() << "Canceling param generator\n"); + p.m_param_generator->cancel(); + } + public: batch_manager(ast_manager& m, parallel& p) : m(m), p(p), m_search_tree(expr_ref(m)) { } @@ -88,7 +98,10 @@ namespace smt { void set_sat(ast_translation& l2g, model& m); void set_exception(std::string const& msg); void set_exception(unsigned error_code); - void set_param_state(params_ref const& p) { m_param_state.copy(p); } + void set_param_state(params_ref const& p) { + m_param_state.copy(p); + IF_VERBOSE(1, verbose_stream() << "Batch manager updated param state\n"); + } void get_param_state(params_ref &p); void collect_statistics(::statistics& st) const; @@ -140,6 +153,7 @@ namespace smt { lbool run_prefix_step(); void protocol_iteration(); void replay_proof_prefixes(unsigned max_conflicts_epsilon); + void cancel(); reslimit &limit() { return m.limit(); @@ -205,16 +219,17 @@ namespace smt { batch_manager m_batch_manager; scoped_ptr_vector m_workers; - param_generator m_param_generator; + scoped_ptr m_param_generator; public: - parallel(context& ctx) : + parallel(context& ctx, bool enable_param_tuner = true) : ctx(ctx), num_threads(std::min( (unsigned)std::thread::hardware_concurrency(), ctx.get_fparams().m_threads)), + m_enable_param_tuner(enable_param_tuner), m_batch_manager(ctx.m, *this), - m_param_generator(*this) {} + m_param_generator(enable_param_tuner ? alloc(param_generator, *this) : nullptr) {} lbool operator()(expr_ref_vector const& asms); };