diff --git a/scripts/compare_seq_solvers.py b/scripts/compare_seq_solvers.py index 9519791b5..3c9d91969 100644 --- a/scripts/compare_seq_solvers.py +++ b/scripts/compare_seq_solvers.py @@ -1,16 +1,13 @@ #!/usr/bin/env python3 """ -Compare z3 string solvers: smt.string_solver=nseq (new) vs smt.string_solver=seq (old). +Compare z3 string solvers: smt.string_solver=nseq (new) vs smt.string_solver=seq (old), +and optionally against an external ZIPT solver. Usage: - python compare_solvers.py --z3 /path/to/z3 [--ext .smt2] + python compare_solvers.py --z3 /path/to/z3 [--zipt /path/to/zipt] [--ext .smt2] -Finds all .smt2 files under the given path, runs both solvers with a 5s timeout, -and reports: - - Files where neither solver terminates (both timeout) - - Files where only one solver terminates (and which one) - - Files where both terminate - - Files where results diverge (sat vs unsat) +Finds all .smt2 files under the given path, runs the solvers with a configurable timeout, +and reports timeouts, divergences, and summary statistics. """ import argparse @@ -71,31 +68,48 @@ def determine_status(res_nseq: str, res_seq: str, smtlib_status: str) -> str: return "unknown" +def _parse_result(output: str) -> str: + """Extract the first sat/unsat/unknown line from solver output.""" + for line in output.splitlines(): + tok = line.strip().lower() + if tok in ("sat", "unsat"): + return tok + if tok == "unknown": + return "timeout" + return "timeout" + + def run_z3(z3_bin: str, smt_file: Path, solver_arg: str, timeout_s: int = DEFAULT_TIMEOUT) -> tuple[str, float]: """Run z3 on a file with the given solver argument. - Returns (result, elapsed) where result is 'sat', 'unsat', 'unknown', or 'timeout'/'error'. + Returns (result, elapsed) where result is 'sat', 'unsat', or 'timeout'/'error'. """ timeout_ms = timeout_s * 1000 cmd = [z3_bin, f"-t:{timeout_ms}", solver_arg] + COMMON_ARGS + [str(smt_file)] start = time.monotonic() try: - proc = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=timeout_s + 5, # subprocess grace period beyond Z3's own timeout - ) + proc = subprocess.run(cmd, capture_output=True, text=True, + timeout=timeout_s + 5) + elapsed = time.monotonic() - start + return _parse_result(proc.stdout.strip()), elapsed + except subprocess.TimeoutExpired: elapsed = time.monotonic() - start - output = proc.stdout.strip() - # Extract first meaningful line (sat/unsat/unknown) - for line in output.splitlines(): - line = line.strip() - if line in ("sat", "unsat"): - return line, elapsed - if line == "unknown": - # Z3 returns "unknown" when it hits -t: limit — treat as timeout - return "timeout", elapsed return "timeout", elapsed + except Exception as e: + elapsed = time.monotonic() - start + return f"error:{e}", elapsed + + +def run_zipt(zipt_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT) -> tuple[str, float]: + """Run ZIPT on a file. Returns (result, elapsed).""" + timeout_ms = timeout_s * 1000 + cmd = [zipt_bin, f"-t:{timeout_ms}", str(smt_file)] + start = time.monotonic() + try: + proc = subprocess.run(cmd, capture_output=True, text=True, + timeout=timeout_s + 5) + elapsed = time.monotonic() - start + out = proc.stdout.strip() + return _parse_result(out), elapsed except subprocess.TimeoutExpired: elapsed = time.monotonic() - start return "timeout", elapsed @@ -123,13 +137,14 @@ def classify(res_nseq: str, res_seq: str) -> str: return "diverge" -def process_file(z3_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT) -> dict: +def process_file(z3_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT, + zipt_bin: str | None = None) -> dict: res_nseq, t_nseq = run_z3(z3_bin, smt_file, SOLVERS["nseq"], timeout_s) res_seq, t_seq = run_z3(z3_bin, smt_file, SOLVERS["seq"], timeout_s) cat = classify(res_nseq, res_seq) smtlib_status = read_smtlib_status(smt_file) status = determine_status(res_nseq, res_seq, smtlib_status) - return { + result = { "file": smt_file, "nseq": res_nseq, "seq": res_seq, @@ -138,7 +153,14 @@ def process_file(z3_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT) "category": cat, "smtlib_status": smtlib_status, "status": status, + "zipt": None, + "t_zipt": None, } + if zipt_bin: + res_zipt, t_zipt = run_zipt(zipt_bin, smt_file, timeout_s) + result["zipt"] = res_zipt + result["t_zipt"] = t_zipt + return result def main(): @@ -149,10 +171,13 @@ def main(): parser.add_argument("--jobs", type=int, default=4, help="Parallel workers (default: 4)") parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, metavar="SEC", help=f"Per-solver timeout in seconds (default: {DEFAULT_TIMEOUT})") + parser.add_argument("--zipt", metavar="PATH", default=None, + help="Path to ZIPT binary (optional; if omitted, ZIPT is not benchmarked)") parser.add_argument("--csv", metavar="FILE", help="Also write results to a CSV file") args = parser.parse_args() z3_bin = args.z3 + zipt_bin = args.zipt timeout_s = args.timeout root = Path(args.path) @@ -165,18 +190,23 @@ def main(): print(f"No {args.ext} files found under {root}", file=sys.stderr) sys.exit(1) - print(f"Found {len(files)} files. Running with {args.jobs} parallel workers, timeout={timeout_s}s …\n") + solvers_label = "nseq, seq" + (", zipt" if zipt_bin else "") + print(f"Found {len(files)} files. Solvers: {solvers_label}. " + f"Workers: {args.jobs}, timeout: {timeout_s}s\n") results = [] with ThreadPoolExecutor(max_workers=args.jobs) as pool: - futures = {pool.submit(process_file, z3_bin, f, timeout_s): f for f in files} + futures = {pool.submit(process_file, z3_bin, f, timeout_s, zipt_bin): f for f in files} done = 0 for fut in as_completed(futures): done += 1 r = fut.result() results.append(r) + zipt_part = "" + if zipt_bin: + zipt_part = f" zipt={r['zipt']:8s} ({r['t_zipt']:.1f}s)" print(f"[{done:4d}/{len(files)}] {r['category']:35s} nseq={r['nseq']:8s} ({r['t_nseq']:.1f}s) " - f"seq={r['seq']:8s} ({r['t_seq']:.1f}s) {r['file'].name}") + f"seq={r['seq']:8s} ({r['t_seq']:.1f}s){zipt_part} {r['file'].name}") # ── Summary ────────────────────────────────────────────────────────────── categories = { @@ -213,10 +243,27 @@ def main(): _print_file_list("NSEQ TIMES OUT", nseq_timeouts) if seq_timeouts: _print_file_list("SEQ TIMES OUT", seq_timeouts) + if zipt_bin: + zipt_timeouts = [r for r in results if r["zipt"] == "timeout"] + if zipt_timeouts: + _print_file_list("ZIPT TIMES OUT", zipt_timeouts) if both_to: - _print_file_list("BOTH TIME OUT", both_to) + _print_file_list("BOTH Z3 SOLVERS TIME OUT", both_to) + if zipt_bin: + all_to = [r for r in results + if r["nseq"] == "timeout" and r["seq"] == "timeout" and r["zipt"] == "timeout"] + if all_to: + _print_file_list("ALL THREE TIME OUT", all_to) if diverged: - _print_file_list("DIVERGE (sat vs unsat)", diverged) + _print_file_list("DIVERGE nseq vs seq (sat vs unsat)", diverged) + if zipt_bin: + definite = {"sat", "unsat"} + zipt_diverged = [r for r in results + if r["zipt"] in definite + and ((r["nseq"] in definite and r["nseq"] != r["zipt"]) + or (r["seq"] in definite and r["seq"] != r["zipt"]))] + if zipt_diverged: + _print_file_list("DIVERGE involving ZIPT", zipt_diverged) print() @@ -236,8 +283,11 @@ def main(): if args.csv: import csv csv_path = Path(args.csv) + fieldnames = ["file", "nseq", "seq", "t_nseq", "t_seq", "category", "smtlib_status", "status"] + if zipt_bin: + fieldnames[4:4] = ["zipt", "t_zipt"] with csv_path.open("w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=["file", "nseq", "seq", "t_nseq", "t_seq", "category", "smtlib_status", "status"]) + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() for r in sorted(results, key=lambda x: x["file"]): writer.writerow({**r, "file": str(r["file"])})