From 1553bae20c5b18cf908c5f4b41c811683472619f Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer <56730610+CEisenhofer@users.noreply.github.com> Date: Tue, 21 Jan 2025 17:01:59 +0100 Subject: [PATCH] Performance improvements for seq-sls (#7519) * Improve length repair * Fixed arguments * Special case regex membership with constant string * Trying hybrid eq-repair strategy * Different heuristic * Fixed stoi --- src/ast/rewriter/seq_rewriter.cpp | 21 ++++++ src/ast/seq_decl_plugin.h | 1 + src/ast/sls/sls_context.cpp | 2 +- src/ast/sls/sls_seq_plugin.cpp | 119 ++++++++++++++++++++++-------- src/ast/sls/sls_seq_plugin.h | 9 ++- src/params/sls_params.pyg | 2 +- 6 files changed, 118 insertions(+), 36 deletions(-) diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index b117e02e1..4c4cc900f 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -4545,6 +4545,27 @@ br_status seq_rewriter::mk_str_in_regexp(expr* a, expr* b, expr_ref& result) { result = m().mk_true(); return BR_DONE; } + + zstring s; + if (str().is_string(a, s) && re().is_ground(b)) { + // Just check membership and replace by true/false + expr_ref r(b, m()); + for (unsigned i = 0; i < s.length(); i++) { + if (re().is_empty(r)) { + result = m().mk_false(); + return BR_DONE; + } + unsigned ch = s[i]; + expr_ref new_r = mk_derivative(m_util.mk_char(ch), r); + r = new_r; + } + if (re().get_info(r).nullable) + result = m().mk_true(); + else + result = m().mk_false(); + return BR_DONE; + } + expr_ref b_s(m()); if (lift_str_from_to_re(b, b_s)) { result = m_br.mk_eq_rw(a, b_s); diff --git a/src/ast/seq_decl_plugin.h b/src/ast/seq_decl_plugin.h index 03cab284a..78ad6d544 100644 --- a/src/ast/seq_decl_plugin.h +++ b/src/ast/seq_decl_plugin.h @@ -636,6 +636,7 @@ public: } family_id get_family_id() const { return m_fid; } + family_id get_char_family_id() const { return ch.get_family_id(); } }; inline std::ostream& operator<<(std::ostream& out, seq_util::rex::pp const & p) { return p.display(out); } diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 5333e43f7..33ff74f22 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -72,7 +72,7 @@ namespace sls { register_plugin(alloc(array_plugin, *this)); else if (fid == datatype_util(m).get_family_id()) register_plugin(alloc(datatype_plugin, *this)); - else if (fid == seq_util(m).get_family_id()) + else if (fid == seq_util(m).get_family_id() || fid == seq_util(m).get_char_family_id()) register_plugin(alloc(seq_plugin, *this)); else { verbose_stream() << "did not find plugin for " << fid << "\n"; diff --git a/src/ast/sls/sls_seq_plugin.cpp b/src/ast/sls/sls_seq_plugin.cpp index eaccfb9bd..2c7e7f8d9 100644 --- a/src/ast/sls/sls_seq_plugin.cpp +++ b/src/ast/sls/sls_seq_plugin.cpp @@ -116,7 +116,7 @@ namespace sls { { m_fid = seq.get_family_id(); sls_params p(c.get_params()); - m_str_update_strategy = p.str_update_strategy() == 0 ? EDIT_CHAR : EDIT_SUBSTR; + m_str_update_strategy = (edit_distance_strategy)p.str_update_strategy(); } void seq_plugin::propagate_literal(sat::literal lit) { @@ -148,6 +148,7 @@ namespace sls { for (expr* e : ctx.subterms()) { expr* x, * y, * z = nullptr; rational r; + // std::cout << "Checking "<< mk_pp(e, m) << std::endl; // coherence between string / integer functions is delayed // so we check and enforce it here. if (seq.str.is_length(e, x) && seq.is_string(x->get_sort())) { @@ -158,10 +159,10 @@ namespace sls { // set e to length of x or // set x to a string of length e - if (r == 0 || sx.length() == 0) { - verbose_stream() << "todo-create lemma: len(x) = 0 <=> x = \"\"\n"; - // create a lemma: len(x) = 0 => x = "" - } + if (r == 0 || sx.length() == 0) + // create lemma: len(x) = 0 <=> x = "" + ctx.add_constraint(m.mk_eq(m.mk_eq(e, a.mk_int(0)), m.mk_eq(x, seq.str.mk_string("")))); + if (ctx.rand(2) == 0 && update(e, rational(sx.length()))) return false; // TODO: Why from the beginning? We can take any subsequence of given length @@ -196,8 +197,27 @@ namespace sls { update(e, rational(sx.indexofu(sy, val_z.get_unsigned()))); return false; } - // last-index-of - // str-to-int + if (seq.str.is_last_index(e, x, y) && seq.is_string(x->get_sort())) { + // TODO + SASSERT(false); + } + if (seq.str.is_stoi(e, x) && seq.is_string(x->get_sort())) { + auto sx = strval0(x); + rational val_e; + VERIFY(a.is_numeral(ctx.get_value(e), val_e)); + // std::cout << "stoi: \"" << sx << "\" -> " << val_e << std::endl; + if (!is_num_string(sx)) { + if (val_e == -1) + continue; + update(e, rational(-1)); + return false; + } + rational val_x(sx.encode().c_str()); + if (val_e == val_x) + continue; + update(e, val_x); + return false; + } } return true; } @@ -538,24 +558,25 @@ namespace sls { rational r; unsigned len_u; VERIFY(a.is_numeral(len, r)); + // std::cout << "repair-str-len: " << mk_pp(e, m) << ": " << r << "" << std::endl; if (!r.is_unsigned()) return false; zstring val_x = strval0(x); + // std::cout << "Arg: \"" << val_x << "\"" << std::endl; len_u = r.get_unsigned(); if (len_u == val_x.length()) return true; - if (len_u < val_x.length()) { - for (unsigned i = 0; i + len_u < val_x.length(); ++i) + if (len_u < val_x.length()) { + for (unsigned i = 0; i + len_u < val_x.length(); ++i) { m_str_updates.push_back({ x, val_x.extract(i, len_u), 1 }); + } + return apply_update(); } - zstring ch = !m_chars.empty() ? m_chars[ctx.rand(m_chars.size())] : zstring("a"); - zstring val_x_new = val_x + ch; - m_str_updates.push_back({ x, val_x_new, 1 }); - zstring val_x_new2 = ch + val_x; - if (val_x_new != val_x_new2) - m_str_updates.push_back({ x, val_x_new2, 1 }); - - return apply_update(); + zstring val_x_new = val_x; + for (unsigned i = val_x.length(); i < len_u; ++i) { + val_x_new += !m_chars.empty() ? m_chars[ctx.rand(m_chars.size())] : 'a'; + } + return update(x, val_x_new); } void seq_plugin::repair_up_str_stoi(app* e) { @@ -563,14 +584,18 @@ namespace sls { VERIFY(seq.str.is_stoi(e, x)); rational val_e; - rational val_x(strval0(x).encode().c_str()); VERIFY(a.is_numeral(ctx.get_value(e), val_e)); - if (val_e.is_unsigned() && val_e == val_x) + // std::cout << "repair-up-str-stoi " << mk_pp(e, m) << ": " << val_e << "; Arg: \""<< strval0(x) << "\"" << std::endl; + if (!is_num_string(strval0(x))) { + if (val_e == -1) + return; + update(e, rational(-1)); return; - if (val_x < 0) - update(e, rational(0)); - else - update(e, val_x); + } + rational val_x(strval0(x).encode().c_str()); + if (val_e == val_x) + return; + update(e, val_x); } void seq_plugin::repair_up_str_itos(app* e) { @@ -682,11 +707,17 @@ namespace sls { return d[n][m]; } - void seq_plugin::add_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars) { + void seq_plugin::add_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars, unsigned diff) { if (m_str_update_strategy == EDIT_CHAR) add_char_edit_updates(w, val, val_other, chars); - else + else if (m_str_update_strategy == EDIT_SUBSTR) add_substr_edit_updates(w, val, val_other, chars); + else { + if (val.length() / 3 >= diff - 1) + add_char_edit_updates(w, val, val_other, chars); + else + add_substr_edit_updates(w, val, val_other, chars); + } } void seq_plugin::add_substr_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars) { @@ -1013,6 +1044,9 @@ namespace sls { b_chars.insert(ch); b += strval0(y); } + + // std::cout << "Repair down " << mk_pp(eq, m) << ": \"" << a << "\" = \"" << b << "\"" << std::endl; + if (a == b) return update(eq->get_arg(0), a) && update(eq->get_arg(1), b); @@ -1020,8 +1054,8 @@ namespace sls { //verbose_stream() << "solve: " << diff << " " << a << " " << b << "\n"; - add_edit_updates(L, a, b, b_chars); - add_edit_updates(R, b, a, a_chars); + add_edit_updates(L, a, b, b_chars, diff); + add_edit_updates(R, b, a, a_chars, diff); for (auto& [x, s, score] : m_str_updates) { a.reset(); @@ -1275,7 +1309,20 @@ namespace sls { rational r; VERIFY(seq.str.is_stoi(e, x)); VERIFY(a.is_numeral(ctx.get_value(e), r) && r.is_int()); - if (r < 0) + // std::cout << "repair-down " << mk_pp(e, m) << ": \"" << strval0(x) << "\" -> " << r << std::endl; + // It might be satisfied already (not checked before, as the value is of integer sort) + if (!is_num_string(strval0(x))) { + if (r == -1) + return true; + } + else { + if (r == rational(strval0(x).encode().c_str())) + return true; + } + if (r == -1) + // TODO: Add some random character somewhere or make it empty + return false; + if (r < -1) return false; zstring r_val(r.to_string()); m_str_updates.push_back({ x, r_val, 1 }); @@ -1286,9 +1333,11 @@ namespace sls { expr* x, * y; VERIFY(seq.str.is_at(e, x, y)); zstring se = strval0(e); + // std::cout << "repair-str-at: " << mk_pp(e, m) << ": \"" << se << "\"" << std::endl; if (se.length() > 1) return false; zstring sx = strval0(x); + // std::cout << "Arg: " << sx << std::endl; unsigned lenx = sx.length(); expr_ref idx = ctx.get_value(y); rational r; @@ -1578,9 +1627,9 @@ namespace sls { for (auto ch : value0) chars.insert(ch); - add_edit_updates(es, value, value0, chars); - unsigned diff = edit_distance(value, value0); + add_edit_updates(es, value, value0, chars, diff); + for (auto& [x, s, score] : m_str_updates) { value.reset(); for (auto z : es) { @@ -1817,6 +1866,14 @@ namespace sls { return m.is_value(e); } + bool seq_plugin::is_num_string(const zstring& s) { + bool is_valid = s.length() > 0; + for (unsigned i = 0; is_valid && i < s.length(); ++i) { + is_valid = s[i] >= '0' && s[i] <= '9'; + } + return is_valid; + } + // Regular expressions bool seq_plugin::is_in_re(zstring const& s, expr* _r) { @@ -1863,7 +1920,7 @@ namespace sls { zstring prefix = s.extract(0, i); choose(d_r, 2, prefix, lookaheads); expr_ref ch(seq.str.mk_char(s[i]), m); - d_r = rw.mk_derivative(ch, d_r); + d_r = rw.mk_derivative(ch, d_r); } unsigned current_min_length = UINT_MAX; if (!seq.re.is_empty(d_r)) { diff --git a/src/ast/sls/sls_seq_plugin.h b/src/ast/sls/sls_seq_plugin.h index 7c7fc16ff..ad2d5c58c 100644 --- a/src/ast/sls/sls_seq_plugin.h +++ b/src/ast/sls/sls_seq_plugin.h @@ -43,8 +43,9 @@ namespace sls { }; enum edit_distance_strategy { - EDIT_CHAR, - EDIT_SUBSTR, + EDIT_CHAR = 0, + EDIT_SUBSTR = 1, + EDIT_COMBINED = 2, }; seq_util seq; @@ -127,7 +128,7 @@ namespace sls { void init_string_instance(ptr_vector const& es, string_instance& a); unsigned edit_distance_with_updates(string_instance const& a, string_instance const& b); unsigned edit_distance(zstring const& a, zstring const& b); - void add_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars); + void add_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars, unsigned diff); void add_char_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars); void add_substr_edit_updates(ptr_vector const& w, zstring const& val, zstring const& val_other, uint_set const& chars); @@ -148,6 +149,8 @@ namespace sls { bool is_in_re(zstring const& s, expr* r); + bool is_num_string(zstring const& s); // Checks if s \in [0-9]+ (i.e., str.to_int is not -1) + // access evaluation bool is_seq_predicate(expr* e); diff --git a/src/params/sls_params.pyg b/src/params/sls_params.pyg index 5df6c1c63..66f3c82a3 100644 --- a/src/params/sls_params.pyg +++ b/src/params/sls_params.pyg @@ -31,5 +31,5 @@ def_module_params('sls', ('bv_use_top_level_assertions', BOOL, True, 'use top-level assertions for BV lookahead solver'), ('bv_use_lookahead', BOOL, True, 'use lookahead solver for BV'), ('bv_allow_rotation', BOOL, True, 'allow model rotation when repairing literal assignment'), - ('str_update_strategy', UINT, 1, 'string update candidate selection: 0 - single character based update, 1 - subsequence based update') + ('str_update_strategy', UINT, 2, 'string update candidate selection: 0 - single character based update, 1 - subsequence based update, 2 - combined') ))