mirror of
https://github.com/Z3Prover/z3
synced 2025-11-05 05:49:13 +00:00
updates to param tuning
This commit is contained in:
parent
33060f7b97
commit
57d7e9fcf5
4 changed files with 62 additions and 86 deletions
|
|
@ -137,7 +137,7 @@ namespace smt {
|
||||||
scoped_ptr<base_dependent_expr_state> m_fmls;
|
scoped_ptr<base_dependent_expr_state> m_fmls;
|
||||||
|
|
||||||
svector<double> m_lit_scores[2];
|
svector<double> m_lit_scores[2];
|
||||||
vector<literal_vector> m_recorded_clauses;
|
vector<expr_ref_vector> m_recorded_cubes;
|
||||||
|
|
||||||
|
|
||||||
// -----------------------------------
|
// -----------------------------------
|
||||||
|
|
@ -1302,7 +1302,7 @@ namespace smt {
|
||||||
|
|
||||||
void add_scores(unsigned n, literal const *lits);
|
void add_scores(unsigned n, literal const *lits);
|
||||||
|
|
||||||
void record_clause(unsigned n, literal const * lits);
|
void record_cube(unsigned n, literal const * lits);
|
||||||
|
|
||||||
|
|
||||||
// -----------------------------------
|
// -----------------------------------
|
||||||
|
|
|
||||||
|
|
@ -966,10 +966,17 @@ namespace smt {
|
||||||
}
|
}
|
||||||
|
|
||||||
// following the pattern of solver::persist_clause in src/sat/smt/user_solver.cpp
|
// following the pattern of solver::persist_clause in src/sat/smt/user_solver.cpp
|
||||||
void context::record_clause(unsigned num_lits, literal const *lits) {
|
void context::record_cube(unsigned num_lits, literal const *lits) {
|
||||||
literal_vector clause;
|
expr_ref_vector cube(m);
|
||||||
clause.append(num_lits, lits);
|
for (unsigned i = 0; i < num_lits; ++i) {
|
||||||
m_recorded_clauses.push_back(clause);
|
literal lit = lits[i];
|
||||||
|
expr* e = bool_var2expr(lit.var());
|
||||||
|
if (!e) continue;
|
||||||
|
if (!lit.sign())
|
||||||
|
e = m.mk_not(e); // only negate positive literal
|
||||||
|
cube.push_back(e);
|
||||||
|
}
|
||||||
|
m_recorded_cubes.push_back(cube);
|
||||||
}
|
}
|
||||||
|
|
||||||
void context::add_scores(unsigned n, literal const *lits) {
|
void context::add_scores(unsigned n, literal const *lits) {
|
||||||
|
|
@ -1440,7 +1447,7 @@ namespace smt {
|
||||||
case CLS_LEARNED:
|
case CLS_LEARNED:
|
||||||
dump_lemma(num_lits, lits);
|
dump_lemma(num_lits, lits);
|
||||||
add_scores(num_lits, lits);
|
add_scores(num_lits, lits);
|
||||||
record_clause(num_lits, lits);
|
record_cube(num_lits, lits);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -83,36 +83,34 @@ namespace smt {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned parallel::param_generator::replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon=200) {
|
std::pair<parallel::param_generator::param_values, bool> parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) {
|
||||||
unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
|
unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
|
||||||
unsigned best_param_state_idx;
|
param_values best_param_state;
|
||||||
double best_score;
|
double best_score;
|
||||||
|
bool found_better_params = false;
|
||||||
|
|
||||||
for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) {
|
for (unsigned i = 0; i < N; ++i) {
|
||||||
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n");
|
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n");
|
||||||
context *probe_ctx = m_param_probe_contexts[i];
|
|
||||||
|
// copy prefix solver context to a new probe_ctx for next replay with candidate mutation
|
||||||
|
scoped_ptr<context> probe_ctx = alloc(context, m, ctx->get_fparams(), m_p);
|
||||||
|
context::copy(*ctx, *probe_ctx, true);
|
||||||
|
|
||||||
|
// apply a candidate (mutated) param state to probe_ctx
|
||||||
|
// (except for the first iteration, use the current param state)
|
||||||
|
param_values mutated_param_state = m_param_state;
|
||||||
|
if (i > 0) {
|
||||||
|
mutated_param_state = mutate_param_state();
|
||||||
|
params_ref p = apply_param_values(mutated_param_state);
|
||||||
|
probe_ctx->updt_params(p);
|
||||||
|
}
|
||||||
|
|
||||||
probe_ctx->get_fparams().m_max_conflicts = conflict_budget;
|
probe_ctx->get_fparams().m_max_conflicts = conflict_budget;
|
||||||
double score = 0.0;
|
double score = 0.0;
|
||||||
|
|
||||||
// apply the ith param state to probe_ctx
|
// replay the cube (negation of the clause)
|
||||||
params_ref p = apply_param_values(candidate_param_states[i]);
|
for (expr_ref_vector const& cube : probe_ctx->m_recorded_cubes) {
|
||||||
probe_ctx->updt_params(p);
|
lbool r = probe_ctx->check(cube.size(), cube.data());
|
||||||
|
|
||||||
// todo: m_recorded_cubes as a expr_ref_vector
|
|
||||||
|
|
||||||
for (auto const& clause : probe_ctx->m_recorded_clauses) {
|
|
||||||
expr_ref_vector negated_lits(probe_ctx->m);
|
|
||||||
for (literal lit : clause) {
|
|
||||||
expr* e = probe_ctx->bool_var2expr(lit.var());
|
|
||||||
if (!e) continue; // skip if var not yet mapped
|
|
||||||
if (!lit.sign())
|
|
||||||
e = probe_ctx->m.mk_not(e); // since bool_var2expr discards sign
|
|
||||||
negated_lits.push_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replay the negated clause
|
|
||||||
|
|
||||||
lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data());
|
|
||||||
|
|
||||||
unsigned conflicts = probe_ctx->m_stats.m_num_conflicts;
|
unsigned conflicts = probe_ctx->m_stats.m_num_conflicts;
|
||||||
unsigned decisions = probe_ctx->m_stats.m_num_decisions;
|
unsigned decisions = probe_ctx->m_stats.m_num_decisions;
|
||||||
|
|
@ -120,13 +118,16 @@ namespace smt {
|
||||||
score += conflicts + decisions;
|
score += conflicts + decisions;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i == 0 || score < best_score) {
|
if (i > 0 && score < best_score) {
|
||||||
|
found_better_params = true;
|
||||||
|
best_param_state = mutated_param_state;
|
||||||
|
best_score = score;
|
||||||
|
} else {
|
||||||
best_score = score;
|
best_score = score;
|
||||||
best_param_state_idx = i;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return best_param_state_idx;
|
return {best_param_state, found_better_params};
|
||||||
}
|
}
|
||||||
|
|
||||||
void parallel::param_generator::init_param_state() {
|
void parallel::param_generator::init_param_state() {
|
||||||
|
|
@ -147,7 +148,6 @@ namespace smt {
|
||||||
};
|
};
|
||||||
|
|
||||||
parallel::param_generator::param_values parallel::param_generator::mutate_param_state() {
|
parallel::param_generator::param_values parallel::param_generator::mutate_param_state() {
|
||||||
|
|
||||||
param_values new_param_values(m_param_state);
|
param_values new_param_values(m_param_state);
|
||||||
unsigned index = ctx->get_random_value() % new_param_values.size();
|
unsigned index = ctx->get_random_value() % new_param_values.size();
|
||||||
auto ¶m = new_param_values[index];
|
auto ¶m = new_param_values[index];
|
||||||
|
|
@ -168,41 +168,23 @@ namespace smt {
|
||||||
|
|
||||||
void parallel::param_generator::protocol_iteration() {
|
void parallel::param_generator::protocol_iteration() {
|
||||||
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER running protocol iteration\n");
|
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER running protocol iteration\n");
|
||||||
|
|
||||||
ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts;
|
ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts;
|
||||||
|
|
||||||
// copy current param state to all param probe contexts, before running the next prefix step
|
|
||||||
// this ensures that each param probe context replays the prefix from the same configuration
|
|
||||||
|
|
||||||
// instead just one one context and reset it each time before copy.
|
|
||||||
for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) {
|
|
||||||
context::copy(*ctx, *m_param_probe_contexts[i], true);
|
|
||||||
}
|
|
||||||
|
|
||||||
lbool r = run_prefix_step();
|
lbool r = run_prefix_step();
|
||||||
|
|
||||||
switch (r) {
|
switch (r) {
|
||||||
case l_undef: {
|
case l_undef: {
|
||||||
// TODO, change from smt_params to a generic param state representation based on params_ref
|
auto [best_param_state, found_better_params] = replay_proof_prefixes();
|
||||||
// only params_ref have effect on updates.
|
|
||||||
param_values best_param_state = m_param_state;
|
|
||||||
vector<param_values> candidate_param_states;
|
|
||||||
|
|
||||||
// you can create the mutations on the fly and get the scores
|
// NOTE: we either need to return a pair from replay_proof_prefixes so we can return a boolean flag indicating whether better params were found.
|
||||||
// you don't have to copy all over each tester.
|
// or, we have to implement a comparison operator for param_values
|
||||||
|
// or, we update the param state every single time even if it hasn't changed
|
||||||
|
// for now, I went with option 1
|
||||||
candidate_param_states.push_back(best_param_state); // first candidate param state is current best
|
if (found_better_params) {
|
||||||
while (candidate_param_states.size() <= N) {
|
m_param_state = best_param_state;
|
||||||
candidate_param_states.push_back(mutate_param_state());
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned best_param_state_idx = replay_proof_prefixes(candidate_param_states);
|
|
||||||
|
|
||||||
if (best_param_state_idx != 0) {
|
|
||||||
m_param_state = candidate_param_states[best_param_state_idx];
|
|
||||||
auto p = apply_param_values(m_param_state);
|
auto p = apply_param_values(m_param_state);
|
||||||
b.set_param_state(p);
|
b.set_param_state(p);
|
||||||
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state at index " << best_param_state_idx << "\n");
|
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state\n");
|
||||||
} else {
|
} else {
|
||||||
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n");
|
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n");
|
||||||
}
|
}
|
||||||
|
|
@ -315,11 +297,6 @@ namespace smt {
|
||||||
: p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
|
: p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
|
||||||
ctx = alloc(context, m, p.ctx.get_fparams(), m_p);
|
ctx = alloc(context, m, p.ctx.get_fparams(), m_p);
|
||||||
context::copy(p.ctx, *ctx, true);
|
context::copy(p.ctx, *ctx, true);
|
||||||
|
|
||||||
for (unsigned i = 0; i < N; ++i) {
|
|
||||||
m_param_probe_contexts.push_back(alloc(context, m, ctx->get_fparams(), m_p));
|
|
||||||
}
|
|
||||||
|
|
||||||
// don't share initial units
|
// don't share initial units
|
||||||
ctx->pop_to_base_lvl();
|
ctx->pop_to_base_lvl();
|
||||||
init_param_state();
|
init_param_state();
|
||||||
|
|
|
||||||
|
|
@ -27,23 +27,6 @@ Revision History:
|
||||||
|
|
||||||
namespace smt {
|
namespace smt {
|
||||||
|
|
||||||
inline bool operator==(const smt_params& a, const smt_params& b) {
|
|
||||||
return a.m_nl_arith_branching == b.m_nl_arith_branching &&
|
|
||||||
a.m_nl_arith_cross_nested == b.m_nl_arith_cross_nested &&
|
|
||||||
a.m_nl_arith_delay == b.m_nl_arith_delay &&
|
|
||||||
a.m_nl_arith_expensive_patching == b.m_nl_arith_expensive_patching &&
|
|
||||||
a.m_nl_arith_gb == b.m_nl_arith_gb &&
|
|
||||||
a.m_nl_arith_horner == b.m_nl_arith_horner &&
|
|
||||||
a.m_nl_arith_horner_frequency == b.m_nl_arith_horner_frequency &&
|
|
||||||
a.m_nl_arith_optimize_bounds == b.m_nl_arith_optimize_bounds &&
|
|
||||||
a.m_nl_arith_propagate_linear_monomials == b.m_nl_arith_propagate_linear_monomials &&
|
|
||||||
a.m_nl_arith_tangents == b.m_nl_arith_tangents;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool operator!=(const smt_params& a, const smt_params& b) {
|
|
||||||
return !(a == b);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct cube_config {
|
struct cube_config {
|
||||||
using literal = expr_ref;
|
using literal = expr_ref;
|
||||||
static bool literal_is_null(expr_ref const& l) { return l == nullptr; }
|
static bool literal_is_null(expr_ref const& l) { return l == nullptr; }
|
||||||
|
|
@ -150,9 +133,18 @@ namespace smt {
|
||||||
param_values m_param_state;
|
param_values m_param_state;
|
||||||
|
|
||||||
params_ref apply_param_values(param_values const &pv) {
|
params_ref apply_param_values(param_values const &pv) {
|
||||||
return m_p;
|
params_ref p = m_p;
|
||||||
|
for (auto const& [k, v] : pv) {
|
||||||
|
if (std::holds_alternative<unsigned_value>(v)) {
|
||||||
|
unsigned_value uv = std::get<unsigned_value>(v);
|
||||||
|
p.set_uint(k, uv.value);
|
||||||
|
} else if (std::holds_alternative<bool>(v)) {
|
||||||
|
bool bv = std::get<bool>(v);
|
||||||
|
p.set_bool(k, bv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p;
|
||||||
}
|
}
|
||||||
// todo
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void init_param_state();
|
void init_param_state();
|
||||||
|
|
@ -163,7 +155,7 @@ namespace smt {
|
||||||
param_generator(parallel &p);
|
param_generator(parallel &p);
|
||||||
lbool run_prefix_step();
|
lbool run_prefix_step();
|
||||||
void protocol_iteration();
|
void protocol_iteration();
|
||||||
unsigned replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon);
|
std::pair<parallel::param_generator::param_values, bool> replay_proof_prefixes(unsigned max_conflicts_epsilon);
|
||||||
|
|
||||||
reslimit &limit() {
|
reslimit &limit() {
|
||||||
return m.limit();
|
return m.limit();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue