add ExternModuleSimulationState::fork_join_scope

This commit is contained in:
Jacob Lifshay 2025-11-07 02:18:43 -08:00
parent fbc8ffa5ae
commit 45fea70c18
Signed by: programmerjake
SSH key fingerprint: SHA256:HnFTLGpSm4Q4Fj502oCFisjZSoakwEuTsJJMSke63RQ
4 changed files with 2325 additions and 1 deletions

View file

@ -49,7 +49,7 @@ use std::{
any::Any,
borrow::Cow,
cell::{Cell, RefCell},
collections::BTreeMap,
collections::{BTreeMap, BTreeSet},
fmt,
future::{Future, IntoFuture},
hash::Hash,
@ -3347,6 +3347,128 @@ impl ExternModuleSimulationState {
module_index,
}
}
pub async fn fork_join_scope<'env, F, Fut>(
&mut self,
in_scope: F,
) -> <Fut::IntoFuture as Future>::Output
where
F: FnOnce(ForkJoinScope<'env>, ExternModuleSimulationState) -> Fut,
Fut: IntoFuture<IntoFuture: 'env + Future<Output: 'env>>,
{
let scope = ForkJoinScope {
new_tasks: Rc::new(RefCell::new(vec![])),
sim: self.forked_state(),
};
let join_handle = scope.spawn(in_scope);
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
struct TaskId(u64);
struct TasksStateInner {
next_task_id: u64,
ready_tasks: BTreeSet<TaskId>,
not_ready_tasks: BTreeSet<TaskId>,
base_waker: std::task::Waker,
}
impl Default for TasksStateInner {
fn default() -> Self {
Self {
next_task_id: Default::default(),
ready_tasks: Default::default(),
not_ready_tasks: Default::default(),
base_waker: std::task::Waker::noop().clone(),
}
}
}
#[derive(Default)]
struct TasksState {
inner: Mutex<TasksStateInner>,
}
#[derive(Clone)]
struct TaskWaker {
state: std::sync::Weak<TasksState>,
task: TaskId,
}
impl std::task::Wake for TaskWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
let Some(state) = self.state.upgrade() else {
return;
};
let mut inner = state.inner.lock().expect("not poisoned");
if inner.not_ready_tasks.remove(&self.task) {
inner.ready_tasks.insert(self.task);
inner.base_waker.wake_by_ref();
}
}
}
struct Task<'env> {
task: Pin<Box<dyn Future<Output = ()> + 'env>>,
waker: std::task::Waker,
}
let mut tasks: BTreeMap<TaskId, Task> = BTreeMap::new();
let tasks_state = Arc::new(TasksState::default());
std::future::poll_fn(move |cx: &mut std::task::Context<'_>| {
let mut state_inner = tasks_state.inner.lock().expect("not poisoned");
state_inner.base_waker.clone_from(cx.waker());
loop {
for new_task in scope.new_tasks.borrow_mut().drain(..) {
let task_id = TaskId(state_inner.next_task_id);
let Some(next_task_id) = state_inner.next_task_id.checked_add(1) else {
drop(state_inner);
panic!("spawned too many tasks");
};
state_inner.next_task_id = next_task_id;
state_inner.ready_tasks.insert(task_id);
tasks.insert(
task_id,
Task {
task: new_task,
waker: Arc::new(TaskWaker {
state: Arc::downgrade(&tasks_state),
task: task_id,
})
.into(),
},
);
}
let Some(task_id) = state_inner.ready_tasks.pop_first() else {
if state_inner.not_ready_tasks.is_empty() {
return Poll::Ready(());
} else {
return Poll::Pending;
};
};
state_inner.not_ready_tasks.insert(task_id); // task can be woken while we're running poll
drop(state_inner);
let std::collections::btree_map::Entry::Occupied(mut entry) = tasks.entry(task_id)
else {
unreachable!();
};
let Task { task, waker } = entry.get_mut();
match task.as_mut().poll(&mut std::task::Context::from_waker(
&std::task::Waker::from(waker.clone()),
)) {
Poll::Pending => {
state_inner = tasks_state.inner.lock().expect("not poisoned");
continue;
}
Poll::Ready(()) => {}
}
drop(entry.remove()); // drop outside lock
state_inner = tasks_state.inner.lock().expect("not poisoned");
state_inner.not_ready_tasks.remove(&task_id);
state_inner.ready_tasks.remove(&task_id);
}
})
.await;
match &mut *join_handle.state.borrow_mut() {
JoinHandleState::Running(_) => unreachable!(),
JoinHandleState::Finished(state) => state
.take()
.expect("filled by running all futures to completion"),
}
}
impl_simulation_methods!(
async_await = (async, await),
track_caller = (),
@ -3354,6 +3476,125 @@ impl ExternModuleSimulationState {
);
}
pub struct ForkJoinScope<'env> {
new_tasks: Rc<RefCell<Vec<Pin<Box<dyn Future<Output = ()> + 'env>>>>>,
sim: ExternModuleSimulationState,
}
impl<'env> Clone for ForkJoinScope<'env> {
fn clone(&self) -> Self {
Self {
new_tasks: self.new_tasks.clone(),
sim: self.sim.forked_state(),
}
}
}
impl<'env> ForkJoinScope<'env> {
fn spawn_inner(&self, fut: Pin<Box<dyn Future<Output = ()> + 'env>>) {
self.new_tasks.borrow_mut().push(fut);
}
pub fn spawn_detached_future(
&self,
fut: impl IntoFuture<IntoFuture: 'env + Future<Output = ()>>,
) {
self.spawn_inner(Box::pin(fut.into_future()));
}
pub fn spawn_detached<F, Fut>(&self, f: F)
where
F: FnOnce(ForkJoinScope<'env>, ExternModuleSimulationState) -> Fut,
Fut: IntoFuture<IntoFuture: 'env + Future<Output = ()>>,
{
self.spawn_detached_future(f(self.clone(), self.sim.forked_state()));
}
pub fn spawn<F, Fut>(&self, f: F) -> JoinHandle<Fut::Output>
where
F: FnOnce(ForkJoinScope<'env>, ExternModuleSimulationState) -> Fut,
Fut: IntoFuture<IntoFuture: 'env + Future<Output: 'env>>,
{
let join_handle = JoinHandle {
state: Default::default(),
};
let state = Rc::downgrade(&join_handle.state);
let fut = f(self.clone(), self.sim.forked_state()).into_future();
self.spawn_detached_future(async move {
let result = fut.await;
let Some(state) = state.upgrade() else { return };
let mut state = state.borrow_mut();
let waker = match &mut *state {
JoinHandleState::Running(waker) => waker.take(),
JoinHandleState::Finished(_) => unreachable!(),
};
*state = JoinHandleState::Finished(Some(result));
drop(state);
let Some(waker) = waker else { return };
waker.wake();
});
join_handle
}
}
enum JoinHandleState<T> {
Running(Option<std::task::Waker>),
Finished(Option<T>),
}
impl<T> Default for JoinHandleState<T> {
fn default() -> Self {
Self::Running(None)
}
}
pub struct JoinHandle<T> {
state: Rc<RefCell<JoinHandleState<T>>>,
}
impl<T> JoinHandle<T> {
pub fn is_finished(&self) -> bool {
matches!(*self.state.borrow(), JoinHandleState::Finished(_))
}
pub fn try_join(self) -> Result<T, Self> {
let mut state = self.state.borrow_mut();
match &mut *state {
JoinHandleState::Running(_) => {
drop(state);
Err(self)
}
JoinHandleState::Finished(retval) => {
let Some(retval) = retval.take() else {
panic!("already returned the value in poll");
};
Ok(retval)
}
}
}
pub async fn join(self) -> T {
self.await
}
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
match &mut *self.state.borrow_mut() {
JoinHandleState::Running(waker) => {
match waker {
None => *waker = Some(cx.waker().clone()),
Some(waker) => waker.clone_from(cx.waker()),
}
Poll::Pending
}
JoinHandleState::Finished(retval) => {
let Some(retval) = retval.take() else {
panic!("already returned Poll::Ready");
};
Poll::Ready(retval)
}
}
}
}
struct ForkJoinImpl<'a> {
futures: Vec<Pin<Box<dyn Future<Output = ()> + 'a>>>,
}

