diff --git a/backends/smt2/witness.py b/backends/smt2/witness.py
index 8d0cc8112..8e13cba27 100644
--- a/backends/smt2/witness.py
+++ b/backends/smt2/witness.py
@@ -84,26 +84,48 @@ def stats(input):
 Transform a Yosys witness trace.
 
 Currently no transformations are implemented, so it is only useful for testing.
+If two or more inputs are provided they will be concatenated together into the output.
 """)
-@click.argument("input", type=click.File("r"))
+@click.argument("inputs", type=click.File("r"), nargs=-1)
 @click.argument("output", type=click.File("w"))
-def yw2yw(input, output):
-    click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...")
-    inyw = ReadWitness(input)
+def yw2yw(inputs, output):
     outyw = WriteWitness(output, "yosys-witness yw2yw")
+    join_inputs = len(inputs) > 1
+    inyws = {}
+    for input in inputs:
+        if (join_inputs):
+            click.echo(f"Loading signals from yosys witness trace {input.name!r}...")
+        inyw = ReadWitness(input)
+        inyws[input] = inyw
+        for clock in inyw.clocks:
+            if clock not in outyw.clocks:
+                outyw.add_clock(clock["path"], clock["offset"], clock["edge"])
 
-    for clock in inyw.clocks:
-        outyw.add_clock(clock["path"], clock["offset"], clock["edge"])
+        for sig in inyw.signals:
+            if sig not in outyw.signals:
+                outyw.add_sig(sig.path, sig.offset, sig.width, sig.init_only)
 
-    for sig in inyw.signals:
-        outyw.add_sig(sig.path, sig.offset, sig.width, sig.init_only)
+    init_values = sum([inyw.init_step() for inyw in inyws.values()], start=WitnessValues())
 
-    for t, values in inyw.steps():
-        outyw.step(values)
+    first_witness = True
+    for (input, inyw) in inyws.items():
+        click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...")
+
+        if first_witness:
+            outyw.step(init_values)
+        else:
+            outyw.step(inyw.first_step())
+
+        for t, values in inyw.steps(1):
+            outyw.step(values)
+
+        click.echo(f"Copied {t + 1} time steps.")
+        first_witness = False
 
     outyw.end_trace()
 
-    click.echo(f"Copied {outyw.t + 1} time steps.")
+    if join_inputs:
+        click.echo(f"Copied {outyw.t} total time steps.")
 
 
 class AigerMap:
diff --git a/backends/smt2/ywio.py b/backends/smt2/ywio.py
index 39cfac41e..2b897200f 100644
--- a/backends/smt2/ywio.py
+++ b/backends/smt2/ywio.py
@@ -165,8 +165,8 @@ class WitnessSig:
         else:
             return f"{pretty_path(self.path)}[{self.offset}]"
 
-    def __eq__(self):
-        return self.sort_key
+    def __eq__(self, other):
+        return self.sort_key == other.sort_key
 
     def __hash__(self):
         return hash(self.sort_key)
@@ -294,6 +294,16 @@ class WitnessValues:
 
         return sorted(signals), missing_signals
 
+    def __add__(self, other: "WitnessValues"):
+        new = WitnessValues()
+        new += self
+        new += other
+        return new
+
+    def __iadd__(self, other: "WitnessValues"):
+        for key, value in other.values.items():
+            self.values.setdefault(key, value)
+        return self
 
 class WriteWitness:
     def __init__(self, f, generator):
@@ -380,13 +390,24 @@ class ReadWitness:
 
         self.bits = [step["bits"] for step in data["steps"]]
 
+    def init_step(self):
+        return self.step(0)
+    
+    def first_step(self):
+        values = WitnessValues()
+        if len(self.bits) <= 1:
+            raise NotImplementedError("ReadWitness.first_step() not supported for less than 2 steps")
+        non_init_bits = len(self.bits[1])
+        values.unpack(WitnessSigMap([sig for sig in self.signals if not sig.init_only]), self.bits[0][-non_init_bits:])
+        return values
+
     def step(self, t):
         values = WitnessValues()
         values.unpack(self.sigmap, self.bits[t])
         return values
 
-    def steps(self):
-        for i in range(len(self.bits)):
+    def steps(self, start=0):
+        for i in range(start, len(self.bits)):
             yield i, self.step(i)
 
     def __len__(self):