diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 1ae892a10..0d21f94ba 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -408,6 +408,11 @@ public: void simplify_split(split_set& s) { m_split.simplify(s); } + // decompose a membership constraint into a set of pairs of regex splits + std::pair split_membership(expr* str, expr* regex, unsigned threshold, split_set& result) const { + return m_split.split_membership(str, regex, threshold, result); + } + /** * check if regular expression is of the form all ++ s ++ all ++ t + u ++ all, where, s, t, u are sequences */ diff --git a/src/ast/rewriter/seq_split.cpp b/src/ast/rewriter/seq_split.cpp index f2026b67e..8095cff95 100644 --- a/src/ast/rewriter/seq_split.cpp +++ b/src/ast/rewriter/seq_split.cpp @@ -122,24 +122,38 @@ bool seq_split::compute(expr* r, split_set& result, unsigned threshold, split_mo expr* s = nullptr; if (rex.is_to_re(r, s)) { zstring str; - if (sq.str.is_string(s, str)) { - for (unsigned i = 0; i <= str.length(); ++i) { - const expr_ref p(rex.mk_to_re(sq.str.mk_string(str.extract(0, i))), mm); - const expr_ref q(rex.mk_to_re(sq.str.mk_string(str.extract(i, str.length() - i))), mm); - push(result, oracle, p, q); + vector stack; + stack.push_back(s); + + while (!stack.empty()) { + expr* cur = stack.back(); + stack.pop_back(); + if (seq().str.is_concat(cur, a, b)) { + stack.push_back(b); + stack.push_back(a); + } + else { + expr* ch; + unsigned cv; + if (seq().str.is_unit(cur, ch) && seq().is_const_char(ch, cv)) { + str += zstring(cv); + continue; + } + zstring str2; + if (sq.str.is_string(s, str2)) { + str = str2; + continue; + } + // not a constant string; unsupported for now + return false; } - return true; } - // a single symbolic unit behaves like one token: { , } - if (sq.str.is_unit(s)) { - const expr_ref ex(r, mm); - const expr_ref eps(rex.mk_epsilon(seq_sort), mm); - push(result, oracle, eps, ex); - push(result, oracle, ex, eps); - return true; + for (unsigned i = 0; i <= str.length(); ++i) { + const expr_ref p(rex.mk_to_re(sq.str.mk_string(str.extract(0, i))), mm); + const expr_ref q(rex.mk_to_re(sq.str.mk_string(str.extract(i, str.length() - i))), mm); + push(result, oracle, p, q); } - // to_re over a non-literal sequence: not handled. - return false; + return true; } // single-character class alpha (., [lo-hi], of_pred): @@ -391,8 +405,8 @@ void seq_split::simplify(split_set& pairs) const { std::pair seq_split::split_membership(expr* str, expr* regex, unsigned threshold, split_set& result) const { ast_manager& m = this->m(); - expr_ref_vector tokens(m); + expr_ref_vector tokens(m); vector stack; stack.push_back(str); @@ -408,9 +422,36 @@ std::pair seq_split::split_membership(expr* str, expr* regex, unsi tokens.push_back(expr_ref(cur, m)); } - SASSERT(!tokens.empty()); - expr* first = tokens.get(0); - SASSERT(seq().is_char(first)); // constants are consumed earlier + expr* ch; + unsigned i = 0; + + while (i < tokens.size() && (seq().str.is_string(tokens.get(i)) || (seq().str.is_unit(tokens.get(i), ch) && seq().is_const_char(ch)))) { + zstring s; + if (seq().str.is_string(tokens.get(i), s)) { + if (s.empty()) { + i++; + continue; + } + ch = seq().mk_char(s[0]); + tokens[i] = seq().str.mk_string(s.extract(1, s.length() - 1)); + } + else + i++; + regex = m_rw.mk_derivative(ch, regex); + } + + if (i > 0) { + unsigned j = 0; + for (; i < tokens.size(); i++, j++) { + tokens[j] = tokens.get(i); + } + tokens.shrink(j); + } + + // TODO: Do this for the back as well (also, why did no rule before do that?) + + if (tokens.empty()) + return { nullptr, nullptr }; // Choose the factorization boundary so the tail starts with the // longest run of concrete characters c. @@ -418,13 +459,13 @@ std::pair seq_split::split_membership(expr* str, expr* regex, unsi // head = u' (tokens before the run), tail = c ยท u''' (tokens from the run onward). const unsigned total = tokens.size(); unsigned run_start = 0, run_len = 0; - for (unsigned i = 0; i < total; ) { - if (!seq().is_char(tokens.get(i))) { + for (i = 1; i < total; ) { + if (!(seq().str.is_unit(tokens.get(i), ch) && seq().is_const_char(ch))) { i++; continue; } unsigned j = i; - while (j < total && seq().is_char(tokens.get(j))) { + while (j < total && seq().str.is_unit(tokens.get(j), ch) && seq().is_const_char(ch)) { j++; } if (j - i > run_len) { @@ -436,21 +477,24 @@ std::pair seq_split::split_membership(expr* str, expr* regex, unsi // No constant run => fall back to splitting off the first token. const unsigned p = run_len == 0 ? 1 : run_start; SASSERT(p >= 1); - expr* head = first; - for (unsigned i = 1; i < p; i++) { + expr* head = tokens.get(0); + for (i = 1; i < p; i++) { head = seq().str.mk_concat(head, tokens.get(i)); } - expr* tail = tokens.get(p); - for (unsigned i = p + run_len; i < tokens.size(); i++) { - tail = seq().str.mk_concat(tail, tokens.get(i)); + expr* tail = seq().str.mk_empty(head->get_sort()); + if (tokens.size() > p + run_len) { + tail = tokens.get(p + run_len); + for (i = p + run_len + 1; i < tokens.size(); i++) { + tail = seq().str.mk_concat(tail, tokens.get(i)); + } } SASSERT(head && tail); // Build the constant lookahead c and (if non-empty) an oracle that // prunes splits whose postfix cannot match c. zstring c; - for (unsigned i = 0; i < run_len; ++i) { - expr* ch; unsigned cv; + for (i = 0; i < run_len; ++i) { + unsigned cv; VERIFY(seq().str.is_unit(tokens.get(run_start + i), ch)); VERIFY(seq().is_const_char(ch, cv)); c = c + zstring(cv); @@ -471,7 +515,7 @@ std::pair seq_split::split_membership(expr* str, expr* regex, unsi // of each postfix if (!c.empty()) { unsigned w = 0; - for (unsigned i = 0; i < result.size(); ++i) { + for (i = 0; i < result.size(); ++i) { expr* d = result[i].m_n; for (unsigned k = 0; d && !seq().re.is_empty(d) && k < c.length(); ++k) { d = m_rw.mk_derivative(seq().mk_char(c[k]), d); diff --git a/src/params/smt_params_helper.pyg b/src/params/smt_params_helper.pyg index cd52b989c..d3f164f3b 100644 --- a/src/params/smt_params_helper.pyg +++ b/src/params/smt_params_helper.pyg @@ -138,6 +138,8 @@ def_module_params(module_name='smt', ('seq.validate', BOOL, False, 'enable self-validation of theory axioms created by seq theory'), ('seq.max_unfolding', UINT, 1000000000, 'maximal unfolding depth for checking string equations and regular expressions'), ('seq.min_unfolding', UINT, 1, 'initial bound for strings whose lengths are bounded by iterative deepening. Set this to a higher value if there are only models with larger string lengths'), + ('seq.regex_factorization_threshold', UINT, 10, 'maximum number of cases to factor a regex into in a single step'), + ('seq.regex_factorization_enabled', BOOL, False, 'apply regex factorization (sigma splitting)'), ('theory_aware_branching', BOOL, False, 'Allow the context to use extra information from theory solvers regarding literal branching prioritization.'), ('sls.enable', BOOL, False, 'enable sls co-processor with SMT engine'), ('sls.parallel', BOOL, True, 'use sls co-processor in parallel or sequential with SMT engine'), diff --git a/src/params/theory_seq_params.cpp b/src/params/theory_seq_params.cpp index 54bf69162..960f145a6 100644 --- a/src/params/theory_seq_params.cpp +++ b/src/params/theory_seq_params.cpp @@ -23,4 +23,6 @@ void theory_seq_params::updt_params(params_ref const & _p) { m_seq_validate = p.seq_validate(); m_seq_max_unfolding = p.seq_max_unfolding(); m_seq_min_unfolding = p.seq_min_unfolding(); + m_seq_regex_factorization_enabled = p.seq_regex_factorization_enabled(); + m_seq_regex_factorization_threshold = p.seq_regex_factorization_threshold(); } diff --git a/src/params/theory_seq_params.h b/src/params/theory_seq_params.h index f964088eb..067a65a66 100644 --- a/src/params/theory_seq_params.h +++ b/src/params/theory_seq_params.h @@ -26,6 +26,8 @@ struct theory_seq_params { bool m_seq_validate = false; unsigned m_seq_max_unfolding = UINT_MAX/4; unsigned m_seq_min_unfolding = 1; + bool m_seq_regex_factorization_enabled = false; + unsigned m_seq_regex_factorization_threshold = 1; theory_seq_params(params_ref const & p = params_ref()) { updt_params(p); diff --git a/src/smt/seq_regex.cpp b/src/smt/seq_regex.cpp index 0e9a03b63..af8b280d8 100644 --- a/src/smt/seq_regex.cpp +++ b/src/smt/seq_regex.cpp @@ -128,6 +128,31 @@ namespace smt { return; } + if (th.get_fparams().m_seq_regex_factorization_enabled) { + unsigned threshold = th.get_fparams().m_seq_regex_factorization_threshold; + if (threshold == 0) + threshold = UINT_MAX; + split_set result; + auto [head, tail] = seq_rw().split_membership(s, r, threshold, result); + if (head) { + SASSERT(tail); + // propagate all cases + expr_ref_vector cases(m); + expr_ref_vector branches(m); + for (auto [pre, post] : result) { + expr_ref mem_head(re().mk_in_re(head, pre), m); + expr_ref mem_tail(re().mk_in_re(tail, post), m); + cases.push_back(m.mk_and(mem_head, mem_tail)); + } + const expr_ref cases_expr(m.mk_or(cases), m); + ctx.internalize(cases_expr, false); + std::cout << mk_pp(s, m) << " in " << mk_pp(r, m) << " =>\n" << mk_pp(cases_expr, m) << std::endl; + th.propagate_lit(nullptr, 1, &lit, ctx.get_literal(cases_expr)); + return; + } + // fallthrough; decomposition failed + } + // Convert a non-ground sequence into an additional regex and // strengthen the original regex constraint into an intersection // for example: