3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-04-07 01:54:10 +00:00

smtbmc: Add --incremental mode

This commit is contained in:
Jannis Harder 2023-11-16 13:15:54 +01:00
parent 032fab1f54
commit e319606ec9
4 changed files with 512 additions and 64 deletions

View file

@ -17,7 +17,7 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# #
import os, sys, getopt, re, bisect import os, sys, getopt, re, bisect, json
##yosys-sys-path## ##yosys-sys-path##
from smtio import SmtIo, SmtOpts, MkVcd from smtio import SmtIo, SmtOpts, MkVcd
from ywio import ReadWitness, WriteWitness, WitnessValues from ywio import ReadWitness, WriteWitness, WitnessValues
@ -56,6 +56,7 @@ binarymode = False
keep_going = False keep_going = False
check_witness = False check_witness = False
detect_loops = False detect_loops = False
incremental = None
so = SmtOpts() so = SmtOpts()
@ -185,6 +186,9 @@ def help():
check if states are unique in temporal induction counter examples check if states are unique in temporal induction counter examples
(this feature is experimental and incomplete) (this feature is experimental and incomplete)
--incremental
run in incremental mode (experimental)
""" + so.helpmsg()) """ + so.helpmsg())
def usage(): def usage():
@ -196,7 +200,7 @@ try:
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts + opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat", ["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops"]) "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"])
except: except:
usage() usage()
@ -282,6 +286,9 @@ for o, a in opts:
check_witness = True check_witness = True
elif o == "--detect-loops": elif o == "--detect-loops":
detect_loops = True detect_loops = True
elif o == "--incremental":
from smtbmc_incremental import Incremental
incremental = Incremental()
elif so.handle(o, a): elif so.handle(o, a):
pass pass
else: else:
@ -290,7 +297,7 @@ for o, a in opts:
if len(args) != 1: if len(args) != 1:
usage() usage()
if sum([tempind, gentrace, covermode]) > 1: if sum([tempind, gentrace, covermode, incremental is not None]) > 1:
usage() usage()
constr_final_start = None constr_final_start = None
@ -444,8 +451,10 @@ if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False smt.produce_models = False
def print_msg(msg): def print_msg(msg):
print("%s %s" % (smt.timestamp(), msg)) if incremental:
sys.stdout.flush() incremental.print_msg(msg)
else:
print("%s %s" % (smt.timestamp(), msg), flush=True)
print_msg("Solver: %s" % (so.solver)) print_msg("Solver: %s" % (so.solver))
@ -640,10 +649,9 @@ if aimfile is not None:
num_steps = max(num_steps, step+2) num_steps = max(num_steps, step+2)
step += 1 step += 1
if inywfile is not None: def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
if not got_topt: if map_steps is None:
skip_steps = 0 map_steps = {}
num_steps = 0
with open(inywfile, "r") as f: with open(inywfile, "r") as f:
inyw = ReadWitness(f) inyw = ReadWitness(f)
@ -662,10 +670,14 @@ if inywfile is not None:
addr_re = re.compile(r'\\\[[0-9]+\]$') addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$') bits_re = re.compile(r'[01?]*$')
max_t = -1
for t, step in inyw.steps(): for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap) present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals: for sig in present_signals:
bits = step[sig] bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits): if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file") raise ValueError("unsupported bit value in Yosys witness file")
@ -684,7 +696,7 @@ if inywfile is not None:
if common_end <= common_offset: if common_end <= common_offset:
continue continue
smt_expr = smt.witness_net_expr(topmod, f"s{t}", wire) smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if not smt_bool: if not smt_bool:
slice_high = common_end - offset - 1 slice_high = common_end - offset - 1
@ -714,7 +726,7 @@ if inywfile is not None:
for mem in smt_mems[sig.memory_path]: for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"] width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{t}", mem["smtpath"]) smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
if bv: if bv:
word_low = sig.memory_addr * width word_low = sig.memory_addr * width
@ -738,11 +750,21 @@ if inywfile is not None:
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr)) constr_assumes[t].append((inywfile, smt_constr))
max_t = t
if not got_topt: return max_t
if not check_witness:
skip_steps = max(skip_steps, t) if inywfile is not None:
num_steps = max(num_steps, t+1) if not got_topt:
skip_steps = 0
num_steps = 0
max_t = ywfile_constraints(inywfile, constr_assumes)
if not got_topt:
if not check_witness:
skip_steps = max(skip_steps, max_t)
num_steps = max(num_steps, max_t+1)
if btorwitfile is not None: if btorwitfile is not None:
with open(btorwitfile, "r") as f: with open(btorwitfile, "r") as f:
@ -841,7 +863,7 @@ if btorwitfile is not None:
skip_steps = step skip_steps = step
num_steps = step+1 num_steps = step+1
def collect_mem_trace_data(steps_start, steps_stop, vcd=None): def collect_mem_trace_data(steps, vcd=None):
mem_trace_data = dict() mem_trace_data = dict()
for mempath in sorted(smt.hiermems(topmod)): for mempath in sorted(smt.hiermems(topmod)):
@ -849,16 +871,16 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
expr_id = list() expr_id = list()
expr_list = list() expr_list = list()
for i in range(steps_start, steps_stop): for seq, i in enumerate(steps):
for j in range(rports): for j in range(rports):
expr_id.append(('R', i-steps_start, j, 'A')) expr_id.append(('R', seq, j, 'A'))
expr_id.append(('R', i-steps_start, j, 'D')) expr_id.append(('R', seq, j, 'D'))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j))
for j in range(wports): for j in range(wports):
expr_id.append(('W', i-steps_start, j, 'A')) expr_id.append(('W', seq, j, 'A'))
expr_id.append(('W', i-steps_start, j, 'D')) expr_id.append(('W', seq, j, 'D'))
expr_id.append(('W', i-steps_start, j, 'M')) expr_id.append(('W', seq, j, 'M'))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j))
@ -943,14 +965,14 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int_addr) netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int_addr)
vcd.add_net([topmod] + netpath, width) vcd.add_net([topmod] + netpath, width)
for i in range(steps_start, steps_stop): for seq, i in enumerate(steps):
if i not in mem_trace_data: if i not in mem_trace_data:
mem_trace_data[i] = list() mem_trace_data[i] = list()
mem_trace_data[i].append((netpath, int_addr, "".join(tdata[i-steps_start]))) mem_trace_data[i].append((netpath, int_addr, "".join(tdata[seq])))
return mem_trace_data return mem_trace_data
def write_vcd_trace(steps_start, steps_stop, index): def write_vcd_trace(steps, index, seq_time=False):
filename = vcdfile.replace("%", index) filename = vcdfile.replace("%", index)
print_msg("Writing trace to VCD file: %s" % (filename)) print_msg("Writing trace to VCD file: %s" % (filename))
@ -971,10 +993,10 @@ def write_vcd_trace(steps_start, steps_stop, index):
vcd.add_clock([topmod] + netpath, edge) vcd.add_clock([topmod] + netpath, edge)
path_list.append(netpath) path_list.append(netpath)
mem_trace_data = collect_mem_trace_data(steps_start, steps_stop, vcd) mem_trace_data = collect_mem_trace_data(steps, vcd)
for i in range(steps_start, steps_stop): for seq, i in enumerate(steps):
vcd.set_time(i) vcd.set_time(seq if seq_time else i)
value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i) value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i)
for path, value in zip(path_list, value_list): for path, value in zip(path_list, value_list):
vcd.set_net([topmod] + path, value) vcd.set_net([topmod] + path, value)
@ -982,7 +1004,14 @@ def write_vcd_trace(steps_start, steps_stop, index):
for path, addr, value in mem_trace_data[i]: for path, addr, value in mem_trace_data[i]:
vcd.set_net([topmod] + path, value) vcd.set_net([topmod] + path, value)
vcd.set_time(steps_stop) if seq_time:
end_time = len(steps)
elif steps:
end_time = steps[-1] + 1
else:
end_time = 0
vcd.set_time(end_time)
def detect_state_loop(steps_start, steps_stop): def detect_state_loop(steps_start, steps_stop):
print_msg(f"Checking for loops in found induction counter example") print_msg(f"Checking for loops in found induction counter example")
@ -1027,7 +1056,7 @@ def escape_identifier(identifier):
def write_vlogtb_trace(steps_start, steps_stop, index): def write_vlogtb_trace(steps, index):
filename = vlogtbfile.replace("%", index) filename = vlogtbfile.replace("%", index)
print_msg("Writing trace to Verilog testbench: %s" % (filename)) print_msg("Writing trace to Verilog testbench: %s" % (filename))
@ -1092,7 +1121,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print(" initial begin", file=f) print(" initial begin", file=f)
regs = sorted(smt.hiernets(vlogtb_topmod, regs_only=True)) regs = sorted(smt.hiernets(vlogtb_topmod, regs_only=True))
regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps_start))) regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps[0])))
print("`ifndef VERILATOR", file=f) print("`ifndef VERILATOR", file=f)
print(" #1;", file=f) print(" #1;", file=f)
@ -1107,7 +1136,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
anyconsts = sorted(smt.hieranyconsts(vlogtb_topmod)) anyconsts = sorted(smt.hieranyconsts(vlogtb_topmod))
for info in anyconsts: for info in anyconsts:
if info[3] is not None: if info[3] is not None:
modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps_start)), info[0]) modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps[0])), info[0])
value = smt.bv2bin(smt.get("(|%s| %s)" % (info[1], modstate))) value = smt.bv2bin(smt.get("(|%s| %s)" % (info[1], modstate)))
print(" UUT.%s = %d'b%s;" % (".".join(escape_identifier(info[0] + [info[3]])), len(value), value), file=f); print(" UUT.%s = %d'b%s;" % (".".join(escape_identifier(info[0] + [info[3]])), len(value), value), file=f);
@ -1117,7 +1146,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
addr_expr_list = list() addr_expr_list = list()
data_expr_list = list() data_expr_list = list()
for i in range(steps_start, steps_stop): for i in steps:
for j in range(rports): for j in range(rports):
addr_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j)) addr_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j))
data_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j)) data_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j))
@ -1138,7 +1167,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print("", file=f) print("", file=f)
anyseqs = sorted(smt.hieranyseqs(vlogtb_topmod)) anyseqs = sorted(smt.hieranyseqs(vlogtb_topmod))
for i in range(steps_start, steps_stop): for i in steps:
pi_names = [[name] for name, _ in primary_inputs if name not in clock_inputs] pi_names = [[name] for name, _ in primary_inputs if name not in clock_inputs]
pi_values = smt.get_net_bin_list(vlogtb_topmod, pi_names, vlogtb_state.replace("@@step_idx@@", str(i))) pi_values = smt.get_net_bin_list(vlogtb_topmod, pi_names, vlogtb_state.replace("@@step_idx@@", str(i)))
@ -1170,14 +1199,14 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print(" end", file=f) print(" end", file=f)
print(" always @(posedge clock) begin", file=f) print(" always @(posedge clock) begin", file=f)
print(" genclock <= cycle < %d;" % (steps_stop-1), file=f) print(" genclock <= cycle < %d;" % (steps[-1]), file=f)
print(" cycle <= cycle + 1;", file=f) print(" cycle <= cycle + 1;", file=f)
print(" end", file=f) print(" end", file=f)
print("endmodule", file=f) print("endmodule", file=f)
def write_constr_trace(steps_start, steps_stop, index): def write_constr_trace(steps, index):
filename = outconstr.replace("%", index) filename = outconstr.replace("%", index)
print_msg("Writing trace to constraints file: %s" % (filename)) print_msg("Writing trace to constraints file: %s" % (filename))
@ -1194,7 +1223,7 @@ def write_constr_trace(steps_start, steps_stop, index):
constr_prefix = smtctop[1] + "." constr_prefix = smtctop[1] + "."
if smtcinit: if smtcinit:
steps_start = steps_stop - 1 steps = [steps[-1]]
with open(filename, "w") as f: with open(filename, "w") as f:
primary_inputs = list() primary_inputs = list()
@ -1203,13 +1232,13 @@ def write_constr_trace(steps_start, steps_stop, index):
width = smt.modinfo[constr_topmod].wsize[name] width = smt.modinfo[constr_topmod].wsize[name]
primary_inputs.append((name, width)) primary_inputs.append((name, width))
if steps_start == 0 or smtcinit: if steps[0] == 0 or smtcinit:
print("initial", file=f) print("initial", file=f)
else: else:
print("state %d" % steps_start, file=f) print("state %d" % steps[0], file=f)
regnames = sorted(smt.hiernets(constr_topmod, regs_only=True)) regnames = sorted(smt.hiernets(constr_topmod, regs_only=True))
regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps_start))) regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps[0])))
for name, val in zip(regnames, regvals): for name, val in zip(regnames, regvals):
print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f) print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f)
@ -1220,7 +1249,7 @@ def write_constr_trace(steps_start, steps_stop, index):
addr_expr_list = list() addr_expr_list = list()
data_expr_list = list() data_expr_list = list()
for i in range(steps_start, steps_stop): for i in steps:
for j in range(rports): for j in range(rports):
addr_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j)) addr_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j))
data_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j)) data_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j))
@ -1236,7 +1265,7 @@ def write_constr_trace(steps_start, steps_stop, index):
for addr, data in addr_data.items(): for addr, data in addr_data.items():
print("assume (= (select [%s%s] %s) %s)" % (constr_prefix, ".".join(mempath), addr, data), file=f) print("assume (= (select [%s%s] %s) %s)" % (constr_prefix, ".".join(mempath), addr, data), file=f)
for k in range(steps_start, steps_stop): for k in steps:
if not smtcinit: if not smtcinit:
print("", file=f) print("", file=f)
print("state %d" % k, file=f) print("state %d" % k, file=f)
@ -1247,11 +1276,14 @@ def write_constr_trace(steps_start, steps_stop, index):
for name, val in zip(pi_names, pi_values): for name, val in zip(pi_names, pi_values):
print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f) print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f)
def write_yw_trace(steps_start, steps_stop, index, allregs=False): def write_yw_trace(steps, index, allregs=False, filename=None):
filename = outywfile.replace("%", index) if filename is None:
print_msg("Writing trace to Yosys witness file: %s" % (filename)) if outywfile is None:
return
filename = outywfile.replace("%", index)
print_msg("Writing trace to Yosys witness file: %s" % (filename))
mem_trace_data = collect_mem_trace_data(steps_start, steps_stop) mem_trace_data = collect_mem_trace_data(steps)
with open(filename, "w") as f: with open(filename, "w") as f:
inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs) inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs)
@ -1295,10 +1327,10 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False):
sig = yw.add_sig(word_path, overlap_start, overlap_end - overlap_start, True) sig = yw.add_sig(word_path, overlap_start, overlap_end - overlap_start, True)
mem_init_values.append((sig, overlap_bits.replace("x", "?"))) mem_init_values.append((sig, overlap_bits.replace("x", "?")))
for k in range(steps_start, steps_stop): for i, k in enumerate(steps):
step_values = WitnessValues() step_values = WitnessValues()
if k == steps_start: if not i:
for sig, value in mem_init_values: for sig, value in mem_init_values:
step_values[sig] = value step_values[sig] = value
sigs = inits + seqs sigs = inits + seqs
@ -1314,17 +1346,24 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False):
def write_trace(steps_start, steps_stop, index, allregs=False): def write_trace(steps_start, steps_stop, index, allregs=False):
if steps_stop is None:
steps = steps_start
seq_time = True
else:
steps = list(range(steps_start, steps_stop))
seq_time = False
if vcdfile is not None: if vcdfile is not None:
write_vcd_trace(steps_start, steps_stop, index) write_vcd_trace(steps, index, seq_time=seq_time)
if vlogtbfile is not None: if vlogtbfile is not None:
write_vlogtb_trace(steps_start, steps_stop, index) write_vlogtb_trace(steps, index)
if outconstr is not None: if outconstr is not None:
write_constr_trace(steps_start, steps_stop, index) write_constr_trace(steps, index)
if outywfile is not None: if outywfile is not None:
write_yw_trace(steps_start, steps_stop, index, allregs) write_yw_trace(steps, index, allregs)
def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()): def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()):
@ -1596,7 +1635,11 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert() smt_forall_assert()
return smt.check_sat(expected=expected) return smt.check_sat(expected=expected)
if tempind:
if incremental:
incremental.mainloop()
elif tempind:
retstatus = "FAILED" retstatus = "FAILED"
skip_counter = step_size skip_counter = step_size
for step in range(num_steps, -1, -1): for step in range(num_steps, -1, -1):
@ -1954,5 +1997,6 @@ else: # not tempind, covermode
smt.write("(exit)") smt.write("(exit)")
smt.wait() smt.wait()
print_msg("Status: %s" % retstatus) if not incremental:
sys.exit(0 if retstatus == "PASSED" else 1) print_msg("Status: %s" % retstatus)
sys.exit(0 if retstatus == "PASSED" else 1)

View file

@ -0,0 +1,389 @@
from collections import defaultdict
import json
import typing
from functools import partial
if typing.TYPE_CHECKING:
import smtbmc
else:
import sys
smtbmc = sys.modules["__main__"]
class InteractiveError(Exception):
pass
class Incremental:
def __init__(self):
self.traceidx = 0
self.state_set = set()
self.map_cache = {}
self._cached_hierwitness = {}
self._witness_index = None
self._yw_constraints = {}
def setup(self):
generic_assert_map = smtbmc.get_assert_map(
smtbmc.topmod, "state", smtbmc.topmod
)
self.inv_generic_assert_map = {
tuple(data[1:]): key for key, data in generic_assert_map.items()
}
assert len(self.inv_generic_assert_map) == len(generic_assert_map)
def print_json(self, **kwargs):
print(json.dumps(kwargs), flush=True)
def print_msg(self, msg):
self.print_json(msg=msg)
def get_cached_assert(self, step, name):
try:
assert_map = self.map_cache[step]
except KeyError:
assert_map = self.map_cache[step] = smtbmc.get_assert_map(
smtbmc.topmod, f"s{step}", smtbmc.topmod
)
return assert_map[self.inv_generic_assert_map[name]][0]
def arg_step(self, cmd, declare=False, name="step", optional=False):
step = cmd.get(name, None)
if step is None and optional:
return None
if not isinstance(step, int):
if optional:
raise InteractiveError(f"{name} must be an integer")
else:
raise InteractiveError(f"integer {name} argument required")
if declare and step in self.state_set:
raise InteractiveError(f"step {step} already declared")
if not declare and step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
return step
def expr_arg_len(self, expr, min_len, max_len=-1):
if max_len == -1:
max_len = min_len
arg_len = len(expr) - 1
if min_len is not None and arg_len < min_len:
if min_len == max_len:
raise (
f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
else:
raise (
f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
if max_len is not None and arg_len > max_len:
raise (
f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}"
)
def expr_step(self, expr, smt_out):
self.expr_arg_len(expr, 1)
step = expr[1]
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
smt_out.append(f"s{step}")
return "module", smtbmc.topmod
def expr_mod_constraint(self, expr, smt_out):
self.expr_arg_len(expr, 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
module = arg_sort[1]
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"
def expr_mod_constraint2(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
module = arg_sort[1]
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"
def expr_not(self, expr, smt_out):
self.expr_arg_len(expr, 1)
smt_out.append("(not ")
self.expr(expr[1], smt_out, required_sort="Bool")
smt_out.append(")")
return "Bool"
def expr_eq(self, expr, smt_out):
self.expr_arg_len(expr, 2)
smt_out.append("(= ")
arg_sort = self.expr(expr[1], smt_out)
if (
smtbmc.smt.unroll
and isinstance(arg_sort, (list, tuple))
and arg_sort[0] == "module"
):
raise InteractiveError("state equality not supported in unroll mode")
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
smt_out.append(")")
return "Bool"
def expr_andor(self, expr, smt_out):
if len(expr) == 1:
smt_out.push({"and": "true", "or": "false"}[expr[0]])
elif len(expr) == 2:
arg_sort = self.expr(expr[1], smt_out)
if arg_sort != "Bool":
raise InteractiveError(
f"arguments of {json.dumps(expr[0])} must have sort Bool"
)
else:
sep = f"({expr[0]} "
for arg in expr[1:]:
smt_out.append(sep)
sep = " "
self.expr(arg, smt_out, required_sort="Bool")
smt_out.append(")")
return "Bool"
def expr_yw(self, expr, smt_out):
if len(expr) == 2:
name = None
step = expr[1]
elif len(expr) == 3:
name = expr[1]
step = expr[2]
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
if name not in self._yw_constraints:
raise InteractiveError(f"no yw file loaded as name {name!r}")
constraints = self._yw_constraints[name].get(step, [])
if len(constraints) == 0:
smt_out.append("true")
elif len(constraints) == 1:
smt_out.append(constraints[0])
else:
sep = "(and "
for constraint in constraints:
smt_out.append(sep)
sep = " "
smt_out.append(constraint)
smt_out.append(")")
return "Bool"
def expr_label(self, expr, smt_out):
if len(expr) != 3:
raise InteractiveError(f'expected ["!", label, sub_expr], got {expr!r}')
label = expr[1]
subexpr = expr[2]
if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string")
smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")
sort = self.expr(subexpr, smt_out)
smt_out.append(")")
return sort
expr_handlers = {
"step": expr_step,
"mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint,
"mod_a": expr_mod_constraint,
"mod_u": expr_mod_constraint,
"mod_t": expr_mod_constraint2,
"not": expr_not,
"and": expr_andor,
"or": expr_andor,
"=": expr_eq,
"yw": expr_yw,
"!": expr_label,
}
def expr(self, expr, smt_out, required_sort=None):
if not isinstance(expr, (list, tuple)) or not expr:
raise InteractiveError(
f"expression must be a non-empty JSON array, found: {json.dumps(expr)}"
)
name = expr[0]
handler = self.expr_handlers.get(name)
if handler:
sort = handler(self, expr, smt_out)
if required_sort is not None:
if isinstance(required_sort, (list, tuple)):
if (
not isinstance(sort, (list, tuple))
or len(sort) != len(required_sort)
or any(
r is not None and r != s
for r, s in zip(required_sort, sort)
)
):
raise InteractiveError(
f"required sort {json.dumps(required_sort)} found sort {json.dumps(sort)}"
)
return sort
raise InteractiveError(f"unknown expression {json.dumps(expr[0])}")
def expr_smt(self, expr, required_sort):
smt_out = []
self.expr(expr, smt_out, required_sort=required_sort)
out = "".join(smt_out)
return out
def cmd_new_step(self, cmd):
step = self.arg_step(cmd, declare=True)
self.state_set.add(step)
smtbmc.smt_state(step)
def cmd_assert(self, cmd):
name = cmd.get("cmd")
assert_fn = {
"assert_antecedent": smtbmc.smt_assert_antecedent,
"assert_consequent": smtbmc.smt_assert_consequent,
"assert": smtbmc.smt_assert,
}[name]
assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))
def cmd_push(self, cmd):
smtbmc.smt_push()
def cmd_pop(self, cmd):
smtbmc.smt_pop()
def cmd_check(self, cmd):
return smtbmc.smt_check_sat()
def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
if self._cached_hierwitness[allregs] is not None:
return self._cached_hierwitness[allregs]
inits, seqs, clocks, mems = smtbmc.smt.hierwitness(smtbmc.topmod, allregs)
self._cached_hierwitness[allregs] = result = dict(
inits=inits, seqs=seqs, clocks=clocks, mems=mems
)
return result
def cmd_write_yw_trace(self, cmd):
steps = cmd.get("steps")
allregs = bool(cmd.get("allregs", False))
if steps is None:
steps = sorted(self.state_set)
path = cmd.get("path")
smtbmc.write_yw_trace(steps, self.traceidx, allregs=allregs, filename=path)
if path is None:
self.traceidx += 1
def cmd_read_yw_trace(self, cmd):
steps = cmd.get("steps")
path = cmd.get("path")
name = cmd.get("name")
skip_x = cmd.get("skip_x", False)
if path is None:
raise InteractiveError("path required")
constraints = defaultdict(list)
if steps is None:
steps = sorted(self.state_set)
map_steps = {i: int(j) for i, j in enumerate(steps)}
smtbmc.ywfile_constraints(path, constraints, map_steps=map_steps, skip_x=skip_x)
self._yw_constraints[name] = {
map_steps.get(i, i): [smtexpr for cexfile, smtexpr in constraint_list]
for i, constraint_list in constraints.items()
}
def cmd_ping(self, cmd):
return cmd
cmd_handlers = {
"new_step": cmd_new_step,
"assert": cmd_assert,
"assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert,
"push": cmd_push,
"pop": cmd_pop,
"check": cmd_check,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,
"ping": cmd_ping,
}
def handle_command(self, cmd):
if not isinstance(cmd, dict) or "cmd" not in cmd:
raise InteractiveError('object with "cmd" key required')
name = cmd.get("cmd", None)
handler = self.cmd_handlers.get(name)
if handler:
return handler(self, cmd)
else:
raise InteractiveError(f"unknown command: {name}")
def mainloop(self):
self.setup()
while True:
try:
cmd = input().strip()
if not cmd or cmd.startswith("#") or cmd.startswith("//"):
continue
try:
cmd = json.loads(cmd)
except json.decoder.JSONDecodeError as e:
self.print_json(err=f"invalid JSON: {e}")
continue
except EOFError:
break
try:
result = self.handle_command(cmd)
except InteractiveError as e:
self.print_json(err=str(e))
continue
except Exception as e:
self.print_json(err=f"internal error: {e}")
raise
else:
self.print_json(ok=result)

View file

@ -33,10 +33,14 @@ def cli():
Display a Yosys witness trace in a human readable format. Display a Yosys witness trace in a human readable format.
""") """)
@click.argument("input", type=click.File("r")) @click.argument("input", type=click.File("r"))
def display(input): @click.option("--skip-x", help="Treat x bits as unassigned.", is_flag=True)
def display(input, skip_x):
click.echo(f"Reading Yosys witness trace {input.name!r}...") click.echo(f"Reading Yosys witness trace {input.name!r}...")
inyw = ReadWitness(input) inyw = ReadWitness(input)
if skip_x:
inyw.skip_x()
def output(): def output():
yield click.style("*** RTLIL bit-order below may differ from source level declarations ***", fg="red") yield click.style("*** RTLIL bit-order below may differ from source level declarations ***", fg="red")
@ -91,7 +95,11 @@ If two or more inputs are provided they will be concatenated together into the o
@click.option("--append", "-p", type=int, multiple=True, @click.option("--append", "-p", type=int, multiple=True,
help="Number of steps (+ve or -ve) to append to end of input trace. " help="Number of steps (+ve or -ve) to append to end of input trace. "
+"Can be defined multiple times, following the same order as input traces. ") +"Can be defined multiple times, following the same order as input traces. ")
def yw2yw(inputs, output, append): @click.option("--skip-x", help="Leave input x bits unassigned.", is_flag=True)
def yw2yw(inputs, output, append, skip_x):
if len(inputs) == 0:
raise click.ClickException(f"no inputs specified")
outyw = WriteWitness(output, "yosys-witness yw2yw") outyw = WriteWitness(output, "yosys-witness yw2yw")
join_inputs = len(inputs) > 1 join_inputs = len(inputs) > 1
inyws = {} inyws = {}
@ -129,12 +137,12 @@ def yw2yw(inputs, output, append):
click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...") click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...")
if first_witness: if first_witness:
outyw.step(init_values) outyw.step(init_values, skip_x=skip_x)
else: else:
outyw.step(inyw.first_step()) outyw.step(inyw.first_step(), skip_x=skip_x)
for t, values in inyw.steps(1): for t, values in inyw.steps(1):
outyw.step(values) outyw.step(values, skip_x=skip_x)
click.echo(f" copied {t + 1} time steps.") click.echo(f" copied {t + 1} time steps.")
first_witness = False first_witness = False
@ -174,7 +182,8 @@ This requires a Yosys witness AIGER map file as generated by 'write_aiger -ywmap
@click.argument("input", type=click.File("r")) @click.argument("input", type=click.File("r"))
@click.argument("mapfile", type=click.File("r")) @click.argument("mapfile", type=click.File("r"))
@click.argument("output", type=click.File("w")) @click.argument("output", type=click.File("w"))
def aiw2yw(input, mapfile, output): @click.option("--skip-x", help="Leave input x bits unassigned.", is_flag=True)
def aiw2yw(input, mapfile, output, skip_x):
input_name = input.name input_name = input.name
click.echo(f"Converting AIGER witness trace {input_name!r} to Yosys witness trace {output.name!r}...") click.echo(f"Converting AIGER witness trace {input_name!r} to Yosys witness trace {output.name!r}...")
click.echo(f"Using Yosys witness AIGER map file {mapfile.name!r}") click.echo(f"Using Yosys witness AIGER map file {mapfile.name!r}")
@ -245,7 +254,7 @@ def aiw2yw(input, mapfile, output):
values[bit] = v values[bit] = v
outyw.step(values) outyw.step(values, skip_x=skip_x)
outyw.end_trace() outyw.end_trace()

View file

@ -351,11 +351,14 @@ class WriteWitness:
self.out.name("steps") self.out.name("steps")
self.out.begin_array() self.out.begin_array()
def step(self, values): def step(self, values, skip_x=False):
if not self.header_written: if not self.header_written:
self.write_header() self.write_header()
self.out.value({"bits": values.pack(self.sigmap)}) packed = values.pack(self.sigmap)
if skip_x:
packed = packed.replace('x', '?')
self.out.value({"bits": packed})
self.t += 1 self.t += 1
@ -390,6 +393,9 @@ class ReadWitness:
self.bits = [step["bits"] for step in data["steps"]] self.bits = [step["bits"] for step in data["steps"]]
def skip_x(self):
self.bits = [step.replace('x', '?') for step in self.bits]
def init_step(self): def init_step(self):
return self.step(0) return self.step(0)