From 43f4feb7842b01509cc9194936008bfe3a4e77e2 Mon Sep 17 00:00:00 2001 From: Jannis Harder Date: Fri, 11 Oct 2024 16:06:27 +0200 Subject: [PATCH] Update cexenum tool to latest version --- tools/aigcexmin/src/aig_eval.rs | 11 + tools/aigcexmin/src/care_graph.rs | 44 +- tools/aigcexmin/src/main.rs | 5 + tools/cexenum/cexenum.py | 1098 ++++++++++++++++++++++++++--- 4 files changed, 1074 insertions(+), 84 deletions(-) diff --git a/tools/aigcexmin/src/aig_eval.rs b/tools/aigcexmin/src/aig_eval.rs index 7e5543c..b80f7f0 100644 --- a/tools/aigcexmin/src/aig_eval.rs +++ b/tools/aigcexmin/src/aig_eval.rs @@ -6,6 +6,17 @@ pub trait AigValue: Copy { fn invert_if(self, en: bool, ctx: &mut Context) -> Self; fn and(self, other: Self, ctx: &mut Context) -> Self; fn constant(value: bool, ctx: &mut Context) -> Self; + + fn invert(self, ctx: &mut Context) -> Self { + self.invert_if(true, ctx) + } + + fn or(self, other: Self, ctx: &mut Context) -> Self { + let not_self = self.invert(ctx); + let not_other = other.invert(ctx); + let not_or = not_self.and(not_other, ctx); + not_or.invert(ctx) + } } pub fn initial_frame( diff --git a/tools/aigcexmin/src/care_graph.rs b/tools/aigcexmin/src/care_graph.rs index 0f608fe..9a8d41c 100644 --- a/tools/aigcexmin/src/care_graph.rs +++ b/tools/aigcexmin/src/care_graph.rs @@ -23,7 +23,15 @@ pub struct NodeRef { impl std::fmt::Debug for NodeRef { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("NodeRef::new").field(&self.index()).finish() + if self.is_const() { + if *self == Self::TRUE { + f.debug_struct("NodeRef::TRUE").finish() + } else { + f.debug_struct("NodeRef::FALSE").finish() + } + } else { + f.debug_tuple("NodeRef::new").field(&self.index()).finish() + } } } @@ -48,6 +56,10 @@ impl NodeRef { pub fn index(self) -> usize { !(self.code.0.get()) as usize } + + pub fn is_const(self) -> bool { + self == Self::TRUE || self == Self::FALSE + } } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -348,7 +360,10 @@ impl AndOrGraph { } }; - if shuffle > 0 && output != NodeRef::FALSE && output != NodeRef::TRUE { + if (shuffle > 0 || new_inputs != inputs) + && output != NodeRef::FALSE + && output != NodeRef::TRUE + { if let Ok((gate, inputs)) = self.nodes[output.index()].def.as_gate() { let [a, b] = inputs.map(|input| self.nodes[input.index()].priority); @@ -438,6 +453,7 @@ pub enum Verification { pub struct MinimizationOptions { pub fixed_init: bool, + pub satisfy_assumptions: bool, pub verify: Option, } @@ -509,6 +525,7 @@ pub fn minimize( ); let mut minimization_target = 'minimization_target: { + let mut invariant_failed = (Some(false), NodeRef::TRUE); for (t, inputs) in frame_inputs.iter().enumerate() { if t > 0 { successor_frame( @@ -527,18 +544,29 @@ pub fn minimize( &mut graph, ); } + + if options.satisfy_assumptions { + for invariant in aig.invariant_constraints.iter() { + let (var, polarity) = unpack_lit(*invariant); + let inv_invariant = state[var].invert_if(!polarity, &mut graph); + + invariant_failed = invariant_failed.or(inv_invariant, &mut graph); + } + } + let mut good_state = (Some(true), NodeRef::TRUE); for (i, bad) in aig.bad_state_properties.iter().enumerate() { let (var, polarity) = unpack_lit(*bad); let inv_bad = state[var].invert_if(!polarity, &mut graph); - if inv_bad.0 == Some(false) { + if inv_bad.0 == Some(false) && invariant_failed.0 == Some(false) { println!("bad state property {i} active in frame {t}"); } good_state = good_state.and(inv_bad, &mut graph); } + good_state = good_state.or(invariant_failed, &mut graph); if good_state.0 == Some(false) { println!("bad state found in frame {t}"); @@ -691,6 +719,16 @@ pub fn minimize( ); } + if options.satisfy_assumptions { + for invariant in aig.invariant_constraints.iter() { + let (var, polarity) = unpack_lit(*invariant); + let invariant_output = check_state[var].invert_if(polarity, &mut ()); + if invariant_output != Some(true) { + break 'frame; + } + } + } + for bad in aig.bad_state_properties.iter() { let (var, polarity) = unpack_lit(*bad); let bad_output = check_state[var].invert_if(polarity, &mut ()); diff --git a/tools/aigcexmin/src/main.rs b/tools/aigcexmin/src/main.rs index 30874db..b3192c5 100644 --- a/tools/aigcexmin/src/main.rs +++ b/tools/aigcexmin/src/main.rs @@ -41,6 +41,10 @@ pub struct Options { /// initialization but instead perform initialization using inputs in the first frame. #[clap(long)] latches: bool, + + /// Require assumptions to stay satisfied during minimization + #[clap(long)] + satisfy_assumptions: bool } #[derive(Copy, Clone, ValueEnum)] @@ -133,6 +137,7 @@ fn main() -> color_eyre::Result<()> { &mut writer_output, &care_graph::MinimizationOptions { fixed_init: !options.latches, + satisfy_assumptions: options.satisfy_assumptions, verify: match options.verify { VerificationOption::Off => None, VerificationOption::Cex => Some(care_graph::Verification::Cex), diff --git a/tools/cexenum/cexenum.py b/tools/cexenum/cexenum.py index 3ac8891..6b32145 100755 --- a/tools/cexenum/cexenum.py +++ b/tools/cexenum/cexenum.py @@ -3,22 +3,39 @@ from __future__ import annotations import asyncio import json +import threading import traceback import argparse import shutil import shlex import os +import sys +import urllib.parse +import random +from collections import Counter, defaultdict, deque from pathlib import Path -from typing import Any, Awaitable, Literal +from typing import Any, Awaitable, Iterable, Literal -import yosys_mau.task_loop.job_server as job -from yosys_mau import task_loop as tl +try: + import readline # type: ignore # noqa +except ImportError: + pass libexec = Path(__file__).parent.resolve() / "libexec" if libexec.exists(): os.environb[b"PATH"] = bytes(libexec) + b":" + os.environb[b"PATH"] + sys.path.insert(0, str(libexec)) + +if True: + import yosys_mau.task_loop.job_server as job + from yosys_mau import task_loop as tl + from yosys_mau.stable_set import StableSet + + +def safe_smtlib_id(name: str) -> str: + return f"|{urllib.parse.quote_plus(name).replace('+', ' ')}|" def arg_parser(): @@ -66,6 +83,58 @@ def arg_parser(): default="-s yices --unroll", ) + parser.add_argument( + "--callback", + metavar='"..."', + type=shlex.split, + help="command that will receive enumerated traces on stdin and can control " + "the enumeration via stdout (pass '-' to handle callbacks interactively)", + default="", + ) + + parser.add_argument( + "--track-assumes", + help="track individual assumptions and in case the assumptions cannot be " + "satisfied, report a subset of used assumptions that are in conflict", + action="store_true", + ) + parser.add_argument( + "--minimize-assumes", + help="when using --track-assumes, solve for a minimal set of sufficient " + "assumptions", + action="store_true", + ) + parser.add_argument( + "--satisfy-assumes", + help="when minimizing counter examples, keep design assumptions satisfied", + action="store_true", + ) + parser.add_argument( + "--diversify", + help="produce a more diverse set of counter examples early on", + action="store_true", + ) + + parser.add_argument( + "--diversify-seed", + metavar="", + help="random seed used for --diversify", + type=int, + default=1, + ) + parser.add_argument( + "--diversify-depth", + help="stop diversification after enumerating the given number of traces", + type=int, + ) + + parser.add_argument( + "--max-per-step", + metavar="", + help="advance to the next step after enumerating N traces", + type=int, + ) + parser.add_argument("--debug", action="store_true", help="enable debug logging") parser.add_argument( "--debug-events", action="store_true", help="enable debug event logging" @@ -83,7 +152,7 @@ def arg_parser(): return parser -def lines(*args): +def lines(*args: Any): return "".join(f"{line}\n" for line in args) @@ -94,9 +163,19 @@ class App: debug: bool = False debug_events: bool = False + track_assumes: bool = False + minimize_assumes: bool = False + satisfy_assumes: bool = False + depth: int enum_depth: int sim: bool + diversify: bool + diversify_seed: int + diversify_depth: int | None = None + max_per_step: int = 0 + + callback: list[str] smtbmc_options: list[str] @@ -148,10 +227,11 @@ def setup_logging(): tl.current_task().set_error_handler(None, error_handler) -async def batch(*args): +async def batch(*args: Awaitable[Any]): result = None for arg in args: - result = await arg + if arg is not None: + result = await arg return result @@ -200,7 +280,7 @@ class AigModel(tl.process.Process): "hierarchy -simcheck", "flatten", "setundef -undriven -anyseq", - "setattr -set keep 1 w:\*", + "setattr -set keep 1 w:\\*", "delete -output", "opt -full", "techmap", @@ -235,6 +315,7 @@ class MinimizeTrace(tl.Task): full_yw = App.trace_dir_full / self.trace_name min_yw = App.trace_dir_min / self.trace_name + min_x_yw = min_yw.with_suffix(".x.yw") stem = full_yw.stem @@ -266,26 +347,60 @@ class MinimizeTrace(tl.Task): App.cache_dir / "design_aiger.ywa", min_yw, cwd=App.trace_dir_min, + options=("--skip-x", "--present-only"), ) aiw2yw[tl.LogContext].scope = f"aiw2yw[{stem}]" aiw2yw.depends_on(aigcexmin) - if App.sim: - sim = SimTrace( - App.cache_dir / "design_aiger.il", - min_yw, - min_yw.with_suffix(".fst"), - cwd=App.trace_dir_min, - ) + self.aiw2yw_x = aiw2yw_x = YosysWitness( + "aiw2yw", + min_aiw, + App.cache_dir / "design_aiger.ywa", + min_x_yw, + cwd=App.trace_dir_min, + options=("--present-only",), + ) + aiw2yw_x[tl.LogContext].scope = f"aiw2yw[{stem}(x)]" + aiw2yw_x.depends_on(aigcexmin) - sim[tl.LogContext].scope = f"sim[{stem}]" - sim.depends_on(aiw2yw) + sim = SimTrace( + App.cache_dir / "design_aiger.il", + full_yw, + full_yw.with_suffix(".fst"), + cwd=App.trace_dir_full, + script_only=not App.sim, + ) + sim.depends_on(aig_model) + sim[tl.LogContext].scope = f"full-sim[{stem}]" + + sim = SimTrace( + App.cache_dir / "design_aiger.il", + min_x_yw, + min_yw.with_suffix(".fst"), + cwd=App.trace_dir_min, + script_only=not App.sim, + ) + + sim[tl.LogContext].scope = f"sim[{stem}]" + sim.depends_on(aiw2yw_x) def relative_to(target: Path, cwd: Path) -> Path: prefix = Path("") target = target.resolve() cwd = cwd.resolve() + + ok = False + for limit in (Path.cwd(), App.work_dir): + limit = Path.cwd().resolve() + try: + target.relative_to(limit) + ok = True + except ValueError: + pass + if not ok: + return target + while True: try: return prefix / (target.relative_to(cwd)) @@ -299,16 +414,18 @@ def relative_to(target: Path, cwd: Path) -> Path: class YosysWitness(tl.process.Process): def __init__( self, - mode: Literal["yw2aiw"] | Literal["aiw2yw"], + mode: Literal["yw2aiw", "aiw2yw"], input: Path, mapfile: Path, output: Path, cwd: Path, + options: Iterable[str] = (), ): super().__init__( [ "yosys-witness", mode, + *(options or []), str(relative_to(input, cwd)), str(relative_to(mapfile, cwd)), str(relative_to(output, cwd)), @@ -324,9 +441,13 @@ class YosysWitness(tl.process.Process): class AigCexMin(tl.process.Process): def __init__(self, design_aig: Path, input_aiw: Path, output_aiw: Path, cwd: Path): + aigcexmin_args = [] + if App.satisfy_assumes: + aigcexmin_args.append("--satisfy-assumptions") super().__init__( [ "aigcexmin", + *aigcexmin_args, str(relative_to(design_aig, cwd)), str(relative_to(input_aiw, cwd)), str(relative_to(output_aiw, cwd)), @@ -353,8 +474,16 @@ class AigCexMin(tl.process.Process): class SimTrace(tl.process.Process): - def __init__(self, design_il: Path, input_yw: Path, output_fst: Path, cwd: Path): + def __init__( + self, + design_il: Path, + input_yw: Path, + output_fst: Path, + cwd: Path, + script_only: bool = False, + ): self[tl.LogContext].scope = "sim" + self._script_only = script_only script_file = output_fst.with_suffix(".fst.ys") log_file = output_fst.with_suffix(".fst.log") @@ -380,108 +509,899 @@ class SimTrace(tl.process.Process): self.name = "sim" self.log_output() + async def on_run(self): + if self._script_only: + tl.log(f"manually run {self.shell_command} to generate a full trace") + else: + await super().on_run() + + +class Callback(tl.TaskGroup): + recover_from_errors = False + + def __init__(self, enumeration: Enumeration): + super().__init__() + self[tl.LogContext].scope = "callback" + self.search_next = False + self.enumeration = enumeration + self.tmp_counter = 0 + + async def step_callback(self, step: int) -> Literal["advance", "search"]: + with self.as_current_task(): + return await self._callback(step=step) + + async def unsat_callback(self, step: int) -> Literal["advance", "search"]: + with self.as_current_task(): + return await self._callback(step=step, unsat=True) + + async def trace_callback( + self, step: int, path: Path + ) -> Literal["advance", "search"]: + with self.as_current_task(): + return await self._callback(step=step, trace_path=path) + + def exit(self): + pass + + async def _callback( + self, + step: int, + trace_path: Path | None = None, + unsat: bool = False, + ) -> Literal["advance", "search"]: + if not self.is_active(): + if unsat: + return "advance" + return "search" + if trace_path is None and self.search_next: + if unsat: + return "advance" + return "search" + self.search_next = False + + info = dict(step=step, enabled=sorted(self.enumeration.active_assumptions)) + + if trace_path: + self.callback_write( + {**info, "event": "trace", "trace_path": str(trace_path)} + ) + elif unsat: + self.callback_write({**info, "event": "unsat"}) + else: + self.callback_write({**info, "event": "step"}) + + while True: + try: + response: dict[Any, Any] | Any = await self.callback_read() + if not isinstance(response, (Exception, dict)): + raise ValueError( + f"expected JSON object, got: {json.dumps(response)}" + ) + if isinstance(response, Exception): + raise response + + did_something = False + + if "block_yw" in response: + did_something = True + block_yw: Any = response["block_yw"] + + if not isinstance(block_yw, str): + raise ValueError( + "'block_yw' must be a string containing a file path, " + f"got: {json.dumps(block_yw)}" + ) + name: Any = response.get("name") + if name is not None and not isinstance(name, str): + raise ValueError( + "optional 'name' must be a string when present, " + "got: {json.dumps(name)}" + ) + self.enumeration.block_trace(Path(block_yw), name=name) + + if "block_aiw" in response: + did_something = True + block_aiw: Any = response["block_aiw"] + if not isinstance(block_aiw, str): + raise ValueError( + "'block_yw' must be a string containing a file path, " + f"got: {json.dumps(block_aiw)}" + ) + name: Any = response.get("name") + if name is not None and not isinstance(name, str): + raise ValueError( + "optional 'name' must be a string when present, " + "got: {json.dumps(name)}" + ) + + tmpdir = App.work_subdir / "tmp" + tmpdir.mkdir(exist_ok=True) + self.tmp_counter += 1 + block_yw = tmpdir / f"callback_{self.tmp_counter}.yw" + + aiw2yw = YosysWitness( + "aiw2yw", + Path(block_aiw), + App.cache_dir / "design_aiger.ywa", + Path(block_yw), + cwd=tmpdir, + options=("--skip-x", "--present-only"), + ) + aiw2yw[tl.LogContext].scope = f"aiw2yw[callback_{self.tmp_counter}]" + + await aiw2yw.finished + self.enumeration.block_trace(Path(block_yw), name=name) + + if "disable" in response: + did_something = True + name = response["disable"] + if not isinstance(name, str): + raise ValueError( + "'disable' must be a string representing an assumption, " + f"got: {json.dumps(name)}" + ) + self.enumeration.disable_assumption(name) + if "enable" in response: + did_something = True + name = response["enable"] + if not isinstance(name, str): + raise ValueError( + "'disable' must be a string representing an assumption, " + f"got: {json.dumps(name)}" + ) + self.enumeration.enable_assumption(name) + action: Any = response.get("action") + if action == "next": + did_something = True + self.search_next = True + action = "search" + if action in ("search", "advance"): + return action + if not did_something: + raise ValueError( + f"could not interpret callback response: {response}" + ) + except Exception as e: + tl.log_exception(e, raise_error=False) + if not self.recover_from_errors: + raise + + def is_active(self) -> bool: + return False + + def callback_write(self, data: Any): + return + + async def callback_read(self) -> Any: + raise NotImplementedError("must be implemented in Callback subclass") + + +class InteractiveCallback(Callback): + recover_from_errors = True + + interactive_shortcuts = { + "n": '{"action": "next"}', + "s": '{"action": "search"}', + "a": '{"action": "advance"}', + } + + def __init__(self, enumeration: Enumeration): + super().__init__(enumeration) + self.__eof_reached = False + + def is_active(self) -> bool: + return not self.__eof_reached + + def callback_write(self, data: Any): + print(f"callback: {json.dumps(data)}") + + async def callback_read(self) -> Any: + future: asyncio.Future[Any] = asyncio.Future() + + loop = asyncio.get_event_loop() + + def blocking_read(): + try: + try: + result = "" + while not result: + result = input( + "callback> " if sys.stdout.isatty() else "" + ).strip() + result = self.interactive_shortcuts.get(result, result) + result_data = json.loads(result) + except EOFError: + print() + self.__eof_reached = True + result_data = dict(action="next") + loop.call_soon_threadsafe(lambda: future.set_result(result_data)) + except Exception as exc: + exception = exc + loop.call_soon_threadsafe(lambda: future.set_exception(exception)) + + thread = threading.Thread(target=blocking_read, daemon=True) + thread.start() + + return await future + + +class ProcessCallback(Callback): + def __init__(self, enumeration: Enumeration, command: list[str]): + super().__init__(enumeration) + self[tl.LogContext].scope = "callback" + self.__eof_reached = False + self._command = command + self._lines: list[asyncio.Future[str]] = [asyncio.Future()] + + def exit(self): + self.process.close_stdin() + + async def on_prepare(self) -> None: + self.process = tl.Process(self._command, cwd=Path.cwd(), interact=True) + self.process.use_lease = False + + future_line: asyncio.Future[str] = self._lines[-1] + + def stdout_handler(event: tl.process.StdoutEvent): + nonlocal future_line + future_line.set_result(event.output) + future_line = asyncio.Future() + self._lines.append(future_line) + + self.process.sync_handle_events(tl.process.StdoutEvent, stdout_handler) + + def stderr_handler(event: tl.process.StderrEvent): + tl.log(event.output) + + self.process.sync_handle_events(tl.process.StderrEvent, stderr_handler) + + def exit_handler(event: tl.process.ExitEvent): + future_line.set_exception(EOFError("callback process exited")) + # Mark the exception as retrieved + future_line.exception() + + self.process.sync_handle_events(tl.process.ExitEvent, exit_handler) + + def is_active(self) -> bool: + return not self.__eof_reached + + def callback_write(self, data: Any): + if not self.process.is_finished: + self.process.write(json.dumps(data) + "\n") + + async def callback_read(self) -> Any: + future_line = self._lines.pop(0) + try: + data = json.loads(await future_line) + tl.log_debug(f"callback action: {data}") + return data + except EOFError: + self._lines.insert(0, future_line) + self.__eof_reached = True + return dict(action="next") + + +def enter_task(fn): + # TODO move into mau + import functools + + @functools.wraps(fn) + def wrapped_fn(self, *args, **kwargs): + with self.as_current_task(): + return fn(self, *args, **kwargs) + + return wrapped_fn + + +def inv_signal(sig): + return (*sig[:-1], "0" if sig[-1] != "0" else "1") + + +class Diversifier(tl.TaskGroup): + def __init__(self, seed, enumeration_depth): + super().__init__() + self._rng = random.Random(seed) + self._failed: StableSet[Any] = StableSet() + self._failed_sets: dict[Any, list[StableSet[Any]]] = defaultdict(list) + self._active: StableSet[Any] = StableSet() + + self._current_id = 0 + + self._targets: dict[int, StableSet[Any]] = {} + + self._sig_index: dict[Any, StableSet[int]] = defaultdict(StableSet) + + self._depth_limit = 0 + + self._enumeration_depth_left = enumeration_depth + + self._unsat_counter = 0 + + self[tl.LogContext].scope = "diversifier" + # self[tl.LogContext].level = "debug" + + def _new_id(self): + self._current_id += 1 + return self._current_id + + @enter_task + def new_step(self): + self._failed = StableSet() + self._failed_sets = defaultdict(list) + self._active = StableSet() + self._depth_limit = 0 + self._unsat_counter = 0 + + @enter_task + def diversify(self): + if self._enumeration_depth_left is not None: + if self._enumeration_depth_left <= 0: + return StableSet() + + selected = StableSet() + blocked = StableSet() + + pending = list(self._targets.values()) + + tl.log_debug(f"{len(pending)=} initial") + + depth = 0 + + deferred = [] + while pending or (deferred and depth < self._depth_limit): + if not pending: + for subspace in deferred: + new_subspace = StableSet( + sig for sig in subspace if sig[:-1] not in blocked + ) + if new_subspace: + pending.append(new_subspace) + deferred = [] + depth += 1 + continue + + counter = Counter() + + for subspace in self._shuffle_list(pending): + counter.update(self._shuffle_list(subspace)) + + sig, sig_count = max(counter.items(), key=lambda item: item[1]) + + tl.log_debug(f"{sig=} {sig_count=}") + + selected.add(sig) + blocked.add(sig[:-1]) + + failed = False + + for candidate in self._failed_sets[sig]: + if candidate.issubset(selected): + failed = True + tl.log_debug(f"failed {candidate=}") + break + + if not failed: + for candidate in self._failed_sets[sig]: + target = next(s for s in candidate if s not in selected) + self._failed_sets[target].append(candidate) + + self._failed_sets.pop(sig, None) + + if failed: + selected.remove(sig) + sig = None + + new_pending = [] + + for subspace in pending: + new_subspace = StableSet( + sig for sig in subspace if sig[:-1] not in blocked + ) + if not new_subspace: + continue + if sig in subspace: + deferred.append(new_subspace) + else: + new_pending.append(new_subspace) + pending = new_pending + + tl.log_debug(f"{len(pending)=} {selected=}") + + return selected + + def _add_target(self, signals): + signals.difference_update(self._failed) + if signals: + id = self._new_id() + self._targets[id] = signals + for sig in signals: + self._sig_index[sig].add(id) + + def _remove_target(self, id): + signals = self._targets.pop(id) + for sig in signals: + self._sig_index[sig].remove(id) + + def _remove_target_sig(self, id, sig): + self._targets[id].remove(sig) + self._sig_index[sig].remove(id) + + @enter_task + def update(self, signals): + if self._enumeration_depth_left is not None: + if self._enumeration_depth_left > 0: + self._enumeration_depth_left -= 1 + if self._enumeration_depth_left == 0: + tl.log( + "reached diversification depth, " + "disabling diversification for future traces" + ) + self._unsat_counter = 0 + inv_signals = StableSet(map(inv_signal, signals)) + self._add_target(inv_signals) + self._depth_limit += 1 + tl.log_debug(f"increased {self._depth_limit=}") + + @enter_task + def failed(self, failed): + self._unsat_counter += 1 + self._depth_limit = max(0, self._depth_limit - 1) + tl.log_debug(f"decreased {self._depth_limit=}") + + failed_sig = next(iter(failed)) + self._failed_sets[failed_sig].append(failed) + + if len(failed) == 1: + self._failed.add(failed_sig) + + for id in list(self._sig_index[failed_sig]): + self._remove_target_sig(id, failed_sig) + + tl.log(f"found {len(self._failed)} settled inputs") + + def should_minimize(self): + return self._unsat_counter > 2 + + def _shuffle_set(self, items): + return StableSet(self._shuffle_list(items)) + + def _shuffle_list(self, items): + as_list = list(items) + self._rng.shuffle(as_list) + return as_list + + +def yw_to_dicts(yw_trace): + """Convert a .yw JSON object to a list of dicts. + + Every item of the list corresponds to one step. For every step it contains a dict + mapping (*path, index) pairs to the bit value for bits that are present. + """ + + signal_bits = [] + for signal in yw_trace["signals"]: + chunk_offset = signal["offset"] + path = tuple(signal["path"]) + for i in range(signal["width"]): + signal_bits.append((*path, chunk_offset + i)) + + return [ + { + signal: value + for signal, value in zip(signal_bits, step["bits"][::-1]) + if value in "01" + } + for step in yw_trace["steps"] + ] + class Enumeration(tl.Task): + callback_mode: Literal["off", "interactive", "process"] + callback_auto_search: bool = False + def __init__(self, aig_model: tl.Task): self.aig_model = aig_model + self._pending_blocks: list[tuple[str | None, Path, StableSet[Any] | None]] = [] + self.named_assumptions: StableSet[str] = StableSet() + self.active_assumptions: StableSet[str] = StableSet() + + self._probes: dict[Any, str] = {} + self._probe_defs: dict[str, Any] = {} + if App.diversify: + self._diversifier = Diversifier(App.diversify_seed, App.diversify_depth) + + self._last_id: int = 0 + super().__init__() + def _new_id(self) -> int: + self._last_id += 1 + return self._last_id + + async def on_prepare(self) -> None: + if App.callback: + if App.callback == ["-"]: + self.callback_task = InteractiveCallback(self) + else: + self.callback_task = ProcessCallback(self, App.callback) + else: + self.callback_task = Callback(self) + async def on_run(self) -> None: - smtbmc = Smtbmc(App.work_dir / "model" / "design_smt2.smt2") + self.smtbmc = smtbmc = Smtbmc(App.work_dir / "model" / "design_smt2.smt2") + self._push_level = 0 await smtbmc.ping() - pred = None + try: + await self._bmc_loop() + finally: + smtbmc.close_stdin() + self.callback_task.exit() - i = 0 + async def _bmc_loop(self) -> None: + smtbmc = self.smtbmc + i = -1 limit = App.depth first_failure = None - while i <= limit: - tl.log(f"Checking assumptions in step {i}..") - presat_checked = await batch( - smtbmc.bmc_step(i, initial=i == 0, assertions=None, pred=pred), - smtbmc.check(), - ) - if presat_checked != "sat": - if first_failure is None: - tl.log_error("Assumptions are not satisfiable") - else: - tl.log("No further counter-examples are reachable") - return + checked = "skip" - tl.log(f"Checking assertions in step {i}..") - checked = await batch( - smtbmc.push(), - smtbmc.assertions(i, False), - smtbmc.check(), - ) - pred = i - if checked != "unsat": + counter = 0 + + while i <= limit or limit < 0: + if checked != "skip": + checked = await self._search_counter_example(i) + + if checked == "unsat": + if i >= 0: + action = await self.callback_task.unsat_callback(i) + if action == "search": + continue + checked = "skip" + if checked == "skip": + checked = "unsat" + i += 1 + if App.diversify: + self._diversifier.new_step() + if i > limit and limit >= 0: + break + action = await self.callback_task.step_callback(i) + + pending = batch( + self._top(), + smtbmc.bmc_step( + i, + initial=i == 0, + assertions=False, + pred=i - 1 if i else None, + ), + ) + + if action == "advance": + tl.log(f"Skipping step {i}") + await batch(pending, smtbmc.assertions(i)) + checked = "skip" + continue + assert action == "search" + + tl.log(f"Checking assumptions in step {i}..") + presat_checked = await batch( + pending, + smtbmc.check(), + ) + if presat_checked != "sat": + + if App.track_assumes: + tl.log("No further counter-examples are reachable") + + tl.log("Conflicting assumptions:") + unsat_assumptions = await smtbmc.get_unsat_assumptions( + App.minimize_assumes + ) + + for key in unsat_assumptions: + desc = await smtbmc.incremental_command( + cmd="get_design_assume", key=key + ) + if desc is None: + tl.log(f" Callback blocked trace {key[1]!r}") + continue + expr, path, info = desc + tl.log(f" In {path}: {info}") + + if not self.active_assumptions: + return + if first_failure is None and not self.active_assumptions: + tl.log_error("Assumptions are not satisfiable") + else: + tl.log("No further counter-examples are reachable") + if not self.active_assumptions: + return + + tl.log(f"Checking assertions in step {i}..") + counter = 0 + continue + elif checked == "sat": if first_failure is None: first_failure = i - limit = i + App.enum_depth + if App.enum_depth < 0: + limit = -1 + else: + limit = i + App.enum_depth tl.log("BMC failed! Enumerating counter-examples..") - counter = 0 - assert checked == "sat" path = App.trace_dir_full / f"trace{i}_{counter}.yw" + await smtbmc.incremental_command(cmd="write_yw_trace", path=str(path)) + tl.log(f"Written counter-example to {path.name}") - while checked == "sat": - await smtbmc.incremental_command( - cmd="write_yw_trace", path=str(path) + minimize = MinimizeTrace(path.name, self.aig_model) + minimize.depends_on(self.aig_model) + + await minimize.aiw2yw.finished + + min_path = App.trace_dir_min / f"trace{i}_{counter}.yw" + + action = await self.callback_task.trace_callback(i, min_path) + + if action == "search" and counter + 1 == App.max_per_step: + tl.log(f"Reached trace limit {App.max_per_step} for this step") + action = "advance" + + if action == "advance": + tl.log("Skipping remaining counter-examples for this step") + checked = "skip" + continue + assert action == "search" + + self.block_trace(min_path, diversify=App.diversify) + + counter += 1 + else: + tl.log_error(f"Unexpected solver result: {checked!r}") + + def block_trace(self, path: Path, name: str | None = None, diversify: bool = False): + if name is not None: + if name in self.named_assumptions: + raise ValueError(f"an assumption with name {name} was already defined") + self.named_assumptions.add(name) + self.active_assumptions.add(name) + + self._pending_blocks.append((name, path.absolute(), diversify)) + + def enable_assumption(self, name: str): + if name not in self.named_assumptions: + raise ValueError(f"unknown assumption {name!r}") + self.active_assumptions.add(name) + + def disable_assumption(self, name: str): + if name not in self.named_assumptions: + raise ValueError(f"unknown assumption {name!r}") + self.active_assumptions.discard(name) + self._diversifier.new_step() + + def _top(self) -> Awaitable[Any]: + return batch(*(self._pop() for _ in range(self._push_level))) + + def _pop(self) -> Awaitable[Any]: + self._push_level -= 1 + tl.log_debug(f"pop to {self._push_level}") + return self.smtbmc.pop() + + def _push(self) -> Awaitable[Any]: + self._push_level += 1 + tl.log_debug(f"push to {self._push_level}") + return self.smtbmc.push() + + async def _search_counter_example(self, step: int) -> str: + smtbmc = self.smtbmc + pending_blocks, self._pending_blocks = self._pending_blocks, [] + + pending = self._top() + + add_to_pending = [] + + for name, block_path, diversify in pending_blocks: + result = smtbmc.incremental_command( + cmd="read_yw_trace", + name="last", + path=str(block_path), + skip_x=True, + ) + + if diversify: + + signals = StableSet( + (s, step_signal, value) + for s, step_signals in enumerate( + yw_to_dicts(json.loads(block_path.read_text())) ) - tl.log(f"Written counter-example to {path.name}") + for step_signal, value in step_signals.items() + ) - minimize = MinimizeTrace(path.name, self.aig_model) - minimize.depends_on(self.aig_model) + self._diversifier.update(signals) - await minimize.aiw2yw.finished - - min_path = App.trace_dir_min / f"trace{i}_{counter}.yw" - - checked = await batch( - smtbmc.incremental_command( - cmd="read_yw_trace", - name="last", - path=str(min_path), - skip_x=True, - ), - smtbmc.assert_( - ["not", ["and", *(["yw", "last", k] for k in range(i + 1))]] - ), - smtbmc.check(), + async def check_yw_trace_len(): + last_step = (await result).get("last_step", step) + if last_step > step: + tl.log_warning( + f"Ignoring future time steps " + f"{step + 1} to {last_step} of " + f"{relative_to(block_path, Path.cwd())}" ) + return last_step - counter += 1 - path = App.trace_dir_full / f"trace{i}_{counter}.yw" + expr = [ + "not", + ["and", *(["yw", "last", k] for k in range(step + 1))], + ] - await batch(smtbmc.pop(), smtbmc.assertions(i)) + if name is not None: + name_id = safe_smtlib_id(f"cexenum trace {name}") + add_to_pending.append(smtbmc.smtlib(f"(declare-const {name_id} Bool)")) + expr = ["=", expr, ["smtlib", name_id, "Bool"]] - i += 1 + add_to_pending.append(check_yw_trace_len()) + add_to_pending.append(smtbmc.assert_(expr)) - smtbmc.close_stdin() + active_diversifications = StableSet() + + try: + while True: + diversifications = StableSet() + if App.diversify: + + diversifications = StableSet(self._diversifier.diversify()) + + for signal_def in active_diversifications.difference( + diversifications + ): + add_to_pending.append( + smtbmc.incremental_command( + cmd="update_assumptions", + key=("cexenum probe", self._probes[signal_def]), + expr=None, + ) + ) + + new_probes = False + + for signal_def in diversifications: + if signal_def in self._probes: + continue + + sig_step, (*path, offset), value = signal_def + probe_result = smtbmc.incremental_command( + cmd="define", + expr=[ + "=", + ["yw_sig", sig_step, path, offset], + ["bv", value], + ], + ) + + async def store_probe(signal_def, probe_result): + defined = await probe_result + self._probes[signal_def] = defined["name"] + self._probe_defs[defined["name"]] = signal_def + + add_to_pending.append(probe_result) + add_to_pending.append(store_probe(signal_def, probe_result)) + new_probes = True + + if new_probes: + await batch(pending, *add_to_pending) + pending, add_to_pending = None, [] + + for signal_def in diversifications.difference( + active_diversifications + ): + add_to_pending.append( + smtbmc.incremental_command( + cmd="update_assumptions", + key=("cexenum probe", self._probes[signal_def]), + expr=["def", self._probes[signal_def]], + ) + ) + + active_diversifications = diversifications + + add_to_pending.append(self._push()) + add_to_pending.append(smtbmc.assertions(step, False)) + + for name in self.active_assumptions: + name_id = safe_smtlib_id(f"cexenum trace {name}") + expr = ["smtlib", name_id, "Bool"] + if App.track_assumes: + add_to_pending.append( + smtbmc.incremental_command( + cmd="update_assumptions", + key=("cexenum trace", name), + expr=expr, + ) + ) + else: + add_to_pending.append(smtbmc.assert_(expr)) + + pending = batch(pending, *add_to_pending) + + result = await batch( + pending, + smtbmc.check(), + ) + pending, add_to_pending = None, [] + + if result == "sat" or not active_diversifications: + return result + + failed_assumptions = StableSet( + map( + tuple, + await smtbmc.get_unsat_assumptions( + minimize=self._diversifier.should_minimize() + ), + ) + ) + + failed_diversifications = StableSet() + + for key, expr in failed_assumptions: + if key == "cexenum probe": + failed_diversifications.add(self._probe_defs[expr]) + + if failed_diversifications: + self._diversifier.failed(failed_diversifications) + else: + return result + + add_to_pending.append(self._top()) + finally: + for signal_def in active_diversifications: + add_to_pending.append( + smtbmc.incremental_command( + cmd="update_assumptions", + key=("cexenum probe", self._probes[signal_def]), + expr=None, + ) + ) + + await batch(pending, *add_to_pending) class Smtbmc(tl.process.Process): def __init__(self, smt2_model: Path): self[tl.LogContext].scope = "smtbmc" + + smtbmc_options = [] + if App.track_assumes: + smtbmc_options += ["--track-assumes"] + if App.minimize_assumes: + smtbmc_options += ["--minimize-assumes"] + if App.diversify: + smtbmc_options += ["--smt2-option", ":produce-unsat-assumptions=true"] + + smtbmc_options += App.smtbmc_options super().__init__( [ "yosys-smtbmc", "--incremental", - *App.smtbmc_options, + "--noprogress", + *smtbmc_options, str(smt2_model), ], interact=True, ) self.name = "smtbmc" - self.expected_results = [] + self.expected_results: list[asyncio.Future[Any]] = [] async def on_run(self) -> None: - def output_handler(event: tl.process.StderrEvent): - result = json.loads(event.output) + def output_handler(event: tl.process.StdoutEvent): + line = event.output.strip() + if line.startswith("{"): + result = json.loads(event.output) + else: + result = dict(msg=line) tl.log_debug(f"smtbmc > {result!r}") if "err" in result: exception = tl.logging.LoggedError( @@ -501,11 +1421,11 @@ class Smtbmc(tl.process.Process): def ping(self) -> Awaitable[None]: return self.incremental_command(cmd="ping") - def incremental_command(self, **command: dict[Any]) -> Awaitable[Any]: + def incremental_command(self, **command: Any) -> Awaitable[Any]: tl.log_debug(f"smtbmc < {command!r}") self.write(json.dumps(command)) self.write("\n") - result = asyncio.Future() + result: asyncio.Future[Any] = asyncio.Future() self.expected_results.append(result) return result @@ -522,6 +1442,11 @@ class Smtbmc(tl.process.Process): def check(self) -> Awaitable[str]: return self.incremental_command(cmd="check") + def smtlib(self, command: str, response=False) -> Awaitable[str]: + return self.incremental_command( + cmd="smtlib", command=command, response=response + ) + def assert_antecedent(self, expr: Any) -> Awaitable[None]: return self.incremental_command(cmd="assert_antecedent", expr=expr) @@ -534,7 +1459,16 @@ class Smtbmc(tl.process.Process): def hierarchy(self, step: int) -> Awaitable[None]: return self.assert_antecedent(["mod_h", ["step", step]]) + def assert_design_assumes(self, step: int) -> Awaitable[None]: + return self.incremental_command(cmd="assert_design_assumes", step=step) + + def get_unsat_assumptions(self, minimize: bool) -> Awaitable[list[Any]]: + return self.incremental_command(cmd="get_unsat_assumptions", minimize=minimize) + def assumptions(self, step: int, valid: bool = True) -> Awaitable[None]: + if valid: + return self.assert_design_assumes(step) + expr = ["mod_u", ["step", step]] if not valid: expr = ["not", expr] @@ -562,20 +1496,22 @@ class Smtbmc(tl.process.Process): self, step: int, initial: bool = False, - assertions: bool | None = True, + assertions: bool = True, pred: int | None = None, + assumptions: bool = True, ) -> Awaitable[None]: - futures = [] + futures: list[Awaitable[None]] = [] futures.append(self.new_step(step)) futures.append(self.hierarchy(step)) - futures.append(self.assumptions(step)) + if assumptions: + futures.append(self.assumptions(step)) futures.append(self.initial(step, initial)) if pred is not None: futures.append(self.transition(pred, step)) - if assertions is not None: - futures.append(self.assertions(assertions)) + if assertions: + futures.append(self.assertions(step)) return batch(*futures)