From 7eceeff34914c8d73eeebf5bbe6c5f171d0f3842 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Mon, 8 Mar 2021 10:09:04 -0800
Subject: [PATCH] move branch of unit variable

---
 examples/python/hs.py            | 136 +++++++++----
 src/ast/rewriter/seq_eq_solver.h |   3 +-
 src/smt/seq_eq_solver.cpp        | 336 +++++++++++++------------------
 src/smt/theory_seq.h             |   1 -
 4 files changed, 238 insertions(+), 238 deletions(-)

diff --git a/examples/python/hs.py b/examples/python/hs.py
index 64b49b6de..d32115360 100644
--- a/examples/python/hs.py
+++ b/examples/python/hs.py
@@ -8,9 +8,12 @@
 from z3 import *
 import random
 
+counter = 0
 
 def add_def(s, fml):
-    name = Bool(f"f{fml}")
+    global counter
+    name = Bool(f"def-{counter}")
+    counter += 1
     s.add(name == fml)
     return name
 
@@ -52,7 +55,8 @@ def count_sets_by_size(sets):
                 
 class Soft:
     def __init__(self, soft):
-        self.formulas = soft
+        self.formulas = set(soft)
+        self.original_soft = soft.copy()
         self.offset = 0
         self.init_names()
 
@@ -60,6 +64,11 @@ class Soft:
         self.name2formula = { Bool(f"s{s}") : s for s in self.formulas }
         self.formula2name = { s : v for (v, s) in self.name2formula.items() }
 
+#
+# TODO: try to replace this by a recursive invocation of HsMaxSAT
+# such that the invocation is incremental with respect to adding constraints
+# and has resource bounded invocation.
+# 
 class HsPicker:
     def __init__(self, soft):
         self.soft = soft
@@ -75,7 +84,7 @@ class HsPicker:
                 hs = hs | { h }
         print("approximate hitting set", len(hs), "smallest possible size", lo)
         return hs, lo
-
+    
     #
     # This can improve lower bound, but is expensive.
     # Note that Z3 does not work well for hitting set optimization.
@@ -86,6 +95,8 @@ class HsPicker:
     #
 
     def pick_hs(self, Ks, lo):
+        if len(Ks) == 0:
+            return set(), lo
         if self.opt_backoff_count < self.opt_backoff_limit:
             self.opt_backoff_count += 1
             return self.pick_hs_(Ks, lo)
@@ -100,7 +111,7 @@ class HsPicker:
         if is_sat == sat:
             mdl = opt.model()
             hs = [self.soft.name2formula[n] for n in self.soft.formula2name.values() if is_true(mdl.eval(n))]
-            return hs, lo
+            return set(hs), lo
         else:
             print("Timeout", self.timeout_value, "lo", lo, "limit", self.opt_backoff_limit)
             self.opt_backoff_limit += 1
@@ -113,17 +124,18 @@ class HsMaxSAT:
         
     def __init__(self, soft, s):
         self.s = s                    # solver object
-        self.original_soft = soft
         self.soft = Soft(soft)        # Soft constraints
         self.hs = HsPicker(self.soft) # Pick a hitting set
-        self.mdl = None               # Current best model
+        self.model = None               # Current best model
         self.lo = 0                   # Current lower bound
         self.hi = len(soft)           # Current upper bound
         self.Ks = []                  # Set of Cores
         self.Cs = []                  # Set of correction sets
         self.small_set_size = 6
-        self.small_set_threshold = 2
-        self.num_max_res_failures = 0        
+        self.small_set_threshold = 1
+        self.num_max_res_failures = 0
+        self.corr_set_enabled = True
+        self.patterns = []
 
     def has_many_small_sets(self, sets):
         small_count = len([c for c in sets if len(c) <= self.small_set_size])
@@ -149,7 +161,6 @@ class HsMaxSAT:
         self.Ks = []
         self.Cs = []
         self.lo -= num_cores_relaxed
-        self.hi -= num_cores_relaxed
         print("New offset", self.soft.offset)
                 
     def maxres(self):
