From e72cf2ec0909550a73a74e0314f0f5f603ffa767 Mon Sep 17 00:00:00 2001 From: Ilana Shapiro Date: Wed, 29 Oct 2025 22:49:13 -0700 Subject: [PATCH] score the param probes, but i can't figure out how to access the relevant solver statistics fields from the statistics obj --- src/smt/smt_parallel.cpp | 42 +++++++++++++++++++++++++++++++++------- src/smt/smt_parallel.h | 2 +- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index a31927934..017ee5024 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -80,14 +80,23 @@ namespace smt { return r; } - void parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) { + unsigned parallel::param_generator::replay_proof_prefixes(vector candidate_param_states, unsigned max_conflicts_epsilon=200) { unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon; - // loop through m_param_probe_contexts + unsigned best_param_state_idx; + double best_score; + for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) { IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n"); context *probe_ctx = m_param_probe_contexts[i]; probe_ctx->get_fparams().m_max_conflicts = conflict_budget; - + double score = 0.0; + + // apply the ith param state to probe_ctx + smt_params params = candidate_param_states[i]; + params_ref p; + params.updt_params(p); + probe_ctx->updt_params(p); + for (auto const& clause : probe_ctx->m_recorded_clauses) { expr_ref_vector negated_lits(probe_ctx->m); for (literal lit : clause) { @@ -100,8 +109,25 @@ namespace smt { // Replay the negated clause lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data()); + + ::statistics st; + probe_ctx->collect_statistics(st); + unsigned conflicts = 0, decisions = 0, rlimit = 0; + + // I can't figure out how to access the statistics fields, I only see an update method + // st.get_uint("conflicts", conflicts); + // st.get_uint("decisions", decisions); + // st.get_uint("rlimit count", rlimit); + score += conflicts + decisions + rlimit; + } + + if (i == 0 || score < best_score) { + best_score = score; + best_param_state_idx = i; } } + + return best_param_state_idx; } void parallel::param_generator::protocol_iteration() { @@ -126,11 +152,13 @@ namespace smt { candidate_param_states.push_back(mutate_param_state()); } - replay_proof_prefixes(); + unsigned best_param_state_idx = replay_proof_prefixes(candidate_param_states); - if (best_param_state != m_param_state) { - IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: no parameter mutation occurred, skipping update\n"); - return; + if (best_param_state_idx != 0) { + best_param_state = candidate_param_states[best_param_state_idx]; + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state at index " << best_param_state_idx << "\n"); + } else { + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n"); } } case l_true: { diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 69e0a0f81..41dab51f1 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -187,7 +187,7 @@ namespace smt { param_generator(parallel& p); lbool run_prefix_step(); void protocol_iteration(); - void replay_proof_prefixes(unsigned max_conflicts_epsilon); + unsigned replay_proof_prefixes(vector candidate_param_states, unsigned max_conflicts_epsilon); reslimit& limit() { return m.limit();