mirror of
				https://github.com/Z3Prover/z3
				synced 2025-10-31 19:52:29 +00:00 
			
		
		
		
	setting up python tuning experiment, not done
This commit is contained in:
		
							parent
							
								
									b9fb032a67
								
							
						
					
					
						commit
						193845c753
					
				
					 1 changed files with 199 additions and 0 deletions
				
			
		
							
								
								
									
										199
									
								
								param-tuning-experiment.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								param-tuning-experiment.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,199 @@ | |||
| import os | ||||
| import z3 | ||||
| 
 | ||||
| MAX_CONFLICTS = 1000 | ||||
| MAX_EXAMPLES = 5 | ||||
| bench_dir = "C:/tmp/parameter-tuning" | ||||
| 
 | ||||
| # Baseline parameter candidates (you can grow this) | ||||
| BASE_PARAM_CANDIDATES = [ | ||||
|     ("smt.arith.eager_eq_axioms", False), | ||||
|     ("smt.restart_factor", 1.2), | ||||
|     ("smt.relevancy", 0), | ||||
|     ("smt.phase_caching_off", 200), | ||||
|     ("smt.phase_caching_on", 600), | ||||
| ] | ||||
| 
 | ||||
| # -------------------------- | ||||
| # One class: BatchManager | ||||
| # -------------------------- | ||||
| class BatchManager: | ||||
|     def __init__(self): | ||||
|         self.best_param_state = None | ||||
|         self.best_score = (10**9, 10**9, 10**9)  # (conflicts, decisions, rlimit) | ||||
|         self.search_complete = False | ||||
| 
 | ||||
|     def mark_complete(self): | ||||
|         self.search_complete = True | ||||
| 
 | ||||
|     def maybe_update_best(self, param_state, triple): | ||||
|         if self._better(triple, self.best_score): | ||||
|             self.best_param_state = list(param_state) | ||||
|             self.best_score = triple | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _better(a, b): | ||||
|         return a < b  # lexicographic | ||||
| 
 | ||||
| # ------------------- | ||||
| # Helpers | ||||
| # ------------------- | ||||
| 
 | ||||
| def get_stat_int(st, key): | ||||
|     try: | ||||
|         v = st.get_key_value(key) | ||||
|         if isinstance(v, (int, float)): | ||||
|             return int(v) | ||||
|     except Exception: | ||||
|         pass | ||||
|     if key == "decisions" and hasattr(st, "decisions"): | ||||
|         try: | ||||
|             return int(st.decisions()) | ||||
|         except Exception: | ||||
|             return 0 | ||||
|     return 0 | ||||
| 
 | ||||
| def solver_from_file(filepath): | ||||
|     s = z3.Solver() | ||||
|     s.set("smt.auto_config", False) | ||||
|     s.from_file(filepath) | ||||
|     return s | ||||
| 
 | ||||
| def apply_param_state(s, param_state): | ||||
|     for name, value in param_state: | ||||
|         s.set(name, value) | ||||
| 
 | ||||
| def stats_tuple(st): | ||||
|     return ( | ||||
|         get_stat_int(st, "conflicts"), | ||||
|         get_stat_int(st, "decisions"), | ||||
|         get_stat_int(st, "rlimit count"), | ||||
|     ) | ||||
| 
 | ||||
| # -------------------------- | ||||
| # Protocol steps | ||||
| # -------------------------- | ||||
| 
 | ||||
| def run_prefix_step(S, K): | ||||
|     S.set("smt.K", K) | ||||
|     r = S.check() | ||||
|     return r, S.statistics() | ||||
| 
 | ||||
| def collect_conflict_clauses_placeholder(S, limit = 4): | ||||
|     return [] | ||||
| 
 | ||||
| def replay_prefix_on_pps(filepath, clauses, param_state, budget): | ||||
|     if not clauses: | ||||
|         s = solver_from_file(filepath) | ||||
|         apply_param_state(s, param_state) | ||||
|         s.set("smt.K", budget) | ||||
|         _ = s.check() | ||||
|         st = s.statistics() | ||||
|         return stats_tuple(st) | ||||
| 
 | ||||
|     total_conflicts = 0 | ||||
|     total_decisions = 0 | ||||
|     total_rlimit = 0 | ||||
| 
 | ||||
|     PPS = solver_from_file(filepath) | ||||
|     apply_param_state(PPS, param_state) | ||||
| 
 | ||||
|     for Cj in clauses: | ||||
|         PPS.set("smt.K", budget) | ||||
|         assumption = z3.Not(Cj) | ||||
|         PPS.check([assumption]) | ||||
|         st = PPS.statistics() | ||||
|         c, d, rl = stats_tuple(st) | ||||
|         total_conflicts += c | ||||
|         total_decisions += d | ||||
|         total_rlimit += rl | ||||
| 
 | ||||