@@ -160,26 +171,33 @@ class HsMaxSAT:
         if self.has_many_small_sets(self.Ks):
             self.num_max_res_failures = 0
             cores = self.get_small_disjoint_sets(self.Ks)
-            self.soft.formulas = set(self.soft.formulas)
             for core in cores:
-                self.small_set_size = min(self.small_set_size, len(core) - 2)
+                self.small_set_size = max(4, min(self.small_set_size, len(core) - 2))
                 relax_core(self.s, core, self.soft.formulas)
             self.reinit_soft(len(cores))
+            self.corr_set_enabled = True
             return
         #
         # If there are sufficiently many small correction sets, then
         # we reduce the soft constraints by dual maxres (IJCAI 2014)
+        #
+        # TODO: the heuristic for when to invoking correction set restriction
+        # needs fine-tuning. For example, the if min(Ks)*optimality_gap < min(Cs)*(max(SS))
+        # we might want to prioritize core relaxation to make progress with less overhead.
+        # here: max(SS) = |Soft|-min(Cs) is the size of the maximal satisfying subset
+        # the optimality gap is self.hi - self.offset
+        # which is a bound on how many cores have to be relaxed before determining optimality.
         # 
-        if self.has_many_small_sets(self.Cs):
+        if self.corr_set_enabled and self.has_many_small_sets(self.Cs):
             self.num_max_res_failures = 0
             cs = self.get_small_disjoint_sets(self.Cs)
-            self.soft.formulas = set(self.soft.formulas)            
             for corr_set in cs:
                 print("restrict cs", len(corr_set))
-                self.small_set_size = min(self.small_set_size, len(corr_set) - 2)
+                self.small_set_size = max(4, min(self.small_set_size, len(corr_set) - 2))
                 restrict_cs(self.s, corr_set, self.soft.formulas)
-                s.add(Or(corr_set))
+                self.s.add(Or(corr_set))
             self.reinit_soft(0)
+            self.corr_set_enabled = False
             return
         #
         # Increment the failure count. If the failure count reaches a threshold
@@ -197,34 +215,55 @@ class HsMaxSAT:
     def save_model(self):
         # 
         # You can save a model here.
-        # For example, add the string: self.mdl.sexpr()
+        # For example, add the string: self.model.sexpr()
         # to a file, or print bounds in custom format.
         #
         # print(f"Bound: {self.lo}")
-        # for f in self.original_soft:
-        #     print(f"{f} := {self.mdl.eval(f)}")
+        # for f in self.soft.original_soft:
+        #     print(f"{f} := {self.model.eval(f)}")
         pass
 
+    def add_pattern(self, orig_cs):
+        named = { f"{f}" : f for f in self.soft.original_soft }
+        sorted_names = sorted(named.keys())
+        sorted_soft = [named[f] for f in sorted_names]
+        bits = [1 if f not in orig_cs else 0 for f in sorted_soft]
+        def eq_bits(b1, b2):
+            return all(b1[i] == b2[i] for i in range(len(b1)))
+        def num_overlaps(b1, b2):
+            return sum(b1[i] == b2[i] for i in range(len(b1)))
+        
+        if not any(eq_bits(b, bits) for b in self.patterns):
+            if len(self.patterns) > 0:
+                print(num_overlaps(bits, self.patterns[-1]), len(bits), bits)
+            self.patterns += [bits]
+            counts = [sum(b[i] for b in self.patterns) for i in range(len(bits))]
+            print(counts)
+                
+
     def improve(self, new_model):
         mss = { f for f in self.soft.formulas if is_true(new_model.eval(f)) }
-        cs = set(self.soft.formulas) - mss
+        cs = self.soft.formulas - mss
         self.Cs += [cs]
-        cost = len(cs)
-        if self.mdl is None:
-            self.mdl = new_model
+        orig_cs = { f for f in self.soft.original_soft if not is_true(new_model.eval(f)) }
+        cost = len(orig_cs) 
+        if self.model is None:
+            self.model = new_model
         if cost <= self.hi:
