diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index ebf364f06..c904aea95 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -79,6 +79,20 @@ def except_hook(exctype, value, traceback): sys.excepthook = except_hook +def recursion_helper(iteration, *request): + stack = [iteration(*request)] + + while stack: + top = stack.pop() + try: + request = next(top) + except StopIteration: + continue + + stack.append(top) + stack.append(iteration(*request)) + + hex_dict = { "0": "0000", "1": "0001", "2": "0010", "3": "0011", "4": "0100", "5": "0101", "6": "0110", "7": "0111", @@ -298,10 +312,22 @@ class SmtIo: return stmt def unroll_stmt(self, stmt): - if not isinstance(stmt, list): - return stmt + result = [] + recursion_helper(self._unroll_stmt_into, stmt, result) + return result.pop() - stmt = [self.unroll_stmt(s) for s in stmt] + def _unroll_stmt_into(self, stmt, output, depth=128): + if not isinstance(stmt, list): + output.append(stmt) + return + + new_stmt = [] + for s in stmt: + if depth: + yield from self._unroll_stmt_into(s, new_stmt, depth - 1) + else: + yield s, new_stmt + stmt = new_stmt if len(stmt) >= 2 and not isinstance(stmt[0], list) and stmt[0] in self.unroll_decls: assert stmt[1] in self.unroll_objs @@ -330,12 +356,19 @@ class SmtIo: decl[2] = list() if len(decl) > 0: - decl = self.unroll_stmt(decl) + tmp = [] + if depth: + yield from self._unroll_stmt_into(decl, tmp, depth - 1) + else: + yield decl, tmp + + decl = tmp.pop() self.write(self.unparse(decl), unroll=False) - return self.unroll_cache[key] + output.append(self.unroll_cache[key]) + return - return stmt + output.append(stmt) def p_thread_main(self): while True: