3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-08-10 13:10:51 +00:00

smtbmc: Add --incremental mode

This commit is contained in:
Jannis Harder 2023-11-16 13:15:54 +01:00
parent 032fab1f54
commit e319606ec9
4 changed files with 512 additions and 64 deletions

View file

@ -17,7 +17,7 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#
import os, sys, getopt, re, bisect
import os, sys, getopt, re, bisect, json
##yosys-sys-path##
from smtio import SmtIo, SmtOpts, MkVcd
from ywio import ReadWitness, WriteWitness, WitnessValues
@ -56,6 +56,7 @@ binarymode = False
keep_going = False
check_witness = False
detect_loops = False
incremental = None
so = SmtOpts()
@ -185,6 +186,9 @@ def help():
check if states are unique in temporal induction counter examples
(this feature is experimental and incomplete)
--incremental
run in incremental mode (experimental)
""" + so.helpmsg())
def usage():
@ -196,7 +200,7 @@ try:
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops"])
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"])
except:
usage()
@ -282,6 +286,9 @@ for o, a in opts:
check_witness = True
elif o == "--detect-loops":
detect_loops = True
elif o == "--incremental":
from smtbmc_incremental import Incremental
incremental = Incremental()
elif so.handle(o, a):
pass
else:
@ -290,7 +297,7 @@ for o, a in opts:
if len(args) != 1:
usage()
if sum([tempind, gentrace, covermode]) > 1:
if sum([tempind, gentrace, covermode, incremental is not None]) > 1:
usage()
constr_final_start = None
@ -444,8 +451,10 @@ if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False
def print_msg(msg):
print("%s %s" % (smt.timestamp(), msg))
sys.stdout.flush()
if incremental:
incremental.print_msg(msg)
else:
print("%s %s" % (smt.timestamp(), msg), flush=True)
print_msg("Solver: %s" % (so.solver))
@ -640,10 +649,9 @@ if aimfile is not None:
num_steps = max(num_steps, step+2)
step += 1
if inywfile is not None:
if not got_topt:
skip_steps = 0
num_steps = 0
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
if map_steps is None:
map_steps = {}
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
@ -662,10 +670,14 @@ if inywfile is not None:
addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
max_t = -1
for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
@ -684,7 +696,7 @@ if inywfile is not None:
if common_end <= common_offset:
continue
smt_expr = smt.witness_net_expr(topmod, f"s{t}", wire)
smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if not smt_bool:
slice_high = common_end - offset - 1
@ -714,7 +726,7 @@ if inywfile is not None:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{t}", mem["smtpath"])
smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
if bv:
word_low = sig.memory_addr * width
@ -738,11 +750,21 @@ if inywfile is not None:
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
if not got_topt:
if not check_witness:
skip_steps = max(skip_steps, t)
num_steps = max(num_steps, t+1)
return max_t
if inywfile is not None:
if not got_topt:
skip_steps = 0
num_steps = 0
max_t = ywfile_constraints(inywfile, constr_assumes)
if not got_topt:
if not check_witness:
skip_steps = max(skip_steps, max_t)
num_steps = max(num_steps, max_t+1)
if btorwitfile is not None:
with open(btorwitfile, "r") as f:
@ -841,7 +863,7 @@ if btorwitfile is not None:
skip_steps = step
num_steps = step+1
def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
def collect_mem_trace_data(steps, vcd=None):
mem_trace_data = dict()
for mempath in sorted(smt.hiermems(topmod)):
@ -849,16 +871,16 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
expr_id = list()
expr_list = list()
for i in range(steps_start, steps_stop):
for seq, i in enumerate(steps):
for j in range(rports):
expr_id.append(('R', i-steps_start, j, 'A'))
expr_id.append(('R', i-steps_start, j, 'D'))
expr_id.append(('R', seq, j, 'A'))
expr_id.append(('R', seq, j, 'D'))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j))
for j in range(wports):
expr_id.append(('W', i-steps_start, j, 'A'))
expr_id.append(('W', i-steps_start, j, 'D'))
expr_id.append(('W', i-steps_start, j, 'M'))
expr_id.append(('W', seq, j, 'A'))
expr_id.append(('W', seq, j, 'D'))
expr_id.append(('W', seq, j, 'M'))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j))
@ -943,14 +965,14 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int_addr)
vcd.add_net([topmod] + netpath, width)
for i in range(steps_start, steps_stop):
for seq, i in enumerate(steps):
if i not in mem_trace_data:
mem_trace_data[i] = list()
mem_trace_data[i].append((netpath, int_addr, "".join(tdata[i-steps_start])))
mem_trace_data[i].append((netpath, int_addr, "".join(tdata[seq])))
return mem_trace_data
def write_vcd_trace(steps_start, steps_stop, index):
def write_vcd_trace(steps, index, seq_time=False):
filename = vcdfile.replace("%", index)
print_msg("Writing trace to VCD file: %s" % (filename))
@ -971,10 +993,10 @@ def write_vcd_trace(steps_start, steps_stop, index):
vcd.add_clock([topmod] + netpath, edge)
path_list.append(netpath)
mem_trace_data = collect_mem_trace_data(steps_start, steps_stop, vcd)
mem_trace_data = collect_mem_trace_data(steps, vcd)
for i in range(steps_start, steps_stop):
vcd.set_time(i)
for seq, i in enumerate(steps):
vcd.set_time(seq if seq_time else i)
value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i)
for path, value in zip(path_list, value_list):
vcd.set_net([topmod] + path, value)
@ -982,7 +1004,14 @@ def write_vcd_trace(steps_start, steps_stop, index):
for path, addr, value in mem_trace_data[i]:
vcd.set_net([topmod] + path, value)
vcd.set_time(steps_stop)
if seq_time:
end_time = len(steps)
elif steps:
end_time = steps[-1] + 1
else:
end_time = 0
vcd.set_time(end_time)
def detect_state_loop(steps_start, steps_stop):
print_msg(f"Checking for loops in found induction counter example")
@ -1027,7 +1056,7 @@ def escape_identifier(identifier):
def write_vlogtb_trace(steps_start, steps_stop, index):
def write_vlogtb_trace(steps, index):
filename = vlogtbfile.replace("%", index)
print_msg("Writing trace to Verilog testbench: %s" % (filename))
@ -1092,7 +1121,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print(" initial begin", file=f)
regs = sorted(smt.hiernets(vlogtb_topmod, regs_only=True))
regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps_start)))
regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps[0])))
print("`ifndef VERILATOR", file=f)
print(" #1;", file=f)
@ -1107,7 +1136,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
anyconsts = sorted(smt.hieranyconsts(vlogtb_topmod))
for info in anyconsts:
if info[3] is not None:
modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps_start)), info[0])
modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps[0])), info[0])
value = smt.bv2bin(smt.get("(|%s| %s)" % (info[1], modstate)))
print(" UUT.%s = %d'b%s;" % (".".join(escape_identifier(info[0] + [info[3]])), len(value), value), file=f);
@ -1117,7 +1146,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
addr_expr_list = list()
data_expr_list = list()
for i in range(steps_start, steps_stop):
for i in steps:
for j in range(rports):
addr_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j))
data_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j))
@ -1138,7 +1167,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print("", file=f)
anyseqs = sorted(smt.hieranyseqs(vlogtb_topmod))
for i in range(steps_start, steps_stop):
for i in steps:
pi_names = [[name] for name, _ in primary_inputs if name not in clock_inputs]
pi_values = smt.get_net_bin_list(vlogtb_topmod, pi_names, vlogtb_state.replace("@@step_idx@@", str(i)))
@ -1170,14 +1199,14 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print(" end", file=f)
print(" always @(posedge clock) begin", file=f)
print(" genclock <= cycle < %d;" % (steps_stop-1), file=f)
print(" genclock <= cycle < %d;" % (steps[-1]), file=f)
print(" cycle <= cycle + 1;", file=f)
print(" end", file=f)
print("endmodule", file=f)
def write_constr_trace(steps_start, steps_stop, index):
def write_constr_trace(steps, index):
filename = outconstr.replace("%", index)
print_msg("Writing trace to constraints file: %s" % (filename))
@ -1194,7 +1223,7 @@ def write_constr_trace(steps_start, steps_stop, index):
constr_prefix = smtctop[1] + "."
if smtcinit:
steps_start = steps_stop - 1
steps = [steps[-1]]
with open(filename, "w") as f:
primary_inputs = list()
@ -1203,13 +1232,13 @@ def write_constr_trace(steps_start, steps_stop, index):
width = smt.modinfo[constr_topmod].wsize[name]
primary_inputs.append((name, width))
if steps_start == 0 or smtcinit:
if steps[0] == 0 or smtcinit:
print("initial", file=f)
else:
print("state %d" % steps_start, file=f)
print("state %d" % steps[0], file=f)
regnames = sorted(smt.hiernets(constr_topmod, regs_only=True))
regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps_start)))
regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps[0])))
for name, val in zip(regnames, regvals):
print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f)
@ -1220,7 +1249,7 @@ def write_constr_trace(steps_start, steps_stop, index):
addr_expr_list = list()
data_expr_list = list()
for i in range(steps_start, steps_stop):
for i in steps:
for j in range(rports):
addr_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j))
data_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j))
@ -1236,7 +1265,7 @@ def write_constr_trace(steps_start, steps_stop, index):
for addr, data in addr_data.items():
print("assume (= (select [%s%s] %s) %s)" % (constr_prefix, ".".join(mempath), addr, data), file=f)
for k in range(steps_start, steps_stop):
for k in steps:
if not smtcinit:
print("", file=f)
print("state %d" % k, file=f)
@ -1247,11 +1276,14 @@ def write_constr_trace(steps_start, steps_stop, index):
for name, val in zip(pi_names, pi_values):
print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f)
def write_yw_trace(steps_start, steps_stop, index, allregs=False):
filename = outywfile.replace("%", index)
print_msg("Writing trace to Yosys witness file: %s" % (filename))
def write_yw_trace(steps, index, allregs=False, filename=None):
if filename is None:
if outywfile is None:
return
filename = outywfile.replace("%", index)
print_msg("Writing trace to Yosys witness file: %s" % (filename))
mem_trace_data = collect_mem_trace_data(steps_start, steps_stop)
mem_trace_data = collect_mem_trace_data(steps)
with open(filename, "w") as f:
inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs)
@ -1295,10 +1327,10 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False):
sig = yw.add_sig(word_path, overlap_start, overlap_end - overlap_start, True)
mem_init_values.append((sig, overlap_bits.replace("x", "?")))
for k in range(steps_start, steps_stop):
for i, k in enumerate(steps):
step_values = WitnessValues()
if k == steps_start:
if not i:
for sig, value in mem_init_values:
step_values[sig] = value
sigs = inits + seqs
@ -1314,17 +1346,24 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False):
def write_trace(steps_start, steps_stop, index, allregs=False):
if steps_stop is None:
steps = steps_start
seq_time = True
else:
steps = list(range(steps_start, steps_stop))
seq_time = False
if vcdfile is not None:
write_vcd_trace(steps_start, steps_stop, index)
write_vcd_trace(steps, index, seq_time=seq_time)
if vlogtbfile is not None:
write_vlogtb_trace(steps_start, steps_stop, index)
write_vlogtb_trace(steps, index)
if outconstr is not None:
write_constr_trace(steps_start, steps_stop, index)
write_constr_trace(steps, index)
if outywfile is not None:
write_yw_trace(steps_start, steps_stop, index, allregs)
write_yw_trace(steps, index, allregs)
def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()):
@ -1596,7 +1635,11 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert()
return smt.check_sat(expected=expected)
if tempind:
if incremental:
incremental.mainloop()
elif tempind:
retstatus = "FAILED"
skip_counter = step_size
for step in range(num_steps, -1, -1):
@ -1954,5 +1997,6 @@ else: # not tempind, covermode
smt.write("(exit)")
smt.wait()
print_msg("Status: %s" % retstatus)
sys.exit(0 if retstatus == "PASSED" else 1)
if not incremental:
print_msg("Status: %s" % retstatus)
sys.exit(0 if retstatus == "PASSED" else 1)