diff --git a/src/ast/rewriter/str_rewriter.cpp b/src/ast/rewriter/str_rewriter.cpp index fdb67f89e..54e0dd443 100644 --- a/src/ast/rewriter/str_rewriter.cpp +++ b/src/ast/rewriter/str_rewriter.cpp @@ -22,6 +22,192 @@ Notes: #include"ast_pp.h" #include"ast_util.h" #include"well_sorted.h" +#include +#include +#include + +class nfa { +protected: + str_util & m_strutil; + + bool m_valid; + unsigned m_next_id; + + unsigned next_id() { + unsigned retval = m_next_id; + ++m_next_id; + return retval; + } + + unsigned m_start_state; + unsigned m_end_state; + + std::map > transition_map; + std::map > epsilon_map; + + void make_transition(unsigned start, char symbol, unsigned end) { + transition_map[start][symbol] = end; + } + + void make_epsilon_move(unsigned start, unsigned end) { + epsilon_map[start].insert(end); + } + + // Convert a regular expression to an e-NFA using Thompson's construction + void convert_re(expr * e, unsigned & start, unsigned & end) { + start = next_id(); + end = next_id(); + if (m_strutil.is_re_Str2Reg(e)) { + app * a = to_app(e); + expr * arg_str = a->get_arg(0); + if (m_strutil.is_string(arg_str)) { + std::string str = m_strutil.get_string_constant_value(arg_str); + TRACE("t_str_rw", tout << "build NFA for '" << str << "'" << std::endl;); + + // TODO this assumes the string is not empty + /* + * For an n-character string, we make (n-1) intermediate states, + * labelled i_(0) through i_(n-2). + * Then we construct the following transitions: + * start --str[0]--> i_(0) --str[1]--> i_(1) --...--> i_(n-2) --str[n-1]--> final + */ + unsigned last = start; + for (unsigned i = 0; i <= str.length() - 2; ++i) { + unsigned i_state = next_id(); + make_transition(last, str.at(i), i_state); + TRACE("t_str_rw", tout << "string transition " << last << "--" << str.at(i) << "--> " << i_state << std::endl;); + last = i_state; + } + make_transition(last, str.at(str.length() - 1), end); + TRACE("t_str_rw", tout << "string transition " << last << "--" << str.at(str.length() - 1) << "--> " << end << std::endl;); + TRACE("t_str_rw", tout << "string NFA: start = " << start << ", end = " << end << std::endl;); + } else { + TRACE("t_str_rw", tout << "invalid string constant in Str2Reg" << std::endl;); + m_valid = false; + return; + } + } else if (m_strutil.is_re_RegexConcat(e)){ + app * a = to_app(e); + expr * re1 = a->get_arg(0); + expr * re2 = a->get_arg(1); + unsigned start1, end1; + convert_re(re1, start1, end1); + unsigned start2, end2; + convert_re(re2, start2, end2); + // start --e--> start1 --...--> end1 --e--> start2 --...--> end2 --e--> end + make_epsilon_move(start, start1); + make_epsilon_move(end1, start2); + make_epsilon_move(end2, end); + TRACE("t_str_rw", tout << "concat NFA: start = " << start << ", end = " << end << std::endl;); + } else if (m_strutil.is_re_RegexUnion(e)) { + app * a = to_app(e); + expr * re1 = a->get_arg(0); + expr * re2 = a->get_arg(1); + unsigned start1, end1; + convert_re(re1, start1, end1); + unsigned start2, end2; + convert_re(re2, start2, end2); + + // start --e--> start1 ; start --e--> start2 + // end1 --e--> end ; end2 --e--> end + make_epsilon_move(start, start1); + make_epsilon_move(start, start2); + make_epsilon_move(end1, end); + make_epsilon_move(end2, end); + TRACE("t_str_rw", tout << "union NFA: start = " << start << ", end = " << end << std::endl;); + } else if (m_strutil.is_re_RegexStar(e)) { + app * a = to_app(e); + expr * subex = a->get_arg(0); + unsigned start_subex, end_subex; + convert_re(subex, start_subex, end_subex); + // start --e--> start_subex, start --e--> end + // end_subex --e--> start_subex, end_subex --e--> end + make_epsilon_move(start, start_subex); + make_epsilon_move(start, end); + make_epsilon_move(end_subex, start_subex); + make_epsilon_move(end_subex, end); + TRACE("t_str_rw", tout << "star NFA: start = " << start << ", end = " << end << std::endl;); + } else { + TRACE("t_str_rw", tout << "invalid regular expression" << std::endl;); + m_valid = false; + return; + } + } + +public: + nfa(str_util & m_strutil, expr * e) +: m_strutil(m_strutil), + m_valid(true), m_next_id(0), m_start_state(0), m_end_state(0) { + convert_re(e, m_start_state, m_end_state); + } + + bool is_valid() const { + return m_valid; + } + + void epsilon_closure(unsigned start, std::set & closure) { + std::deque worklist; + closure.insert(start); + worklist.push_back(start); + + while(!worklist.empty()) { + unsigned state = worklist.front(); + worklist.pop_front(); + if (epsilon_map.find(state) != epsilon_map.end()) { + for (std::set::iterator it = epsilon_map[state].begin(); + it != epsilon_map[state].end(); ++it) { + unsigned new_state = *it; + if (closure.find(new_state) == closure.end()) { + closure.insert(new_state); + worklist.push_back(new_state); + } + } + } + } + } + + bool matches(std::string input) { + /* + * Keep a set of all states the NFA can currently be in. + * Initially this is the e-closure of m_start_state + * For each character A in the input string, + * the set of next states contains + * all states in transition_map[S][A] for each S in current_states, + * and all states in epsilon_map[S] for each S in current_states. + * After consuming the entire input string, + * the match is successful iff current_states contains m_end_state. + */ + std::set current_states; + epsilon_closure(m_start_state, current_states); + for (unsigned i = 0; i < input.length(); ++i) { + char A = input.at(i); + std::set next_states; + for (std::set::iterator it = current_states.begin(); + it != current_states.end(); ++it) { + unsigned S = *it; + // check transition_map + if (transition_map[S].find(A) != transition_map[S].end()) { + next_states.insert(transition_map[S][A]); + } + } + + // take e-closure over next_states to compute the actual next_states + std::set epsilon_next_states; + for (std::set::iterator it = next_states.begin(); it != next_states.end(); ++it) { + unsigned S = *it; + std::set closure; + epsilon_closure(S, closure); + epsilon_next_states.insert(closure.begin(), closure.end()); + } + current_states = epsilon_next_states; + } + if (current_states.find(m_end_state) != current_states.end()) { + return true; + } else { + return false; + } + } +}; br_status str_rewriter::mk_str_Concat(expr * arg0, expr * arg1, expr_ref & result) { TRACE("t_str_rw", tout << "rewrite (Concat " << mk_pp(arg0, m()) << " " << mk_pp(arg1, m()) << ")" << std::endl;); @@ -243,6 +429,20 @@ br_status str_rewriter::mk_re_RegexIn(expr * str, expr * re, expr_ref & result) return BR_REWRITE_FULL; } + // necessary for model validation + if (m_strutil.is_string(str)) { + TRACE("t_str_rw", tout << "RegexIn with constant string argument" << std::endl;); + nfa regex_nfa(m_strutil, re); + ENSURE(regex_nfa.is_valid()); + std::string input = m_strutil.get_string_constant_value(str); + if (regex_nfa.matches(input)) { + result = m().mk_true(); + } else { + result = m().mk_false(); + } + return BR_DONE; + } + return BR_FAILED; } diff --git a/src/ast/str_decl_plugin.h b/src/ast/str_decl_plugin.h index 4b7a8858e..5b0ca2a3a 100644 --- a/src/ast/str_decl_plugin.h +++ b/src/ast/str_decl_plugin.h @@ -111,7 +111,6 @@ public: virtual bool is_value(app * e) const; virtual bool is_unique_value(app * e) const { return is_value(e); } - // TODO }; class str_recognizers { @@ -125,11 +124,12 @@ public: bool is_string(expr const * n) const; bool is_re_Str2Reg(expr const * n) const { return is_app_of(n, get_fid(), OP_RE_STR2REGEX); } + bool is_re_RegexConcat(expr const * n) const { return is_app_of(n, get_fid(), OP_RE_REGEXCONCAT); } + bool is_re_RegexUnion(expr const * n) const { return is_app_of(n, get_fid(), OP_RE_REGEXUNION); } bool is_re_RegexStar(expr const * n) const { return is_app_of(n, get_fid(), OP_RE_REGEXSTAR); } bool is_re_RegexPlus(expr const * n) const { return is_app_of(n, get_fid(), OP_RE_REGEXPLUS); } std::string get_string_constant_value(expr const *n) const; - // TODO }; class str_util : public str_recognizers {