3
0
Fork 0
mirror of https://github.com/YosysHQ/sby.git synced 2025-04-11 00:13:33 +00:00

tools/cexenum: Add '--callback' option and refactor enumeration loop

This requires YosysHQ/yosys#4078
This commit is contained in:
Jannis Harder 2023-12-14 17:32:52 +01:00
parent eeee1a1ec5
commit 1ad19048e1

View file

@ -3,16 +3,25 @@ from __future__ import annotations
import asyncio
import json
import threading
import traceback
import argparse
import shutil
import shlex
import os
import sys
import urllib.parse
from pathlib import Path
from typing import Any, Awaitable, Literal
from typing import Any, Awaitable, Iterable, Literal
try:
import readline # type: ignore # noqa
except ImportError:
pass
import yosys_mau.task_loop.job_server as job
from yosys_mau import task_loop as tl
from yosys_mau.stable_set import StableSet
libexec = Path(__file__).parent.resolve() / "libexec"
@ -21,6 +30,10 @@ if libexec.exists():
os.environb[b"PATH"] = bytes(libexec) + b":" + os.environb[b"PATH"]
def safe_smtlib_id(name: str) -> str:
return f"|{urllib.parse.quote_plus(name).replace('+', ' ')}|"
def arg_parser():
parser = argparse.ArgumentParser(
prog="cexenum", usage="%(prog)s [options] <sby_workdir>"
@ -66,6 +79,15 @@ def arg_parser():
default="-s yices --unroll",
)
parser.add_argument(
"--callback",
metavar='"..."',
type=shlex.split,
help="command that will receive enumerated traces on stdin and can control "
"the enumeration via stdout (pass '-' to handle callbacks interactively)",
default="",
)
parser.add_argument("--debug", action="store_true", help="enable debug logging")
parser.add_argument(
"--debug-events", action="store_true", help="enable debug event logging"
@ -98,6 +120,8 @@ class App:
enum_depth: int
sim: bool
callback: list[str]
smtbmc_options: list[str]
work_dir: Path
@ -266,6 +290,7 @@ class MinimizeTrace(tl.Task):
App.cache_dir / "design_aiger.ywa",
min_yw,
cwd=App.trace_dir_min,
options=("--skip-x", "--present-only"),
)
aiw2yw[tl.LogContext].scope = f"aiw2yw[{stem}]"
aiw2yw.depends_on(aigcexmin)
@ -286,6 +311,18 @@ def relative_to(target: Path, cwd: Path) -> Path:
prefix = Path("")
target = target.resolve()
cwd = cwd.resolve()
ok = False
for limit in (Path.cwd(), App.work_dir):
limit = Path.cwd().resolve()
try:
target.relative_to(limit)
ok = True
except ValueError:
pass
if not ok:
return target
while True:
try:
return prefix / (target.relative_to(cwd))
@ -304,11 +341,13 @@ class YosysWitness(tl.process.Process):
mapfile: Path,
output: Path,
cwd: Path,
options: Iterable[str] = (),
):
super().__init__(
[
"yosys-witness",
mode,
*(options or []),
str(relative_to(input, cwd)),
str(relative_to(mapfile, cwd)),
str(relative_to(output, cwd)),
@ -381,87 +420,467 @@ class SimTrace(tl.process.Process):
self.log_output()
class Callback(tl.TaskGroup):
recover_from_errors = False
def __init__(self, enumeration: Enumeration):
super().__init__()
self[tl.LogContext].scope = "callback"
self.search_next = False
self.enumeration = enumeration
self.tmp_counter = 0
async def step_callback(self, step: int) -> Literal["advance", "search"]:
with self.as_current_task():
return await self._callback(step=step)
async def unsat_callback(self, step: int) -> Literal["advance", "search"]:
with self.as_current_task():
return await self._callback(step=step, unsat=True)
async def trace_callback(
self, step: int, path: Path
) -> Literal["advance", "search"]:
with self.as_current_task():
return await self._callback(step=step, trace_path=path)
async def _callback(
self,
step: int,
trace_path: Path | None = None,
unsat: bool = False,
) -> Literal["advance", "search"]:
if not self.is_active():
if unsat:
return "advance"
return "search"
if trace_path is None and self.search_next:
if unsat:
return "advance"
return "search"
self.search_next = False
info = dict(step=step, enabled=sorted(self.enumeration.active_assumptions))
if trace_path:
self.callback_write(
{**info, "event": "trace", "trace_path": str(trace_path)}
)
elif unsat:
self.callback_write({**info, "event": "unsat"})
else:
self.callback_write({**info, "event": "step"})
while True:
try:
response: dict[Any, Any] | Any = await self.callback_read()
if not isinstance(response, (Exception, dict)):
raise ValueError(
f"expected JSON object, got: {json.dumps(response)}"
)
if isinstance(response, Exception):
raise response
did_something = False
if "block_yw" in response:
did_something = True
block_yw: Any = response["block_yw"]
if not isinstance(block_yw, str):
raise ValueError(
"'block_yw' must be a string containing a file path, "
f"got: {json.dumps(block_yw)}"
)
name: Any = response.get("name")
if name is not None and not isinstance(name, str):
raise ValueError(
"optional 'name' must be a string when present, "
"got: {json.dumps(name)}"
)
self.enumeration.block_trace(Path(block_yw), name=name)
if "block_aiw" in response:
did_something = True
block_aiw: Any = response["block_aiw"]
if not isinstance(block_aiw, str):
raise ValueError(
"'block_yw' must be a string containing a file path, "
f"got: {json.dumps(block_aiw)}"
)
name: Any = response.get("name")
if name is not None and not isinstance(name, str):
raise ValueError(
"optional 'name' must be a string when present, "
"got: {json.dumps(name)}"
)
tmpdir = App.work_subdir / "tmp"
tmpdir.mkdir(exist_ok=True)
self.tmp_counter += 1
block_yw = tmpdir / f"callback_{self.tmp_counter}.yw"
aiw2yw = YosysWitness(
"aiw2yw",
Path(block_aiw),
App.cache_dir / "design_aiger.ywa",
Path(block_yw),
cwd=tmpdir,
options=("--skip-x", "--present-only"),
)
aiw2yw[tl.LogContext].scope = f"aiw2yw[callback_{self.tmp_counter}]"
await aiw2yw.finished
self.enumeration.block_trace(Path(block_yw), name=name)
if "disable" in response:
did_something = True
name = response["disable"]
if not isinstance(name, str):
raise ValueError(
"'disable' must be a string representing an assumption, "
f"got: {json.dumps(name)}"
)
self.enumeration.disable_assumption(name)
if "enable" in response:
did_something = True
name = response["enable"]
if not isinstance(name, str):
raise ValueError(
"'disable' must be a string representing an assumption, "
f"got: {json.dumps(name)}"
)
self.enumeration.enable_assumption(name)
action: Any = response.get("action")
if action == "next":
did_something = True
self.search_next = True
action = "search"
if action in ("search", "advance"):
return action
if not did_something:
raise ValueError(
f"could not interpret callback response: {response}"
)
except Exception as e:
tl.log_exception(e, raise_error=False)
if not self.recover_from_errors:
raise
def is_active(self) -> bool:
return False
def callback_write(self, data: Any):
return
async def callback_read(self) -> Any:
raise NotImplementedError("must be implemented in Callback subclass")
class InteractiveCallback(Callback):
recover_from_errors = True
interactive_shortcuts = {
"n": '{"action": "next"}',
"s": '{"action": "search"}',
"a": '{"action": "advance"}',
}
def __init__(self, enumeration: Enumeration):
super().__init__(enumeration)
self.__eof_reached = False
def is_active(self) -> bool:
return not self.__eof_reached
def callback_write(self, data: Any):
print(f"callback: {json.dumps(data)}")
async def callback_read(self) -> Any:
future: asyncio.Future[Any] = asyncio.Future()
loop = asyncio.get_event_loop()
def blocking_read():
try:
try:
result = ""
while not result:
result = input(
"callback> " if sys.stdout.isatty() else ""
).strip()
result = self.interactive_shortcuts.get(result, result)
result_data = json.loads(result)
except EOFError:
print()
self.__eof_reached = True
result_data = dict(action="next")
loop.call_soon_threadsafe(lambda: future.set_result(result_data))
except Exception as exc:
exception = exc
loop.call_soon_threadsafe(lambda: future.set_exception(exception))
thread = threading.Thread(target=blocking_read, daemon=True)
thread.start()
return await future
class ProcessCallback(Callback):
def __init__(self, enumeration: Enumeration, command: list[str]):
super().__init__(enumeration)
self[tl.LogContext].scope = "callback"
self.__eof_reached = False
self._command = command
self._lines: list[asyncio.Future[str]] = [asyncio.Future()]
async def on_prepare(self) -> None:
self.process = tl.Process(self._command, cwd=Path.cwd(), interact=True)
self.process.use_lease = False
future_line: asyncio.Future[str] = self._lines[-1]
def stdout_handler(event: tl.process.StdoutEvent):
nonlocal future_line
future_line.set_result(event.output)
future_line = asyncio.Future()
self._lines.append(future_line)
self.process.sync_handle_events(tl.process.StdoutEvent, stdout_handler)
def stderr_handler(event: tl.process.StderrEvent):
tl.log(event.output)
self.process.sync_handle_events(tl.process.StderrEvent, stderr_handler)
def exit_handler(event: tl.process.ExitEvent):
self._lines[-1].set_exception(EOFError("callback process exited"))
self.process.sync_handle_events(tl.process.ExitEvent, exit_handler)
def is_active(self) -> bool:
return not self.__eof_reached
def callback_write(self, data: Any):
if not self.process.is_finished:
self.process.write(json.dumps(data) + "\n")
async def callback_read(self) -> Any:
future_line = self._lines.pop(0)
try:
data = json.loads(await future_line)
tl.log_debug(f"callback action: {data}")
return data
except EOFError:
self._lines.insert(0, future_line)
self.__eof_reached = True
return dict(action="next")
class Enumeration(tl.Task):
callback_mode: Literal["off", "interactive", "process"]
callback_auto_search: bool = False
def __init__(self, aig_model: tl.Task):
self.aig_model = aig_model
self._pending_blocks: list[tuple[str | None, Path]] = []
self.named_assumptions: StableSet[str] = StableSet()
self.active_assumptions: StableSet[str] = StableSet()
super().__init__()
async def on_prepare(self) -> None:
if App.callback:
if App.callback == ["-"]:
self.callback_task = InteractiveCallback(self)
else:
self.callback_task = ProcessCallback(self, App.callback)
else:
self.callback_task = Callback(self)
async def on_run(self) -> None:
smtbmc = Smtbmc(App.work_dir / "model" / "design_smt2.smt2")
self.smtbmc = smtbmc = Smtbmc(App.work_dir / "model" / "design_smt2.smt2")
self._push_level = 0
await smtbmc.ping()
pred = None
i = 0
i = -1
limit = App.depth
first_failure = None
while i <= limit:
tl.log(f"Checking assumptions in step {i}..")
presat_checked = await batch(
smtbmc.bmc_step(i, initial=i == 0, assertions=None, pred=pred),
smtbmc.check(),
)
if presat_checked != "sat":
if first_failure is None:
tl.log_error("Assumptions are not satisfiable")
else:
tl.log("No further counter-examples are reachable")
return
checked = "skip"
tl.log(f"Checking assertions in step {i}..")
checked = await batch(
smtbmc.push(),
smtbmc.assertions(i, False),
smtbmc.check(),
)
pred = i
if checked != "unsat":
counter = 0
while i <= limit or limit < 0:
if checked != "skip":
checked = await self._search_counter_example(i)
if checked == "unsat":
if i >= 0:
action = await self.callback_task.unsat_callback(i)
if action == "search":
continue
checked = "skip"
if checked == "skip":
checked = "unsat"
i += 1
if i > limit and limit >= 0:
break
action = await self.callback_task.step_callback(i)
pending = batch(
self._top(),
smtbmc.bmc_step(
i, initial=i == 0, assertions=None, pred=i - 1 if i else None
),
)
if action == "advance":
tl.log(f"Skipping step {i}")
await batch(pending, smtbmc.assertions(i))
checked = "skip"
continue
assert action == "search"
tl.log(f"Checking assumptions in step {i}..")
presat_checked = await batch(
pending,
smtbmc.check(),
)
if presat_checked != "sat":
if first_failure is None:
tl.log_error("Assumptions are not satisfiable")
else:
tl.log("No further counter-examples are reachable")
smtbmc.close_stdin()
return
tl.log(f"Checking assertions in step {i}..")
counter = 0
continue
elif checked == "sat":
if first_failure is None:
first_failure = i
limit = i + App.enum_depth
if App.enum_depth < 0:
limit = -1
else:
limit = i + App.enum_depth
tl.log("BMC failed! Enumerating counter-examples..")
counter = 0
assert checked == "sat"
path = App.trace_dir_full / f"trace{i}_{counter}.yw"
await smtbmc.incremental_command(cmd="write_yw_trace", path=str(path))
tl.log(f"Written counter-example to {path.name}")
while checked == "sat":
await smtbmc.incremental_command(
cmd="write_yw_trace", path=str(path)
)
tl.log(f"Written counter-example to {path.name}")
minimize = MinimizeTrace(path.name, self.aig_model)
minimize.depends_on(self.aig_model)
minimize = MinimizeTrace(path.name, self.aig_model)
minimize.depends_on(self.aig_model)
await minimize.aiw2yw.finished
await minimize.aiw2yw.finished
min_path = App.trace_dir_min / f"trace{i}_{counter}.yw"
min_path = App.trace_dir_min / f"trace{i}_{counter}.yw"
action = await self.callback_task.trace_callback(i, min_path)
if action == "advance":
tl.log("Skipping remaining counter-examples for this step")
checked = "skip"
continue
assert action == "search"
checked = await batch(
smtbmc.incremental_command(
cmd="read_yw_trace",
name="last",
path=str(min_path),
skip_x=True,
),
smtbmc.assert_(
["not", ["and", *(["yw", "last", k] for k in range(i + 1))]]
),
smtbmc.check(),
)
counter += 1
path = App.trace_dir_full / f"trace{i}_{counter}.yw"
await batch(smtbmc.pop(), smtbmc.assertions(i))
i += 1
self.block_trace(min_path)
counter += 1
else:
tl.log_error(f"Unexpected solver result: {checked!r}")
smtbmc.close_stdin()
def block_trace(self, path: Path, name: str | None = None):
if name is not None:
if name in self.named_assumptions:
raise ValueError(f"an assumption with name {name} was already defined")
self.named_assumptions.add(name)
self.active_assumptions.add(name)
self._pending_blocks.append((name, path.absolute()))
def enable_assumption(self, name: str):
if name not in self.named_assumptions:
raise ValueError(f"unknown assumption {name!r}")
self.active_assumptions.add(name)
def disable_assumption(self, name: str):
if name not in self.named_assumptions:
raise ValueError(f"unknown assumption {name!r}")
self.active_assumptions.discard(name)
def _top(self) -> Awaitable[Any]:
return batch(*(self._pop() for _ in range(self._push_level)))
def _pop(self) -> Awaitable[Any]:
self._push_level -= 1
tl.log_debug(f"pop to {self._push_level}")
return self.smtbmc.pop()
def _push(self) -> Awaitable[Any]:
self._push_level += 1
tl.log_debug(f"push to {self._push_level}")
return self.smtbmc.push()
def _search_counter_example(self, step: int) -> Awaitable[Any]:
smtbmc = self.smtbmc
pending_blocks, self._pending_blocks = self._pending_blocks, []
pending = self._top()
for name, block_path in pending_blocks:
result = smtbmc.incremental_command(
cmd="read_yw_trace",
name="last",
path=str(block_path),
skip_x=True,
)
async def check_yw_trace_len():
last_step = (await result).get("last_step", step)
if last_step > step:
tl.log_warning(
f"Ignoring future time steps "
f"{step + 1} to {last_step} of "
f"{relative_to(block_path, Path.cwd())}"
)
return last_step
expr = [
"not",
["and", *(["yw", "last", k] for k in range(step + 1))],
]
if name is not None:
name_id = safe_smtlib_id(f"cexenum trace {name}")
pending = batch(
pending, smtbmc.smtlib(f"(declare-const {name_id} Bool)")
)
expr = ["=", expr, ["smtlib", name_id, "Bool"]]
pending = batch(
pending,
check_yw_trace_len(),
smtbmc.assert_(expr),
)
pending = batch(
pending,
self._push(),
smtbmc.assertions(step, False),
)
for name in self.active_assumptions:
name_id = safe_smtlib_id(f"cexenum trace {name}")
pending = batch(pending, smtbmc.assert_(["smtlib", name_id, "Bool"]))
return batch(
pending,
smtbmc.check(),
)
class Smtbmc(tl.process.Process):
def __init__(self, smt2_model: Path):
@ -483,7 +902,7 @@ class Smtbmc(tl.process.Process):
async def on_run(self) -> None:
def output_handler(event: tl.process.StdoutEvent):
line = event.output.strip()
if line.startswith('{'):
if line.startswith("{"):
result = json.loads(event.output)
else:
result = dict(msg=line)
@ -527,6 +946,9 @@ class Smtbmc(tl.process.Process):
def check(self) -> Awaitable[str]:
return self.incremental_command(cmd="check")
def smtlib(self, command: str) -> Awaitable[str]:
return self.incremental_command(cmd="smtlib", command=command)
def assert_antecedent(self, expr: Any) -> Awaitable[None]:
return self.incremental_command(cmd="assert_antecedent", expr=expr)