3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-12-19 18:53:43 +00:00

update pythonnn prototyping experiment, need to add a couple more things

This commit is contained in:
Ilana Shapiro 2025-10-19 00:01:49 -07:00
parent 193845c753
commit 86d7790c42

View file

@ -1,11 +1,12 @@
import os import os
import z3 from z3 import *
import threading
import math
MAX_CONFLICTS = 1000 MAX_CONFLICTS = 1000
MAX_EXAMPLES = 5 MAX_EXAMPLES = 5
bench_dir = "C:/tmp/parameter-tuning" bench_dir = "C:/tmp/parameter-tuning"
# Baseline parameter candidates (you can grow this)
BASE_PARAM_CANDIDATES = [ BASE_PARAM_CANDIDATES = [
("smt.arith.eager_eq_axioms", False), ("smt.arith.eager_eq_axioms", False),
("smt.restart_factor", 1.2), ("smt.restart_factor", 1.2),
@ -20,7 +21,7 @@ BASE_PARAM_CANDIDATES = [
class BatchManager: class BatchManager:
def __init__(self): def __init__(self):
self.best_param_state = None self.best_param_state = None
self.best_score = (10**9, 10**9, 10**9) # (conflicts, decisions, rlimit) self.best_score = (math.inf, math.inf, math.inf)
self.search_complete = False self.search_complete = False
def mark_complete(self): def mark_complete(self):
@ -33,99 +34,81 @@ class BatchManager:
@staticmethod @staticmethod
def _better(a, b): def _better(a, b):
return a < b # lexicographic return a < b # lexicographic compare
# ------------------- # -------------------
# Helpers # Helpers
# ------------------- # -------------------
def get_stat_int(st, key):
try:
v = st.get_key_value(key)
if isinstance(v, (int, float)):
return int(v)
except Exception:
pass
if key == "decisions" and hasattr(st, "decisions"):
try:
return int(st.decisions())
except Exception:
return 0
return 0
def solver_from_file(filepath): def solver_from_file(filepath):
s = z3.Solver() s = Solver()
s.set("smt.auto_config", False) s.set("smt.auto_config", False)
s.from_file(filepath) s.from_file(filepath)
return s return s
def apply_param_state(s, param_state): def apply_param_state(s, param_state):
for name, value in param_state: for name, value in param_state:
s.set(name, value) s.set(name, value)
def stats_tuple(st): def stats_tuple(st):
return ( return (
get_stat_int(st, "conflicts"), int(st["conflicts"]),
get_stat_int(st, "decisions"), int(st["decisions"]),
get_stat_int(st, "rlimit count"), int(st["rlimit count"]),
) )
# -------------------------- # --------------------------
# Protocol steps # Protocol steps
# -------------------------- # --------------------------
def run_prefix_step(S, K): def run_prefix_step(S, K):
S.set("smt.K", K) S.set("smt.max_conflicts", K)
r = S.check() r = S.check()
return r, S.statistics() return r, S.statistics()
def collect_conflict_clauses_placeholder(S, limit=4): def collect_conflict_clauses_placeholder(S, limit=4):
return [] return []
def replay_prefix_on_pps(filepath, clauses, param_state, budget): # Replay proof prefix on an existing PPS_solver (no solver recreation)
if not clauses: # Solver continues from its current state.
s = solver_from_file(filepath) def replay_prefix_on_pps(PPS_solver, clauses, param_state, budget):
apply_param_state(s, param_state) apply_param_state(PPS_solver, param_state)
s.set("smt.K", budget) PPS_solver.set("smt.max_conflicts", budget)
_ = s.check()
st = s.statistics() asms = []
for Cj in clauses:
PPS_solver.set("smt.max_conflicts", budget)
asms.append(Not(Cj))
PPS_solver.check(asms)
st = PPS_solver.statistics()
return stats_tuple(st) return stats_tuple(st)
total_conflicts = 0
total_decisions = 0
total_rlimit = 0
PPS = solver_from_file(filepath) # For each PPS_i, replay the proof prefix of S
apply_param_state(PPS, param_state) def replay_proof_prefixes(clauses, param_states, PPS_solvers, K, eps=200):
for Cj in clauses:
PPS.set("smt.K", budget)
assumption = z3.Not(Cj)
PPS.check([assumption])
st = PPS.statistics()
c, d, rl = stats_tuple(st)
total_conflicts += c
total_decisions += d
total_rlimit += rl
return (total_conflicts, total_decisions, total_rlimit)
def choose_best_pps(filepath, clauses, base_param_state, candidate_param_states, K, eps = 200):
budget = K + eps budget = K + eps
best_param_state = base_param_state base_param_state, candidate_param_states = param_states[0], param_states[1:]
best_score = (10**9, 10**9, 10**9)
score0 = replay_prefix_on_pps(filepath, clauses, base_param_state, budget) # PPS_0 (baseline)
if score0 < best_score: score0 = replay_prefix_on_pps(PPS_solvers[0], clauses, base_param_state, budget)
best_param_state, best_score = base_param_state, score0 best_param_state, best_score = base_param_state, score0
for p_state in candidate_param_states: # PPS_i, i > 0
sc = replay_prefix_on_pps(filepath, clauses, p_state, budget) for i, p_state in enumerate(candidate_param_states, start=1):
if sc < best_score: score = replay_prefix_on_pps(PPS_solvers[i], clauses, p_state, budget)
best_param_state, best_score = p_state, sc if score < best_score:
best_param_state, best_score = p_state, score
return best_param_state, best_score return best_param_state, best_score
def next_perturbations(around_state): def next_perturbations(around_state):
outs = [] outs = []
for name, val in around_state: for name, val in around_state:
@ -136,8 +119,7 @@ def next_perturbations(around_state):
k = max(1, int(val)) k = max(1, int(val))
outs.append([(name, k // 2)]) outs.append([(name, k // 2)])
outs.append([(name, k * 2)]) outs.append([(name, k * 2)])
else: elif name == "smt.relevancy":
if name == "smt.relevancy":
outs.extend([[(name, 0)], [(name, 1)], [(name, 2)]]) outs.extend([[(name, 0)], [(name, 1)], [(name, 2)]])
return outs or [around_state] return outs or [around_state]
@ -145,28 +127,27 @@ def next_perturbations(around_state):
# Protocol iteration # Protocol iteration
# -------------------------- # --------------------------
def protocol_iteration(filepath, manager, K, eps=200): def protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K, eps=200):
S = solver_from_file(filepath) # Proof Prefix solver # --- Proof Prefix Solver (S) ---
P = manager.best_param_state or BASE_PARAM_CANDIDATES # current optimal parameter setting P = manager.best_param_state or BASE_PARAM_CANDIDATES
apply_param_state(S, P) apply_param_state(S, P)
# Run S with max conflicts K # Run S with max conflicts K
r, st = run_prefix_step(S, K) r, st = run_prefix_step(S, K)
# If S returns SAT, or UNSAT we have a verdict. Tell the central dispatch that search is complete. Exit. # If S returns SAT or UNSAT we have a verdict
if r == z3.sat or r == z3.unsat: # Tell the central dispatch that search is complete and exit
if r == sat or r == unsat:
print(f"[S] {os.path.basename(filepath)}{r} (within max_conflicts={K}). Search complete.") print(f"[S] {os.path.basename(filepath)}{r} (within max_conflicts={K}). Search complete.")
manager.mark_complete() manager.mark_complete()
return return
# Collect a subset of conflict clauses from the bounded run of S. Call these clauses C1, ..., Cl. # Collect subset of conflict clauses from the bounded run of S. Call these clauses C1...Cl
C_list = collect_conflict_clauses_placeholder(S) C_list = collect_conflict_clauses_placeholder(S)
print(f"[S] collected {len(C_list)} conflict clauses for replay") print(f"[S] collected {len(C_list)} conflict clauses for replay")
PPS0 = P # For each PPS_i, replay the proof prefix of S
PPS_perturb = next_perturbations(P) best_state, best_score = replay_proof_prefixes(C_list, PPS_states, PPS_solvers, K, eps)
best_state, best_score = choose_best_pps(filepath, C_list, PPS0, PPS_perturb, K, eps)
print(f"[Replay] best={best_state} score(conf, dec, rlim)={best_score}") print(f"[Replay] best={best_state} score(conf, dec, rlim)={best_score}")
if best_state != P: if best_state != P:
@ -174,9 +155,43 @@ def protocol_iteration(filepath, manager, K, eps=200):
manager.maybe_update_best(best_state, best_score) manager.maybe_update_best(best_state, best_score)
P = best_state P = best_state
PPS0 = P # Update PPS_0 to use P (if it changed), and update all PPS_i > 0 with new perturbations of P
PPS_perturb = next_perturbations(P) PPS_states[0] = P
print(f"[Dispatch] PPS_0 := {PPS0}, new perturbations: {PPS_perturb}") new_perturb = next_perturbations(P)
for i in range(1, len(PPS_states)):
PPS_states[i] = new_perturb[i - 1]
print(f"[Dispatch] PPS_0 := {PPS_states[0]}, new perturbations: {new_perturb}")
# --------------------------
# Prefix probing thread
# --------------------------
def prefix_probe_thread(filepath, manager):
# Proof prefix solver S
S = solver_from_file(filepath)
apply_param_state(S, BASE_PARAM_CANDIDATES)
PPS_solvers = []
PPS_states = []
PPS_states.append(list(BASE_PARAM_CANDIDATES))
perturbations = next_perturbations(BASE_PARAM_CANDIDATES)
for i in range(4):
st = perturbations[i] if i < len(perturbations) else BASE_PARAM_CANDIDATES
PPS_solver = Solver()
apply_param_state(PPS_solver, st)
PPS_solvers.append(PPS_solver)
PPS_states.append(st)
print(f"[Init] PPS_{i} initialized with params={st}")
# Reuse the same solvers each iteration
for iteration in range(3): # run a few iterations
if manager.search_complete:
break
print(f"\n[PrefixThread] Iteration {iteration}")
protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K=MAX_CONFLICTS, eps=200)
# -------------------------- # --------------------------
# Main # Main
@ -190,10 +205,19 @@ def main():
continue continue
filepath = os.path.join(bench_dir, benchmark) filepath = os.path.join(bench_dir, benchmark)
protocol_iteration(filepath, manager, K=MAX_CONFLICTS, eps=200) prefix_thread = threading.Thread(target=prefix_probe_thread, args=(filepath, manager))
prefix_thread.start()
# main thread can perform monitoring or waiting
while prefix_thread.is_alive():
if manager.search_complete:
break
prefix_thread.join()
if manager.best_param_state: if manager.best_param_state:
print(f"\n[GLOBAL] Best parameter state: {manager.best_param_state} with score {manager.best_score}") print(f"\n[GLOBAL] Best parameter state: {manager.best_param_state} with score {manager.best_score}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()