mirror of
https://github.com/Z3Prover/z3
synced 2025-12-19 18:53:43 +00:00
update pythonnn prototyping experiment, need to add a couple more things
This commit is contained in:
parent
193845c753
commit
86d7790c42
1 changed files with 99 additions and 75 deletions
|
|
@ -1,11 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import z3
|
from z3 import *
|
||||||
|
import threading
|
||||||
|
import math
|
||||||
|
|
||||||
MAX_CONFLICTS = 1000
|
MAX_CONFLICTS = 1000
|
||||||
MAX_EXAMPLES = 5
|
MAX_EXAMPLES = 5
|
||||||
bench_dir = "C:/tmp/parameter-tuning"
|
bench_dir = "C:/tmp/parameter-tuning"
|
||||||
|
|
||||||
# Baseline parameter candidates (you can grow this)
|
|
||||||
BASE_PARAM_CANDIDATES = [
|
BASE_PARAM_CANDIDATES = [
|
||||||
("smt.arith.eager_eq_axioms", False),
|
("smt.arith.eager_eq_axioms", False),
|
||||||
("smt.restart_factor", 1.2),
|
("smt.restart_factor", 1.2),
|
||||||
|
|
@ -20,7 +21,7 @@ BASE_PARAM_CANDIDATES = [
|
||||||
class BatchManager:
|
class BatchManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.best_param_state = None
|
self.best_param_state = None
|
||||||
self.best_score = (10**9, 10**9, 10**9) # (conflicts, decisions, rlimit)
|
self.best_score = (math.inf, math.inf, math.inf)
|
||||||
self.search_complete = False
|
self.search_complete = False
|
||||||
|
|
||||||
def mark_complete(self):
|
def mark_complete(self):
|
||||||
|
|
@ -33,99 +34,81 @@ class BatchManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _better(a, b):
|
def _better(a, b):
|
||||||
return a < b # lexicographic
|
return a < b # lexicographic compare
|
||||||
|
|
||||||
|
|
||||||
# -------------------
|
# -------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# -------------------
|
# -------------------
|
||||||
|
|
||||||
def get_stat_int(st, key):
|
|
||||||
try:
|
|
||||||
v = st.get_key_value(key)
|
|
||||||
if isinstance(v, (int, float)):
|
|
||||||
return int(v)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if key == "decisions" and hasattr(st, "decisions"):
|
|
||||||
try:
|
|
||||||
return int(st.decisions())
|
|
||||||
except Exception:
|
|
||||||
return 0
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def solver_from_file(filepath):
|
def solver_from_file(filepath):
|
||||||
s = z3.Solver()
|
s = Solver()
|
||||||
s.set("smt.auto_config", False)
|
s.set("smt.auto_config", False)
|
||||||
s.from_file(filepath)
|
s.from_file(filepath)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def apply_param_state(s, param_state):
|
def apply_param_state(s, param_state):
|
||||||
for name, value in param_state:
|
for name, value in param_state:
|
||||||
s.set(name, value)
|
s.set(name, value)
|
||||||
|
|
||||||
|
|
||||||
def stats_tuple(st):
|
def stats_tuple(st):
|
||||||
return (
|
return (
|
||||||
get_stat_int(st, "conflicts"),
|
int(st["conflicts"]),
|
||||||
get_stat_int(st, "decisions"),
|
int(st["decisions"]),
|
||||||
get_stat_int(st, "rlimit count"),
|
int(st["rlimit count"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
# Protocol steps
|
# Protocol steps
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
def run_prefix_step(S, K):
|
def run_prefix_step(S, K):
|
||||||
S.set("smt.K", K)
|
S.set("smt.max_conflicts", K)
|
||||||
r = S.check()
|
r = S.check()
|
||||||
return r, S.statistics()
|
return r, S.statistics()
|
||||||
|
|
||||||
|
|
||||||
def collect_conflict_clauses_placeholder(S, limit=4):
|
def collect_conflict_clauses_placeholder(S, limit=4):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def replay_prefix_on_pps(filepath, clauses, param_state, budget):
|
# Replay proof prefix on an existing PPS_solver (no solver recreation)
|
||||||
if not clauses:
|
# Solver continues from its current state.
|
||||||
s = solver_from_file(filepath)
|
def replay_prefix_on_pps(PPS_solver, clauses, param_state, budget):
|
||||||
apply_param_state(s, param_state)
|
apply_param_state(PPS_solver, param_state)
|
||||||
s.set("smt.K", budget)
|
PPS_solver.set("smt.max_conflicts", budget)
|
||||||
_ = s.check()
|
|
||||||
st = s.statistics()
|
asms = []
|
||||||
|
for Cj in clauses:
|
||||||
|
PPS_solver.set("smt.max_conflicts", budget)
|
||||||
|
asms.append(Not(Cj))
|
||||||
|
|
||||||
|
PPS_solver.check(asms)
|
||||||
|
st = PPS_solver.statistics()
|
||||||
|
|
||||||
return stats_tuple(st)
|
return stats_tuple(st)
|
||||||
|
|
||||||
total_conflicts = 0
|
|
||||||
total_decisions = 0
|
|
||||||
total_rlimit = 0
|
|
||||||
|
|
||||||
PPS = solver_from_file(filepath)
|
# For each PPS_i, replay the proof prefix of S
|
||||||
apply_param_state(PPS, param_state)
|
def replay_proof_prefixes(clauses, param_states, PPS_solvers, K, eps=200):
|
||||||
|
|
||||||
for Cj in clauses:
|
|
||||||
PPS.set("smt.K", budget)
|
|
||||||
assumption = z3.Not(Cj)
|
|
||||||
PPS.check([assumption])
|
|
||||||
st = PPS.statistics()
|
|
||||||
c, d, rl = stats_tuple(st)
|
|
||||||
total_conflicts += c
|
|
||||||
total_decisions += d
|
|
||||||
total_rlimit += rl
|
|
||||||
|
|
||||||
return (total_conflicts, total_decisions, total_rlimit)
|
|
||||||
|
|
||||||
def choose_best_pps(filepath, clauses, base_param_state, candidate_param_states, K, eps = 200):
|
|
||||||
budget = K + eps
|
budget = K + eps
|
||||||
best_param_state = base_param_state
|
base_param_state, candidate_param_states = param_states[0], param_states[1:]
|
||||||
best_score = (10**9, 10**9, 10**9)
|
|
||||||
|
|
||||||
score0 = replay_prefix_on_pps(filepath, clauses, base_param_state, budget)
|
# PPS_0 (baseline)
|
||||||
if score0 < best_score:
|
score0 = replay_prefix_on_pps(PPS_solvers[0], clauses, base_param_state, budget)
|
||||||
best_param_state, best_score = base_param_state, score0
|
best_param_state, best_score = base_param_state, score0
|
||||||
|
|
||||||
for p_state in candidate_param_states:
|
# PPS_i, i > 0
|
||||||
sc = replay_prefix_on_pps(filepath, clauses, p_state, budget)
|
for i, p_state in enumerate(candidate_param_states, start=1):
|
||||||
if sc < best_score:
|
score = replay_prefix_on_pps(PPS_solvers[i], clauses, p_state, budget)
|
||||||
best_param_state, best_score = p_state, sc
|
if score < best_score:
|
||||||
|
best_param_state, best_score = p_state, score
|
||||||
|
|
||||||
return best_param_state, best_score
|
return best_param_state, best_score
|
||||||
|
|
||||||
|
|
||||||
def next_perturbations(around_state):
|
def next_perturbations(around_state):
|
||||||
outs = []
|
outs = []
|
||||||
for name, val in around_state:
|
for name, val in around_state:
|
||||||
|
|
@ -136,8 +119,7 @@ def next_perturbations(around_state):
|
||||||
k = max(1, int(val))
|
k = max(1, int(val))
|
||||||
outs.append([(name, k // 2)])
|
outs.append([(name, k // 2)])
|
||||||
outs.append([(name, k * 2)])
|
outs.append([(name, k * 2)])
|
||||||
else:
|
elif name == "smt.relevancy":
|
||||||
if name == "smt.relevancy":
|
|
||||||
outs.extend([[(name, 0)], [(name, 1)], [(name, 2)]])
|
outs.extend([[(name, 0)], [(name, 1)], [(name, 2)]])
|
||||||
return outs or [around_state]
|
return outs or [around_state]
|
||||||
|
|
||||||
|
|
@ -145,28 +127,27 @@ def next_perturbations(around_state):
|
||||||
# Protocol iteration
|
# Protocol iteration
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
def protocol_iteration(filepath, manager, K, eps=200):
|
def protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K, eps=200):
|
||||||
S = solver_from_file(filepath) # Proof Prefix solver
|
# --- Proof Prefix Solver (S) ---
|
||||||
P = manager.best_param_state or BASE_PARAM_CANDIDATES # current optimal parameter setting
|
P = manager.best_param_state or BASE_PARAM_CANDIDATES
|
||||||
apply_param_state(S, P)
|
apply_param_state(S, P)
|
||||||
|
|
||||||
# Run S with max conflicts K
|
# Run S with max conflicts K
|
||||||
r, st = run_prefix_step(S, K)
|
r, st = run_prefix_step(S, K)
|
||||||
|
|
||||||
# If S returns SAT, or UNSAT we have a verdict. Tell the central dispatch that search is complete. Exit.
|
# If S returns SAT or UNSAT we have a verdict
|
||||||
if r == z3.sat or r == z3.unsat:
|
# Tell the central dispatch that search is complete and exit
|
||||||
|
if r == sat or r == unsat:
|
||||||
print(f"[S] {os.path.basename(filepath)} → {r} (within max_conflicts={K}). Search complete.")
|
print(f"[S] {os.path.basename(filepath)} → {r} (within max_conflicts={K}). Search complete.")
|
||||||
manager.mark_complete()
|
manager.mark_complete()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Collect a subset of conflict clauses from the bounded run of S. Call these clauses C1, ..., Cl.
|
# Collect subset of conflict clauses from the bounded run of S. Call these clauses C1...Cl
|
||||||
C_list = collect_conflict_clauses_placeholder(S)
|
C_list = collect_conflict_clauses_placeholder(S)
|
||||||
print(f"[S] collected {len(C_list)} conflict clauses for replay")
|
print(f"[S] collected {len(C_list)} conflict clauses for replay")
|
||||||
|
|
||||||
PPS0 = P
|
# For each PPS_i, replay the proof prefix of S
|
||||||
PPS_perturb = next_perturbations(P)
|
best_state, best_score = replay_proof_prefixes(C_list, PPS_states, PPS_solvers, K, eps)
|
||||||
|
|
||||||
best_state, best_score = choose_best_pps(filepath, C_list, PPS0, PPS_perturb, K, eps)
|
|
||||||
print(f"[Replay] best={best_state} score(conf, dec, rlim)={best_score}")
|
print(f"[Replay] best={best_state} score(conf, dec, rlim)={best_score}")
|
||||||
|
|
||||||
if best_state != P:
|
if best_state != P:
|
||||||
|
|
@ -174,9 +155,43 @@ def protocol_iteration(filepath, manager, K, eps=200):
|
||||||
manager.maybe_update_best(best_state, best_score)
|
manager.maybe_update_best(best_state, best_score)
|
||||||
P = best_state
|
P = best_state
|
||||||
|
|
||||||
PPS0 = P
|
# Update PPS_0 to use P (if it changed), and update all PPS_i > 0 with new perturbations of P
|
||||||
PPS_perturb = next_perturbations(P)
|
PPS_states[0] = P
|
||||||
print(f"[Dispatch] PPS_0 := {PPS0}, new perturbations: {PPS_perturb}")
|
new_perturb = next_perturbations(P)
|
||||||
|
for i in range(1, len(PPS_states)):
|
||||||
|
PPS_states[i] = new_perturb[i - 1]
|
||||||
|
print(f"[Dispatch] PPS_0 := {PPS_states[0]}, new perturbations: {new_perturb}")
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# Prefix probing thread
|
||||||
|
# --------------------------
|
||||||
|
|
||||||
|
def prefix_probe_thread(filepath, manager):
|
||||||
|
# Proof prefix solver S
|
||||||
|
S = solver_from_file(filepath)
|
||||||
|
apply_param_state(S, BASE_PARAM_CANDIDATES)
|
||||||
|
|
||||||
|
PPS_solvers = []
|
||||||
|
PPS_states = []
|
||||||
|
PPS_states.append(list(BASE_PARAM_CANDIDATES))
|
||||||
|
perturbations = next_perturbations(BASE_PARAM_CANDIDATES)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
st = perturbations[i] if i < len(perturbations) else BASE_PARAM_CANDIDATES
|
||||||
|
PPS_solver = Solver()
|
||||||
|
apply_param_state(PPS_solver, st)
|
||||||
|
PPS_solvers.append(PPS_solver)
|
||||||
|
PPS_states.append(st)
|
||||||
|
print(f"[Init] PPS_{i} initialized with params={st}")
|
||||||
|
|
||||||
|
# Reuse the same solvers each iteration
|
||||||
|
for iteration in range(3): # run a few iterations
|
||||||
|
if manager.search_complete:
|
||||||
|
break
|
||||||
|
print(f"\n[PrefixThread] Iteration {iteration}")
|
||||||
|
protocol_iteration(filepath, manager, S, PPS_solvers, PPS_states, K=MAX_CONFLICTS, eps=200)
|
||||||
|
|
||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
# Main
|
# Main
|
||||||
|
|
@ -190,10 +205,19 @@ def main():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
filepath = os.path.join(bench_dir, benchmark)
|
filepath = os.path.join(bench_dir, benchmark)
|
||||||
protocol_iteration(filepath, manager, K=MAX_CONFLICTS, eps=200)
|
prefix_thread = threading.Thread(target=prefix_probe_thread, args=(filepath, manager))
|
||||||
|
prefix_thread.start()
|
||||||
|
|
||||||
|
# main thread can perform monitoring or waiting
|
||||||
|
while prefix_thread.is_alive():
|
||||||
|
if manager.search_complete:
|
||||||
|
break
|
||||||
|
|
||||||
|
prefix_thread.join()
|
||||||
|
|
||||||
if manager.best_param_state:
|
if manager.best_param_state:
|
||||||
print(f"\n[GLOBAL] Best parameter state: {manager.best_param_state} with score {manager.best_score}")
|
print(f"\n[GLOBAL] Best parameter state: {manager.best_param_state} with score {manager.best_score}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue