diff --git a/src/smt/seq/seq_nielsen.cpp b/src/smt/seq/seq_nielsen.cpp index 213282831..d439abbe2 100644 --- a/src/smt/seq/seq_nielsen.cpp +++ b/src/smt/seq/seq_nielsen.cpp @@ -24,6 +24,7 @@ Author: #include "ast/ast_pp.h" #include "util/bit_util.h" #include "util/hashtable.h" +#include #include namespace seq { @@ -145,6 +146,149 @@ namespace seq { return true; } + // ----------------------------------------------- + // char_set + // ----------------------------------------------- + + unsigned char_set::char_count() const { + unsigned count = 0; + for (auto const& r : m_ranges) + count += r.length(); + return count; + } + + bool char_set::contains(unsigned c) const { + // binary search over sorted non-overlapping ranges + int lo = 0, hi = static_cast(m_ranges.size()) - 1; + while (lo <= hi) { + int mid = lo + (hi - lo) / 2; + if (c < m_ranges[mid].m_lo) + hi = mid - 1; + else if (c >= m_ranges[mid].m_hi) + lo = mid + 1; + else + return true; + } + return false; + } + + void char_set::add(unsigned c) { + if (m_ranges.empty()) { + m_ranges.push_back(char_range(c)); + return; + } + // binary search for insertion point + int lo = 0, hi = static_cast(m_ranges.size()) - 1; + while (lo <= hi) { + int mid = lo + (hi - lo) / 2; + if (c < m_ranges[mid].m_lo) + hi = mid - 1; + else if (c >= m_ranges[mid].m_hi) + lo = mid + 1; + else + return; // already contained + } + // lo is the insertion point + unsigned idx = static_cast(lo); + bool merge_left = idx > 0 && m_ranges[idx - 1].m_hi == c; + bool merge_right = idx < m_ranges.size() && m_ranges[idx].m_lo == c + 1; + if (merge_left && merge_right) { + m_ranges[idx - 1].m_hi = m_ranges[idx].m_hi; + m_ranges.erase(m_ranges.begin() + idx); + } else if (merge_left) { + m_ranges[idx - 1].m_hi = c + 1; + } else if (merge_right) { + m_ranges[idx].m_lo = c; + } else { + // positional insert: shift elements right and place new element + m_ranges.push_back(char_range()); + for (unsigned k = m_ranges.size() - 1; k > idx; --k) + m_ranges[k] = m_ranges[k - 1]; + m_ranges[idx] = char_range(c); + } + } + + void char_set::add(char_set const& other) { + for (auto const& r : other.m_ranges) { + for (unsigned c = r.m_lo; c < r.m_hi; ++c) + add(c); + } + } + + char_set char_set::intersect_with(char_set const& other) const { + char_set result; + unsigned i = 0, j = 0; + while (i < m_ranges.size() && j < other.m_ranges.size()) { + unsigned lo = std::max(m_ranges[i].m_lo, other.m_ranges[j].m_lo); + unsigned hi = std::min(m_ranges[i].m_hi, other.m_ranges[j].m_hi); + if (lo < hi) + result.m_ranges.push_back(char_range(lo, hi)); + if (m_ranges[i].m_hi < other.m_ranges[j].m_hi) + ++i; + else + ++j; + } + return result; + } + + char_set char_set::complement(unsigned max_char) const { + char_set result; + if (m_ranges.empty()) { + result.m_ranges.push_back(char_range(0, max_char + 1)); + return result; + } + unsigned from = 0; + for (auto const& r : m_ranges) { + if (from < r.m_lo) + result.m_ranges.push_back(char_range(from, r.m_lo)); + from = r.m_hi; + } + if (from <= max_char) + result.m_ranges.push_back(char_range(from, max_char + 1)); + return result; + } + + bool char_set::is_disjoint(char_set const& other) const { + unsigned i = 0, j = 0; + while (i < m_ranges.size() && j < other.m_ranges.size()) { + if (m_ranges[i].m_hi <= other.m_ranges[j].m_lo) + ++i; + else if (other.m_ranges[j].m_hi <= m_ranges[i].m_lo) + ++j; + else + return false; + } + return true; + } + + std::ostream& char_set::display(std::ostream& out) const { + if (m_ranges.empty()) { + out << "{}"; + return out; + } + out << "{ "; + bool first = true; + for (auto const& r : m_ranges) { + if (!first) out << ", "; + first = false; + if (r.is_unit()) { + unsigned c = r.m_lo; + if (c >= 'a' && c <= 'z') + out << (char)c; + else if (c >= 'A' && c <= 'Z') + out << (char)c; + else if (c >= '0' && c <= '9') + out << (char)c; + else + out << "#[" << c << "]"; + } else { + out << "[" << r.m_lo << "-" << (r.m_hi - 1) << "]"; + } + } + out << " }"; + return out; + } + // ----------------------------------------------- // nielsen_edge // ----------------------------------------------- @@ -164,10 +308,23 @@ namespace seq { void nielsen_node::clone_from(nielsen_node const& parent) { m_str_eq.reset(); m_str_mem.reset(); + m_char_diseqs.reset(); + m_char_ranges.reset(); for (auto const& eq : parent.m_str_eq) m_str_eq.push_back(str_eq(eq.m_lhs, eq.m_rhs, eq.m_dep)); for (auto const& mem : parent.m_str_mem) m_str_mem.push_back(str_mem(mem.m_str, mem.m_regex, mem.m_history, mem.m_id, mem.m_dep)); + // clone character disequalities + for (auto const& kv : parent.m_char_diseqs) { + ptr_vector diseqs; + for (euf::snode* s : kv.m_value) + diseqs.push_back(s); + m_char_diseqs.insert(kv.m_key, diseqs); + } + // clone character ranges + for (auto const& kv : parent.m_char_ranges) { + m_char_ranges.insert(kv.m_key, kv.m_value.clone()); + } } void nielsen_node::apply_subst(euf::sgraph& sg, nielsen_subst const& s) { @@ -188,6 +345,90 @@ namespace seq { } } + void nielsen_node::add_char_range(euf::snode* sym_char, char_set const& range) { + SASSERT(sym_char && sym_char->is_unit()); + unsigned id = sym_char->id(); + if (m_char_ranges.contains(id)) { + char_set& existing = m_char_ranges.find(id); + char_set inter = existing.intersect_with(range); + existing = inter; + if (inter.is_empty()) { + m_is_general_conflict = true; + m_reason = backtrack_reason::character_range; + } + } else { + m_char_ranges.insert(id, range.clone()); + } + } + + void nielsen_node::add_char_diseq(euf::snode* sym_char, euf::snode* other) { + SASSERT(sym_char && sym_char->is_unit()); + SASSERT(other && other->is_unit()); + unsigned id = sym_char->id(); + if (!m_char_diseqs.contains(id)) + m_char_diseqs.insert(id, ptr_vector()); + ptr_vector& existing = m_char_diseqs.find(id); + // check for duplicates + for (euf::snode* s : existing) + if (s == other) return; + existing.push_back(other); + } + + void nielsen_node::apply_char_subst(euf::sgraph& sg, char_subst const& s) { + if (!s.m_var) return; + + // replace occurrences of s.m_var with s.m_val in all string constraints + for (unsigned i = 0; i < m_str_eq.size(); ++i) { + str_eq& eq = m_str_eq[i]; + eq.m_lhs = sg.subst(eq.m_lhs, s.m_var, s.m_val); + eq.m_rhs = sg.subst(eq.m_rhs, s.m_var, s.m_val); + eq.sort(); + } + for (unsigned i = 0; i < m_str_mem.size(); ++i) { + str_mem& mem = m_str_mem[i]; + mem.m_str = sg.subst(mem.m_str, s.m_var, s.m_val); + mem.m_regex = sg.subst(mem.m_regex, s.m_var, s.m_val); + } + + unsigned var_id = s.m_var->id(); + + if (s.m_val->is_unit()) { + // symbolic char → symbolic char: check disequalities + if (m_char_diseqs.contains(var_id)) { + ptr_vector& diseqs = m_char_diseqs.find(var_id); + for (euf::snode* d : diseqs) { + if (d == s.m_val) { + m_is_general_conflict = true; + m_reason = backtrack_reason::character_range; + return; + } + } + m_char_diseqs.remove(var_id); + m_char_ranges.remove(var_id); + } + } else { + SASSERT(s.m_val->is_char()); + // symbolic char → concrete char: check range constraints + if (m_char_ranges.contains(var_id)) { + char_set& range = m_char_ranges.find(var_id); + // extract the concrete char value from the s_char snode + unsigned ch_val = 0; + seq_util& seq = sg.get_seq_util(); + expr* unit_expr = s.m_val->get_expr(); + expr* ch_expr = nullptr; + if (unit_expr && seq.str.is_unit(unit_expr, ch_expr)) + seq.is_const_char(ch_expr, ch_val); + if (!range.contains(ch_val)) { + m_is_general_conflict = true; + m_reason = backtrack_reason::character_range; + return; + } + m_char_diseqs.remove(var_id); + m_char_ranges.remove(var_id); + } + } + } + // ----------------------------------------------- // nielsen_graph // ----------------------------------------------- @@ -446,6 +687,20 @@ namespace seq { << dot_html_escape(snode_label(mem.m_regex, m)) << "
"; } + // character ranges + for (auto const& kv : m_char_ranges) { + if (!any) { out << "Cnstr:
"; any = true; } + out << "?" << kv.m_key << " ∈ "; + kv.m_value.display(out); + out << "
"; + } + // character disequalities + for (auto const& kv : m_char_diseqs) { + if (!any) { out << "Cnstr:
"; any = true; } + for (euf::snode* d : kv.m_value) { + out << "?" << kv.m_key << " ≠ ?" << d->id() << "
"; + } + } if (!any) out << "⊤"; // ⊤ (trivially satisfied) @@ -541,6 +796,12 @@ namespace seq { << " → " // mapping arrow << dot_html_escape(snode_label(s.m_replacement, m)); } + for (auto const& cs : e->char_substs()) { + if (!first) out << "
"; + first = false; + out << "?" << cs.m_var->id() + << " → ?" << cs.m_val->id(); + } out << ">"; // colour @@ -1104,6 +1365,17 @@ namespace seq { return m_sg.mk_var(symbol(name.c_str())); } + euf::snode* nielsen_graph::mk_fresh_char_var() { + ++m_stats.m_num_fresh_vars; + std::string name = "?c!" + std::to_string(m_fresh_cnt++); + seq_util& seq = m_sg.get_seq_util(); + ast_manager& m = m_sg.get_manager(); + sort* char_sort = seq.mk_char_sort(); + expr_ref fresh_const(m.mk_fresh_const(name.c_str(), char_sort), m); + expr_ref unit(seq.str.mk_unit(fresh_const), m); + return m_sg.mk(unit); + } + // ----------------------------------------------------------------------- // nielsen_graph: apply_regex_char_split // ----------------------------------------------------------------------- @@ -1643,8 +1915,8 @@ namespace seq { created = true; } - // Branch 2+: for each minterm m_i, x → fresh_char · x' - // where fresh_char is constrained by the minterm + // Branch 2+: for each minterm m_i, x → ?c · x' + // where ?c is a symbolic char constrained by the minterm for (euf::snode* mt : minterms) { if (mt->is_fail()) continue; @@ -1654,13 +1926,15 @@ namespace seq { if (deriv && deriv->is_fail()) continue; euf::snode* fresh_var = mk_fresh_var(); - euf::snode* fresh_char = mk_fresh_var(); + euf::snode* fresh_char = mk_fresh_char_var(); euf::snode* replacement = m_sg.mk_concat(fresh_char, fresh_var); nielsen_node* child = mk_child(node); nielsen_edge* e = mk_edge(node, child, true); nielsen_subst s(first, replacement, mem.m_dep); e->add_subst(s); child->apply_subst(m_sg, s); + // TODO: derive char_set from minterm and add as range constraint + // child->add_char_range(fresh_char, minterm_to_char_set(mt)); created = true; } diff --git a/src/smt/seq/seq_nielsen.h b/src/smt/seq/seq_nielsen.h index a892cbadb..caaa8533d 100644 --- a/src/smt/seq/seq_nielsen.h +++ b/src/smt/seq/seq_nielsen.h @@ -232,6 +232,7 @@ Author: #include "util/vector.h" #include "util/uint_set.h" +#include "util/map.h" #include "ast/ast.h" #include "ast/arith_decl_plugin.h" #include "ast/seq_decl_plugin.h" @@ -291,6 +292,104 @@ namespace seq { bool operator!=(dep_tracker const& other) const { return !(*this == other); } }; + // ----------------------------------------------- + // character range and set types + // mirrors ZIPT's CharacterRange and CharacterSet + // ----------------------------------------------- + + // half-open character interval [lo, hi) + // mirrors ZIPT's CharacterRange + struct char_range { + unsigned m_lo; + unsigned m_hi; // exclusive + + char_range(): m_lo(0), m_hi(0) {} + char_range(unsigned c): m_lo(c), m_hi(c + 1) {} + char_range(unsigned lo, unsigned hi): m_lo(lo), m_hi(hi) { SASSERT(lo <= hi); } + + bool is_empty() const { return m_lo == m_hi; } + bool is_unit() const { return m_hi == m_lo + 1; } + unsigned length() const { return m_hi - m_lo; } + bool contains(unsigned c) const { return c >= m_lo && c < m_hi; } + + bool operator==(char_range const& o) const { return m_lo == o.m_lo && m_hi == o.m_hi; } + bool operator<(char_range const& o) const { + return m_lo < o.m_lo || (m_lo == o.m_lo && m_hi < o.m_hi); + } + }; + + // sorted list of non-overlapping character intervals + // mirrors ZIPT's CharacterSet + class char_set { + svector m_ranges; + public: + char_set() = default; + explicit char_set(char_range const& r) { if (!r.is_empty()) m_ranges.push_back(r); } + + static char_set full(unsigned max_char) { return char_set(char_range(0, max_char + 1)); } + + bool is_empty() const { return m_ranges.empty(); } + bool is_full(unsigned max_char) const { + return m_ranges.size() == 1 && m_ranges[0].m_lo == 0 && m_ranges[0].m_hi == max_char + 1; + } + bool is_unit() const { return m_ranges.size() == 1 && m_ranges[0].is_unit(); } + unsigned first_char() const { SASSERT(!is_empty()); return m_ranges[0].m_lo; } + + svector const& ranges() const { return m_ranges; } + + // total number of characters in the set + unsigned char_count() const; + + // membership test via binary search + bool contains(unsigned c) const; + + // add a single character + void add(unsigned c); + + // union with another char_set + void add(char_set const& other); + + // intersect with another char_set, returns the result + char_set intersect_with(char_set const& other) const; + + // complement relative to [0, max_char] + char_set complement(unsigned max_char) const; + + // check if two sets are disjoint + bool is_disjoint(char_set const& other) const; + + bool operator==(char_set const& other) const { return m_ranges == other.m_ranges; } + + char_set clone() const { char_set r; r.m_ranges = m_ranges; return r; } + + std::ostream& display(std::ostream& out) const; + }; + + // ----------------------------------------------- + // character-level substitution + // mirrors ZIPT's CharSubst + // ----------------------------------------------- + + // maps a symbolic char (s_unit snode) to a concrete or symbolic char + struct char_subst { + euf::snode* m_var; // the symbolic char being substituted (s_unit) + euf::snode* m_val; // replacement: s_char (concrete) or s_unit (symbolic) + + char_subst(): m_var(nullptr), m_val(nullptr) {} + char_subst(euf::snode* var, euf::snode* val): + m_var(var), m_val(val) { + SASSERT(var && var->is_unit()); + SASSERT(val && (val->is_char() || val->is_unit())); + } + + // true when the replacement is a concrete character + bool is_eliminating() const { return m_val && m_val->is_char(); } + + bool operator==(char_subst const& o) const { + return m_var == o.m_var && m_val == o.m_val; + } + }; + // string equality constraint: lhs = rhs // mirrors ZIPT's StrEq (both sides are regex-free snode trees) struct str_eq { @@ -387,6 +486,7 @@ namespace seq { nielsen_node* m_src; nielsen_node* m_tgt; vector m_subst; + vector m_char_subst; // character-level substitutions (mirrors ZIPT's SubstC) ptr_vector m_side_str_eq; // side constraints: string equalities ptr_vector m_side_str_mem; // side constraints: regex memberships bool m_is_progress; // does this edge represent progress? @@ -401,6 +501,9 @@ namespace seq { vector const& subst() const { return m_subst; } void add_subst(nielsen_subst const& s) { m_subst.push_back(s); } + vector const& char_substs() const { return m_char_subst; } + void add_char_subst(char_subst const& s) { m_char_subst.push_back(s); } + void add_side_str_eq(str_eq* eq) { m_side_str_eq.push_back(eq); } void add_side_str_mem(str_mem* mem) { m_side_str_mem.push_back(mem); } @@ -426,6 +529,11 @@ namespace seq { vector m_str_eq; // string equalities vector m_str_mem; // regex memberships + // character constraints (mirrors ZIPT's DisEqualities and CharRanges) + // key: snode id of the s_unit symbolic character + u_map> m_char_diseqs; // ?c != {?d, ?e, ...} + u_map m_char_ranges; // ?c in [lo, hi) + // edges ptr_vector m_outgoing; nielsen_node* m_backedge = nullptr; @@ -455,6 +563,21 @@ namespace seq { void add_str_eq(str_eq const& eq) { m_str_eq.push_back(eq); } void add_str_mem(str_mem const& mem) { m_str_mem.push_back(mem); } + // character constraint access (mirrors ZIPT's DisEqualities / CharRanges) + u_map> const& char_diseqs() const { return m_char_diseqs; } + u_map const& char_ranges() const { return m_char_ranges; } + + // add a character range constraint for a symbolic char. + // intersects with existing range; sets conflict if result is empty. + void add_char_range(euf::snode* sym_char, char_set const& range); + + // add a character disequality: sym_char != other + void add_char_diseq(euf::snode* sym_char, euf::snode* other); + + // apply a character-level substitution to all constraints. + // checks disequalities and ranges; sets conflict on violation. + void apply_char_subst(euf::sgraph& sg, char_subst const& s); + // edge access ptr_vector const& outgoing() const { return m_outgoing; } void add_outgoing(nielsen_edge* e) { m_outgoing.push_back(e); } @@ -655,6 +778,10 @@ namespace seq { // create a fresh variable with a unique name euf::snode* mk_fresh_var(); + // create a fresh symbolic character: seq.unit(fresh_char_const) + // analogous to ZIPT's SymCharToken creation + euf::snode* mk_fresh_char_var(); + // deterministic modifier: var = ε, same-head cancel bool apply_det_modifier(nielsen_node* node); diff --git a/src/test/nseq_zipt.cpp b/src/test/nseq_zipt.cpp index 755c05c93..fac62fa28 100644 --- a/src/test/nseq_zipt.cpp +++ b/src/test/nseq_zipt.cpp @@ -546,8 +546,95 @@ static void test_zipt_parikh() { std::cout << " ok\n"; } +// ----------------------------------------------------------------------- +// Tricky string equation benchmarks (hand-crafted, beyond ZIPT suite). +// +// SAT witnesses are noted inline. UNSAT arguments are grouped by type: +// [first-char] — immediate first-character mismatch +// [after-cancel] — mismatch exposed after prefix/suffix cancellation +// [induction] — recursive unrolling forces a = b contradiction +// [parity] — length parity (odd vs even) rules out all solutions +// [midpoint] — equal length forced by lengths; midpoint char differs +// ----------------------------------------------------------------------- +static void test_tricky_str_equations() { + std::cout << "test_tricky_str_equations\n"; + + // --- SAT: commutativity / rotation / symmetry --- + + // XY = YX (classic commutativity; witness: X="ab", Y="abab") + VERIFY(eq_sat("XY", "YX")); + + // Xab = abX (X commutes with the word "ab"; witness: X="ab") + VERIFY(eq_sat("Xab", "abX")); + + // XaY = YaX (swap-symmetric; witness: X=Y=any, e.g. X=Y="b") + VERIFY(eq_sat("XaY", "YaX")); + + // XYX = YXY (Markov-type; witness: X=Y) + VERIFY(eq_sat("XYX", "YXY")); + + // XYZ = ZYX (reverse-palindrome; witness: X="a",Y="b",Z="a") + VERIFY(eq_sat("XYZ", "ZYX")); + + // XabY = YabX (rotation-like; witness: X="",Y="ab") + VERIFY(eq_sat("XabY", "YabX")); + + // aXYa = aYXa (cancel outer 'a'; reduces to XY=YX; witness: X=Y="") + VERIFY(eq_sat("aXYa", "aYXa")); + + // XaXb = YaYb (both halves share variable; witness: X=Y) + VERIFY(eq_sat("XaXb", "YaYb")); + + // abXba = Xabba (witness: X="" gives "abba"="abba") + VERIFY(eq_sat("abXba", "Xabba")); + + // --- UNSAT: first-character mismatch --- + + // abXba = baXab (LHS starts 'a', RHS starts 'b') + VERIFY(eq_unsat("abXba", "baXab")); + + // XabX = XbaX (cancel X prefix/suffix → "ab"="ba"; 'a'≠'b') + VERIFY(eq_unsat("XabX", "XbaX")); + + // --- UNSAT: mismatch exposed after cancellation --- + + // XaYb = XbYa (cancel X prefix → aYb=bYa; first char 'a'≠'b') + VERIFY(eq_unsat("XaYb", "XbYa")); + + // XaYbX = XbYaX (cancel X prefix+suffix → aYb=bYa; first char 'a'≠'b') + VERIFY(eq_unsat("XaYbX", "XbYaX")); + + // XaXbX = XbXaX (cancel X prefix+suffix → aXb=bXa; first char 'a'≠'b') + VERIFY(eq_unsat("XaXbX", "XbXaX")); + + // --- UNSAT: induction --- + + // aXb = Xba (forces X=a^n; final step requires a=b) + VERIFY(eq_unsat("aXb", "Xba")); + + // XaY = YbX (a≠b; recursive unrolling forces a=b) + VERIFY(eq_unsat("XaY", "YbX")); + + // --- UNSAT: length parity --- + + // XaX = YY (|XaX|=2|X|+1 is odd; |YY|=2|Y| is even) + VERIFY(eq_unsat("XaX", "YY")); + + // XaaX = YbY (|XaaX|=2|X|+2 is even; |YbY|=2|Y|+1 is odd) + VERIFY(eq_unsat("XaaX", "YbY")); + + // --- UNSAT: midpoint argument --- + + // XaX = YbY (equal length forces |X|=|Y|; midpoint position |X| + // holds 'a' in LHS and 'b' in RHS, but 'a'≠'b') + VERIFY(eq_unsat("XaX", "YbY")); + + std::cout << " ok\n"; +} + void tst_nseq_zipt() { test_zipt_str_equations(); + test_tricky_str_equations(); test_zipt_regex_ground(); test_zipt_str_membership(); test_zipt_parikh();