View file

@ -2109,6 +2109,99 @@ fn test_sim_fork_join() {
}
}
#[hdl_module(outline_generated, extern)]
pub fn sim_fork_join_scope<const N: usize>()
where
ConstUsize<N>: KnownSize,
{
#[hdl]
let clocks: Array<Clock, N> = m.input();
#[hdl]
let outputs: Array<UInt<8>, N> = m.output();
m.extern_module_simulation_fn((clocks, outputs), |(clocks, outputs), mut sim| async move {
sim.write(outputs, [0u8; N]).await;
loop {
let written = vec![std::cell::Cell::new(false); N]; // test shared scope
let written = &written; // work around move in async move
sim.fork_join_scope(|scope, _| async move {
let mut spawned = vec![];
for i in 0..N {
let join_handle =
scope.spawn(move |_, mut sim: ExternModuleSimulationState| async move {
sim.wait_for_clock_edge(clocks[i]).await;
let v = sim
.read_bool_or_int(outputs[i])
.await
.to_bigint()
.try_into()
.expect("known to be in range");
sim.write(outputs[i], 1u8.wrapping_add(v)).await;
written[i].set(true);
i
});
if i % 2 == 0 && i < N - 1 {
spawned.push((i, join_handle));
}
}
for (i, join_handle) in spawned {
assert_eq!(i, join_handle.join().await);
}
})
.await;
for written in written {
assert!(written.get());
}
}
});
}
#[test]
fn test_sim_fork_join_scope() {
let _n = SourceLocation::normalize_files_for_tests();
const N: usize = 3;
let mut sim = Simulation::new(sim_fork_join_scope::<N>());
let mut writer = RcWriter::default();
sim.add_trace_writer(VcdWriterDecls::new(writer.clone()));
sim.write(sim.io().clocks, [false; N]);
let mut clocks_triggered = [false; N];
let mut expected = [0u8; N];
for i0 in 0..N {
for i1 in 0..N {
for i2 in 0..N {
for i3 in 0..N {
let indexes = [i0, i1, i2, i3];
for i in indexes {
sim.advance_time(SimDuration::from_micros(1));
sim.write(sim.io().clocks[i], true);
sim.advance_time(SimDuration::from_micros(1));
sim.write(sim.io().clocks[i], false);
if !clocks_triggered[i] {
expected[i] = expected[i].wrapping_add(1);
}
clocks_triggered[i] = true;
if clocks_triggered == [true; N] {
clocks_triggered = [false; N];
}
let output = sim.read(sim.io().outputs);
assert_eq!(output, expected.to_sim_value(), "indexes={indexes:?} i={i}");
}
}
}
}
}
sim.flush_traces().unwrap();
let vcd = String::from_utf8(writer.take()).unwrap();
println!("####### VCD:\n{vcd}\n#######");
if vcd != include_str!("sim/expected/sim_fork_join_scope.vcd") {
panic!();
}
let sim_debug = format!("{sim:#?}");
println!("#######\n{sim_debug}\n#######");
if sim_debug != include_str!("sim/expected/sim_fork_join_scope.txt") {
panic!();
}
}
#[hdl_module(outline_generated, extern)]
pub fn sim_resettable_counter<R: ResetType>() {
#[hdl]

View file

@ -0,0 +1,523 @@
Simulation {
state: State {
insns: Insns {
state_layout: StateLayout {
ty: TypeLayout {
small_slots: StatePartLayout<SmallSlots> {
len: 0,
debug_data: [],
..
},
big_slots: StatePartLayout<BigSlots> {
len: 6,
debug_data: [
SlotDebugData {
name: "InstantiatedModule(sim_fork_join_scope: sim_fork_join_scope).sim_fork_join_scope::clocks[0]",
ty: Clock,
},
SlotDebugData {
name: "InstantiatedModule(sim_fork_join_scope: sim_fork_join_scope).sim_fork_join_scope::clocks[1]",
ty: Clock,
},
SlotDebugData {
name: "InstantiatedModule(sim_fork_join_scope: sim_fork_join_scope).sim_fork_join_scope::clocks[2]",
ty: Clock,
},
SlotDebugData {
name: "InstantiatedModule(sim_fork_join_scope: sim_fork_join_scope).sim_fork_join_scope::outputs[0]",
ty: UInt<8>,
},
SlotDebugData {
name: "InstantiatedModule(sim_fork_join_scope: sim_fork_join_scope).sim_fork_join_scope::outputs[1]",
ty: UInt<8>,
},
SlotDebugData {
name: "InstantiatedModule(sim_fork_join_scope: sim_fork_join_scope).sim_fork_join_scope::outputs[2]",
ty: UInt<8>,
},
],
..
},
sim_only_slots: StatePartLayout<SimOnlySlots> {
len: 0,
debug_data: [],
layout_data: [],
..
},
},
memories: StatePartLayout<Memories> {
len: 0,
debug_data: [],
layout_data: [],
..
},
},
insns: [
// at: module-XXXXXXXXXX.rs:1:1
0: Return,
],
..
},
pc: 0,
memory_write_log: [],
memories: StatePart {
value: [],
},
small_slots: StatePart {
value: [],
},
big_slots: StatePart {
value: [
0,
0,
0,
49,
50,
50,
],
},
sim_only_slots: StatePart {
value: [],
},
},
io: Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
},
main_module: SimulationModuleState {
base_targets: [
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.clocks,
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.outputs,
],
uninitialized_ios: {},
io_targets: {
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.clocks,
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.clocks[0],
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.clocks[1],
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.clocks[2],
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.outputs,
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.outputs[0],
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.outputs[1],
Instance {
name: <simulator>::sim_fork_join_scope,
instantiated: Module {
name: sim_fork_join_scope,
..
},
}.outputs[2],
},
did_initial_settle: true,
},
extern_modules: [
SimulationExternModuleState {
module_state: SimulationModuleState {
base_targets: [
ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
},
ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
},
],
uninitialized_ios: {},
io_targets: {
ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
},
ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
}[0],
ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
}[1],
ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
}[2],
ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
},
ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
}[0],
ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
}[1],
ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
}[2],
},
did_initial_settle: true,
},
sim: ExternModuleSimulation {
generator: SimGeneratorFn {
args: (
ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
},
ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
},
),
f: ...,
},
sim_io_to_generator_map: {
ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
}: ModuleIO {
name: sim_fork_join_scope::clocks,
is_input: true,
ty: Array<Clock, 3>,
..
},
ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
}: ModuleIO {
name: sim_fork_join_scope::outputs,
is_input: false,
ty: Array<UInt<8>, 3>,
..
},
},
source_location: SourceLocation(
module-XXXXXXXXXX.rs:4:1,
),
},
running_generator: Some(
...,
),
},
],
trace_decls: TraceModule {
name: "sim_fork_join_scope",
children: [
TraceModuleIO {
name: "clocks",
child: TraceArray {
name: "clocks",
elements: [
TraceClock {
location: TraceScalarId(0),
name: "[0]",
flow: Source,
},
TraceClock {
location: TraceScalarId(1),
name: "[1]",
flow: Source,
},
TraceClock {
location: TraceScalarId(2),
name: "[2]",
flow: Source,
},
],
ty: Array<Clock, 3>,
flow: Source,
},
ty: Array<Clock, 3>,
flow: Source,
},
TraceModuleIO {
name: "outputs",
child: TraceArray {
name: "outputs",
elements: [
TraceUInt {
location: TraceScalarId(3),
name: "[0]",
ty: UInt<8>,
flow: Sink,
},
TraceUInt {
location: TraceScalarId(4),
name: "[1]",
ty: UInt<8>,
flow: Sink,
},
TraceUInt {
location: TraceScalarId(5),
name: "[2]",
ty: UInt<8>,
flow: Sink,
},
],
ty: Array<UInt<8>, 3>,
flow: Sink,
},
ty: Array<UInt<8>, 3>,
flow: Sink,
},
],
},
traces: [
SimTrace {
id: TraceScalarId(0),
kind: BigClock {
index: StatePartIndex<BigSlots>(0),
},
state: 0x0,
last_state: 0x0,
},
SimTrace {
id: TraceScalarId(1),
kind: BigClock {
index: StatePartIndex<BigSlots>(1),
},
state: 0x0,
last_state: 0x0,
},
SimTrace {
id: TraceScalarId(2),
kind: BigClock {
index: StatePartIndex<BigSlots>(2),
},
state: 0x0,
last_state: 0x1,
},
SimTrace {
id: TraceScalarId(3),
kind: BigUInt {
index: StatePartIndex<BigSlots>(3),
ty: UInt<8>,
},
state: 0x31,
last_state: 0x31,
},
SimTrace {
id: TraceScalarId(4),
kind: BigUInt {
index: StatePartIndex<BigSlots>(4),
ty: UInt<8>,
},
state: 0x32,
last_state: 0x32,
},
SimTrace {
id: TraceScalarId(5),
kind: BigUInt {
index: StatePartIndex<BigSlots>(5),
ty: UInt<8>,
},
state: 0x32,
last_state: 0x32,
},
],
trace_memories: {},
trace_writers: [
Running(
VcdWriter {
finished_init: true,
timescale: 1 ps,
..
},
),
],
clocks_triggered: [],
event_queue: EventQueue(EventQueueData {
instant: 648 μs,
events: {},
}),
waiting_sensitivity_sets_by_address: {
SensitivitySet {
id: 198,
values: {
CompiledValue {
layout: CompiledTypeLayout {
ty: Clock,
layout: TypeLayout {
small_slots: StatePartLayout<SmallSlots> {
len: 0,
debug_data: [],
..
},
big_slots: StatePartLayout<BigSlots> {
len: 1,
debug_data: [
SlotDebugData {
name: "",
ty: Clock,
},
],
..
},
sim_only_slots: StatePartLayout<SimOnlySlots> {
len: 0,
debug_data: [],
layout_data: [],
..
},
},
body: Scalar,
},
range: TypeIndexRange {
small_slots: StatePartIndexRange<SmallSlots> { start: 0, len: 0 },
big_slots: StatePartIndexRange<BigSlots> { start: 0, len: 1 },
sim_only_slots: StatePartIndexRange<SimOnlySlots> { start: 0, len: 0 },
},
write: None,
}: SimValue {
ty: Clock,
value: OpaqueSimValue {
bits: 0x0_u1,
sim_only_values: [],
},
},
},
changed: Cell {
value: false,
},
..
},
},
waiting_sensitivity_sets_by_compiled_value: {
CompiledValue {
layout: CompiledTypeLayout {
ty: Clock,
layout: TypeLayout {
small_slots: StatePartLayout<SmallSlots> {
len: 0,
debug_data: [],
..
},
big_slots: StatePartLayout<BigSlots> {
len: 1,
debug_data: [
SlotDebugData {
name: "",
ty: Clock,
},
],
..
},
sim_only_slots: StatePartLayout<SimOnlySlots> {
len: 0,
debug_data: [],
layout_data: [],
..
},
},
body: Scalar,
},
range: TypeIndexRange {
small_slots: StatePartIndexRange<SmallSlots> { start: 0, len: 0 },
big_slots: StatePartIndexRange<BigSlots> { start: 0, len: 1 },
sim_only_slots: StatePartIndexRange<SimOnlySlots> { start: 0, len: 0 },
},
write: None,
}: (
SimValue {
ty: Clock,
value: OpaqueSimValue {
bits: 0x0_u1,
sim_only_values: [],
},
},
{
SensitivitySet {
id: 198,
..
},
},
),
},
..
}

File diff suppressed because it is too large Load diff