diff --git a/src/smt/seq/seq_nielsen.cpp b/src/smt/seq/seq_nielsen.cpp index af599d285..098f58518 100644 --- a/src/smt/seq/seq_nielsen.cpp +++ b/src/smt/seq/seq_nielsen.cpp @@ -1009,6 +1009,8 @@ namespace seq { } // pass 2: detect symbol clashes, empty-propagation, and prefix cancellation + // unit equalities produced by unit-unit prefix/suffix splits below + svector unit_eqs; for (str_eq& eq : m_str_eq) { if (!eq.m_lhs || !eq.m_rhs) continue; @@ -1046,6 +1048,13 @@ namespace seq { m_reason = backtrack_reason::symbol_clash; return simplify_result::conflict; } + else if (lt->is_char_or_unit() && rt->is_char_or_unit()) { + // unit(a) ++ rest1 == unit(b) ++ rest2: split into unit(a)==unit(b) and rest1==rest2 + str_eq ueq(lt, rt, eq.m_dep); + ueq.sort(); + unit_eqs.push_back(ueq); + ++prefix; + } else break; } @@ -1063,6 +1072,13 @@ namespace seq { m_reason = backtrack_reason::symbol_clash; return simplify_result::conflict; } + else if (lt->is_char_or_unit() && rt->is_char_or_unit()) { + // rest1 ++ unit(a) == rest2 ++ unit(b): split into unit(a)==unit(b) and rest1==rest2 + str_eq ueq(lt, rt, eq.m_dep); + ueq.sort(); + unit_eqs.push_back(ueq); + ++suffix; + } else break; } @@ -1126,6 +1142,12 @@ namespace seq { } } + // flush unit equalities generated by prefix/suffix unit splits + for (str_eq const& ueq : unit_eqs) { + m_str_eq.push_back(ueq); + changed = true; + } + // pass 3: power simplification (mirrors ZIPT's LcpCompression + // SimplifyPowerElim + SimplifyPowerSingle) for (str_eq& eq : m_str_eq) { diff --git a/src/test/seq_nielsen.cpp b/src/test/seq_nielsen.cpp index 9db2a6c1e..f1d1f759f 100644 --- a/src/test/seq_nielsen.cpp +++ b/src/test/seq_nielsen.cpp @@ -3638,6 +3638,140 @@ static void test_var_bound_watcher_multi_var() { std::cout << " ok\n"; } +// test simplify_and_init: unit-unit prefix split +// unit(a) ++ x = unit(b) ++ y -> unit(a)==unit(b), x==y +static void test_simplify_unit_prefix_split() { + std::cout << "test_simplify_unit_prefix_split\n"; + ast_manager m; + reg_decl_plugins(m); + euf::egraph eg(m); + euf::sgraph sg(m, eg); + seq_util seq(m); + + dummy_simple_solver solver; + seq::nielsen_graph ng(sg, solver); + + // create symbolic char variables a, b (non-concrete -> s_unit) + sort* char_sort = seq.mk_char_sort(); + expr_ref sym_a(m.mk_const(symbol("a"), char_sort), m); + expr_ref sym_b(m.mk_const(symbol("b"), char_sort), m); + expr_ref unit_a_expr(seq.str.mk_unit(sym_a), m); + expr_ref unit_b_expr(seq.str.mk_unit(sym_b), m); + euf::snode* ua = sg.mk(unit_a_expr); + euf::snode* ub = sg.mk(unit_b_expr); + SASSERT(ua->is_unit()); + SASSERT(ub->is_unit()); + + euf::snode* x = sg.mk_var(symbol("x"), sg.get_str_sort()); + euf::snode* y = sg.mk_var(symbol("y"), sg.get_str_sort()); + + // ua ++ x = ub ++ y + euf::snode* lhs = sg.mk_concat(ua, x); + euf::snode* rhs = sg.mk_concat(ub, y); + + seq::nielsen_node* node = ng.mk_node(); + seq::dep_tracker dep = nullptr; + node->add_str_eq(seq::str_eq(lhs, rhs, dep)); + + auto sr = node->simplify_and_init(); + SASSERT(sr == seq::simplify_result::proceed); + // original eq stripped to x==y, plus a new unit(a)==unit(b) eq + SASSERT(node->str_eqs().size() == 2); + // at least one eq has both sides as unit or var (the unit equality) + bool found_unit_eq = false; + for (auto const& eq : node->str_eqs()) { + if (eq.m_lhs && eq.m_rhs && + eq.m_lhs->is_char_or_unit() && eq.m_rhs->is_char_or_unit()) + found_unit_eq = true; + } + SASSERT(found_unit_eq); + std::cout << " ok\n"; +} + +// test simplify_and_init: unit-unit prefix split with empty rest on rhs +// unit(a) ++ x = unit(b) -> unit(a)==unit(b), x==empty +static void test_simplify_unit_prefix_split_empty_rest() { + std::cout << "test_simplify_unit_prefix_split_empty_rest\n"; + ast_manager m; + reg_decl_plugins(m); + euf::egraph eg(m); + euf::sgraph sg(m, eg); + seq_util seq(m); + + dummy_simple_solver solver; + seq::nielsen_graph ng(sg, solver); + + sort* char_sort = seq.mk_char_sort(); + expr_ref sym_a(m.mk_const(symbol("a"), char_sort), m); + expr_ref sym_b(m.mk_const(symbol("b"), char_sort), m); + expr_ref unit_a_expr(seq.str.mk_unit(sym_a), m); + expr_ref unit_b_expr(seq.str.mk_unit(sym_b), m); + euf::snode* ua = sg.mk(unit_a_expr); + euf::snode* ub = sg.mk(unit_b_expr); + + euf::snode* x = sg.mk_var(symbol("x"), sg.get_str_sort()); + + // ua ++ x = ub (rhs has no rest after unit) + euf::snode* lhs = sg.mk_concat(ua, x); + + seq::nielsen_node* node = ng.mk_node(); + seq::dep_tracker dep = nullptr; + node->add_str_eq(seq::str_eq(lhs, ub, dep)); + + auto sr = node->simplify_and_init(); + // unit(a)==unit(b) and x==empty are produced; x==empty forces x->epsilon and satisfied + SASSERT(sr == seq::simplify_result::satisfied || sr == seq::simplify_result::proceed); + std::cout << " ok\n"; +} + +// test simplify_and_init: unit-unit suffix split +// x ++ unit(a) = y ++ unit(b) -> unit(a)==unit(b), x==y +static void test_simplify_unit_suffix_split() { + std::cout << "test_simplify_unit_suffix_split\n"; + ast_manager m; + reg_decl_plugins(m); + euf::egraph eg(m); + euf::sgraph sg(m, eg); + seq_util seq(m); + + dummy_simple_solver solver; + seq::nielsen_graph ng(sg, solver); + + sort* char_sort = seq.mk_char_sort(); + expr_ref sym_a(m.mk_const(symbol("a"), char_sort), m); + expr_ref sym_b(m.mk_const(symbol("b"), char_sort), m); + expr_ref unit_a_expr(seq.str.mk_unit(sym_a), m); + expr_ref unit_b_expr(seq.str.mk_unit(sym_b), m); + euf::snode* ua = sg.mk(unit_a_expr); + euf::snode* ub = sg.mk(unit_b_expr); + SASSERT(ua->is_unit()); + SASSERT(ub->is_unit()); + + euf::snode* x = sg.mk_var(symbol("x"), sg.get_str_sort()); + euf::snode* y = sg.mk_var(symbol("y"), sg.get_str_sort()); + + // x ++ ua = y ++ ub + euf::snode* lhs = sg.mk_concat(x, ua); + euf::snode* rhs = sg.mk_concat(y, ub); + + seq::nielsen_node* node = ng.mk_node(); + seq::dep_tracker dep = nullptr; + node->add_str_eq(seq::str_eq(lhs, rhs, dep)); + + auto sr = node->simplify_and_init(); + SASSERT(sr == seq::simplify_result::proceed); + // original eq stripped to x==y, plus a new unit(a)==unit(b) eq + SASSERT(node->str_eqs().size() == 2); + bool found_unit_eq = false; + for (auto const& eq : node->str_eqs()) { + if (eq.m_lhs && eq.m_rhs && + eq.m_lhs->is_char_or_unit() && eq.m_rhs->is_char_or_unit()) + found_unit_eq = true; + } + SASSERT(found_unit_eq); + std::cout << " ok\n"; +} + void tst_seq_nielsen() { test_dep_tracker(); test_str_eq(); @@ -3759,4 +3893,7 @@ void tst_seq_nielsen() { test_assert_root_constraints_to_solver(); test_assert_root_constraints_once(); test_var_bound_watcher_multi_var(); + test_simplify_unit_prefix_split(); + test_simplify_unit_prefix_split_empty_rest(); + test_simplify_unit_suffix_split(); }