From d03c5e2a00fc1fbb485bdd99bdb264658e316058 Mon Sep 17 00:00:00 2001
From: Jannis Harder <me@jix.one>
Date: Wed, 21 Feb 2024 16:35:17 +0100
Subject: [PATCH] smtbmc: Break dependency recursion during unrolling

Previously `unroll_stmt` would recurse over the smtlib expressions as
well as recursively follow not-yet-emitted definitions the current
expression depends on. While the depth of smtlib expressions generated
by yosys seems to be reasonably bounded, the dependency chain of
not-yet-emitted definitions can grow linearly with the size of the
design and linearly in the BMC depth.

This makes `unroll_stmt` use a `list` as stack, using python generators
and `recursion_helper` function to keep the overall code structure of
the previous recursive implementation.
---
 backends/smt2/smtio.py | 45 ++++++++++++++++++++++++++++++++++++------
 1 file changed, 39 insertions(+), 6 deletions(-)

diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py
index ebf364f06..c904aea95 100644
--- a/backends/smt2/smtio.py
+++ b/backends/smt2/smtio.py
@@ -79,6 +79,20 @@ def except_hook(exctype, value, traceback):
 sys.excepthook = except_hook
 
 
+def recursion_helper(iteration, *request):
+    stack = [iteration(*request)]
+
+    while stack:
+        top = stack.pop()
+        try:
+            request = next(top)
+        except StopIteration:
+            continue
+
+        stack.append(top)
+        stack.append(iteration(*request))
+
+
 hex_dict = {
     "0": "0000", "1": "0001", "2": "0010", "3": "0011",
     "4": "0100", "5": "0101", "6": "0110", "7": "0111",
@@ -298,10 +312,22 @@ class SmtIo:
         return stmt
 
     def unroll_stmt(self, stmt):
-        if not isinstance(stmt, list):
-            return stmt
+        result = []
+        recursion_helper(self._unroll_stmt_into, stmt, result)
+        return result.pop()
 
-        stmt = [self.unroll_stmt(s) for s in stmt]
+    def _unroll_stmt_into(self, stmt, output, depth=128):
+        if not isinstance(stmt, list):
+            output.append(stmt)
+            return
+
+        new_stmt = []
+        for s in stmt:
+            if depth:
+                yield from self._unroll_stmt_into(s, new_stmt, depth - 1)
+            else:
+                yield s, new_stmt
+        stmt = new_stmt
 
         if len(stmt) >= 2 and not isinstance(stmt[0], list) and stmt[0] in self.unroll_decls:
             assert stmt[1] in self.unroll_objs
@@ -330,12 +356,19 @@ class SmtIo:
                     decl[2] = list()
 
                 if len(decl) > 0:
-                    decl = self.unroll_stmt(decl)
+                    tmp = []
+                    if depth:
+                        yield from self._unroll_stmt_into(decl, tmp, depth - 1)
+                    else:
+                        yield decl, tmp
+
+                    decl = tmp.pop()
                     self.write(self.unparse(decl), unroll=False)
 
-            return self.unroll_cache[key]
+            output.append(self.unroll_cache[key])
+            return
 
-        return stmt
+        output.append(stmt)
 
     def p_thread_main(self):
         while True: