From 8fdb491c1b644f03b12bc415a1e1311ee64c09c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Mar 2026 19:33:01 +0000 Subject: [PATCH] refactor: use arg(0)/arg(1) instead of seq_util in power accessors Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com> --- src/ast/euf/euf_snode.h | 14 +++++++++ src/smt/seq/seq_nielsen.cpp | 60 ++++++++++++++----------------------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/ast/euf/euf_snode.h b/src/ast/euf/euf_snode.h index 1028cdfa1..6ea8b488f 100644 --- a/src/ast/euf/euf_snode.h +++ b/src/ast/euf/euf_snode.h @@ -132,6 +132,20 @@ namespace euf { bool is_to_re() const { return m_kind == snode_kind::s_to_re; } bool is_in_re() const { return m_kind == snode_kind::s_in_re; } + // get the base expression of a power snode, e.g., s from s^n + expr* get_power_base(seq_util& seq) const { + if (!is_power()) return nullptr; + expr* base = nullptr, *exp = nullptr; + return (m_expr && seq.str.is_power(m_expr, base, exp)) ? base : nullptr; + } + + // get the exponent expression of a power snode, e.g., n from s^n + expr* get_power_exp(seq_util& seq) const { + if (!is_power()) return nullptr; + expr* base = nullptr, *exp = nullptr; + return (m_expr && seq.str.is_power(m_expr, base, exp)) ? exp : nullptr; + } + // is this a leaf token (analogous to ZIPT's StrToken as opposed to Str) bool is_token() const { switch (m_kind) { diff --git a/src/smt/seq/seq_nielsen.cpp b/src/smt/seq/seq_nielsen.cpp index 588bf66be..a78e42382 100644 --- a/src/smt/seq/seq_nielsen.cpp +++ b/src/smt/seq/seq_nielsen.cpp @@ -1132,22 +1132,6 @@ namespace seq { return false; } - // Get the base expression of a power snode. - static expr* get_power_base_expr(euf::snode* power, seq_util& seq) { - if (!power || !power->is_power()) return nullptr; - expr* e = power->get_expr(); - expr* base = nullptr, *exp = nullptr; - return (e && seq.str.is_power(e, base, exp)) ? base : nullptr; - } - - // Get the exponent expression of a power snode. - static expr* get_power_exp_expr(euf::snode* power, seq_util& seq) { - if (!power || !power->is_power()) return nullptr; - expr* e = power->get_expr(); - expr* base = nullptr, *exp = nullptr; - return (e && seq.str.is_power(e, base, exp)) ? exp : nullptr; - } - // Merge adjacent tokens with the same power base on one side of an equation. // Handles: char(c) · power(c^e) → power(c^(e+1)), // power(c^e) · char(c) → power(c^(e+1)), @@ -1177,8 +1161,8 @@ namespace seq { // cross-side cancellation works better with unmerged leading powers // (e.g., w^k trivially ≤ 1+k, but w^(2k) vs 1+k requires k ≥ 1). if (tok->is_power() && i > 0) { - expr* base_e = get_power_base_expr(tok, seq); - expr* exp_acc = get_power_exp_expr(tok, seq); + expr* base_e = tok->get_power_base(seq); + expr* exp_acc = tok->get_power_exp(seq); if (!base_e || !exp_acc) { result.push_back(tok); ++i; continue; } bool local_merged = false; @@ -1186,9 +1170,9 @@ namespace seq { while (j < tokens.size()) { euf::snode* next = tokens[j]; if (next->is_power()) { - expr* nb = get_power_base_expr(next, seq); + expr* nb = next->get_power_base(seq); if (nb == base_e) { - exp_acc = arith.mk_add(exp_acc, get_power_exp_expr(next, seq)); + exp_acc = arith.mk_add(exp_acc, next->get_power_exp(seq)); local_merged = true; ++j; continue; } } @@ -1215,17 +1199,17 @@ namespace seq { // unwind produces u · u^(n-1); merging it back to u^n creates an infinite cycle. if (i > 0 && tok->is_char() && tok->get_expr() && i + 1 < tokens.size()) { euf::snode* next = tokens[i + 1]; - if (next->is_power() && get_power_base_expr(next, seq) == tok->get_expr()) { + if (next->is_power() && next->get_power_base(seq) == tok->get_expr()) { expr* base_e = tok->get_expr(); // Use same arg order as Case 1: add(exp, 1), not add(1, exp), // so that merging "c · c^e" and "c^e · c" both produce add(e, 1) // and the resulting power expression is hash-consed identically. - expr* exp_acc = arith.mk_add(get_power_exp_expr(next, seq), arith.mk_int(1)); + expr* exp_acc = arith.mk_add(next->get_power_exp(seq), arith.mk_int(1)); unsigned j = i + 2; while (j < tokens.size()) { euf::snode* further = tokens[j]; - if (further->is_power() && get_power_base_expr(further, seq) == base_e) { - exp_acc = arith.mk_add(exp_acc, get_power_exp_expr(further, seq)); + if (further->is_power() && further->get_power_base(seq) == base_e) { + exp_acc = arith.mk_add(exp_acc, further->get_power_exp(seq)); ++j; continue; } if (further->is_char() && further->get_expr() == base_e) { @@ -1273,7 +1257,7 @@ namespace seq { for (euf::snode* tok : tokens) { if (tok->is_power()) { - expr* exp_e = get_power_exp_expr(tok, seq); + expr* exp_e = tok->get_power_exp(seq); rational val; if (exp_e && arith.is_numeral(exp_e, val)) { if (val.is_zero()) { @@ -1352,7 +1336,7 @@ namespace seq { for (unsigned j = 0; j < pb_tokens.size() && match; j++) match = (pb_tokens[j] == base_tokens[j]); if (match) { - expr* pow_exp = get_power_exp_expr(t, seq); + expr* pow_exp = t->get_power_exp(seq); if (pow_exp) { sum = sum ? arith.mk_add(sum, pow_exp) : pow_exp; continue; @@ -1501,7 +1485,7 @@ namespace seq { nielsen_subst s(pow_head, sg.mk_empty_seq(pow_head->get_sort()), eq.m_dep); e->add_subst(s); child->apply_subst(sg, s); - expr* pow_exp = get_power_exp_expr(pow_head, seq); + expr* pow_exp = pow_head->get_power_exp(seq); if (pow_exp) { expr* zero = arith.mk_numeral(rational(0), true); e->add_side_int(m_graph.mk_int_constraint( @@ -1555,7 +1539,7 @@ namespace seq { euf::snode* end_tok = dir_token(pow_side, fwd); if (!end_tok || !end_tok->is_power()) continue; euf::snode* base_sn = end_tok->arg(0); - expr* pow_exp = get_power_exp_expr(end_tok, seq); + expr* pow_exp = end_tok->get_power_exp(seq); if (!base_sn || !pow_exp) continue; auto [count, consumed] = comm_power(base_sn, other_side, m, fwd); @@ -1576,7 +1560,7 @@ namespace seq { pow_side = dir_drop(sg, pow_side, 1, fwd); other_side = dir_drop(sg, other_side, consumed, fwd); - expr* base_e = get_power_base_expr(end_tok, seq); + expr* base_e = end_tok->get_power_base(seq); if (pow_le_count && count_le_pow) { // equal: both cancel completely } @@ -1617,12 +1601,12 @@ namespace seq { euf::snode* rh = dir_token(eq.m_rhs, fwd); if (!(lh && rh && lh->is_power() && rh->is_power())) continue; - expr* lb = get_power_base_expr(lh, seq); - expr* rb = get_power_base_expr(rh, seq); + expr* lb = lh->get_power_base(seq); + expr* rb = rh->get_power_base(seq); if (!(lb && rb && lb == rb)) continue; - expr* lp = get_power_exp_expr(lh, seq); - expr* rp = get_power_exp_expr(rh, seq); + expr* lp = lh->get_power_exp(seq); + expr* rp = rh->get_power_exp(seq); rational diff; if (lp && rp && get_const_power_diff(rp, lp, arith, diff)) { // rp = lp + diff (constant difference) @@ -3177,7 +3161,7 @@ namespace seq { euf::snode* end_tok = dir_token(pow_side, fwd); if (!end_tok || !end_tok->is_power()) continue; euf::snode* base_sn = end_tok->arg(0); - expr* pow_exp = get_power_exp_expr(end_tok, seq); + expr* pow_exp = end_tok->get_power_exp(seq); if (!base_sn || !pow_exp) continue; auto [count, consumed] = comm_power(base_sn, other_side, m, fwd); @@ -3477,7 +3461,7 @@ namespace seq { // E.g., [(ab)^3] → [a, b] so we get (ab)^n instead of ((ab)^3)^n. // (mirrors ZIPT: if b.Length == 1 && b is PowerToken pt => b = pt.Base) if (ground_prefix.size() == 1 && ground_prefix[0]->is_power()) { - expr* base_e = get_power_base_expr(ground_prefix[0], seq); + expr* base_e = ground_prefix[0]->get_power_base(seq); if (base_e) { euf::snode* base_sn = m_sg.mk(base_e); if (base_sn) { @@ -3540,7 +3524,7 @@ namespace seq { if (tok->is_power()) { // Token is a power u^exp: use fresh m' with 0 ≤ m' ≤ exp expr* inner_exp = get_power_exponent(tok); - expr* inner_base = get_power_base_expr(tok, seq); + expr* inner_base = tok->get_power_base(seq); if (inner_exp && inner_base) { fresh_m = mk_fresh_int_var(); expr_ref partial_pow(seq.str.mk_power(inner_base, fresh_m), m); @@ -3680,7 +3664,7 @@ namespace seq { euf::snode_vector base_toks; collect_tokens_dir(base, fwd, base_toks); unsigned base_len = base_toks.size(); - expr* base_expr = get_power_base_expr(power, seq); + expr* base_expr = power->get_power_base(seq); if (!base_expr || base_len == 0) return false; @@ -3709,7 +3693,7 @@ namespace seq { if (tok->is_power()) { // Token is a power u^exp: decompose with fresh m', 0 <= m' <= exp expr* inner_exp = get_power_exponent(tok); - expr* inner_base_e = get_power_base_expr(tok, seq); + expr* inner_base_e = tok->get_power_base(seq); if (inner_exp && inner_base_e) { fresh_inner_m = mk_fresh_int_var(); expr_ref partial_pow(seq.str.mk_power(inner_base_e, fresh_inner_m), m);