diff --git a/src/ast/rewriter/bool_rewriter.cpp b/src/ast/rewriter/bool_rewriter.cpp index 014e563a5b..945aa297ee 100644 --- a/src/ast/rewriter/bool_rewriter.cpp +++ b/src/ast/rewriter/bool_rewriter.cpp @@ -1196,15 +1196,15 @@ bool bool_rewriter::decompose_ite(expr *r, expr_ref &c, expr_ref &th, expr_ref & } for (expr *e : subterms::ground(expr_ref(r, m()))) { if (m().is_ite(e, cond, r1, r2)) { - expr_safe_replace rep1(m()); - expr_safe_replace rep2(m()); - rep1.insert(e, r1); - rep2.insert(e, r2); + m_rep1.reset(); + m_rep2.reset(); + m_rep1.insert(e, r1); + m_rep2.insert(e, r2); c = cond; th = r; el = r; - rep1(th); - rep2(el); + m_rep1(th); + m_rep2(el); return true; } } diff --git a/src/ast/rewriter/bool_rewriter.h b/src/ast/rewriter/bool_rewriter.h index 27c971f787..87c50f171e 100644 --- a/src/ast/rewriter/bool_rewriter.h +++ b/src/ast/rewriter/bool_rewriter.h @@ -20,6 +20,7 @@ Notes: #include "ast/ast.h" #include "ast/rewriter/rewriter.h" +#include "ast/rewriter/expr_safe_replace.h" #include "util/params.h" /** @@ -64,6 +65,7 @@ class bool_rewriter { ptr_vector m_todo1, m_todo2; unsigned_vector m_counts1, m_counts2; expr_mark m_marked; + expr_safe_replace m_rep1, m_rep2; br_status mk_flat_and_core(unsigned num_args, expr * const * args, expr_ref & result); br_status mk_flat_or_core(unsigned num_args, expr * const * args, expr_ref & result); @@ -87,7 +89,7 @@ class bool_rewriter { expr_ref simplify_eq_ite(expr* value, expr* ite); public: - bool_rewriter(ast_manager & m, params_ref const & p = params_ref()):m_manager(m), m_local_ctx_cost(0) { + bool_rewriter(ast_manager & m, params_ref const & p = params_ref()):m_manager(m), m_local_ctx_cost(0), m_rep1(m), m_rep2(m) { updt_params(p); } ast_manager & m() const { return m_manager; } diff --git a/src/ast/rewriter/seq_derive.cpp b/src/ast/rewriter/seq_derive.cpp index 7447e09113..64248fa2ca 100644 --- a/src/ast/rewriter/seq_derive.cpp +++ b/src/ast/rewriter/seq_derive.cpp @@ -45,12 +45,11 @@ namespace seq { } void derive::reset() { - m_cache.reset(); - m_top_cache.reset(); - m_union_cache.reset(); - m_inter_cache.reset(); - m_concat_cache.reset(); - m_complement_cache.reset(); + m_acache.reset(); + m_bcache.reset(); + m_atop_cache.reset(); + m_btop_cache.reset(); + reset_op_caches(); m_trail.reset(); m_ele = nullptr; } @@ -59,14 +58,19 @@ namespace seq { // while preserving derivative caches (m_cache, m_top_cache) // The op cache does index on m_ele so it has to be reset if m_ele changes. void derive::reset_op_caches() { - m_union_cache.reset(); - m_inter_cache.reset(); - m_concat_cache.reset(); - m_complement_cache.reset(); + m_aunion_cache.reset(); + m_ainter_cache.reset(); + m_aconcat_cache.reset(); + m_acomplement_cache.reset(); + m_bunion_cache.reset(); + m_binter_cache.reset(); + m_bconcat_cache.reset(); + m_bcomplement_cache.reset(); m_ele = nullptr; } - expr_ref derive::operator()(expr* ele, expr* r) { + expr_ref derive::operator()(derivative_kind k, expr* ele, expr* r) { + m_derivative_kind = k; SASSERT(m_util.is_re(r)); if (m_trail.size() > 500000) reset(); @@ -78,7 +82,7 @@ namespace seq { // Check top-level cache (post-simplify result) expr* cached = nullptr; expr_ref result(m); - if (m_top_cache.find(ele, r, cached)) { + if (top_cache().find(ele, r, cached)) { result = cached; return result; } @@ -100,20 +104,20 @@ namespace seq { m_intervals_start = 0; m_path_expr = m.mk_true(); result = derive_rec(r); - m_top_cache.insert(ele, r, result); + top_cache().insert(ele, r, result); // pin the final result m_trail.push_back(result); return result; } - expr_ref derive::operator()(expr* r) { + expr_ref derive::operator()(derivative_kind k, expr* r) { SASSERT(m_util.is_re(r)); sort* seq_sort = nullptr, * ele_sort = nullptr; VERIFY(m_util.is_re(r, seq_sort)); VERIFY(m_util.is_seq(seq_sort, ele_sort)); expr_ref v(m.mk_var(0, ele_sort), m); - return (*this)(v, r); + return (*this)(k,v, r); } // ------------------------------------------------------- @@ -125,7 +129,7 @@ namespace seq { // Check cache (indexed by both m_ele and r) expr* cached = nullptr; - if (m_cache.find(m_ele, r, cached)) + if (cache().find(m_ele, r, cached)) return expr_ref(cached, m); // Depth check @@ -138,7 +142,7 @@ namespace seq { expr_ref result = derive_core(r); // Cache the result - m_cache.insert(m_ele, r, result); + cache().insert(m_ele, r, result); m_trail.push_back(m_ele); m_trail.push_back(r); m_trail.push_back(result); @@ -667,7 +671,7 @@ namespace seq { expr_ref derive::mk_core(decl_kind k, expr* a, expr* b) { expr *pe = get_path_expr(); expr *cached = nullptr; - auto& cache = k == OP_RE_UNION ? m_union_cache : k == OP_RE_INTERSECT ? m_inter_cache : m_xor_cache; + auto& cache = k == OP_RE_UNION ? union_cache() : k == OP_RE_INTERSECT ? inter_cache() : xor_cache(); if (cache.find(a, b, pe, cached)) return expr_ref(cached, m); expr_ref result(m); @@ -677,7 +681,8 @@ namespace seq { auto xor_op = [&](expr *x, expr *y) { return mk_xor(x, y); }; switch (k) { case OP_RE_UNION: - //result = hoist_ite(a, b, union_op); + if (m_derivative_kind == derivative_kind::brzozowski_t) + result = hoist_ite(a, b, union_op); if (!result) result = mk_union_core(a, b); break; @@ -828,13 +833,13 @@ namespace seq { // Check path-aware op cache expr* pe = get_path_expr(); expr* cached = nullptr; - if (m_complement_cache.find(a, pe, cached)) + if (complement_cache().find(a, pe, cached)) return expr_ref(cached, m); expr_ref result = mk_complement_core(a); // Store in cache - m_complement_cache.insert(a, pe, result); + complement_cache().insert(a, pe, result); m_trail.push_back(a); m_trail.push_back(pe); m_trail.push_back(result); @@ -892,13 +897,13 @@ namespace seq { expr_ref derive::mk_deriv_concat(expr* d, expr* tail) { // Check op cache expr* cached = nullptr; - if (m_concat_cache.find(d, tail, cached)) + if (concat_cache().find(d, tail, cached)) return expr_ref(cached, m); expr_ref result = mk_deriv_concat_core(d, tail); // Store in cache - m_concat_cache.insert(d, tail, result); + concat_cache().insert(d, tail, result); m_trail.push_back(d); m_trail.push_back(tail); m_trail.push_back(result); @@ -921,7 +926,7 @@ namespace seq { } // (t ∪ e) · tail → (t · tail) ∪ (e · tail) - if (m_antimirov_derivative && re().is_union(d, t, e)) { + if (m_derivative_kind == derivative_kind::antimirov_t && re().is_union(d, t, e)) { expr_ref left = mk_deriv_concat(t, tail); expr_ref right = mk_deriv_concat(e, tail); return mk_union(left, right); @@ -1381,7 +1386,7 @@ namespace seq { void derive::derivative_cofactors(expr* r, expr_ref_pair_vector& result) { // Compute the symbolic derivative wrt the canonical variable // (:var 0); operator() sets m_ele to that variable. - expr_ref d = (*this)(r); + expr_ref d = (*this)(derivative_kind::brzozowski_t, r); // Enumerate the reachable, fully ITE-hoisted leaves of the // transition regex. get_cofactors uses the SAME m_ele set above, // so the (:var 0) conditions in d are matched and pruned. diff --git a/src/ast/rewriter/seq_derive.h b/src/ast/rewriter/seq_derive.h index 4a228b6bca..d3288dccf8 100644 --- a/src/ast/rewriter/seq_derive.h +++ b/src/ast/rewriter/seq_derive.h @@ -36,6 +36,7 @@ class seq_rewriter; namespace seq { + enum class derivative_kind { antimirov_t, brzozowski_t }; /** * Symbolic derivative engine for regular expressions. * @@ -60,15 +61,15 @@ namespace seq { seq_rewriter& m_re; // Cache: maps (ele, regex) pair to its derivative - obj_pair_map m_cache; - obj_pair_map m_top_cache; // post-simplify cache + obj_pair_map m_acache, m_bcache; + obj_pair_map m_atop_cache, m_btop_cache; // post-simplify cache expr_ref_vector m_trail; // pin cached results // Op cache for ITE-hoisting operations (union, inter, concat, complement) // Path-aware caches: key is (a, b, path_expr) for binary ops, (a, path_expr) for complement - obj_triple_map m_union_cache, m_inter_cache, m_xor_cache; - obj_pair_map m_concat_cache; - obj_pair_map m_complement_cache; + obj_triple_map m_aunion_cache, m_bunion_cache, m_ainter_cache, m_binter_cache, m_axor_cache, m_bxor_cache; + obj_pair_map m_aconcat_cache, m_bconcat_cache; + obj_pair_map m_acomplement_cache, m_bcomplement_cache; // Depth limiting unsigned m_depth { 0 }; @@ -77,7 +78,7 @@ namespace seq { seq_util::rex& re() { return m_util.re; } seq_util& u() { return m_util; } - bool m_antimirov_derivative = true; + derivative_kind m_derivative_kind = derivative_kind::antimirov_t; // The element (character) for the current derivative computation expr_ref m_ele; @@ -101,6 +102,34 @@ namespace seq { void pop(); // restore state to matching push expr* get_path_expr() { return m_path_expr; } + obj_pair_map &cache() { + return m_derivative_kind == derivative_kind::antimirov_t ? m_acache : m_bcache; + } + + obj_pair_map &top_cache() { + return m_derivative_kind == derivative_kind::antimirov_t ? m_atop_cache : m_btop_cache; + } + + obj_triple_map &union_cache() { + return m_derivative_kind == derivative_kind::antimirov_t ? m_aunion_cache : m_bunion_cache; + } + + obj_triple_map &inter_cache() { + return m_derivative_kind == derivative_kind::antimirov_t ? m_ainter_cache : m_binter_cache; + } + + obj_triple_map &xor_cache() { + return m_derivative_kind == derivative_kind::antimirov_t ? m_axor_cache : m_bxor_cache; + } + + obj_pair_map &concat_cache() { + return m_derivative_kind == derivative_kind::antimirov_t ? m_aconcat_cache : m_bconcat_cache; + } + + obj_pair_map &complement_cache() { + return m_derivative_kind == derivative_kind::antimirov_t ? m_acomplement_cache : m_bcomplement_cache; + } + // Hoist ITE: apply_op through ite(c, t, e) with path pruning expr_ref apply_ite(expr* c, expr* t, expr* e, expr* r, std::function apply_op); expr_ref apply_ite(expr* c, expr* t1, expr* e1, expr* t2, expr* e2, std::function apply_op); @@ -189,12 +218,12 @@ namespace seq { * When ele is a de Bruijn variable, produces a symbolic ITE-tree. * When ele is a concrete character, produces the concrete derivative. */ - expr_ref operator()(expr* ele, expr* r); + expr_ref operator()(derivative_kind k, expr* ele, expr* r); /** * Convenience: symbolic derivative using de Bruijn var 0. */ - expr_ref operator()(expr* r); + expr_ref operator()(derivative_kind k, expr* r); /** * Nullable check: returns a Boolean expression that is true iff r accepts the empty string. @@ -231,9 +260,6 @@ namespace seq { */ void derivative_cofactors(expr* r, expr_ref_pair_vector& result); - void set_antimirov(bool flag) { - m_antimirov_derivative = flag; - } }; } diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index a084a5af20..fe932ec782 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -2837,13 +2837,13 @@ bool seq_rewriter::check_deriv_normal_form(expr* r, int level) { #endif expr_ref seq_rewriter::mk_derivative(expr* r) { - auto result = m_derive(r); + auto result = m_derive(seq::derivative_kind::antimirov_t, r); TRACE(seq, tout << "Derivative of " << mk_pp(r, m()) << "\nis\n" << result << std::endl;); return result; } expr_ref seq_rewriter::mk_derivative(expr* ele, expr* r) { - auto result = m_derive(ele, r); + auto result = m_derive(seq::derivative_kind::antimirov_t, ele, r); TRACE(seq, tout << "Derivative of " << mk_pp(r, m()) << " w.r.t. " << mk_pp(ele, m()) << "\nis\n" << result << std::endl;); return result;