From b529a58b91570618d11a95a284165a624820bf0d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 15 Dec 2024 05:53:28 -0800 Subject: [PATCH] add unit test for incremental equation edit distance with repair --- src/ast/sls/sls_seq_plugin.cpp | 209 ++++++++++--------- src/ast/sls/sls_seq_plugin.h | 13 +- src/test/CMakeLists.txt | 1 + src/test/main.cpp | 1 + src/test/sls_seq_plugin.cpp | 354 +++++++++++++++++++++++++++++++++ 5 files changed, 469 insertions(+), 109 deletions(-) create mode 100644 src/test/sls_seq_plugin.cpp diff --git a/src/ast/sls/sls_seq_plugin.cpp b/src/ast/sls/sls_seq_plugin.cpp index 6a77dbb95..3c54e549c 100644 --- a/src/ast/sls/sls_seq_plugin.cpp +++ b/src/ast/sls/sls_seq_plugin.cpp @@ -662,67 +662,6 @@ namespace sls { return d[n][m]; } - /** - * \brief edit distance with update calculation - */ - unsigned seq_plugin::edit_distance_with_updates(zstring const& a, bool_vector const& a_is_value, zstring const& b, bool_vector const& b_is_value) { - unsigned n = a.length(); - unsigned m = b.length(); - vector d(n + 1); // edit distance - vector u(n + 1); // edit distance with updates. - m_string_updates.reset(); - for (unsigned i = 0; i <= n; ++i) { - d[i].resize(m + 1, 0); - u[i].resize(m + 1, 0); - } - for (unsigned i = 0; i <= n; ++i) - d[i][0] = i, u[i][0] = i; - for (unsigned j = 0; j <= m; ++j) - d[0][j] = j, u[0][j] = j; - for (unsigned j = 1; j <= m; ++j) { - for (unsigned i = 1; i <= n; ++i) { - if (a[i - 1] == b[j - 1]) { - d[i][j] = d[i - 1][j - 1]; - u[i][j] = u[i - 1][j - 1]; - } - else { - u[i][j] = 1 + std::min(u[i - 1][j], std::min(u[i][j - 1], u[i - 1][j - 1])); - d[i][j] = 1 + std::min(d[i - 1][j], std::min(d[i][j - 1], d[i - 1][j - 1])); - - // TODO: take into account for a_is_value[i - 1] and b_is_value[j - 1] - // and whether index i-1, j-1 is at the boundary of an empty string variable. - - if (d[i - 1][j] < u[i][j] && !a_is_value[i - 1]) { - m_string_updates.reset(); - u[i][j] = d[i - 1][j]; - } - if (d[i][j - 1] < u[i][j] && !b_is_value[i - 1]) { - m_string_updates.reset(); - u[i][j] = d[i][j - 1]; - } - if (d[i - 1][j - 1] < u[i][j] && (!a_is_value[i - 1] || !b_is_value[j - 1])) { - m_string_updates.reset(); - u[i][j] = d[i - 1][j - 1]; - } - if (d[i - 1][j] == u[i][j] && !a_is_value[i - 1]) { - add_string_update(side_t::left, op_t::del, i - 1, 0); - add_string_update(side_t::left, op_t::add, j - 1, i - 1); - } - if (d[i][j - 1] == u[i][j] && !b_is_value[j - 1]) { - add_string_update(side_t::right, op_t::del, j - 1, 0); - add_string_update(side_t::right, op_t::add, i - 1, j - 1); - } - if (d[i - 1][j - 1] == u[i][j] && !a_is_value[i - 1]) - add_string_update(side_t::left, op_t::copy, j - 1, i - 1); - - if (d[i - 1][j - 1] == u[i][j] && !b_is_value[j - 1]) - add_string_update(side_t::right, op_t::copy, i - 1, j - 1); - - } - } - } - return u[n][m]; - } void seq_plugin::add_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars) { for (auto x : w) { @@ -793,67 +732,124 @@ namespace sls { #endif } + void seq_plugin::init_string_instance(ptr_vector const& es, string_instance& a) { + bool prev_is_var = false; + for (auto x : es) { + auto const& val = strval0(x); + auto len = val.length(); + bool is_val = is_value(x); + a.s += val; + if (!prev_is_var && !is_val && !a.next_is_var.empty()) + a.next_is_var.back() = true; + for (unsigned i = 0; i < len; ++i) { + a.is_value.push_back(is_val); + a.prev_is_var.push_back(false); + a.next_is_var.push_back(false); + } + if (len > 0 && is_val && prev_is_var && !a.is_value.empty()) + a.prev_is_var[a.prev_is_var.size() - len] = true; + prev_is_var = !is_val; + } + } + + + /** + * \brief edit distance with update calculation + */ + unsigned seq_plugin::edit_distance_with_updates(string_instance const& a, string_instance const& b) { + unsigned n = a.s.length(); + unsigned m = b.s.length(); + vector d(n + 1); // edit distance + vector u(n + 1); // edit distance with updates. + m_string_updates.reset(); + for (unsigned i = 0; i <= n; ++i) { + d[i].resize(m + 1, 0); + u[i].resize(m + 1, 0); + } + for (unsigned i = 0; i <= n; ++i) + d[i][0] = i, u[i][0] = i; + for (unsigned j = 0; j <= m; ++j) + d[0][j] = j, u[0][j] = j; + for (unsigned j = 1; j <= m; ++j) { + for (unsigned i = 1; i <= n; ++i) { + if (a.s[i - 1] == b.s[j - 1]) { + d[i][j] = d[i - 1][j - 1]; + u[i][j] = u[i - 1][j - 1]; + } + else { + u[i][j] = 1 + std::min(u[i - 1][j], std::min(u[i][j - 1], u[i - 1][j - 1])); + d[i][j] = 1 + std::min(d[i - 1][j], std::min(d[i][j - 1], d[i - 1][j - 1])); + + if (d[i - 1][j] < u[i][j] && a.can_add(i - 1)) { + m_string_updates.reset(); + u[i][j] = d[i - 1][j]; + } + if (d[i][j - 1] < u[i][j] && b.can_add(i - 1)) { + m_string_updates.reset(); + u[i][j] = d[i][j - 1]; + } + if (d[i - 1][j - 1] < u[i][j] && (a.can_add(i - 1) || b.can_add(j - 1))) { + m_string_updates.reset(); + u[i][j] = d[i - 1][j - 1]; + } + if (d[i - 1][j] == u[i][j] && a.can_add(i - 1)) + add_string_update(side_t::left, op_t::add, j - 1, i - 1); + + if (d[i][j - 1] == u[i][j] && b.can_add(j - 1)) + add_string_update(side_t::right, op_t::add, i - 1, j - 1); + + if (d[i - 1][j] == u[i][j] && !a.is_value[i - 1]) + add_string_update(side_t::left, op_t::del, i - 1, 0); + + if (d[i][j - 1] == u[i][j] && !b.is_value[j - 1]) + add_string_update(side_t::right, op_t::del, j - 1, 0); + + if (d[i - 1][j - 1] == u[i][j] && !a.is_value[i - 1]) + add_string_update(side_t::left, op_t::copy, j - 1, i - 1); + + if (d[i - 1][j - 1] == u[i][j] && !b.is_value[j - 1]) + add_string_update(side_t::right, op_t::copy, i - 1, j - 1); + } + } + } + return u[n][m]; + } + bool seq_plugin::repair_down_str_eq_edit_distance_incremental(app* eq) { auto const& L = lhs(eq); auto const& R = rhs(eq); - zstring a, b; - bool_vector a_is_value, b_is_value; + string_instance a, b; + init_string_instance(L, a); + init_string_instance(R, b); - for (auto x : L) { - auto const& val = strval0(x); - auto len = val.length(); - auto is_val = is_value(x); - a += val; - for (unsigned i = 0; i < len; ++i) - a_is_value.push_back(is_val); - } - - for (auto y : R) { - auto const& val = strval0(y); - auto len = val.length(); - auto is_val = is_value(y); - b += val; - for (unsigned i = 0; i < len; ++i) - b_is_value.push_back(is_val); - } - - if (a == b) - return update(eq->get_arg(0), a) && update(eq->get_arg(1), b); + if (a.s == b.s) + return update(eq->get_arg(0), a.s) && update(eq->get_arg(1), b.s); - unsigned diff = edit_distance_with_updates(a, a_is_value, b, b_is_value); - if (a.length() == 0) { - m_str_updates.push_back({ eq->get_arg(1), zstring(), 1 }); - m_str_updates.push_back({ eq->get_arg(0), zstring(b[0]), 1}); - m_str_updates.push_back({ eq->get_arg(0), zstring(b[b.length() - 1]), 1}); - } - if (b.length() == 0) { - m_str_updates.push_back({ eq->get_arg(0), zstring(), 1 }); - m_str_updates.push_back({ eq->get_arg(1), zstring(a[0]), 1 }); - m_str_updates.push_back({ eq->get_arg(1), zstring(a[a.length() - 1]), 1 }); - } + unsigned diff = edit_distance_with_updates(a, b); - verbose_stream() << "diff \"" << a << "\" \"" << b << "\" diff " << diff << " updates " << m_string_updates.size() << "\n"; + + verbose_stream() << "diff \"" << a.s << "\" \"" << b.s << "\" diff " << diff << " updates " << m_string_updates.size() << "\n"; #if 1 for (auto const& [side, op, i, j] : m_string_updates) { switch (op) { case op_t::del: if (side == side_t::left) - verbose_stream() << "del " << a[i] << " @ " << i << " left\n"; + verbose_stream() << "del " << a.s[i] << " @ " << i << " left\n"; else - verbose_stream() << "del " << b[i] << " @ " << i << " right\n"; + verbose_stream() << "del " << b.s[i] << " @ " << i << " right\n"; break; case op_t::add: if (side == side_t::left) - verbose_stream() << "add " << b[i] << " @ " << j << " left\n"; + verbose_stream() << "add " << b.s[i] << " @ " << j << " left\n"; else - verbose_stream() << "add " << a[i] << " @ " << j << " right\n"; + verbose_stream() << "add " << a.s[i] << " @ " << j << " right\n"; break; case op_t::copy: if (side == side_t::left) - verbose_stream() << "copy " << b[i] << " @ " << j << " left\n"; + verbose_stream() << "copy " << b.s[i] << " @ " << j << " left\n"; else - verbose_stream() << "copy " << a[i] << " @ " << j << " right\n"; + verbose_stream() << "copy " << a.s[i] << " @ " << j << " right\n"; break; } } @@ -905,13 +901,13 @@ namespace sls { else if (op == op_t::del && side == side_t::right) delete_char(R, i); else if (op == op_t::add && side == side_t::left) - add_char(L, j, b[i]); + add_char(L, j, b.s[i]); else if (op == op_t::add && side == side_t::right) - add_char(R, j, a[i]); + add_char(R, j, a.s[i]); else if (op == op_t::copy && side == side_t::left) - copy_char(L, j, b[i]); + copy_char(L, j, b.s[i]); else if (op == op_t::copy && side == side_t::right) - copy_char(R, j, a[i]); + copy_char(R, j, a.s[i]); } verbose_stream() << "num updates " << m_str_updates.size() << "\n"; bool r = apply_update(); @@ -939,9 +935,6 @@ namespace sls { if (a == b) return update(eq->get_arg(0), a) && update(eq->get_arg(1), b); - - - unsigned diff = edit_distance(a, b); //verbose_stream() << "solve: " << diff << " " << a << " " << b << "\n"; diff --git a/src/ast/sls/sls_seq_plugin.h b/src/ast/sls/sls_seq_plugin.h index ea9b3f0b9..18a2d3a55 100644 --- a/src/ast/sls/sls_seq_plugin.h +++ b/src/ast/sls/sls_seq_plugin.h @@ -103,9 +103,20 @@ namespace sls { op_t op; unsigned i, j; }; + struct string_instance { + zstring s; + bool_vector is_value; + bool_vector prev_is_var; + bool_vector next_is_var; + + bool can_add(unsigned i) const { + return !is_value[i] || prev_is_var[i]; + } + }; svector m_string_updates; void add_string_update(side_t side, op_t op, unsigned i, unsigned j) { m_string_updates.push_back({ side, op, i, j }); } - unsigned edit_distance_with_updates(zstring const& a, bool_vector const& a_is_value, zstring const& b, bool_vector const& b_is_value); + void init_string_instance(ptr_vector const& es, string_instance& a); + unsigned edit_distance_with_updates(string_instance const& a, string_instance const& b); unsigned edit_distance(zstring const& a, zstring const& b); void add_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars); diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 658647ea6..47b433563 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -111,6 +111,7 @@ add_executable(test-z3 simplex.cpp simplifier.cpp sls_test.cpp + sls_seq_plugin.cpp small_object_allocator.cpp smt2print_parse.cpp smt_context.cpp diff --git a/src/test/main.cpp b/src/test/main.cpp index c05af0aa2..9fc1decb3 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -270,4 +270,5 @@ int main(int argc, char ** argv) { TST(euf_arith_plugin); TST(sls_test); TST(scoped_vector); + TST(sls_seq_plugin); } diff --git a/src/test/sls_seq_plugin.cpp b/src/test/sls_seq_plugin.cpp new file mode 100644 index 000000000..b7a23a596 --- /dev/null +++ b/src/test/sls_seq_plugin.cpp @@ -0,0 +1,354 @@ +#include "ast/ast.h" +#include "util/vector.h" +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/reg_decl_plugins.h" + +struct test_seq { + + test_seq(ast_manager& m) : + m(m), + seq(m), + a(m) + { + } + + struct value { + value(ast_manager& m) : evalue(m) {} + zstring svalue; + expr_ref evalue; + }; + + struct eval { + eval(ast_manager& m) : + val0(m), val1(m) { + } + value val0; + value val1; + bool is_value = false; + unsigned min_length = 0; + unsigned max_length = UINT_MAX; + ptr_vector lhs, rhs; + }; + + ast_manager& m; + seq_util seq; + arith_util a; + scoped_ptr_vector m_values; + indexed_uint_set m_chars; + bool m_initialized = false; + + struct str_update { + expr* e; + zstring value; + double m_score; + }; + struct int_update { + expr* e; + rational value; + double m_score; + }; + vector m_str_updates; + + + enum op_t { + add, del, copy + }; + enum side_t { + left, right + }; + struct string_update { + side_t side; + op_t op; + unsigned i, j; + }; + struct string_instance { + zstring s; + bool_vector is_value; + bool_vector prev_is_var; + bool_vector next_is_var; + + bool can_add(unsigned i) const { + return !is_value[i] || prev_is_var[i]; + } + }; + svector m_string_updates; + + + bool is_value(expr* e) { + if (seq.is_seq(e)) + return get_eval(e).is_value; + return m.is_value(e); + } + + void init_string_instance(ptr_vector const& es, string_instance& a) { + bool prev_is_var = false; + for (auto x : es) { + auto const& val = strval0(x); + auto len = val.length(); + bool is_val = is_value(x); + a.s += val; + if (!prev_is_var && !is_val && !a.next_is_var.empty()) + a.next_is_var.back() = true; + for (unsigned i = 0; i < len; ++i) { + a.is_value.push_back(is_val); + a.prev_is_var.push_back(false); + a.next_is_var.push_back(false); + } + if (len > 0 && is_val && prev_is_var && !a.is_value.empty()) + a.prev_is_var[a.prev_is_var.size() - len] = true; + prev_is_var = !is_val; + } + } + + + /** + * \brief edit distance with update calculation + */ + unsigned edit_distance_with_updates(string_instance const& a, string_instance const& b) { + unsigned n = a.s.length(); + unsigned m = b.s.length(); + vector d(n + 1); // edit distance + vector u(n + 1); // edit distance with updates. + m_string_updates.reset(); + for (unsigned i = 0; i <= n; ++i) { + d[i].resize(m + 1, 0); + u[i].resize(m + 1, 0); + } + for (unsigned i = 0; i <= n; ++i) + d[i][0] = i, u[i][0] = i; + for (unsigned j = 0; j <= m; ++j) + d[0][j] = j, u[0][j] = j; + for (unsigned j = 1; j <= m; ++j) { + for (unsigned i = 1; i <= n; ++i) { + if (a.s[i - 1] == b.s[j - 1]) { + d[i][j] = d[i - 1][j - 1]; + u[i][j] = u[i - 1][j - 1]; + } + else { + u[i][j] = 1 + std::min(u[i - 1][j], std::min(u[i][j - 1], u[i - 1][j - 1])); + d[i][j] = 1 + std::min(d[i - 1][j], std::min(d[i][j - 1], d[i - 1][j - 1])); + + if (d[i - 1][j] < u[i][j] && a.can_add(i - 1)) { + m_string_updates.reset(); + u[i][j] = d[i - 1][j]; + } + if (d[i][j - 1] < u[i][j] && b.can_add(j - 1)) { + m_string_updates.reset(); + u[i][j] = d[i][j - 1]; + } + + if (d[i - 1][j - 1] < u[i][j] && (!a.is_value[i - 1] || !b.is_value[j - 1])) { + m_string_updates.reset(); + u[i][j] = d[i - 1][j - 1]; + } + + if (d[i - 1][j] == u[i][j] && a.can_add(i - 1)) + add_string_update(side_t::left, op_t::add, j - 1, i - 1); + + if (d[i][j - 1] == u[i][j] && b.can_add(j - 1)) + add_string_update(side_t::right, op_t::add, i - 1, j - 1); + + if (d[i][j - 1] < u[i][j] && b.next_is_var[j - 1] && j == m) + add_string_update(side_t::right, op_t::add, i - 1, j); + + if (d[i - 1][j] < u[i][j] && a.next_is_var[i - 1] && i == n) + add_string_update(side_t::left, op_t::add, j - 1, i); + + if (d[i - 1][j] == u[i][j] && !a.is_value[i - 1]) + add_string_update(side_t::left, op_t::del, i - 1, 0); + + if (d[i][j - 1] == u[i][j] && !b.is_value[j - 1]) + add_string_update(side_t::right, op_t::del, j - 1, 0); + + if (d[i - 1][j - 1] == u[i][j] && !a.is_value[i - 1]) + add_string_update(side_t::left, op_t::copy, j - 1, i - 1); + + if (d[i - 1][j - 1] == u[i][j] && !b.is_value[j - 1]) + add_string_update(side_t::right, op_t::copy, i - 1, j - 1); + } + } + } + return u[n][m]; + } + + void add_string_update(side_t side, op_t op, unsigned i, unsigned j) { m_string_updates.push_back({ side, op, i, j }); } + + + eval& get_eval(expr* e) { + unsigned id = e->get_id(); + m_values.reserve(id + 1); + if (!m_values[id]) { + m_values.set(id, alloc(eval, m)); + zstring s; + bool is_string = seq.str.is_string(e, s); + m_values[id]->is_value = is_string; + if (is_string) + m_values[id]->val0.svalue = s; + } + return *m_values[id]; + } + + eval* get_eval(expr* e) const { + unsigned id = e->get_id(); + return m_values.get(id, nullptr); + } + + ptr_vector const& lhs(expr* eq) { + auto& ev = get_eval(eq); + if (ev.lhs.empty()) { + expr* x, * y; + VERIFY(m.is_eq(eq, x, y)); + seq.str.get_concat(x, ev.lhs); + seq.str.get_concat(y, ev.rhs); + } + return ev.lhs; + } + + ptr_vector const& concats(expr* x) { + auto& ev = get_eval(x); + if (ev.lhs.empty()) + seq.str.get_concat(x, ev.lhs); + return ev.lhs; + } + + ptr_vector const& rhs(expr* eq) { + lhs(eq); + auto& e = get_eval(eq); + return e.rhs; + } + + zstring& strval0(expr* e) { + SASSERT(seq.is_string(e->get_sort())); + return get_eval(e).val0.svalue; + } + + bool repair_down_str_eq_edit_distance_incremental(app* eq) { + auto const& L = lhs(eq); + auto const& R = rhs(eq); + string_instance a, b; + verbose_stream() << "eq\n"; + for (auto x : L) + verbose_stream() << mk_pp(x, m) << "\n"; + init_string_instance(L, a); + init_string_instance(R, b); + + verbose_stream() << a.s << " == " << b.s << "\n"; + if (a.s == b.s) + return true; + + unsigned diff = edit_distance_with_updates(a, b); + + + verbose_stream() << "diff \"" << a.s << "\" \"" << b.s << "\" diff " << diff << " updates " << m_string_updates.size() << "\n"; +#if 1 + for (auto const& [side, op, i, j] : m_string_updates) { + switch (op) { + case op_t::del: + if (side == side_t::left) + verbose_stream() << "del " << a.s[i] << " @ " << i << " left\n"; + else + verbose_stream() << "del " << b.s[i] << " @ " << i << " right\n"; + break; + case op_t::add: + if (side == side_t::left) + verbose_stream() << "add " << b.s[i] << " @ " << j << " left\n"; + else + verbose_stream() << "add " << a.s[i] << " @ " << j << " right\n"; + break; + case op_t::copy: + if (side == side_t::left) + verbose_stream() << "copy " << b.s[i] << " @ " << j << " left\n"; + else + verbose_stream() << "copy " << a.s[i] << " @ " << j << " right\n"; + break; + } + } +#endif + auto delete_char = [&](auto const& es, unsigned i) { + for (auto x : es) { + auto const& value = strval0(x); + if (i >= value.length()) + i -= value.length(); + else { + if (!is_value(x)) + m_str_updates.push_back({ x, value.extract(0, i) + value.extract(i + 1, value.length()), 1 }); + break; + } + } + }; + + auto add_char = [&](auto const& es, unsigned j, uint32_t ch) { + for (auto x : es) { + auto const& value = strval0(x); + //verbose_stream() << "add " << j << " " << value << " " << value.length() << " " << is_value(x) << "\n"; + if (j > value.length() || (j == value.length() && j > 0)) { + j -= value.length(); + continue; + } + if (!is_value(x)) + m_str_updates.push_back({ x, value.extract(0, j) + zstring(ch) + value.extract(j, value.length()), 1 }); + if (j < value.length()) + break; + } + }; + + auto copy_char = [&](auto const& es, unsigned j, uint32_t ch) { + for (auto x : es) { + auto const& value = strval0(x); + if (j >= value.length()) + j -= value.length(); + else { + if (!is_value(x)) + m_str_updates.push_back({ x, value.extract(0, j) + zstring(ch) + value.extract(j + 1, value.length()), 1 }); + break; + } + } + }; + + for (auto& [side, op, i, j] : m_string_updates) { + if (op == op_t::del && side == side_t::left) + delete_char(L, i); + else if (op == op_t::del && side == side_t::right) + delete_char(R, i); + else if (op == op_t::add && side == side_t::left) + add_char(L, j, b.s[i]); + else if (op == op_t::add && side == side_t::right) + add_char(R, j, a.s[i]); + else if (op == op_t::copy && side == side_t::left) + copy_char(L, j, b.s[i]); + else if (op == op_t::copy && side == side_t::right) + copy_char(R, j, a.s[i]); + } + for (auto const& [e, value, score] : m_str_updates) { + verbose_stream() << mk_pp(e, m) << " := " << value << "\n"; + } + return true; + } + +}; + +void tst_sls_seq_plugin() { + ast_manager m; + reg_decl_plugins(m); + test_seq ts(m); + seq_util seq(m); + expr_ref_vector ls(m), rs(m); + sort* S = seq.str.mk_string_sort(); + expr_ref x(m.mk_const("x", S), m); + expr_ref y(m.mk_const("y", S), m); + expr_ref z(m.mk_const("z", S), m); + expr_ref a(seq.str.mk_string("a"), m); + expr_ref b(seq.str.mk_string("b"), m); + expr_ref c(seq.str.mk_string("c"), m); + + ls.push_back(x).push_back(a).push_back(y); + rs.push_back(b).push_back(c).push_back(z); + expr_ref l(seq.str.mk_concat(ls, S), m); + expr_ref r(seq.str.mk_concat(rs, S), m); + app_ref eq(m.mk_eq(l, r), m); + verbose_stream() << eq << "\n"; + ts.repair_down_str_eq_edit_distance_incremental(eq); +} \ No newline at end of file