diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index c0a2f3a4d..87878ab6e 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -3617,15 +3617,15 @@ namespace smt { \remark A logical context can only be configured at scope level 0, and before internalizing any formulas. */ - lbool context::setup_and_check(bool reset_cancel) { + lbool context::setup_and_check(bool reset_cancel, bool enable_parallel_param_tuning) { if (!check_preamble(reset_cancel)) return l_undef; SASSERT(m_scope_lvl == 0); SASSERT(!m_setup.already_configured()); setup_context(m_fparams.m_auto_config); if (m_fparams.m_threads > 1 && !m.has_trace_stream()) { - parallel p(*this); expr_ref_vector asms(m); + parallel p(*this, enable_parallel_param_tuning); return p(asms); } @@ -3685,14 +3685,15 @@ namespace smt { } } - lbool context::check(unsigned num_assumptions, expr * const * assumptions, bool reset_cancel) { + lbool context::check(unsigned num_assumptions, expr * const * assumptions, bool reset_cancel, bool enable_parallel_param_tuning) { if (!check_preamble(reset_cancel)) return l_undef; SASSERT(at_base_level()); setup_context(false); search_completion sc(*this); if (m_fparams.m_threads > 1 && !m.has_trace_stream()) { expr_ref_vector asms(m, num_assumptions, assumptions); - parallel p(*this); + IF_VERBOSE(1, verbose_stream() << "Starting parallel check with " << asms.size() << " assumptions and param tuning enabled: " << enable_parallel_param_tuning << "\n"); + parallel p(*this, enable_parallel_param_tuning); return p(asms); } lbool r = l_undef; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 0e4cdde44..a363548e4 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -132,6 +132,7 @@ namespace smt { unsigned m_par_index = 0; bool m_internalizing_assertions = false; lbool m_internal_completed = l_undef; + bool m_in_parallel = false; scoped_ptr m_simplifier; scoped_ptr m_fmls; @@ -1689,7 +1690,7 @@ namespace smt { void pop(unsigned num_scopes); - lbool check(unsigned num_assumptions = 0, expr * const * assumptions = nullptr, bool reset_cancel = true); + lbool check(unsigned num_assumptions = 0, expr * const * assumptions = nullptr, bool reset_cancel = true, bool enable_parallel_param_tuning = true); lbool check(expr_ref_vector const& cube, vector const& clauses); @@ -1699,7 +1700,7 @@ namespace smt { lbool preferred_sat(expr_ref_vector const& asms, vector& cores); - lbool setup_and_check(bool reset_cancel = true); + lbool setup_and_check(bool reset_cancel = true, bool enable_parallel_param_tuning = true); void reduce_assertions(); diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index e1eda24f9..b86af05ac 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -990,7 +990,8 @@ namespace smt { void context::undo_mk_bool_var() { - SASSERT(!m_b_internalized_stack.empty()); + SASSERT(!m_b_internalized_stack.empty(ue key per literal + m_lit_scores[lit.sign()][v] += 1.)); m_stats.m_num_del_bool_var++; expr * n = m_b_internalized_stack.back(); unsigned n_id = n->get_id(); diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 8611f57bd..5de53176e 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -66,13 +66,18 @@ namespace smt { namespace smt { lbool parallel::param_generator::run_prefix_step() { - IF_VERBOSE(1, verbose_stream() << " Param generator running prefix step\n"); + if (m.limit().is_canceled()) + return l_undef; + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER running prefix step\n"); ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts; + ctx->get_fparams().m_threads = 1; + m_recorded_cubes.reset(); ctx->m_recorded_cubes = &m_recorded_cubes; lbool r = l_undef; try { - r = ctx->check(); + r = ctx->check(0, nullptr, true, false); + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: prefix step result " << r << "\n"); } catch (z3_error &err) { b.set_exception(err.error_code()); @@ -93,8 +98,11 @@ namespace smt { bool found_better_params = false; for (unsigned i = 0; i <= N; ++i) { - IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n"); + if (m.limit().is_canceled()) + return; + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n"); + // copy prefix solver context to a new probe_ctx for next replay with candidate mutation smt_params smtp(m_p); scoped_ptr probe_ctx = alloc(context, m, smtp, m_p); @@ -110,11 +118,16 @@ namespace smt { } probe_ctx->get_fparams().m_max_conflicts = conflict_budget; + probe_ctx->get_fparams().m_threads = 1; // replay the cube (negation of the clause) + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: begin replay of " << m_recorded_cubes.size() << " cubes\n"); for (expr_ref_vector const& cube : m_recorded_cubes) { - lbool r = probe_ctx->check(cube.size(), cube.data()); - IF_VERBOSE(1, verbose_stream() << " PARAM TUNER " << i << ": cube replay result " << r << "\n"); + if (m.limit().is_canceled()) + return; + // the conflicts and decisions are cumulative over all cube replays inside the probe_ctx + lbool r = probe_ctx->check(cube.size(), cube.data(), true, false); + IF_VERBOSE(2, verbose_stream() << " PARAM TUNER " << i << ": cube replay result " << r << "\n"); } unsigned conflicts = probe_ctx->m_stats.m_num_conflicts; unsigned decisions = probe_ctx->m_stats.m_num_decisions; @@ -130,10 +143,7 @@ namespace smt { best_score = score; } } - // NOTE: we either need to apply the best params found that are better than base line - // or, we have to implement a comparison operator for param_values (what would this do?) - // or, we update the param state every single time even if it hasn't changed (what would this do?) - // for now, I went with option 1 + if (found_better_params) { m_param_state = best_param_state; auto p = apply_param_values(m_param_state); @@ -192,7 +202,8 @@ namespace smt { } std::get(param.second).value = new_value; } - IF_VERBOSE(0, + IF_VERBOSE(1, + verbose_stream() << "Mutating param: "; for (auto const &[name, val] : new_param_values) { if (std::holds_alternative(val)) { verbose_stream() << name << " = " << std::get(val) << "\n"; @@ -211,6 +222,9 @@ namespace smt { ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts; lbool r = run_prefix_step(); + if (m.limit().is_canceled()) + return; + switch (r) { case l_undef: { replay_proof_prefixes(); @@ -234,6 +248,11 @@ namespace smt { } } + void parallel::param_generator::cancel() { + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER cancelling\n"); + m.limit().cancel(); + } + void parallel::worker::run() { search_tree::node *node = nullptr; expr_ref_vector cube(m); @@ -255,7 +274,7 @@ namespace smt { lbool r = check_cube(cube); - if (!m.inc()) { + if (m.limit().is_canceled()) { b.set_exception("context cancelled"); return; } @@ -320,7 +339,11 @@ namespace smt { } parallel::param_generator::param_generator(parallel& p) - : p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) { + : b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) { + // patch fix so that ctx = alloc(context, m, p.ctx.get_fparams(), m_p); doesn't crash due to some issue with default construction of m + ast_translation m_g2l(p.ctx.m, m); + m_g2l(p.ctx.m.mk_true()); + ctx = alloc(context, m, p.ctx.get_fparams(), m_p); context::copy(p.ctx, *ctx, true); // don't share initial units @@ -450,7 +473,7 @@ namespace smt { IF_VERBOSE(1, m_search_tree.display(verbose_stream() << bounded_pp_exprs(core) << "\n");); if (m_search_tree.is_closed()) { m_state = state::is_unsat; - cancel_workers(); + cancel_background_threads(); } } @@ -516,7 +539,7 @@ namespace smt { << bounded_pp_exprs(cube) << "with max_conflicts: " << ctx->get_fparams().m_max_conflicts << "\n";); try { - r = ctx->check(asms.size(), asms.data()); + r = ctx->check(asms.size(), asms.data(), true, false); } catch (z3_error &err) { b.set_exception(err.error_code()); } catch (z3_exception &ex) { @@ -561,7 +584,7 @@ namespace smt { return; m_state = state::is_sat; p.ctx.set_model(m.translate(l2g)); - cancel_workers(); + cancel_background_threads(); } void parallel::batch_manager::set_unsat(ast_translation &l2g, expr_ref_vector const &unsat_core) { @@ -575,7 +598,7 @@ namespace smt { SASSERT(p.ctx.m_unsat_core.empty()); for (expr *e : unsat_core) p.ctx.m_unsat_core.push_back(l2g(e)); - cancel_workers(); + cancel_background_threads(); } void parallel::batch_manager::set_exception(unsigned error_code) { @@ -585,7 +608,7 @@ namespace smt { return; m_state = state::is_exception_code; m_exception_code = error_code; - cancel_workers(); + cancel_background_threads(); } void parallel::batch_manager::set_exception(std::string const &msg) { @@ -595,7 +618,7 @@ namespace smt { return; m_state = state::is_exception_msg; m_exception_msg = msg; - cancel_workers(); + cancel_background_threads(); } lbool parallel::batch_manager::get_result() const { @@ -675,6 +698,7 @@ namespace smt { m_batch_manager.initialize(); m_workers.reset(); + scoped_limits sl(m.limit()); flet _nt(ctx.m_fparams.m_threads, 1); SASSERT(num_threads > 1); @@ -684,15 +708,16 @@ namespace smt { for (auto w : m_workers) sl.push_child(&(w->limit())); - sl.push_child(&(m_param_generator.limit())); + sl.push_child(&(m_param_generator->limit())); // Launch threads - vector threads(num_threads + 1); // +1 for parameter generator - for (unsigned i = 0; i < num_threads - 1; ++i) { + vector threads(m_enable_param_tuner ? num_threads + 1 : num_threads); // +1 for param generator + for (unsigned i = 0; i < num_threads; ++i) { threads[i] = std::thread([&, i]() { m_workers[i]->run(); }); } // the final thread runs the parameter generator - threads[num_threads - 1] = std::thread([&]() { m_param_generator.protocol_iteration(); }); + if (m_enable_param_tuner) + threads[num_threads] = std::thread([&]() { m_param_generator->protocol_iteration(); }); // Wait for all threads to finish for (auto &th : threads) diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 27111c0bd..ea75fc498 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -36,6 +36,7 @@ namespace smt { class parallel { context& ctx; unsigned num_threads; + bool m_enable_param_tuner; struct shared_clause { unsigned source_worker_id; @@ -57,7 +58,6 @@ namespace smt { unsigned m_num_cubes = 0; }; - ast_manager& m; parallel& p; std::mutex mux; @@ -72,6 +72,11 @@ namespace smt { vector shared_clause_trail; // store all shared clauses with worker IDs obj_hashtable shared_clause_set; // for duplicate filtering on per-thread clause expressions + void cancel_background_threads() { + cancel_workers(); + cancel_param_generator(); + } + // called from batch manager to cancel other workers if we've reached a verdict void cancel_workers() { IF_VERBOSE(1, verbose_stream() << "Canceling workers\n"); @@ -79,6 +84,11 @@ namespace smt { w->cancel(); } + void cancel_param_generator() { + IF_VERBOSE(1, verbose_stream() << "Canceling param generator\n"); + p.m_param_generator->cancel(); + } + public: batch_manager(ast_manager& m, parallel& p) : m(m), p(p), m_search_tree(expr_ref(m)) { } @@ -88,7 +98,10 @@ namespace smt { void set_sat(ast_translation& l2g, model& m); void set_exception(std::string const& msg); void set_exception(unsigned error_code); - void set_param_state(params_ref const& p) { m_param_state.copy(p); } + void set_param_state(params_ref const& p) { + m_param_state.copy(p); + IF_VERBOSE(1, verbose_stream() << "Batch manager updated param state\n"); + } void get_param_state(params_ref &p); void collect_statistics(::statistics& st) const; @@ -118,11 +131,9 @@ namespace smt { using param_value = std::variant; using param_values = vector>; - parallel &p; batch_manager &b; ast_manager m; scoped_ptr ctx; - ast_translation m_l2g; unsigned N = 4; // number of prefix permutations to test (including current) unsigned m_max_prefix_conflicts = 1000; @@ -131,6 +142,7 @@ namespace smt { vector m_recorded_cubes; params_ref m_p; param_values m_param_state; + ast_translation m_l2g; params_ref apply_param_values(param_values const &pv); void init_param_state(); @@ -141,6 +153,7 @@ namespace smt { lbool run_prefix_step(); void protocol_iteration(); void replay_proof_prefixes(unsigned max_conflicts_epsilon); + void cancel(); reslimit &limit() { return m.limit(); @@ -206,16 +219,17 @@ namespace smt { batch_manager m_batch_manager; scoped_ptr_vector m_workers; - param_generator m_param_generator; + scoped_ptr m_param_generator; public: - parallel(context& ctx) : + parallel(context& ctx, bool enable_param_tuner = true) : ctx(ctx), num_threads(std::min( (unsigned)std::thread::hardware_concurrency(), ctx.get_fparams().m_threads)), + m_enable_param_tuner(enable_param_tuner), m_batch_manager(ctx.m, *this), - m_param_generator(*this) {} + m_param_generator(enable_param_tuner ? alloc(param_generator, *this) : nullptr) {} lbool operator()(expr_ref_vector const& asms); };