mirror of
https://github.com/Z3Prover/z3
synced 2026-03-18 02:53:46 +00:00
248 lines
9.5 KiB
Python
248 lines
9.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Compare z3 string solvers: smt.string_solver=nseq (new) vs smt.string_solver=seq (old).
|
|
|
|
Usage:
|
|
python compare_solvers.py <path-to-smtlib-files> --z3 /path/to/z3 [--ext .smt2]
|
|
|
|
Finds all .smt2 files under the given path, runs both solvers with a 5s timeout,
|
|
and reports:
|
|
- Files where neither solver terminates (both timeout)
|
|
- Files where only one solver terminates (and which one)
|
|
- Files where both terminate
|
|
- Files where results diverge (sat vs unsat)
|
|
"""
|
|
|
|
import argparse
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from pathlib import Path
|
|
|
|
DEFAULT_TIMEOUT = 5 # seconds
|
|
COMMON_ARGS = ["model_validate=true"]
|
|
|
|
SOLVERS = {
|
|
"nseq": "smt.string_solver=nseq",
|
|
"seq": "smt.string_solver=seq",
|
|
}
|
|
|
|
|
|
_STATUS_RE = re.compile(r'\(\s*set-info\s+:status\s+(sat|unsat|unknown)\s*\)')
|
|
|
|
|
|
def read_smtlib_status(smt_file: Path) -> str:
|
|
"""Read the expected status from the SMT-LIB (set-info :status ...) directive.
|
|
Returns 'sat', 'unsat', or 'unknown'.
|
|
"""
|
|
try:
|
|
text = smt_file.read_text(encoding="utf-8", errors="replace")
|
|
m = _STATUS_RE.search(text)
|
|
if m:
|
|
return m.group(1)
|
|
except OSError:
|
|
pass
|
|
return "unknown"
|
|
|
|
|
|
def determine_status(res_nseq: str, res_seq: str, smtlib_status: str) -> str:
|
|
"""Determine the ground-truth status of a problem.
|
|
Priority: if both solvers agree on sat/unsat, use that; otherwise if one
|
|
solver gives sat/unsat, use that; otherwise fall back to the SMT-LIB
|
|
annotation; otherwise 'unknown'.
|
|
"""
|
|
definite = {"sat", "unsat"}
|
|
if res_nseq in definite and res_nseq == res_seq:
|
|
return res_nseq
|
|
if res_nseq in definite and res_seq not in definite:
|
|
return res_nseq
|
|
if res_seq in definite and res_nseq not in definite:
|
|
return res_seq
|
|
# Disagreement (sat vs unsat) — fall back to SMTLIB annotation
|
|
if res_nseq in definite and res_seq in definite and res_nseq != res_seq:
|
|
if smtlib_status in definite:
|
|
return smtlib_status
|
|
return "unknown"
|
|
# Neither solver gave a definite answer
|
|
if smtlib_status in definite:
|
|
return smtlib_status
|
|
return "unknown"
|
|
|
|
|
|
def run_z3(z3_bin: str, smt_file: Path, solver_arg: str, timeout_s: int = DEFAULT_TIMEOUT) -> tuple[str, float]:
|
|
"""Run z3 on a file with the given solver argument.
|
|
Returns (result, elapsed) where result is 'sat', 'unsat', 'unknown', or 'timeout'/'error'.
|
|
"""
|
|
timeout_ms = timeout_s * 1000
|
|
cmd = [z3_bin, f"-t:{timeout_ms}", solver_arg] + COMMON_ARGS + [str(smt_file)]
|
|
start = time.monotonic()
|
|
try:
|
|
proc = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=timeout_s + 5, # subprocess grace period beyond Z3's own timeout
|
|
)
|
|
elapsed = time.monotonic() - start
|
|
output = proc.stdout.strip()
|
|
# Extract first meaningful line (sat/unsat/unknown)
|
|
for line in output.splitlines():
|
|
line = line.strip()
|
|
if line in ("sat", "unsat"):
|
|
return line, elapsed
|
|
if line == "unknown":
|
|
# Z3 returns "unknown" when it hits -t: limit — treat as timeout
|
|
return "timeout", elapsed
|
|
return "timeout", elapsed
|
|
except subprocess.TimeoutExpired:
|
|
elapsed = time.monotonic() - start
|
|
return "timeout", elapsed
|
|
except Exception as e:
|
|
elapsed = time.monotonic() - start
|
|
return f"error:{e}", elapsed
|
|
|
|
|
|
def classify(res_nseq: str, res_seq: str) -> str:
|
|
"""Classify a pair of results into a category."""
|
|
timed_nseq = res_nseq == "timeout"
|
|
timed_seq = res_seq == "timeout"
|
|
|
|
if timed_nseq and timed_seq:
|
|
return "both_timeout"
|
|
if timed_nseq:
|
|
return "only_seq_terminates"
|
|
if timed_seq:
|
|
return "only_nseq_terminates"
|
|
# Both terminated — check agreement
|
|
if res_nseq == "unknown" or res_seq == "unknown":
|
|
return "both_terminate_unknown_involved"
|
|
if res_nseq == res_seq:
|
|
return "both_agree"
|
|
return "diverge"
|
|
|
|
|
|
def process_file(z3_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT) -> dict:
|
|
res_nseq, t_nseq = run_z3(z3_bin, smt_file, SOLVERS["nseq"], timeout_s)
|
|
res_seq, t_seq = run_z3(z3_bin, smt_file, SOLVERS["seq"], timeout_s)
|
|
cat = classify(res_nseq, res_seq)
|
|
smtlib_status = read_smtlib_status(smt_file)
|
|
status = determine_status(res_nseq, res_seq, smtlib_status)
|
|
return {
|
|
"file": smt_file,
|
|
"nseq": res_nseq,
|
|
"seq": res_seq,
|
|
"t_nseq": t_nseq,
|
|
"t_seq": t_seq,
|
|
"category": cat,
|
|
"smtlib_status": smtlib_status,
|
|
"status": status,
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
parser.add_argument("path", help="Directory containing SMT-LIB2 files")
|
|
parser.add_argument("--z3", required=True, metavar="PATH", help="Path to z3 binary")
|
|
parser.add_argument("--ext", default=".smt2", help="File extension to search for (default: .smt2)")
|
|
parser.add_argument("--jobs", type=int, default=4, help="Parallel workers (default: 4)")
|
|
parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, metavar="SEC",
|
|
help=f"Per-solver timeout in seconds (default: {DEFAULT_TIMEOUT})")
|
|
parser.add_argument("--csv", metavar="FILE", help="Also write results to a CSV file")
|
|
args = parser.parse_args()
|
|
|
|
z3_bin = args.z3
|
|
timeout_s = args.timeout
|
|
|
|
root = Path(args.path)
|
|
if not root.exists():
|
|
print(f"Error: path does not exist: {root}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
files = sorted(root.rglob(f"*{args.ext}"))
|
|
if not files:
|
|
print(f"No {args.ext} files found under {root}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
print(f"Found {len(files)} files. Running with {args.jobs} parallel workers, timeout={timeout_s}s …\n")
|
|
|
|
results = []
|
|
with ThreadPoolExecutor(max_workers=args.jobs) as pool:
|
|
futures = {pool.submit(process_file, z3_bin, f, timeout_s): f for f in files}
|
|
done = 0
|
|
for fut in as_completed(futures):
|
|
done += 1
|
|
r = fut.result()
|
|
results.append(r)
|
|
print(f"[{done:4d}/{len(files)}] {r['category']:35s} nseq={r['nseq']:8s} ({r['t_nseq']:.1f}s) "
|
|
f"seq={r['seq']:8s} ({r['t_seq']:.1f}s) {r['file'].name}")
|
|
|
|
# ── Summary ──────────────────────────────────────────────────────────────
|
|
categories = {
|
|
"both_timeout": [],
|
|
"only_seq_terminates": [],
|
|
"only_nseq_terminates": [],
|
|
"both_agree": [],
|
|
"both_terminate_unknown_involved":[],
|
|
"diverge": [],
|
|
}
|
|
for r in results:
|
|
categories.setdefault(r["category"], []).append(r)
|
|
|
|
print("\n" + "="*70)
|
|
print("TOTALS")
|
|
for cat, items in categories.items():
|
|
print(f" {cat:40s}: {len(items)}")
|
|
print(f"{'='*70}")
|
|
|
|
# ── Per-solver timeout & divergence file lists ─────────────────────────
|
|
nseq_timeouts = [r for r in results if r["nseq"] == "timeout"]
|
|
seq_timeouts = [r for r in results if r["seq"] == "timeout"]
|
|
both_to = categories["both_timeout"]
|
|
diverged = categories["diverge"]
|
|
|
|
def _print_file_list(label: str, items: list[dict]):
|
|
print(f"\n{'─'*70}")
|
|
print(f" {label} ({len(items)} files)")
|
|
print(f"{'─'*70}")
|
|
for r in sorted(items, key=lambda x: x["file"]):
|
|
print(f" {r['file']}")
|
|
|
|
if nseq_timeouts:
|
|
_print_file_list("NSEQ TIMES OUT", nseq_timeouts)
|
|
if seq_timeouts:
|
|
_print_file_list("SEQ TIMES OUT", seq_timeouts)
|
|
if both_to:
|
|
_print_file_list("BOTH TIME OUT", both_to)
|
|
if diverged:
|
|
_print_file_list("DIVERGE (sat vs unsat)", diverged)
|
|
|
|
print()
|
|
|
|
# ── Problem status statistics ────────────────────────────────────────────
|
|
status_counts = {"sat": 0, "unsat": 0, "unknown": 0}
|
|
for r in results:
|
|
status_counts[r["status"]] = status_counts.get(r["status"], 0) + 1
|
|
|
|
print(f"\nPROBLEM STATUS (total {len(results)} files)")
|
|
print(f"{'─'*40}")
|
|
print(f" {'sat':12s}: {status_counts['sat']:5d} ({100*status_counts['sat']/len(results):.1f}%)")
|
|
print(f" {'unsat':12s}: {status_counts['unsat']:5d} ({100*status_counts['unsat']/len(results):.1f}%)")
|
|
print(f" {'unknown':12s}: {status_counts['unknown']:5d} ({100*status_counts['unknown']/len(results):.1f}%)")
|
|
print(f"{'='*70}\n")
|
|
|
|
# ── Optional CSV output ───────────────────────────────────────────────────
|
|
if args.csv:
|
|
import csv
|
|
csv_path = Path(args.csv)
|
|
with csv_path.open("w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=["file", "nseq", "seq", "t_nseq", "t_seq", "category", "smtlib_status", "status"])
|
|
writer.writeheader()
|
|
for r in sorted(results, key=lambda x: x["file"]):
|
|
writer.writerow({**r, "file": str(r["file"])})
|
|
print(f"Results written to: {csv_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|