diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 198c099a0..63df29d29 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -137,7 +137,7 @@ namespace smt { scoped_ptr m_fmls; svector m_lit_scores[2]; - vector m_recorded_clauses; + vector m_recorded_cubes; // ----------------------------------- @@ -1302,7 +1302,7 @@ namespace smt { void add_scores(unsigned n, literal const *lits); - void record_clause(unsigned n, literal const * lits); + void record_cube(unsigned n, literal const * lits); // ----------------------------------- diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index be884fd18..6aa17d0b9 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -966,10 +966,17 @@ namespace smt { } // following the pattern of solver::persist_clause in src/sat/smt/user_solver.cpp - void context::record_clause(unsigned num_lits, literal const *lits) { - literal_vector clause; - clause.append(num_lits, lits); - m_recorded_clauses.push_back(clause); + void context::record_cube(unsigned num_lits, literal const *lits) { + expr_ref_vector cube(m); + for (unsigned i = 0; i < num_lits; ++i) { + literal lit = lits[i]; + expr* e = bool_var2expr(lit.var()); + if (!e) continue; + if (!lit.sign()) + e = m.mk_not(e); // only negate positive literal + cube.push_back(e); + } + m_recorded_cubes.push_back(cube); } void context::add_scores(unsigned n, literal const *lits) { @@ -1440,7 +1447,7 @@ namespace smt { case CLS_LEARNED: dump_lemma(num_lits, lits); add_scores(num_lits, lits); - record_clause(num_lits, lits); + record_cube(num_lits, lits); break; default: break; diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 96de76217..59c01a30d 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -83,36 +83,34 @@ namespace smt { return r; } - unsigned parallel::param_generator::replay_proof_prefixes(vector const& candidate_param_states, unsigned max_conflicts_epsilon=200) { + std::pair parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) { unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon; - unsigned best_param_state_idx; + param_values best_param_state; double best_score; + bool found_better_params = false; - for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) { + for (unsigned i = 0; i < N; ++i) { IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n"); - context *probe_ctx = m_param_probe_contexts[i]; + + // copy prefix solver context to a new probe_ctx for next replay with candidate mutation + scoped_ptr probe_ctx = alloc(context, m, ctx->get_fparams(), m_p); + context::copy(*ctx, *probe_ctx, true); + + // apply a candidate (mutated) param state to probe_ctx + // (except for the first iteration, use the current param state) + param_values mutated_param_state = m_param_state; + if (i > 0) { + mutated_param_state = mutate_param_state(); + params_ref p = apply_param_values(mutated_param_state); + probe_ctx->updt_params(p); + } + probe_ctx->get_fparams().m_max_conflicts = conflict_budget; double score = 0.0; - // apply the ith param state to probe_ctx - params_ref p = apply_param_values(candidate_param_states[i]); - probe_ctx->updt_params(p); - - // todo: m_recorded_cubes as a expr_ref_vector - - for (auto const& clause : probe_ctx->m_recorded_clauses) { - expr_ref_vector negated_lits(probe_ctx->m); - for (literal lit : clause) { - expr* e = probe_ctx->bool_var2expr(lit.var()); - if (!e) continue; // skip if var not yet mapped - if (!lit.sign()) - e = probe_ctx->m.mk_not(e); // since bool_var2expr discards sign - negated_lits.push_back(e); - } - - // Replay the negated clause - - lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data()); + // replay the cube (negation of the clause) + for (expr_ref_vector const& cube : probe_ctx->m_recorded_cubes) { + lbool r = probe_ctx->check(cube.size(), cube.data()); unsigned conflicts = probe_ctx->m_stats.m_num_conflicts; unsigned decisions = probe_ctx->m_stats.m_num_decisions; @@ -120,13 +118,16 @@ namespace smt { score += conflicts + decisions; } - if (i == 0 || score < best_score) { + if (i > 0 && score < best_score) { + found_better_params = true; + best_param_state = mutated_param_state; + best_score = score; + } else { best_score = score; - best_param_state_idx = i; } } - return best_param_state_idx; + return {best_param_state, found_better_params}; } void parallel::param_generator::init_param_state() { @@ -147,7 +148,6 @@ namespace smt { }; parallel::param_generator::param_values parallel::param_generator::mutate_param_state() { - param_values new_param_values(m_param_state); unsigned index = ctx->get_random_value() % new_param_values.size(); auto ¶m = new_param_values[index]; @@ -168,41 +168,23 @@ namespace smt { void parallel::param_generator::protocol_iteration() { IF_VERBOSE(1, verbose_stream() << " PARAM TUNER running protocol iteration\n"); - ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts; - - // copy current param state to all param probe contexts, before running the next prefix step - // this ensures that each param probe context replays the prefix from the same configuration - - // instead just one one context and reset it each time before copy. - for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) { - context::copy(*ctx, *m_param_probe_contexts[i], true); - } + ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts; lbool r = run_prefix_step(); switch (r) { case l_undef: { - // TODO, change from smt_params to a generic param state representation based on params_ref - // only params_ref have effect on updates. - param_values best_param_state = m_param_state; - vector candidate_param_states; + auto [best_param_state, found_better_params] = replay_proof_prefixes(); - // you can create the mutations on the fly and get the scores - // you don't have to copy all over each tester. - - - candidate_param_states.push_back(best_param_state); // first candidate param state is current best - while (candidate_param_states.size() <= N) { - candidate_param_states.push_back(mutate_param_state()); - } - - unsigned best_param_state_idx = replay_proof_prefixes(candidate_param_states); - - if (best_param_state_idx != 0) { - m_param_state = candidate_param_states[best_param_state_idx]; + // NOTE: we either need to return a pair from replay_proof_prefixes so we can return a boolean flag indicating whether better params were found. + // or, we have to implement a comparison operator for param_values + // or, we update the param state every single time even if it hasn't changed + // for now, I went with option 1 + if (found_better_params) { + m_param_state = best_param_state; auto p = apply_param_values(m_param_state); b.set_param_state(p); - IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state at index " << best_param_state_idx << "\n"); + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state\n"); } else { IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n"); } @@ -315,11 +297,6 @@ namespace smt { : p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) { ctx = alloc(context, m, p.ctx.get_fparams(), m_p); context::copy(p.ctx, *ctx, true); - - for (unsigned i = 0; i < N; ++i) { - m_param_probe_contexts.push_back(alloc(context, m, ctx->get_fparams(), m_p)); - } - // don't share initial units ctx->pop_to_base_lvl(); init_param_state(); diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index ec2967ed2..fab6029ab 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -27,23 +27,6 @@ Revision History: namespace smt { - inline bool operator==(const smt_params& a, const smt_params& b) { - return a.m_nl_arith_branching == b.m_nl_arith_branching && - a.m_nl_arith_cross_nested == b.m_nl_arith_cross_nested && - a.m_nl_arith_delay == b.m_nl_arith_delay && - a.m_nl_arith_expensive_patching == b.m_nl_arith_expensive_patching && - a.m_nl_arith_gb == b.m_nl_arith_gb && - a.m_nl_arith_horner == b.m_nl_arith_horner && - a.m_nl_arith_horner_frequency == b.m_nl_arith_horner_frequency && - a.m_nl_arith_optimize_bounds == b.m_nl_arith_optimize_bounds && - a.m_nl_arith_propagate_linear_monomials == b.m_nl_arith_propagate_linear_monomials && - a.m_nl_arith_tangents == b.m_nl_arith_tangents; - } - - inline bool operator!=(const smt_params& a, const smt_params& b) { - return !(a == b); - } - struct cube_config { using literal = expr_ref; static bool literal_is_null(expr_ref const& l) { return l == nullptr; } @@ -150,9 +133,18 @@ namespace smt { param_values m_param_state; params_ref apply_param_values(param_values const &pv) { - return m_p; + params_ref p = m_p; + for (auto const& [k, v] : pv) { + if (std::holds_alternative(v)) { + unsigned_value uv = std::get(v); + p.set_uint(k, uv.value); + } else if (std::holds_alternative(v)) { + bool bv = std::get(v); + p.set_bool(k, bv); + } + } + return p; } - // todo private: void init_param_state(); @@ -163,7 +155,7 @@ namespace smt { param_generator(parallel &p); lbool run_prefix_step(); void protocol_iteration(); - unsigned replay_proof_prefixes(vector const& candidate_param_states, unsigned max_conflicts_epsilon); + std::pair replay_proof_prefixes(unsigned max_conflicts_epsilon); reslimit &limit() { return m.limit();