mirror of
				https://github.com/Z3Prover/z3
				synced 2025-10-31 11:42:28 +00:00 
			
		
		
		
	update pythonnn prototyping experiment, need to add a couple more things
This commit is contained in:
		
							parent
							
								
									193845c753
								
							
						
					
					
						commit
						86d7790c42
					
				
					 1 changed files with 99 additions and 75 deletions
				
			
		|  | @ -1,11 +1,12 @@ | |||
| import os | ||||
| import z3 | ||||
| from z3 import * | ||||
| import threading | ||||
| import math | ||||
| 
 | ||||
| MAX_CONFLICTS = 1000 | ||||
| MAX_EXAMPLES = 5 | ||||
| bench_dir = "C:/tmp/parameter-tuning" | ||||
| 
 | ||||
| # Baseline parameter candidates (you can grow this) | ||||
| BASE_PARAM_CANDIDATES = [ | ||||
|     ("smt.arith.eager_eq_axioms", False), | ||||
|     ("smt.restart_factor", 1.2), | ||||
|  | @ -20,7 +21,7 @@ BASE_PARAM_CANDIDATES = [ | |||
| class BatchManager: | ||||
|     def __init__(self): | ||||
|         self.best_param_state = None | ||||
|         self.best_score = (10**9, 10**9, 10**9)  # (conflicts, decisions, rlimit) | ||||
|         self.best_score = (math.inf, math.inf, math.inf) | ||||
|         self.search_complete = False | ||||
| 
 | ||||
|     def mark_complete(self): | ||||
|  | @ -33,99 +34,81 @@ class BatchManager: | |||
| 
 | ||||
|     @staticmethod | ||||
|     def _better(a, b): | ||||
|         return a < b  # lexicographic | ||||
|         return a < b  # lexicographic compare | ||||
| 
 | ||||
| 
 | ||||
| # ------------------- | ||||
| # Helpers | ||||
| # ------------------- | ||||
| 
 | ||||
| def get_stat_int(st, key): | ||||
|     try: | ||||
|         v = st.get_key_value(key) | ||||
|         if isinstance(v, (int, float)): | ||||
|             return int(v) | ||||
|     except Exception: | ||||
|         pass | ||||
|     if key == "decisions" and hasattr(st, "decisions"): | ||||
|         try: | ||||
|             return int(st.decisions()) | ||||
|         except Exception: | ||||
|             return 0 | ||||
|     return 0 | ||||
| 
 | ||||
| def solver_from_file(filepath): | ||||
|     s = z3.Solver() | ||||
|     s = Solver() | ||||
|     s.set("smt.auto_config", False) | ||||
|     s.from_file(filepath) | ||||
|     return s | ||||
| 
 | ||||
| 
 | ||||
| def apply_param_state(s, param_state): | ||||
|     for name, value in param_state: | ||||
|         s.set(name, value) | ||||
| 
 | ||||
| 
 | ||||
| def stats_tuple(st): | ||||
|     return ( | ||||
|         get_stat_int(st, "conflicts"), | ||||
|         get_stat_int(st, "decisions"), | ||||
|         get_stat_int(st, "rlimit count"), | ||||
|         int(st["conflicts"]), | ||||
|         int(st["decisions"]), | ||||
|         int(st["rlimit count"]), | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| # -------------------------- | ||||
| # Protocol steps | ||||
| # -------------------------- | ||||
| 
 | ||||
| def run_prefix_step(S, K): | ||||
|     S.set("smt.K", K) | ||||
|     S.set("smt.max_conflicts", K) | ||||
|     r = S.check() | ||||
|     return r, S.statistics() | ||||
| 
 | ||||
| def collect_conflict_clauses_placeholder(S, limit = 4): | ||||
| 
 | ||||
| def collect_conflict_clauses_placeholder(S, limit=4): | ||||
|     return [] | ||||
| 
 | ||||
| def replay_prefix_on_pps(filepath, clauses, param_state, budget): | ||||
|     if not clauses: | ||||
|         s = solver_from_file(filepath) | ||||
|         apply_param_state(s, param_state) | ||||
|         s.set("smt.K", budget) | ||||
|         _ = s.check() | ||||
|         st = s.statistics() | ||||
|         return stats_tuple(st) | ||||
| 
 | ||||
|     total_conflicts = 0 | ||||
|     total_decisions = 0 | ||||
|     total_rlimit = 0 | ||||
| 
 | ||||
|     PPS = solver_from_file(filepath) | ||||
|     apply_param_state(PPS, param_state) | ||||
| # Replay proof prefix on an existing PPS_solver (no solver recreation) | ||||
| # Solver continues from its current state. | ||||
| def replay_prefix_on_pps(PPS_solver, clauses, param_state, budget): | ||||
|     apply_param_state(PPS_solver, param_state) | ||||
|     PPS_solver.set("smt.max_conflicts", budget) | ||||
| 
 | ||||
|     asms = [] | ||||
|     for Cj in clauses: | ||||
|         PPS.set("smt.K", budget) | ||||
|         assumption = z3.Not(Cj) | ||||
|         PPS.check([assumption]) | ||||
|         st = PPS.statistics() | ||||
|         c, d, rl = stats_tuple(st) | ||||
|         total_conflicts += c | ||||
|         total_decisions += d | ||||
|         total_rlimit += rl | ||||
|       PPS_solver.set("smt.max_conflicts", budget) | ||||
|       asms.append(Not(Cj)) | ||||
| 
 | ||||
|     return (total_conflicts, total_decisions, total_rlimit) | ||||
|     PPS_solver.check(asms) | ||||
|     st = PPS_solver.statistics() | ||||
| 
 | ||||
| def choose_best_pps(filepath, clauses, base_param_state, candidate_param_states, K, eps = 200): | ||||
|     return stats_tuple(st) | ||||
| 
 | ||||
| 
 | ||||
| # For each PPS_i, replay the proof prefix of S | ||||
| def replay_proof_prefixes(clauses, param_states, PPS_solvers, K, eps=200): | ||||
|     budget = K + eps | ||||
|     best_param_state = base_param_state | ||||
|     best_score = (10**9, 10**9, 10**9) | ||||
|     base_param_state, candidate_param_states = param_states[0], param_states[1:] | ||||
| 
 | ||||
|     score0 = replay_prefix_on_pps(filepath, clauses, base_param_state, budget) | ||||
|     if score0 < best_score: | ||||
|         best_param_state, best_score = base_param_state, score0 | ||||
|     # PPS_0 (baseline) | ||||
|     score0 = replay_prefix_on_pps(PPS_solvers[0], clauses, base_param_state, budget) | ||||
|     best_param_state, best_score = base_param_state, score0 | ||||
| 
 | ||||
|     for p_state in candidate_param_states: | ||||
|         sc = replay_prefix_on_pps(filepath, clauses, p_state, budget) | ||||
|         if sc < best_score: | ||||
|             best_param_state, best_score = p_state, sc | ||||
|     # PPS_i, i > 0 | ||||
|     for i, p_state in enumerate(candidate_param_states, start=1): | ||||
|         score = replay_prefix_on_pps(PPS_solvers[i], clauses, p_state, budget) | ||||
|         if score < best_score: | ||||
|             best_param_state, best_score = p_state, score | ||||
| 
 | ||||
|     return best_param_state, best_score | ||||
| 
 | ||||
| 
 | ||||
| def next_perturbations(around_state): | ||||
|     outs = [] | ||||
|     for name, val in around_state: | ||||
|  | @ -136,37 +119,35 @@ def next_perturbations(around_state): | |||
|             k = max(1, int(val)) | ||||
|             outs.append([(name, k // 2)]) | ||||
|             outs.append([(name, k * 2)]) | ||||
|         else: | ||||
|             if name == "smt.relevancy": | ||||
|                 outs.extend([[(name, 0)], [(name, 1)], [(name, 2)]]) | ||||
|         elif name == "smt.relevancy": | ||||
|             outs.extend([[(name, 0)], [(name, 1)], [(name, 2)]]) | ||||
|     return outs or [around_state] | ||||
| 
 | ||||
| # -------------------------- | ||||
| # Protocol iteration | ||||
| # -------------------------- | ||||
| 
 | ||||
| def protocol_iteration(filepath, manager, K, eps=200): | ||||
|     S = solver_from_file(filepath) # Proof Prefix solver | ||||
|     P = manager.best_param_state or BASE_PARAM_CANDIDATES # current optimal parameter setting | ||||
| def protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K, eps=200): | ||||
|     # --- Proof Prefix Solver (S) --- | ||||
|     P = manager.best_param_state or BASE_PARAM_CANDIDATES | ||||
|     apply_param_state(S, P) | ||||
| 
 | ||||
|     # Run S with max conflicts K | ||||
|     r, st = run_prefix_step(S, K) | ||||
| 
 | ||||
|     # If S returns SAT, or UNSAT we have a verdict. Tell the central dispatch that search is complete. Exit. | ||||
|     if r == z3.sat or r == z3.unsat: | ||||
|     # If S returns SAT or UNSAT we have a verdict | ||||
|     # Tell the central dispatch that search is complete and exit | ||||
|     if r == sat or r == unsat: | ||||
|         print(f"[S] {os.path.basename(filepath)} → {r} (within max_conflicts={K}). Search complete.") | ||||
|         manager.mark_complete() | ||||
|         return | ||||
| 
 | ||||
|     # Collect a subset of conflict clauses from the bounded run of S. Call these clauses C1, ..., Cl. | ||||
|     # Collect subset of conflict clauses from the bounded run of S. Call these clauses C1...Cl | ||||
|     C_list = collect_conflict_clauses_placeholder(S) | ||||
|     print(f"[S] collected {len(C_list)} conflict clauses for replay") | ||||
| 
 | ||||
|     PPS0 = P | ||||
|     PPS_perturb = next_perturbations(P) | ||||
| 
 | ||||
|     best_state, best_score = choose_best_pps(filepath, C_list, PPS0, PPS_perturb, K, eps) | ||||
|     # For each PPS_i, replay the proof prefix of S | ||||
|     best_state, best_score = replay_proof_prefixes(C_list, PPS_states, PPS_solvers, K, eps) | ||||
|     print(f"[Replay] best={best_state} score(conf, dec, rlim)={best_score}") | ||||
| 
 | ||||
|     if best_state != P: | ||||
|  | @ -174,9 +155,43 @@ def protocol_iteration(filepath, manager, K, eps=200): | |||
|         manager.maybe_update_best(best_state, best_score) | ||||
|         P = best_state | ||||
| 
 | ||||
|     PPS0 = P | ||||
|     PPS_perturb = next_perturbations(P) | ||||
|     print(f"[Dispatch] PPS_0 := {PPS0}, new perturbations: {PPS_perturb}") | ||||
|     # Update PPS_0 to use P (if it changed), and update all PPS_i > 0 with new perturbations of P | ||||
|     PPS_states[0] = P | ||||
|     new_perturb = next_perturbations(P) | ||||
|     for i in range(1, len(PPS_states)): | ||||
|         PPS_states[i] = new_perturb[i - 1] | ||||
|     print(f"[Dispatch] PPS_0 := {PPS_states[0]}, new perturbations: {new_perturb}") | ||||
| 
 | ||||
| 
 | ||||
| # -------------------------- | ||||
| # Prefix probing thread | ||||
| # -------------------------- | ||||
| 
 | ||||
| def prefix_probe_thread(filepath, manager): | ||||
|     # Proof prefix solver S | ||||
|     S = solver_from_file(filepath) | ||||
|     apply_param_state(S, BASE_PARAM_CANDIDATES) | ||||
| 
 | ||||
|     PPS_solvers = [] | ||||
|     PPS_states = [] | ||||
|     PPS_states.append(list(BASE_PARAM_CANDIDATES)) | ||||
|     perturbations = next_perturbations(BASE_PARAM_CANDIDATES) | ||||
| 
 | ||||
|     for i in range(4): | ||||
|         st = perturbations[i] if i < len(perturbations) else BASE_PARAM_CANDIDATES | ||||
|         PPS_solver = Solver() | ||||
|         apply_param_state(PPS_solver, st) | ||||
|         PPS_solvers.append(PPS_solver) | ||||
|         PPS_states.append(st) | ||||
|         print(f"[Init] PPS_{i} initialized with params={st}") | ||||
| 
 | ||||
|     # Reuse the same solvers each iteration | ||||
|     for iteration in range(3):  # run a few iterations | ||||
|         if manager.search_complete: | ||||
|             break | ||||
|         print(f"\n[PrefixThread] Iteration {iteration}") | ||||
|         protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K=MAX_CONFLICTS, eps=200) | ||||
| 
 | ||||
| 
 | ||||
| # -------------------------- | ||||
| # Main | ||||
|  | @ -190,10 +205,19 @@ def main(): | |||
|             continue | ||||
| 
 | ||||
|         filepath = os.path.join(bench_dir, benchmark) | ||||
|         protocol_iteration(filepath, manager, K=MAX_CONFLICTS, eps=200) | ||||
|         prefix_thread = threading.Thread(target=prefix_probe_thread, args=(filepath, manager)) | ||||
|         prefix_thread.start() | ||||
| 
 | ||||
|         # main thread can perform monitoring or waiting | ||||
|         while prefix_thread.is_alive(): | ||||
|             if manager.search_complete: | ||||
|                 break | ||||
| 
 | ||||
|         prefix_thread.join() | ||||
| 
 | ||||
|     if manager.best_param_state: | ||||
|         print(f"\n[GLOBAL] Best parameter state: {manager.best_param_state} with score {manager.best_score}") | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue