From 7ed185aa9e3c3170b2c68bda9fb34fa0014d02e5 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 24 Nov 2024 19:09:50 -0800 Subject: [PATCH] add comments Signed-off-by: Nikolaj Bjorner --- src/ast/sls/sls_arith_base.cpp | 111 ++++++++++++++++------ src/ast/sls/sls_arith_base.h | 9 +- src/ast/sls/sls_seq_plugin.cpp | 164 +++++++++++++++++++++++++++++---- 3 files changed, 235 insertions(+), 49 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 4a599ad75..93025909f 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -548,7 +548,7 @@ namespace sls { template bool arith_base::find_lin_moves(sat::literal lit) { m_updates.reset(); - auto* ineq = atom(lit.var()); + auto* ineq = get_ineq(lit.var()); num_t a, b; if (!ineq) return false; @@ -582,7 +582,7 @@ namespace sls { num_t d(1), d2; bool first = true; for (auto a : ctx.get_clause(cl)) { - auto const* ineq = atom(a.var()); + auto const* ineq = get_ineq(a.var()); if (!ineq) continue; d2 = dtt(a.sign(), *ineq); @@ -601,7 +601,7 @@ namespace sls { num_t d(1), d2; bool first = true; for (auto lit : ctx.get_clause(cl)) { - auto const* ineq = atom(lit.var()); + auto const* ineq = get_ineq(lit.var()); if (!ineq) continue; d2 = dtt(lit.sign(), *ineq, v, new_value); @@ -667,8 +667,8 @@ namespace sls { } buffer to_flip; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - auto& ineq = *atom(bv); + for (auto const& [coeff, bv] : vi.m_ineqs) { + auto& ineq = *get_ineq(bv); bool old_sign = sign(bv); sat::literal lit(bv, old_sign); SASSERT(ctx.is_true(lit)); @@ -684,9 +684,9 @@ namespace sls { m_last_var = v; for (auto bv : to_flip) { - if (dtt(sign(bv), *atom(bv)) != 0) + if (dtt(sign(bv), *get_ineq(bv)) != 0) ctx.flip(bv); - SASSERT(dtt(sign(bv), *atom(bv)) == 0); + SASSERT(dtt(sign(bv), *get_ineq(bv)) == 0); } IF_VERBOSE(10, verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n"); @@ -933,12 +933,12 @@ namespace sls { template void arith_base::init_bool_var(sat::bool_var bv) { expr* e = ctx.atom(bv); - if (m_bool_vars.get(bv, nullptr)) + if (m_ineqs.get(bv, nullptr)) return; if (!e) return; expr* x, * y; - m_bool_vars.reserve(bv + 1); + m_ineqs.reserve(bv + 1); if (a.is_le(e, x, y) || a.is_ge(e, y, x)) { auto& ineq = new_ineq(ineq_kind::LE, num_t(0)); add_args(ineq, x, num_t(1)); @@ -963,8 +963,8 @@ namespace sls { add_args(ineq, y, num_t(-1)); init_ineq(bv, ineq); } - else if (m.is_distinct(e) && a.is_int_real(to_app(e)->get_arg(0))) { - NOT_IMPLEMENTED_YET(); + else if (is_distinct(e)) { + verbose_stream() << "distinct " << mk_pp(e, m) << "\n"; } else if (a.is_is_int(e, x)) { @@ -1004,7 +1004,7 @@ namespace sls { // compute the value of the linear term, and accumulate non-linear sub-terms i.m_args_value = i.m_coeff; for (auto const& [coeff, v] : i.m_args) { - m_vars[v].m_bool_vars.push_back({ coeff, bv }); + m_vars[v].m_ineqs.push_back({ coeff, bv }); i.m_args_value += coeff * value(v); if (is_mul(v)) { auto const& [w, monomial] = get_mul(v); @@ -1044,21 +1044,28 @@ namespace sls { } // attach i to bv - m_bool_vars.set(bv, &i); + m_ineqs.set(bv, &i); } template void arith_base::init_bool_var_assignment(sat::bool_var v) { - auto* ineq = atom(v); + auto* ineq = get_ineq(v); if (ineq && ineq->is_true() != ctx.is_true(v)) ctx.flip(v); + if (is_distinct(ctx.atom(v)) && eval_distinct(ctx.atom(v)) != ctx.is_true(v)) + ctx.flip(v); } template void arith_base::propagate_literal(sat::literal lit) { if (!ctx.is_true(lit)) return; - auto const* ineq = atom(lit.var()); + expr* e = ctx.atom(lit.var()); + if (is_distinct(e) && eval_distinct(e) != ctx.is_true(lit)) { + repair_distinct(e); + return; + } + auto const* ineq = get_ineq(lit.var()); if (!ineq) return; if (ineq->is_true() != lit.sign()) @@ -1136,7 +1143,7 @@ namespace sls { void arith_base::repair_up(app* e) { if (m.is_bool(e)) { auto v = ctx.atom2bool_var(e); - auto const* ineq = atom(v); + auto const* ineq = get_ineq(v); if (ineq && ineq->is_true() != ctx.is_true(v)) ctx.flip(v); return; @@ -1333,7 +1340,7 @@ namespace sls { template void arith_base::initialize_unit(sat::literal lit) { init_bool_var(lit.var()); - auto* ineq = atom(lit.var()); + auto* ineq = get_ineq(lit.var()); if (!ineq) return; @@ -1623,11 +1630,11 @@ namespace sls { double arith_base::compute_score(var_t x, num_t const& delta) { int result = 0; int breaks = 0; - for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { + for (auto const& [coeff, bv] : m_vars[x].m_ineqs) { bool old_sign = sign(bv); auto lit = sat::literal(bv, old_sign); - auto dtt_old = dtt(old_sign, *atom(bv)); - auto dtt_new = dtt(old_sign, *atom(bv), coeff, delta); + auto dtt_old = dtt(old_sign, *get_ineq(bv)); + auto dtt_new = dtt(old_sign, *get_ineq(bv), coeff, delta); #if 1 if (dtt_new == 0 && dtt_old != 0) result += 1; @@ -1711,7 +1718,7 @@ namespace sls { template bool arith_base::find_nl_moves(sat::literal lit) { m_updates.reset(); - auto* ineq = atom(lit.var()); + auto* ineq = get_ineq(lit.var()); num_t a, b; if (!ineq) return false; @@ -1766,7 +1773,7 @@ namespace sls { template bool arith_base::find_reset_moves(sat::literal lit) { m_updates.reset(); - auto* ineq = atom(lit.var()); + auto* ineq = get_ineq(lit.var()); num_t a, b; if (!ineq) return false; @@ -1892,7 +1899,7 @@ namespace sls { template void arith_base::check_ineqs() { for (unsigned bv = 0; bv < ctx.num_bool_vars(); ++bv) { - auto const* ineq = atom(bv); + auto const* ineq = get_ineq(bv); if (!ineq) continue; num_t d = dtt(sign(bv), *ineq); @@ -1918,6 +1925,45 @@ namespace sls { mk_term(arg); } + template + bool arith_base::is_distinct(expr* e) { + return m.is_distinct(e) && + to_app(e)->get_num_args() > 0 && + a.is_int_real(to_app(e)->get_arg(0)); + } + + template + bool arith_base::eval_distinct(expr* e) { + auto const& args = *to_app(e); + for (unsigned i = 0; i < args.get_num_args(); ++i) + for (unsigned j = i + 1; j < args.get_num_args(); ++j) { + auto v1 = mk_term(args.get_arg(i)); + auto v2 = mk_term(args.get_arg(j)); + if (value(v1) == value(v2)) + return false; + } + return true; + } + + template + void arith_base::repair_distinct(expr* e) { + auto const& args = *to_app(e); + for (unsigned i = 0; i < args.get_num_args(); ++i) + for (unsigned j = i + 1; j < args.get_num_args(); ++j) { + auto v1 = mk_term(args.get_arg(i)); + auto v2 = mk_term(args.get_arg(j)); + if (value(v1) == value(v2)) { + auto new_value = value(v1) + num_t(1); + if (new_value == value(v2)) + new_value += num_t(1); + if (!is_fixed(v2)) + update(v2, new_value); + else if (!is_fixed(v1)) + update(v1, new_value); + } + } + } + template bool arith_base::set_value(expr* e, expr* v) { if (!a.is_int_real(e)) @@ -1956,7 +2002,14 @@ namespace sls { for (auto lit : clause.m_clause) { if (!ctx.is_true(lit)) continue; - auto ineq = atom(lit.var()); + if (is_distinct(ctx.atom(lit.var()))) { + if (eval_distinct(ctx.atom(lit.var())) != lit.sign()) { + sat = true; + break; + } + continue; + } + auto ineq = get_ineq(lit.var()); if (!ineq) { sat = true; break; @@ -1972,7 +2025,7 @@ namespace sls { verbose_stream() << clause << "\n"; for (auto lit : clause.m_clause) { verbose_stream() << lit << " (" << ctx.is_true(lit) << ") "; - auto ineq = atom(lit.var()); + auto ineq = get_ineq(lit.var()); if (!ineq) continue; verbose_stream() << *ineq << "\n"; @@ -2069,9 +2122,9 @@ namespace sls { out << " "; } - if (!vi.m_bool_vars.empty()) { + if (!vi.m_ineqs.empty()) { out << " bool: "; - for (auto [c, bv] : vi.m_bool_vars) + for (auto [c, bv] : vi.m_ineqs) out << c << "@" << bv << " "; } return out; @@ -2080,7 +2133,7 @@ namespace sls { template std::ostream& arith_base::display(std::ostream& out) const { for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { - auto ineq = atom(v); + auto ineq = get_ineq(v); if (ineq) out << v << ": " << *ineq << "\n"; } @@ -2176,7 +2229,7 @@ namespace sls { template void arith_base::invariant() { for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { - auto ineq = atom(v); + auto ineq = get_ineq(v); if (ineq) invariant(*ineq); } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index d70a49d20..60a03cb36 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -87,7 +87,7 @@ namespace sls { var_sort m_sort; arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP; unsigned m_def_idx = UINT_MAX; - vector> m_bool_vars; + vector> m_ineqs; unsigned_vector m_muls; unsigned_vector m_adds; optional m_lo, m_hi; @@ -159,7 +159,7 @@ namespace sls { stats m_stats; config m_config; - scoped_ptr_vector m_bool_vars; + scoped_ptr_vector m_ineqs; vector m_vars; vector m_muls; vector m_adds; @@ -181,6 +181,9 @@ namespace sls { unsigned get_num_vars() const { return m_vars.size(); } + bool is_distinct(expr* e); + bool eval_distinct(expr* e); + void repair_distinct(expr* e); bool eval_is_correct(var_t v); bool repair_mul(mul_def const& md); bool repair_add(add_def const& ad); @@ -219,7 +222,7 @@ namespace sls { // double reward(sat::literal lit); bool sign(sat::bool_var v) const { return !ctx.is_true(sat::literal(v, false)); } - ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); } + ineq* get_ineq(sat::bool_var bv) const { return m_ineqs.get(bv, nullptr); } num_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } num_t dtt(bool sign, num_t const& args_value, ineq const& ineq) const; num_t dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const; diff --git a/src/ast/sls/sls_seq_plugin.cpp b/src/ast/sls/sls_seq_plugin.cpp index ded8d10fb..3a85f3db7 100644 --- a/src/ast/sls/sls_seq_plugin.cpp +++ b/src/ast/sls/sls_seq_plugin.cpp @@ -12,6 +12,44 @@ Abstract: Author: Nikolaj Bjorner (nbjorner) 2024-11-22 + +Notes: + +- regex + Assume regexes are ground and for zstring. + to repair: + x in R + - get prefix of x that can be in R + - extend prefix by sampled string y, such that prefix(x)y in R + + x not in R: + - assume x is in R, then + - sample prefix of x that is not in R + - sample extension of x that is not in R + - sample prefix of x in R, with extension not in R + +- sequences + +- use length constraints as tabu for updates. + +- alternate to lookahead strategy: + Lookahead repair based of changing leaves: + With each predicate, track the leaves of non-value arguments. + Suppose x is a leaf string used in a violated predicate. + then we can repair x by taking sub-string, or adding a character, + or adding x with an existing constant within the domain of known constants. + or truncating x to the empty string. + Suppose z is a leaf integer. + we can increment, decrement z, set z to -1, 0, or a known bound. + Lookahead works by updating strval1 starting from the leaf. + - create a priority buffer array of vector> based on depth. + - walk from lowest depth up. Reset each inner buffer when processed. Parents always + have higher depth. + - calculate repair/break score when hitting a predicate based on bval1. + - strval1 and bval1 are modified by + - use a global timestamp. + - label each eval subterm by a timestamp that gets set. + - strval0 evaluates to strval1 if timestamp matches global timestamp. --*/ @@ -55,6 +93,42 @@ namespace sls { } bool seq_plugin::is_sat() { + + for (expr* e : ctx.subterms()) { + expr* x, * y, * z = nullptr; + rational r; + // coherence between string / integer functions is delayed + // so we check and enforce it here. + if (seq.str.is_length(e, x) && seq.is_string(x->get_sort())) { + auto sx = strval0(x); + auto ve = ctx.get_value(e); + if (a.is_numeral(ve, r) && r == sx.length()) + continue; + update(e, rational(sx.length())); + return false; + } + if ((seq.str.is_index(e, x, y, z) || seq.str.is_index(e, x, y)) && seq.is_string(x->get_sort())) { + auto sx = strval0(x); + auto sy = strval0(y); + rational val_z, val_e; + if (z) { + VERIFY(a.is_numeral(ctx.get_value(z), val_z)); + } + VERIFY(a.is_numeral(ctx.get_value(e), val_e)); + // case: x is empty, val_z = 0 + if (val_e < 0 && (val_z < 0 || (val_z >= sx.length() && sx.length() > 0))) + continue; + if (val_z.is_unsigned() && rational(sx.indexofu(sy, val_z.get_unsigned())) == val_e) + continue; + if (val_z < 0 || (val_z >= sx.length() && sx.length() > 0)) + update(e, rational(-1)); + else + update(e, rational(sx.indexofu(sy, val_z.get_unsigned()))); + return false; + } + // last-index-of + // str-to-int + } return true; } @@ -66,9 +140,8 @@ namespace sls { if (is_app(e) && to_app(e)->get_family_id() == m_fid && all_of(*to_app(e), [&](expr* arg) { return is_value(arg); })) - get_eval(e).is_value = true; - return; - } + get_eval(e).is_value = true; + } } std::ostream& seq_plugin::display(std::ostream& out) const { @@ -82,7 +155,7 @@ namespace sls { auto* ev = get_eval(t); if (!ev) continue; - out << mk_pp(t, m) << " -> " << ev->val0.svalue; + out << mk_pp(t, m) << " -> \"" << ev->val0.svalue << "\""; if (ev->min_length > 0) out << " min-length: " << ev->min_length; if (ev->max_length < UINT_MAX) @@ -145,6 +218,7 @@ namespace sls { bool seq_plugin::bval1_seq(app* e) { expr* a, *b; + SASSERT(e->get_family_id() == seq.get_family_id()); switch (e->get_decl_kind()) { case OP_SEQ_CONTAINS: VERIFY(seq.str.is_contains(e, a, b)); @@ -177,6 +251,7 @@ namespace sls { NOT_IMPLEMENTED_YET(); break; default: + UNREACHABLE(); break; } return false; @@ -267,6 +342,8 @@ namespace sls { case OP_RE_DERIVATIVE: case OP_STRING_ITOS: case OP_STRING_FROM_CODE: + case OP_STRING_UBVTOS: + case OP_STRING_SBVTOS: verbose_stream() << "strval1 " << mk_bounded_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; @@ -291,8 +368,6 @@ namespace sls { case OP_SEQ_INDEX: case OP_SEQ_LAST_INDEX: case OP_STRING_STOI: - case OP_STRING_UBVTOS: - case OP_STRING_SBVTOS: case OP_STRING_LT: case OP_STRING_LE: case OP_STRING_IS_DIGIT: @@ -411,6 +486,7 @@ namespace sls { bool is_true = ctx.is_true(e); expr* x, * y; VERIFY(m.is_eq(e, x, y)); + verbose_stream() << is_true << ": " << mk_bounded_pp(e, m, 3) << "\n"; if (ctx.is_true(e)) { if (!is_value(x)) m_str_updates.push_back({ x, strval1(y), 1 }); @@ -768,14 +844,40 @@ namespace sls { bool seq_plugin::repair_down_str_concat(app* e) { zstring val_e = strval0(e); unsigned len_e = val_e.length(); - // sample a ranom partition. + // sample a random partition. + // the current sample algorithm isn't uniformly sampling + // each possible partition, but favors what would be a + // normal distribution sbuffer lengths(e->get_num_args(), 0); - for (unsigned i = 0; i < len_e; ++i) - lengths[ctx.rand(lengths.size())]++; - unsigned i = 0, len_prefix = 0; + sbuffer non_values; + unsigned i = 0; + //verbose_stream() << "repair concat " << mk_bounded_pp(e, m) << "\n"; + for (expr* arg : *e) { + ++i; + if (!is_value(arg)) { + non_values.push_back(i - 1); + continue; + } + auto const& arg_val = strval0(arg); + if (arg_val.length() > len_e) + return false; + lengths[i - 1] = arg_val.length(); + len_e -= arg_val.length(); + } + // TODO: take duplications into account + while (len_e > 0 && !non_values.empty()) { + lengths[non_values[ctx.rand(non_values.size())]]++; + --len_e; + } + if (len_e > 0 && non_values.empty()) + return false; + i = 0; + //verbose_stream() << "repair concat2 " << mk_bounded_pp(e, m) << "\n"; + unsigned len_prefix = 0; for (expr* arg : *e) { auto len = lengths[i]; auto val_arg = val_e.extract(len_prefix, len); + //verbose_stream() << "repair concat3 " << mk_bounded_pp(arg, m) << " " << val_arg << "\n"; if (!update(arg, val_arg)) return false; ++i; @@ -785,6 +887,7 @@ namespace sls { } + bool seq_plugin::apply_update() { double sum_scores = 0; for (auto const& [e, val, score] : m_str_updates) @@ -813,9 +916,8 @@ namespace sls { if (is_str_update) { auto [e, value, score] = m_str_updates[i]; - verbose_stream() << "set value " << mk_bounded_pp(e, m) << " := \"" << value << "\"\n"; - if (update(e, value)) { + verbose_stream() << "set value " << mk_bounded_pp(e, m) << " := \"" << value << "\"\n"; m_str_updates.reset(); m_int_updates.reset(); return true; @@ -845,10 +947,12 @@ namespace sls { } bool seq_plugin::update(expr* e, zstring const& value) { - if (is_value(e)) - return false; if (value == strval0(e)) return true; + if (is_value(e)) + return false; + if (get_eval(e).min_length > value.length() || get_eval(e).max_length < value.length()) + return false; strval0(e) = value; ctx.new_value_eh(e); return true; @@ -878,6 +982,34 @@ namespace sls { } } } + for (auto t : ctx.subterms()) { + if (seq.str.is_string(t)) { + auto& ev = get_eval(t); + ev.min_length = strval0(t).length(); + ev.max_length = strval0(t).length(); + } + if (seq.str.is_concat(t)) { + unsigned min_length = 0; + unsigned max_length = 0; + for (expr* arg : *to_app(t)) { + auto& ev = get_eval(arg); + min_length += ev.min_length; + if (ev.max_length < UINT_MAX && max_length != UINT_MAX) + max_length += ev.max_length; + else + max_length = UINT_MAX; + } + auto& ev = get_eval(t); + ev.min_length = std::max(min_length, ev.min_length); + ev.max_length = std::min(max_length, ev.max_length); + } + if (seq.str.is_at(t)) { + auto& ev = get_eval(t); + ev.max_length = 1; + } + // extract with constant length. + + } } void seq_plugin::repair_literal(sat::literal lit) { @@ -894,7 +1026,5 @@ namespace sls { if (seq.is_seq(e)) return get_eval(e).is_value; return m.is_value(e); - } - - + } }