diff --git a/src/params/smt_parallel_params.pyg b/src/params/smt_parallel_params.pyg index 9a81f6c25..fc6cd8063 100644 --- a/src/params/smt_parallel_params.pyg +++ b/src/params/smt_parallel_params.pyg @@ -21,5 +21,7 @@ def_module_params('smt_parallel', ('cubetree', BOOL, False, 'use cube tree data structure for storing cubes'), ('searchtree', BOOL, False, 'use search tree implementation (parallel2)'), ('inprocessing', BOOL, False, 'integrate in-processing as a heuristic simplification'), - ('inprocessing_delay', UINT, 0, 'number of undef before invoking simplification') + ('inprocessing_delay', UINT, 0, 'number of undef before invoking simplification'), + ('param_tuning', BOOL, False, 'whether to tune params online during solving'), + ('tunable_params', STRING, '', 'comma-separated key=value list for online param tuning, e.g. \"smt.arith.nl.horner=false,smt.arith.nl.delay=8\"') )) \ No newline at end of file diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index ee2d06cdb..65ac8385a 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -163,11 +163,9 @@ namespace smt { m_param_state.push_back({symbol("smt.arith.nl.expensive_patching"), smtp.arith_nl_expensive_patching()}); m_param_state.push_back({symbol("smt.arith.nl.gb"), smtp.arith_nl_grobner()}); m_param_state.push_back({symbol("smt.arith.nl.horner"), smtp.arith_nl_horner()}); - m_param_state.push_back({symbol("smt.arith.nl.horner_frequency"), unsigned_value({smtp.arith_nl_horner_frequency(), 2, 6}) - }); + m_param_state.push_back({symbol("smt.arith.nl.horner_frequency"), unsigned_value({smtp.arith_nl_horner_frequency(), 2, 6})}); m_param_state.push_back({symbol("smt.arith.nl.optimize_bounds"), smtp.arith_nl_optimize_bounds()}); - m_param_state.push_back( - {symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials()}); + m_param_state.push_back({symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials()}); m_param_state.push_back({symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents()}); }; @@ -698,27 +696,76 @@ namespace smt { m_batch_manager.initialize(); m_workers.reset(); + + smt_parallel_params pp(ctx.m_params); + m_should_tune_params = pp.param_tuning(); scoped_limits sl(m.limit()); flet _nt(ctx.m_fparams.m_threads, 1); - m_param_generator = alloc(param_generator, *this); SASSERT(num_threads > 1); for (unsigned i = 0; i < num_threads; ++i) m_workers.push_back(alloc(worker, i, *this, asms)); - + for (auto w : m_workers) sl.push_child(&(w->limit())); - - sl.push_child(&(m_param_generator->limit())); + + if (m_should_tune_params) { + m_param_generator = alloc(param_generator, *this); + sl.push_child(&(m_param_generator->limit())); + } + + std::string tuned = pp.tunable_params(); + if (!tuned.empty()) { + auto trim = [](std::string &s) { + s.erase(0, s.find_first_not_of(" \t\n\r")); + s.erase(s.find_last_not_of(" \t\n\r") + 1); + }; + + std::stringstream ss(tuned); + std::string kv; + + while (std::getline(ss, kv, ',')) { + size_t eq = kv.find('='); + if (eq == std::string::npos) + continue; + + std::string key = kv.substr(0, eq); + std::string val = kv.substr(eq + 1); + trim(key); + trim(val); + + if (val == "true" || val == "1") { + ctx.m_params.set_bool(symbol(key.c_str()), true); + } + else if (val == "false" || val == "0") { + ctx.m_params.set_bool(symbol(key.c_str()), false); + } + else if (std::all_of(val.begin(), val.end(), ::isdigit)) { + ctx.m_params.set_uint(symbol(key.c_str()), + static_cast(std::stoul(val))); + } + else { + // if non-numeric and non-bool, just store as string/symbol + ctx.m_params.set_str(symbol(key.c_str()), val.c_str()); + } + } + + IF_VERBOSE(1, + verbose_stream() << "Applied parameter overrides:\n"; + ctx.m_params.display(verbose_stream()); + ); + } // Launch threads - vector threads(num_threads + 1); // +1 for param generator + vector threads(m_should_tune_params ? num_threads + 1 : num_threads); // +1 for param generator for (unsigned i = 0; i < num_threads; ++i) { threads[i] = std::thread([&, i]() { m_workers[i]->run(); }); } // the final thread runs the parameter generator - threads[num_threads] = std::thread([&]() { m_param_generator->protocol_iteration(); }); + if (m_should_tune_params) { + threads[num_threads] = std::thread([&]() { m_param_generator->protocol_iteration(); }); + } // Wait for all threads to finish for (auto &th : threads) diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 9d7834a25..f0fc43844 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -36,6 +36,7 @@ namespace smt { class parallel { context& ctx; unsigned num_threads; + bool m_should_tune_params; struct shared_clause { unsigned source_worker_id; @@ -73,7 +74,7 @@ namespace smt { void cancel_background_threads() { cancel_workers(); - cancel_param_generator(); + if (p.m_should_tune_params) cancel_param_generator(); } // called from batch manager to cancel other workers if we've reached a verdict