From 87d2a3b4e55081955720e4081b320a1f61ce0650 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 4 May 2022 01:10:18 -0700 Subject: [PATCH] map/mapi/foldl/foldli Signed-off-by: Nikolaj Bjorner --- src/ast/array_decl_plugin.h | 2 + src/ast/rewriter/seq_rewriter.cpp | 121 ++++++++++++++++++++++++++++++ src/ast/rewriter/seq_rewriter.h | 4 + src/ast/seq_decl_plugin.cpp | 25 +++++- src/ast/seq_decl_plugin.h | 17 ++++- 5 files changed, 167 insertions(+), 2 deletions(-) diff --git a/src/ast/array_decl_plugin.h b/src/ast/array_decl_plugin.h index 7298a0e47..5a606a509 100644 --- a/src/ast/array_decl_plugin.h +++ b/src/ast/array_decl_plugin.h @@ -272,6 +272,8 @@ public: func_decl * mk_array_ext(sort* domain, unsigned i); sort * mk_array_sort(sort* dom, sort* range) { return mk_array_sort(1, &dom, range); } + sort * mk_array_sort(sort* a, sort* b, sort* range) { sort* dom[2] = { a, b }; return mk_array_sort(2, dom, range); } + sort * mk_array_sort(sort* a, sort* b, sort* c, sort* range) { sort* dom[3] = { a, b, c}; return mk_array_sort(3, dom, range); } sort * mk_array_sort(unsigned arity, sort* const* domain, sort* range); diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 9fcc35b2d..877263e79 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -673,6 +673,22 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con SASSERT(num_args == 3); st = mk_seq_replace_all(args[0], args[1], args[2], result); break; + case OP_SEQ_MAP: + SASSERT(num_args == 2); + st = mk_seq_map(args[0], args[1], result); + break; + case OP_SEQ_MAPI: + SASSERT(num_args == 3); + st = mk_seq_mapi(args[0], args[1], args[2], result); + break; + case OP_SEQ_FOLDL: + SASSERT(num_args == 3); + st = mk_seq_foldl(args[0], args[1], args[2], result); + break; + case OP_SEQ_FOLDLI: + SASSERT(num_args == 4); + st = mk_seq_foldli(args[0], args[1], args[2], args[3], result); + break; case OP_SEQ_REPLACE_RE: SASSERT(num_args == 3); st = mk_seq_replace_re(args[0], args[1], args[2], result); @@ -850,6 +866,14 @@ br_status seq_rewriter::mk_seq_length(expr* a, expr_ref& result) { result = str().mk_length(x); return BR_REWRITE1; } + if (str().is_map(a, x, y)) { + result = str().mk_length(y); + return BR_REWRITE1; + } + if (str().is_mapi(a, x, y, z)) { + result = str().mk_length(z); + return BR_REWRITE1; + } #if 0 expr* s = nullptr, *offset = nullptr, *length = nullptr; if (str().is_extract(a, s, offset, length)) { @@ -1640,6 +1664,13 @@ br_status seq_rewriter::mk_seq_nth_i(expr* a, expr* b, expr_ref& result) { return BR_REWRITE1; } + expr* f, *s; + if (str().is_map(a, f, s)) { + expr* args[2] = { f, str().mk_nth_i(s, b) }; + result = array_util(m()).mk_select(2, args); + return BR_REWRITE1; + } + expr_ref_vector as(m()); str().get_concat_units(a, as); @@ -2008,6 +2039,96 @@ br_status seq_rewriter::mk_seq_replace_all(expr* a, expr* b, expr* c, expr_ref& return BR_FAILED; } +/** + rewrites for map(f, s): + + map(f, []) = [] + map(f, [x]) = [f(x)] + map(f, s + t) = map(f, s) + map(f, t) + len(map(f, s)) = len(s) + nth_i(map(f,s), i) = f(nth_i(s, i)) + + */ +br_status seq_rewriter::mk_seq_map(expr* f, expr* seqA, expr_ref& result) { + if (str().is_empty(seqA)) { + result = str().mk_empty(get_array_range(f->get_sort())); + return BR_DONE; + } + expr* a, *s1, *s2; + if (str().is_unit(seqA, a)) { + array_util array(m()); + expr* args[2] = { f, a }; + result = str().mk_unit(array.mk_select(2, args)); + return BR_REWRITE2; + } + if (str().is_concat(seqA, s1, s2)) { + result = str().mk_concat(str().mk_map(f, s1), str().mk_map(f, s2)); + return BR_REWRITE2; + } + return BR_FAILED; +} + +br_status seq_rewriter::mk_seq_mapi(expr* f, expr* i, expr* seqA, expr_ref& result) { + if (str().is_empty(seqA)) { + result = str().mk_empty(get_array_range(f->get_sort())); + return BR_DONE; + } + expr* a, *s1, *s2; + if (str().is_unit(seqA, a)) { + array_util array(m()); + expr* args[3] = { f, i, a }; + result = str().mk_unit(array.mk_select(3, args)); + return BR_REWRITE2; + } + if (str().is_concat(seqA, s1, s2)) { + expr_ref j(m_autil.mk_add(i, str().mk_length(s1)), m()); + result = str().mk_concat(str().mk_mapi(f, i, s1), str().mk_mapi(f, j, s2)); + return BR_REWRITE2; + } + return BR_FAILED; +} + +br_status seq_rewriter::mk_seq_foldl(expr* f, expr* b, expr* seqA, expr_ref& result) { + if (str().is_empty(seqA)) { + result = b; + return BR_DONE; + } + expr* a, *s1, *s2; + if (str().is_unit(seqA, a)) { + array_util array(m()); + expr* args[3] = { f, b, a }; + result = array.mk_select(3, args); + return BR_REWRITE1; + } + if (str().is_concat(seqA, s1, s2)) { + result = str().mk_foldl(f, b, s1); + result = str().mk_foldl(f, result, s2); + return BR_REWRITE3; + } + return BR_FAILED; +} + +br_status seq_rewriter::mk_seq_foldli(expr* f, expr* i, expr* b, expr* seqA, expr_ref& result) { + if (str().is_empty(seqA)) { + result = b; + return BR_DONE; + } + expr* a, *s1, *s2; + if (str().is_unit(seqA, a)) { + array_util array(m()); + expr* args[4] = { f, i, b, a }; + result = array.mk_select(4, args); + return BR_REWRITE1; + } + if (str().is_concat(seqA, s1, s2)) { + expr_ref j(m_autil.mk_add(i, str().mk_length(s1)), m()); + result = str().mk_foldli(f, i, b, s1); + result = str().mk_foldli(f, j, result, s2); + return BR_REWRITE3; + } + return BR_FAILED; +} + /* * Returns false if s is not a single unit value or concatenation of unit values. * Else extracts the units from s into vals and returns true. diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index f10532572..500972b1f 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -256,6 +256,10 @@ class seq_rewriter { br_status mk_seq_replace_re(expr* a, expr* b, expr* c, expr_ref& result); br_status mk_seq_prefix(expr* a, expr* b, expr_ref& result); br_status mk_seq_suffix(expr* a, expr* b, expr_ref& result); + br_status mk_seq_map(expr* f, expr* s, expr_ref& result); + br_status mk_seq_mapi(expr* f, expr* i, expr* s, expr_ref& result); + br_status mk_seq_foldl(expr* f, expr* b, expr* s, expr_ref& result); + br_status mk_seq_foldli(expr* f, expr* i, expr* b, expr* s, expr_ref& result); br_status mk_str_units(func_decl* f, expr_ref& result); br_status mk_str_itos(expr* a, expr_ref& result); br_status mk_str_stoi(expr* a, expr_ref& result); diff --git a/src/ast/seq_decl_plugin.cpp b/src/ast/seq_decl_plugin.cpp index f4cf2ecaa..77179a263 100644 --- a/src/ast/seq_decl_plugin.cpp +++ b/src/ast/seq_decl_plugin.cpp @@ -182,18 +182,26 @@ sort* seq_decl_plugin::apply_binding(ptr_vector const& binding, sort* s) { void seq_decl_plugin::init() { if (m_init) return; ast_manager& m = *m_manager; + array_util autil(m); m_init = true; sort* A = m.mk_uninterpreted_sort(symbol(0u)); + sort* B = m.mk_uninterpreted_sort(symbol(1u)); sort* strT = m_string; parameter paramA(A); + parameter paramB(B); parameter paramS(strT); sort* seqA = m.mk_sort(m_family_id, SEQ_SORT, 1, ¶mA); + sort* seqB = m.mk_sort(m_family_id, SEQ_SORT, 1, ¶mB); parameter paramSA(seqA); sort* reA = m.mk_sort(m_family_id, RE_SORT, 1, ¶mSA); sort* reT = m.mk_sort(m_family_id, RE_SORT, 1, ¶mS); sort* boolT = m.mk_bool_sort(); sort* intT = arith_util(m).mk_int(); - sort* predA = array_util(m).mk_array_sort(A, boolT); + sort* predA = autil.mk_array_sort(A, boolT); + sort* arrAB = autil.mk_array_sort(A, B); + sort* arrIAB = autil.mk_array_sort(intT, A, B); + sort* arrBAB = autil.mk_array_sort(B, A, B); + sort* arrIBAB = autil.mk_array_sort(intT, B, A, B); sort* seqAseqAseqA[3] = { seqA, seqA, seqA }; sort* seqAreAseqA[3] = { seqA, reA, seqA }; sort* seqAseqA[2] = { seqA, seqA }; @@ -209,6 +217,11 @@ void seq_decl_plugin::init() { sort* str2TintT[3] = { strT, strT, intT }; sort* seqAintT[2] = { seqA, intT }; sort* seq3A[3] = { seqA, seqA, seqA }; + sort* arrABseqA[2] = { arrAB, seqA }; + sort* arrIABintTseqA[3] = { arrIAB, intT, seqA }; + sort* arrBAB_BseqA[3] = { arrBAB, B,seqA }; + sort* arrIBABintTBseqA[4] = { arrIBAB, intT, B, seqA }; + m_sigs.resize(LAST_SEQ_OP); // TBD: have (par ..) construct and load parameterized signature from premable. m_sigs[OP_SEQ_UNIT] = alloc(psig, m, "seq.unit", 1, 1, &A, seqA); @@ -226,6 +239,10 @@ void seq_decl_plugin::init() { m_sigs[OP_SEQ_NTH_I] = alloc(psig, m, "seq.nth_i", 1, 2, seqAintT, A); m_sigs[OP_SEQ_NTH_U] = alloc(psig, m, "seq.nth_u", 1, 2, seqAintT, A); m_sigs[OP_SEQ_LENGTH] = alloc(psig, m, "seq.len", 1, 1, &seqA, intT); + m_sigs[OP_SEQ_MAP] = alloc(psig, m, "seq.map", 2, 2, arrABseqA, seqB); + m_sigs[OP_SEQ_MAPI] = alloc(psig, m, "seq.mapi", 2, 3, arrIABintTseqA, seqB); + m_sigs[OP_SEQ_FOLDL] = alloc(psig, m, "seq.fold_left", 2, 3, arrBAB_BseqA, B); + m_sigs[OP_SEQ_FOLDLI] = alloc(psig, m, "seq.fold_leftli", 2, 4, arrIBABintTBseqA, B); m_sigs[OP_RE_PLUS] = alloc(psig, m, "re.+", 1, 1, &reA, reA); m_sigs[OP_RE_STAR] = alloc(psig, m, "re.*", 1, 1, &reA, reA); m_sigs[OP_RE_OPTION] = alloc(psig, m, "re.opt", 1, 1, &reA, reA); @@ -582,6 +599,12 @@ func_decl* seq_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, p case _OP_STRING_STRCTN: return mk_str_fun(k, arity, domain, range, OP_SEQ_CONTAINS); + case OP_SEQ_MAP: + case OP_SEQ_MAPI: + case OP_SEQ_FOLDL: + case OP_SEQ_FOLDLI: + return mk_str_fun(k, arity, domain, range, k); + case OP_SEQ_TO_RE: m_has_re = true; return mk_seq_fun(k, arity, domain, range, _OP_STRING_TO_REGEXP); diff --git a/src/ast/seq_decl_plugin.h b/src/ast/seq_decl_plugin.h index ddde7fa6a..03cfef033 100644 --- a/src/ast/seq_decl_plugin.h +++ b/src/ast/seq_decl_plugin.h @@ -55,7 +55,11 @@ enum seq_op_kind { OP_SEQ_REPLACE_RE_ALL, // Seq -> RegEx -> Seq -> Seq OP_SEQ_REPLACE_RE, // Seq -> RegEx -> Seq -> Seq OP_SEQ_REPLACE_ALL, // Seq -> Seq -> Seq -> Seq - + OP_SEQ_MAP, // Array[A,B] -> Seq[A] -> Seq[B] + OP_SEQ_MAPI, // Array[Int,A,B] -> Int -> Seq[A] -> Seq[B] + OP_SEQ_FOLDL, // Array[B,A,B] -> B -> Seq[A] -> B + OP_SEQ_FOLDLI, // Array[Int,B,A,B] -> Int -> B -> Seq[A] -> B + OP_RE_PLUS, OP_RE_STAR, OP_RE_OPTION, @@ -296,6 +300,10 @@ public: app* mk_nth_i(expr* s, expr* i) const { expr* es[2] = { s, i }; return m.mk_app(m_fid, OP_SEQ_NTH_I, 2, es); } app* mk_nth_u(expr* s, expr* i) const { expr* es[2] = { s, i }; return m.mk_app(m_fid, OP_SEQ_NTH_U, 2, es); } app* mk_nth_c(expr* s, unsigned i) const; + app* mk_map(expr* f, expr* s) const { expr* es[2] = { f, s }; return m.mk_app(m_fid, OP_SEQ_MAP, 2, es); } + app* mk_mapi(expr* f, expr* i, expr* s) const { expr* es[3] = { f, i, s }; return m.mk_app(m_fid, OP_SEQ_MAPI, 3, es); } + app* mk_foldl(expr* f, expr* b, expr* s) const { expr* es[3] = { f, b, s }; return m.mk_app(m_fid, OP_SEQ_FOLDL, 3, es); } + app* mk_foldli(expr* f, expr* i, expr* b, expr* s) const { expr* es[4] = { f, i, b, s }; return m.mk_app(m_fid, OP_SEQ_FOLDLI, 4, es); } app* mk_substr(expr* a, expr* b, expr* c) const { expr* es[3] = { a, b, c }; return m.mk_app(m_fid, OP_SEQ_EXTRACT, 3, es); } app* mk_contains(expr* a, expr* b) const { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_CONTAINS, 2, es); } @@ -333,6 +341,10 @@ public: } bool is_concat(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_CONCAT); } bool is_length(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_LENGTH); } + bool is_map(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_MAP); } + bool is_mapi(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_MAPI); } + bool is_foldl(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_FOLDL); } + bool is_foldli(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_FOLDLI); } bool is_extract(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_EXTRACT); } bool is_contains(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_CONTAINS); } bool is_at(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_AT); } @@ -384,6 +396,9 @@ public: MATCH_BINARY(is_nth_u); MATCH_BINARY(is_index); MATCH_TERNARY(is_index); + MATCH_BINARY(is_map); + MATCH_TERNARY(is_mapi); + MATCH_TERNARY(is_foldl); MATCH_BINARY(is_last_index); MATCH_TERNARY(is_replace); MATCH_TERNARY(is_replace_re);