From 1fd567d1e9a8bfeea2f7a72de7011b6e8d06edfa Mon Sep 17 00:00:00 2001 From: calebstanford-msr <65260146+calebstanford-msr@users.noreply.github.com> Date: Tue, 9 Jun 2020 14:36:31 -0400 Subject: [PATCH] fix bug in seq rewriter op_cache::find (#4509) * remove unneeded reverse case in derivative; placeholder for generalized lifted derivative * experimental tweaks to RE rewriter to improve performance * if-then-else lifting (broken code -- preserving this commit in case this idea is useful later) * if-then-else derivative optimizations: new approach templates * implement if-then-else BDD normal form for derivatives (code compiles but is still buggy) * remove std::cout debugging for PR * Revert "remove std::cout debugging for PR" This reverts commit c7bdc44d319ea02735c7d8f1076c01acb29ddc91. * debugging * fix derivative interaction with reverse; add flags for left/right derivative and lifting over union/intersection * remove debugging statements for PR * Revert "remove debugging statements for PR" This reverts commit 38e85a72881d46153bd8561e454599bdf851689f. * revert some purely cosmetic changes from upstream; fix a bug * revert unnecessary changes * remove some redundant rewrites and add a new one for str.in_re(s, comp(r)) * add disabled rewrite for complement * fix bug in op cache find (result was not saved) * remove debugging std::cout for PR --- src/ast/rewriter/seq_rewriter.cpp | 87 +++++++++++++++++++++++++------ src/ast/rewriter/seq_rewriter.h | 16 +++--- src/smt/seq_regex.cpp | 22 +++++--- src/smt/seq_regex.h | 2 + 4 files changed, 99 insertions(+), 28 deletions(-) diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index f7518f919..e3fbe8ca0 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -727,7 +727,7 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con UNREACHABLE(); } if (st == BR_FAILED) { - st = lift_ite(f, num_args, args, result); + st = lift_ites_throttled(f, num_args, args, result); } CTRACE("seq_verbose", st != BR_FAILED, tout << expr_ref(m().mk_app(f, num_args, args), m()) << " -> " << result << "\n";); SASSERT(st == BR_FAILED || m().get_sort(result) == f->get_range()); @@ -830,7 +830,14 @@ br_status seq_rewriter::mk_seq_length(expr* a, expr_ref& result) { return BR_FAILED; } -br_status seq_rewriter::lift_ite(func_decl* f, unsigned n, expr* const* args, expr_ref& result) { +/* + Lift all ite expressions to the top level, safely + throttled to not blowup the size of the expression. + + Note: this function does not ensure the same BDD form that is + used in the normal form for derivatives in mk_re_derivative. +*/ +br_status seq_rewriter::lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result) { expr* c = nullptr, *t = nullptr, *e = nullptr; for (unsigned i = 0; i < n; ++i) { if (m().is_ite(args[i], c, t, e) && @@ -2239,13 +2246,13 @@ expr_ref seq_rewriter::is_nullable(expr* r) { } /* - Push reverse inwards (gets stuck at variables and strings). + Push reverse inwards (whenever possible). */ br_status seq_rewriter::mk_re_reverse(expr* r, expr_ref& result) { sort* seq_sort = nullptr; VERIFY(m_util.is_re(r, seq_sort)); - expr* r1 = nullptr, *r2 = nullptr, *p = nullptr, *s = nullptr; - expr* s1 = nullptr, *s2 = nullptr; + expr *r1 = nullptr, *r2 = nullptr, *p = nullptr, *s = nullptr; + expr *s1 = nullptr, *s2 = nullptr; zstring zs; unsigned lo = 0, hi = 0; if (re().is_concat(r, r1, r2)) { @@ -2318,7 +2325,7 @@ br_status seq_rewriter::mk_re_reverse(expr* r, expr_ref& result) { return BR_REWRITE3; } else { - // stuck cases: variable, re().is_to_re, re().is_derivative, ... + // stuck cases: variable, re().is_derivative, ... return BR_FAILED; } } @@ -2334,6 +2341,16 @@ br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) { return re().is_derivative(result) ? BR_DONE : BR_REWRITE_FULL; } +/* + Memoized, recursive implementation of the symbolic derivative such that + the result is in an optimized BDD form. + + Definition of BDD form: + if-then-elses are pushed outwards + and sorted by condition ID (cond->get_id()), from largest on + the outside to smallest on the inside. + Duplicate nested conditions are eliminated. +*/ expr_ref seq_rewriter::mk_derivative(expr* ele, expr* r) { expr_ref result(m_op_cache.find(OP_RE_DERIVATIVE, ele, r), m()); if (!result) { @@ -2355,6 +2372,16 @@ expr_ref seq_rewriter::mk_der_concat(expr* r1, expr* r2) { return mk_der_op(OP_RE_CONCAT, r1, r2); } +/* + Apply a binary operation, preserving BDD normal form on derivative expressions. + + Preconditions: + - k is a binary op code on REs (concat, intersection, or union) + - a and b are in BDD form + + Postcondition: + - result is in BDD form +*/ expr_ref seq_rewriter::mk_der_op_rec(decl_kind k, expr* a, expr* b) { expr* ca = nullptr, *a1 = nullptr, *a2 = nullptr; expr* cb = nullptr, *b1 = nullptr, *b2 = nullptr; @@ -2427,8 +2454,8 @@ expr_ref seq_rewriter::mk_der_compl(expr* r) { } else if (BR_FAILED == mk_re_complement(r, result)) result = re().mk_complement(r); + m_op_cache.insert(OP_RE_COMPLEMENT, r, nullptr, result); } - m_op_cache.insert(OP_RE_COMPLEMENT, r, nullptr, result); return result; } @@ -2778,6 +2805,12 @@ br_status seq_rewriter::mk_str_in_regexp(expr* a, expr* b, expr_ref& result) { re().mk_in_re(str().mk_substr(a, len_hd, len_tl), tl)); return BR_REWRITE_FULL; } + + // Disabled rewrites + if (false && re().is_complement(b, b1)) { + result = m().mk_not(re().mk_in_re(a, b1)); + return BR_REWRITE2; + } if (false && rewrite_contains_pattern(a, b, result)) return BR_REWRITE_FULL; @@ -2788,6 +2821,19 @@ br_status seq_rewriter::mk_str_to_regexp(expr* a, expr_ref& result) { return BR_FAILED; } +/* + easy cases: + .* ++ .* -> .* + [] ++ r -> [] + r ++ [] -> [] + r ++ "" -> r + "" ++ r -> r + + to_re and star: + (str.to_re s1) ++ (str.to_re s2) -> (str.to_re (s1 ++ s2)) + r* ++ r* -> r* + r* ++ r -> r ++ r* +*/ br_status seq_rewriter::mk_re_concat(expr* a, expr* b, expr_ref& result) { if (re().is_full_seq(a) && re().is_full_seq(b)) { result = a; @@ -2985,8 +3031,15 @@ br_status seq_rewriter::mk_re_union(expr* a, expr* b, expr_ref& result) { return BR_FAILED; } +/* + comp(intersect e1 e2) -> union comp(e1) comp(e2) + comp(union e1 e2) -> intersect comp(e1) comp(e2) + comp(none) -> all + comp(all) -> none + comp(comp(e1)) -> e1 +*/ br_status seq_rewriter::mk_re_complement(expr* a, expr_ref& result) { - expr* e1, *e2; + expr *e1 = nullptr, *e2 = nullptr; if (re().is_intersection(a, e1, e2)) { result = re().mk_union(re().mk_complement(e1), re().mk_complement(e2)); return BR_REWRITE2; @@ -3011,11 +3064,15 @@ br_status seq_rewriter::mk_re_complement(expr* a, expr_ref& result) { } /** - (emp n r) = emp - (r n emp) = emp - (all n r) = r - (r n all) = r - (r n r) = r + (r n r) = r + (emp n r) = emp + (r n emp) = emp + (all n r) = r + (r n all) = r + (r n comp(r)) = emp + (comp(r) n r) = emp + (r n to_re(s)) = ite (s in r) to_re(s) emp + (to_re(s) n r) = " */ br_status seq_rewriter::mk_re_inter(expr* a, expr* b, expr_ref& result) { if (a == b) { @@ -3384,7 +3441,7 @@ void seq_rewriter::elim_condition(expr* elem, expr_ref& cond) { } } - + br_status seq_rewriter::reduce_re_is_empty(expr* r, expr_ref& result) { expr* r1, *r2, *r3, *r4; zstring s1, s2; @@ -3965,7 +4022,7 @@ seq_rewriter::op_cache::op_cache(ast_manager& m): expr* seq_rewriter::op_cache::find(decl_kind op, expr* a, expr* b) { op_entry e(op, a, b, nullptr); - m_table.find(e); + m_table.find(e, e); return e.r; } diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 541a08f9f..4499ef56e 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -171,23 +171,26 @@ class seq_rewriter { length_comparison compare_lengths(unsigned sza, expr* const* as, unsigned szb, expr* const* bs); - // Support for regular expression derivatives bool get_head_tail(expr* e, expr_ref& head, expr_ref& tail); bool get_head_tail_reversed(expr* e, expr_ref& head, expr_ref& tail); bool get_re_head_tail(expr* e, expr_ref& head, expr_ref& tail); bool get_re_head_tail_reversed(expr* e, expr_ref& head, expr_ref& tail); + expr_ref re_and(expr* cond, expr* r); expr_ref re_predicate(expr* cond, sort* seq_sort); + expr_ref mk_seq_concat(expr* a, expr* b); + // Calculate derivative, memoized and enforcing a normal form + expr_ref mk_derivative(expr* ele, expr* r); + expr_ref mk_derivative_rec(expr* ele, expr* r); expr_ref mk_der_op(decl_kind k, expr* a, expr* b); expr_ref mk_der_op_rec(decl_kind k, expr* a, expr* b); expr_ref mk_der_concat(expr* a, expr* b); expr_ref mk_der_union(expr* a, expr* b); expr_ref mk_der_inter(expr* a, expr* b); expr_ref mk_der_compl(expr* a); - expr_ref mk_derivative(expr* ele, expr* r); - expr_ref mk_derivative_rec(expr* ele, expr* r); + expr_ref mk_der_reverse(expr* a); bool are_complements(expr* r1, expr* r2) const; bool is_subset(expr* r1, expr* r2) const; @@ -231,11 +234,12 @@ class seq_rewriter { br_status mk_re_range(expr* lo, expr* hi, expr_ref& result); br_status mk_re_reverse(expr* r, expr_ref& result); br_status mk_re_derivative(expr* ele, expr* r, expr_ref& result); - br_status lift_ite(func_decl* f, unsigned n, expr* const* args, expr_ref& result); + + br_status lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result); + br_status reduce_re_eq(expr* a, expr* b, expr_ref& result); br_status reduce_re_is_empty(expr* r, expr_ref& result); - bool non_overlap(expr_ref_vector const& p1, expr_ref_vector const& p2) const; bool non_overlap(zstring const& p1, zstring const& p2) const; bool rewrite_contains_pattern(expr* a, expr* b, expr_ref& result); @@ -324,9 +328,9 @@ public: void add_seqs(expr_ref_vector const& ls, expr_ref_vector const& rs, expr_ref_pair_vector& new_eqs); + // Check for acceptance of the empty string expr_ref is_nullable(expr* r); - // heuristic elimination of element from condition that comes form a derivative. // special case optimization for conjunctions of equalities, disequalities and ranges. void elim_condition(expr* elem, expr_ref& cond); diff --git a/src/smt/seq_regex.cpp b/src/smt/seq_regex.cpp index 2a8c2e3a4..6044f0275 100644 --- a/src/smt/seq_regex.cpp +++ b/src/smt/seq_regex.cpp @@ -146,7 +146,6 @@ namespace smt { m_to_propagate.push_back(lit); } - /** * Propagate the atom (accept s i r) * @@ -222,9 +221,9 @@ namespace smt { // (accept s i R) & len(s) > i => (accept s (+ i 1) D(nth(s, i), R)) or conds expr_ref d(m); expr_ref head = th.mk_nth(s, i); - d = re().mk_derivative(m.mk_var(0, m.get_sort(head)), r); + + d = derivative_wrapper(m.mk_var(0, m.get_sort(head)), r); // timer tm; - rewrite(d); // std::cout << d->get_id() << " " << tm.get_seconds() << "\n"; // if (tm.get_seconds() > 1) // std::cout << d << "\n"; @@ -351,6 +350,17 @@ namespace smt { return r; } + /* + Wrapper around the regex symbolic derivative from the rewriter. + Ensures that the derivative is written in a normalized BDD form + with optimizations for if-then-else expressions involving the head. + */ + expr_ref seq_regex::derivative_wrapper(expr* hd, expr* r) { + expr_ref result = expr_ref(re().mk_derivative(hd, r), m); + rewrite(result); + return result; + } + void seq_regex::propagate_eq(expr* r1, expr* r2) { expr_ref r = symmetric_diff(r1, r2); expr_ref emp(re().mk_empty(m.get_sort(r)), m); @@ -392,8 +402,7 @@ namespace smt { literal null_lit = th.mk_literal(is_nullable); expr_ref hd = mk_first(r); expr_ref d(m); - d = re().mk_derivative(hd, r); - rewrite(d); + d = derivative_wrapper(hd, r); literal_vector lits; lits.push_back(~lit); if (null_lit != false_literal) @@ -450,8 +459,7 @@ namespace smt { th.add_axiom(~lit, ~th.mk_literal(is_nullable)); expr_ref hd = mk_first(r); expr_ref d(m); - d = re().mk_derivative(hd, r); - rewrite(d); + d = derivative_wrapper(hd, r); literal_vector lits; expr_ref_pair_vector cofactors(m); get_cofactors(d, cofactors); diff --git a/src/smt/seq_regex.h b/src/smt/seq_regex.h index 71bcd160d..56d3bab6b 100644 --- a/src/smt/seq_regex.h +++ b/src/smt/seq_regex.h @@ -71,6 +71,8 @@ namespace smt { expr_ref symmetric_diff(expr* r1, expr* r2); + expr_ref derivative_wrapper(expr* hd, expr* r); + void get_cofactors(expr* r, expr_ref_vector& conds, expr_ref_pair_vector& result); void get_cofactors(expr* r, expr_ref_pair_vector& result) {