diff --git a/src/ast/sls/sls_seq_plugin.cpp b/src/ast/sls/sls_seq_plugin.cpp index 1f7d927e2..3fcb0adff 100644 --- a/src/ast/sls/sls_seq_plugin.cpp +++ b/src/ast/sls/sls_seq_plugin.cpp @@ -252,6 +252,13 @@ namespace sls { return ev.lhs; } + ptr_vector const& seq_plugin::concats(expr* x) { + auto& ev = get_eval(x); + if (ev.lhs.empty()) + seq.str.get_concat(x, ev.lhs); + return ev.lhs; + } + ptr_vector const& seq_plugin::rhs(expr* eq) { lhs(eq); auto& e = get_eval(eq); @@ -593,7 +600,8 @@ namespace sls { VERIFY(m.is_eq(e, x, y)); IF_VERBOSE(3, verbose_stream() << is_true << ": " << mk_bounded_pp(e, m, 3) << "\n"); if (ctx.is_true(e)) { - if (ctx.rand(10) != 0) + //return repair_down_str_eq_edit_distance(e); + if (ctx.rand(2) != 0) return repair_down_str_eq_unify(e); if (!is_value(x)) m_str_updates.push_back({ x, strval1(y), 1 }); @@ -619,6 +627,100 @@ namespace sls { return apply_update(); } + /** + * \brief compute the edit distance between two strings. + */ + unsigned seq_plugin::edit_distance(zstring const& a, zstring const& b) { + unsigned n = a.length(); + unsigned m = b.length(); + vector d(n + 1); + for (unsigned i = 0; i <= n; ++i) + d[i].resize(m + 1, 0); + for (unsigned i = 0; i <= n; ++i) + d[i][0] = i; + for (unsigned j = 0; j <= m; ++j) + d[0][j] = j; + for (unsigned j = 1; j <= m; ++j) { + for (unsigned i = 1; i <= n; ++i) { + if (a[i - 1] == b[j - 1]) + d[i][j] = d[i - 1][j - 1]; + else + d[i][j] = std::min(std::min(d[i - 1][j] + 1, d[i][j - 1] + 1), d[i - 1][j - 1] + 1); + } + } + return d[n][m]; + } + + void seq_plugin::add_edit_updates(ptr_vector const& w, uint_set const& chars) { + for (auto x : w) { + if (is_value(x)) + continue; + zstring const & a = strval0(x); + for (auto ch : chars) + m_str_updates.push_back({ x, a + zstring(ch), 1 }); + for (auto ch : chars) + m_str_updates.push_back({ x, zstring(ch) + a, 1 }); + if (a.length() > 0) { + zstring b = a.extract(0, a.length() - 1); + m_str_updates.push_back({ x, b, 1 }); // truncate a + for (auto ch : chars) + m_str_updates.push_back({ x, b + zstring(ch), 1 }); // replace last character in a by ch + b = a.extract(1, a.length() - 1); + m_str_updates.push_back({ x, b, 1 }); // truncate a + for (auto ch : chars) + m_str_updates.push_back({ x, zstring(ch) + b, 1 }); // replace first character in a by ch + } + } + } + + bool seq_plugin::repair_down_str_eq_edit_distance(app* eq) { + auto const& L = lhs(eq); + auto const& R = rhs(eq); + zstring a, b; + uint_set a_chars, b_chars; + + for (auto x : L) { + for (auto ch : strval0(x)) + a_chars.insert(ch); + a += strval0(x); + } + for (auto y : R) { + for (auto ch : strval0(y)) + b_chars.insert(ch); + b += strval0(y); + } + if (a == b) + return update(eq->get_arg(0), a) && update(eq->get_arg(1), b); + + unsigned diff = a.length() + b.length() + L.size() + R.size(); + + add_edit_updates(L, b_chars); + add_edit_updates(R, a_chars); + + for (auto& [x, s, score] : m_str_updates) { + a.reset(); + b.reset(); + for (auto z : L) { + if (z == x) + a += s; + else + a += strval0(z); + } + for (auto z : R) { + if (z == x) + b += s; + else + b += strval0(z); + } + unsigned local_diff = edit_distance(a, b); + if (local_diff >= diff) + score = 0.1; + else + score = (diff - local_diff) * (diff - local_diff); + } + return apply_update(); + } + bool seq_plugin::repair_down_str_eq_unify(app* eq) { auto const& L = lhs(eq); auto const& R = rhs(eq); @@ -1081,6 +1183,42 @@ namespace sls { return apply_update(); } +#if 1 + bool seq_plugin::repair_down_str_concat(app* e) { + auto const& es = concats(e); + zstring value; + zstring value0 = strval0(e); + for (auto const& e : es) + value += strval0(e); + if (value == value0) + return true; + uint_set chars; + + for (auto ch : value0) + chars.insert(ch); + + add_edit_updates(es, chars); + + unsigned diff = edit_distance(value, value0); + for (auto& [x, s, score] : m_str_updates) { + value.reset(); + for (auto z : es) { + if (z == x) + value += s; + else + value += strval0(z); + } + unsigned local_diff = edit_distance(value, value0); + if (local_diff >= diff) + score = 0.1; + else + score = (diff - local_diff) * (diff - local_diff); + } + return apply_update(); + + } +#else + bool seq_plugin::repair_down_str_concat(app* e) { zstring val_e = strval0(e); unsigned len_e = val_e.length(); @@ -1125,6 +1263,7 @@ namespace sls { } return true; } +#endif diff --git a/src/ast/sls/sls_seq_plugin.h b/src/ast/sls/sls_seq_plugin.h index 22363e21f..ca93b23b4 100644 --- a/src/ast/sls/sls_seq_plugin.h +++ b/src/ast/sls/sls_seq_plugin.h @@ -71,6 +71,7 @@ namespace sls { bool repair_down_seq(app* e); bool repair_down_eq(app* e); bool repair_down_str_eq_unify(app* e); + bool repair_down_str_eq_edit_distance(app* e); bool repair_down_str_eq(app* e); bool repair_down_str_extract(app* e); bool repair_down_str_contains(expr* e); @@ -90,6 +91,9 @@ namespace sls { void repair_up_str_itos(app* e); void repair_up_str_stoi(app* e); + unsigned edit_distance(zstring const& a, zstring const& b); + void add_edit_updates(ptr_vector const& w, uint_set const& chars); + // regex functionality // enumerate set of strings that can match a prefix of regex r. @@ -111,6 +115,7 @@ namespace sls { eval* get_eval(expr* e) const; ptr_vector const& lhs(expr* eq); ptr_vector const& rhs(expr* eq); + ptr_vector const& concats(expr* eq); bool is_value(expr* e); public: diff --git a/src/util/zstring.h b/src/util/zstring.h index e661b5389..44dfcd36c 100644 --- a/src/util/zstring.h +++ b/src/util/zstring.h @@ -85,6 +85,11 @@ public: bool operator!=(const zstring& other) const; unsigned hash() const; + void reset() { m_buffer.reset(); } + zstring& operator+=(zstring const& other) { m_buffer.append(other.m_buffer); return *this; } + uint32_t const* begin() const { return m_buffer.begin(); } + uint32_t const* end() const { return m_buffer.end(); } + friend std::ostream& operator<<(std::ostream &os, const zstring &str); friend bool operator<(const zstring& lhs, const zstring& rhs); };