From 6263391c11278ca6653d61f9cc059b9b9232b4e5 Mon Sep 17 00:00:00 2001 From: Murphy Berzish Date: Wed, 17 Aug 2016 20:58:57 -0400 Subject: [PATCH] fix out-of-range integer comparison bug in string NFA --- src/ast/rewriter/str_rewriter.cpp | 290 +++++++++++++----------------- src/ast/rewriter/str_rewriter.h | 47 +++++ src/smt/theory_str.cpp | 59 +++++- src/smt/theory_str.h | 12 ++ 4 files changed, 243 insertions(+), 165 deletions(-) diff --git a/src/ast/rewriter/str_rewriter.cpp b/src/ast/rewriter/str_rewriter.cpp index fe434575e..c644ecd46 100644 --- a/src/ast/rewriter/str_rewriter.cpp +++ b/src/ast/rewriter/str_rewriter.cpp @@ -26,188 +26,150 @@ Notes: #include #include -class nfa { -protected: - str_util & m_strutil; +// Convert a regular expression to an e-NFA using Thompson's construction +void nfa::convert_re(expr * e, unsigned & start, unsigned & end, str_util & m_strutil) { + start = next_id(); + end = next_id(); + if (m_strutil.is_re_Str2Reg(e)) { + app * a = to_app(e); + expr * arg_str = a->get_arg(0); + if (m_strutil.is_string(arg_str)) { + std::string str = m_strutil.get_string_constant_value(arg_str); + TRACE("t_str_rw", tout << "build NFA for '" << str << "'" << std::endl;); - bool m_valid; - unsigned m_next_id; - - unsigned next_id() { - unsigned retval = m_next_id; - ++m_next_id; - return retval; - } - - unsigned m_start_state; - unsigned m_end_state; - - std::map > transition_map; - std::map > epsilon_map; - - void make_transition(unsigned start, char symbol, unsigned end) { - transition_map[start][symbol] = end; - } - - void make_epsilon_move(unsigned start, unsigned end) { - epsilon_map[start].insert(end); - } - - // Convert a regular expression to an e-NFA using Thompson's construction - void convert_re(expr * e, unsigned & start, unsigned & end) { - start = next_id(); - end = next_id(); - if (m_strutil.is_re_Str2Reg(e)) { - app * a = to_app(e); - expr * arg_str = a->get_arg(0); - if (m_strutil.is_string(arg_str)) { - std::string str = m_strutil.get_string_constant_value(arg_str); - TRACE("t_str_rw", tout << "build NFA for '" << str << "'" << std::endl;); - - // TODO this assumes the string is not empty - /* - * For an n-character string, we make (n-1) intermediate states, - * labelled i_(0) through i_(n-2). - * Then we construct the following transitions: - * start --str[0]--> i_(0) --str[1]--> i_(1) --...--> i_(n-2) --str[n-1]--> final - */ - unsigned last = start; - for (unsigned i = 0; i <= str.length() - 2; ++i) { - unsigned i_state = next_id(); - make_transition(last, str.at(i), i_state); - TRACE("t_str_rw", tout << "string transition " << last << "--" << str.at(i) << "--> " << i_state << std::endl;); - last = i_state; - } - make_transition(last, str.at(str.length() - 1), end); - TRACE("t_str_rw", tout << "string transition " << last << "--" << str.at(str.length() - 1) << "--> " << end << std::endl;); - TRACE("t_str_rw", tout << "string NFA: start = " << start << ", end = " << end << std::endl;); - } else { - TRACE("t_str_rw", tout << "invalid string constant in Str2Reg" << std::endl;); - m_valid = false; - return; + // TODO this assumes the string is not empty + /* + * For an n-character string, we make (n-1) intermediate states, + * labelled i_(0) through i_(n-2). + * Then we construct the following transitions: + * start --str[0]--> i_(0) --str[1]--> i_(1) --...--> i_(n-2) --str[n-1]--> final + */ + unsigned last = start; + for (int i = 0; i <= ((int)str.length()) - 2; ++i) { + unsigned i_state = next_id(); + make_transition(last, str.at(i), i_state); + TRACE("t_str_rw", tout << "string transition " << last << "--" << str.at(i) << "--> " << i_state << std::endl;); + last = i_state; } - } else if (m_strutil.is_re_RegexConcat(e)){ - app * a = to_app(e); - expr * re1 = a->get_arg(0); - expr * re2 = a->get_arg(1); - unsigned start1, end1; - convert_re(re1, start1, end1); - unsigned start2, end2; - convert_re(re2, start2, end2); - // start --e--> start1 --...--> end1 --e--> start2 --...--> end2 --e--> end - make_epsilon_move(start, start1); - make_epsilon_move(end1, start2); - make_epsilon_move(end2, end); - TRACE("t_str_rw", tout << "concat NFA: start = " << start << ", end = " << end << std::endl;); - } else if (m_strutil.is_re_RegexUnion(e)) { - app * a = to_app(e); - expr * re1 = a->get_arg(0); - expr * re2 = a->get_arg(1); - unsigned start1, end1; - convert_re(re1, start1, end1); - unsigned start2, end2; - convert_re(re2, start2, end2); - - // start --e--> start1 ; start --e--> start2 - // end1 --e--> end ; end2 --e--> end - make_epsilon_move(start, start1); - make_epsilon_move(start, start2); - make_epsilon_move(end1, end); - make_epsilon_move(end2, end); - TRACE("t_str_rw", tout << "union NFA: start = " << start << ", end = " << end << std::endl;); - } else if (m_strutil.is_re_RegexStar(e)) { - app * a = to_app(e); - expr * subex = a->get_arg(0); - unsigned start_subex, end_subex; - convert_re(subex, start_subex, end_subex); - // start --e--> start_subex, start --e--> end - // end_subex --e--> start_subex, end_subex --e--> end - make_epsilon_move(start, start_subex); - make_epsilon_move(start, end); - make_epsilon_move(end_subex, start_subex); - make_epsilon_move(end_subex, end); - TRACE("t_str_rw", tout << "star NFA: start = " << start << ", end = " << end << std::endl;); + make_transition(last, str.at(str.length() - 1), end); + TRACE("t_str_rw", tout << "string transition " << last << "--" << str.at(str.length() - 1) << "--> " << end << std::endl;); + TRACE("t_str_rw", tout << "string NFA: start = " << start << ", end = " << end << std::endl;); } else { - TRACE("t_str_rw", tout << "invalid regular expression" << std::endl;); + TRACE("t_str_rw", tout << "invalid string constant in Str2Reg" << std::endl;); m_valid = false; return; } + } else if (m_strutil.is_re_RegexConcat(e)){ + app * a = to_app(e); + expr * re1 = a->get_arg(0); + expr * re2 = a->get_arg(1); + unsigned start1, end1; + convert_re(re1, start1, end1, m_strutil); + unsigned start2, end2; + convert_re(re2, start2, end2, m_strutil); + // start --e--> start1 --...--> end1 --e--> start2 --...--> end2 --e--> end + make_epsilon_move(start, start1); + make_epsilon_move(end1, start2); + make_epsilon_move(end2, end); + TRACE("t_str_rw", tout << "concat NFA: start = " << start << ", end = " << end << std::endl;); + } else if (m_strutil.is_re_RegexUnion(e)) { + app * a = to_app(e); + expr * re1 = a->get_arg(0); + expr * re2 = a->get_arg(1); + unsigned start1, end1; + convert_re(re1, start1, end1, m_strutil); + unsigned start2, end2; + convert_re(re2, start2, end2, m_strutil); + + // start --e--> start1 ; start --e--> start2 + // end1 --e--> end ; end2 --e--> end + make_epsilon_move(start, start1); + make_epsilon_move(start, start2); + make_epsilon_move(end1, end); + make_epsilon_move(end2, end); + TRACE("t_str_rw", tout << "union NFA: start = " << start << ", end = " << end << std::endl;); + } else if (m_strutil.is_re_RegexStar(e)) { + app * a = to_app(e); + expr * subex = a->get_arg(0); + unsigned start_subex, end_subex; + convert_re(subex, start_subex, end_subex, m_strutil); + // start --e--> start_subex, start --e--> end + // end_subex --e--> start_subex, end_subex --e--> end + make_epsilon_move(start, start_subex); + make_epsilon_move(start, end); + make_epsilon_move(end_subex, start_subex); + make_epsilon_move(end_subex, end); + TRACE("t_str_rw", tout << "star NFA: start = " << start << ", end = " << end << std::endl;); + } else { + TRACE("t_str_rw", tout << "invalid regular expression" << std::endl;); + m_valid = false; + return; } +} -public: - nfa(str_util & m_strutil, expr * e) -: m_strutil(m_strutil), - m_valid(true), m_next_id(0), m_start_state(0), m_end_state(0) { - convert_re(e, m_start_state, m_end_state); - } +void nfa::epsilon_closure(unsigned start, std::set & closure) { + std::deque worklist; + closure.insert(start); + worklist.push_back(start); - bool is_valid() const { - return m_valid; - } - - void epsilon_closure(unsigned start, std::set & closure) { - std::deque worklist; - closure.insert(start); - worklist.push_back(start); - - while(!worklist.empty()) { - unsigned state = worklist.front(); - worklist.pop_front(); - if (epsilon_map.find(state) != epsilon_map.end()) { - for (std::set::iterator it = epsilon_map[state].begin(); - it != epsilon_map[state].end(); ++it) { - unsigned new_state = *it; - if (closure.find(new_state) == closure.end()) { - closure.insert(new_state); - worklist.push_back(new_state); - } + while(!worklist.empty()) { + unsigned state = worklist.front(); + worklist.pop_front(); + if (epsilon_map.find(state) != epsilon_map.end()) { + for (std::set::iterator it = epsilon_map[state].begin(); + it != epsilon_map[state].end(); ++it) { + unsigned new_state = *it; + if (closure.find(new_state) == closure.end()) { + closure.insert(new_state); + worklist.push_back(new_state); } } } } +} - bool matches(std::string input) { - /* - * Keep a set of all states the NFA can currently be in. - * Initially this is the e-closure of m_start_state - * For each character A in the input string, - * the set of next states contains - * all states in transition_map[S][A] for each S in current_states, - * and all states in epsilon_map[S] for each S in current_states. - * After consuming the entire input string, - * the match is successful iff current_states contains m_end_state. - */ - std::set current_states; - epsilon_closure(m_start_state, current_states); - for (unsigned i = 0; i < input.length(); ++i) { - char A = input.at(i); - std::set next_states; - for (std::set::iterator it = current_states.begin(); - it != current_states.end(); ++it) { - unsigned S = *it; - // check transition_map - if (transition_map[S].find(A) != transition_map[S].end()) { - next_states.insert(transition_map[S][A]); - } +bool nfa::matches(std::string input) { + /* + * Keep a set of all states the NFA can currently be in. + * Initially this is the e-closure of m_start_state + * For each character A in the input string, + * the set of next states contains + * all states in transition_map[S][A] for each S in current_states, + * and all states in epsilon_map[S] for each S in current_states. + * After consuming the entire input string, + * the match is successful iff current_states contains m_end_state. + */ + std::set current_states; + epsilon_closure(m_start_state, current_states); + for (unsigned i = 0; i < input.length(); ++i) { + char A = input.at(i); + std::set next_states; + for (std::set::iterator it = current_states.begin(); + it != current_states.end(); ++it) { + unsigned S = *it; + // check transition_map + if (transition_map[S].find(A) != transition_map[S].end()) { + next_states.insert(transition_map[S][A]); } + } - // take e-closure over next_states to compute the actual next_states - std::set epsilon_next_states; - for (std::set::iterator it = next_states.begin(); it != next_states.end(); ++it) { - unsigned S = *it; - std::set closure; - epsilon_closure(S, closure); - epsilon_next_states.insert(closure.begin(), closure.end()); - } - current_states = epsilon_next_states; - } - if (current_states.find(m_end_state) != current_states.end()) { - return true; - } else { - return false; + // take e-closure over next_states to compute the actual next_states + std::set epsilon_next_states; + for (std::set::iterator it = next_states.begin(); it != next_states.end(); ++it) { + unsigned S = *it; + std::set closure; + epsilon_closure(S, closure); + epsilon_next_states.insert(closure.begin(), closure.end()); } + current_states = epsilon_next_states; } -}; + if (current_states.find(m_end_state) != current_states.end()) { + return true; + } else { + return false; + } +} + br_status str_rewriter::mk_str_Concat(expr * arg0, expr * arg1, expr_ref & result) { TRACE("t_str_rw", tout << "rewrite (Concat " << mk_pp(arg0, m()) << " " << mk_pp(arg1, m()) << ")" << std::endl;); diff --git a/src/ast/rewriter/str_rewriter.h b/src/ast/rewriter/str_rewriter.h index 862fc3e7e..c64d086f9 100644 --- a/src/ast/rewriter/str_rewriter.h +++ b/src/ast/rewriter/str_rewriter.h @@ -21,6 +21,8 @@ Notes: #include"arith_decl_plugin.h" #include"rewriter_types.h" #include"params.h" +#include +#include class str_rewriter { str_util m_strutil; @@ -61,3 +63,48 @@ public: bool reduce_eq(expr_ref_vector& ls, expr_ref_vector& rs, expr_ref_vector& lhs, expr_ref_vector& rhs, bool& change); }; + +class nfa { +protected: + bool m_valid; + unsigned m_next_id; + + unsigned next_id() { + unsigned retval = m_next_id; + ++m_next_id; + return retval; + } + + unsigned m_start_state; + unsigned m_end_state; + + std::map > transition_map; + std::map > epsilon_map; + + void make_transition(unsigned start, char symbol, unsigned end) { + transition_map[start][symbol] = end; + } + + void make_epsilon_move(unsigned start, unsigned end) { + epsilon_map[start].insert(end); + } + + // Convert a regular expression to an e-NFA using Thompson's construction + void convert_re(expr * e, unsigned & start, unsigned & end, str_util & m_strutil); + +public: + nfa(str_util & m_strutil, expr * e) +: m_valid(true), m_next_id(0), m_start_state(0), m_end_state(0) { + convert_re(e, m_start_state, m_end_state, m_strutil); + } + + nfa() : m_valid(false), m_next_id(0), m_start_state(0), m_end_state(0) {} + + bool is_valid() const { + return m_valid; + } + + void epsilon_closure(unsigned start, std::set & closure); + + bool matches(std::string input); +}; diff --git a/src/smt/theory_str.cpp b/src/smt/theory_str.cpp index 087bf6ad0..d249649c7 100644 --- a/src/smt/theory_str.cpp +++ b/src/smt/theory_str.cpp @@ -35,6 +35,7 @@ theory_str::theory_str(ast_manager & m): opt_LCMUnrollStep(2), opt_NoQuickReturn_IntegerTheory(false), opt_DisableIntegerTheoryIntegration(false), + opt_NoCheckRegexIn(false), /* Internal setup */ search_started(false), m_autil(m), @@ -1643,7 +1644,14 @@ bool theory_str::new_eq_check(expr * lhs, expr * rhs) { check_contain_in_new_eq(lhs, rhs); } - // TODO regexInBoolMap + if (!regex_in_bool_map.empty()) { + if (opt_NoCheckRegexIn) { + TRACE("t_str", tout << "WARNING: skipping check_regex_in()" << std::endl;); + } else { + TRACE("t_str", tout << "checking regex consistency" << std::endl;); + check_regex_in(lhs, rhs); + } + } // okay, all checks here passed return true; @@ -5213,6 +5221,55 @@ void theory_str::check_concat_len_in_eqc(expr * concat) { } while (eqc_it != eqc_base); } +void theory_str::check_regex_in(expr * nn1, expr * nn2) { + context & ctx = get_context(); + ast_manager & m = get_manager(); + + expr_ref_vector eqNodeSet(m); + expr * constStr = collect_eq_nodes(nn1, eqNodeSet); + + if (constStr == NULL) { + return; + } else { + expr_ref_vector::iterator itor = eqNodeSet.begin(); + for (; itor != eqNodeSet.end(); itor++) { + if (regex_in_var_reg_str_map.find(*itor) != regex_in_var_reg_str_map.end()) { + std::set::iterator strItor = regex_in_var_reg_str_map[*itor].begin(); + for (; strItor != regex_in_var_reg_str_map[*itor].end(); strItor++) { + std::string regStr = *strItor; + std::string constStrValue = m_strutil.get_string_constant_value(constStr); + std::pair key1 = std::make_pair(*itor, regStr); + if (regex_in_bool_map.find(key1) != regex_in_bool_map.end()) { + expr * boolVar = regex_in_bool_map[key1]; // actually the RegexIn term + app * a_regexIn = to_app(boolVar); + expr * regexTerm = a_regexIn->get_arg(1); + + if (regex_nfa_cache.find(regexTerm) == regex_nfa_cache.end()) { + TRACE("t_str_detail", tout << "regex_nfa_cache: cache miss" << std::endl;); + regex_nfa_cache[regexTerm] = nfa(m_strutil, regexTerm); + } else { + TRACE("t_str_detail", tout << "regex_nfa_cache: cache hit" << std::endl;); + } + + nfa regexNFA = regex_nfa_cache[regexTerm]; + ENSURE(regexNFA.is_valid()); + bool matchRes = regexNFA.matches(constStrValue); + + TRACE("t_str_detail", tout << mk_pp(*itor, m) << " in " << regStr << " : " << (matchRes ? "yes" : "no") << std::endl;); + + expr_ref implyL(ctx.mk_eq_atom(*itor, constStr), m); + if (matchRes) { + assert_implication(implyL, boolVar); + } else { + assert_implication(implyL, m.mk_not(boolVar)); + } + } + } + } + } + } +} + /* * strArgmt::solve_concat_eq_str() * Solve concatenations of the form: diff --git a/src/smt/theory_str.h b/src/smt/theory_str.h index 06a72c3e2..8acdb4f02 100644 --- a/src/smt/theory_str.h +++ b/src/smt/theory_str.h @@ -25,6 +25,7 @@ Revision History: #include"arith_decl_plugin.h" #include #include +#include"str_rewriter.h" namespace smt { @@ -137,6 +138,14 @@ namespace smt { */ bool opt_DisableIntegerTheoryIntegration; + /* + * If NoCheckRegexIn is set to true, + * an expensive regular expression membership test is skipped. + * This option is for experiment purposes only and should be set to 'false' + * as skipping this check impacts the correctness of the solver. + */ + bool opt_NoCheckRegexIn; + bool search_started; arith_util m_autil; str_util m_strutil; @@ -221,6 +230,8 @@ namespace smt { std::map, expr*> regex_in_bool_map; std::map > regex_in_var_reg_str_map; + std::map regex_nfa_cache; // Regex term --> NFA + char * char_set; std::map charSetLookupTable; int charSetSize; @@ -423,6 +434,7 @@ namespace smt { expr * gen_unroll_assign(expr * var, std::string lcmStr, expr * testerVar, int l, int h); void reduce_virtual_regex_in(expr * var, expr * regex, expr_ref_vector & items); std::string get_std_regex_str(expr * regex); + void check_regex_in(expr * nn1, expr * nn2); void dump_assignments(); void initialize_charset();