mirror of
				https://github.com/Z3Prover/z3
				synced 2025-11-04 05:19:11 +00:00 
			
		
		
		
	merge
This commit is contained in:
		
						commit
						33060f7b97
					
				
					 2 changed files with 72 additions and 65 deletions
				
			
		| 
						 | 
				
			
			@ -83,7 +83,7 @@ namespace smt {
 | 
			
		|||
        return r;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    unsigned parallel::param_generator::replay_proof_prefixes(vector<smt_params> candidate_param_states, unsigned max_conflicts_epsilon=200) {
 | 
			
		||||
    unsigned parallel::param_generator::replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon=200) {
 | 
			
		||||
        unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
 | 
			
		||||
        unsigned best_param_state_idx;
 | 
			
		||||
        double best_score;
 | 
			
		||||
| 
						 | 
				
			
			@ -95,11 +95,11 @@ namespace smt {
 | 
			
		|||
            double score = 0.0;
 | 
			
		||||
 | 
			
		||||
            // apply the ith param state to probe_ctx
 | 
			
		||||
            smt_params params = candidate_param_states[i];
 | 
			
		||||
            params_ref p;
 | 
			
		||||
            params.updt_params(p);
 | 
			
		||||
            params_ref p = apply_param_values(candidate_param_states[i]);
 | 
			
		||||
            probe_ctx->updt_params(p);
 | 
			
		||||
 | 
			
		||||
            // todo: m_recorded_cubes as a expr_ref_vector
 | 
			
		||||
 | 
			
		||||
            for (auto const& clause : probe_ctx->m_recorded_clauses) {
 | 
			
		||||
                expr_ref_vector negated_lits(probe_ctx->m);
 | 
			
		||||
                for (literal lit : clause) {
 | 
			
		||||
| 
						 | 
				
			
			@ -111,15 +111,13 @@ namespace smt {
 | 
			
		|||
                }
 | 
			
		||||
 | 
			
		||||
                // Replay the negated clause
 | 
			
		||||
     
 | 
			
		||||
                lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data());
 | 
			
		||||
 | 
			
		||||
                ::statistics st;
 | 
			
		||||
                probe_ctx->collect_statistics(st);
 | 
			
		||||
                unsigned conflicts = 0, decisions = 0, rlimit = 0;
 | 
			
		||||
                conflicts = st.get_val("conflicts");
 | 
			
		||||
                decisions = st.get_val("decisions");
 | 
			
		||||
                rlimit = st.get_val("rlimit count");
 | 
			
		||||
                score += conflicts + decisions + rlimit;
 | 
			
		||||
                unsigned conflicts = probe_ctx->m_stats.m_num_conflicts;                
 | 
			
		||||
                unsigned decisions = probe_ctx->m_stats.m_num_decisions;
 | 
			
		||||
 | 
			
		||||
                score += conflicts + decisions;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (i == 0 || score < best_score) {
 | 
			
		||||
| 
						 | 
				
			
			@ -132,49 +130,40 @@ namespace smt {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    void parallel::param_generator::init_param_state() {
 | 
			
		||||
        // param_descrs smt_desc;
 | 
			
		||||
        // smt_params_helper::collect_param_descrs(smt_desc);
 | 
			
		||||
        smt_params_helper smtp(m_p);
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.branching"), smtp.arith_nl_branching());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.cross_nested"), smtp.arith_nl_cross_nested());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.delay"), smtp.arith_nl_delay());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.expensive_patching"), smtp.arith_nl_expensive_patching());
 | 
			
		||||
        // m_my_param_state.insert(symbol("smt.arith.nl.gb"), smtp.arith_nl_gb());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.horner"), smtp.arith_nl_horner());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.horner_frequency"), smtp.arith_nl_horner_frequency());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.optimize_bounds"), smtp.arith_nl_optimize_bounds());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials());
 | 
			
		||||
        m_my_param_state.insert(symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents());
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.branching"), smtp.arith_nl_branching()});
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.cross_nested"), smtp.arith_nl_cross_nested()});
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.delay"), unsigned_value({smtp.arith_nl_delay(), 5, 10})});
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.expensive_patching"), smtp.arith_nl_expensive_patching()});
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.gb"), smtp.arith_nl_grobner()});
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.horner"), smtp.arith_nl_horner()});
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.horner_frequency"), unsigned_value({smtp.arith_nl_horner_frequency(), 2, 6})
 | 
			
		||||
        });
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.optimize_bounds"), smtp.arith_nl_optimize_bounds()});
 | 
			
		||||
        m_param_state.push_back(
 | 
			
		||||
            {symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials()});
 | 
			
		||||
        m_param_state.push_back({symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents()});
 | 
			
		||||
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // TODO: this should mutate only one field at a time an mutate it based on m_my_param_state to keep it generic.
 | 
			
		||||
    parallel::param_generator::param_values parallel::param_generator::mutate_param_state() {
 | 
			
		||||
 | 
			
		||||
    smt_params parallel::param_generator::mutate_param_state() {
 | 
			
		||||
        smt_params p = m_param_state;
 | 
			
		||||
        random_gen m_rand;
 | 
			
		||||
 | 
			
		||||
        auto flip_bool = [&](bool &x) {
 | 
			
		||||
            if (m_rand(2) == 0)
 | 
			
		||||
                x = !x;
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        auto mutate_uint = [&](unsigned &x, unsigned lo, unsigned hi) {
 | 
			
		||||
            if ((m_rand() % 2) == 0)
 | 
			
		||||
                x = lo + (m_rand((hi - lo + 1)));
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        flip_bool(p.m_nl_arith_branching);
 | 
			
		||||
        flip_bool(p.m_nl_arith_cross_nested);
 | 
			
		||||
        mutate_uint(p.m_nl_arith_delay, 5, 20);
 | 
			
		||||
        flip_bool(p.m_nl_arith_expensive_patching);
 | 
			
		||||
        flip_bool(p.m_nl_arith_gb);
 | 
			
		||||
        flip_bool(p.m_nl_arith_horner);
 | 
			
		||||
        mutate_uint(p.m_nl_arith_horner_frequency, 2, 6);
 | 
			
		||||
        flip_bool(p.m_nl_arith_optimize_bounds);
 | 
			
		||||
        flip_bool(p.m_nl_arith_propagate_linear_monomials);
 | 
			
		||||
        flip_bool(p.m_nl_arith_tangents);
 | 
			
		||||
 | 
			
		||||
        return p;
 | 
			
		||||
        param_values new_param_values(m_param_state);
 | 
			
		||||
        unsigned index = ctx->get_random_value() % new_param_values.size();
 | 
			
		||||
        auto ¶m = new_param_values[index];
 | 
			
		||||
        if (std::holds_alternative<bool>(param.second)) {
 | 
			
		||||
            bool value = *std::get_if<bool>(¶m.second);
 | 
			
		||||
            param.second = !value;
 | 
			
		||||
        } 
 | 
			
		||||
        else if (std::holds_alternative<unsigned_value>(param.second)) {
 | 
			
		||||
            auto [value, lo, hi] = *std::get_if<unsigned_value>(¶m.second);
 | 
			
		||||
            unsigned new_value = value;
 | 
			
		||||
            while (new_value == value) {
 | 
			
		||||
                new_value = lo + ctx->get_random_value() % (hi - lo + 1);
 | 
			
		||||
            }
 | 
			
		||||
            std::get<unsigned_value>(param.second).value = new_value;
 | 
			
		||||
        }
 | 
			
		||||
        return new_param_values;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void parallel::param_generator::protocol_iteration() {
 | 
			
		||||
| 
						 | 
				
			
			@ -183,6 +172,8 @@ namespace smt {
 | 
			
		|||
 | 
			
		||||
        // copy current param state to all param probe contexts, before running the next prefix step
 | 
			
		||||
        // this ensures that each param probe context replays the prefix from the same configuration
 | 
			
		||||
 | 
			
		||||
        // instead just one one context and reset it each time before copy.
 | 
			
		||||
        for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) {
 | 
			
		||||
            context::copy(*ctx, *m_param_probe_contexts[i], true);
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -193,8 +184,12 @@ namespace smt {
 | 
			
		|||
            case l_undef: {
 | 
			
		||||
            // TODO, change from smt_params to a generic param state representation based on params_ref
 | 
			
		||||
                // only params_ref have effect on updates.
 | 
			
		||||
                smt_params best_param_state = m_param_state;
 | 
			
		||||
                vector<smt_params> candidate_param_states;
 | 
			
		||||
                param_values best_param_state = m_param_state;
 | 
			
		||||
                vector<param_values> candidate_param_states;
 | 
			
		||||
 | 
			
		||||
                // you can create the mutations on the fly and get the scores 
 | 
			
		||||
                // you don't have to copy all over each tester.
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
                candidate_param_states.push_back(best_param_state); // first candidate param state is current best
 | 
			
		||||
                while (candidate_param_states.size() <= N) {
 | 
			
		||||
| 
						 | 
				
			
			@ -205,7 +200,8 @@ namespace smt {
 | 
			
		|||
 | 
			
		||||
                if (best_param_state_idx != 0) {
 | 
			
		||||
                    m_param_state = candidate_param_states[best_param_state_idx];
 | 
			
		||||
                    b.set_param_state(m_param_state);
 | 
			
		||||
                    auto p = apply_param_values(m_param_state);
 | 
			
		||||
                    b.set_param_state(p);
 | 
			
		||||
                    IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state at index " << best_param_state_idx << "\n");
 | 
			
		||||
                } else {
 | 
			
		||||
                    IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n");
 | 
			
		||||
| 
						 | 
				
			
			@ -316,12 +312,12 @@ namespace smt {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    parallel::param_generator::param_generator(parallel& p)
 | 
			
		||||
        : p(p), b(p.m_batch_manager), m_param_state(p.ctx.get_fparams()), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
 | 
			
		||||
        ctx = alloc(context, m, m_param_state, m_p);
 | 
			
		||||
        : p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
 | 
			
		||||
        ctx = alloc(context, m, p.ctx.get_fparams(), m_p);
 | 
			
		||||
        context::copy(p.ctx, *ctx, true);
 | 
			
		||||
 | 
			
		||||
        for (unsigned i = 0; i < N; ++i) {
 | 
			
		||||
            m_param_probe_contexts.push_back(alloc(context, m, m_param_state, m_p));
 | 
			
		||||
            m_param_probe_contexts.push_back(alloc(context, m, ctx->get_fparams(), m_p));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // don't share initial units
 | 
			
		||||
| 
						 | 
				
			
			@ -481,7 +477,8 @@ namespace smt {
 | 
			
		|||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    smt_params parallel::batch_manager::get_best_param_state() {
 | 
			
		||||
    // todo make this thread safe by not using reference counts implicit in params ref but instead copying the entire structure.
 | 
			
		||||
    params_ref parallel::batch_manager::get_best_param_state() {
 | 
			
		||||
        std::scoped_lock lock(mux);
 | 
			
		||||
        return m_param_state;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -80,7 +80,7 @@ namespace smt {
 | 
			
		|||
            std::mutex mux;
 | 
			
		||||
            state m_state = state::is_running;
 | 
			
		||||
            stats m_stats;
 | 
			
		||||
            smt_params m_param_state;
 | 
			
		||||
            params_ref m_param_state;
 | 
			
		||||
            using node = search_tree::node<cube_config>;
 | 
			
		||||
            search_tree::tree<cube_config> m_search_tree;
 | 
			
		||||
            
 | 
			
		||||
| 
						 | 
				
			
			@ -105,10 +105,10 @@ namespace smt {
 | 
			
		|||
            void set_sat(ast_translation& l2g, model& m);
 | 
			
		||||
            void set_exception(std::string const& msg);
 | 
			
		||||
            void set_exception(unsigned error_code);
 | 
			
		||||
            void set_param_state(smt_params const& p) { m_param_state = p; }
 | 
			
		||||
            void set_param_state(params_ref const& p) { m_param_state.copy(p); }
 | 
			
		||||
            void collect_statistics(::statistics& st) const;
 | 
			
		||||
            
 | 
			
		||||
            smt_params get_best_param_state();
 | 
			
		||||
            params_ref get_best_param_state();
 | 
			
		||||
            bool get_cube(ast_translation& g2l, unsigned id, expr_ref_vector& cube, node*& n);
 | 
			
		||||
            void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n);
 | 
			
		||||
            void split(ast_translation& l2g, unsigned id, node* n, expr* atom);
 | 
			
		||||
| 
						 | 
				
			
			@ -138,22 +138,32 @@ namespace smt {
 | 
			
		|||
 | 
			
		||||
            scoped_ptr<context> m_prefix_solver;
 | 
			
		||||
            scoped_ptr_vector<context> m_param_probe_contexts;
 | 
			
		||||
            smt_params m_param_state;
 | 
			
		||||
            params_ref m_p;
 | 
			
		||||
 | 
			
		||||
            using param_value = std::variant<unsigned, bool, double>;
 | 
			
		||||
            symbol_table<param_value> m_my_param_state;
 | 
			
		||||
            struct unsigned_value {
 | 
			
		||||
                unsigned value;
 | 
			
		||||
                unsigned min_value;
 | 
			
		||||
                unsigned max_value;
 | 
			
		||||
            };
 | 
			
		||||
            using param_value = std::variant<unsigned_value, bool>;
 | 
			
		||||
            using param_values = vector<std::pair<symbol, param_value>>;
 | 
			
		||||
            param_values m_param_state;
 | 
			
		||||
 | 
			
		||||
            params_ref apply_param_values(param_values const &pv) {
 | 
			
		||||
                return m_p;
 | 
			
		||||
            }
 | 
			
		||||
            // todo
 | 
			
		||||
 | 
			
		||||
        private:
 | 
			
		||||
            void init_param_state();
 | 
			
		||||
 | 
			
		||||
            smt_params mutate_param_state();
 | 
			
		||||
            param_values mutate_param_state();
 | 
			
		||||
 | 
			
		||||
        public:
 | 
			
		||||
            param_generator(parallel &p);
 | 
			
		||||
            lbool run_prefix_step();
 | 
			
		||||
            void protocol_iteration();
 | 
			
		||||
            unsigned replay_proof_prefixes(vector<smt_params> candidate_param_states, unsigned max_conflicts_epsilon);
 | 
			
		||||
            unsigned replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon);
 | 
			
		||||
 | 
			
		||||
            reslimit &limit() {
 | 
			
		||||
                return m.limit();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue