diff --git a/scripts/compare_seq_solvers.py b/scripts/compare_seq_solvers.py index aa7f3c04d..9519791b5 100644 --- a/scripts/compare_seq_solvers.py +++ b/scripts/compare_seq_solvers.py @@ -21,7 +21,7 @@ import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -TIMEOUT = 5 # seconds +DEFAULT_TIMEOUT = 5 # seconds COMMON_ARGS = ["model_validate=true"] SOLVERS = { @@ -71,27 +71,31 @@ def determine_status(res_nseq: str, res_seq: str, smtlib_status: str) -> str: return "unknown" -def run_z3(z3_bin: str, smt_file: Path, solver_arg: str) -> tuple[str, float]: +def run_z3(z3_bin: str, smt_file: Path, solver_arg: str, timeout_s: int = DEFAULT_TIMEOUT) -> tuple[str, float]: """Run z3 on a file with the given solver argument. Returns (result, elapsed) where result is 'sat', 'unsat', 'unknown', or 'timeout'/'error'. """ - cmd = [z3_bin, solver_arg] + COMMON_ARGS + [str(smt_file)] + timeout_ms = timeout_s * 1000 + cmd = [z3_bin, f"-t:{timeout_ms}", solver_arg] + COMMON_ARGS + [str(smt_file)] start = time.monotonic() try: proc = subprocess.run( cmd, capture_output=True, text=True, - timeout=TIMEOUT, + timeout=timeout_s + 5, # subprocess grace period beyond Z3's own timeout ) elapsed = time.monotonic() - start output = proc.stdout.strip() # Extract first meaningful line (sat/unsat/unknown) for line in output.splitlines(): line = line.strip() - if line in ("sat", "unsat", "unknown"): + if line in ("sat", "unsat"): return line, elapsed - return "unknown", elapsed + if line == "unknown": + # Z3 returns "unknown" when it hits -t: limit — treat as timeout + return "timeout", elapsed + return "timeout", elapsed except subprocess.TimeoutExpired: elapsed = time.monotonic() - start return "timeout", elapsed @@ -119,9 +123,9 @@ def classify(res_nseq: str, res_seq: str) -> str: return "diverge" -def process_file(z3_bin: str, smt_file: Path) -> dict: - res_nseq, t_nseq = run_z3(z3_bin, smt_file, SOLVERS["nseq"]) - res_seq, t_seq = run_z3(z3_bin, smt_file, SOLVERS["seq"]) +def process_file(z3_bin: str, smt_file: Path, timeout_s: int = DEFAULT_TIMEOUT) -> dict: + res_nseq, t_nseq = run_z3(z3_bin, smt_file, SOLVERS["nseq"], timeout_s) + res_seq, t_seq = run_z3(z3_bin, smt_file, SOLVERS["seq"], timeout_s) cat = classify(res_nseq, res_seq) smtlib_status = read_smtlib_status(smt_file) status = determine_status(res_nseq, res_seq, smtlib_status) @@ -143,10 +147,13 @@ def main(): parser.add_argument("--z3", required=True, metavar="PATH", help="Path to z3 binary") parser.add_argument("--ext", default=".smt2", help="File extension to search for (default: .smt2)") parser.add_argument("--jobs", type=int, default=4, help="Parallel workers (default: 4)") + parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, metavar="SEC", + help=f"Per-solver timeout in seconds (default: {DEFAULT_TIMEOUT})") parser.add_argument("--csv", metavar="FILE", help="Also write results to a CSV file") args = parser.parse_args() z3_bin = args.z3 + timeout_s = args.timeout root = Path(args.path) if not root.exists(): @@ -158,11 +165,11 @@ def main(): print(f"No {args.ext} files found under {root}", file=sys.stderr) sys.exit(1) - print(f"Found {len(files)} files. Running with {args.jobs} parallel workers …\n") + print(f"Found {len(files)} files. Running with {args.jobs} parallel workers, timeout={timeout_s}s …\n") results = [] with ThreadPoolExecutor(max_workers=args.jobs) as pool: - futures = {pool.submit(process_file, z3_bin, f): f for f in files} + futures = {pool.submit(process_file, z3_bin, f, timeout_s): f for f in files} done = 0 for fut in as_completed(futures): done += 1 diff --git a/src/smt/seq/seq_nielsen.cpp b/src/smt/seq/seq_nielsen.cpp index a2ebc8d6e..a0ec09c2d 100644 --- a/src/smt/seq/seq_nielsen.cpp +++ b/src/smt/seq/seq_nielsen.cpp @@ -3307,40 +3307,90 @@ namespace seq { expr* exp_n = get_power_exponent(power); expr* zero = arith.mk_int(0); - // Branch 1: x = base^m · prefix where 0 <= m < n - // Side constraints: m >= 0, m < n (i.e., n >= m + 1) + // Branch 1: enumerate all decompositions of the base. + // x = base^m · prefix_i(base) where 0 <= m < n + // Uses the same GetDecompose pattern as fire_gpower_intro. { + euf::snode_vector base_toks; + base->collect_tokens(base_toks); + unsigned base_len = base_toks.size(); + expr* base_expr = get_power_base_expr(power); + if (!base_expr || base_len == 0) + return false; + expr_ref fresh_m = mk_fresh_int_var(); - euf::snode* fresh_power = mk_fresh_var(); // represents base^m - euf::snode* fresh_suffix = mk_fresh_var(); // represents prefix(base) - euf::snode* replacement = m_sg.mk_concat(fresh_power, fresh_suffix); - nielsen_node* child = mk_child(node); - nielsen_edge* e = mk_edge(node, child, true); - nielsen_subst s(var_head, replacement, eq->m_dep); - e->add_subst(s); - child->apply_subst(m_sg, s); - // m >= 0 - e->add_side_int(mk_int_constraint(fresh_m, zero, int_constraint_kind::ge, eq->m_dep)); - // m < n ⟺ n >= m + 1 - if (exp_n) { - expr_ref m_plus_1(arith.mk_add(fresh_m, arith.mk_int(1)), m); - e->add_side_int(mk_int_constraint(exp_n, m_plus_1, int_constraint_kind::ge, eq->m_dep)); + expr_ref power_m_expr(seq.str.mk_power(base_expr, fresh_m), m); + euf::snode* power_m_sn = m_sg.mk(power_m_expr); + if (!power_m_sn) + return false; + + for (unsigned i = 0; i < base_len; ++i) { + euf::snode* tok = base_toks[i]; + + // Skip char position when preceding token is a power: + // the power case at i-1 with 0 <= m' <= exp already covers m' = exp. + if (!tok->is_power() && i > 0 && base_toks[i - 1]->is_power()) + continue; + + // Build full-token prefix: base_toks[0..i-1] + euf::snode* prefix_sn = nullptr; + for (unsigned j = 0; j < i; ++j) + prefix_sn = (j == 0) ? base_toks[0] : m_sg.mk_concat(prefix_sn, base_toks[j]); + + euf::snode* replacement; + expr_ref fresh_inner_m(m); + + if (tok->is_power()) { + // Token is a power u^exp: decompose with fresh m', 0 <= m' <= exp + expr* inner_exp = get_power_exponent(tok); + expr* inner_base_e = get_power_base_expr(tok); + if (inner_exp && inner_base_e) { + fresh_inner_m = mk_fresh_int_var(); + expr_ref partial_pow(seq.str.mk_power(inner_base_e, fresh_inner_m), m); + euf::snode* partial_sn = m_sg.mk(partial_pow); + euf::snode* suffix_sn = prefix_sn ? m_sg.mk_concat(prefix_sn, partial_sn) : partial_sn; + replacement = m_sg.mk_concat(power_m_sn, suffix_sn); + } else { + euf::snode* suffix_sn = prefix_sn ? m_sg.mk_concat(prefix_sn, tok) : tok; + replacement = m_sg.mk_concat(power_m_sn, suffix_sn); + } + } else { + // P(char) = ε, suffix is just the prefix + replacement = prefix_sn ? m_sg.mk_concat(power_m_sn, prefix_sn) : power_m_sn; + } + + nielsen_node* child = mk_child(node); + nielsen_edge* e = mk_edge(node, child, true); + nielsen_subst s(var_head, replacement, eq->m_dep); + e->add_subst(s); + child->apply_subst(m_sg, s); + + // m >= 0 + e->add_side_int(mk_int_constraint(fresh_m, zero, int_constraint_kind::ge, eq->m_dep)); + // m < n ⟺ n >= m + 1 + if (exp_n) { + expr_ref m_plus_1(arith.mk_add(fresh_m, arith.mk_int(1)), m); + e->add_side_int(mk_int_constraint(exp_n, m_plus_1, int_constraint_kind::ge, eq->m_dep)); + } + + // Inner power constraints: 0 <= m' <= inner_exp + if (fresh_inner_m.get()) { + expr* inner_exp = get_power_exponent(tok); + e->add_side_int(mk_int_constraint(fresh_inner_m, zero, int_constraint_kind::ge, eq->m_dep)); + e->add_side_int(mk_int_constraint(inner_exp, fresh_inner_m, int_constraint_kind::ge, eq->m_dep)); + } } } // Branch 2: x = u^n · x' (variable extends past full power, non-progress) - // Side constraint: n >= 0 { euf::snode* fresh_tail = mk_fresh_var(); - // Peel one base unit (approximation of extending past the power) - euf::snode* replacement = m_sg.mk_concat(base, fresh_tail); + euf::snode* replacement = m_sg.mk_concat(power, fresh_tail); nielsen_node* child = mk_child(node); nielsen_edge* e = mk_edge(node, child, false); nielsen_subst s(var_head, replacement, eq->m_dep); e->add_subst(s); child->apply_subst(m_sg, s); - if (exp_n) - e->add_side_int(mk_int_constraint(exp_n, zero, int_constraint_kind::ge, eq->m_dep)); } return true;