diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 3f82e489b..29c9b4a16 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -409,7 +409,91 @@ void seq_rewriter::get_param_descrs(param_descrs & r) { seq_rewriter_params::collect_param_descrs(r); } +br_status seq_rewriter::mk_bool_app(func_decl* f, unsigned n, expr* const* args, expr_ref& result) { + switch (f->get_decl_kind()) { + case OP_AND: + return mk_bool_app_helper(true, n, args, result); + case OP_OR: + return mk_bool_app_helper(false, n, args, result); + default: + return BR_FAILED; + } +} +br_status seq_rewriter::mk_bool_app_helper(bool is_and, unsigned n, expr* const* args, expr_ref& result) { + bool found = false; + expr* arg = nullptr; + + for (unsigned i = 0; i < n && !found; ++i) { + found = m_util.str.is_in_re(args[i]) || (m().is_not(args[i], arg) && m_util.str.is_in_re(arg)); + } + if (!found) return BR_FAILED; + + obj_map in_re, not_in_re; + bool found_pair = false; + + for (unsigned i = 0; i < n; ++i) { + expr* args_i = args[i]; + expr* x = nullptr; + expr* y = nullptr; + expr* z = nullptr; + if (m_util.str.is_in_re(args_i, x, y)) { + if (in_re.find(x, z)) { + in_re[x] = is_and ? m_util.re.mk_inter(z, y) : m_util.re.mk_union(z, y); + found_pair = true; + } + else { + in_re.insert(x, y); + } + found_pair |= not_in_re.contains(x); + } + else if (m().is_not(args_i, arg) && m_util.str.is_in_re(arg, x, y)) { + if (not_in_re.find(x, z)) { + not_in_re[x] = is_and ? m_util.re.mk_union(z, y) : m_util.re.mk_inter(z, y); + found_pair = true; + } + else { + not_in_re.insert(x, y); + } + found_pair |= in_re.contains(x); + } + } + + if (!found_pair) { + return BR_FAILED; + } + + ptr_buffer new_args; + for (auto const & kv : in_re) { + expr* x = kv.m_key; + expr* y = kv.m_value; + expr* z = nullptr; + if (not_in_re.find(x, z)) { + expr* z_c = m_util.re.mk_complement(z); + expr* w = is_and ? m_util.re.mk_inter(y, z_c) : m_util.re.mk_union(y, z_c); + new_args.push_back(m_util.re.mk_in_re(x, w)); + } + else { + new_args.push_back(m_util.re.mk_in_re(x, y)); + } + } + for (auto const& kv : not_in_re) { + expr* x = kv.m_key; + expr* y = kv.m_value; + if (!in_re.contains(x)) { + new_args.push_back(m_util.re.mk_in_re(x, m_util.re.mk_complement(y))); + } + } + for (unsigned i = 0; i < n; ++i) { + expr* arg = args[i], * x; + if (!m_util.str.is_in_re(arg) && !(m().is_not(arg, x) && m_util.str.is_in_re(x))) { + new_args.push_back(arg); + } + } + + result = is_and ? m().mk_and(new_args.size(), new_args.c_ptr()) : m().mk_or(new_args.size(), new_args.c_ptr()); + return BR_REWRITE_FULL; +} br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) { SASSERT(f->get_family_id() == get_fid()); diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index daff907cb..ab8f3563d 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -179,6 +179,9 @@ class seq_rewriter { br_status mk_re_range(expr* lo, expr* hi, expr_ref& result); br_status lift_ite(func_decl* f, unsigned n, expr* const* args, expr_ref& result); + + br_status mk_bool_app_helper(bool is_and, unsigned n, expr* const* args, expr_ref& result); + bool cannot_contain_prefix(expr* a, expr* b); bool cannot_contain_suffix(expr* a, expr* b); expr_ref zero() { return expr_ref(m_autil.mk_int(0), m()); } @@ -233,6 +236,8 @@ public: void add_seqs(expr_ref_vector const& ls, expr_ref_vector const& rs, expr_ref_pair_vector& new_eqs); + br_status mk_bool_app(func_decl* f, unsigned n, expr* const* args, expr_ref& result); + }; diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 040412894..4e48a81cd 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -203,6 +203,11 @@ struct th_rewriter_cfg : public default_rewriter_cfg { if (st != BR_FAILED) return st; } + if (k == OP_AND || k == OP_OR) { + st = m_seq_rw.mk_bool_app(f, num, args, result); + if (st != BR_FAILED) + return st; + } return m_b_rw.mk_app_core(f, num, args, result); } if (fid == m_a_rw.get_fid())