3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-04-06 01:24:10 +00:00
yosys/backends/smt2/smtbmc_incremental.py
Jannis Harder a52088b6af smtbmc: Improvements for --incremental and .yw fixes
This extends the experimental incremental JSON API to allow arbitrary
smtlib subexpressions, defining smtlib constants and to allow access of
signals by their .yw path.

It also fixes a bug during .yw writing where values would be re-emitted
in later cycles if they have no newer defined value and a potential
crash when using --track-assumes.
2024-05-07 17:57:37 +02:00

617 lines
19 KiB
Python

from collections import defaultdict
import json
import typing
import ywio
if typing.TYPE_CHECKING:
import smtbmc
else:
import sys
smtbmc = sys.modules["__main__"]
class InteractiveError(Exception):
pass
def mkkey(data):
if isinstance(data, list):
return tuple(map(mkkey, data))
elif isinstance(data, dict):
raise InteractiveError(f"JSON objects found in assumption key: {data!r}")
return data
class Incremental:
def __init__(self):
self.traceidx = 0
self.state_set = set()
self.map_cache = {}
self._cached_hierwitness = {}
self._witness_index = None
self._yw_constraints = {}
self._define_sorts = {}
def setup(self):
generic_assert_map = smtbmc.get_assert_map(
smtbmc.topmod, "state", smtbmc.topmod
)
self.inv_generic_assert_map = {
tuple(data[1:]): key for key, data in generic_assert_map.items()
}
assert len(self.inv_generic_assert_map) == len(generic_assert_map)
def print_json(self, **kwargs):
print(json.dumps(kwargs), flush=True)
def print_msg(self, msg):
self.print_json(msg=msg)
def get_cached_assert(self, step, name):
try:
assert_map = self.map_cache[step]
except KeyError:
assert_map = self.map_cache[step] = smtbmc.get_assert_map(
smtbmc.topmod, f"s{step}", smtbmc.topmod
)
return assert_map[self.inv_generic_assert_map[name]][0]
def arg_step(self, cmd, declare=False, name="step", optional=False):
step = cmd.get(name, None)
if step is None and optional:
return None
if not isinstance(step, int):
if optional:
raise InteractiveError(f"{name} must be an integer")
else:
raise InteractiveError(f"integer {name} argument required")
if declare and step in self.state_set:
raise InteractiveError(f"step {step} already declared")
if not declare and step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
return step
def expr_arg_len(self, expr, min_len, max_len=-1):
if max_len == -1:
max_len = min_len
arg_len = len(expr) - 1
if min_len is not None and arg_len < min_len:
if min_len == max_len:
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
else:
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
if max_len is not None and arg_len > max_len:
raise InteractiveError(
f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}"
)
def expr_step(self, expr, smt_out):
self.expr_arg_len(expr, 1)
step = expr[1]
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
smt_out.append(f"s{step}")
return "module", smtbmc.topmod
def expr_cell(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None])
smt_out.append(")")
module = arg_sort[1]
cell = expr[1]
submod = smtbmc.smt.modinfo[module].cells.get(cell)
if submod is None:
raise InteractiveError(f"module {module!r} has no cell {cell!r}")
smt_out[position] = f"(|{module}_h {cell}| "
return ("module", submod)
def expr_mod_constraint(self, expr, smt_out):
suffix = expr[0][3:]
self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None])
module = arg_sort[1]
if len(expr) == 3:
smt_out[position] = f"(|{module}{suffix} {expr[1]}| "
else:
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"
def expr_mod_constraint2(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
module = arg_sort[1]
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"
def expr_not(self, expr, smt_out):
self.expr_arg_len(expr, 1)
smt_out.append("(not ")
self.expr(expr[1], smt_out, required_sort="Bool")
smt_out.append(")")
return "Bool"
def expr_eq(self, expr, smt_out):
self.expr_arg_len(expr, 2)
smt_out.append("(= ")
arg_sort = self.expr(expr[1], smt_out)
if (
smtbmc.smt.unroll
and isinstance(arg_sort, (list, tuple))
and arg_sort[0] == "module"
):
raise InteractiveError("state equality not supported in unroll mode")
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
smt_out.append(")")
return "Bool"
def expr_andor(self, expr, smt_out):
if len(expr) == 1:
smt_out.push({"and": "true", "or": "false"}[expr[0]])
elif len(expr) == 2:
self.expr(expr[1], smt_out, required_sort="Bool")
else:
sep = f"({expr[0]} "
for arg in expr[1:]:
smt_out.append(sep)
sep = " "
self.expr(arg, smt_out, required_sort="Bool")
smt_out.append(")")
return "Bool"
def expr_bv_binop(self, expr, smt_out):
self.expr_arg_len(expr, 2)
smt_out.append(f"({expr[0]} ")
arg_sort = self.expr(expr[1], smt_out, required_sort=("BitVec", None))
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
smt_out.append(")")
return arg_sort
def expr_extract(self, expr, smt_out):
self.expr_arg_len(expr, 3)
hi = expr[1]
lo = expr[2]
smt_out.append(f"((_ extract {hi} {lo}) ")
arg_sort = self.expr(expr[3], smt_out, required_sort=("BitVec", None))
smt_out.append(")")
if not (isinstance(hi, int) and 0 <= hi < arg_sort[1]):
raise InteractiveError(
f"high bit index must be 0 <= index < {arg_sort[1]}, is {hi!r}"
)
if not (isinstance(lo, int) and 0 <= lo <= hi):
raise InteractiveError(
f"low bit index must be 0 <= index < {hi}, is {lo!r}"
)
return "BitVec", hi - lo + 1
def expr_bv(self, expr, smt_out):
self.expr_arg_len(expr, 1)
arg = expr[1]
if not isinstance(arg, str) or arg.count("0") + arg.count("1") != len(arg):
raise InteractiveError("bv argument must contain only 0 or 1 bits")
smt_out.append("#b" + arg)
return "BitVec", len(arg)
def expr_yw(self, expr, smt_out):
self.expr_arg_len(expr, 1, 2)
if len(expr) == 2:
name = None
step = expr[1]
elif len(expr) == 3:
name = expr[1]
step = expr[2]
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
if name not in self._yw_constraints:
raise InteractiveError(f"no yw file loaded as name {name!r}")
constraints = self._yw_constraints[name].get(step, [])
if len(constraints) == 0:
smt_out.append("true")
elif len(constraints) == 1:
smt_out.append(constraints[0])
else:
sep = "(and "
for constraint in constraints:
smt_out.append(sep)
sep = " "
smt_out.append(constraint)
smt_out.append(")")
return "Bool"
def expr_yw_sig(self, expr, smt_out):
self.expr_arg_len(expr, 3, 4)
step = expr[1]
path = expr[2]
offset = expr[3]
width = expr[4] if len(expr) == 5 else 1
if not isinstance(offset, int) or offset < 0:
raise InteractiveError(
f"offset must be a non-negative integer, got {json.dumps(offset)}"
)
if not isinstance(width, int) or width <= 0:
raise InteractiveError(
f"width must be a positive integer, got {json.dumps(width)}"
)
if not isinstance(path, list) or not all(isinstance(s, str) for s in path):
raise InteractiveError(
f"path must be a string list, got {json.dumps(path)}"
)
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
smt_expr = smtbmc.ywfile_signal(
ywio.WitnessSig(path=path, offset=offset, width=width), step
)
smt_out.append(smt_expr)
return "BitVec", width
def expr_smtlib(self, expr, smt_out):
self.expr_arg_len(expr, 2)
smtlib_expr = expr[1]
sort = expr[2]
if not isinstance(smtlib_expr, str):
raise InteractiveError(
"raw SMT-LIB expression has to be a string, "
f"got {json.dumps(smtlib_expr)}"
)
if (
isinstance(sort, list)
and len(sort) == 2
and sort[0] == "BitVec"
and (sort[1] is None or isinstance(sort[1], int))
):
sort = tuple(sort)
elif not isinstance(sort, str):
raise InteractiveError(f"unsupported raw SMT-LIB sort {json.dumps(sort)}")
smt_out.append(smtlib_expr)
return sort
def expr_label(self, expr, smt_out):
if len(expr) != 3:
raise InteractiveError(
f'expected ["!", label, sub_expr], got {json.dumps(expr)}'
)
label = expr[1]
subexpr = expr[2]
if not isinstance(label, str):
raise InteractiveError("expression label has to be a string")
smt_out.append("(! ")
sort = self.expr(subexpr, smt_out)
smt_out.append(" :named ")
smt_out.append(label)
smt_out.append(")")
return sort
def expr_def(self, expr, smt_out):
self.expr_arg_len(expr, 1)
sort = self._define_sorts.get(expr[1])
if sort is None:
raise InteractiveError(f"unknown definition {json.dumps(expr)}")
smt_out.append(expr[1])
return sort
expr_handlers = {
"step": expr_step,
"cell": expr_cell,
"mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint,
"mod_a": expr_mod_constraint,
"mod_u": expr_mod_constraint,
"mod_t": expr_mod_constraint2,
"not": expr_not,
"and": expr_andor,
"or": expr_andor,
"bv": expr_bv,
"bvand": expr_bv_binop,
"bvor": expr_bv_binop,
"bvxor": expr_bv_binop,
"extract": expr_extract,
"def": expr_def,
"=": expr_eq,
"yw": expr_yw,
"yw_sig": expr_yw_sig,
"smtlib": expr_smtlib,
"!": expr_label,
}
def expr(self, expr, smt_out, required_sort=None):
if not isinstance(expr, (list, tuple)) or not expr:
raise InteractiveError(
f"expression must be a non-empty JSON array, found: {json.dumps(expr)}"
)
name = expr[0]
handler = self.expr_handlers.get(name)
if handler:
sort = handler(self, expr, smt_out)
if required_sort is not None:
if isinstance(required_sort, (list, tuple)):
if (
not isinstance(sort, (list, tuple))
or len(sort) != len(required_sort)
or any(
r is not None and r != s
for r, s in zip(required_sort, sort)
)
):
raise InteractiveError(
f"required sort {json.dumps(required_sort)} "
f"found sort {json.dumps(sort)}"
)
return sort
raise InteractiveError(f"unknown expression {json.dumps(expr[0])}")
def expr_smt(self, expr, required_sort):
return self.expr_smt_and_sort(expr, required_sort)[0]
def expr_smt_and_sort(self, expr, required_sort=None):
smt_out = []
output_sort = self.expr(expr, smt_out, required_sort=required_sort)
out = "".join(smt_out)
return out, output_sort
def cmd_new_step(self, cmd):
step = self.arg_step(cmd, declare=True)
self.state_set.add(step)
smtbmc.smt_state(step)
def cmd_assert(self, cmd):
name = cmd.get("cmd")
assert_fn = {
"assert_antecedent": smtbmc.smt_assert_antecedent,
"assert_consequent": smtbmc.smt_assert_consequent,
"assert": smtbmc.smt_assert,
}[name]
assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))
def cmd_assert_design_assumes(self, cmd):
step = self.arg_step(cmd)
smtbmc.smt_assert_design_assumes(step)
def cmd_get_design_assume(self, cmd):
key = mkkey(cmd.get("key"))
return smtbmc.assume_enables.get(key)
def cmd_update_assumptions(self, cmd):
expr = cmd.get("expr")
key = cmd.get("key")
key = mkkey(key)
result = smtbmc.smt.smt2_assumptions.pop(key, None)
if expr is not None:
expr = self.expr_smt(expr, "Bool")
smtbmc.smt.smt2_assumptions[key] = expr
return result
def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get("minimize")))
def cmd_push(self, cmd):
smtbmc.smt_push()
def cmd_pop(self, cmd):
smtbmc.smt_pop()
def cmd_check(self, cmd):
return smtbmc.smt_check_sat()
def cmd_smtlib(self, cmd):
command = cmd.get("command")
response = cmd.get("response", False)
if not isinstance(command, str):
raise InteractiveError(
f"raw SMT-LIB command must be a string, found {json.dumps(command)}"
)
smtbmc.smt.write(command)
if response:
return smtbmc.smt.read()
def cmd_define(self, cmd):
expr = cmd.get("expr")
if expr is None:
raise InteractiveError("'define' copmmand requires 'expr' parameter")
expr, sort = self.expr_smt_and_sort(expr)
if isinstance(sort, tuple) and sort[0] == "module":
raise InteractiveError("'define' does not support module sorts")
define_name = f"|inc def {len(self._define_sorts)}|"
self._define_sorts[define_name] = sort
if isinstance(sort, tuple):
sort = f"(_ {' '.join(map(str, sort))})"
smtbmc.smt.write(f"(define-const {define_name} {sort} {expr})")
return {"name": define_name}
def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
if self._cached_hierwitness[allregs] is not None:
return self._cached_hierwitness[allregs]
inits, seqs, clocks, mems = smtbmc.smt.hierwitness(smtbmc.topmod, allregs)
self._cached_hierwitness[allregs] = result = dict(
inits=inits, seqs=seqs, clocks=clocks, mems=mems
)
return result
def cmd_write_yw_trace(self, cmd):
steps = cmd.get("steps")
allregs = bool(cmd.get("allregs", False))
if steps is None:
steps = sorted(self.state_set)
path = cmd.get("path")
smtbmc.write_yw_trace(steps, self.traceidx, allregs=allregs, filename=path)
if path is None:
self.traceidx += 1
def cmd_read_yw_trace(self, cmd):
steps = cmd.get("steps")
path = cmd.get("path")
name = cmd.get("name")
skip_x = cmd.get("skip_x", False)
if path is None:
raise InteractiveError("path required")
constraints = defaultdict(list)
if steps is None:
steps = sorted(self.state_set)
map_steps = {i: int(j) for i, j in enumerate(steps)}
last_step = smtbmc.ywfile_constraints(
path, constraints, map_steps=map_steps, skip_x=skip_x
)
self._yw_constraints[name] = {
map_steps.get(i, i): [smtexpr for cexfile, smtexpr in constraint_list]
for i, constraint_list in constraints.items()
}
return dict(last_step=last_step)
def cmd_modinfo(self, cmd):
fields = cmd.get("fields", [])
mod = cmd.get("mod")
if mod is None:
mod = smtbmc.topmod
modinfo = smtbmc.smt.modinfo.get(mod)
if modinfo is None:
return None
result = dict(name=mod)
for field in fields:
result[field] = getattr(modinfo, field, None)
return result
def cmd_ping(self, cmd):
return cmd
cmd_handlers = {
"new_step": cmd_new_step,
"assert": cmd_assert,
"assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert,
"assert_design_assumes": cmd_assert_design_assumes,
"get_design_assume": cmd_get_design_assume,
"update_assumptions": cmd_update_assumptions,
"get_unsat_assumptions": cmd_get_unsat_assumptions,
"push": cmd_push,
"pop": cmd_pop,
"check": cmd_check,
"smtlib": cmd_smtlib,
"define": cmd_define,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,
"modinfo": cmd_modinfo,
"ping": cmd_ping,
}
def handle_command(self, cmd):
if not isinstance(cmd, dict) or "cmd" not in cmd:
raise InteractiveError('object with "cmd" key required')
name = cmd.get("cmd", None)
handler = self.cmd_handlers.get(name)
if handler:
return handler(self, cmd)
else:
raise InteractiveError(f"unknown command: {name}")
def mainloop(self):
self.setup()
while True:
try:
cmd = input().strip()
if not cmd or cmd.startswith("#") or cmd.startswith("//"):
continue
try:
cmd = json.loads(cmd)
except json.decoder.JSONDecodeError as e:
self.print_json(err=f"invalid JSON: {e}")
continue
except EOFError:
break
try:
result = self.handle_command(cmd)
except InteractiveError as e:
self.print_json(err=str(e))
continue
except Exception as e:
self.print_json(err=f"internal error: {e}")
raise
else:
self.print_json(ok=result)