diff --git a/backends/smt2/witness.py b/backends/smt2/witness.py index 8d0cc8112..8e13cba27 100644 --- a/backends/smt2/witness.py +++ b/backends/smt2/witness.py @@ -84,26 +84,48 @@ def stats(input): Transform a Yosys witness trace. Currently no transformations are implemented, so it is only useful for testing. +If two or more inputs are provided they will be concatenated together into the output. """) -@click.argument("input", type=click.File("r")) +@click.argument("inputs", type=click.File("r"), nargs=-1) @click.argument("output", type=click.File("w")) -def yw2yw(input, output): - click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...") - inyw = ReadWitness(input) +def yw2yw(inputs, output): outyw = WriteWitness(output, "yosys-witness yw2yw") + join_inputs = len(inputs) > 1 + inyws = {} + for input in inputs: + if (join_inputs): + click.echo(f"Loading signals from yosys witness trace {input.name!r}...") + inyw = ReadWitness(input) + inyws[input] = inyw + for clock in inyw.clocks: + if clock not in outyw.clocks: + outyw.add_clock(clock["path"], clock["offset"], clock["edge"]) - for clock in inyw.clocks: - outyw.add_clock(clock["path"], clock["offset"], clock["edge"]) + for sig in inyw.signals: + if sig not in outyw.signals: + outyw.add_sig(sig.path, sig.offset, sig.width, sig.init_only) - for sig in inyw.signals: - outyw.add_sig(sig.path, sig.offset, sig.width, sig.init_only) + init_values = sum([inyw.init_step() for inyw in inyws.values()], start=WitnessValues()) - for t, values in inyw.steps(): - outyw.step(values) + first_witness = True + for (input, inyw) in inyws.items(): + click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...") + + if first_witness: + outyw.step(init_values) + else: + outyw.step(inyw.first_step()) + + for t, values in inyw.steps(1): + outyw.step(values) + + click.echo(f"Copied {t + 1} time steps.") + first_witness = False outyw.end_trace() - click.echo(f"Copied {outyw.t + 1} time steps.") + if join_inputs: + click.echo(f"Copied {outyw.t} total time steps.") class AigerMap: diff --git a/backends/smt2/ywio.py b/backends/smt2/ywio.py index 39cfac41e..2b897200f 100644 --- a/backends/smt2/ywio.py +++ b/backends/smt2/ywio.py @@ -165,8 +165,8 @@ class WitnessSig: else: return f"{pretty_path(self.path)}[{self.offset}]" - def __eq__(self): - return self.sort_key + def __eq__(self, other): + return self.sort_key == other.sort_key def __hash__(self): return hash(self.sort_key) @@ -294,6 +294,16 @@ class WitnessValues: return sorted(signals), missing_signals + def __add__(self, other: "WitnessValues"): + new = WitnessValues() + new += self + new += other + return new + + def __iadd__(self, other: "WitnessValues"): + for key, value in other.values.items(): + self.values.setdefault(key, value) + return self class WriteWitness: def __init__(self, f, generator): @@ -380,13 +390,24 @@ class ReadWitness: self.bits = [step["bits"] for step in data["steps"]] + def init_step(self): + return self.step(0) + + def first_step(self): + values = WitnessValues() + if len(self.bits) <= 1: + raise NotImplementedError("ReadWitness.first_step() not supported for less than 2 steps") + non_init_bits = len(self.bits[1]) + values.unpack(WitnessSigMap([sig for sig in self.signals if not sig.init_only]), self.bits[0][-non_init_bits:]) + return values + def step(self, t): values = WitnessValues() values.unpack(self.sigmap, self.bits[t]) return values - def steps(self): - for i in range(len(self.bits)): + def steps(self, start=0): + for i in range(start, len(self.bits)): yield i, self.step(i) def __len__(self):