From 766e979641fdc07755929d957bcbcb106eda09c7 Mon Sep 17 00:00:00 2001 From: calebstanford-msr Date: Mon, 8 Jun 2020 13:17:33 -0400 Subject: [PATCH] fix derivative interaction with reverse; add flags for left/right derivative and lifting over union/intersection --- src/ast/rewriter/seq_rewriter.cpp | 181 ++++++++++++++++++++++-------- src/ast/rewriter/seq_rewriter.h | 10 +- src/ast/seq_decl_plugin.h | 1 + 3 files changed, 145 insertions(+), 47 deletions(-) diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index e1237fcda..9036a409c 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -2319,21 +2319,35 @@ br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) { } /* - Recursive implementation of the symbolic derivative such that + Memoized, recursive implementation of the symbolic derivative such that the result is in an optimized BDD form. + flags: + - lift_over_union, lift_over_inter (default true) + If false, then preserve unions, intersections (respectively) + at the top level. + Note that memoization ignores these flags, so if called + on the same expression with different flags, will get the same + result. + - left (default true) + Take a left-derivative. If false take a right-derivative. + Definition of BDD form: if-then-elses are pushed outwards and sorted by condition ID (cond->get_id()), from largest on the outside to smallest on the inside. Duplicate nested conditions are eliminated. - */ -expr_ref seq_rewriter::mk_derivative(expr* ele, expr* r) { - expr_ref result(m_op_cache.find(OP_RE_DERIVATIVE, ele, r, nullptr), m()); +expr_ref seq_rewriter::mk_derivative(expr* ele, expr* r, + bool left, + bool lift_over_union, + bool lift_over_inter) { + decl_kind k = left ? OP_RE_DERIVATIVE : _OP_RE_RIGHT_DERIVATIVE; + expr_ref result(m_op_cache.find(k, ele, r, nullptr), m()); if (!result) { - result = mk_derivative_rec(ele, r); - m_op_cache.insert(OP_RE_DERIVATIVE, ele, r, nullptr, result); + result = mk_derivative_rec(ele, r, + lift_over_union, lift_over_inter, left); + m_op_cache.insert(k, ele, r, nullptr, result); } return result; } @@ -2351,10 +2365,10 @@ expr_ref seq_rewriter::mk_der_concat(expr* r1, expr* r2) { } /* - Form a derivative by combining two if-then-else expressions in BDD form. + Apply a binary operation, preserving BDD normal form on derivative expressions. Preconditions: - - k is a binary op code on REs (re.union, re.inter, etc.) + - k is a binary op code on REs (concat, intersection, or union) - a and b are in BDD form Postcondition: @@ -2426,18 +2440,37 @@ expr_ref seq_rewriter::mk_der_op(decl_kind k, expr* a, expr* b) { expr_ref seq_rewriter::mk_der_compl(expr* r) { expr_ref result(m_op_cache.find(OP_RE_COMPLEMENT, r, nullptr, nullptr), m()); if (!result) { - expr* c = nullptr, * r1 = nullptr, * r2 = nullptr; + expr *c = nullptr, *r1 = nullptr, *r2 = nullptr; if (m().is_ite(r, c, r1, r2)) { result = m().mk_ite(c, mk_der_compl(r1), mk_der_compl(r2)); } - else if (BR_FAILED == mk_re_complement(r, result)) + else if (BR_FAILED == mk_re_complement(r, result)) { result = re().mk_complement(r); + } + m_op_cache.insert(OP_RE_COMPLEMENT, r, nullptr, nullptr, result); } - m_op_cache.insert(OP_RE_COMPLEMENT, r, nullptr, nullptr, result); return result; } -expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r) { +expr_ref seq_rewriter::mk_der_reverse(expr* r) { + expr_ref result(m_op_cache.find(OP_RE_REVERSE, r, nullptr, nullptr), m()); + if (!result) { + expr *c = nullptr, *r1 = nullptr, *r2 = nullptr; + if (m().is_ite(r, c, r1, r2)) { + result = m().mk_ite(c, mk_der_reverse(r1), mk_der_reverse(r2)); + } + else if (BR_FAILED == mk_re_reverse(r, result)) { + result = re().mk_reverse(r); + } + m_op_cache.insert(OP_RE_REVERSE, r, nullptr, nullptr, result); + } + return result; +} + +expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r, + bool left, + bool lift_over_union, + bool lift_over_inter) { expr_ref result(m()); sort* seq_sort = nullptr, *ele_sort = nullptr; VERIFY(m_util.is_re(r, seq_sort)); @@ -2446,7 +2479,7 @@ expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r) { expr* r1 = nullptr, *r2 = nullptr, *p = nullptr; auto mk_empty = [&]() { return expr_ref(re().mk_empty(m().get_sort(r)), m()); }; unsigned lo = 0, hi = 0; - if (re().is_concat(r, r1, r2)) { + if (re().is_concat(r, r1, r2) && left) { expr_ref is_n = is_nullable(r1); expr_ref dr1 = mk_derivative(ele, r1); result = mk_der_concat(dr1, r2); @@ -2455,40 +2488,91 @@ expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r) { } expr_ref dr2 = mk_derivative(ele, r2); is_n = re_predicate(is_n, seq_sort); - return mk_der_union(result, mk_der_concat(is_n, dr2)); + if (lift_over_union) { + return mk_der_union(result, mk_der_concat(is_n, dr2)); + } + else { + return expr_ref(re().mk_union(result, mk_der_concat(is_n, dr2)), m()); + } + } + else if (re().is_concat(r, r1, r2) && !left) { + expr_ref is_n = is_nullable(r2); + expr_ref dr2 = mk_derivative(ele, r2, left); + result = mk_der_concat(r1, dr2); + if (m().is_false(is_n)) { + return result; + } + expr_ref dr1 = mk_derivative(ele, r1, left); + is_n = re_predicate(is_n, seq_sort); + if (lift_over_union) { + return mk_der_union(result, mk_der_concat(dr1, is_n)); + } + else { + return expr_ref(re().mk_union(result, mk_der_concat(is_n, dr2)), m()); + } } else if (re().is_star(r, r1)) { - return mk_der_concat(mk_derivative(ele, r1), r); + if (left) { + return mk_der_concat(mk_derivative(ele, r1, left), r); + } + else { + return mk_der_concat(r, mk_derivative(ele, r1, left)); + } } else if (re().is_plus(r, r1)) { expr_ref star(re().mk_star(r1), m()); - return mk_derivative(ele, star); + return mk_derivative(ele, star, left); } else if (re().is_union(r, r1, r2)) { - return mk_der_union(mk_derivative(ele, r1), mk_derivative(ele, r2)); + if (!lift_over_union) { + return expr_ref(re().mk_union( + mk_derivative(ele, r1, left, lift_over_union, lift_over_inter), + mk_derivative(ele, r2, left, lift_over_union, lift_over_inter) + ), m()); + } else { + return mk_der_union(mk_derivative(ele, r1, left), + mk_derivative(ele, r2, left)); + } } else if (re().is_intersection(r, r1, r2)) { - return mk_der_inter(mk_derivative(ele, r1), mk_derivative(ele, r2)); + if (!lift_over_inter) { + return expr_ref(re().mk_inter( + mk_derivative(ele, r1, left, lift_over_union, lift_over_inter), + mk_derivative(ele, r2, left, lift_over_union, lift_over_inter) + ), m()); + } else { + return mk_der_inter(mk_derivative(ele, r1, left), + mk_derivative(ele, r2, left)); + } } else if (re().is_diff(r, r1, r2)) { - return mk_der_inter(mk_derivative(ele, r1), mk_der_compl(mk_derivative(ele, r2))); + return mk_derivative(ele, re().mk_inter(r1, re().mk_complement(r2)), + left, lift_over_union, lift_over_inter); } else if (m().is_ite(r, p, r1, r2)) { // there is no BDD normalization here - result = m().mk_ite(p, mk_derivative(ele, r1), mk_derivative(ele, r2)); + result = m().mk_ite(p, mk_derivative(ele, r1, left), + mk_derivative(ele, r2, left)); return result; } else if (re().is_opt(r, r1)) { - return mk_derivative(ele, r1); + return mk_derivative(ele, r1, left, lift_over_union, lift_over_inter); } else if (re().is_complement(r, r1)) { - return mk_der_compl(mk_derivative(ele, r1)); + // If lift_over_union and lift_over_inter are false, this stops + // lifting. It would be possible to do smarter lifting here + return mk_der_compl(mk_derivative(ele, r1, left)); } else if (re().is_loop(r, r1, lo)) { if (lo > 0) { lo--; } - return mk_der_concat(mk_derivative(ele, r1), re().mk_loop(r1, lo)); + if (left) { + return mk_der_concat(mk_derivative(ele, r1), re().mk_loop(r1, lo)); + } else { + return mk_der_concat(re().mk_loop(r1, lo), + mk_derivative(ele, r1, left)); + } } else if (re().is_loop(r, r1, lo, hi)) { if (hi == 0) { @@ -2498,7 +2582,12 @@ expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r) { if (lo > 0) { lo--; } - return mk_der_concat(mk_derivative(ele, r1), re().mk_loop(r1, lo, hi)); + if (left) { + return mk_der_concat(mk_derivative(ele, r1), re().mk_loop(r1, lo, hi)); + } else { + return mk_der_concat(re().mk_loop(r1, lo, hi), + mk_derivative(ele, r1, left)); + } } else if (re().is_full_seq(r) || re().is_empty(r)) { @@ -2507,31 +2596,25 @@ expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r) { else if (re().is_to_re(r, r1)) { // r1 is a string here (not a regexp) expr_ref hd(m()), tl(m()); - if (get_head_tail(r1, hd, tl)) { + if (left && get_head_tail(r1, hd, tl)) { // head must be equal; if so, derivative is tail return re_and(m().mk_eq(ele, hd), re().mk_to_re(tl)); } + else if (!left && get_head_tail_reversed(r1, hd, tl)) { + return re_and(m().mk_eq(ele, tl), re().mk_to_re(hd)); + } else if (str().is_empty(r1)) { return mk_empty(); } - else { - return expr_ref(re().mk_derivative(ele, r), m()); - } + // (Otherwise, falls back to default case) } - else if (re().is_reverse(r, r1) && re().is_to_re(r1, r2)) { - // Reverses are rewritten so that the only derivative case is - // derivative of a reverse of a string. (All other cases stuck) - // This is analagous to the previous is_to_re case. - expr_ref hd(m()), tl(m()); - if (get_head_tail_reversed(r2, hd, tl)) { - return re_and(m().mk_eq(ele, tl), re().mk_reverse(re().mk_to_re(hd))); - } - else if (str().is_empty(r2)) { - return mk_empty(); - } - else { - return expr_ref(re().mk_derivative(ele, r), m()); - } + else if (re().is_reverse(r, r1)) { + // Push derivative inside and flip direction. + // If lift_over_union and lift_over_inter are false, this stops + // lifting. It may be possible to do smarter lifting here. + return mk_der_reverse(mk_derivative(ele, r1, !left, + lift_over_union, + lift_over_inter)); } else if (re().is_range(r, r1, r2)) { // r1, r2 are sequences. @@ -2562,9 +2645,17 @@ expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r) { result = array.mk_select(2, args); return re_predicate(result, seq_sort); } - // stuck cases: re().is_derivative, variable, ... - // and re().is_reverse if the reverse is not applied to a string - return expr_ref(re().mk_derivative(ele, r), m()); + // stuck cases: re().is_derivative, variable, and + // to_re if the string can't be rewritten as empty or head/tail + if (left) { + return expr_ref(re().mk_derivative(ele, r), m()); + } + else { + return expr_ref( + re().mk_reverse(re().mk_derivative(ele, re().mk_reverse(r))), m() + ); + } + } /* diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 208261474..e0316dd62 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -182,14 +182,20 @@ class seq_rewriter { expr_ref mk_seq_concat(expr* a, expr* b); + // Calculate derivative, memoized and enforcing a normal form + expr_ref mk_derivative(expr* ele, expr* r, bool left = true, + bool lift_over_union = true, + bool lift_over_inter = true); + expr_ref mk_derivative_rec(expr* ele, expr* r, bool left, + bool lift_over_union, + bool lift_over_inter); expr_ref mk_der_op(decl_kind k, expr* a, expr* b); expr_ref mk_der_op_rec(decl_kind k, expr* a, expr* b); expr_ref mk_der_concat(expr* a, expr* b); expr_ref mk_der_union(expr* a, expr* b); expr_ref mk_der_inter(expr* a, expr* b); expr_ref mk_der_compl(expr* a); - expr_ref mk_derivative(expr* ele, expr* r); - expr_ref mk_derivative_rec(expr* ele, expr* r); + expr_ref mk_der_reverse(expr* a); bool are_complements(expr* r1, expr* r2) const; bool is_subset(expr* r1, expr* r2) const; diff --git a/src/ast/seq_decl_plugin.h b/src/ast/seq_decl_plugin.h index 17438a054..ae83f29b1 100644 --- a/src/ast/seq_decl_plugin.h +++ b/src/ast/seq_decl_plugin.h @@ -109,6 +109,7 @@ enum seq_op_kind { _OP_REGEXP_EMPTY, _OP_REGEXP_FULL_CHAR, _OP_RE_IS_NULLABLE, + _OP_RE_RIGHT_DERIVATIVE, _OP_SEQ_SKOLEM, LAST_SEQ_OP };