diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 8962b80ec..0e74a7db2 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -519,6 +519,14 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con SASSERT(num_args == 1); st = mk_re_opt(args[0], result); break; + case OP_RE_REVERSE: + SASSERT(num_args == 1); + st = mk_re_reverse(args[0], result); + break; + case OP_RE_DERIVATIVE: + SASSERT(num_args == 2); + st = mk_re_derivative(args[0], args[1], result); + break; case OP_RE_CONCAT: if (num_args == 1) { result = args[0]; @@ -2052,6 +2060,9 @@ bool seq_rewriter::is_sequence(expr* e, expr_ref_vector& seq) { return true; } +/* + s = head + tail where |head| = 1 +*/ bool seq_rewriter::get_head_tail(expr* s, expr_ref& head, expr_ref& tail) { expr* h = nullptr, *t = nullptr; zstring s1; @@ -2063,7 +2074,7 @@ bool seq_rewriter::get_head_tail(expr* s, expr_ref& head, expr_ref& tail) { if (m_util.str.is_string(s, s1) && s1.length() > 0) { head = m_util.mk_char(s1[0]); tail = m_util.str.mk_string(s1.extract(1, s1.length())); - return true; + return true; } if (m_util.str.is_concat(s, h, t) && get_head_tail(h, head, tail)) { tail = m_util.str.mk_concat(tail, t); @@ -2072,6 +2083,29 @@ bool seq_rewriter::get_head_tail(expr* s, expr_ref& head, expr_ref& tail) { return false; } +/* + s = head + tail where |tail| = 1 +*/ +bool seq_rewriter::get_head_tail_reversed(expr* s, expr_ref& head, expr_ref& tail) { + expr* h = nullptr, *t = nullptr; + zstring s1; + if (m_util.str.is_unit(s, t)) { + head = m_util.str.mk_empty(m().get_sort(s)); + tail = t; + return true; + } + if (m_util.str.is_string(s, s1) && s1.length() > 0) { + head = m_util.str.mk_string(s1.extract(0, s1.length() - 1)); + tail = m_util.mk_char(s1[s1.length() - 1]); + return true; + } + if (m_util.str.is_concat(s, h, t) && get_head_tail_reversed(t, head, tail)) { + head = m_util.str.mk_concat(h, head); + return true; + } + return false; +} + expr_ref seq_rewriter::re_and(expr* cond, expr* r) { if (m().is_true(cond)) return expr_ref(r, m()); @@ -2117,7 +2151,8 @@ expr_ref seq_rewriter::is_nullable(expr* r) { } else if (re().is_plus(r, r1) || (re().is_loop(r, r1, lo) && lo > 0) || - (re().is_loop(r, r1, lo, hi) && lo > 0)) { + (re().is_loop(r, r1, lo, hi) && lo > 0) || + (re().is_reverse(r, r1))) { result = is_nullable(r1); } else if (re().is_complement(r, r1)) { @@ -2141,126 +2176,212 @@ expr_ref seq_rewriter::is_nullable(expr* r) { } /* - Symbolic derivative - Evaluates recursively. - Returns null expression `expr_ref(m())` on failure. + Push reverse inwards (gets stuck at variables and strings). */ -expr_ref seq_rewriter::derivative(expr* elem, expr* r) { - sort* seq_sort = nullptr, *elem_sort = nullptr; +br_status seq_rewriter::mk_re_reverse(expr* r, expr_ref& result) { + sort* seq_sort = nullptr; VERIFY(m_util.is_re(r, seq_sort)); - VERIFY(m_util.is_seq(seq_sort, elem_sort)); - SASSERT(elem_sort == m().get_sort(elem)); - expr* r1 = nullptr, * r2 = nullptr, *p = nullptr; - expr_ref dr1(m()), dr2(m()), result(m()); + expr* r1 = nullptr, *r2 = nullptr, *p = nullptr; + unsigned lo = 0, hi = 0; + if (re().is_concat(r, r1, r2)) { + result = re().mk_concat(re().mk_reverse(r2), re().mk_reverse(r1)); + return BR_REWRITE2; + } + else if (re().is_star(r, r1)) { + result = re().mk_star((re().mk_reverse(r1))); + return BR_REWRITE2; + } + else if (re().is_plus(r, r1)) { + result = re().mk_plus((re().mk_reverse(r1))); + return BR_REWRITE2; + } + else if (re().is_union(r, r1, r2)) { + result = re().mk_union(re().mk_reverse(r1), re().mk_reverse(r2)); + return BR_REWRITE2; + } + else if (re().is_intersection(r, r1, r2)) { + result = re().mk_inter(re().mk_reverse(r1), re().mk_reverse(r2)); + return BR_REWRITE2; + } + else if (re().is_diff(r, r1, r2)) { + result = re().mk_diff(re().mk_reverse(r1), re().mk_reverse(r2)); + return BR_REWRITE2; + } + else if (m().is_ite(r, p, r1, r2)) { + result = m().mk_ite(p, re().mk_reverse(r1), re().mk_reverse(r2)); + return BR_REWRITE2; + } + else if (re().is_opt(r, r1)) { + result = re().mk_opt(re().mk_reverse(r1)); + return BR_REWRITE2; + } + else if (re().is_complement(r, r1)) { + result = re().mk_complement(re().mk_reverse(r1)); + return BR_REWRITE2; + } + else if (re().is_loop(r, r1, lo)) { + result = re().mk_loop(re().mk_reverse(r1), lo); + return BR_REWRITE2; + } + else if (re().is_loop(r, r1, lo, hi)) { + result = re().mk_loop(re().mk_reverse(r1), lo, hi); + return BR_REWRITE2; + } + else if (re().is_reverse(r, r1)) { + result = r1; + return BR_DONE; + } + else if (re().is_full_seq(r) || + re().is_empty(r) || + re().is_range(r) || + re().is_full_char(r) || + re().is_of_pred(r)) { + result = r; + return BR_DONE; + } + else { + // stuck cases: variable, re().is_to_re, re().is_derivative, ... + return BR_FAILED; + } +} + +/* + Symbolic derivative: seq -> regex -> regex + seq should be single char +*/ +br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) { + sort* seq_sort = nullptr, *ele_sort = nullptr; + VERIFY(m_util.is_re(r, seq_sort)); + VERIFY(m_util.is_seq(seq_sort, ele_sort)); + SASSERT(ele_sort == m().get_sort(ele)); + expr* r1 = nullptr, *r2 = nullptr, *p = nullptr; unsigned lo = 0, hi = 0; if (re().is_concat(r, r1, r2)) { expr_ref is_n = is_nullable(r1); - dr1 = derivative(elem, r1); - if (!dr1) { - result = dr1; // failed + expr* dr1 = re().mk_derivative(ele, r1); + expr* dr2 = re().mk_derivative(ele, r2); + result = re().mk_concat(dr1, r2); + if (m().is_false(is_n)) { + return BR_REWRITE2; } - else if (m().is_false(is_n)) { - result = re().mk_concat(dr1, r2); + else if (m().is_true(is_n)) { + result = re().mk_union(result, dr2); + return BR_REWRITE3; } else { - dr2 = derivative(elem, r2); - if (!dr2) { - result = dr2; // failed - } - else if (m().is_true(is_n)) { - result = re().mk_union( - re().mk_concat(dr1, r2), - dr2 - ); - } - else { - result = re().mk_union( - re().mk_concat(dr1, r2), - re_and(is_n, dr2) - ); - } + result = re().mk_union(result, re_and(is_n, dr2)); + return BR_REWRITE3; } } else if (re().is_star(r, r1)) { - result = derivative(elem, r1); - if (result) { - result = re().mk_concat(result, r); - } + result = re().mk_concat(re().mk_derivative(ele, r1), r); + return BR_REWRITE2; } else if (re().is_plus(r, r1)) { - result = re().mk_star(r1); - result = derivative(elem, result); + result = re().mk_derivative(ele, re().mk_star(r1)); + return BR_REWRITE1; } else if (re().is_union(r, r1, r2)) { - dr1 = derivative(elem, r1); - dr2 = derivative(elem, r2); - if (dr1 && dr2) { - result = re().mk_union(dr1, dr2); - } + result = re().mk_union( + re().mk_derivative(ele, r1), + re().mk_derivative(ele, r2) + ); + return BR_REWRITE2; } else if (re().is_intersection(r, r1, r2)) { - dr1 = derivative(elem, r1); - dr2 = derivative(elem, r2); - if (dr1 && dr2) { - result = re().mk_inter(dr1, dr2); - } + result = re().mk_inter( + re().mk_derivative(ele, r1), + re().mk_derivative(ele, r2) + ); + return BR_REWRITE2; + } + else if (re().is_diff(r, r1, r2)) { + result = re().mk_diff( + re().mk_derivative(ele, r1), + re().mk_derivative(ele, r2) + ); + return BR_REWRITE2; + } + else if (m().is_ite(r, p, r1, r2)) { + result = m().mk_ite( + p, + re().mk_derivative(ele, r1), + re().mk_derivative(ele, r2) + ); + return BR_REWRITE2; } else if (re().is_opt(r, r1)) { - result = derivative(elem, r1); + result = re().mk_derivative(ele, r1); + return BR_REWRITE1; } else if (re().is_complement(r, r1)) { - result = derivative(elem, r1); - if (result) { - result = re().mk_complement(result); - } + result = re().mk_complement(re().mk_derivative(ele, r1)); + return BR_REWRITE2; } else if (re().is_loop(r, r1, lo)) { - result = derivative(elem, r1); - if (result) { - if (lo > 0) { - lo--; - } - result = re().mk_concat( - result, - re().mk_loop(r1, lo) - ); + if (lo > 0) { + lo--; } + result = re().mk_concat( + re().mk_derivative(ele, r1), + re().mk_loop(r1, lo) + ); + return BR_REWRITE2; } else if (re().is_loop(r, r1, lo, hi)) { if (hi == 0) { result = re().mk_empty(m().get_sort(r)); + return BR_DONE; } - else { - result = derivative(elem, r1); - if (result) { - hi--; - if (lo > 0) { - lo--; - } - result = re().mk_concat( - result, - re().mk_loop(r1, lo, hi) - ); - } + hi--; + if (lo > 0) { + lo--; } + result = re().mk_concat( + re().mk_derivative(ele, r1), + re().mk_loop(r1, lo, hi) + ); + return BR_REWRITE2; } else if (re().is_full_seq(r) || re().is_empty(r)) { result = r; + return BR_DONE; } else if (re().is_to_re(r, r1)) { // r1 is a string here (not a regexp) - expr_ref hd(m()); - expr_ref tl(m()); + expr_ref hd(m()), tl(m()); if (get_head_tail(r1, hd, tl)) { // head must be equal; if so, derivative is tail - result = re_and( - m().mk_eq(elem, hd), - re().mk_to_re(tl) - ); + result = re_and(m().mk_eq(ele, hd),re().mk_to_re(tl)); + return BR_REWRITE2; } else if (m_util.str.is_empty(r1)) { result = re().mk_empty(m().get_sort(r)); + return BR_DONE; + } + else { + return BR_FAILED; + } + } + else if (re().is_reverse(r, r1) && re().is_to_re(r1, r2)) { + // Reverses are rewritten so that the only derivative case is + // derivative of a reverse of a string. (All other cases stuck) + // This is analagous to the previous is_to_re case. + expr_ref hd(m()), tl(m()); + if (get_head_tail_reversed(r2, hd, tl)) { + result = re_and( + m().mk_eq(ele, tl), + re().mk_reverse(re().mk_to_re(hd)) + ); + return BR_REWRITE3; + } + else if (m_util.str.is_empty(r2)) { + result = re().mk_empty(m().get_sort(r)); + return BR_DONE; + } + else { + return BR_FAILED; } } else if (re().is_range(r, r1, r2)) { @@ -2270,31 +2391,30 @@ expr_ref seq_rewriter::derivative(expr* elem, expr* r) { if (s1.length() == 1 && s2.length() == 1) { r1 = m_util.mk_char(s1[0]); r2 = m_util.mk_char(s2[0]); - result = m().mk_and(m_util.mk_le(r1, elem), m_util.mk_le(elem, r2)); + result = m().mk_and(m_util.mk_le(r1, ele), m_util.mk_le(ele, r2)); result = re_predicate(result, seq_sort); + return BR_REWRITE3; } else { result = re().mk_empty(m().get_sort(r)); + return BR_DONE; } } } else if (re().is_full_char(r)) { result = re().mk_to_re(m_util.str.mk_empty(seq_sort)); + return BR_DONE; } else if (re().is_of_pred(r, p)) { array_util array(m()); - expr* args[2] = { p, elem }; + expr* args[2] = { p, ele }; result = array.mk_select(2, args); result = re_predicate(result, seq_sort); + return BR_REWRITE2; } - else if (m().is_ite(r, p, r1, r2)) { - dr1 = derivative(elem, r1); - dr2 = derivative(elem, r2); - if (dr1 && dr2) { - result = m().mk_ite(p, dr1, dr2); - } - } - return result; + // stuck cases: re().is_derivative, variable, ... + // and re().is_reverse if the reverse is not applied to a string + return BR_FAILED; } /* @@ -2409,7 +2529,6 @@ bool seq_rewriter::rewrite_contains_pattern(expr* a, expr* b, expr_ref& result) while (str().is_concat(u, z, u) && (str().is_unit(z) || str().is_string(z))) { m_lhs.push_back(z); } - bool no_overlaps = true; for (auto const& p : patterns) if (!non_overlap(p, m_lhs)) return false; @@ -2436,6 +2555,14 @@ bool seq_rewriter::rewrite_contains_pattern(expr* a, expr* b, expr_ref& result) return true; } +/* + a in empty -> false + a in full -> true + a in (str.to_re a') -> (a == a') + "" in b -> is_nullable(b) + (ele + tail) in b -> tail in (derivative e b) + (head + ele) in b -> head in (right-derivative e b) +*/ br_status seq_rewriter::mk_str_in_regexp(expr* a, expr* b, expr_ref& result) { if (re().is_empty(b)) { @@ -2461,17 +2588,21 @@ br_status seq_rewriter::mk_str_in_regexp(expr* a, expr* b, expr_ref& result) { expr_ref hd(m()), tl(m()); if (get_head_tail(a, hd, tl)) { - expr_ref db = derivative(hd, b); // null if failed - if (db) { - result = re().mk_in_re(tl, db); - return BR_REWRITE_FULL; - } + result = re().mk_in_re(tl, re().mk_derivative(hd, b)); + return BR_REWRITE2; + } + else if (get_head_tail_reversed(a, hd, tl)) { + result = re().mk_in_re( + hd, + re().mk_reverse(re().mk_derivative(tl, re().mk_reverse(b))) + ); + return BR_REWRITE_FULL; } - if (rewrite_contains_pattern(a, b, result)) + if (rewrite_contains_pattern(a, b, result)) return BR_REWRITE_FULL; - return BR_FAILED; + return BR_FAILED; } br_status seq_rewriter::mk_str_to_regexp(expr* a, expr_ref& result) { diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 43d51a994..da1a03193 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -135,6 +135,7 @@ class seq_rewriter { // Support for regular expression derivatives bool get_head_tail(expr* e, expr_ref& head, expr_ref& tail); + bool get_head_tail_reversed(expr* e, expr_ref& head, expr_ref& tail); expr_ref re_and(expr* cond, expr* r); expr_ref re_predicate(expr* cond, sort* seq_sort); @@ -175,6 +176,8 @@ class seq_rewriter { br_status mk_re_power(func_decl* f, expr* a, expr_ref& result); br_status mk_re_loop(func_decl* f, unsigned num_args, expr* const* args, expr_ref& result); br_status mk_re_range(expr* lo, expr* hi, expr_ref& result); + br_status mk_re_reverse(expr* r, expr_ref& result); + br_status mk_re_derivative(expr* ele, expr* r, expr_ref& result); br_status lift_ite(func_decl* f, unsigned n, expr* const* args, expr_ref& result); br_status reduce_re_eq(expr* a, expr* b, expr_ref& result); br_status reduce_re_is_empty(expr* r, expr_ref& result); @@ -265,8 +268,6 @@ public: void add_seqs(expr_ref_vector const& ls, expr_ref_vector const& rs, expr_ref_pair_vector& new_eqs); - expr_ref derivative(expr* hd, expr* r); - expr_ref is_nullable(expr* r); bool has_cofactor(expr* r, expr_ref& cond, expr_ref& th, expr_ref& el); diff --git a/src/ast/seq_decl_plugin.cpp b/src/ast/seq_decl_plugin.cpp index 1456676b4..9e112c175 100644 --- a/src/ast/seq_decl_plugin.cpp +++ b/src/ast/seq_decl_plugin.cpp @@ -566,6 +566,7 @@ void seq_decl_plugin::init() { sort* seqAseqA[2] = { seqA, seqA }; sort* seqAreA[2] = { seqA, reA }; sort* reAreA[2] = { reA, reA }; + sort* AreA[2] = { A, reA }; sort* seqAint2T[3] = { seqA, intT, intT }; sort* seq2AintT[3] = { seqA, seqA, intT }; sort* str2T[2] = { strT, strT }; @@ -607,6 +608,8 @@ void seq_decl_plugin::init() { m_sigs[OP_RE_FULL_SEQ_SET] = alloc(psig, m, "re.all", 1, 0, nullptr, reA); m_sigs[OP_RE_FULL_CHAR_SET] = alloc(psig, m, "re.allchar", 1, 0, nullptr, reA); m_sigs[OP_RE_OF_PRED] = alloc(psig, m, "re.of.pred", 1, 1, &predA, reA); + m_sigs[OP_RE_REVERSE] = alloc(psig, m, "re.reverse", 1, 1, &reA, reA); + m_sigs[OP_RE_DERIVATIVE] = alloc(psig, m, "re.derivative", 1, 2, AreA, reA); m_sigs[OP_SEQ_TO_RE] = alloc(psig, m, "seq.to.re", 1, 1, &seqA, reA); m_sigs[OP_SEQ_IN_RE] = alloc(psig, m, "seq.in.re", 1, 2, seqAreA, boolT); m_sigs[OP_SEQ_REPLACE_RE_ALL] = alloc(psig, m, "str.replace_re_all", 1, 3, seqAreAseqA, seqA); @@ -748,6 +751,8 @@ func_decl * seq_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, case OP_RE_RANGE: case OP_RE_OF_PRED: case OP_RE_COMPLEMENT: + case OP_RE_REVERSE: + case OP_RE_DERIVATIVE: m_has_re = true; // fall-through case OP_SEQ_UNIT: diff --git a/src/ast/seq_decl_plugin.h b/src/ast/seq_decl_plugin.h index 0f245a03e..466688a44 100644 --- a/src/ast/seq_decl_plugin.h +++ b/src/ast/seq_decl_plugin.h @@ -75,6 +75,8 @@ enum seq_op_kind { OP_RE_FULL_SEQ_SET, OP_RE_FULL_CHAR_SET, OP_RE_OF_PRED, + OP_RE_REVERSE, + OP_RE_DERIVATIVE, // Char -> RegEx -> RegEx // string specific operators. @@ -427,6 +429,8 @@ public: app* mk_full_seq(sort* s); app* mk_empty(sort* s); app* mk_of_pred(expr* p); + app* mk_reverse(expr* r) { return m.mk_app(m_fid, OP_RE_REVERSE, r); } + app* mk_derivative(expr* ele, expr* r) { return m.mk_app(m_fid, OP_RE_DERIVATIVE, ele, r); } bool is_to_re(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_TO_RE); } bool is_concat(expr const* n) const { return is_app_of(n, m_fid, OP_RE_CONCAT); } @@ -443,6 +447,8 @@ public: bool is_full_char(expr const* n) const { return is_app_of(n, m_fid, OP_RE_FULL_CHAR_SET); } bool is_full_seq(expr const* n) const { return is_app_of(n, m_fid, OP_RE_FULL_SEQ_SET); } bool is_of_pred(expr const* n) const { return is_app_of(n, m_fid, OP_RE_OF_PRED); } + bool is_reverse(expr const* n) const { return is_app_of(n, m_fid, OP_RE_REVERSE); } + bool is_derivative(expr const* n) const { return is_app_of(n, m_fid, OP_RE_DERIVATIVE); } MATCH_UNARY(is_to_re); MATCH_BINARY(is_concat); MATCH_BINARY(is_union); @@ -454,6 +460,8 @@ public: MATCH_UNARY(is_plus); MATCH_UNARY(is_opt); MATCH_UNARY(is_of_pred); + MATCH_UNARY(is_reverse); + MATCH_BINARY(is_derivative); bool is_loop(expr const* n, expr*& body, unsigned& lo, unsigned& hi); bool is_loop(expr const* n, expr*& body, unsigned& lo); bool is_loop(expr const* n, expr*& body, expr*& lo, expr*& hi); diff --git a/src/smt/seq_regex.cpp b/src/smt/seq_regex.cpp index 64c1737ab..1440d2db9 100644 --- a/src/smt/seq_regex.cpp +++ b/src/smt/seq_regex.cpp @@ -216,10 +216,9 @@ namespace smt { return false; // (accept s i R) & len(s) > i => (accept s (+ i 1) D(nth(s, i), R)) or conds - expr_ref head = th.mk_nth(s, i); - d = seq_rw().derivative(head, d); - if (!d) - throw default_exception("unable to expand derivative"); + expr_ref head = th.mk_nth(s, i); + d = re().mk_derivative(head, r); + rewrite(d); literal acc_next = th.mk_literal(sk().mk_accept(s, a().mk_int(idx + 1), d)); conds.push_back(~lit); @@ -339,9 +338,9 @@ namespace smt { return; literal null_lit = th.mk_literal(is_nullable); expr_ref hd = mk_first(r); - expr_ref d = seq_rw().derivative(hd, r); - if (!d) - throw default_exception("derivative was not defined"); + expr_ref d(m); + d = re().mk_derivative(hd, r); + rewrite(d); literal_vector lits; lits.push_back(~lit); if (null_lit != false_literal) @@ -382,9 +381,9 @@ namespace smt { } th.add_axiom(~lit, ~th.mk_literal(is_nullable)); expr_ref hd = mk_first(r); - expr_ref d = seq_rw().derivative(hd, r); - if (!d) - throw default_exception("derivative was not defined"); + expr_ref d(m); + d = re().mk_derivative(hd, r); + rewrite(d); literal_vector lits; expr_ref_pair_vector cofactors(m); seq_rw().get_cofactors(d, cofactors);