+            self.add_pattern(orig_cs)
             print("improve", self.hi, cost)
-            self.mdl = new_model
+            self.model = new_model
             self.save_model()
         if cost < self.hi:
             self.hi = cost
-        assert self.mdl
+        assert self.model
 
-    def local_mss(self, hi, new_model):
+    def local_mss(self, new_model):
         mss = { f for f in self.soft.formulas if is_true(new_model.eval(f)) }
-        ps = set(self.soft.formulas) - mss
+        ps = self.soft.formulas - mss
         backbones = set()
         qs = set()
+        backbone2core = {}
         while len(ps) > 0:
             p = random.choice([p for p in ps])
             ps = ps - { p }
@@ -249,30 +288,43 @@ class HsMaxSAT:
                 qs = qs - rs
                 self.improve(mdl)
             elif is_sat == unsat:
+                core = set()
+                for c in self.s.unsat_core():
+                    if c in backbone2core:
+                        core = core | backbone2core[c]
+                    elif not p.eq(c):
+                        core = core | { c }
+                self.Ks += [core]
+                backbone2core[Not(p)] = core
                 backbones = backbones | { Not(p) }
             else:
                 qs = qs | { p }
         if len(qs) > 0:
             print("Number undetermined", len(qs))
+        
 
     def get_cores(self, hs):
         core = self.s.unsat_core()
-        remaining = set(self.soft.formulas) - set(core) - set(hs)
+        remaining = self.soft.formulas - hs
         num_cores = 0
         cores = [core]
         if len(core) == 0:
-            self.lo = self.hi
-            return []
+            self.lo = self.hi - self.soft.offset
+            return
         print("new core of size", len(core))    
         while True:        
             is_sat = self.s.check(remaining)
             if unsat == is_sat:
                 core = self.s.unsat_core()
                 print("new core of size", len(core))
+                if len(core) == 0:
+                    self.lo = self.hi - self.soft.offset
+                    return
                 cores += [core]
-                remaining = remaining - set(core)
+                h = random.choice([c for c in core])                
+                remaining = remaining - { h }
             elif sat == is_sat and num_cores == len(cores):
-                self.local_mss(self.hi, self.s.model())
+                self.local_mss(self.s.model())
                 break
             elif sat == is_sat:
                 self.improve(self.s.model())
@@ -283,37 +335,33 @@ class HsMaxSAT:
                 # The new hitting set contains at least one new element
                 # from the original cores
                 #
-                hs = set(hs)
-                for i in range(num_cores, len(cores)):
-                    h = random.choice([c for c in cores[i]])
-                    hs = hs | { h }
-                remaining = set(self.soft.formulas) - set(core) - set(hs)
+                hs = hs | { random.choice([c for c in cores[i]]) for i in range(num_cores, len(cores)) }
+                remaining = self.soft.formulas - hs
                 num_cores = len(cores)
             else:
                 print(is_sat)
                 break
-        return cores
+        self.Ks += [set(core) for core in cores]
+        print("total number of cores", len(self.Ks))
+        print("total number of correction sets", len(self.Cs))
 
     def step(self):
         soft = self.soft
         hs = self.pick_hs()
-        is_sat = self.s.check(set(soft.formulas) - set(hs))    
+        is_sat = self.s.check(soft.formulas - set(hs))    
         if is_sat == sat:
             self.improve(self.s.model())
         elif is_sat == unsat:
-            cores = self.get_cores(hs)            
-            self.Ks += [set(core) for core in cores]
-            print("total number of cores", len(self.Ks))
-            print("total number of correction sets", len(self.Cs))
+            self.get_cores(hs)            
         else:
             print("unknown")
-        print("maxsat [", self.lo + soft.offset, ", ", self.hi + soft.offset, "]","offset", soft.offset)
+        print("maxsat [", self.lo + soft.offset, ", ", self.hi, "]","offset", soft.offset)
         count_sets_by_size(self.Ks)
         count_sets_by_size(self.Cs)
         self.maxres()
 
     def run(self):
