diff --git a/src/ast/str_decl_plugin.cpp b/src/ast/str_decl_plugin.cpp index b140e11c3..ef94272c7 100644 --- a/src/ast/str_decl_plugin.cpp +++ b/src/ast/str_decl_plugin.cpp @@ -41,6 +41,7 @@ str_decl_plugin::str_decl_plugin(): m_re_regexconcat_decl(0), m_re_regexstar_decl(0), m_re_regexunion_decl(0), + m_re_unroll_decl(0), m_arith_plugin(0), m_arith_fid(0), m_int_sort(0){ @@ -69,6 +70,7 @@ void str_decl_plugin::finalize(void) { DEC_REF(m_re_regexconcat_decl); DEC_REF(m_re_regexstar_decl); DEC_REF(m_re_regexunion_decl); + DEC_REF(m_re_unroll_decl); DEC_REF(m_int_sort); } @@ -154,6 +156,9 @@ void str_decl_plugin::set_manager(ast_manager * m, family_id id) { m_re_regexunion_decl = m->mk_func_decl(symbol("RegexUnion"), re, re, re, func_decl_info(id, OP_RE_REGEXUNION)); m_manager->inc_ref(m_re_regexunion_decl); + m_re_unroll_decl = m->mk_func_decl(symbol("Unroll"), re, i, s, func_decl_info(id, OP_RE_UNROLL)); + m_manager->inc_ref(m_re_unroll_decl); + } decl_plugin * str_decl_plugin::mk_fresh() { @@ -186,6 +191,7 @@ func_decl * str_decl_plugin::mk_func_decl(decl_kind k) { case OP_RE_REGEXCONCAT: return m_re_regexconcat_decl; case OP_RE_REGEXSTAR: return m_re_regexstar_decl; case OP_RE_REGEXUNION: return m_re_regexunion_decl; + case OP_RE_UNROLL: return m_re_unroll_decl; default: return 0; } } @@ -256,6 +262,7 @@ void str_decl_plugin::get_op_names(svector & op_names, symbol cons op_names.push_back(builtin_name("RegexConcat", OP_RE_REGEXCONCAT)); op_names.push_back(builtin_name("RegexStar", OP_RE_REGEXSTAR)); op_names.push_back(builtin_name("RegexUnion", OP_RE_REGEXUNION)); + op_names.push_back(builtin_name("Unroll", OP_RE_UNROLL)); } void str_decl_plugin::get_sort_names(svector & sort_names, symbol const & logic) { diff --git a/src/ast/str_decl_plugin.h b/src/ast/str_decl_plugin.h index ccd2915af..c2ad088a4 100644 --- a/src/ast/str_decl_plugin.h +++ b/src/ast/str_decl_plugin.h @@ -47,6 +47,7 @@ enum str_op_kind { OP_RE_REGEXCONCAT, OP_RE_REGEXSTAR, OP_RE_REGEXUNION, + OP_RE_UNROLL, // end LAST_STR_OP }; @@ -75,6 +76,7 @@ protected: func_decl * m_re_regexconcat_decl; func_decl * m_re_regexstar_decl; func_decl * m_re_regexunion_decl; + func_decl * m_re_unroll_decl; arith_decl_plugin * m_arith_plugin; family_id m_arith_fid; diff --git a/src/smt/theory_str.cpp b/src/smt/theory_str.cpp index 46248abd2..947c35f98 100644 --- a/src/smt/theory_str.cpp +++ b/src/smt/theory_str.cpp @@ -454,6 +454,10 @@ app * theory_str::mk_int_var(std::string name) { return a; } +app * theory_str::mk_unroll_bound_var() { + return mk_int_var("unroll"); +} + app * theory_str::mk_str_var(std::string name) { context & ctx = get_context(); ast_manager & m = get_manager(); @@ -545,6 +549,24 @@ app * theory_str::mk_nonempty_str_var() { return a; } +app * theory_str::mk_unroll(expr * n, expr * bound) { + context & ctx = get_context(); + ast_manager & m = get_manager(); + + expr * args[2] = {n, bound}; + app * unrollFunc = get_manager().mk_app(get_id(), OP_RE_UNROLL, 0, 0, 2, args); + + expr_ref_vector items(m); + items.push_back(ctx.mk_eq_atom(ctx.mk_eq_atom(bound, mk_int(0)), ctx.mk_eq_atom(unrollFunc, m_strutil.mk_string("")))); + items.push_back(m_autil.mk_ge(bound, mk_int(0))); + items.push_back(m_autil.mk_ge(mk_strlen(unrollFunc), mk_int(0))); + + expr_ref finalAxiom(mk_and(items), m); + SASSERT(finalAxiom); + assert_axiom(finalAxiom); + return unrollFunc; +} + app * theory_str::mk_contains(expr * haystack, expr * needle) { expr * args[2] = {haystack, needle}; app * contains = get_manager().mk_app(get_id(), OP_STR_CONTAINS, 0, 0, 2, args); @@ -1342,7 +1364,16 @@ void theory_str::instantiate_axiom_RegexIn(enode * e) { items.push_back(ctx.mk_eq_atom(expr, orVar)); assert_axiom(mk_and(items)); } else if (is_RegexStar(regex)) { - NOT_IMPLEMENTED_YET(); + // slightly more complex due to the unrolling step. + expr_ref regex1(regex->get_arg(0), m); + expr_ref unrollCount(mk_unroll_bound_var(), m); + expr_ref unrollFunc(mk_unroll(regex1, unrollCount), m); + expr_ref_vector items(m); + items.push_back(ctx.mk_eq_atom(expr, ctx.mk_eq_atom(str, unrollFunc))); + items.push_back(ctx.mk_eq_atom(ctx.mk_eq_atom(unrollCount, mk_int(0)), ctx.mk_eq_atom(unrollFunc, m_strutil.mk_string("")))); + expr_ref finalAxiom(mk_and(items), m); + SASSERT(finalAxiom); + assert_axiom(finalAxiom); } else { TRACE("t_str_detail", tout << "ERROR: unknown regex expression " << mk_pp(regex, m) << "!" << std::endl;); NOT_IMPLEMENTED_YET(); @@ -3368,6 +3399,63 @@ void theory_str::process_concat_eq_type6(expr * concatAst1, expr * concatAst2) { assert_implication(ctx.mk_eq_atom(concatAst1, concatAst2), implyR); } +void theory_str::process_unroll_eq_const_str(expr * unrollFunc, expr * constStr) { + context & ctx = get_context(); + ast_manager & m = get_manager(); + + if (!is_Unroll(to_app(unrollFunc))) { + return; + } + if (!m_strutil.is_string(constStr)) { + return; + } + + expr * funcInUnroll = to_app(unrollFunc)->get_arg(0); + std::string strValue = m_strutil.get_string_constant_value(constStr); + + TRACE("t_str_detail", tout << "unrollFunc: " << mk_pp(unrollFunc, m) << std::endl + << "constStr: " << mk_pp(constStr, m) << std::endl;); + + if (strValue == "") { + return; + } + + if (is_Str2Reg(to_app(funcInUnroll))) { + unroll_str2reg_constStr(unrollFunc, constStr); + return; + } +} + +void theory_str::unroll_str2reg_constStr(expr * unrollFunc, expr * eqConstStr) { + context & ctx = get_context(); + expr * str2RegFunc = to_app(unrollFunc)->get_arg(0); + expr * strInStr2RegFunc = to_app(str2RegFunc)->get_arg(0); + expr * oriCnt = to_app(unrollFunc)->get_arg(1); + + // TODO NEXT + NOT_IMPLEMENTED_YET(); + + /* + Z3_context ctx = Z3_theory_get_context(t); + Z3_ast str2RegFunc = Z3_get_app_arg(ctx, Z3_to_app(ctx, unrollFunc), 0); + Z3_ast strInStr2RegFunc = Z3_get_app_arg(ctx, Z3_to_app(ctx, str2RegFunc), 0); + Z3_ast oriCnt = Z3_get_app_arg(ctx, Z3_to_app(ctx, unrollFunc), 1); + + std::string strValue = getConstStrValue(t, eqConstStr); + std::string regStrValue = getConstStrValue(t, strInStr2RegFunc); + int strLen = strValue.length(); + int regStrLen = regStrValue.length(); + int cnt = strLen / regStrLen; + + Z3_ast implyL = Z3_mk_eq(ctx, unrollFunc, eqConstStr); + Z3_ast implyR1 = Z3_mk_eq(ctx, oriCnt, mk_int(ctx, cnt)); + Z3_ast implyR2 = Z3_mk_eq(ctx, mk_length(t, unrollFunc), mk_int(ctx, strLen)); + Z3_ast toAssert = Z3_mk_implies(ctx, implyL, mk_2_and(t, implyR1, implyR2)); + + addAxiom(t, toAssert, __LINE__); + */ +} + /* * Look through the equivalence class of n to find a string constant. * Return that constant if it is found, and set hasEqcValue to true. @@ -3392,6 +3480,26 @@ expr * theory_str::get_eqc_value(expr * n, bool & hasEqcValue) { return n; } +void theory_str::get_eqc_all_unroll(expr * n, expr * & constStr, std::set & unrollFuncSet) { + context & ctx = get_context(); + + constStr = NULL; + unrollFuncSet.clear(); + + // iterate over the eqc of 'n' + enode * n_enode = ctx.get_enode(n); + enode * e_curr = n_enode; + do { + app * curr = e_curr->get_owner(); + if (m_strutil.is_string(curr)) { + constStr = curr; + } else if (is_Unroll(curr)) { + unrollFuncSet.insert(curr); + } + e_curr = e_curr->get_next(); + } while (e_curr != n_enode); +} + // from Z3: theory_seq.cpp static theory_mi_arith* get_th_arith(context& ctx, theory_id afid, expr* e) { @@ -4198,7 +4306,45 @@ void theory_str::handle_equality(expr * lhs, expr * rhs) { simplify_parent(lhs, rhs_value); } - // TODO regex unroll? (much later) + // regex unroll + /* + Z3_ast nn1EqConst = NULL; + std::set nn1EqUnrollFuncs; + get_eqc_allUnroll(t, nn1, nn1EqConst, nn1EqUnrollFuncs); + Z3_ast nn2EqConst = NULL; + std::set nn2EqUnrollFuncs; + get_eqc_allUnroll(t, nn2, nn2EqConst, nn2EqUnrollFuncs); + + if (nn2EqConst != NULL) { + for (std::set::iterator itor1 = nn1EqUnrollFuncs.begin(); itor1 != nn1EqUnrollFuncs.end(); itor1++) { + processUnrollEqConstStr(t, *itor1, nn2EqConst); + } + } + + if (nn1EqConst != NULL) { + for (std::set::iterator itor2 = nn2EqUnrollFuncs.begin(); itor2 != nn2EqUnrollFuncs.end(); itor2++) { + processUnrollEqConstStr(t, *itor2, nn1EqConst); + } + } + */ + expr * nn1EqConst = NULL; + std::set nn1EqUnrollFuncs; + get_eqc_all_unroll(lhs, nn1EqConst, nn1EqUnrollFuncs); + expr * nn2EqConst = NULL; + std::set nn2EqUnrollFuncs; + get_eqc_all_unroll(rhs, nn2EqConst, nn2EqUnrollFuncs); + + if (nn2EqConst != NULL) { + for (std::set::iterator itor1 = nn1EqUnrollFuncs.begin(); itor1 != nn1EqUnrollFuncs.end(); itor1++) { + process_unroll_eq_const_str(*itor1, nn2EqConst); + } + } + + if (nn1EqConst != NULL) { + for (std::set::iterator itor2 = nn2EqUnrollFuncs.begin(); itor2 != nn2EqUnrollFuncs.end(); itor2++) { + process_unroll_eq_const_str(*itor2, nn1EqConst); + } + } } void theory_str::set_up_axioms(expr * ex) { diff --git a/src/smt/theory_str.h b/src/smt/theory_str.h index 9aead1105..5bf30a266 100644 --- a/src/smt/theory_str.h +++ b/src/smt/theory_str.h @@ -183,6 +183,7 @@ namespace smt { app * mk_internal_xor_var(); expr * mk_internal_valTest_var(expr * node, int len, int vTries); app * mk_regex_rep_var(); + app * mk_unroll_bound_var(); bool is_concat(app const * a) const { return a->is_app_of(get_id(), OP_STRCAT); } bool is_concat(enode const * n) const { return is_concat(n->get_owner()); } @@ -219,7 +220,8 @@ namespace smt { bool is_RegexUnion(enode const * n) const { return is_RegexUnion(n->get_owner()); } bool is_Str2Reg(app const * a) const { return a->is_app_of(get_id(), OP_RE_STR2REGEX); } bool is_Str2Reg(enode const * n) const { return is_Str2Reg(n->get_owner()); } - + bool is_Unroll(app const * a) const { return a->is_app_of(get_id(), OP_RE_UNROLL); } + bool is_Unroll(enode const * n) const { return is_Unroll(n->get_owner()); } void instantiate_concat_axiom(enode * cat); void instantiate_basic_string_axioms(enode * str); @@ -237,6 +239,11 @@ namespace smt { expr * mk_RegexIn(expr * str, expr * regexp); void instantiate_axiom_RegexIn(enode * e); + app * mk_unroll(expr * n, expr * bound); + + void get_eqc_all_unroll(expr * n, expr * & constStr, std::set & unrollFuncSet); + void process_unroll_eq_const_str(expr * unrollFunc, expr * constStr); + void unroll_str2reg_constStr(expr * unrollFunc, expr * eqConstStr); void set_up_axioms(expr * ex); void handle_equality(expr * lhs, expr * rhs);