From f315cac0cda385a56cc856746e25508321184aa5 Mon Sep 17 00:00:00 2001 From: Ilana Shapiro Date: Wed, 29 Oct 2025 15:15:36 -0700 Subject: [PATCH] setting up the param probe solvers and mutation generator --- src/smt/smt_parallel.cpp | 42 ++++++++++++++++++---- src/smt/smt_parallel.h | 77 ++++++++++++++++++++++++++++++++-------- 2 files changed, 98 insertions(+), 21 deletions(-) diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 917884193..e619a4107 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -80,17 +80,42 @@ namespace smt { return r; } + void parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) { + unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon; + + } + void parallel::param_generator::protocol_iteration() { - IF_VERBOSE(1, verbose_stream() << " Param generator running protocol iteration\n"); + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER running protocol iteration\n"); ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts; + + // copy current param state to all param probe contexts, before running the next prefix step + // this ensures that each param probe context replays the prefix from the same configuration + for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) { + context::copy(*ctx, *m_param_probe_contexts[i], true); + } + lbool r = run_prefix_step(); switch (r) { case l_undef: { - return; + smt_params best_param_state = m_param_state; + vector candidate_param_states; + + candidate_param_states.push_back(best_param_state); // first candidate param state is current best + while (candidate_param_states.size() <= N) { + candidate_param_states.push_back(mutate_param_state()); + } + + replay_proof_prefixes(); + + if (best_param_state != m_param_state) { + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: no parameter mutation occurred, skipping update\n"); + return; + } } case l_true: { - IF_VERBOSE(1, verbose_stream() << " Param tuning thread found formula sat\n"); + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found formula sat\n"); model_ref mdl; ctx->get_model(mdl); b.set_sat(m_l2g, *mdl); @@ -100,7 +125,7 @@ namespace smt { expr_ref_vector const &unsat_core = ctx->unsat_core(); IF_VERBOSE(2, verbose_stream() << " unsat core:\n"; for (auto c : unsat_core) verbose_stream() << mk_bounded_pp(c, m, 3) << "\n"); - IF_VERBOSE(1, verbose_stream() << " Param tuning thread determined formula unsat\n"); + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER determined formula unsat\n"); b.set_unsat(m_l2g, unsat_core); return; } @@ -187,9 +212,14 @@ namespace smt { } parallel::param_generator::param_generator(parallel& p) - : p(p), b(p.m_batch_manager), m_best_param_state(p.ctx.get_fparams()), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) { - ctx = alloc(context, m, m_best_param_state, m_p); + : p(p), b(p.m_batch_manager), m_param_state(p.ctx.get_fparams()), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) { + ctx = alloc(context, m, m_param_state, m_p); context::copy(p.ctx, *ctx, true); + + for (unsigned i = 0; i < N; ++i) { + m_param_probe_contexts.push_back(alloc(context, m, m_param_state, m_p)); + } + // don't share initial units ctx->pop_to_base_lvl(); init_param_state(); diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index d8da5e34a..69e0a0f81 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -20,12 +20,30 @@ Revision History: #include "smt/smt_context.h" #include "util/search_tree.h" +// #include "util/util.h" #include #include namespace smt { + inline bool operator==(const smt_params& a, const smt_params& b) { + return a.m_nl_arith_branching == b.m_nl_arith_branching && + a.m_nl_arith_cross_nested == b.m_nl_arith_cross_nested && + a.m_nl_arith_delay == b.m_nl_arith_delay && + a.m_nl_arith_expensive_patching == b.m_nl_arith_expensive_patching && + a.m_nl_arith_gb == b.m_nl_arith_gb && + a.m_nl_arith_horner == b.m_nl_arith_horner && + a.m_nl_arith_horner_frequency == b.m_nl_arith_horner_frequency && + a.m_nl_arith_optimize_bounds == b.m_nl_arith_optimize_bounds && + a.m_nl_arith_propagate_linear_monomials == b.m_nl_arith_propagate_linear_monomials && + a.m_nl_arith_tangents == b.m_nl_arith_tangents; + } + + inline bool operator!=(const smt_params& a, const smt_params& b) { + return !(a == b); + } + struct cube_config { using literal = expr_ref; static bool literal_is_null(expr_ref const& l) { return l == nullptr; } @@ -112,35 +130,64 @@ namespace smt { scoped_ptr ctx; ast_translation m_l2g; - unsigned N = 4; // number of prefix permutation testers + unsigned N = 4; // number of prefix permutations to test (including current) unsigned m_max_prefix_conflicts = 1000; scoped_ptr m_prefix_solver; - scoped_ptr_vector m_testers; // N testers - smt_params m_best_param_state; + scoped_ptr_vector m_param_probe_contexts; + smt_params m_param_state; params_ref m_p; private: void init_param_state() { - m_best_param_state.m_nl_arith_branching = true; - m_best_param_state.m_nl_arith_cross_nested = true; - m_best_param_state.m_nl_arith_delay = 10; - m_best_param_state.m_nl_arith_expensive_patching = false; - m_best_param_state.m_nl_arith_gb = true; - m_best_param_state.m_nl_arith_horner = true; - m_best_param_state.m_nl_arith_horner_frequency = 4; - m_best_param_state.m_nl_arith_optimize_bounds = true; - m_best_param_state.m_nl_arith_propagate_linear_monomials = true; - m_best_param_state.m_nl_arith_tangents = true; + m_param_state.m_nl_arith_branching = true; + m_param_state.m_nl_arith_cross_nested = true; + m_param_state.m_nl_arith_delay = 10; + m_param_state.m_nl_arith_expensive_patching = false; + m_param_state.m_nl_arith_gb = true; + m_param_state.m_nl_arith_horner = true; + m_param_state.m_nl_arith_horner_frequency = 4; + m_param_state.m_nl_arith_optimize_bounds = true; + m_param_state.m_nl_arith_propagate_linear_monomials = true; + m_param_state.m_nl_arith_tangents = true; - m_best_param_state.updt_params(m_p); + m_param_state.updt_params(m_p); ctx->updt_params(m_p); }; + + smt_params mutate_param_state() { + smt_params p = m_param_state; + random_gen m_rand; + + auto flip_bool = [&](bool &x) { + if ((m_rand() % 2) == 0) + x = !x; + }; + + auto mutate_uint = [&](unsigned &x, unsigned lo, unsigned hi) { + if ((m_rand() % 2) == 0) + x = lo + (m_rand() % (hi - lo + 1)); + }; + + flip_bool(p.m_nl_arith_branching); + flip_bool(p.m_nl_arith_cross_nested); + mutate_uint(p.m_nl_arith_delay, 5, 20); + flip_bool(p.m_nl_arith_expensive_patching); + flip_bool(p.m_nl_arith_gb); + flip_bool(p.m_nl_arith_horner); + mutate_uint(p.m_nl_arith_horner_frequency, 2, 6); + flip_bool(p.m_nl_arith_optimize_bounds); + flip_bool(p.m_nl_arith_propagate_linear_monomials); + flip_bool(p.m_nl_arith_tangents); + + return p; + } + public: param_generator(parallel& p); lbool run_prefix_step(); void protocol_iteration(); - void replay_proof_prefixes(); + void replay_proof_prefixes(unsigned max_conflicts_epsilon); reslimit& limit() { return m.limit();