-        while self.lo < self.hi:
+        while self.lo + self.soft.offset < self.hi:
             self.step()
 
                 
diff --git a/src/ast/rewriter/seq_eq_solver.h b/src/ast/rewriter/seq_eq_solver.h
index 1053b143d..c4d98d0a3 100644
--- a/src/ast/rewriter/seq_eq_solver.h
+++ b/src/ast/rewriter/seq_eq_solver.h
@@ -79,7 +79,6 @@ namespace seq {
 
 
         bool branch_unit_variable(eqr const& e);
-        bool branch_unit_variable(expr* X, expr_ref_vector const& units);
 
 
         /**
@@ -156,6 +155,8 @@ namespace seq {
 
         bool can_align_from_lhs_aux(expr_ref_vector const& ls, expr_ref_vector const& rs);
         bool can_align_from_rhs_aux(expr_ref_vector const& ls, expr_ref_vector const& rs);
+
+        bool branch_unit_variable(expr* X, expr_ref_vector const& units);
         
     };
 
diff --git a/src/smt/seq_eq_solver.cpp b/src/smt/seq_eq_solver.cpp
index 9776797c6..2a9ed15d1 100644
--- a/src/smt/seq_eq_solver.cpp
+++ b/src/smt/seq_eq_solver.cpp
@@ -125,145 +125,6 @@ expr* theory_seq::expr2rep(expr* e) {
     return ctx.get_enode(e)->get_root()->get_expr();
 }
 
-
-#if 0
-
-/**
-   \brief
-
-   This step performs destructive superposition
-
-   Based on the implementation it would do the following:
-  
-   e:   l1 + l2 + l3 + l = r1 + r2 + r
-   G |- len(l1) = len(l2) = len(r1) = 0
-   e':  l1 + l2 + l3 + l = r3 + r'         occurs prior to e among equations
-   G |- len(r3) = len(r2)
-   r2, r3 are not identical
-   ----------------------------------
-   e'' : r3 + r' = r1 + r2 + r
-
-   e:   l1 + l2 + l3 + l = r1 + r2 + r
-   G |- len(l1) = len(l2) = len(r1) = 0
-   e':  l1 + l2 + l3 + l = r3 + r'         occurs prior to e among equations
-   G |- len(r3) = len(r2) + offset
-   r2, r3 are not identical
-   ----------------------------------
-   e'' : r3 + r' = r1 + r2 + r
-
-   NB, this doesn't make sense because e'' is just e', which already occurs.
-   It doesn't inherit the constraints from e either, which would get lost.
-
-   NB, if len(r3) = len(r2) would be used, then the justification for the equality
-   needs to be tracked in dependencies.
-    
-   TODO: propagate length offsets for last vars
-
-*/
-bool theory_seq::find_better_rep(expr_ref_vector const& ls, expr_ref_vector const& rs, unsigned idx,
-                                 dependency*& deps, expr_ref_vector & res) {
-
-    // disabled until functionality is clarified
-    return false;
-
-    if (ls.empty() || rs.empty())
-        return false;
-    expr* l_fst = find_fst_non_empty_var(ls);
-    expr* r_fst = find_fst_non_empty_var(rs);
-    if (!r_fst) return false;
-    expr_ref len_r_fst = mk_len(r_fst);
-    expr_ref len_l_fst(m);
-    enode * root2;
-    if (!ctx.e_internalized(len_r_fst)) {
-        return false;
-    }
-    if (l_fst) {
-        len_l_fst = mk_len(l_fst);
-    }
-
-    root2 = get_root(len_r_fst);
-
-    // Offset = 0, No change
-    if (l_fst && get_root(len_l_fst) == root2) {
-        TRACE("seq", tout << "(" << mk_pp(l_fst, m) << ", " << mk_pp(r_fst, m) << ")\n";);
-        return false;
-    }
-
-    // Offset = 0, Changed
-
-    for (unsigned i = 0; i < idx; ++i) {
-        depeq const& e = m_eqs[i];
-        if (e.ls != ls) continue;
-        expr* nl_fst = nullptr;
-        if (e.rs.size() > 1 && is_var(e.rs.get(0)))
-            nl_fst = e.rs.get(0);
-        if (nl_fst && nl_fst != r_fst && root2 == get_root(mk_len(nl_fst))) {
-            res.reset();
-            res.append(e.rs);
-            deps = m_dm.mk_join(e.dep(), deps);
-            return true;
-        }
-    }
-    // Offset != 0, No change
-    if (l_fst && ctx.e_internalized(len_l_fst)) {
-        enode * root1 = get_root(len_l_fst);
-        if (m_offset_eq.contains(root1, root2)) {
-            TRACE("seq", tout << "(" << mk_pp(l_fst, m) << ", " << mk_pp(r_fst,m) << ")\n";);
-            return false;
-        }
-    }
-    // Offset != 0, Changed
-    if (m_offset_eq.contains(root2)) {
-        for (unsigned i = 0; i < idx; ++i) {
-            depeq const& e = m_eqs[i];
-            if (e.ls != ls) continue;
-            expr* nl_fst = nullptr;
-            if (e.rs.size() > 1 && is_var(e.rs.get(0)))
-                nl_fst = e.rs.get(0);
-            if (nl_fst && nl_fst != r_fst) {
-                expr_ref len_nl_fst = mk_len(nl_fst);
-                if (ctx.e_internalized(len_nl_fst)) {
-                    enode * root1 = get_root(len_nl_fst);
-                    if (m_offset_eq.contains(root2, root1)) {
-                        res.reset();
-                        res.append(e.rs);
-                        deps = m_dm.mk_join(e.dep(), deps);
-                        return true;
-                    }
-                }
-            }
-        }
-    }
-    return false;
-}
-
-int theory_seq::find_fst_non_empty_idx(expr_ref_vector const& xs) {
-    for (unsigned i = 0; i < xs.size(); ++i) {
-        expr* x = xs[i];
-        if (!is_var(x)) 
-            return -1;
-        expr_ref e = mk_len(x);
-        if (ctx.e_internalized(e)) {
-            enode* root = ctx.get_enode(e)->get_root();
-            rational val;
-            if (m_autil.is_numeral(root->get_expr(), val) && val.is_zero()) {
-                continue;
-            }
-        }
-        return i;
-    }
-    return -1;
-}
-
-expr* theory_seq::find_fst_non_empty_var(expr_ref_vector const& x) {
-    int i = find_fst_non_empty_idx(x);
-    if (i >= 0)
-        return x[i];
-    return nullptr;
-}
-
-#endif
-
 bool theory_seq::has_len_offset(expr_ref_vector const& ls, expr_ref_vector const& rs, int & offset) {
 
     if (ls.empty() || rs.empty()) 
@@ -597,7 +458,8 @@ bool theory_seq::branch_binary_variable(depeq const& e) {
     if (lenX <= rational(ys.size())) {
         expr_ref_vector Ys(m);
         Ys.append(ys.size(), ys.c_ptr());
-        if (branch_unit_variable(e.dep(), x, Ys))
+        m_eq_deps = e.dep();
+        if (m_eq.branch_unit_variable(x, Ys))
             return true;
     }
     expr_ref le(m_autil.mk_le(mk_len(x), m_autil.mk_int(ys.size())), m);
@@ -625,67 +487,17 @@ bool theory_seq::branch_binary_variable(depeq const& e) {
 bool theory_seq::branch_unit_variable() {
     bool result = false;
     for (auto const& e : m_eqs) {
-#if 0
-        eqr er(e.ls, e.rs);
-        m_eq_deps = e.deps;
+        seq::eqr er(e.ls, e.rs);
+        m_eq_deps = e.dep();
         if (m_eq.branch(0, er)) {
             result = true;
             break;
         }
-#else
-        if (is_unit_eq(e.ls, e.rs) && 
-            branch_unit_variable(e.dep(), e.ls[0], e.rs)) {
-            result = true;
-            break;
-        }
-        if (is_unit_eq(e.rs, e.ls) && 
-            branch_unit_variable(e.dep(), e.rs[0], e.ls)) {
-            result = true;
-            break;
-        }
-#endif
     }
     CTRACE("seq", result, tout << "branch unit variable\n";);
     return result;
 }
 
-bool theory_seq::branch_unit_variable(dependency* dep, expr* X, expr_ref_vector const& units) {
-    SASSERT(is_var(X));
-    rational lenX;
-    if (!get_length(X, lenX)) {
-        TRACE("seq", tout << "enforce length on " << mk_bounded_pp(X, m, 2) << "\n";);
-        add_length_to_eqc(X);
-        return true;
-    }
-    if (lenX > rational(units.size())) {
-        expr_ref le(m_autil.mk_le(mk_len(X), m_autil.mk_int(units.size())), m);
-        TRACE("seq", tout << "propagate length on " << mk_bounded_pp(X, m, 2) << "\n";);
-        propagate_lit(dep, 0, nullptr, mk_literal(le));
-        return true;
-    }
-    SASSERT(lenX.is_unsigned());
-    unsigned lX = lenX.get_unsigned();
-    if (lX == 0) {
-        TRACE("seq", tout << "set empty length " << mk_bounded_pp(X, m, 2) << "\n";);
-        set_empty(X);
-        return true;
-    }
-
-    literal lit = mk_eq(m_autil.mk_int(lX), mk_len(X), false);
-    switch (ctx.get_assignment(lit)) {
-    case l_true: {
-        expr_ref R = mk_concat(lX, units.c_ptr(), X->get_sort());     
-        return propagate_eq(dep, lit, X, R);
-    }
-    case l_undef: 
-        TRACE("seq", tout << "set phase " << mk_pp(X, m) << "\n";);
-        ctx.mark_as_relevant(lit);
-        ctx.force_phase(lit);
-        return true;
-    default:
-        return false;
-    }
-}
 
 bool theory_seq::branch_ternary_variable() {
     for (auto const& e : m_eqs) {
@@ -1337,3 +1149,143 @@ bool theory_seq::solve_nth_eq(expr_ref_vector const& ls, expr_ref_vector const&
 }
 
 
+
+
+
+#if 0
+
+/**
+   \brief
+
+   This step performs destructive superposition
+
+   Based on the implementation it would do the following:
+  
+   e:   l1 + l2 + l3 + l = r1 + r2 + r
+   G |- len(l1) = len(l2) = len(r1) = 0
+   e':  l1 + l2 + l3 + l = r3 + r'         occurs prior to e among equations
+   G |- len(r3) = len(r2)
+   r2, r3 are not identical
+   ----------------------------------
+   e'' : r3 + r' = r1 + r2 + r
+
+   e:   l1 + l2 + l3 + l = r1 + r2 + r
+   G |- len(l1) = len(l2) = len(r1) = 0
+   e':  l1 + l2 + l3 + l = r3 + r'         occurs prior to e among equations
+   G |- len(r3) = len(r2) + offset
+   r2, r3 are not identical
+   ----------------------------------
+   e'' : r3 + r' = r1 + r2 + r
+
+   NB, this doesn't make sense because e'' is just e', which already occurs.
+   It doesn't inherit the constraints from e either, which would get lost.
+
+   NB, if len(r3) = len(r2) would be used, then the justification for the equality
+   needs to be tracked in dependencies.
+    
+   TODO: propagate length offsets for last vars
+
+*/
+bool theory_seq::find_better_rep(expr_ref_vector const& ls, expr_ref_vector const& rs, unsigned idx,
+                                 dependency*& deps, expr_ref_vector & res) {
+
+    // disabled until functionality is clarified
+    return false;
+
+    if (ls.empty() || rs.empty())
+        return false;
+    expr* l_fst = find_fst_non_empty_var(ls);
+    expr* r_fst = find_fst_non_empty_var(rs);
+    if (!r_fst) return false;
+    expr_ref len_r_fst = mk_len(r_fst);
+    expr_ref len_l_fst(m);
+    enode * root2;
+    if (!ctx.e_internalized(len_r_fst)) {
+        return false;
+    }
+    if (l_fst) {
+        len_l_fst = mk_len(l_fst);
+    }
+
+    root2 = get_root(len_r_fst);
+
+    // Offset = 0, No change
+    if (l_fst && get_root(len_l_fst) == root2) {
+        TRACE("seq", tout << "(" << mk_pp(l_fst, m) << ", " << mk_pp(r_fst, m) << ")\n";);
+        return false;
+    }
+
+    // Offset = 0, Changed
+
+    for (unsigned i = 0; i < idx; ++i) {
+        depeq const& e = m_eqs[i];
+        if (e.ls != ls) continue;
+        expr* nl_fst = nullptr;
+        if (e.rs.size() > 1 && is_var(e.rs.get(0)))
+            nl_fst = e.rs.get(0);
+        if (nl_fst && nl_fst != r_fst && root2 == get_root(mk_len(nl_fst))) {
+            res.reset();
+            res.append(e.rs);
+            deps = m_dm.mk_join(e.dep(), deps);
+            return true;
+        }
+    }
+    // Offset != 0, No change
+    if (l_fst && ctx.e_internalized(len_l_fst)) {
+        enode * root1 = get_root(len_l_fst);
+        if (m_offset_eq.contains(root1, root2)) {
+            TRACE("seq", tout << "(" << mk_pp(l_fst, m) << ", " << mk_pp(r_fst,m) << ")\n";);
+            return false;
+        }
+    }
+    // Offset != 0, Changed
+    if (m_offset_eq.contains(root2)) {
+        for (unsigned i = 0; i < idx; ++i) {
+            depeq const& e = m_eqs[i];
+            if (e.ls != ls) continue;
+            expr* nl_fst = nullptr;
+            if (e.rs.size() > 1 && is_var(e.rs.get(0)))
+                nl_fst = e.rs.get(0);
+            if (nl_fst && nl_fst != r_fst) {
+                expr_ref len_nl_fst = mk_len(nl_fst);
+                if (ctx.e_internalized(len_nl_fst)) {
+                    enode * root1 = get_root(len_nl_fst);
+                    if (m_offset_eq.contains(root2, root1)) {
+                        res.reset();
+                        res.append(e.rs);
+                        deps = m_dm.mk_join(e.dep(), deps);
+                        return true;
+                    }
+                }
+            }
+        }
+    }
+    return false;
+}
+
+int theory_seq::find_fst_non_empty_idx(expr_ref_vector const& xs) {
+    for (unsigned i = 0; i < xs.size(); ++i) {
+        expr* x = xs[i];
+        if (!is_var(x)) 
+            return -1;
+        expr_ref e = mk_len(x);
+        if (ctx.e_internalized(e)) {
+            enode* root = ctx.get_enode(e)->get_root();
+            rational val;
+            if (m_autil.is_numeral(root->get_expr(), val) && val.is_zero()) {
+                continue;
+            }
+        }
+        return i;
+    }
+    return -1;
+}
+
+expr* theory_seq::find_fst_non_empty_var(expr_ref_vector const& x) {
+    int i = find_fst_non_empty_idx(x);
+    if (i >= 0)
+        return x[i];
+    return nullptr;
+}
+
+#endif
diff --git a/src/smt/theory_seq.h b/src/smt/theory_seq.h
index a962e08cd..527df86a5 100644
--- a/src/smt/theory_seq.h
+++ b/src/smt/theory_seq.h
@@ -436,7 +436,6 @@ namespace smt {
         bool check_length_coherence(expr* e);
         bool fixed_length(bool is_zero = false);
         bool fixed_length(expr* e, bool is_zero);
-        bool branch_unit_variable(dependency* dep, expr* X, expr_ref_vector const& units);
         bool branch_variable_eq(depeq const& e);
         bool branch_binary_variable(depeq const& e);
         bool can_align_from_lhs(expr_ref_vector const& ls, expr_ref_vector const& rs);