From 9b357373b0624f4b4809d1cd1e69e525884e2127 Mon Sep 17 00:00:00 2001 From: CEisenhofer Date: Tue, 9 Jun 2026 14:44:41 +0200 Subject: [PATCH] Preprocess away leading characters; otw. we are unsound --- scripts/compare_seq_solvers.py | 188 ++++++++++++++------------------- src/smt/theory_nseq.cpp | 43 ++++++++ 2 files changed, 124 insertions(+), 107 deletions(-) diff --git a/scripts/compare_seq_solvers.py b/scripts/compare_seq_solvers.py index 1f7cc9d9f..8f577a7de 100644 --- a/scripts/compare_seq_solvers.py +++ b/scripts/compare_seq_solvers.py @@ -1,10 +1,14 @@ #!/usr/bin/env python3 """ -Compare z3 string solvers: smt.string_solver=nseq (new) vs smt.string_solver=seq (old), -and optionally against an external ZIPT solver. +Compare z3 string solver configurations, and optionally against an external ZIPT solver. + +We always run three z3 configurations: + nseq_md monadic decomposition (parikh off, eager regex factorization) + nseq_pa parikh (parikh on, no regex factorization) + seq the old/baseline string solver Usage: - python compare_solvers.py --z3 /path/to/z3 [--zipt /path/to/zipt] [--ext .smt2] + python compare_seq_solvers.py --z3 /path/to/z3 [--zipt /path/to/zipt] [--ext .smt2] Finds all .smt2 files under the given path, runs the solvers with a configurable timeout, and reports timeouts, divergences, and summary statistics. @@ -23,12 +27,18 @@ from pathlib import Path DEFAULT_TIMEOUT = 5 # seconds COMMON_ARGS = ["model_validate=true"] +# All three configurations are always run. SOLVERS = { - "nseq": ["smt.string_solver=nseq", "smt.nseq.parikh=false"], - "nseq_p": ["smt.string_solver=nseq", "smt.nseq.parikh=true"], - "seq": ["smt.string_solver=seq"], + "nseq_md": ["smt.string_solver=nseq", "smt.nseq.parikh=false", + "smt.nseq.regex_factorization_threshold=100", "smt.nseq.regex_factorization_eager=true"], + "nseq_pa": ["smt.string_solver=nseq", "smt.nseq.parikh=true", + "smt.nseq.regex_factorization_threshold=0", "smt.nseq.regex_factorization_eager=false"], + "seq": ["smt.string_solver=seq"], } +# Ordered list of the z3 configuration names (excludes the external zipt solver). +SOLVER_NAMES = list(SOLVERS.keys()) + _STATUS_RE = re.compile(r'\(\s*set-info\s+:status\s+(sat|unsat|unknown)\s*\)') @@ -47,25 +57,16 @@ def read_smtlib_status(smt_file: Path) -> str: return "unknown" -def determine_status(res_nseq: str, res_seq: str, smtlib_status: str) -> str: +def determine_status(solver_results: dict[str, str], smtlib_status: str) -> str: """Determine the ground-truth status of a problem. - Priority: if both solvers agree on sat/unsat, use that; otherwise if one - solver gives sat/unsat, use that; otherwise fall back to the SMT-LIB - annotation; otherwise 'unknown'. + If the solvers that gave a definite (sat/unsat) answer all agree, use that; + if they disagree, fall back to the SMT-LIB annotation; otherwise 'unknown'. """ definite = {"sat", "unsat"} - if res_nseq in definite and res_nseq == res_seq: - return res_nseq - if res_nseq in definite and res_seq not in definite: - return res_nseq - if res_seq in definite and res_nseq not in definite: - return res_seq - # Disagreement (sat vs unsat) — fall back to SMTLIB annotation - if res_nseq in definite and res_seq in definite and res_nseq != res_seq: - if smtlib_status in definite: - return smtlib_status - return "unknown" - # Neither solver gave a definite answer + found = {r for r in solver_results.values() if r in definite} + if len(found) == 1: + return found.pop() + # No definite answer, or a disagreement among definite answers. if smtlib_status in definite: return smtlib_status return "unknown" @@ -124,63 +125,54 @@ def run_zipt(zipt_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT) -> return f"error:{e}", elapsed -def classify(res_nseq: str, res_seq: str, res_nseq_p: str | None = None) -> str: - """Classify a pair of results into a category.""" - if res_nseq == "invalid_model" or res_seq == "invalid_model" or res_nseq_p == "invalid_model": +def classify(solver_results: dict[str, str]) -> str: + """Classify the results of all z3 configurations into a category.""" + definite = {"sat", "unsat"} + results = list(solver_results.values()) + + if any(r == "invalid" for r in results): return "invalid_model" - timed_nseq = res_nseq == "timeout" - timed_seq = res_seq == "timeout" - - if res_nseq_p: - timed_nseq_p = res_nseq_p == "timeout" - if timed_nseq and timed_seq and timed_nseq_p: return "all_timeout" - if not timed_nseq and not timed_seq and not timed_nseq_p: - if res_nseq == "unknown" or res_seq == "unknown" or res_nseq_p == "unknown": - return "all_terminate_unknown_involved" - if res_nseq == res_seq == res_nseq_p: return "all_agree" - return "diverge" + n_timeout = sum(1 for r in results if r == "timeout") + if n_timeout == len(results): + return "all_timeout" - if timed_nseq and timed_seq: - return "both_timeout" - if timed_nseq: - return "only_seq_terminates" - if timed_seq: - return "only_nseq_terminates" - # Both terminated — check agreement - if res_nseq == "unknown" or res_seq == "unknown": - return "both_terminate_unknown_involved" - if res_nseq == res_seq: - return "both_agree" - return "diverge" + # Among the configurations that terminated: + terminated = [r for r in results if r != "timeout"] + found = {r for r in terminated if r in definite} + if len(found) > 1: + return "diverge" + if any(r == "unknown" for r in terminated): + return "unknown_involved" + # All terminating configurations agree on a definite answer. + if n_timeout == 0: + return "all_agree" + return "agree_some_timeout" def process_file(z3_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT, - zipt_bin: str | None = None, run_nseq_p: bool = False) -> dict: - res_nseq, t_nseq = run_z3(z3_bin, smt_file, SOLVERS["nseq"], timeout_s) - res_seq, t_seq = run_z3(z3_bin, smt_file, SOLVERS["seq"], timeout_s) - - res_nseq_p, t_nseq_p = None, None - if run_nseq_p: - res_nseq_p, t_nseq_p = run_z3(z3_bin, smt_file, SOLVERS["nseq_p"], timeout_s) + zipt_bin: str | None = None) -> dict: + solver_results: dict[str, str] = {} + solver_times: dict[str, float] = {} + for name in SOLVER_NAMES: + res, t = run_z3(z3_bin, smt_file, SOLVERS[name], timeout_s) + solver_results[name] = res + solver_times[name] = t - cat = classify(res_nseq, res_seq, res_nseq_p) + cat = classify(solver_results) smtlib_status = read_smtlib_status(smt_file) - status = determine_status(res_nseq, res_seq, smtlib_status) + status = determine_status(solver_results, smtlib_status) result = { "file": smt_file, - "nseq": res_nseq, - "seq": res_seq, - "t_nseq": t_nseq, - "t_seq": t_seq, "category": cat, "smtlib_status": smtlib_status, "status": status, "zipt": None, "t_zipt": None, - "nseq_p": res_nseq_p, - "t_nseq_p": t_nseq_p, } + for name in SOLVER_NAMES: + result[name] = solver_results[name] + result[f"t_{name}"] = solver_times[name] if zipt_bin: res_zipt, t_zipt = run_zipt(zipt_bin, smt_file, timeout_s) result["zipt"] = res_zipt @@ -202,8 +194,6 @@ def main(): help=f"Per-solver timeout in seconds (default: {DEFAULT_TIMEOUT})") parser.add_argument("--zipt", metavar="PATH", default=None, help="Path to ZIPT binary (optional; if omitted, ZIPT is not benchmarked)") - parser.add_argument("--parikh", action="store_true", - help="Also run nseq with nseq.parikh=true") parser.add_argument("--csv", metavar="FILE", help="Also write results to a CSV file") args = parser.parse_args() @@ -235,29 +225,24 @@ def main(): print(f"Sampling: {len(files)} files selected " f"(max {args.max_per_folder} per subfolder, {len(by_folder)} subfolder(s))") - solvers_label = "nseq, seq" - if args.parikh: solvers_label += ", nseq_p" + solvers_label = ", ".join(SOLVER_NAMES) if zipt_bin: solvers_label += ", zipt" print(f"Found {len(files)} files. Solvers: {solvers_label}. " f"Workers: {args.jobs}, timeout: {timeout_s}s\n") results = [] pool = ThreadPoolExecutor(max_workers=args.jobs) - futures = {pool.submit(process_file, z3_bin, f, timeout_s, zipt_bin, args.parikh): f for f in files} + futures = {pool.submit(process_file, z3_bin, f, timeout_s, zipt_bin): f for f in files} done = 0 try: for fut in as_completed(futures): done += 1 r = fut.result() - results.append(r) - np_part = "" - if args.parikh: - np_part = f" nseq_p={r['nseq_p']:8s} ({r['t_nseq_p']:.1f}s)" + solver_part = " ".join(f"{name}={r[name]:8s} ({r[f't_{name}']:.1f}s)" for name in SOLVER_NAMES) zipt_part = "" if zipt_bin: zipt_part = f" zipt={r['zipt']:8s} ({r['t_zipt']:.1f}s)" - print(f"[{done:4d}/{len(files)}] {r['category']:35s} nseq={r['nseq']:8s} ({r['t_nseq']:.1f}s) " - f"seq={r['seq']:8s} ({r['t_seq']:.1f}s){np_part}{zipt_part} {r['file'].name}") + print(f"[{done:4d}/{len(files)}] {r['category']:20s} {solver_part}{zipt_part} {r['file'].parent.name}/{r['file'].name}") except KeyboardInterrupt: print("\nInterrupted — cancelling pending tasks.", file=sys.stderr) pool.shutdown(wait=False, cancel_futures=True) @@ -267,16 +252,12 @@ def main(): # ── Summary ────────────────────────────────────────────────────────────── categories = { - "invalid_model": [], - "all_timeout": [], - "both_timeout": [], - "only_seq_terminates": [], - "only_nseq_terminates": [], - "both_agree": [], - "both_terminate_unknown_involved":[], - "all_terminate_unknown_involved":[], - "all_agree": [], - "diverge": [], + "invalid_model": [], + "all_timeout": [], + "agree_some_timeout": [], + "all_agree": [], + "unknown_involved": [], + "diverge": [], } for r in results: categories.setdefault(r["category"], []).append(r) @@ -288,10 +269,7 @@ def main(): print(f"{'='*70}") # ── Per-solver timeout & divergence file lists ───────────────────────── - nseq_timeouts = [r for r in results if r["nseq"] == "timeout"] - seq_timeouts = [r for r in results if r["seq"] == "timeout"] - both_to = categories["both_timeout"] - diverged = categories["diverge"] + diverged = categories["diverge"] def _print_file_list(label: str, items: list[dict]): print(f"\n{'─'*70}") @@ -300,32 +278,33 @@ def main(): for r in sorted(items, key=lambda x: x["file"]): print(f" {r['file']}") - if nseq_timeouts: - _print_file_list("NSEQ TIMES OUT", nseq_timeouts) - if seq_timeouts: - _print_file_list("SEQ TIMES OUT", seq_timeouts) + for name in SOLVER_NAMES: + timeouts = [r for r in results if r[name] == "timeout"] + if timeouts: + _print_file_list(f"{name.upper()} TIMES OUT", timeouts) if zipt_bin: zipt_timeouts = [r for r in results if r["zipt"] == "timeout"] if zipt_timeouts: _print_file_list("ZIPT TIMES OUT", zipt_timeouts) - if both_to: - _print_file_list("BOTH Z3 SOLVERS TIME OUT", both_to) + + all_to = categories["all_timeout"] + if all_to: + _print_file_list("ALL Z3 CONFIGURATIONS TIME OUT", all_to) invalid_models = categories.get("invalid_model", []) if invalid_models: _print_file_list("INVALID MODEL GENERATED", invalid_models) if zipt_bin: - all_to = [r for r in results - if r["nseq"] == "timeout" and r["seq"] == "timeout" and r["zipt"] == "timeout"] - if all_to: - _print_file_list("ALL THREE TIME OUT", all_to) + all_three_to = [r for r in results + if all(r[name] == "timeout" for name in SOLVER_NAMES) and r["zipt"] == "timeout"] + if all_three_to: + _print_file_list("ALL SOLVERS (INCL. ZIPT) TIME OUT", all_three_to) if diverged: - _print_file_list("DIVERGE nseq vs seq (sat vs unsat)", diverged) + _print_file_list("DIVERGE among z3 configurations (sat vs unsat)", diverged) if zipt_bin: definite = {"sat", "unsat"} zipt_diverged = [r for r in results if r["zipt"] in definite - and ((r["nseq"] in definite and r["nseq"] != r["zipt"]) - or (r["seq"] in definite and r["seq"] != r["zipt"]))] + and any(r[name] in definite and r[name] != r["zipt"] for name in SOLVER_NAMES)] if zipt_diverged: _print_file_list("DIVERGE involving ZIPT", zipt_diverged) @@ -347,12 +326,7 @@ def main(): if args.csv: import csv csv_path = Path(args.csv) - fieldnames = ["file", "nseq", "seq"] - if args.parikh: - fieldnames.append("nseq_p") - fieldnames.extend(["t_nseq", "t_seq"]) - if args.parikh: - fieldnames.append("t_nseq_p") + fieldnames = ["file"] + SOLVER_NAMES + [f"t_{name}" for name in SOLVER_NAMES] fieldnames.extend(["category", "smtlib_status", "status"]) if zipt_bin: fieldnames.extend(["zipt", "t_zipt"]) diff --git a/src/smt/theory_nseq.cpp b/src/smt/theory_nseq.cpp index c5ed38a5f..3aa83a69f 100644 --- a/src/smt/theory_nseq.cpp +++ b/src/smt/theory_nseq.cpp @@ -457,6 +457,40 @@ namespace smt { expr* const s = mem.m_str->get_expr(); std::cout << "Propagating: " << seq::mem_pp(mem, m) << std::endl; + if (!mem.m_str->is_empty()) { + if (mem.m_str->first()->is_char()) { + euf::snode* re_node = mem.m_regex; + euf::snode* str_node = mem.m_str; + do { + // eliminate leading character by derivatives + re_node = m_sgraph.brzozowski_deriv(re_node, mem.m_str->first()); + str_node = m_sgraph.drop_first(str_node); + } while (!str_node->is_empty() && str_node->first()->is_char()); + + if (re_node->is_fail()) { + literal_vector lits; + lits.push_back(mem.lit); + set_conflict(lits); + return; + } + const expr_ref e(m_seq.re.mk_in_re(str_node->get_expr(), re_node->get_expr()), m); + ctx.mk_th_axiom(get_id(), ~mem.lit, mk_literal(e)); + m_ignored_mem.insert(mem.lit); + ctx.push_trail(insert_map(m_ignored_mem, mem.lit)); + return; + } + } + else { + // check nullability + if (m_sgraph.re_nullable(mem.m_regex) == l_true) { + // empty string in nullable regex → trivially satisfied + m_ignored_mem.insert(mem.lit); + ctx.push_trail(insert_map(m_ignored_mem, mem.lit)); + return; + } + return; + } + if (mem.m_regex->is_full_seq()) { // u \in .* can be ignored m_ignored_mem.insert(mem.lit); @@ -510,6 +544,11 @@ namespace smt { if (!get_fparams().m_nseq_regex_factorization_threshold) return; + SASSERT(!mem.m_str->is_empty()); + SASSERT(!mem.m_str->first()->is_char()); + if (!mem.m_str->first()->is_var()) + return; + // Eager sigma factorization (token-level): when enabled, split a non-primitive // membership s ∈ r at the boundary between the first concat argument (head) and // the rest (tail), using compute_sigma. This mirrors the lazy Nielsen @@ -553,12 +592,16 @@ namespace smt { // forward direction; mk_literal Tseitin-encodes each conjunction literal_vector lits; lits.push_back(~mem.lit); + std::cout << "Decomposing into:\n"; for (auto const& sp : pairs) { expr_ref mem_head(m_seq.re.mk_in_re(head, sp.m_p), m); expr_ref mem_tail(m_seq.re.mk_in_re(tail, sp.m_q), m); expr_ref conj(m.mk_and(mem_head, mem_tail), m); lits.push_back(mk_literal(conj)); + seq::dep_tracker dep = nullptr; + std::cout << seq::mem_pp(seq::str_mem(m_sgraph.mk(head), m_sgraph.mk(sp.m_p), dep), m) << " && " << seq::mem_pp(seq::str_mem(m_sgraph.mk(tail), m_sgraph.mk(sp.m_q), dep), m) << "\n"; } + std::cout << std::endl; ctx.mk_th_axiom(get_id(), lits.size(), lits.data()); m_ignored_mem.insert(mem.lit); ctx.push_trail(insert_map(m_ignored_mem, mem.lit));