diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 44744c821b..3dd2d9a364 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -1855,6 +1855,127 @@ br_status seq_rewriter::mk_seq_replace_all(expr* a, expr* b, expr* c, expr_ref& return BR_FAILED; } + +/** + * replace_char("ab", "a", b") = empty + * replace_char("bc", "a", b") = {"a", "b"}"c" + * replace_char(R u R', "a", "b") = replace_char(R, "a", "b") u replace_char(R', "a", "b") + * replace_char(R n R', "a", "b") = replace_char(R, "a", "b") n replace_char(R', "a", "b") + * replace_char(R*, "a", "b") = replace_char(R, "a", "b")* + * replace_char(R R', "a", "b") = replace_char(R, "a", "b") replace_char(R', "a", "b") + */ +expr_ref seq_rewriter::re_replace_char(expr *r, unsigned a_ch, unsigned b_ch, expr *a_str, expr *b_str) { + expr *r1 = nullptr, *r2 = nullptr, *s = nullptr; + zstring str_val; + sort *seq_sort = nullptr; + + if (re().is_to_re(r, s) && str().is_string(s, str_val)) { + seq_sort = s->get_sort(); + expr_ref_vector parts(m()); + for (unsigned i = 0; i < str_val.length(); ++i) { + if (str_val[i] == a_ch) { + // replace_all never outputs a_ch, so this position is impossible + return expr_ref(re().mk_empty(re().mk_re(seq_sort)), m()); + } + else if (str_val[i] == b_ch) { + // b in output came from either a or b in x + auto a_re = re().mk_to_re(a_str); + auto b_re = re().mk_to_re(b_str); + parts.push_back(re().mk_union(a_re, b_re)); + } + else { + zstring ch(str_val[i]); + parts.push_back(re().mk_to_re(str().mk_string(ch))); + } + } + if (parts.empty()) + return expr_ref(re().mk_epsilon(seq_sort), m()); + expr_ref result(parts.back(), m()); + for (int i = parts.size() - 1; i-- > 0;) + result = re().mk_concat(parts.get(i), result); + return result; + } + + if (re().is_range(r, r1, r2)) { + zstring lo_s, hi_s; + if (str().is_string(r1, lo_s) && str().is_string(r2, hi_s) && lo_s.length() == 1 && hi_s.length() == 1) { + unsigned lo = lo_s[0], hi = hi_s[0]; + // Build the transformed range: + // - Remove a_ch from the range (impossible in output) + // - Replace b_ch with union(a_str, b_str) + expr_ref_vector parts(m()); + // Characters in [lo, hi] excluding a_ch and b_ch + if (lo <= hi) { + // Sub-ranges excluding a_ch and b_ch + unsigned prev = lo; + for (unsigned ch = lo; ch <= hi; ++ch) { + if (ch == a_ch || ch == b_ch) { + if (prev < ch) { + zstring prev_z(prev), pred_z(ch - 1); + parts.push_back(re().mk_range(str().mk_string(prev_z), str().mk_string(pred_z))); + } + if (ch == b_ch) { + parts.push_back(re().mk_union(re().mk_to_re(a_str), re().mk_to_re(b_str))); + } + // a_ch is simply excluded (not added) + prev = ch + 1; + } + } + if (prev <= hi) { + zstring prev_z(prev), hi_z(hi); + parts.push_back(re().mk_range(str().mk_string(prev_z), str().mk_string(hi_z))); + } + } + if (parts.empty()) { + sort *re_sort = r->get_sort(); + return expr_ref(re().mk_empty(re_sort), m()); + } + expr_ref result(parts[0].get(), m()); + for (unsigned i = 1; i < parts.size(); ++i) + result = re().mk_union(result, parts[i].get()); + return result; + } + return expr_ref(r, m()); + } + + if (re().is_union(r, r1, r2)) { + return expr_ref( + re().mk_union(re_replace_char(r1, a_ch, b_ch, a_str, b_str), re_replace_char(r2, a_ch, b_ch, a_str, b_str)), + m()); + } + if (re().is_intersection(r, r1, r2)) { + return expr_ref( + re().mk_inter(re_replace_char(r1, a_ch, b_ch, a_str, b_str), re_replace_char(r2, a_ch, b_ch, a_str, b_str)), + m()); + } + if (re().is_concat(r, r1, r2)) { + return expr_ref(re().mk_concat(re_replace_char(r1, a_ch, b_ch, a_str, b_str), + re_replace_char(r2, a_ch, b_ch, a_str, b_str)), + m()); + } + if (re().is_star(r, r1)) { + return expr_ref(re().mk_star(re_replace_char(r1, a_ch, b_ch, a_str, b_str)), m()); + } + if (re().is_plus(r, r1)) { + return expr_ref(re().mk_plus(re_replace_char(r1, a_ch, b_ch, a_str, b_str)), m()); + } + if (re().is_opt(r, r1)) { + return expr_ref(re().mk_opt(re_replace_char(r1, a_ch, b_ch, a_str, b_str)), m()); + } + unsigned lo, hi; + if (re().is_loop(r, r1, lo, hi)) { + return expr_ref(re().mk_loop(re_replace_char(r1, a_ch, b_ch, a_str, b_str), lo, hi), m()); + } + if (re().is_loop(r, r1, lo)) { + return expr_ref(re().mk_loop(re_replace_char(r1, a_ch, b_ch, a_str, b_str), lo), m()); + } + if (re().is_complement(r)) { + UNREACHABLE(); + } + // For anything else (full_seq, empty, epsilon, of_pred, etc.), return unchanged + return expr_ref(r, m()); +} + /** rewrites for map(f, s): @@ -4444,6 +4565,23 @@ br_status seq_rewriter::mk_str_in_regexp(expr* a, expr* b, expr_ref& result) { } } + + // replace_all(x, a, b) in R where R is ground, a and b are unit-length strings + // ==> x in R[b -> {a, b}, a -> empty] + expr *ra_x = nullptr, *ra_a = nullptr, *ra_b = nullptr; + zstring sa_val, sb_val; + if (str().is_replace_all(a, ra_x, ra_a, ra_b) && ra_a == ra_b) { + result = ra_x; + return BR_DONE; + } + if (str().is_replace_all(a, ra_x, ra_a, ra_b) && str().is_string(ra_a, sa_val) && sa_val.length() == 1 && + str().is_string(ra_b, sb_val) && sb_val.length() == 1 && sa_val[0] != sb_val[0] && re().is_ground(b) && + re().get_info(b).classical) { + expr_ref new_re = re_replace_char(b, sa_val[0], sb_val[0], ra_a, ra_b); + result = re().mk_in_re(ra_x, new_re); + return BR_REWRITE_FULL; + } + expr_ref b_s(m()); if (lift_str_from_to_re(b, b_s)) { result = m_br.mk_eq_rw(a, b_s); @@ -5475,6 +5613,25 @@ br_status seq_rewriter::mk_eq_core(expr * l, expr * r, expr_ref & result) { if (reduce_eq_empty(l, r, result)) return BR_REWRITE_FULL; + // a, b are unit-length ground strings => replace_all(x, a, b) in re.to_re(s) + { + expr *ra_x = nullptr, *ra_a = nullptr, *ra_b = nullptr; + zstring sa_val, sb_val, s_val; + expr *str_side = nullptr, *ra_side = nullptr; + if (str().is_replace_all(l)) + ra_side = l, str_side = r; + else if (str().is_replace_all(r)) + ra_side = r, str_side = l; + if (ra_side && str_side && + str().is_replace_all(ra_side, ra_x, ra_a, ra_b) && str().is_string(ra_a, sa_val) && + sa_val.length() == 1 && + str().is_string(ra_b, sb_val) && sb_val.length() == 1 && + str().is_string(str_side, s_val)) { + result = re().mk_in_re(ra_side, re().mk_to_re(str_side)); + return BR_REWRITE_FULL; + } + } + #if 0 if (reduce_arith_eq(l, r, res) || reduce_arith_eq(r, l, res)) { result = mk_and(res); diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 618124e10b..1b693ca3d6 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -173,6 +173,11 @@ class seq_rewriter { //replace b in a by c into result void replace_all_subvectors(expr_ref_vector const& as, expr_ref_vector const& bs, expr* c, expr_ref_vector& result); + // For replace_all(x, a, b) in R: transform R so that + // - occurrences of b_ch are replaced by union(to_re(a_str), to_re(b_str)) + // - occurrences of a_ch are replaced by empty (replace_all never outputs a) + expr_ref re_replace_char(expr *r, unsigned a_ch, unsigned b_ch, expr *a_str, expr *b_str); + // Calculate derivative, memoized and enforcing a normal form expr_ref is_nullable_rec(expr* r); expr_ref mk_derivative_rec(expr* ele, expr* r);