From 59be3bd645cdca9b6d4276b0e8141375075b6785 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Sun, 24 Nov 2024 03:44:31 -0800 Subject: [PATCH] WIP working on implementing deduce_resets pass --- crates/fayalite/src/module.rs | 16 +- .../src/module/transform/deduce_resets.rs | 709 +++++++++++++++--- 2 files changed, 637 insertions(+), 88 deletions(-) diff --git a/crates/fayalite/src/module.rs b/crates/fayalite/src/module.rs index b8610e7..915bf43 100644 --- a/crates/fayalite/src/module.rs +++ b/crates/fayalite/src/module.rs @@ -180,7 +180,7 @@ impl Block { } } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct StmtConnect { pub lhs: Expr, pub rhs: Expr, @@ -235,7 +235,7 @@ impl fmt::Debug for StmtConnect { } } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct StmtFormal { pub kind: FormalKind, pub clk: Expr, @@ -284,6 +284,8 @@ pub struct StmtIf { pub blocks: [S::Block; 2], } +impl Copy for StmtIf {} + impl StmtIf { pub fn then_block(&self) -> S::Block { self.blocks[0] @@ -315,6 +317,8 @@ pub struct StmtMatch { pub blocks: Interned<[S::Block]>, } +impl Copy for StmtMatch {} + impl StmtMatch { #[track_caller] fn assert_validity(&self) { @@ -459,6 +463,8 @@ pub struct StmtWire { pub wire: Wire, } +impl Copy for StmtWire {} + #[derive(Hash, Clone, PartialEq, Eq, Debug)] pub struct StmtReg { pub annotations: S::StmtAnnotations, @@ -473,6 +479,8 @@ pub struct StmtInstance { pub instance: Instance, } +impl Copy for StmtInstance {} + wrapper_enum! { #[impl( () self: StmtDeclaration = self, @@ -490,6 +498,8 @@ wrapper_enum! { } } +impl Copy for StmtDeclaration {} + impl StmtDeclaration { pub fn annotations(&self) -> S::StmtAnnotations { match self { @@ -546,6 +556,8 @@ wrapper_enum! { } } +impl Copy for Stmt {} + impl Stmt { pub fn sub_stmt_blocks(&self) -> &[S::Block] { match self { diff --git a/crates/fayalite/src/module/transform/deduce_resets.rs b/crates/fayalite/src/module/transform/deduce_resets.rs index dc01b33..9b5cee3 100644 --- a/crates/fayalite/src/module/transform/deduce_resets.rs +++ b/crates/fayalite/src/module/transform/deduce_resets.rs @@ -2,20 +2,28 @@ // See Notices.txt for copyright information use hashbrown::{hash_map::Entry, HashMap}; -use petgraph::graph::{NodeIndex, UnGraph}; +use petgraph::{ + graph::{NodeIndex, UnGraph}, + unionfind::UnionFind, +}; use crate::{ bundle::{BundleField, BundleType}, enum_::{EnumType, EnumVariant}, - expr::target::{TargetPathArrayElement, TargetPathBundleField, TargetPathElement}, + expr::{ + ops, + target::{TargetBase, TargetPathArrayElement, TargetPathBundleField, TargetPathElement}, + Flow, + }, intern::{Intern, Interned, Memoize}, module::{ - AnnotatedModuleIO, ExprInInstantiatedModule, ExternModuleBody, InstantiatedModule, - ModuleBody, NormalModuleBody, TargetInInstantiatedModule, + AnnotatedModuleIO, Block, ExprInInstantiatedModule, ExternModuleBody, InstantiatedModule, + ModuleBody, ModuleIO, NormalModuleBody, Stmt, StmtConnect, StmtDeclaration, StmtFormal, + StmtIf, StmtInstance, StmtMatch, StmtReg, StmtWire, TargetInInstantiatedModule, }, prelude::*, }; -use std::fmt; +use std::{convert::Infallible, fmt}; #[derive(Debug)] pub enum DeduceResetsError { @@ -40,18 +48,132 @@ impl From for std::io::Error { } } +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +enum ResetTarget { + Base { + base: Interned, + }, + BundleField { + parent: Interned, + field_ty: CanonicalType, + field_index: usize, + }, + EnumVariant { + parent: Interned, + variant_ty: CanonicalType, + variant_index: usize, + }, + /// Array's Elements: + /// deduce_resets requires all array elements to deduce to the same pattern of async/sync resets, + /// so we don't track individual array elements but instead track all of an array's elements together. + ArraysElements { + parent: Interned, + element_ty: CanonicalType, + }, +} + +impl ResetTarget { + fn canonical_ty(self) -> CanonicalType { + match self { + ResetTarget::Base { base } => base.canonical_ty(), + ResetTarget::BundleField { field_ty, .. } => field_ty, + ResetTarget::EnumVariant { variant_ty, .. } => variant_ty, + ResetTarget::ArraysElements { element_ty, .. } => element_ty, + } + } + fn parent(self) -> Option> { + match self { + ResetTarget::Base { .. } => None, + ResetTarget::BundleField { parent, .. } + | ResetTarget::EnumVariant { parent, .. } + | ResetTarget::ArraysElements { parent, .. } => Some(parent), + } + } + fn base(mut self) -> Interned { + loop { + match self { + ResetTarget::Base { base } => break base, + ResetTarget::BundleField { parent, .. } + | ResetTarget::EnumVariant { parent, .. } + | ResetTarget::ArraysElements { parent, .. } => self = *parent, + } + } + } + fn bundle_field(self, field_index: usize) -> Self { + let field_ty = Bundle::from_canonical(self.canonical_ty()).fields()[field_index].ty; + Self::BundleField { + parent: self.intern_sized(), + field_ty, + field_index, + } + } + fn enum_variant(self, variant_index: usize) -> Self { + let variant_ty = Enum::from_canonical(self.canonical_ty()).variants()[variant_index] + .ty + .expect("known to have a variant field"); + Self::EnumVariant { + parent: self.intern_sized(), + variant_ty, + variant_index, + } + } + fn arrays_elements(self) -> Self { + let element_ty = ::from_canonical(self.canonical_ty()).element(); + Self::ArraysElements { + parent: self.intern_sized(), + element_ty, + } + } + fn for_each_child(self, mut f: impl FnMut(ResetTarget) -> Result<(), E>) -> Result<(), E> { + match self.canonical_ty() { + CanonicalType::UInt(_) + | CanonicalType::SInt(_) + | CanonicalType::Bool(_) + | CanonicalType::AsyncReset(_) + | CanonicalType::SyncReset(_) + | CanonicalType::Reset(_) + | CanonicalType::Clock(_) => Ok(()), + CanonicalType::Array(_) => f(self.arrays_elements()), + CanonicalType::Enum(ty) => { + for variant_index in 0..ty.variants().len() { + f(self.enum_variant(variant_index))?; + } + Ok(()) + } + CanonicalType::Bundle(ty) => { + for field_index in 0..ty.fields().len() { + f(self.bundle_field(field_index))?; + } + Ok(()) + } + } + } +} + +impl> From for ResetTarget { + fn from(base: T) -> Self { + ResetTarget::Base { + base: TargetBase::intern_sized(base.into()), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +struct ResetTargetInInstantiatedModule { + instantiated_module: InstantiatedModule, + target: ResetTarget, +} + #[derive(Debug)] struct Node { - target: TargetInInstantiatedModule, + target: ResetTargetInInstantiatedModule, deduced_type: Option, } #[derive(Debug)] struct State { base_module: Interned>, - transformed_exprs: - HashMap, ExprInInstantiatedModule>, - node_ids: HashMap>, + node_ids: HashMap>, graph: UnGraph, fallback_to_sync_reset: bool, } @@ -89,88 +211,499 @@ fn type_contains_any_undeduced_resets(ty: CanonicalType) -> bool { MyMemoize.get_owned(ty) } -impl State { - fn add_target_to_graph_recursive( - &mut self, - target: TargetInInstantiatedModule, - ) -> NodeIndex { - let entry = match self.node_ids.entry(target) { - Entry::Vacant(entry) => entry, - Entry::Occupied(entry) => { - return *entry.get(); - } +trait ProcessStep { + type Error; + type Processed; + fn processed_as_ref(v: &Self::Processed) -> Self::Processed<&T>; + fn processed_as_mut(v: &mut Self::Processed) -> Self::Processed<&mut T>; + fn processed_zip(t: Self::Processed, u: Self::Processed) + -> Self::Processed<(T, U)>; + fn processed_map(v: Self::Processed, f: impl FnOnce(T) -> U) -> Self::Processed; + fn processed_make(t: T) -> Self::Processed; + fn processed_from_iter>( + iter: impl IntoIterator>, + ) -> Self::Processed; +} + +macro_rules! impl_process_step_with_empty_processed { + () => { + type Processed = (); + fn processed_as_ref(v: &Self::Processed) -> Self::Processed<&T> { + *v + } + fn processed_as_mut(v: &mut Self::Processed) -> Self::Processed<&mut T> { + *v + } + fn processed_zip( + t: Self::Processed, + u: Self::Processed, + ) -> Self::Processed<(T, U)> { + let _ = t; + let _ = u; + () + } + fn processed_map( + v: Self::Processed, + f: impl FnOnce(T) -> U, + ) -> Self::Processed { + let _ = f; + v + } + fn processed_make(t: T) -> Self::Processed { + let _ = t; + () + } + fn processed_from_iter>( + iter: impl IntoIterator>, + ) -> Self::Processed { + FromIterator::from_iter(iter) + } + }; +} + +struct Processed(Step::Processed); + +impl Processed { + fn as_ref(&self) -> Processed<&T, Step> { + Processed(Step::processed_as_ref(&self.0)) + } + fn as_mut(&mut self) -> Processed<&mut T, Step> { + Processed(Step::processed_as_mut(&mut self.0)) + } + fn zip(self, u: Processed) -> Processed<(T, U), Step> { + Processed(Step::processed_zip(self.0, u.0)) + } + fn new(v: T) -> Self { + Processed(Step::processed_make(v)) + } + fn map(self, f: impl FnOnce(T) -> U) -> Processed { + Processed(Step::processed_map(self.0, f)) + } +} + +impl, A, Step: ProcessStep> FromIterator> + for Processed +{ + fn from_iter>>(iter: T) -> Self { + Processed(Step::processed_from_iter(iter.into_iter().map(|v| v.0))) + } +} + +struct AddNodesToGraphStep; + +impl ProcessStep for AddNodesToGraphStep { + type Error = Infallible; + impl_process_step_with_empty_processed!(); +} + +struct AddEdgesToGraphStep { + union_find: UnionFind, +} + +impl ProcessStep for AddEdgesToGraphStep { + type Error = Infallible; + impl_process_step_with_empty_processed!(); +} + +trait RunProcessStep: Sized { + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error>; +} + +impl RunProcessStep for ResetTarget { + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut AddNodesToGraphStep, + ) -> Result, Infallible> { + let target = ResetTargetInInstantiatedModule { + instantiated_module, + target: self, }; - let ty = target.target.canonical_ty(); - let node_id = self.graph.add_node(Node { + let Entry::Vacant(entry) = state.node_ids.entry(target) else { + return Ok(Processed(())); + }; + let ty = self.canonical_ty(); + let node_id = state.graph.add_node(Node { target, deduced_type: type_contains_any_undeduced_resets(ty).then_some(ty), }); entry.insert(node_id); - match target.target.canonical_ty() { - CanonicalType::UInt(_) - | CanonicalType::SInt(_) - | CanonicalType::Bool(_) - | CanonicalType::AsyncReset(_) - | CanonicalType::SyncReset(_) - | CanonicalType::Clock(_) => {} - CanonicalType::Array(ty) => { - for index in 0..ty.len() { - self.add_target_to_graph_recursive(TargetInInstantiatedModule { - instantiated_module: target.instantiated_module, - target: target.target.join( - TargetPathElement::from(TargetPathArrayElement { index }) - .intern_sized(), - ), - }); - } - } - CanonicalType::Enum(_) => {} - CanonicalType::Bundle(ty) => { - for BundleField { name, .. } in ty.fields() { - self.add_target_to_graph_recursive(TargetInInstantiatedModule { - instantiated_module: target.instantiated_module, - target: target.target.join( - TargetPathElement::from(TargetPathBundleField { name }).intern_sized(), - ), - }); - } - } - CanonicalType::Reset(_) => {} - } - node_id + self.for_each_child(|target| -> Result<(), Infallible> { + target.run_process_step(instantiated_module, state, step)?; + Ok(()) + })?; + Ok(Processed(())) } - fn build_graph_for_module(&mut self, instantiated_module: Interned) { - let Module { - name: _, - source_location: _, - body, - io_ty: _, - module_io, - module_annotations: _, - } = *instantiated_module.leaf_module(); - for AnnotatedModuleIO { - annotations: _, - module_io, - } in module_io - { - self.add_target_to_graph_recursive(TargetInInstantiatedModule { - instantiated_module: *instantiated_module, - target: module_io.into(), - }); +} + +struct ConnectAndLhsInstantiatedModule { + lhs_instantiated_module: InstantiatedModule, + lhs: Expr, + rhs: Expr, + source_location: SourceLocation, +} + +impl RunProcessStep for ConnectAndLhsInstantiatedModule { + fn run_process_step( + self, + rhs_instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + let Self { + lhs_instantiated_module, + lhs, + rhs, + source_location, + } = self; + todo!(); + } +} + +impl RunProcessStep for StmtConnect { + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + let StmtConnect { + lhs, + rhs, + source_location, + } = self; + Ok(ConnectAndLhsInstantiatedModule { + lhs_instantiated_module: instantiated_module, + lhs, + rhs, + source_location, } - match body { - ModuleBody::Normal(NormalModuleBody { body }) => todo!(), - ModuleBody::Extern(ExternModuleBody { + .run_process_step(instantiated_module, state, step)? + .map( + |ConnectAndLhsInstantiatedModule { + lhs_instantiated_module: _, + lhs, + rhs, + source_location, + }| StmtConnect { + lhs, + rhs, + source_location, + }, + )) + } +} + +impl RunProcessStep for StmtFormal { + fn run_process_step( + self, + _instantiated_module: InstantiatedModule, + _state: &mut State, + _step: &mut Step, + ) -> Result, Step::Error> { + // no inputs are Reset + Ok(Processed::new(self)) + } +} + +impl RunProcessStep for StmtIf { + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + todo!() + } +} + +impl RunProcessStep for StmtMatch { + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + todo!() + } +} + +impl RunProcessStep for ModuleIO +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + Ok(ResetTarget::from(self) + .run_process_step(instantiated_module, state, step)? + .map(|target| { + let ResetTarget::Base { base } = target else { + unreachable!(); + }; + let TargetBase::ModuleIO(module_io) = *base else { + unreachable!(); + }; + module_io + })) + } +} + +impl RunProcessStep for StmtWire +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + let Self { annotations, wire } = self; + Ok(ResetTarget::from(wire) + .run_process_step(instantiated_module, state, step)? + .map(|target| { + let ResetTarget::Base { base } = target else { + unreachable!(); + }; + let TargetBase::Wire(wire) = *base else { + unreachable!(); + }; + Self { annotations, wire } + })) + } +} + +impl RunProcessStep for StmtReg +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + let Self { annotations, reg } = self; + Ok(ResetTarget::from(reg) + .run_process_step(instantiated_module, state, step)? + .map(|target| { + let ResetTarget::Base { base } = target else { + unreachable!(); + }; + let TargetBase::Reg(reg) = *base else { + unreachable!(); + }; + Self { annotations, reg } + })) + } +} + +impl RunProcessStep for StmtInstance +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + let Self { + annotations, + instance, + } = self; + let child_instantiated_module = InstantiatedModule::Child { + parent: instantiated_module.intern_sized(), + instance: instance.intern_sized(), + }; + instance + .instantiated() + .run_process_step(child_instantiated_module, state, step)?; + for (field_index, AnnotatedModuleIO { module_io, .. }) in + instance.instantiated().module_io().into_iter().enumerate() + { + let (lhs_instantiated_module, lhs, rhs_instantiated_module, rhs) = match module_io + .flow() + { + Flow::Source => { + // connect to submodule's input from instance input + ( + child_instantiated_module, + module_io.to_expr(), + instantiated_module, + ops::FieldAccess::new_by_index(instance.to_expr(), field_index).to_expr(), + ) + } + Flow::Sink => { + // connect to instance output from submodule's output + ( + instantiated_module, + ops::FieldAccess::new_by_index(instance.to_expr(), field_index).to_expr(), + child_instantiated_module, + module_io.to_expr(), + ) + } + Flow::Duplex => unreachable!(), + }; + ConnectAndLhsInstantiatedModule { + lhs_instantiated_module, + lhs, + rhs, + source_location: instance.source_location(), + } + .run_process_step(rhs_instantiated_module, state, step)?; + } + Ok(ResetTarget::from(instance) + .run_process_step(instantiated_module, state, step)? + .map(|target| { + let ResetTarget::Base { base } = target else { + unreachable!(); + }; + let TargetBase::Instance(instance) = *base else { + unreachable!(); + }; + Self { + annotations, + instance, + } + })) + } +} + +impl RunProcessStep for StmtDeclaration +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + Ok(match self { + StmtDeclaration::Wire(decl) => decl + .run_process_step(instantiated_module, state, step)? + .map(StmtDeclaration::from), + StmtDeclaration::Reg(decl) => decl + .run_process_step(instantiated_module, state, step)? + .map(StmtDeclaration::from), + StmtDeclaration::Instance(decl) => decl + .run_process_step(instantiated_module, state, step)? + .map(StmtDeclaration::from), + }) + } +} + +impl RunProcessStep for Stmt +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + Ok(match self { + Stmt::Connect(stmt) => stmt + .run_process_step(instantiated_module, state, step)? + .map(Stmt::from), + Stmt::Formal(stmt) => stmt + .run_process_step(instantiated_module, state, step)? + .map(Stmt::from), + Stmt::If(stmt) => stmt + .run_process_step(instantiated_module, state, step)? + .map(Stmt::from), + Stmt::Match(stmt) => stmt + .run_process_step(instantiated_module, state, step)? + .map(Stmt::from), + Stmt::Declaration(stmt) => stmt + .run_process_step(instantiated_module, state, step)? + .map(Stmt::from), + }) + } +} + +impl RunProcessStep for Block +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + let Block { memories, stmts } = self; + // memories and memory ports won't ever contain any Reset values, + // so always just use the old `memories` value. we add the ports to the graph anyway to make the other logic easier. + for memory in memories { + for port in memory.ports() { + ResetTarget::from(port).run_process_step(instantiated_module, state, step)?; + } + } + let stmts = Result::, _>::from_iter( + stmts + .iter() + .map(|stmt| stmt.run_process_step(instantiated_module, state, step)), + )?; + Ok(stmts.map(|stmts| Block { memories, stmts })) + } +} + +impl RunProcessStep for Module +where + ResetTarget: RunProcessStep, +{ + fn run_process_step( + self, + instantiated_module: InstantiatedModule, + state: &mut State, + step: &mut Step, + ) -> Result, Step::Error> { + let module = *instantiated_module.leaf_module(); + let module_io = + Result::, _>, _>::from_iter(module.module_io().iter().map( + |&AnnotatedModuleIO { + annotations, + module_io, + }| { + Ok(module_io + .run_process_step(instantiated_module, state, step)? + .map(|module_io| AnnotatedModuleIO { + annotations, + module_io, + })) + }, + ))?; + let body = match module.body() { + ModuleBody::Normal(NormalModuleBody { body }) => body + .run_process_step(instantiated_module, state, step)? + .map(|body| ModuleBody::Normal(NormalModuleBody { body })), + body @ ModuleBody::Extern(ExternModuleBody { verilog_name: _, parameters: _, - }) => {} - } - } - fn deduce_types(&mut self) -> Result<(), DeduceResetsError> { - todo!() - } - fn transform_module(&mut self, module: Interned>) -> Interned> { - todo!() + }) => Processed::new(body), + }; + Ok(module_io.zip(body).map(|(module_io, body)| { + Module::new_unchecked( + module.name_id(), + module.source_location(), + body, + module_io, + module.module_annotations(), + ) + })) } } @@ -180,12 +713,16 @@ pub fn deduce_resets( ) -> Result>, DeduceResetsError> { let mut state = State { base_module: module, - transformed_exprs: HashMap::new(), node_ids: HashMap::new(), graph: UnGraph::new_undirected(), fallback_to_sync_reset, }; - state.build_graph_for_module(InstantiatedModule::Base(module).intern_sized()); - state.deduce_types()?; - Ok(state.transform_module(module)) + let Ok(Processed(())) = module.run_process_step( + InstantiatedModule::Base(module), + &mut state, + &mut AddNodesToGraphStep, + ); + todo!("add edges"); + todo!("deduce types"); + Ok(todo!("transform module")) }