diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 5b9e27847..f4d010632 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -690,14 +690,6 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con SASSERT(num_args == 1); st = mk_str_stoi(args[0], result); break; - // case OP_ITE: - // // Rewrite ITEs in the case of regexes - // SASSERT(num_args == 3); - // if (m_util.is_re(args[1])) { - // SASSERT(m_util.is_re(args[2])); - // st = mk_re_ite(args[0], args[1], args[2], result); - // } - // break; case _OP_STRING_CONCAT: case _OP_STRING_PREFIX: case _OP_STRING_SUFFIX: @@ -2140,18 +2132,19 @@ expr_ref seq_rewriter::re_predicate(expr* cond, sort* seq_sort) { } expr_ref seq_rewriter::is_nullable_rec(expr* r) { - std::cout << "is_nullable_rec" << std::endl; - expr_ref result(m_op_cache.find(_OP_RE_IS_NULLABLE, r, nullptr), m()); + std::cout << "n"; + expr_ref result(m_op_cache.find(_OP_RE_IS_NULLABLE, r, nullptr, nullptr), m()); if (!result) { + std::cout << "(m) "; result = is_nullable(r); - m_op_cache.insert(_OP_RE_IS_NULLABLE, r, nullptr, result); + m_op_cache.insert(_OP_RE_IS_NULLABLE, r, nullptr, nullptr, result); + } else { + std::cout << "(h) "; } return result; } expr_ref seq_rewriter::is_nullable(expr* r) { - std::cout << "is_nullable" << std::endl; - // std::cout << "call to is_nullable(" << expr_ref(r, m()) << ")" << std::endl; SASSERT(m_util.is_re(r)); expr* r1 = nullptr, *r2 = nullptr, *cond = nullptr; unsigned lo = 0, hi = 0; @@ -2303,7 +2296,6 @@ br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) { expr *r1 = nullptr, *r2 = nullptr, *p = nullptr; unsigned lo = 0, hi = 0; if (re().is_concat(r, r1, r2)) { - std::cout << "is_nullable -- from concat" << std::endl; expr_ref is_n = is_nullable(r1); expr_ref dr1(re().mk_derivative(ele, r1), m()); expr_ref dr2(re().mk_derivative(ele, r2), m()); @@ -2451,31 +2443,176 @@ br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) { } /* - Optimizations for ITEs of regexes, since they come up frequently - in calculating derivatives. + Combine two if-then-else expressions in BDD form. - lifting functions (lift_ites, lift_ites_throttled): - push all ite expressions to the top level. + Definition of BDD form: + if-then-elses are pushed outwards + and sorted by condition ID (cond->get_id()), from largest on + the outside to smallest on the inside. + Duplicate nested conditions are eliminated. - rewriting (mk_re_ite): - ite(not c, r1, r2) -> ite(c, r2, r1) - ite(c, ite(c, r1, r2), r3)) -> ite(c, r1, r3) - ite(c, r1, ite(c, r2, r3)) -> ite(c, r1, r3) - ite(c1, ite(c2, r1, r2), r3) where id of c1 < id of c2 -> - ite(c2, ite(c1, r1, r3), ite(c1, r2, r3)) - ite(c1, r1, ite(c2, r2, r3)) where id of c1 < id of c2 -> - ite(c2, ite(c1, r1, r2), ite(c1, r1, r3)) + Preconditions: + - EITHER k is a binary op code on REs (re.union, re.inter, etc.) + and cond is nullptr, + OR k is if-then-else (OP.ITE) and cond is the condition. + - a and b are in BDD form. + + Postcondition: result is in BDD form. + if-then-elses are pushed outwards + and sorted by condition ID (cond->get_id()), from largest on + the outside to smallest on the inside. + + Uses op cache (memoization) to avoid duplicating work for the same + pair of pointers. */ -expr_ref seq_rewriter::lift_ites(expr* a, bool lift_over_union, bool lift_over_inter) { - expr_ref result(m()); - result = a; +expr_ref seq_rewriter::combine_ites(decl_kind k, expr* a, expr* b, expr* cond) { + std::cout << "c"; + expr_ref result(m_op_cache.find(k, a, b, cond), m()); + if (result) { + std::cout << "(h) "; + return result; + } + std::cout << "(m) "; + SASSERT((k == OP_ITE) == (cond != nullptr)); + expr *acond = nullptr, *a1 = nullptr, *a2 = nullptr, + *bcond = nullptr, *b1 = nullptr, *b2 = nullptr; + expr_ref result1(m()), result2(m()); + if (k == OP_ITE) { + if (m().is_ite(a, acond, a1, a2) && + cond->get_id() < acond->get_id()) { + // Push ITE inwards on first arg + result1 = combine_ites(k, a1, b, cond); + result2 = combine_ites(k, a2, b, cond); + result = combine_ites(k, result1, result2, acond); + } + else if (m().is_ite(a, acond, a1, a2) && + cond == acond) { + // Collapse ITE on first arg + result = combine_ites(k, a1, b, cond); + } + else if (m().is_ite(b, bcond, b1, b2) && + cond->get_id() < bcond->get_id()) { + // Push ITE inwards on second arg + result1 = combine_ites(k, a, b1, cond); + result2 = combine_ites(k, a, b2, cond); + result = combine_ites(k, result1, result2, bcond); + } + else if (m().is_ite(b, bcond, b1, b2) && + cond == bcond) { + // Collapse ITE on second arg + result = combine_ites(k, a, b2, cond); + } + else { + // Apply ITE -- no simplification required + result = m().mk_ite(a, b, cond); + } + } + else if (m().is_ite(a, acond, a1, a2)) { + // Push binary op inwards on first arg + result1 = combine_ites(k, a1, b, nullptr); + result2 = combine_ites(k, a2, b, nullptr); + result = combine_ites(OP_ITE, result1, result2, acond); + } + else if (m().is_ite(b, bcond, b1, b2)) { + // Push binary op inwards on second arg + result1 = combine_ites(k, a, b1, nullptr); + result2 = combine_ites(k, a, b2, nullptr); + result = combine_ites(OP_ITE, result1, result2, bcond); + } + else { + // Apply binary op (a and b are free of ITE) + result = m().mk_app(get_fid(), k, a, b); + } + // Save result before returning + m_op_cache.insert(k, a, b, cond, result); return result; } +/* + Lift if-then-else expressions to the top level, enforcing a BDD form. + + Postcondition: result is in BDD form. + - Alternatively, if lift_over_union and/or lift_over_inter is false, + then result is a disjunction and/or conjunciton of expressions in + BDD form. (Even in this case, ITE is still lifted at lower levels, + just not at the top level.) + - Note that the result may not be fully simplified (particularly the + nested expressions inside if-then-else). Simplification should be + called afterwards. + + Cost: Causes potential blowup in the size of an expression (when + expanded out), but keeps the representation compact (subexpressions + are shared). + + Used by: the regex solver in seq_regex.cpp when dealing with + derivatives of a regex by a symbolic character. Enables efficient + representation in unfolding string in regex constraints. +*/ +expr_ref seq_rewriter::lift_ites(expr* r, bool lift_over_union, bool lift_over_inter) { + std::cout << "l "; + decl_kind k = to_app(r)->get_decl_kind(); + family_id fid = get_fid(); + expr *r1 = nullptr, *r2 = nullptr, *cond = nullptr, *ele = nullptr; + unsigned lo = 0, hi = 0; + expr_ref result(m()), result1(m()), result2(m()); + if ((re().is_union(r, r1, r2) && !lift_over_union) || + (re().is_intersection(r, r1, r2) && !lift_over_inter)) { + // Preserve unions and/or intersections + result1 = lift_ites(r1, lift_over_union, lift_over_inter); + result2 = lift_ites(r2, lift_over_union, lift_over_inter); + result = m().mk_app(fid, k, r1, r2); + } + else if (m().is_ite(r, cond, r1, r2) || + re().is_concat(r, r1, r2) || + re().is_union(r, r1, r2) || + re().is_intersection(r, r1, r2) || + re().is_diff(r, r1, r2)) { + // Use combine_ites on the subresults + // Stop preserving unions and intersections + result1 = lift_ites(r1, true, true); + result2 = lift_ites(r2, true, true); + result = combine_ites(k, r1, r2, cond); + } + else if (re().is_star(r, r1) || + re().is_plus(r, r1) || + re().is_opt(r, r1) || + re().is_complement(r, r1) || + re().is_reverse(r, r1)) { + // Stop preserving unions and intersections + result1 = lift_ites(r1, true, true); + result = m().mk_app(fid, k, r1); + } + else if (re().is_derivative(r, ele, r1)) { + result1 = lift_ites(r1, true, true); + result = m().mk_app(fid, k, ele, r1); + } + else if (re().is_loop(r, r1, lo)) { + result1 = lift_ites(r1, true, true); + result = re().mk_loop(result1, lo); + } + else if (re().is_loop(r, r1, lo, hi)) { + result1 = lift_ites(r1, true, true); + result = re().mk_loop(result1, lo, hi); + } + else { + // is_full_seq, is_empty, is_to_re, is_range, is_full_char, is_of_pred + result = r; + } + return result; +} + +/* + Lift all ite expressions to the top level, but + a different "safe" version which is throttled to not + blowup the size of the expression. + + Note: this function does not ensure the same BDD form that lift_ites + ensures. +*/ br_status seq_rewriter::lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result) { expr* c = nullptr, *t = nullptr, *e = nullptr; - for (unsigned i = 0; i < n; ++i) { - if (m().is_ite(args[i], c, t, e) && + for (unsigned i = 0; i < n; ++i) { + if (m().is_ite(args[i], c, t, e) && (get_depth(t) <= 2 || t->get_ref_count() == 1 || get_depth(e) <= 2 || e->get_ref_count() == 1)) { ptr_buffer new_args; @@ -2491,40 +2628,50 @@ br_status seq_rewriter::lift_ites_throttled(func_decl* f, unsigned n, expr* cons return BR_FAILED; } -br_status seq_rewriter::mk_re_ite(expr* cond, expr* r1, expr* r2, expr_ref& result) { - VERIFY(m_util.is_re(r1)); - VERIFY(m_util.is_re(r2)); - expr *c = nullptr, *ra = nullptr, *rb = nullptr; - if (m().is_not(cond, c)) { - result = m().mk_ite(c, r2, r1); - return BR_REWRITE1; - } - if (m().is_ite(r1, c, ra, rb)) { - if (m().are_equal(c, cond)) { - result = m().mk_ite(cond, ra, r2); - return BR_REWRITE1; - } - if (cond->get_id() < c->get_id()) { - expr *result1 = m().mk_ite(cond, ra, r2); - expr *result2 = m().mk_ite(cond, rb, r2); - result = m().mk_ite(c, result1, result2); - return BR_REWRITE2; - } - } - if (m().is_ite(r2, c, ra, rb)) { - if (m().are_equal(c, cond)) { - result = m().mk_ite(cond, r1, rb); - return BR_REWRITE1; - } - if (cond->get_id() < c->get_id()) { - expr *result1 = m().mk_ite(cond, r1, ra); - expr* result2 = m().mk_ite(cond, r1, rb); - result = m().mk_ite(c, result1, result2); - return BR_REWRITE2; - } - } - return BR_DONE; -} +// /* +// Rewrite rules for ITEs of regexes. +// ite(not c, r1, r2) -> ite(c, r2, r1) +// ite(c, ite(c, r1, r2), r3)) -> ite(c, r1, r3) +// ite(c, r1, ite(c, r2, r3)) -> ite(c, r1, r3) +// ite(c1, ite(c2, r1, r2), r3) where id of c1 < id of c2 -> +// ite(c2, ite(c1, r1, r3), ite(c1, r2, r3)) +// ite(c1, r1, ite(c2, r2, r3)) where id of c1 < id of c2 -> +// ite(c2, ite(c1, r1, r2), ite(c1, r1, r3)) +// */ +// br_status seq_rewriter::rewrite_re_ite(expr* cond, expr* r1, expr* r2, expr_ref& result) { +// VERIFY(m_util.is_re(r1)); +// VERIFY(m_util.is_re(r2)); +// expr *c = nullptr, *ra = nullptr, *rb = nullptr; +// if (m().is_not(cond, c)) { +// result = m().mk_ite(c, r2, r1); +// return BR_REWRITE1; +// } +// if (m().is_ite(r1, c, ra, rb)) { +// if (m().are_equal(c, cond)) { +// result = m().mk_ite(cond, ra, r2); +// return BR_REWRITE1; +// } +// if (cond->get_id() < c->get_id()) { +// expr *result1 = m().mk_ite(cond, ra, r2); +// expr *result2 = m().mk_ite(cond, rb, r2); +// result = m().mk_ite(c, result1, result2); +// return BR_REWRITE2; +// } +// } +// if (m().is_ite(r2, c, ra, rb)) { +// if (m().are_equal(c, cond)) { +// result = m().mk_ite(cond, r1, rb); +// return BR_REWRITE1; +// } +// if (cond->get_id() < c->get_id()) { +// expr *result1 = m().mk_ite(cond, r1, ra); +// expr* result2 = m().mk_ite(cond, r1, rb); +// result = m().mk_ite(c, result1, result2); +// return BR_REWRITE2; +// } +// } +// return BR_DONE; +// } /* * pattern match against all ++ "abc" ++ all ++ "def" ++ all regexes. @@ -2703,7 +2850,6 @@ br_status seq_rewriter::mk_str_in_regexp(expr* a, expr* b, expr_ref& result) { return BR_REWRITE1; } if (str().is_empty(a)) { - std::cout << "is_nullable -- from str.in_re" << std::endl; result = is_nullable(b); if (str().is_in_re(result)) return BR_DONE; @@ -3963,18 +4109,20 @@ seq_rewriter::op_cache::op_cache(ast_manager& m): m_trail(m) {} -expr* seq_rewriter::op_cache::find(decl_kind op, expr* a, expr* b) { - op_entry e(op, a, b, nullptr); +expr* seq_rewriter::op_cache::find(decl_kind op, expr* a, expr* b, expr* c) { + op_entry e(op, a, b, c, nullptr); m_table.find(e); return e.r; } -void seq_rewriter::op_cache::insert(decl_kind op, expr* a, expr* b, expr* r) { +void seq_rewriter::op_cache::insert(decl_kind op, + expr* a, expr* b, expr* c, expr* r) { cleanup(); if (a) m_trail.push_back(a); if (b) m_trail.push_back(b); + if (c) m_trail.push_back(c); if (r) m_trail.push_back(r); - m_table.insert(op_entry(op, a, b, r)); + m_table.insert(op_entry(op, a, b, c, r)); } void seq_rewriter::op_cache::cleanup() { diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index f41f11a5f..21ab76d07 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -118,20 +118,23 @@ class seq_rewriter { class op_cache { struct op_entry { decl_kind k; - expr* a, *b, *r; - op_entry(decl_kind k, expr* a, expr* b, expr* r): k(k), a(a), b(b), r(r) {} - op_entry():k(0), a(nullptr), b(nullptr), r(nullptr) {} + expr* a, *b, *c, *r; + op_entry(decl_kind k, expr* a, expr* b, expr* c, expr* r): + k(k), a(a), b(b), c(c), r(r) {} + op_entry():k(0), a(nullptr), b(nullptr), c(nullptr), r(nullptr) {} }; struct hash_entry { unsigned operator()(op_entry const& e) const { - return mk_mix(e.k, e.a ? e.a->get_id() : 0, e.b ? e.b->get_id() : 0); + return combine_hash(e.k, mk_mix(e.a ? e.a->get_id() : 0, + e.b ? e.b->get_id() : 0, + e.c ? e.c->get_id() : 0)); } }; struct eq_entry { bool operator()(op_entry const& a, op_entry const& b) const { - return a.k == b.k && a.a == b.a && a.b == b.b; + return a.k == b.k && a.a == b.a && a.b == b.b && a.c == b.c; } }; @@ -145,8 +148,8 @@ class seq_rewriter { public: op_cache(ast_manager& m); - expr* find(decl_kind op, expr* a, expr* b); - void insert(decl_kind op, expr* a, expr* b, expr* r); + expr* find(decl_kind op, expr* a, expr* b, expr* c); + void insert(decl_kind op, expr* a, expr* b, expr* c, expr* r); }; seq_util m_util; @@ -219,15 +222,9 @@ class seq_rewriter { br_status mk_re_reverse(expr* r, expr_ref& result); br_status mk_re_derivative(expr* ele, expr* r, expr_ref& result); - // if-then-else rewriting support (for REs) - br_status mk_re_ite(expr* cond, expr* r1, expr* r2, expr_ref& result); - expr_ref lift_ites(expr* a, bool lift_over_union = true, bool lift_over_inter = true); - br_status lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result); - br_status reduce_re_eq(expr* a, expr* b, expr_ref& result); br_status reduce_re_is_empty(expr* r, expr_ref& result); - bool non_overlap(expr_ref_vector const& p1, expr_ref_vector const& p2) const; bool non_overlap(zstring const& p1, zstring const& p2) const; bool rewrite_contains_pattern(expr* a, expr* b, expr_ref& result); @@ -271,6 +268,9 @@ class seq_rewriter { void get_cofactors(expr* r, expr_ref_vector& conds, expr_ref_pair_vector& result); void intersect(unsigned lo, unsigned hi, svector>& ranges); + expr_ref combine_ites(decl_kind k, expr* a, expr* b, expr* cond); + br_status lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result); + public: seq_rewriter(ast_manager & m, params_ref const & p = params_ref()): m_util(m), m_autil(m), m_re2aut(m), m_op_cache(m), m_es(m), m_lhs(m), m_rhs(m), m_coalesce_chars(true) { @@ -318,7 +318,7 @@ public: expr_ref is_nullable(expr* r); expr_ref is_nullable_rec(expr* r); - // utilities for cofactors: conditions that appear in if-then-else expressions + // utilities for cofactors of if-then-else expressions bool has_cofactor(expr* r, expr_ref& cond, expr_ref& th, expr_ref& el); void get_cofactors(expr* r, expr_ref_pair_vector& result) { expr_ref_vector conds(m()); @@ -329,6 +329,8 @@ public: // special case optimization for conjunctions of equalities, disequalities and ranges. void elim_condition(expr* elem, expr_ref& cond); + // if-then-else rewriting support (for REs) + expr_ref lift_ites(expr* r, bool lift_over_union = true, bool lift_over_inter = true); }; #endif diff --git a/src/smt/seq_regex.cpp b/src/smt/seq_regex.cpp index 1fb9ee2aa..675893381 100644 --- a/src/smt/seq_regex.cpp +++ b/src/smt/seq_regex.cpp @@ -101,7 +101,7 @@ namespace smt { expr* e = ctx.bool_var2expr(lit.var()); VERIFY(str().is_in_re(e, s, r)); - std::cout << "SEQ REGEX P_IN_RE" << std::endl; + std::cout << "PI "; TRACE("seq", tout << "propagate " << mk_pp(e, m) << "\n";); @@ -146,7 +146,7 @@ namespace smt { } void seq_regex::propagate_accept(literal lit) { - std::cout << "SEQ REGEX P_ACCEPT" << std::endl; + std::cout << "PA "; if (!propagate(lit)) m_to_propagate.push_back(lit); } @@ -200,7 +200,7 @@ namespace smt { TRACE("seq", tout << "propagate " << mk_pp(e, m) << "\n";); - std::cout << "SEQ REGEX P" << std::endl; + std::cout << "P "; // << mk_pp(e, m) << std::endl; if (block_unfolding(lit, idx)) @@ -222,8 +222,7 @@ namespace smt { case l_undef: ctx.mark_as_relevant(len_s_le_i); return false; - case l_true: - std::cout << "is_nullable -- from prop" << std::endl; + case l_true: is_nullable = seq_rw().is_nullable(d); rewrite(is_nullable); conds.push_back(~len_s_le_i); @@ -234,12 +233,9 @@ namespace smt { break; } - std::cout << "...MK DERIVATIVE" << std::endl; - // (accept s i R) & len(s) > i => (accept s (+ i 1) D(nth(s, i), R)) or conds expr_ref head = th.mk_nth(s, i); - d = re().mk_derivative(head, r); - rewrite(d); + d = derivative_wrapper(head, r); literal acc_next = th.mk_literal(sk().mk_accept(s, a().mk_int(idx + 1), d)); conds.push_back(len_s_le_i); @@ -319,14 +315,17 @@ namespace smt { } /* - Memoized wrapper around the regex symbolic derivative. + Memoized(TODO) wrapper around the regex symbolic derivative. Also ensures that the derivative is written in a normalized form with optimizations for if-then-else expressions involving the head. */ expr_ref seq_regex::derivative_wrapper(expr* hd, expr* r) { + std::cout << "D "; expr_ref result = expr_ref(re().mk_derivative(hd, r), m); rewrite(result); - // TODO + // don't lift over unions + result = seq_rw().lift_ites(result); // false, true); + rewrite(result); return result; } @@ -362,7 +361,7 @@ namespace smt { * */ void seq_regex::propagate_is_non_empty(literal lit) { - std::cout << "SEQ REGEX P_NE" << std::endl; + std::cout << "PN "; expr* e = ctx.bool_var2expr(lit.var()), *r = nullptr, *u = nullptr; VERIFY(sk().is_is_non_empty(e, r, u)); expr_ref is_nullable = seq_rw().is_nullable(r); @@ -403,7 +402,7 @@ namespace smt { is_empty(r, u) is true if r is a member of u */ void seq_regex::propagate_is_empty(literal lit) { - std::cout << "SEQ REGEX P_E" << std::endl; + std::cout << "PE "; expr* e = ctx.bool_var2expr(lit.var()), *r = nullptr, *u = nullptr; VERIFY(sk().is_is_empty(e, r, u)); expr_ref is_nullable = seq_rw().is_nullable(r);