From 546010739a187e0343e46edda3fd5b59d8c27341 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 26 Nov 2024 21:26:56 -0800 Subject: [PATCH] working on deduce_resets --- .../src/module/transform/deduce_resets.rs | 744 ++++++++++++++++-- 1 file changed, 697 insertions(+), 47 deletions(-) diff --git a/crates/fayalite/src/module/transform/deduce_resets.rs b/crates/fayalite/src/module/transform/deduce_resets.rs index e6b5987..d9009f8 100644 --- a/crates/fayalite/src/module/transform/deduce_resets.rs +++ b/crates/fayalite/src/module/transform/deduce_resets.rs @@ -3,44 +3,41 @@ #![cfg(todo)] -use hashbrown::{hash_map::Entry, HashMap}; -use num_bigint::BigInt; -use petgraph::{ - graph::{NodeIndex, UnGraph}, - unionfind::UnionFind, -}; - use crate::{ annotations::{Annotation, TargetedAnnotation}, bundle::{BundleField, BundleType}, enum_::{EnumType, EnumVariant}, expr::{ - ops, target::{ Target, TargetBase, TargetChild, TargetPathArrayElement, TargetPathBundleField, TargetPathDynArrayElement, TargetPathElement, }, - ExprEnum, Flow, + ExprEnum, }, + int::{SIntValue, UIntValue}, intern::{Intern, Interned, Memoize}, memory::{DynPortType, MemPort}, module::{ - transform::visit::{Folder, Visit, Visitor}, - AnnotatedModuleIO, Block, ExprInInstantiatedModule, ExternModuleBody, - ExternModuleParameter, ExternModuleParameterValue, InstantiatedModule, ModuleBody, - ModuleIO, NameId, NormalModuleBody, Stmt, StmtConnect, StmtDeclaration, StmtFormal, StmtIf, - StmtInstance, StmtMatch, StmtReg, StmtWire, + AnnotatedModuleIO, Block, ExprInInstantiatedModule, ExternModuleBody, InstantiatedModule, + ModuleBody, ModuleIO, NameId, NormalModuleBody, }, prelude::*, - source_location, + reset::{ResetType, ResetTypeDispatch}, +}; +use hashbrown::{hash_map::Entry, HashMap}; +use hashbrown::{hash_map::Entry, HashMap}; +use num_bigint::BigInt; +use num_bigint::BigInt; +use petgraph::unionfind::UnionFind; +use petgraph::{ + graph::{NodeIndex, UnGraph}, + unionfind::UnionFind, }; use std::{ borrow::Borrow, - convert::Infallible, fmt, hash::{Hash, Hasher}, marker::PhantomData, - mem, }; #[derive(Debug)] @@ -70,6 +67,25 @@ impl From for std::io::Error { } } +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +enum AnyReg { + Reg(Reg), + RegSync(Reg), + RegAsync(Reg), +} + +macro_rules! match_any_reg { + ( + $match_expr:expr, $fn:expr + ) => { + match $match_expr { + AnyReg::Reg(reg) => $fn(reg), + AnyReg::RegSync(reg) => $fn(reg), + AnyReg::RegAsync(reg) => $fn(reg), + } + }; +} + #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] enum ResetsLayout { NoResets, @@ -261,6 +277,45 @@ impl ResetGraph { }) } } + fn append_new_nodes_for_layout( + &mut self, + layout: ResetsLayout, + node_indexes: &mut Vec, + source_location: Option, + ) { + match layout { + ResetsLayout::NoResets => {} + ResetsLayout::Reset => node_indexes.push(self.new_node(None, source_location)), + ResetsLayout::SyncReset => { + node_indexes.push(self.new_node(Some(false), source_location)) + } + ResetsLayout::AsyncReset => { + node_indexes.push(self.new_node(Some(true), source_location)) + } + ResetsLayout::Bundle { + fields, + reset_count: _, + } => { + for field in fields { + self.append_new_nodes_for_layout(field, node_indexes, source_location); + } + } + ResetsLayout::Enum { + variants, + reset_count: _, + } => { + for variant in variants { + self.append_new_nodes_for_layout(variant, node_indexes, source_location); + } + } + ResetsLayout::Array { + element, + reset_count: _, + } => { + self.append_new_nodes_for_layout(*element, node_indexes, source_location); + } + } + } } #[derive(Debug)] @@ -318,23 +373,164 @@ impl<'a, T: ?Sized + Intern> MaybeInterned<'a, T> { } #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] -struct Resets<'a> { +struct Resets { ty: CanonicalType, layout: ResetsLayout, - node_indexes: MaybeInterned<'a, [ResetNodeIndex]>, + node_indexes: Interned<[ResetNodeIndex]>, } -impl<'a> Resets<'a> { - fn into_interned(self) -> Resets<'static> { - let Self { +impl Resets { + fn with_new_nodes( + reset_graph: &mut ResetGraph, + ty: CanonicalType, + source_location: Option, + ) -> Self { + let layout = ResetsLayout::new(ty); + let mut node_indexes = Vec::with_capacity(layout.reset_count()); + reset_graph.append_new_nodes_for_layout(layout, &mut node_indexes, source_location); + let node_indexes = Intern::intern_owned(node_indexes); + Self { ty, layout, node_indexes, - } = self; - Resets { - ty, - layout, - node_indexes: MaybeInterned::Interned(node_indexes.to_interned()), + } + } + fn array_elements(self) -> Self { + let array = ::from_canonical(self.ty); + let ResetsLayout::Array { + element, + reset_count: _, + } = self.layout + else { + unreachable!(); + }; + Self { + ty: array.element(), + layout: *element, + node_indexes: self.node_indexes, + } + } + fn bundle_fields(self) -> impl Iterator { + let bundle = Bundle::from_canonical(self.ty); + let ResetsLayout::Bundle { + fields, + reset_count: _, + } = self.layout + else { + unreachable!(); + }; + bundle.fields().into_iter().zip(fields).scan( + 0, + move |start_index, (BundleField { ty, .. }, layout)| { + let end_index = *start_index + layout.reset_count(); + let node_indexes = self.node_indexes[*start_index..end_index].intern(); + *start_index = end_index; + Some(Self { + ty, + layout, + node_indexes, + }) + }, + ) + } + fn enum_variants(self) -> impl Iterator> { + let enum_ = Enum::from_canonical(self.ty); + let ResetsLayout::Enum { + variants, + reset_count: _, + } = self.layout + else { + unreachable!(); + }; + enum_.variants().into_iter().zip(variants).scan( + 0, + move |start_index, (EnumVariant { ty, .. }, layout)| { + let end_index = *start_index + layout.reset_count(); + let node_indexes = self.node_indexes[*start_index..end_index].intern(); + *start_index = end_index; + Some(ty.map(|ty| Self { + ty, + layout, + node_indexes, + })) + }, + ) + } + fn substituted_type( + self, + reset_graph: &mut ResetGraph, + fallback_to_sync_reset: bool, + fallback_error_source_location: SourceLocation, + ) -> Result { + if self.layout.reset_count() == 0 { + return Ok(self.ty); + } + match self.ty { + CanonicalType::UInt(_) + | CanonicalType::SInt(_) + | CanonicalType::Bool(_) + | CanonicalType::AsyncReset(_) + | CanonicalType::SyncReset(_) + | CanonicalType::Clock(_) => Ok(self.ty), + CanonicalType::Array(ty) => Ok(CanonicalType::Array(Array::new_dyn( + self.array_elements().substituted_type( + reset_graph, + fallback_to_sync_reset, + fallback_error_source_location, + )?, + ty.len(), + ))), + CanonicalType::Enum(ty) => Ok(CanonicalType::Enum(Enum::new(Result::from_iter( + self.enum_variants().zip(ty.variants()).map( + |(resets, EnumVariant { name, ty: _ })| { + Ok(EnumVariant { + name, + ty: resets + .map(|resets| { + resets.substituted_type( + reset_graph, + fallback_to_sync_reset, + fallback_error_source_location, + ) + }) + .transpose()?, + }) + }, + ), + )?))), + CanonicalType::Bundle(ty) => Ok(CanonicalType::Bundle(Bundle::new(Result::from_iter( + self.bundle_fields().zip(ty.fields()).map( + |( + resets, + BundleField { + name, + flipped, + ty: _, + }, + )| { + Ok(BundleField { + name, + flipped, + ty: resets.substituted_type( + reset_graph, + fallback_to_sync_reset, + fallback_error_source_location, + )?, + }) + }, + ), + )?))), + CanonicalType::Reset(_) => Ok( + if reset_graph.is_async( + self.node_indexes[0], + fallback_to_sync_reset, + fallback_error_source_location, + )? { + CanonicalType::AsyncReset(AsyncReset) + } else { + CanonicalType::SyncReset(SyncReset) + }, + ), } } } @@ -342,11 +538,48 @@ impl<'a> Resets<'a> { #[derive(Debug)] struct State { base_module: Interned>, - expr_resets: HashMap, Resets<'static>>, + expr_resets: HashMap, Resets>, reset_graph: ResetGraph, fallback_to_sync_reset: bool, } +impl State { + fn get_resets( + &self, + instantiated_module: InstantiatedModule, + expr: impl ToExpr, + ) -> Option { + self.expr_resets + .get(&ExprInInstantiatedModule { + instantiated_module, + expr: Expr::canonical(expr.to_expr()), + }) + .copied() + } + fn get_or_make_resets( + &mut self, + instantiated_module: InstantiatedModule, + expr: impl ToExpr, + source_location: Option, + ) -> (Resets, bool) { + let expr = Expr::canonical(expr.to_expr()); + match self.expr_resets.entry(ExprInInstantiatedModule { + instantiated_module, + expr, + }) { + Entry::Occupied(entry) => (*entry.get(), false), + Entry::Vacant(entry) => ( + *entry.insert(Resets::with_new_nodes( + &mut self.reset_graph, + Expr::ty(expr), + source_location, + )), + true, + ), + } + } +} + struct PassOutput(P::Output); impl PassOutput { @@ -442,6 +675,19 @@ impl, P: Pass, A> FromIterator> for PassOutp } } +trait PassDispatch: Sized { + type Input; + type Output; + fn build_reset_graph( + self, + input: Self::Input, + ) -> Self::Output; + fn substitute_resets( + self, + input: Self::Input, + ) -> Self::Output; +} + trait Pass: Sized { type Output; fn output_new(v: T) -> PassOutput; @@ -452,6 +698,7 @@ trait Pass: Sized { ) -> PassOutput; fn map(v: PassOutput, f: impl FnOnce(T) -> U) -> PassOutput; fn zip(t: PassOutput, u: PassOutput) -> PassOutput<(T, U), Self>; + fn dispatch(dispatch: D, input: D::Input) -> D::Output; } struct BuildResetGraph; @@ -485,6 +732,10 @@ impl Pass for BuildResetGraph { fn zip(_t: PassOutput, _u: PassOutput) -> PassOutput<(T, U), Self> { PassOutput(()) } + + fn dispatch(dispatch: D, input: D::Input) -> D::Output { + dispatch.build_reset_graph(input) + } } struct SubstituteResets; @@ -517,6 +768,10 @@ impl Pass for SubstituteResets { fn zip(t: PassOutput, u: PassOutput) -> PassOutput<(T, U), Self> { PassOutput((t.0, u.0)) } + + fn dispatch(dispatch: D, input: D::Input) -> D::Output { + dispatch.substitute_resets(input) + } } struct PassArgs<'a, P: Pass> { @@ -560,6 +815,17 @@ impl<'a, P: Pass> PassArgs<'a, P> { _phantom: PhantomData, } } + fn get_resets(&self, expr: impl ToExpr) -> Option { + self.state.get_resets(self.instantiated_module, expr) + } + fn get_or_make_resets( + &mut self, + expr: impl ToExpr, + source_location: Option, + ) -> (Resets, bool) { + self.state + .get_or_make_resets(self.instantiated_module, expr, source_location) + } } trait RunPass: Sized { @@ -569,6 +835,50 @@ trait RunPass: Sized { ) -> Result, DeduceResetsError>; } +trait RunPassDispatch: Sized { + fn build_reset_graph( + &self, + pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError>; + fn substitute_resets( + &self, + pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError>; + fn dispatch( + &self, + pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + struct Dispatch<'a, T>(T, PhantomData<&'a mut ()>); + impl<'a, T: RunPassDispatch> PassDispatch for Dispatch<'a, &'_ T> { + type Input = PassArgs<'a, P>; + type Output = Result, DeduceResetsError>; + + fn build_reset_graph( + self, + input: Self::Input, + ) -> Self::Output { + self.0.build_reset_graph(input) + } + fn substitute_resets( + self, + input: Self::Input, + ) -> Self::Output { + self.0.substitute_resets(input) + } + } + P::dispatch(Dispatch(self, PhantomData), pass_args) + } +} + +impl RunPass

for T { + fn run_pass( + &self, + pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + T::dispatch(self, pass_args) + } +} + impl + Intern + Clone, P: Pass> RunPass

for Interned { fn run_pass( &self, @@ -590,6 +900,332 @@ where } } +impl, P: Pass> RunPass

for Option { + fn run_pass( + &self, + pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + match self { + Some(v) => Ok(v.run_pass(pass_args)?.map(Some)), + None => Ok(PassOutput::new(None)), + } + } +} + +fn reg_expr_run_pass( + reg: &Reg, + pass_args: PassArgs<'_, P>, +) -> Result, DeduceResetsError> { + Ok(AnyReg::from(*reg) + .run_pass(pass_args)? + .map(|reg| match_any_reg!(reg, ExprEnum::from))) +} + +impl RunPass

for ExprEnum { + fn run_pass( + &self, + pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + match self { + ExprEnum::UIntLiteral(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::SIntLiteral(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BoolLiteral(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BundleLiteral(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ArrayLiteral(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::EnumLiteral(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::Uninit(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::NotU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::NotS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::NotB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::Neg(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitAndU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitAndS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitAndB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitOrU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitOrS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitOrB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitXorU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitXorS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::BitXorB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::AddU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::AddS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::SubU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::SubS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::MulU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::MulS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::DivU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::DivS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::RemU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::RemS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::DynShlU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::DynShlS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::DynShrU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::DynShrS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::FixedShlU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::FixedShlS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::FixedShrU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::FixedShrS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpLtB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpLeB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpGtB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpGeB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpEqB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpNeB(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpLtU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpLeU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpGtU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpGeU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpEqU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpNeU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpLtS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpLeS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpGtS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpGeS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpEqS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CmpNeS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastUIntToUInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastUIntToSInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastSIntToUInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastSIntToSInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastBoolToUInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastBoolToSInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastUIntToBool(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastSIntToBool(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastBoolToSyncReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastUIntToSyncReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastSIntToSyncReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastBoolToAsyncReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastUIntToAsyncReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastSIntToAsyncReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastSyncResetToBool(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastSyncResetToUInt(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastSyncResetToSInt(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastSyncResetToReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastAsyncResetToBool(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastAsyncResetToUInt(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastAsyncResetToSInt(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastAsyncResetToReset(expr) => { + Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)) + } + ExprEnum::CastResetToBool(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastResetToUInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastResetToSInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastBoolToClock(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastUIntToClock(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastSIntToClock(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastClockToBool(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastClockToUInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastClockToSInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::FieldAccess(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::VariantAccess(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ArrayIndex(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::DynArrayIndex(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ReduceBitAndU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ReduceBitAndS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ReduceBitOrU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ReduceBitOrS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ReduceBitXorU(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ReduceBitXorS(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::SliceUInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::SliceSInt(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastToBits(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::CastBitsTo(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::ModuleIO(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::Instance(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::Wire(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + ExprEnum::Reg(expr) => reg_expr_run_pass(expr, pass_args), + ExprEnum::RegSync(expr) => reg_expr_run_pass(expr, pass_args), + ExprEnum::RegAsync(expr) => reg_expr_run_pass(expr, pass_args), + ExprEnum::MemPort(expr) => Ok(expr.run_pass(pass_args)?.map(ExprEnum::from)), + } + } +} + +impl RunPass

for Expr { + fn run_pass( + &self, + pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + Ok(Expr::expr_enum(*self) + .run_pass(pass_args)? + .map(|expr_enum| expr_enum.to_expr())) + } +} + +impl RunPassDispatch for ModuleIO { + fn build_reset_graph( + &self, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError> { + pass_args.get_or_make_resets(self, Some(self.source_location())); + Ok(PassOutput(())) + } + + fn substitute_resets( + &self, + pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError> { + let resets = pass_args + .get_resets(self) + .expect("added resets in build_reset_graph"); + Ok(PassOutput(Self::new_unchecked( + self.containing_module_name_id(), + self.name_id(), + self.source_location(), + self.is_input(), + resets.substituted_type( + &mut pass_args.state.reset_graph, + pass_args.state.fallback_to_sync_reset, + self.source_location(), + )?, + ))) + } +} + +impl RunPassDispatch for Wire { + fn build_reset_graph( + &self, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError> { + pass_args.get_or_make_resets(self, Some(self.source_location())); + Ok(PassOutput(())) + } + + fn substitute_resets( + &self, + pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError> { + let resets = pass_args + .get_resets(self) + .expect("added resets in build_reset_graph"); + Ok(PassOutput(Self::new_unchecked( + self.scoped_name(), + self.source_location(), + resets.substituted_type( + &mut pass_args.state.reset_graph, + pass_args.state.fallback_to_sync_reset, + self.source_location(), + )?, + ))) + } +} + +impl From> for AnyReg { + fn from(value: Reg) -> Self { + struct Dispatch; + impl ResetTypeDispatch for Dispatch { + type Input = Reg; + type Output = AnyReg; + + fn reset(self, input: Self::Input) -> Self::Output { + AnyReg::Reg(input) + } + + fn sync_reset(self, input: Self::Input) -> Self::Output { + AnyReg::RegSync(input) + } + + fn async_reset(self, input: Self::Input) -> Self::Output { + AnyReg::RegAsync(input) + } + } + T::dispatch(value, Dispatch) + } +} + +impl RunPassDispatch for AnyReg { + fn build_reset_graph( + &self, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError> { + match_any_reg!(self, |reg: &Reg| { + pass_args + .get_or_make_resets(Expr::canonical(reg.to_expr()), Some(reg.source_location())); + reg.init().run_pass(pass_args.as_mut())?; + Expr::canonical(reg.clock_domain()).run_pass(pass_args)?; + Ok(PassOutput(())) + }) + } + + fn substitute_resets( + &self, + mut pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError> { + match_any_reg!(self, |reg: &Reg| { + let scoped_name = reg.scoped_name(); + let source_location = reg.source_location(); + let resets = pass_args + .get_resets(Expr::canonical(reg.to_expr())) + .expect("added resets in build_reset_graph"); + let ty = resets.substituted_type( + &mut pass_args.state.reset_graph, + pass_args.state.fallback_to_sync_reset, + source_location, + )?; + let init = reg.init().run_pass(pass_args.as_mut())?.0; + let clock_domain = Expr::::from_canonical( + Expr::canonical(reg.clock_domain()).run_pass(pass_args)?.0, + ); + match Expr::ty(clock_domain) + .field_by_name("rst".intern()) + .expect("ClockDomain has rst field") + .ty + { + CanonicalType::AsyncReset(_) => { + Ok(PassOutput(AnyReg::RegAsync(Reg::new_unchecked( + scoped_name, + source_location, + ty, + Expr::from_bundle(clock_domain), + init, + )))) + } + CanonicalType::SyncReset(_) => Ok(PassOutput(AnyReg::RegSync(Reg::new_unchecked( + scoped_name, + source_location, + ty, + Expr::from_bundle(clock_domain), + init, + )))), + CanonicalType::UInt(_) + | CanonicalType::SInt(_) + | CanonicalType::Bool(_) + | CanonicalType::Array(_) + | CanonicalType::Enum(_) + | CanonicalType::Bundle(_) + | CanonicalType::Reset(_) + | CanonicalType::Clock(_) => unreachable!(), + } + }) + } +} + macro_rules! impl_run_pass_copy { ([$($generics:tt)*] $ty:ty) => { impl RunPass

for $ty { @@ -616,10 +1252,21 @@ macro_rules! impl_run_pass_clone { }; } -impl_run_pass_copy!([] Interned); -impl_run_pass_copy!([] usize); -impl_run_pass_copy!([] bool); impl_run_pass_clone!([] BigInt); +impl_run_pass_clone!([] UIntValue); +impl_run_pass_clone!([] SIntValue); +impl_run_pass_copy!([] BlackBoxInlineAnnotation); +impl_run_pass_copy!([] BlackBoxPathAnnotation); +impl_run_pass_copy!([] bool); +impl_run_pass_copy!([] CustomFirrtlAnnotation); +impl_run_pass_copy!([] DocStringAnnotation); +impl_run_pass_copy!([] DontTouchAnnotation); +impl_run_pass_copy!([] ExternModuleBody); +impl_run_pass_copy!([] Interned); +impl_run_pass_copy!([] NameId); +impl_run_pass_copy!([] SourceLocation); +impl_run_pass_copy!([] SVAttributeAnnotation); +impl_run_pass_copy!([] usize); macro_rules! impl_run_pass_for_struct { ( @@ -690,25 +1337,28 @@ impl_run_pass_for_struct! { } } -impl_run_pass_copy!([] ExternModuleBody); -impl_run_pass_copy!([] NameId); -impl_run_pass_copy!([] SourceLocation); -impl_run_pass_copy!([] DontTouchAnnotation); -impl_run_pass_copy!([] SVAttributeAnnotation); -impl_run_pass_copy!([] BlackBoxInlineAnnotation); -impl_run_pass_copy!([] BlackBoxPathAnnotation); -impl_run_pass_copy!([] DocStringAnnotation); -impl_run_pass_copy!([] CustomFirrtlAnnotation); impl_run_pass_copy!([] MemPort); // Mem can't contain any `Reset` types impl_run_pass_copy!([] Mem); // Mem can't contain any `Reset` types -impl_run_pass_for_enum! { - impl[] RunPass for TargetBase { - ModuleIO(v), - MemPort(v), - Reg(v), - Wire(v), - Instance(v), +impl RunPass

for TargetBase { + fn run_pass( + &self, + pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + let reg = match self { + TargetBase::ModuleIO(v) => return Ok(v.run_pass(pass_args)?.map(TargetBase::ModuleIO)), + TargetBase::MemPort(v) => return Ok(v.run_pass(pass_args)?.map(TargetBase::MemPort)), + &TargetBase::Reg(v) => v.into(), + &TargetBase::RegSync(v) => v.into(), + &TargetBase::RegAsync(v) => v.into(), + TargetBase::Wire(v) => return Ok(v.run_pass(pass_args)?.map(TargetBase::Wire)), + TargetBase::Instance(v) => return Ok(v.run_pass(pass_args)?.map(TargetBase::Instance)), + }; + Ok(reg.run_pass(pass_args)?.map(|reg| match reg { + AnyReg::Reg(reg) => TargetBase::Reg(reg), + AnyReg::RegSync(reg) => TargetBase::RegSync(reg), + AnyReg::RegAsync(reg) => TargetBase::RegAsync(reg), + })) } }