diff --git a/scripts/compare_seq_solvers.py b/scripts/compare_seq_solvers.py index db2b9abd8..1f7cc9d9f 100644 --- a/scripts/compare_seq_solvers.py +++ b/scripts/compare_seq_solvers.py @@ -11,10 +11,12 @@ and reports timeouts, divergences, and summary statistics. """ import argparse +import random import re import subprocess import sys import time +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path @@ -191,7 +193,11 @@ def main(): parser.add_argument("path", help="Directory containing SMT-LIB2 files") parser.add_argument("--z3", required=True, metavar="PATH", help="Path to z3 binary") parser.add_argument("--ext", default=".smt2", help="File extension to search for (default: .smt2)") - parser.add_argument("--jobs", type=int, default=4, help="Parallel workers (default: 4)") + parser.add_argument("--jobs", type=int, default=4, metavar="N", + help="Number of parallel worker threads (default: 4)") + parser.add_argument("--max-per-folder", type=int, default=None, metavar="N", + help="Max SMT2 files to benchmark per subfolder; " + "excess files are randomly sampled down to this limit") parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, metavar="SEC", help=f"Per-solver timeout in seconds (default: {DEFAULT_TIMEOUT})") parser.add_argument("--zipt", metavar="PATH", default=None, @@ -215,6 +221,20 @@ def main(): print(f"No {args.ext} files found under {root}", file=sys.stderr) sys.exit(1) + if args.max_per_folder is not None: + by_folder: dict[Path, list[Path]] = defaultdict(list) + for f in files: + by_folder[f.parent].append(f) + sampled: list[Path] = [] + for folder_files in by_folder.values(): + if len(folder_files) > args.max_per_folder: + sampled.extend(random.sample(folder_files, args.max_per_folder)) + else: + sampled.extend(folder_files) + files = sorted(sampled) + print(f"Sampling: {len(files)} files selected " + f"(max {args.max_per_folder} per subfolder, {len(by_folder)} subfolder(s))") + solvers_label = "nseq, seq" if args.parikh: solvers_label += ", nseq_p" if zipt_bin: solvers_label += ", zipt" @@ -222,9 +242,10 @@ def main(): f"Workers: {args.jobs}, timeout: {timeout_s}s\n") results = [] - with ThreadPoolExecutor(max_workers=args.jobs) as pool: - futures = {pool.submit(process_file, z3_bin, f, timeout_s, zipt_bin, args.parikh): f for f in files} - done = 0 + pool = ThreadPoolExecutor(max_workers=args.jobs) + futures = {pool.submit(process_file, z3_bin, f, timeout_s, zipt_bin, args.parikh): f for f in files} + done = 0 + try: for fut in as_completed(futures): done += 1 r = fut.result() @@ -237,6 +258,12 @@ def main(): zipt_part = f" zipt={r['zipt']:8s} ({r['t_zipt']:.1f}s)" print(f"[{done:4d}/{len(files)}] {r['category']:35s} nseq={r['nseq']:8s} ({r['t_nseq']:.1f}s) " f"seq={r['seq']:8s} ({r['t_seq']:.1f}s){np_part}{zipt_part} {r['file'].name}") + except KeyboardInterrupt: + print("\nInterrupted — cancelling pending tasks.", file=sys.stderr) + pool.shutdown(wait=False, cancel_futures=True) + sys.exit(130) + else: + pool.shutdown(wait=True) # ── Summary ────────────────────────────────────────────────────────────── categories = {