From 8901f8f44cdbcbda9d3f344c8be60da42a9729b0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 30 Oct 2025 15:59:13 -0700 Subject: [PATCH] some comments and change to how parameter variants are stored Signed-off-by: Nikolaj Bjorner --- src/smt/smt_parallel.cpp | 111 +++++++++++++++++++-------------------- src/smt/smt_parallel.h | 26 ++++++--- 2 files changed, 71 insertions(+), 66 deletions(-) diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 0df7a4c53..96de76217 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -83,7 +83,7 @@ namespace smt { return r; } - unsigned parallel::param_generator::replay_proof_prefixes(vector candidate_param_states, unsigned max_conflicts_epsilon=200) { + unsigned parallel::param_generator::replay_proof_prefixes(vector const& candidate_param_states, unsigned max_conflicts_epsilon=200) { unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon; unsigned best_param_state_idx; double best_score; @@ -95,11 +95,11 @@ namespace smt { double score = 0.0; // apply the ith param state to probe_ctx - smt_params params = candidate_param_states[i]; - params_ref p; - params.updt_params(p); + params_ref p = apply_param_values(candidate_param_states[i]); probe_ctx->updt_params(p); + // todo: m_recorded_cubes as a expr_ref_vector + for (auto const& clause : probe_ctx->m_recorded_clauses) { expr_ref_vector negated_lits(probe_ctx->m); for (literal lit : clause) { @@ -111,17 +111,13 @@ namespace smt { } // Replay the negated clause + lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data()); - ::statistics st; - probe_ctx->collect_statistics(st); - unsigned conflicts = 0, decisions = 0, rlimit = 0; + unsigned conflicts = probe_ctx->m_stats.m_num_conflicts; + unsigned decisions = probe_ctx->m_stats.m_num_decisions; - // I can't figure out how to access the statistics fields, I only see an update method - // st.get_uint("conflicts", conflicts); - // st.get_uint("decisions", decisions); - // st.get_uint("rlimit count", rlimit); - score += conflicts + decisions + rlimit; + score += conflicts + decisions; } if (i == 0 || score < best_score) { @@ -134,49 +130,40 @@ namespace smt { } void parallel::param_generator::init_param_state() { - // param_descrs smt_desc; - // smt_params_helper::collect_param_descrs(smt_desc); smt_params_helper smtp(m_p); - m_my_param_state.insert(symbol("smt.arith.nl.branching"), smtp.arith_nl_branching()); - m_my_param_state.insert(symbol("smt.arith.nl.cross_nested"), smtp.arith_nl_cross_nested()); - m_my_param_state.insert(symbol("smt.arith.nl.delay"), smtp.arith_nl_delay()); - m_my_param_state.insert(symbol("smt.arith.nl.expensive_patching"), smtp.arith_nl_expensive_patching()); - m_my_param_state.insert(symbol("smt.arith.nl.gb"), smtp.arith_nl_gb()); - m_my_param_state.insert(symbol("smt.arith.nl.horner"), smtp.arith_nl_horner()); - m_my_param_state.insert(symbol("smt.arith.nl.horner_frequency"), smtp.arith_nl_horner_frequency()); - m_my_param_state.insert(symbol("smt.arith.nl.optimize_bounds"), smtp.arith_nl_optimize_bounds()); - m_my_param_state.insert(symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials()); - m_my_param_state.insert(symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents()); + m_param_state.push_back({symbol("smt.arith.nl.branching"), smtp.arith_nl_branching()}); + m_param_state.push_back({symbol("smt.arith.nl.cross_nested"), smtp.arith_nl_cross_nested()}); + m_param_state.push_back({symbol("smt.arith.nl.delay"), unsigned_value({smtp.arith_nl_delay(), 5, 10})}); + m_param_state.push_back({symbol("smt.arith.nl.expensive_patching"), smtp.arith_nl_expensive_patching()}); + m_param_state.push_back({symbol("smt.arith.nl.gb"), smtp.arith_nl_grobner()}); + m_param_state.push_back({symbol("smt.arith.nl.horner"), smtp.arith_nl_horner()}); + m_param_state.push_back({symbol("smt.arith.nl.horner_frequency"), unsigned_value({smtp.arith_nl_horner_frequency(), 2, 6}) + }); + m_param_state.push_back({symbol("smt.arith.nl.optimize_bounds"), smtp.arith_nl_optimize_bounds()}); + m_param_state.push_back( + {symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials()}); + m_param_state.push_back({symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents()}); + }; - // TODO: this should mutate only one field at a time an mutate it based on m_my_param_state to keep it generic. + parallel::param_generator::param_values parallel::param_generator::mutate_param_state() { - smt_params parallel::param_generator::mutate_param_state() { - smt_params p = m_param_state; - random_gen m_rand; - - auto flip_bool = [&](bool &x) { - if (m_rand(2) == 0) - x = !x; - }; - - auto mutate_uint = [&](unsigned &x, unsigned lo, unsigned hi) { - if ((m_rand() % 2) == 0) - x = lo + (m_rand((hi - lo + 1))); - }; - - flip_bool(p.m_nl_arith_branching); - flip_bool(p.m_nl_arith_cross_nested); - mutate_uint(p.m_nl_arith_delay, 5, 20); - flip_bool(p.m_nl_arith_expensive_patching); - flip_bool(p.m_nl_arith_gb); - flip_bool(p.m_nl_arith_horner); - mutate_uint(p.m_nl_arith_horner_frequency, 2, 6); - flip_bool(p.m_nl_arith_optimize_bounds); - flip_bool(p.m_nl_arith_propagate_linear_monomials); - flip_bool(p.m_nl_arith_tangents); - - return p; + param_values new_param_values(m_param_state); + unsigned index = ctx->get_random_value() % new_param_values.size(); + auto ¶m = new_param_values[index]; + if (std::holds_alternative(param.second)) { + bool value = *std::get_if(¶m.second); + param.second = !value; + } + else if (std::holds_alternative(param.second)) { + auto [value, lo, hi] = *std::get_if(¶m.second); + unsigned new_value = value; + while (new_value == value) { + new_value = lo + ctx->get_random_value() % (hi - lo + 1); + } + std::get(param.second).value = new_value; + } + return new_param_values; } void parallel::param_generator::protocol_iteration() { @@ -185,6 +172,8 @@ namespace smt { // copy current param state to all param probe contexts, before running the next prefix step // this ensures that each param probe context replays the prefix from the same configuration + + // instead just one one context and reset it each time before copy. for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) { context::copy(*ctx, *m_param_probe_contexts[i], true); } @@ -195,8 +184,12 @@ namespace smt { case l_undef: { // TODO, change from smt_params to a generic param state representation based on params_ref // only params_ref have effect on updates. - smt_params best_param_state = m_param_state; - vector candidate_param_states; + param_values best_param_state = m_param_state; + vector candidate_param_states; + + // you can create the mutations on the fly and get the scores + // you don't have to copy all over each tester. + candidate_param_states.push_back(best_param_state); // first candidate param state is current best while (candidate_param_states.size() <= N) { @@ -207,7 +200,8 @@ namespace smt { if (best_param_state_idx != 0) { m_param_state = candidate_param_states[best_param_state_idx]; - b.set_param_state(m_param_state); + auto p = apply_param_values(m_param_state); + b.set_param_state(p); IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state at index " << best_param_state_idx << "\n"); } else { IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n"); @@ -318,12 +312,12 @@ namespace smt { } parallel::param_generator::param_generator(parallel& p) - : p(p), b(p.m_batch_manager), m_param_state(p.ctx.get_fparams()), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) { - ctx = alloc(context, m, m_param_state, m_p); + : p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) { + ctx = alloc(context, m, p.ctx.get_fparams(), m_p); context::copy(p.ctx, *ctx, true); for (unsigned i = 0; i < N; ++i) { - m_param_probe_contexts.push_back(alloc(context, m, m_param_state, m_p)); + m_param_probe_contexts.push_back(alloc(context, m, ctx->get_fparams(), m_p)); } // don't share initial units @@ -483,7 +477,8 @@ namespace smt { } } - smt_params parallel::batch_manager::get_best_param_state() { + // todo make this thread safe by not using reference counts implicit in params ref but instead copying the entire structure. + params_ref parallel::batch_manager::get_best_param_state() { std::scoped_lock lock(mux); return m_param_state; } diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 834011c3d..1e9c6cdaf 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -81,7 +81,7 @@ namespace smt { std::mutex mux; state m_state = state::is_running; stats m_stats; - smt_params m_param_state; + params_ref m_param_state; using node = search_tree::node; search_tree::tree m_search_tree; @@ -106,10 +106,10 @@ namespace smt { void set_sat(ast_translation& l2g, model& m); void set_exception(std::string const& msg); void set_exception(unsigned error_code); - void set_param_state(smt_params const& p) { m_param_state = p; } + void set_param_state(params_ref const& p) { m_param_state.copy(p); } void collect_statistics(::statistics& st) const; - smt_params get_best_param_state(); + params_ref get_best_param_state(); bool get_cube(ast_translation& g2l, unsigned id, expr_ref_vector& cube, node*& n); void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n); void split(ast_translation& l2g, unsigned id, node* n, expr* atom); @@ -139,22 +139,32 @@ namespace smt { scoped_ptr m_prefix_solver; scoped_ptr_vector m_param_probe_contexts; - smt_params m_param_state; params_ref m_p; - using param_value = std::variant; - symbol_table m_my_param_state; + struct unsigned_value { + unsigned value; + unsigned min_value; + unsigned max_value; + }; + using param_value = std::variant; + using param_values = vector>; + param_values m_param_state; + + params_ref apply_param_values(param_values const &pv) { + return m_p; + } + // todo private: void init_param_state(); - smt_params mutate_param_state(); + param_values mutate_param_state(); public: param_generator(parallel &p); lbool run_prefix_step(); void protocol_iteration(); - unsigned replay_proof_prefixes(vector candidate_param_states, unsigned max_conflicts_epsilon); + unsigned replay_proof_prefixes(vector const& candidate_param_states, unsigned max_conflicts_epsilon); reslimit &limit() { return m.limit();