|     return (total_conflicts, total_decisions, total_rlimit) | ||||
| 
 | ||||
| def choose_best_pps(filepath, clauses, base_param_state, candidate_param_states, K, eps = 200): | ||||
|     budget = K + eps | ||||
|     best_param_state = base_param_state | ||||
|     best_score = (10**9, 10**9, 10**9) | ||||
| 
 | ||||
|     score0 = replay_prefix_on_pps(filepath, clauses, base_param_state, budget) | ||||
|     if score0 < best_score: | ||||
|         best_param_state, best_score = base_param_state, score0 | ||||
| 
 | ||||
|     for p_state in candidate_param_states: | ||||
|         sc = replay_prefix_on_pps(filepath, clauses, p_state, budget) | ||||
|         if sc < best_score: | ||||
|             best_param_state, best_score = p_state, sc | ||||
| 
 | ||||
|     return best_param_state, best_score | ||||
| 
 | ||||
| def next_perturbations(around_state): | ||||
|     outs = [] | ||||
|     for name, val in around_state: | ||||
|         if isinstance(val, (int, float)) and "restart_factor" in name: | ||||
|             outs.append([(name, float(val) * 0.9)]) | ||||
|             outs.append([(name, float(val) * 1.1)]) | ||||
|         elif isinstance(val, int) and "phase_caching" in name: | ||||
|             k = max(1, int(val)) | ||||
|             outs.append([(name, k // 2)]) | ||||
|             outs.append([(name, k * 2)]) | ||||
|         else: | ||||
|             if name == "smt.relevancy": | ||||
|                 outs.extend([[(name, 0)], [(name, 1)], [(name, 2)]]) | ||||
|     return outs or [around_state] | ||||
| 
 | ||||
| # -------------------------- | ||||
| # Protocol iteration | ||||
| # -------------------------- | ||||
| 
 | ||||
| def protocol_iteration(filepath, manager, K, eps=200): | ||||
|     S = solver_from_file(filepath) # Proof Prefix solver | ||||
|     P = manager.best_param_state or BASE_PARAM_CANDIDATES # current optimal parameter setting | ||||
|     apply_param_state(S, P) | ||||
| 
 | ||||
|     # Run S with max conflicts K | ||||
|     r, st = run_prefix_step(S, K) | ||||
| 
 | ||||
|     # If S returns SAT, or UNSAT we have a verdict. Tell the central dispatch that search is complete. Exit. | ||||
|     if r == z3.sat or r == z3.unsat: | ||||
|         print(f"[S] {os.path.basename(filepath)} → {r} (within max_conflicts={K}). Search complete.") | ||||
|         manager.mark_complete() | ||||
|         return | ||||
| 
 | ||||
|     # Collect a subset of conflict clauses from the bounded run of S. Call these clauses C1, ..., Cl. | ||||
|     C_list = collect_conflict_clauses_placeholder(S) | ||||
|     print(f"[S] collected {len(C_list)} conflict clauses for replay") | ||||
| 
 | ||||
|     PPS0 = P | ||||
|     PPS_perturb = next_perturbations(P) | ||||
| 
 | ||||
|     best_state, best_score = choose_best_pps(filepath, C_list, PPS0, PPS_perturb, K, eps) | ||||
|     print(f"[Replay] best={best_state} score(conf, dec, rlim)={best_score}") | ||||
| 
 | ||||
|     if best_state != P: | ||||
|         print(f"[Dispatch] updating best param state") | ||||
|         manager.maybe_update_best(best_state, best_score) | ||||
|         P = best_state | ||||
| 
 | ||||
|     PPS0 = P | ||||
|     PPS_perturb = next_perturbations(P) | ||||
|     print(f"[Dispatch] PPS_0 := {PPS0}, new perturbations: {PPS_perturb}") | ||||
| 
 | ||||
| # -------------------------- | ||||
| # Main | ||||
| # -------------------------- | ||||
| 
 | ||||
| def main(): | ||||
|     manager = BatchManager() | ||||
| 
 | ||||
|     for benchmark in os.listdir(bench_dir): | ||||
|         if benchmark != "From_T2__hqr.t2_fixed__term_unfeasibility_1_0.smt2": | ||||
|             continue | ||||
| 
 | ||||
|         filepath = os.path.join(bench_dir, benchmark) | ||||
|         protocol_iteration(filepath, manager, K=MAX_CONFLICTS, eps=200) | ||||
| 
 | ||||
|     if manager.best_param_state: | ||||
|         print(f"\n[GLOBAL] Best parameter state: {manager.best_param_state} with score {manager.best_score}") | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue