3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-07-31 00:13:18 +00:00

smtbmc: Improvements for --incremental and .yw fixes

This extends the experimental incremental JSON API to allow arbitrary
smtlib subexpressions, defining smtlib constants and to allow access of
signals by their .yw path.

It also fixes a bug during .yw writing where values would be re-emitted
in later cycles if they have no newer defined value and a potential
crash when using --track-assumes.
This commit is contained in:
Jannis Harder 2024-05-07 17:57:37 +02:00
parent 71f2540cd8
commit a52088b6af
3 changed files with 284 additions and 97 deletions

View file

@ -199,7 +199,6 @@ def help():
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())
def usage():
@ -670,18 +669,12 @@ if aimfile is not None:
ywfile_hierwitness_cache = None
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
def ywfile_hierwitness():
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}
if ywfile_hierwitness_cache is None:
ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True)
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
if ywfile_hierwitness_cache is None:
ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True)
inits, seqs, clocks, mems = ywfile_hierwitness_cache
inits, seqs, clocks, mems = ywfile_hierwitness
smt_wires = defaultdict(list)
smt_mems = defaultdict(list)
@ -692,9 +685,128 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
for mem in mems:
smt_mems[mem["path"]].append(mem)
addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems
return ywfile_hierwitness_cache
def_bits_re = re.compile(r'([01]+)')
def smt_extract_mask(smt_expr, mask):
chunks = []
def_bits = ''
mask_index_order = mask[::-1]
for matched in def_bits_re.finditer(mask_index_order):
chunks.append(matched.span())
def_bits += matched[0]
if not chunks:
return
if len(chunks) == 1:
start, end = chunks[0]
if start == 0 and end == len(mask_index_order):
combined_chunks = smt_expr
else:
combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr)
else:
combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join(
'((_ extract %d %d) x)' % (end - 1, start)
for start, end in reversed(chunks)
))
return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1]
def smt_concat(exprs):
if not exprs:
return ""
if len(exprs) == 1:
return exprs[1]
return "(concat %s)" % ' '.join(exprs)
def ywfile_signal(sig, step, mask=None):
assert sig.width > 0
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
sig_end = sig.offset + sig.width
output = []
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
offset = max(offset, 0)
end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue
smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire)
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
else:
smt_expr = "(ite %s #b1 #b0)" % smt_expr
output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr))
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"])
if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
if sig.width < width:
slice_high = sig.offset + sig.width - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
output.append((0, sig.width, smt_expr))
output.sort()
output = [chunk for chunk in output if chunk[0] != chunk[1]]
pos = 0
for start, end, smt_expr in output:
assert start == pos
pos = end
assert pos == sig.width
if len(output) == 1:
return output[0][-1]
return smt_concat(smt_expr for start, end, smt_expr in reversed(output))
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
bits_re = re.compile(r'[01?]*$')
max_t = -1
for t, step in inyw.steps():
@ -706,77 +818,14 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
sig_end = sig.offset + len(bits)
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
smt_expr = ywfile_signal(sig, map_steps.get(t, t))
smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
smt_expr, bits = smt_extract_mask(smt_expr, bits)
offset = max(offset, 0)
smt_constr = "(= %s #b%s)" % (smt_expr, bits)
constr_assumes[t].append((inywfile, smt_constr))
end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue
smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)]
if bit_slice.count("?") == len(bit_slice):
continue
if smt_bool:
assert width == 1
smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false")
else:
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
if len(bits) < width:
slice_high = sig.offset + len(bits) - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
bit_slice = bits
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
return max_t
if inywfile is not None:
@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None):
exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs)
all_sigs.append(sigs)
all_sigs.append((step_values, sigs))
bvs = iter(smt.get_list(exprs))
for sigs in all_sigs:
for (step_values, sigs) in all_sigs:
for sig in sigs:
value = smt.bv2bin(next(bvs))
step_values[sig["sig"]] = value