diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index e466249a3..b00ec1405 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -500,7 +500,7 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con SASSERT(f->get_family_id() == get_fid()); br_status st = BR_FAILED; switch(f->get_decl_kind()) { - + case OP_SEQ_UNIT: SASSERT(num_args == 1); st = mk_seq_unit(args[0], result); @@ -529,17 +529,17 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con break; case OP_RE_CONCAT: if (num_args == 1) { - result = args[0]; + result = args[0]; st = BR_DONE; } else { SASSERT(num_args == 2); - st = mk_re_concat(args[0], args[1], result); + st = mk_re_concat(args[0], args[1], result); } break; case OP_RE_UNION: if (num_args == 1) { - result = args[0]; + result = args[0]; st = BR_DONE; } else { @@ -551,13 +551,13 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con SASSERT(num_args == 2); st = mk_re_range(args[0], args[1], result); break; - case OP_RE_DIFF: + case OP_RE_DIFF: if (num_args == 2) st = mk_re_diff(args[0], args[1], result); else if (num_args == 1) { result = args[0]; st = BR_DONE; - } + } break; case OP_RE_INTERSECT: if (num_args == 1) { @@ -580,16 +580,16 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con st = mk_re_power(f, args[0], result); break; case OP_RE_EMPTY_SET: - return BR_FAILED; + return BR_FAILED; case OP_RE_FULL_SEQ_SET: - return BR_FAILED; + return BR_FAILED; case OP_RE_FULL_CHAR_SET: - return BR_FAILED; + return BR_FAILED; case OP_RE_OF_PRED: - return BR_FAILED; + return BR_FAILED; case _OP_SEQ_SKOLEM: - return BR_FAILED; - case OP_SEQ_CONCAT: + return BR_FAILED; + case OP_SEQ_CONCAT: if (num_args == 1) { result = args[0]; st = BR_DONE; @@ -607,25 +607,25 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con SASSERT(num_args == 3); st = mk_seq_extract(args[0], args[1], args[2], result); break; - case OP_SEQ_CONTAINS: + case OP_SEQ_CONTAINS: SASSERT(num_args == 2); st = mk_seq_contains(args[0], args[1], result); break; case OP_SEQ_AT: SASSERT(num_args == 2); - st = mk_seq_at(args[0], args[1], result); + st = mk_seq_at(args[0], args[1], result); break; case OP_SEQ_NTH: SASSERT(num_args == 2); - return mk_seq_nth(args[0], args[1], result); + return mk_seq_nth(args[0], args[1], result); case OP_SEQ_NTH_I: SASSERT(num_args == 2); - return mk_seq_nth_i(args[0], args[1], result); - case OP_SEQ_PREFIX: + return mk_seq_nth_i(args[0], args[1], result); + case OP_SEQ_PREFIX: SASSERT(num_args == 2); st = mk_seq_prefix(args[0], args[1], result); break; - case OP_SEQ_SUFFIX: + case OP_SEQ_SUFFIX: SASSERT(num_args == 2); st = mk_seq_suffix(args[0], args[1], result); break; @@ -690,17 +690,25 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con SASSERT(num_args == 1); st = mk_str_stoi(args[0], result); break; + case OP_ITE: + // Rewrite ITEs in the case of regexes + SASSERT(num_args == 3); + if (m_util.is_re(args[1])) { + SASSERT(m_util.is_re(args[2])); + st = mk_re_ite(args[0], args[1], args[2], result); + } + break; case _OP_STRING_CONCAT: case _OP_STRING_PREFIX: case _OP_STRING_SUFFIX: case _OP_STRING_STRCTN: case _OP_STRING_LENGTH: case _OP_STRING_CHARAT: - case _OP_STRING_IN_REGEXP: - case _OP_STRING_TO_REGEXP: - case _OP_STRING_SUBSTR: + case _OP_STRING_IN_REGEXP: + case _OP_STRING_TO_REGEXP: + case _OP_STRING_SUBSTR: case _OP_STRING_STRREPL: - case _OP_STRING_STRIDOF: + case _OP_STRING_STRIDOF: UNREACHABLE(); } // if (st == BR_FAILED) { @@ -2218,14 +2226,70 @@ expr_ref seq_rewriter::is_nullable(expr* r) { return result; } +/* + Symbolic derivative: regex -> seq -> regex + Recursive version. + seq should be single char. + Uses BDD representation (and aux functions bdd_union, bdd_concat, etc) + to enable efficiently handling a set of constraints on the +*/ +// expr_ref seq_rewriter::derivative() + + +/* + Optimizations for ITEs of regexes, since they come up frequently + in calculating derivatives. + + ite(not c, r1, r2) -> ite(c, r2, r1) + ite(c, ite(c, r1, r2), r3)) -> ite(c, r1, r3) + ite(c, r1, ite(c, r2, r3)) -> ite(c, r1, r3) + ite(c1, ite(c2, r1, r2), r3) where id of c1 < id of c2 -> + ite(c2, ite(c1, r1, r3), ite(c1, r2, r3)) + ite(c1, r1, ite(c2, r2, r3)) where id of c1 < id of c2 -> + ite(c2, ite(c1, r1, r2), ite(c1, r1, r3)) +*/ +br_status seq_rewriter::mk_re_ite(expr* cond, expr* r1, expr* r2, expr_ref& result) { + VERIFY(m_util.is_re(r1)); + VERIFY(m_util.is_re(r2)); + expr *c = nullptr, *ra = nullptr, *rb = nullptr; + if (m().is_not(cond, c)) { + result = m().mk_ite(c, r2, r1); + return BR_REWRITE1; + } + if (m().is_ite(r1, c, ra, rb)) { + if (m().are_equal(c, cond)) { + result = m().mk_ite(cond, ra, r2); + return BR_REWRITE1; + } + if (cond->get_id() < c->get_id()) { + expr *result1 = m().mk_ite(cond, ra, r2); + expr *result2 = m().mk_ite(cond, rb, r2); + result = m().mk_ite(c, result1, result2); + return BR_REWRITE2; + } + } + if (m().is_ite(r2, c, ra, rb)) { + if (m().are_equal(c, cond)) { + result = m().mk_ite(cond, r1, rb); + return BR_REWRITE1; + } + if (cond->get_id() < c->get_id()) { + expr *result1 = m().mk_ite(cond, r1, ra); + expr* result2 = m().mk_ite(cond, r1, rb); + return BR_REWRITE2; + } + } + return BR_DONE; +} + /* Push reverse inwards (whenever possible). */ br_status seq_rewriter::mk_re_reverse(expr* r, expr_ref& result) { sort* seq_sort = nullptr; VERIFY(m_util.is_re(r, seq_sort)); - expr* r1 = nullptr, *r2 = nullptr, *p = nullptr, *s = nullptr; - expr* s1 = nullptr, *s2 = nullptr; + expr *r1 = nullptr, *r2 = nullptr, *p = nullptr, *s = nullptr; + expr *s1 = nullptr, *s2 = nullptr; zstring zs; unsigned lo = 0, hi = 0; if (re().is_concat(r, r1, r2)) { @@ -2308,11 +2372,11 @@ br_status seq_rewriter::mk_re_reverse(expr* r, expr_ref& result) { seq should be single char */ br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) { - sort* seq_sort = nullptr, *ele_sort = nullptr; + sort *seq_sort = nullptr, *ele_sort = nullptr; VERIFY(m_util.is_re(r, seq_sort)); VERIFY(m_util.is_seq(seq_sort, ele_sort)); SASSERT(ele_sort == m().get_sort(ele)); - expr* r1 = nullptr, *r2 = nullptr, *p = nullptr; + expr *r1 = nullptr, *r2 = nullptr, *p = nullptr; unsigned lo = 0, hi = 0; if (re().is_concat(r, r1, r2)) { std::cout << "is_nullable -- from concat" << std::endl; @@ -2949,9 +3013,12 @@ br_status seq_rewriter::mk_re_concat(expr* a, expr* b, expr_ref& result) { return BR_FAILED; } /* - (a + a) = a - (a + eps) = a - (eps + a) = a + (a + a) = a + (a + eps) = a + (eps + a) = a + + if-then-else lifting: + (ite p r1 r2) + b -> ite p (r1 + b) (r2 + b) */ br_status seq_rewriter::mk_re_union(expr* a, expr* b, expr_ref& result) { if (a == b) { @@ -2982,9 +3049,30 @@ br_status seq_rewriter::mk_re_union(expr* a, expr* b, expr_ref& result) { result = b; return BR_DONE; } + expr *a1 = nullptr, *a2 = nullptr, + *b1 = nullptr, *b2 = nullptr, *cond = nullptr; + if (m().is_ite(a, cond, a1, a2)) { + result = m().mk_ite(cond, re().mk_union(a1, b), + re().mk_union(a2, b)); + return BR_REWRITE2; + } + if (m().is_ite(b, cond, b1, b2)) { + result = m().mk_ite(cond, re().mk_union(a, b1), + re().mk_union(a, b2)); + return BR_REWRITE2; + } return BR_FAILED; } +/* + comp(intersect e1 e2) -> union comp(e1) comp(e2) + comp(union e1 e2) -> intersect comp(e1) comp(e2) + comp(none) = all + comp(all) = none + + if-then-else lifting: + comp(ite p e1 e2) -> ite p comp(e1) comp(e2) +*/ br_status seq_rewriter::mk_re_complement(expr* a, expr_ref& result) { expr* e1, *e2; if (re().is_intersection(a, e1, e2)) { @@ -3003,15 +3091,28 @@ br_status seq_rewriter::mk_re_complement(expr* a, expr_ref& result) { result = re().mk_empty(m().get_sort(a)); return BR_DONE; } + expr *a1 = nullptr, *a2 = nullptr, *cond = nullptr; + if (m().is_ite(a, cond, a1, a2)) { + result = m().mk_ite(cond, re().mk_complement(a1), + re().mk_complement(a2)); + return BR_REWRITE2; + } return BR_FAILED; } /** - (emp n r) = emp - (r n emp) = emp - (all n r) = r - (r n all) = r - (r n r) = r + (r n r) = r + (emp n r) = emp + (r n emp) = emp + (all n r) = r + (r n all) = r + (r n comp(r)) = emp + (comp(r) n r) = emp + (r n to_re(s)) = ite (s in r) to_re(s) emp + (to_re(s) n r) = " + + if-then-else lifting: + (ite p r1 r2) n b -> ite p (r1 n b) (r2 n b) */ br_status seq_rewriter::mk_re_inter(expr* a, expr* b, expr_ref& result) { if (a == b) { @@ -3040,6 +3141,18 @@ br_status seq_rewriter::mk_re_inter(expr* a, expr* b, expr_ref& result) { result = re().mk_empty(m().get_sort(a)); return BR_DONE; } + expr *a1 = nullptr, *a2 = nullptr, + *b1 = nullptr, *b2 = nullptr, *cond = nullptr; + if (m().is_ite(a, cond, a1, a2)) { + result = m().mk_ite(cond, re().mk_inter(a1, b), + re().mk_inter(a2, b)); + return BR_REWRITE2; + } + if (m().is_ite(b, cond, b1, b2)) { + result = m().mk_ite(cond, re().mk_inter(a, b1), + re().mk_inter(a, b2)); + return BR_REWRITE2; + } if (re().is_to_re(b)) std::swap(a, b); expr* s = nullptr; @@ -3055,13 +3168,12 @@ br_status seq_rewriter::mk_re_diff(expr* a, expr* b, expr_ref& result) { return BR_REWRITE2; } - br_status seq_rewriter::mk_re_loop(func_decl* f, unsigned num_args, expr* const* args, expr_ref& result) { rational n1, n2; unsigned lo, hi, lo2, hi2, np; expr* a = nullptr; switch (num_args) { - case 1: + case 1: np = f->get_num_parameters(); lo2 = np > 0 ? f->get_parameter(0).get_int() : 0; hi2 = np > 1 ? f->get_parameter(1).get_int() : lo2; @@ -3089,10 +3201,29 @@ br_status seq_rewriter::mk_re_loop(func_decl* f, unsigned num_args, expr* const* result = args[0]; return BR_DONE; } - // (loop a 0) = a* - if (np == 1 && lo2 == 0) { + // (loop a) = (loop a 0) = a* + if ((np == 0) || + (np == 1 && lo2 == 0)) { result = re().mk_star(args[0]); - return BR_DONE; + return BR_REWRITE1; + } + // if-then-else lifting: loop (ite p r1 r2) -> ite p (loop r1) (loop r2) + expr *cond = nullptr, *a1 = nullptr, *a2 = nullptr; + if (m().is_ite(args[0], cond, a1, a2)) { + if (np == 0) { + expr_ref result1 = re().mk_loop(a1); + expr_ref result2 = re().mk_loop(a2); + } + else if (np == 1) { + expr_ref result1 = re().mk_loop(a1, lo2); + expr_ref result2 = re().mk_loop(a2, lo2); + } + else if (np == 2) { + expr_ref result1 = re().mk_loop(a1, lo2, hi2); + expr_ref result2 = re().mk_loop(a2, lo2, hi2); + } + result = m().mk_ite(cond, result1, result2); + return BR_REWRITE2; } break; case 2: @@ -3122,13 +3253,16 @@ br_status seq_rewriter::mk_re_power(func_decl* f, expr* a, expr_ref& result) { /* - a** = a* - (a* + b)* = (a + b)* - (a + b*)* = (a + b)* - (a*b*)* = (a + b)* - a+* = a* - emp* = "" - all* = all + a** = a* + (a* + b)* = (a + b)* + (a + b*)* = (a + b)* + (a*b*)* = (a + b)* + a+* = a* + emp* = "" + all* = all + + if-then-else lifting: + (ite p r1 r2)* -> ite p (r1)* (r2)* */ br_status seq_rewriter::mk_re_star(expr* a, expr_ref& result) { expr* b, *c, *b1, *c1; @@ -3173,6 +3307,12 @@ br_status seq_rewriter::mk_re_star(expr* a, expr_ref& result) { result = re().mk_star(re().mk_union(b1, c1)); return BR_REWRITE2; } + expr *a1 = nullptr, *a2 = nullptr, *cond = nullptr; + if (m().is_ite(a, cond, a1, a2)) { + result = m().mk_ite(cond, re().mk_star(a1), + re().mk_star(a2)); + return BR_REWRITE2; + } return BR_FAILED; } diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 873f88581..ce4b5bedb 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -219,6 +219,7 @@ class seq_rewriter { br_status mk_re_range(expr* lo, expr* hi, expr_ref& result); br_status mk_re_reverse(expr* r, expr_ref& result); br_status mk_re_derivative(expr* ele, expr* r, expr_ref& result); + br_status mk_re_ite(expr* ele, expr* r, expr_ref& result); br_status lift_ite(func_decl* f, unsigned n, expr* const* args, expr_ref& result); br_status reduce_re_eq(expr* a, expr* b, expr_ref& result); br_status reduce_re_is_empty(expr* r, expr_ref& result); @@ -310,8 +311,16 @@ public: void add_seqs(expr_ref_vector const& ls, expr_ref_vector const& rs, expr_ref_pair_vector& new_eqs); + // Support for regular expression derivatives expr_ref is_nullable(expr* r); expr_ref is_nullable_rec(expr* r); + // expr_ref derivative(expr* r); + // expr_ref derivative_rec(expr* r); + // expr_ref bdd_union(expr* r); + // expr_ref bdd_inter(expr* r); + // expr_ref bdd_comp(expr* r); + // expr_ref bdd_concat(expr* r); + // expr_ref bdd_star(expr* r); bool has_cofactor(expr* r, expr_ref& cond, expr_ref& th, expr_ref& el);