mirror of
				https://github.com/Z3Prover/z3
				synced 2025-11-04 05:19:11 +00:00 
			
		
		
		
	updates to param tuning
This commit is contained in:
		
							parent
							
								
									33060f7b97
								
							
						
					
					
						commit
						57d7e9fcf5
					
				
					 4 changed files with 62 additions and 86 deletions
				
			
		| 
						 | 
				
			
			@ -137,7 +137,7 @@ namespace smt {
 | 
			
		|||
        scoped_ptr<base_dependent_expr_state> m_fmls;
 | 
			
		||||
 | 
			
		||||
        svector<double> m_lit_scores[2];
 | 
			
		||||
        vector<literal_vector> m_recorded_clauses;
 | 
			
		||||
        vector<expr_ref_vector> m_recorded_cubes;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        // -----------------------------------
 | 
			
		||||
| 
						 | 
				
			
			@ -1302,7 +1302,7 @@ namespace smt {
 | 
			
		|||
 | 
			
		||||
        void add_scores(unsigned n, literal const *lits);
 | 
			
		||||
 | 
			
		||||
        void record_clause(unsigned n, literal const * lits);
 | 
			
		||||
        void record_cube(unsigned n, literal const * lits);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        // -----------------------------------
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -966,10 +966,17 @@ namespace smt {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    // following the pattern of solver::persist_clause in src/sat/smt/user_solver.cpp
 | 
			
		||||
    void context::record_clause(unsigned num_lits, literal const *lits) {
 | 
			
		||||
        literal_vector clause;
 | 
			
		||||
        clause.append(num_lits, lits);
 | 
			
		||||
        m_recorded_clauses.push_back(clause);
 | 
			
		||||
    void context::record_cube(unsigned num_lits, literal const *lits) {
 | 
			
		||||
        expr_ref_vector cube(m);
 | 
			
		||||
        for (unsigned i = 0; i < num_lits; ++i) {
 | 
			
		||||
            literal lit = lits[i];
 | 
			
		||||
            expr* e = bool_var2expr(lit.var());
 | 
			
		||||
            if (!e) continue;
 | 
			
		||||
            if (!lit.sign())
 | 
			
		||||
                e = m.mk_not(e);  // only negate positive literal
 | 
			
		||||
            cube.push_back(e);
 | 
			
		||||
        }
 | 
			
		||||
        m_recorded_cubes.push_back(cube);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void context::add_scores(unsigned n, literal const *lits) {
 | 
			
		||||
| 
						 | 
				
			
			@ -1440,7 +1447,7 @@ namespace smt {
 | 
			
		|||
        case CLS_LEARNED:
 | 
			
		||||
            dump_lemma(num_lits, lits);
 | 
			
		||||
            add_scores(num_lits, lits);
 | 
			
		||||
            record_clause(num_lits, lits);
 | 
			
		||||
            record_cube(num_lits, lits);
 | 
			
		||||
            break;
 | 
			
		||||
        default:
 | 
			
		||||
            break;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -83,36 +83,34 @@ namespace smt {
 | 
			
		|||
        return r;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    unsigned parallel::param_generator::replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon=200) {
 | 
			
		||||
    std::pair<parallel::param_generator::param_values, bool> parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) {
 | 
			
		||||
        unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
 | 
			
		||||
        unsigned best_param_state_idx;
 | 
			
		||||
        param_values best_param_state;
 | 
			
		||||
        double best_score;
 | 
			
		||||
        bool found_better_params = false;
 | 
			
		||||
 | 
			
		||||
        for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) {
 | 
			
		||||
        for (unsigned i = 0; i < N; ++i) {
 | 
			
		||||
            IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n");
 | 
			
		||||
            context *probe_ctx = m_param_probe_contexts[i];
 | 
			
		||||
 | 
			
		||||
            // copy prefix solver context to a new probe_ctx for next replay with candidate mutation
 | 
			
		||||
            scoped_ptr<context> probe_ctx = alloc(context, m, ctx->get_fparams(), m_p);
 | 
			
		||||
            context::copy(*ctx, *probe_ctx, true);
 | 
			
		||||
 | 
			
		||||
            // apply a candidate (mutated) param state to probe_ctx
 | 
			
		||||
            // (except for the first iteration, use the current param state)
 | 
			
		||||
            param_values mutated_param_state = m_param_state;
 | 
			
		||||
            if (i > 0) {
 | 
			
		||||
                mutated_param_state = mutate_param_state();
 | 
			
		||||
                params_ref p = apply_param_values(mutated_param_state);
 | 
			
		||||
                probe_ctx->updt_params(p);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            probe_ctx->get_fparams().m_max_conflicts = conflict_budget;
 | 
			
		||||
            double score = 0.0;
 | 
			
		||||
 | 
			
		||||
            // apply the ith param state to probe_ctx
 | 
			
		||||
            params_ref p = apply_param_values(candidate_param_states[i]);
 | 
			
		||||
            probe_ctx->updt_params(p);
 | 
			
		||||
 | 
			
		||||
            // todo: m_recorded_cubes as a expr_ref_vector
 | 
			
		||||
 | 
			
		||||
            for (auto const& clause : probe_ctx->m_recorded_clauses) {
 | 
			
		||||
                expr_ref_vector negated_lits(probe_ctx->m);
 | 
			
		||||
                for (literal lit : clause) {
 | 
			
		||||
                    expr* e = probe_ctx->bool_var2expr(lit.var());
 | 
			
		||||
                    if (!e) continue;  // skip if var not yet mapped
 | 
			
		||||
                    if (!lit.sign())
 | 
			
		||||
                        e = probe_ctx->m.mk_not(e); // since bool_var2expr discards sign
 | 
			
		||||
                    negated_lits.push_back(e);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                // Replay the negated clause
 | 
			
		||||
     
 | 
			
		||||
                lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data());
 | 
			
		||||
            // replay the cube (negation of the clause)
 | 
			
		||||
            for (expr_ref_vector const& cube : probe_ctx->m_recorded_cubes) {
 | 
			
		||||
                lbool r = probe_ctx->check(cube.size(), cube.data());
 | 
			
		||||
 | 
			
		||||
                unsigned conflicts = probe_ctx->m_stats.m_num_conflicts;                
 | 
			
		||||
                unsigned decisions = probe_ctx->m_stats.m_num_decisions;
 | 
			
		||||
| 
						 | 
				
			
			@ -120,13 +118,16 @@ namespace smt {
 | 
			
		|||
                score += conflicts + decisions;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (i == 0 || score < best_score) {
 | 
			
		||||
            if (i > 0 && score < best_score) {
 | 
			
		||||
                found_better_params = true;
 | 
			
		||||
                best_param_state = mutated_param_state;
 | 
			
		||||
                best_score = score;
 | 
			
		||||
            } else {
 | 
			
		||||
                best_score = score;
 | 
			
		||||
                best_param_state_idx = i;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return best_param_state_idx;
 | 
			
		||||
        return {best_param_state, found_better_params};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void parallel::param_generator::init_param_state() {
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +148,6 @@ namespace smt {
 | 
			
		|||
    };
 | 
			
		||||
 | 
			
		||||
    parallel::param_generator::param_values parallel::param_generator::mutate_param_state() {
 | 
			
		||||
 | 
			
		||||
        param_values new_param_values(m_param_state);
 | 
			
		||||
        unsigned index = ctx->get_random_value() % new_param_values.size();
 | 
			
		||||
        auto ¶m = new_param_values[index];
 | 
			
		||||
| 
						 | 
				
			
			@ -168,41 +168,23 @@ namespace smt {
 | 
			
		|||
 | 
			
		||||
    void parallel::param_generator::protocol_iteration() {
 | 
			
		||||
        IF_VERBOSE(1, verbose_stream() << " PARAM TUNER running protocol iteration\n");
 | 
			
		||||
        ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts;
 | 
			
		||||
 | 
			
		||||
        // copy current param state to all param probe contexts, before running the next prefix step
 | 
			
		||||
        // this ensures that each param probe context replays the prefix from the same configuration
 | 
			
		||||
 | 
			
		||||
        // instead just one one context and reset it each time before copy.
 | 
			
		||||
        for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) {
 | 
			
		||||
            context::copy(*ctx, *m_param_probe_contexts[i], true);
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        ctx->get_fparams().m_max_conflicts = m_max_prefix_conflicts;
 | 
			
		||||
        lbool r = run_prefix_step();
 | 
			
		||||
 | 
			
		||||
        switch (r) {
 | 
			
		||||
            case l_undef: {
 | 
			
		||||
            // TODO, change from smt_params to a generic param state representation based on params_ref
 | 
			
		||||
                // only params_ref have effect on updates.
 | 
			
		||||
                param_values best_param_state = m_param_state;
 | 
			
		||||
                vector<param_values> candidate_param_states;
 | 
			
		||||
                auto [best_param_state, found_better_params] = replay_proof_prefixes();
 | 
			
		||||
 | 
			
		||||
                // you can create the mutations on the fly and get the scores 
 | 
			
		||||
                // you don't have to copy all over each tester.
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
                candidate_param_states.push_back(best_param_state); // first candidate param state is current best
 | 
			
		||||
                while (candidate_param_states.size() <= N) {
 | 
			
		||||
                    candidate_param_states.push_back(mutate_param_state());
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                unsigned best_param_state_idx = replay_proof_prefixes(candidate_param_states);
 | 
			
		||||
 | 
			
		||||
                if (best_param_state_idx != 0) {
 | 
			
		||||
                    m_param_state = candidate_param_states[best_param_state_idx];
 | 
			
		||||
                // NOTE: we either need to return a pair from replay_proof_prefixes so we can return a boolean flag indicating whether better params were found.
 | 
			
		||||
                // or, we have to implement a comparison operator for param_values
 | 
			
		||||
                // or, we update the param state every single time even if it hasn't changed
 | 
			
		||||
                // for now, I went with option 1
 | 
			
		||||
                if (found_better_params) {
 | 
			
		||||
                    m_param_state = best_param_state;
 | 
			
		||||
                    auto p = apply_param_values(m_param_state);
 | 
			
		||||
                    b.set_param_state(p);
 | 
			
		||||
                    IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state at index " << best_param_state_idx << "\n");
 | 
			
		||||
                    IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state\n");
 | 
			
		||||
                } else {
 | 
			
		||||
                    IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n");
 | 
			
		||||
                }
 | 
			
		||||
| 
						 | 
				
			
			@ -315,11 +297,6 @@ namespace smt {
 | 
			
		|||
        : p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
 | 
			
		||||
        ctx = alloc(context, m, p.ctx.get_fparams(), m_p);
 | 
			
		||||
        context::copy(p.ctx, *ctx, true);
 | 
			
		||||
 | 
			
		||||
        for (unsigned i = 0; i < N; ++i) {
 | 
			
		||||
            m_param_probe_contexts.push_back(alloc(context, m, ctx->get_fparams(), m_p));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // don't share initial units
 | 
			
		||||
        ctx->pop_to_base_lvl();
 | 
			
		||||
        init_param_state();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,23 +27,6 @@ Revision History:
 | 
			
		|||
 | 
			
		||||
namespace smt {
 | 
			
		||||
 | 
			
		||||
  inline bool operator==(const smt_params& a, const smt_params& b) {
 | 
			
		||||
      return a.m_nl_arith_branching == b.m_nl_arith_branching &&
 | 
			
		||||
            a.m_nl_arith_cross_nested == b.m_nl_arith_cross_nested &&
 | 
			
		||||
            a.m_nl_arith_delay == b.m_nl_arith_delay &&
 | 
			
		||||
            a.m_nl_arith_expensive_patching == b.m_nl_arith_expensive_patching &&
 | 
			
		||||
            a.m_nl_arith_gb == b.m_nl_arith_gb &&
 | 
			
		||||
            a.m_nl_arith_horner == b.m_nl_arith_horner &&
 | 
			
		||||
            a.m_nl_arith_horner_frequency == b.m_nl_arith_horner_frequency &&
 | 
			
		||||
            a.m_nl_arith_optimize_bounds == b.m_nl_arith_optimize_bounds &&
 | 
			
		||||
            a.m_nl_arith_propagate_linear_monomials == b.m_nl_arith_propagate_linear_monomials &&
 | 
			
		||||
            a.m_nl_arith_tangents == b.m_nl_arith_tangents;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  inline bool operator!=(const smt_params& a, const smt_params& b) {
 | 
			
		||||
      return !(a == b);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    struct cube_config {
 | 
			
		||||
        using literal = expr_ref;
 | 
			
		||||
        static bool literal_is_null(expr_ref const& l) { return l == nullptr; }
 | 
			
		||||
| 
						 | 
				
			
			@ -150,9 +133,18 @@ namespace smt {
 | 
			
		|||
            param_values m_param_state;
 | 
			
		||||
 | 
			
		||||
            params_ref apply_param_values(param_values const &pv) {
 | 
			
		||||
                return m_p;
 | 
			
		||||
                params_ref p = m_p;
 | 
			
		||||
                for (auto const& [k, v] : pv) {
 | 
			
		||||
                    if (std::holds_alternative<unsigned_value>(v)) {
 | 
			
		||||
                        unsigned_value uv = std::get<unsigned_value>(v);
 | 
			
		||||
                        p.set_uint(k, uv.value);
 | 
			
		||||
                    } else if (std::holds_alternative<bool>(v)) {
 | 
			
		||||
                        bool bv = std::get<bool>(v);
 | 
			
		||||
                        p.set_bool(k, bv);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                return p;
 | 
			
		||||
            }
 | 
			
		||||
            // todo
 | 
			
		||||
 | 
			
		||||
        private:
 | 
			
		||||
            void init_param_state();
 | 
			
		||||
| 
						 | 
				
			
			@ -163,7 +155,7 @@ namespace smt {
 | 
			
		|||
            param_generator(parallel &p);
 | 
			
		||||
            lbool run_prefix_step();
 | 
			
		||||
            void protocol_iteration();
 | 
			
		||||
            unsigned replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon);
 | 
			
		||||
            std::pair<parallel::param_generator::param_values, bool> replay_proof_prefixes(unsigned max_conflicts_epsilon);
 | 
			
		||||
 | 
			
		||||
            reslimit &limit() {
 | 
			
		||||
                return m.limit();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue