mirror of
				https://github.com/Z3Prover/z3
				synced 2025-10-31 03:32:28 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			235 lines
		
	
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			235 lines
		
	
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| from more_itertools import iterate
 | |
| from z3 import *
 | |
| from multiprocessing import Process
 | |
| import math, random
 | |
| 
 | |
| MAX_CONFLICTS = 100
 | |
| MAX_EXAMPLES = 5
 | |
| bench_dir = "../z3-poly-testing/inputs/QF_NIA_small"
 | |
| 
 | |
| BASE_PARAM_CANDIDATES = [
 | |
|     ("smt.arith.eager_eq_axioms", False),
 | |
|     ("smt.restart_factor", 1.2),
 | |
|     ("smt.relevancy", 0),
 | |
|     ("smt.phase_caching_off", 200),
 | |
|     ("smt.phase_caching_on", 600),
 | |
| ]
 | |
| 
 | |
| # --------------------------
 | |
| # One class: BatchManager
 | |
| # --------------------------
 | |
| class BatchManager:
 | |
|     def __init__(self):
 | |
|         self.best_param_state = None
 | |
|         self.best_score = (math.inf, math.inf, math.inf)
 | |
|         self.search_complete = False
 | |
| 
 | |
|     def mark_complete(self):
 | |
|         self.search_complete = True
 | |
| 
 | |
|     def maybe_update_best(self, param_state, triple):
 | |
|         if self._better(triple, self.best_score):
 | |
|             self.best_param_state = list(param_state)
 | |
|             self.best_score = triple
 | |
| 
 | |
|     @staticmethod
 | |
|     def _better(a, b):
 | |
|         return a < b  # lexicographic compare
 | |
| 
 | |
| 
 | |
| # -------------------
 | |
| # Helpers
 | |
| # -------------------
 | |
| 
 | |
| def solver_from_file(filepath):
 | |
|     s = Solver()
 | |
|     s.set("smt.auto_config", False)
 | |
|     s.from_file(filepath)
 | |
|     return s
 | |
| 
 | |
| 
 | |
| def apply_param_state(s, param_state):
 | |
|     print(f"Applying param state: {param_state}")
 | |
|     for name, value in param_state:
 | |
|         s.set(name, value)
 | |
| 
 | |
| 
 | |
| def stats_tuple(st):
 | |
|     return (
 | |
|         int(st.get_key_value("conflicts")),
 | |
|         int(st.get_key_value("decisions")),
 | |
|         int(st.get_key_value("rlimit count")),
 | |
|     )
 | |
| 
 | |
| 
 | |
| # --------------------------
 | |
| # Protocol steps
 | |
| # --------------------------
 | |
| 
 | |
| def run_prefix_step(S, K):
 | |
|     S.set("max_conflicts", K)
 | |
|     r = S.check()
 | |
|     return r, S.statistics()
 | |
| 
 | |
| 
 | |
| def collect_conflict_clauses_placeholder(S, limit=4):
 | |
|     return []
 | |
| 
 | |
| # Replay proof prefix on an existing PPS_solver (no solver recreation)
 | |
| # Solver continues from its current state.
 | |
| def replay_prefix_on_pps(PPS_solver, clauses, param_state, budget):
 | |
|     print(f"[Replaying] on PPS with params={param_state} and budget={budget}")
 | |
|     apply_param_state(PPS_solver, param_state)
 | |
|     PPS_solver.set("max_conflicts", budget)
 | |
| 
 | |
|     asms = []
 | |
|     for Cj in clauses:
 | |
|       asms.append(Not(Cj))
 | |
| 
 | |
|     PPS_solver.check(asms)
 | |
|     st = PPS_solver.statistics()
 | |
| 
 | |
|     return stats_tuple(st)
 | |
| 
 | |
| 
 | |
| # For each PPS_i, replay the proof prefix of S
 | |
| def replay_proof_prefixes(clauses, param_states, PPS_solvers, K, eps=200):
 | |
|     budget = K + eps
 | |
|     base_param_state, candidate_param_states = param_states[0], param_states[1:]
 | |
| 
 | |
|     # PPS_0 (baseline)
 | |
|     score0 = replay_prefix_on_pps(PPS_solvers[0], clauses, base_param_state, budget)
 | |
|     best_param_state, best_score = base_param_state, score0
 | |
|     
 | |
|     # PPS_i, i > 0
 | |
|     for i, p_state in enumerate(candidate_param_states, start=1):
 | |
|         score = replay_prefix_on_pps(PPS_solvers[i], clauses, p_state, budget)
 | |
|         if score < best_score:
 | |
|             best_param_state, best_score = p_state, score
 | |
| 
 | |
|     return best_param_state, best_score
 | |
| 
 | |
| # return a variant of the given param state
 | |
| def perturbate(param_state):
 | |
|     new_state = []
 | |
|     for name, val in param_state:
 | |
|         if isinstance(val, (int, float)) and "restart_factor" in name:
 | |
|             # perturb multiplicatively +/-10%
 | |
|             factor = random.choice([0.9, 1.1])
 | |
|             new_state.append((name, round(val * factor, 3)))
 | |
|         elif isinstance(val, int) and "phase_caching" in name:
 | |
|             # pick half or double
 | |
|             new_val = random.choice([max(1, val // 2), val * 2])
 | |
|             new_state.append((name, new_val))
 | |
|         elif name == "smt.relevancy":
 | |
|             # pick random alternative from {0,1,2}
 | |
|             new_val = random.choice([0, 1, 2])
 | |
|             new_state.append((name, new_val))
 | |
|         else:
 | |
|             # unchanged
 | |
|             new_state.append((name, val))
 | |
|     return new_state
 | |
| 
 | |
| # --------------------------
 | |
| # Protocol iteration
 | |
| # --------------------------
 | |
| 
 | |
| def protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K, eps=200):
 | |
|     # --- Proof Prefix Solver (S) ---
 | |
|     P = manager.best_param_state or BASE_PARAM_CANDIDATES
 | |
|     apply_param_state(S, P)
 | |
| 
 | |
|     # Run S with max conflicts K
 | |
|     print(f"[S] Running proof prefix solver with params={P} and max_conflicts={K}")
 | |
|     r, st = run_prefix_step(S, K)
 | |
| 
 | |
|     # If S returns SAT or UNSAT we have a verdict
 | |
|     # Tell the central dispatch that search is complete and exit
 | |
|     if r == sat or r == unsat:
 | |
|         print(f"[S] {os.path.basename(filepath)} → {r} (within max_conflicts={K}). Search complete.")
 | |
|         manager.mark_complete()
 | |
|         return
 | |
| 
 | |
|     # Collect subset of conflict clauses from the bounded run of S. Call these clauses C1...Cl
 | |
|     print(f"[S] collecting conflict clauses for replay")
 | |
|     C_list = collect_conflict_clauses_placeholder(S)
 | |
| 
 | |
|     # For each PPS_i, replay the proof prefix of S
 | |
|     print(f"[Replaying] Replaying proof prefix on PPS solvers with budget={K + eps}")
 | |
|     best_state, best_score = replay_proof_prefixes(C_list, PPS_states, PPS_solvers, K, eps)
 | |
| 
 | |
|     if best_state != P:
 | |
|         print(f"[Dispatch] updating best param state")
 | |
|         manager.maybe_update_best(best_state, best_score)
 | |
|         P = best_state
 | |
| 
 | |
|     # Update PPS_0 to use P (if it changed), and update all PPS_i > 0 with new perturbations of P
 | |
|     PPS_states[0] = P 
 | |
|     for i in range(1, len(PPS_states)):
 | |
|         PPS_states[i] = perturbate(P)
 | |
|         
 | |
|     return PPS_states
 | |
| 
 | |
| 
 | |
| # --------------------------
 | |
| # Prefix probing thread
 | |
| # --------------------------
 | |
| 
 | |
| def prefix_probe_thread(filepath, manager):
 | |
|     # Proof prefix solver S
 | |
|     S = solver_from_file(filepath)
 | |
|     apply_param_state(S, BASE_PARAM_CANDIDATES)
 | |
| 
 | |
|     PPS_solvers = []
 | |
|     PPS_states = []
 | |
| 
 | |
|     # set up the 4 variant parameter probe solvers PPS_1 ... PPS_4 as new contexts on the proof prefix solver S
 | |
|     for i in range(4):
 | |
|         st = BASE_PARAM_CANDIDATES if i == 0 else perturbate(BASE_PARAM_CANDIDATES) # PPS_0 uses base params
 | |
|         ctx = Context()
 | |
|         PPS_solver = S.translate(ctx)  # clone S (proof prefix) into new context
 | |
|         apply_param_state(PPS_solver, st)
 | |
|         PPS_solvers.append(PPS_solver)
 | |
|         PPS_states.append(st)
 | |
|         print(f"[Init] PPS_{i} inherited prefix in new context with params={st}")
 | |
| 
 | |
|     # Reuse the same solvers each iteration
 | |
|     iteration = 0
 | |
|     while not manager.search_complete:
 | |
|         print(f"\n[PrefixThread] Iteration {iteration}")
 | |
|         PPS_states = protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K=MAX_CONFLICTS, eps=200)
 | |
|         iteration += 1
 | |
| 
 | |
| 
 | |
| # --------------------------
 | |
| # Main
 | |
| # --------------------------
 | |
| 
 | |
| def run_main_solver(filepath):
 | |
|     set_param("parallel.enable", True)
 | |
|     main_solver = solver_from_file(filepath)
 | |
|     apply_param_state(main_solver, BASE_PARAM_CANDIDATES)
 | |
|     print(f"[Main] Started main solver on {os.path.basename(filepath)} with parallel.enable=True")
 | |
|     r = main_solver.check()
 | |
|     print(f"[Main] {os.path.basename(filepath)} → {r}")
 | |
| 
 | |
| def main():
 | |
|     manager = BatchManager()
 | |
|     for benchmark in os.listdir(bench_dir):
 | |
|         if benchmark != "From_T2__hqr.t2_fixed__term_unfeasibility_1_0.smt2":
 | |
|             continue
 | |
| 
 | |
|         filepath = os.path.join(bench_dir, benchmark)
 | |
|         prefix_proc = Process(target=prefix_probe_thread, args=(filepath, manager))
 | |
|         main_proc = Process(target=run_main_solver, args=(filepath,))
 | |
|         prefix_proc.start()
 | |
|         main_proc.start()
 | |
|         prefix_proc.join()
 | |
|         main_proc.join()
 | |
| 
 | |
|     if manager.best_param_state:
 | |
|         print(f"\n[GLOBAL] Best parameter state: {manager.best_param_state} with score {manager.best_score}")
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |