From 6446b71afd245e6372a9ce3df97c646445d15336 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 27 Nov 2024 23:20:22 -0800 Subject: [PATCH] deduce_resets works! --- .../src/module/transform/deduce_resets.rs | 877 +++++++++++++++--- crates/fayalite/tests/module.rs | 263 +++++- 2 files changed, 989 insertions(+), 151 deletions(-) diff --git a/crates/fayalite/src/module/transform/deduce_resets.rs b/crates/fayalite/src/module/transform/deduce_resets.rs index ad943ae..b2d2cb5 100644 --- a/crates/fayalite/src/module/transform/deduce_resets.rs +++ b/crates/fayalite/src/module/transform/deduce_resets.rs @@ -1,39 +1,34 @@ // SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information -#![cfg(todo)] - use crate::{ annotations::{Annotation, TargetedAnnotation}, bundle::{BundleField, BundleType}, enum_::{EnumType, EnumVariant}, expr::{ - ops, + ops::{self, ArrayLiteral}, target::{ Target, TargetBase, TargetChild, TargetPathArrayElement, TargetPathBundleField, TargetPathDynArrayElement, TargetPathElement, }, ExprEnum, }, + formal::FormalKind, int::{SIntValue, UIntValue}, intern::{Intern, Interned, Memoize}, memory::{DynPortType, MemPort}, module::{ AnnotatedModuleIO, Block, ExprInInstantiatedModule, ExternModuleBody, InstantiatedModule, - ModuleBody, ModuleIO, NameId, NormalModuleBody, + ModuleBody, ModuleIO, NameId, NormalModuleBody, Stmt, StmtConnect, StmtDeclaration, + StmtFormal, StmtIf, StmtInstance, StmtMatch, StmtReg, StmtWire, }, prelude::*, reset::{ResetType, ResetTypeDispatch}, }; -use hashbrown::{hash_map::Entry, HashMap}; +use hashbrown::{hash_map::Entry, HashMap, HashSet}; use num_bigint::BigInt; use petgraph::unionfind::UnionFind; -use std::{ - borrow::Borrow, - fmt, - hash::{Hash, Hasher}, - marker::PhantomData, -}; +use std::{fmt, marker::PhantomData}; #[derive(Debug)] pub enum DeduceResetsError { @@ -313,60 +308,6 @@ impl ResetGraph { } } -#[derive(Debug)] -enum MaybeInterned<'a, T: ?Sized + Intern> { - Interned(Interned), - Borrowed(&'a T), -} - -impl<'a, T: ?Sized + Intern> Borrow for MaybeInterned<'a, T> { - fn borrow(&self) -> &T { - &**self - } -} - -impl<'a, T: ?Sized + Intern + PartialEq> PartialEq for MaybeInterned<'a, T> { - fn eq(&self, other: &Self) -> bool { - **self == **other - } -} - -impl<'a, T: ?Sized + Intern + Eq> Eq for MaybeInterned<'a, T> {} - -impl<'a, T: ?Sized + Intern + Hash> Hash for MaybeInterned<'a, T> { - fn hash(&self, state: &mut H) { - (**self).hash(state); - } -} - -impl<'a, T: ?Sized + Intern> Copy for MaybeInterned<'a, T> {} - -impl<'a, T: ?Sized + Intern> Clone for MaybeInterned<'a, T> { - fn clone(&self) -> Self { - *self - } -} - -impl<'a, T: ?Sized + Intern> std::ops::Deref for MaybeInterned<'a, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - match self { - MaybeInterned::Interned(v) => v, - MaybeInterned::Borrowed(v) => v, - } - } -} - -impl<'a, T: ?Sized + Intern> MaybeInterned<'a, T> { - fn to_interned(self) -> Interned { - match self { - MaybeInterned::Interned(v) => v, - MaybeInterned::Borrowed(v) => v.intern(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] struct Resets { ty: CanonicalType, @@ -532,7 +473,8 @@ impl Resets { #[derive(Debug)] struct State { - base_module: Interned>, + modules_added_to_graph: HashSet, + substituted_modules: HashMap>, expr_resets: HashMap, Resets>, reset_graph: ResetGraph, fallback_to_sync_reset: bool, @@ -584,12 +526,6 @@ impl PassOutput { fn from_fn(f: impl FnOnce() -> T) -> Self { PassOutput::new(()).map(|()| f()) } - fn as_ref(&self) -> PassOutput<&T, P> { - P::output_as_ref(self) - } - fn as_mut(&mut self) -> PassOutput<&mut T, P> { - P::output_as_mut(self) - } fn map(self, f: impl FnOnce(T) -> U) -> PassOutput { P::map(self, f) } @@ -686,11 +622,12 @@ trait PassDispatch: Sized { trait Pass: Sized { type Output; fn output_new(v: T) -> PassOutput; - fn output_as_ref(v: &PassOutput) -> PassOutput<&T, Self>; - fn output_as_mut(v: &mut PassOutput) -> PassOutput<&mut T, Self>; fn output_from_iter, A>( iter: impl IntoIterator>, ) -> PassOutput; + fn try_array_from_fn( + f: impl FnMut(usize) -> Result, E>, + ) -> Result, E>; fn map(v: PassOutput, f: impl FnOnce(T) -> U) -> PassOutput; fn zip(t: PassOutput, u: PassOutput) -> PassOutput<(T, U), Self>; fn dispatch(dispatch: D, input: D::Input) -> D::Output; @@ -705,14 +642,6 @@ impl Pass for BuildResetGraph { PassOutput(()) } - fn output_as_ref(_v: &PassOutput) -> PassOutput<&T, Self> { - PassOutput(()) - } - - fn output_as_mut(_v: &mut PassOutput) -> PassOutput<&mut T, Self> { - PassOutput(()) - } - fn output_from_iter, A>( iter: impl IntoIterator>, ) -> PassOutput { @@ -720,6 +649,15 @@ impl Pass for BuildResetGraph { PassOutput(()) } + fn try_array_from_fn( + mut f: impl FnMut(usize) -> Result, E>, + ) -> Result, E> { + for i in 0..N { + f(i)?; + } + Ok(PassOutput(())) + } + fn map(_v: PassOutput, _f: impl FnOnce(T) -> U) -> PassOutput { PassOutput(()) } @@ -742,20 +680,24 @@ impl Pass for SubstituteResets { PassOutput(v) } - fn output_as_ref(v: &PassOutput) -> PassOutput<&T, Self> { - PassOutput(&v.0) - } - - fn output_as_mut(v: &mut PassOutput) -> PassOutput<&mut T, Self> { - PassOutput(&mut v.0) - } - fn output_from_iter, A>( iter: impl IntoIterator>, ) -> PassOutput { PassOutput(T::from_iter(iter.into_iter().map(|PassOutput(v)| v))) } + fn try_array_from_fn( + mut f: impl FnMut(usize) -> Result, E>, + ) -> Result, E> { + let mut retval = [const { None }; N]; + for i in 0..N { + retval[i] = Some(f(i)?.0); + } + Ok(PassOutput( + retval.map(|v| v.expect("just wrote Some to all elements")), + )) + } + fn map(v: PassOutput, f: impl FnOnce(T) -> U) -> PassOutput { PassOutput(f(v.0)) } @@ -821,6 +763,23 @@ impl<'a, P: Pass> PassArgs<'a, P> { self.state .get_or_make_resets(self.instantiated_module, expr, source_location) } + fn union( + &mut self, + a: Resets, + b: Resets, + fallback_error_source_location: Option, + ) -> Result<(), DeduceResetsError> { + assert_eq!(a.layout, b.layout); + assert_eq!(a.ty, b.ty); + for (a_node_index, b_node_index) in a.node_indexes.into_iter().zip(b.node_indexes) { + self.state.reset_graph.union( + a_node_index, + b_node_index, + fallback_error_source_location.unwrap_or(self.fallback_error_source_location), + )?; + } + Ok(()) + } } trait RunPass: Sized { @@ -874,6 +833,61 @@ impl RunPass

for T { } } +trait RunPassExpr: ToExpr + Sized { + type Args<'a>: IntoIterator> + 'a + where + Self: 'a; + fn args<'a>(&'a self) -> Self::Args<'a>; + fn source_location(&self) -> Option; + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError>; + fn new( + &self, + ty: CanonicalType, + new_args: Vec>, + ) -> Result; +} + +impl RunPassDispatch for T { + fn build_reset_graph( + &self, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError> { + let source_location = self.source_location(); + let (resets, _) = pass_args.get_or_make_resets(self, source_location); + let args_resets = Result::from_iter(self.args().into_iter().map(|arg| { + arg.run_pass(pass_args.as_mut())?; + let (resets, _) = pass_args.get_or_make_resets(arg, source_location); + Ok(resets) + }))?; + self.union_parts(resets, args_resets, pass_args)?; + Ok(PassOutput(())) + } + + fn substitute_resets( + &self, + mut pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError> { + let source_location = self.source_location(); + let (resets, _) = pass_args.get_or_make_resets(self, source_location); + let ty = resets.substituted_type( + &mut pass_args.state.reset_graph, + pass_args.state.fallback_to_sync_reset, + pass_args.fallback_error_source_location, + )?; + let new_args = Result::from_iter( + self.args() + .into_iter() + .map(|arg| Ok(arg.run_pass(pass_args.as_mut())?.0)), + )?; + Ok(PassOutput(self.new(ty, new_args)?)) + } +} + impl + Intern + Clone, P: Pass> RunPass

for Interned { fn run_pass( &self, @@ -895,6 +909,15 @@ where } } +impl, P: Pass, const N: usize> RunPass

for [T; N] { + fn run_pass( + &self, + mut pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + P::try_array_from_fn(|i| self[i].run_pass(pass_args.as_mut())) + } +} + impl, P: Pass> RunPass

for Option { fn run_pass( &self, @@ -937,6 +960,7 @@ fn cast_bit_op( Expr::canonical(self.arg).run_pass(pass_args.as_mut())?; let (expr_resets, _) = pass_args.get_or_make_resets(self.expr, None); let (arg_resets, _) = pass_args.get_or_make_resets(self.arg, None); + // don't use PassArgs::union since types don't match and we want to just union resets if they exist for (expr_node_index, arg_node_index) in expr_resets .node_indexes .into_iter() @@ -1166,6 +1190,17 @@ impl RunPass

for Expr { } } +impl RunPass

for Expr { + fn run_pass( + &self, + pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + Ok(Expr::canonical(*self) + .run_pass(pass_args)? + .map(Expr::from_canonical)) + } +} + impl RunPass

for Expr { fn run_pass( &self, @@ -1177,61 +1212,374 @@ impl RunPass

for Expr { } } -impl RunPassDispatch for ModuleIO { - fn build_reset_graph( - &self, - mut pass_args: PassArgs<'_, BuildResetGraph>, - ) -> Result, DeduceResetsError> { - pass_args.get_or_make_resets(self, Some(self.source_location())); - Ok(PassOutput(())) +impl RunPassExpr for ops::Uninit { + type Args<'a> = [Expr; 0]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [] } - fn substitute_resets( + fn source_location(&self) -> Option { + None + } + + fn union_parts( &self, - pass_args: PassArgs<'_, SubstituteResets>, - ) -> Result, DeduceResetsError> { - let resets = pass_args - .get_resets(self) - .expect("added resets in build_reset_graph"); - Ok(PassOutput(Self::new_unchecked( + _resets: Resets, + _args_resets: Vec, + _pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + Ok(()) + } + + fn new( + &self, + ty: CanonicalType, + _new_args: Vec>, + ) -> Result { + Ok(ops::Uninit::new(ty)) + } +} + +impl RunPassExpr for ops::BundleLiteral { + type Args<'a> = Interned<[Expr]>; + + fn args<'a>(&'a self) -> Self::Args<'a> { + self.field_values() + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + for (resets_field, field_expr_resets) in resets.bundle_fields().zip(args_resets) { + pass_args.union(resets_field, field_expr_resets, None)?; + } + Ok(()) + } + + fn new( + &self, + ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(ops::BundleLiteral::new( + Bundle::from_canonical(ty), + Intern::intern_owned(new_args), + )) + } +} + +impl RunPassExpr for ArrayLiteral { + type Args<'a> = Interned<[Expr]>; + + fn args<'a>(&'a self) -> Self::Args<'a> { + self.element_values() + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + let resets_elements = resets.array_elements(); + for arg_resets in args_resets { + pass_args.union(resets_elements, arg_resets, None)?; + } + Ok(()) + } + + fn new( + &self, + ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(Self::new( + ::from_canonical(ty).element(), + Intern::intern_owned(new_args), + )) + } +} + +impl RunPassExpr for ops::EnumLiteral { + type Args<'a> = Option>; + + fn args<'a>(&'a self) -> Self::Args<'a> { + self.variant_value() + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + if let Some(Some(variant_resets)) = resets.enum_variants().nth(self.variant_index()) { + pass_args.union(variant_resets, args_resets[0], None)?; + } + Ok(()) + } + + fn new( + &self, + ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(Self::new_by_index( + Enum::from_canonical(ty), + self.variant_index(), + new_args.get(0).copied(), + )) + } +} + +impl RunPassExpr for ops::FieldAccess { + type Args<'a> = [Expr; 1]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [Expr::canonical(self.base())] + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + let Some(field_resets) = args_resets[0].bundle_fields().nth(self.field_index()) else { + unreachable!(); + }; + pass_args.union(resets, field_resets, None) + } + + fn new( + &self, + _ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(Self::new_by_index( + Expr::from_canonical(new_args[0]), + self.field_index(), + )) + } +} + +impl RunPassExpr for ops::VariantAccess { + type Args<'a> = [Expr; 1]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [Expr::canonical(self.base())] + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + if let Some(Some(variant_resets)) = args_resets[0].enum_variants().nth(self.variant_index()) + { + pass_args.union(resets, variant_resets, None)?; + } + Ok(()) + } + + fn new( + &self, + _ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(Self::new_by_index( + Expr::from_canonical(new_args[0]), + self.variant_index(), + )) + } +} + +impl RunPassExpr for ops::ArrayIndex { + type Args<'a> = [Expr; 1]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [Expr::canonical(self.base())] + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + pass_args.union(resets, args_resets[0].array_elements(), None) + } + + fn new( + &self, + _ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(Self::new( + Expr::from_canonical(new_args[0]), + self.element_index(), + )) + } +} + +impl RunPassExpr for ops::DynArrayIndex { + type Args<'a> = [Expr; 2]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [ + Expr::canonical(self.base()), + Expr::canonical(self.element_index()), + ] + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + resets: Resets, + args_resets: Vec, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + pass_args.union(resets, args_resets[0].array_elements(), None) + } + + fn new( + &self, + _ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(Self::new( + Expr::from_canonical(new_args[0]), + Expr::from_canonical(new_args[1]), + )) + } +} + +impl RunPassExpr for ops::CastBitsTo { + type Args<'a> = [Expr; 1]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [Expr::canonical(self.arg())] + } + + fn source_location(&self) -> Option { + None + } + + fn union_parts( + &self, + _resets: Resets, + _args_resets: Vec, + _pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + Ok(()) + } + + fn new( + &self, + ty: CanonicalType, + new_args: Vec>, + ) -> Result { + Ok(Self::new(Expr::from_canonical(new_args[0]), ty)) + } +} + +impl RunPassExpr for ModuleIO { + type Args<'a> = [Expr; 0]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [] + } + + fn source_location(&self) -> Option { + Some(self.source_location()) + } + + fn union_parts( + &self, + _resets: Resets, + _args_resets: Vec, + _pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + Ok(()) + } + + fn new( + &self, + ty: CanonicalType, + _new_args: Vec>, + ) -> Result { + Ok(Self::new_unchecked( self.containing_module_name_id(), self.name_id(), self.source_location(), self.is_input(), - resets.substituted_type( - &mut pass_args.state.reset_graph, - pass_args.state.fallback_to_sync_reset, - self.source_location(), - )?, - ))) + ty, + )) } } -impl RunPassDispatch for Wire { - fn build_reset_graph( - &self, - mut pass_args: PassArgs<'_, BuildResetGraph>, - ) -> Result, DeduceResetsError> { - pass_args.get_or_make_resets(self, Some(self.source_location())); - Ok(PassOutput(())) +impl RunPassExpr for Wire { + type Args<'a> = [Expr; 0]; + + fn args<'a>(&'a self) -> Self::Args<'a> { + [] } - fn substitute_resets( + fn source_location(&self) -> Option { + Some(self.source_location()) + } + + fn union_parts( &self, - pass_args: PassArgs<'_, SubstituteResets>, - ) -> Result, DeduceResetsError> { - let resets = pass_args - .get_resets(self) - .expect("added resets in build_reset_graph"); - Ok(PassOutput(Self::new_unchecked( + _resets: Resets, + _args_resets: Vec, + _pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result<(), DeduceResetsError> { + Ok(()) + } + + fn new( + &self, + ty: CanonicalType, + _new_args: Vec>, + ) -> Result { + Ok(Self::new_unchecked( self.scoped_name(), self.source_location(), - resets.substituted_type( - &mut pass_args.state.reset_graph, - pass_args.state.fallback_to_sync_reset, - self.source_location(), - )?, - ))) + ty, + )) } } @@ -1325,6 +1673,54 @@ impl RunPassDispatch for AnyReg { } } +impl RunPassDispatch for Instance { + fn build_reset_graph( + &self, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError> { + self.instantiated().run_pass(PassArgs:: { + state: pass_args.state, + instantiated_module: InstantiatedModule::Child { + parent: pass_args.instantiated_module.intern_sized(), + instance: self.intern(), + }, + fallback_error_source_location: self.instantiated().source_location(), + _phantom: PhantomData, + })?; + let (resets, _) = pass_args.get_or_make_resets(self, Some(self.source_location())); + for (resets_field, module_io) in resets.bundle_fields().zip(self.instantiated().module_io()) + { + let (module_io_resets, _) = pass_args.get_or_make_resets( + module_io.module_io, + Some(self.instantiated().source_location()), + ); + pass_args.union(resets_field, module_io_resets, Some(self.source_location()))?; + } + Ok(PassOutput(())) + } + + fn substitute_resets( + &self, + pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError> { + let PassOutput(instantiated) = + self.instantiated().run_pass(PassArgs:: { + state: pass_args.state, + instantiated_module: InstantiatedModule::Child { + parent: pass_args.instantiated_module.intern_sized(), + instance: self.intern(), + }, + fallback_error_source_location: self.instantiated().source_location(), + _phantom: PhantomData, + })?; + Ok(PassOutput(Self::new_unchecked( + self.scoped_name(), + instantiated, + self.source_location(), + ))) + } +} + macro_rules! impl_run_pass_copy { ([$($generics:tt)*] $ty:ty) => { impl RunPass

for $ty { @@ -1352,8 +1748,9 @@ macro_rules! impl_run_pass_clone { } impl_run_pass_clone!([] BigInt); -impl_run_pass_clone!([] UIntValue); impl_run_pass_clone!([] SIntValue); +impl_run_pass_clone!([] std::ops::Range); +impl_run_pass_clone!([] UIntValue); impl_run_pass_copy!([] BlackBoxInlineAnnotation); impl_run_pass_copy!([] BlackBoxPathAnnotation); impl_run_pass_copy!([] bool); @@ -1363,9 +1760,12 @@ impl_run_pass_copy!([] DontTouchAnnotation); impl_run_pass_copy!([] ExternModuleBody); impl_run_pass_copy!([] Interned); impl_run_pass_copy!([] NameId); +impl_run_pass_copy!([] SInt); impl_run_pass_copy!([] SourceLocation); impl_run_pass_copy!([] SVAttributeAnnotation); +impl_run_pass_copy!([] UInt); impl_run_pass_copy!([] usize); +impl_run_pass_copy!([] FormalKind); macro_rules! impl_run_pass_for_struct { ( @@ -1530,6 +1930,80 @@ impl_run_pass_for_int_cast_op!(ops::CastUIntToSInt); impl_run_pass_for_int_cast_op!(ops::CastSIntToUInt); impl_run_pass_for_int_cast_op!(ops::CastSIntToSInt); +impl_run_pass_for_struct! { + #[constructor = ops::SliceUInt::new(base, range)] + impl[] RunPass for ops::SliceUInt { + base(): _, + range(): _, + } +} + +impl_run_pass_for_struct! { + #[constructor = ops::SliceSInt::new(base, range)] + impl[] RunPass for ops::SliceSInt { + base(): _, + range(): _, + } +} + +impl_run_pass_for_struct! { + #[constructor = ops::CastToBits::new(arg)] + impl[] RunPass for ops::CastToBits { + arg(): _, + } +} + +impl_run_pass_for_struct! { + impl[] RunPass for StmtFormal { + kind: _, + clk: _, + pred: _, + en: _, + text: _, + source_location: _, + } +} + +impl_run_pass_for_struct! { + impl[] RunPass for StmtIf { + cond: _, + source_location: _, + blocks: _, + } +} + +impl_run_pass_for_struct! { + impl[] RunPass for StmtMatch { + expr: _, + source_location: _, + blocks: _, + } +} + +impl_run_pass_for_struct! { + impl[] RunPass for StmtWire { + annotations: _, + wire: _, + } +} + +impl_run_pass_for_struct! { + impl[] RunPass for StmtInstance { + annotations: _, + instance: _, + } +} + +impl_run_pass_for_enum! { + impl[] RunPass for Stmt { + Connect(v), + Formal(v), + If(v), + Match(v), + Declaration(v), + } +} + impl_run_pass_for_struct! { impl[] RunPass for Block { memories: _, @@ -1546,12 +2020,51 @@ impl_run_pass_for_struct! { impl_run_pass_copy!([] MemPort); // Mem can't contain any `Reset` types impl_run_pass_copy!([] Mem); // Mem can't contain any `Reset` types +impl RunPassDispatch for StmtConnect { + fn build_reset_graph( + &self, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError> { + let Self { + lhs, + rhs, + source_location, + } = *self; + pass_args.fallback_error_source_location = source_location; + lhs.run_pass(pass_args.as_mut())?; + rhs.run_pass(pass_args.as_mut())?; + let (lhs_resets, _) = pass_args.get_or_make_resets(lhs, Some(source_location)); + let (rhs_resets, _) = pass_args.get_or_make_resets(rhs, Some(source_location)); + pass_args.union(lhs_resets, rhs_resets, Some(source_location))?; + Ok(PassOutput(())) + } + + fn substitute_resets( + &self, + mut pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError> { + let StmtConnect { + lhs, + rhs, + source_location, + } = *self; + pass_args.fallback_error_source_location = source_location; + let lhs = lhs.run_pass(pass_args.as_mut())?.0; + let rhs = rhs.run_pass(pass_args)?.0; + Ok(PassOutput(StmtConnect { + lhs, + rhs, + source_location, + })) + } +} + impl RunPass

for TargetBase { fn run_pass( &self, pass_args: PassArgs<'_, P>, ) -> Result, DeduceResetsError> { - let reg = match self { + let reg: AnyReg = match self { TargetBase::ModuleIO(v) => return Ok(v.run_pass(pass_args)?.map(TargetBase::ModuleIO)), TargetBase::MemPort(v) => return Ok(v.run_pass(pass_args)?.map(TargetBase::MemPort)), &TargetBase::Reg(v) => v.into(), @@ -1568,6 +2081,36 @@ impl RunPass

for TargetBase { } } +impl RunPass

for StmtDeclaration { + fn run_pass( + &self, + mut pass_args: PassArgs<'_, P>, + ) -> Result, DeduceResetsError> { + let (annotations, reg) = match self { + StmtDeclaration::Wire(v) => { + return Ok(v.run_pass(pass_args)?.map(StmtDeclaration::Wire)) + } + &StmtDeclaration::Reg(StmtReg { annotations, reg }) => (annotations, AnyReg::from(reg)), + &StmtDeclaration::RegSync(StmtReg { annotations, reg }) => { + (annotations, AnyReg::from(reg)) + } + &StmtDeclaration::RegAsync(StmtReg { annotations, reg }) => { + (annotations, AnyReg::from(reg)) + } + StmtDeclaration::Instance(v) => { + return Ok(v.run_pass(pass_args)?.map(StmtDeclaration::Instance)) + } + }; + let annotations = annotations.run_pass(pass_args.as_mut())?; + let reg = reg.run_pass(pass_args)?; + Ok((annotations, reg).call(|(annotations, reg)| match reg { + AnyReg::Reg(reg) => StmtReg { annotations, reg }.into(), + AnyReg::RegSync(reg) => StmtReg { annotations, reg }.into(), + AnyReg::RegAsync(reg) => StmtReg { annotations, reg }.into(), + })) + } +} + impl_run_pass_for_struct! { impl[] RunPass for TargetPathBundleField { name: _, @@ -1640,23 +2183,56 @@ impl_run_pass_for_struct! { } } -impl_run_pass_for_struct! { - #[adjust_pass_args = |module: &Module<_>, pass_args: &mut PassArgs<'_, _>| { - pass_args.fallback_error_source_location = module.source_location(); - }] - #[constructor = Module::new_unchecked( - name_id, - source_location, - body, - module_io, - module_annotations, - )] - impl[] RunPass for Module { - name_id(): _, - source_location(): _, - body(): _, - module_io(): _, - module_annotations(): _, +impl RunPassDispatch for Module { + fn build_reset_graph( + &self, + mut pass_args: PassArgs<'_, BuildResetGraph>, + ) -> Result, DeduceResetsError> { + pass_args.fallback_error_source_location = self.source_location(); + if pass_args + .state + .modules_added_to_graph + .insert(pass_args.instantiated_module) + { + self.name_id().run_pass(pass_args.as_mut())?; + self.source_location().run_pass(pass_args.as_mut())?; + self.module_io().run_pass(pass_args.as_mut())?; + self.body().run_pass(pass_args.as_mut())?; + self.module_annotations().run_pass(pass_args.as_mut())?; + } + Ok(PassOutput(())) + } + + fn substitute_resets( + &self, + mut pass_args: PassArgs<'_, SubstituteResets>, + ) -> Result, DeduceResetsError> { + pass_args.fallback_error_source_location = self.source_location(); + if let Some(&retval) = pass_args + .state + .substituted_modules + .get(&pass_args.instantiated_module) + { + return Ok(PassOutput(retval)); + } + let PassOutput(name_id) = self.name_id().run_pass(pass_args.as_mut())?; + let PassOutput(source_location) = self.source_location().run_pass(pass_args.as_mut())?; + let PassOutput(module_io) = self.module_io().run_pass(pass_args.as_mut())?; + let PassOutput(body) = self.body().run_pass(pass_args.as_mut())?; + let PassOutput(module_annotations) = + self.module_annotations().run_pass(pass_args.as_mut())?; + let retval = Module::new_unchecked( + name_id, + source_location, + body, + module_io, + module_annotations, + ); + pass_args + .state + .substituted_modules + .insert(pass_args.instantiated_module, retval); + Ok(PassOutput(retval)) } } @@ -1665,7 +2241,8 @@ pub fn deduce_resets( fallback_to_sync_reset: bool, ) -> Result>, DeduceResetsError> { let mut state = State { - base_module: module, + modules_added_to_graph: HashSet::new(), + substituted_modules: HashMap::new(), expr_resets: HashMap::new(), reset_graph: ResetGraph::default(), fallback_to_sync_reset, diff --git a/crates/fayalite/tests/module.rs b/crates/fayalite/tests/module.rs index 222f7ba..4cb3057 100644 --- a/crates/fayalite/tests/module.rs +++ b/crates/fayalite/tests/module.rs @@ -2,7 +2,8 @@ // See Notices.txt for copyright information use fayalite::{ assert_export_firrtl, firrtl::ExportOptions, intern::Intern, - module::transform::simplify_enums::SimplifyEnumsKind, prelude::*, ty::StaticType, + module::transform::simplify_enums::SimplifyEnumsKind, prelude::*, reset::ResetType, + ty::StaticType, }; use serde_json::json; @@ -4026,3 +4027,263 @@ circuit check_enum_connect_any: ", }; } + +#[hdl_module(outline_generated)] +pub fn check_deduce_resets(ty: T) { + #[hdl] + let cd: ClockDomain = m.input(ClockDomain[ty]); + #[hdl] + let my_reg = reg_builder().reset(0u8).clock_domain(cd); + #[hdl] + let u8_in: UInt<8> = m.input(); + connect(my_reg, u8_in); + #[hdl] + let u8_out: UInt<8> = m.output(); + connect(u8_out, my_reg); + #[hdl] + let enum_in: OneOfThree = m.input(); + #[hdl] + let enum_out: OneOfThree = m.output(); + #[hdl] + let reset_out: Reset = m.output(); + connect(reset_out, cd.rst.to_reset()); + #[hdl] + match enum_in { + OneOfThree::<_, _, _>::A(v) => { + connect( + enum_out, + OneOfThree[Reset][AsyncReset][SyncReset].A(cd.rst.to_reset()), + ); + connect(reset_out, v); + } + OneOfThree::<_, _, _>::B(v) => { + connect(enum_out, OneOfThree[Reset][AsyncReset][SyncReset].B(v)) + } + OneOfThree::<_, _, _>::C(v) => { + connect(enum_out, OneOfThree[Reset][AsyncReset][SyncReset].C(v)) + } + } +} + +#[test] +fn test_deduce_resets() { + let _n = SourceLocation::normalize_files_for_tests(); + let m = check_deduce_resets(Reset); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + options: ExportOptions { + simplify_enums: None, + ..ExportOptions::default() + }, + "/test/check_deduce_resets.fir": r"FIRRTL version 3.2.0 +circuit check_deduce_resets: + type Ty0 = {clk: Clock, rst: Reset} + type Ty1 = {|A: Reset, B: AsyncReset, C: UInt<1>|} + module check_deduce_resets: @[module-XXXXXXXXXX.rs 1:1] + input cd: Ty0 @[module-XXXXXXXXXX.rs 2:1] + input u8_in: UInt<8> @[module-XXXXXXXXXX.rs 4:1] + output u8_out: UInt<8> @[module-XXXXXXXXXX.rs 6:1] + input enum_in: Ty1 @[module-XXXXXXXXXX.rs 8:1] + output enum_out: Ty1 @[module-XXXXXXXXXX.rs 9:1] + output reset_out: Reset @[module-XXXXXXXXXX.rs 10:1] + regreset my_reg: UInt<8>, cd.clk, cd.rst, UInt<8>(0h0) @[module-XXXXXXXXXX.rs 3:1] + connect my_reg, u8_in @[module-XXXXXXXXXX.rs 5:1] + connect u8_out, my_reg @[module-XXXXXXXXXX.rs 7:1] + connect reset_out, cd.rst @[module-XXXXXXXXXX.rs 11:1] + match enum_in: @[module-XXXXXXXXXX.rs 12:1] + A(_match_arm_value): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(A, cd.rst) @[module-XXXXXXXXXX.rs 13:1] + connect reset_out, _match_arm_value @[module-XXXXXXXXXX.rs 14:1] + B(_match_arm_value_1): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(B, _match_arm_value_1) @[module-XXXXXXXXXX.rs 15:1] + C(_match_arm_value_2): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(C, _match_arm_value_2) @[module-XXXXXXXXXX.rs 16:1] +", + }; + fayalite::module::transform::deduce_resets::deduce_resets(m.canonical().intern_sized(), false) + .unwrap_err(); + let m = fayalite::module::transform::deduce_resets::deduce_resets( + m.canonical().intern_sized(), + true, + ) + .unwrap(); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + options: ExportOptions { + simplify_enums: None, + ..ExportOptions::default() + }, + "/test/check_deduce_resets.fir": r"FIRRTL version 3.2.0 +circuit check_deduce_resets: + type Ty0 = {clk: Clock, rst: UInt<1>} + type Ty1 = {|A: UInt<1>, B: AsyncReset, C: UInt<1>|} + module check_deduce_resets: @[module-XXXXXXXXXX.rs 1:1] + input cd: Ty0 @[module-XXXXXXXXXX.rs 2:1] + input u8_in: UInt<8> @[module-XXXXXXXXXX.rs 4:1] + output u8_out: UInt<8> @[module-XXXXXXXXXX.rs 6:1] + input enum_in: Ty1 @[module-XXXXXXXXXX.rs 8:1] + output enum_out: Ty1 @[module-XXXXXXXXXX.rs 9:1] + output reset_out: UInt<1> @[module-XXXXXXXXXX.rs 10:1] + regreset my_reg: UInt<8>, cd.clk, cd.rst, UInt<8>(0h0) @[module-XXXXXXXXXX.rs 3:1] + connect my_reg, u8_in @[module-XXXXXXXXXX.rs 5:1] + connect u8_out, my_reg @[module-XXXXXXXXXX.rs 7:1] + connect reset_out, cd.rst @[module-XXXXXXXXXX.rs 11:1] + match enum_in: @[module-XXXXXXXXXX.rs 12:1] + A(_match_arm_value): + connect enum_out, {|A: UInt<1>, B: AsyncReset, C: UInt<1>|}(A, cd.rst) @[module-XXXXXXXXXX.rs 13:1] + connect reset_out, _match_arm_value @[module-XXXXXXXXXX.rs 14:1] + B(_match_arm_value_1): + connect enum_out, {|A: UInt<1>, B: AsyncReset, C: UInt<1>|}(B, _match_arm_value_1) @[module-XXXXXXXXXX.rs 15:1] + C(_match_arm_value_2): + connect enum_out, {|A: UInt<1>, B: AsyncReset, C: UInt<1>|}(C, _match_arm_value_2) @[module-XXXXXXXXXX.rs 16:1] +", + }; + let m = check_deduce_resets(SyncReset); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + options: ExportOptions { + simplify_enums: None, + ..ExportOptions::default() + }, + "/test/check_deduce_resets.fir": r"FIRRTL version 3.2.0 +circuit check_deduce_resets: + type Ty0 = {clk: Clock, rst: UInt<1>} + type Ty1 = {|A: Reset, B: AsyncReset, C: UInt<1>|} + module check_deduce_resets: @[module-XXXXXXXXXX.rs 1:1] + input cd: Ty0 @[module-XXXXXXXXXX.rs 2:1] + input u8_in: UInt<8> @[module-XXXXXXXXXX.rs 4:1] + output u8_out: UInt<8> @[module-XXXXXXXXXX.rs 6:1] + input enum_in: Ty1 @[module-XXXXXXXXXX.rs 8:1] + output enum_out: Ty1 @[module-XXXXXXXXXX.rs 9:1] + output reset_out: Reset @[module-XXXXXXXXXX.rs 10:1] + regreset my_reg: UInt<8>, cd.clk, cd.rst, UInt<8>(0h0) @[module-XXXXXXXXXX.rs 3:1] + connect my_reg, u8_in @[module-XXXXXXXXXX.rs 5:1] + connect u8_out, my_reg @[module-XXXXXXXXXX.rs 7:1] + connect reset_out, cd.rst @[module-XXXXXXXXXX.rs 11:1] + match enum_in: @[module-XXXXXXXXXX.rs 12:1] + A(_match_arm_value): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(A, cd.rst) @[module-XXXXXXXXXX.rs 13:1] + connect reset_out, _match_arm_value @[module-XXXXXXXXXX.rs 14:1] + B(_match_arm_value_1): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(B, _match_arm_value_1) @[module-XXXXXXXXXX.rs 15:1] + C(_match_arm_value_2): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(C, _match_arm_value_2) @[module-XXXXXXXXXX.rs 16:1] +", + }; + let m = fayalite::module::transform::deduce_resets::deduce_resets( + m.canonical().intern_sized(), + false, + ) + .unwrap(); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + options: ExportOptions { + simplify_enums: None, + ..ExportOptions::default() + }, + "/test/check_deduce_resets.fir": r"FIRRTL version 3.2.0 +circuit check_deduce_resets: + type Ty0 = {clk: Clock, rst: UInt<1>} + type Ty1 = {|A: UInt<1>, B: AsyncReset, C: UInt<1>|} + module check_deduce_resets: @[module-XXXXXXXXXX.rs 1:1] + input cd: Ty0 @[module-XXXXXXXXXX.rs 2:1] + input u8_in: UInt<8> @[module-XXXXXXXXXX.rs 4:1] + output u8_out: UInt<8> @[module-XXXXXXXXXX.rs 6:1] + input enum_in: Ty1 @[module-XXXXXXXXXX.rs 8:1] + output enum_out: Ty1 @[module-XXXXXXXXXX.rs 9:1] + output reset_out: UInt<1> @[module-XXXXXXXXXX.rs 10:1] + regreset my_reg: UInt<8>, cd.clk, cd.rst, UInt<8>(0h0) @[module-XXXXXXXXXX.rs 3:1] + connect my_reg, u8_in @[module-XXXXXXXXXX.rs 5:1] + connect u8_out, my_reg @[module-XXXXXXXXXX.rs 7:1] + connect reset_out, cd.rst @[module-XXXXXXXXXX.rs 11:1] + match enum_in: @[module-XXXXXXXXXX.rs 12:1] + A(_match_arm_value): + connect enum_out, {|A: UInt<1>, B: AsyncReset, C: UInt<1>|}(A, cd.rst) @[module-XXXXXXXXXX.rs 13:1] + connect reset_out, _match_arm_value @[module-XXXXXXXXXX.rs 14:1] + B(_match_arm_value_1): + connect enum_out, {|A: UInt<1>, B: AsyncReset, C: UInt<1>|}(B, _match_arm_value_1) @[module-XXXXXXXXXX.rs 15:1] + C(_match_arm_value_2): + connect enum_out, {|A: UInt<1>, B: AsyncReset, C: UInt<1>|}(C, _match_arm_value_2) @[module-XXXXXXXXXX.rs 16:1] +", + }; + let m = check_deduce_resets(AsyncReset); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + options: ExportOptions { + simplify_enums: None, + ..ExportOptions::default() + }, + "/test/check_deduce_resets.fir": r"FIRRTL version 3.2.0 +circuit check_deduce_resets: + type Ty0 = {clk: Clock, rst: AsyncReset} + type Ty1 = {|A: Reset, B: AsyncReset, C: UInt<1>|} + module check_deduce_resets: @[module-XXXXXXXXXX.rs 1:1] + input cd: Ty0 @[module-XXXXXXXXXX.rs 2:1] + input u8_in: UInt<8> @[module-XXXXXXXXXX.rs 4:1] + output u8_out: UInt<8> @[module-XXXXXXXXXX.rs 6:1] + input enum_in: Ty1 @[module-XXXXXXXXXX.rs 8:1] + output enum_out: Ty1 @[module-XXXXXXXXXX.rs 9:1] + output reset_out: Reset @[module-XXXXXXXXXX.rs 10:1] + regreset my_reg: UInt<8>, cd.clk, cd.rst, UInt<8>(0h0) @[module-XXXXXXXXXX.rs 3:1] + connect my_reg, u8_in @[module-XXXXXXXXXX.rs 5:1] + connect u8_out, my_reg @[module-XXXXXXXXXX.rs 7:1] + connect reset_out, cd.rst @[module-XXXXXXXXXX.rs 11:1] + match enum_in: @[module-XXXXXXXXXX.rs 12:1] + A(_match_arm_value): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(A, cd.rst) @[module-XXXXXXXXXX.rs 13:1] + connect reset_out, _match_arm_value @[module-XXXXXXXXXX.rs 14:1] + B(_match_arm_value_1): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(B, _match_arm_value_1) @[module-XXXXXXXXXX.rs 15:1] + C(_match_arm_value_2): + connect enum_out, {|A: Reset, B: AsyncReset, C: UInt<1>|}(C, _match_arm_value_2) @[module-XXXXXXXXXX.rs 16:1] +", + }; + let m = fayalite::module::transform::deduce_resets::deduce_resets( + m.canonical().intern_sized(), + false, + ) + .unwrap(); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + options: ExportOptions { + simplify_enums: None, + ..ExportOptions::default() + }, + "/test/check_deduce_resets.fir": r"FIRRTL version 3.2.0 +circuit check_deduce_resets: + type Ty0 = {clk: Clock, rst: AsyncReset} + type Ty1 = {|A: AsyncReset, B: AsyncReset, C: UInt<1>|} + module check_deduce_resets: @[module-XXXXXXXXXX.rs 1:1] + input cd: Ty0 @[module-XXXXXXXXXX.rs 2:1] + input u8_in: UInt<8> @[module-XXXXXXXXXX.rs 4:1] + output u8_out: UInt<8> @[module-XXXXXXXXXX.rs 6:1] + input enum_in: Ty1 @[module-XXXXXXXXXX.rs 8:1] + output enum_out: Ty1 @[module-XXXXXXXXXX.rs 9:1] + output reset_out: AsyncReset @[module-XXXXXXXXXX.rs 10:1] + regreset my_reg: UInt<8>, cd.clk, cd.rst, UInt<8>(0h0) @[module-XXXXXXXXXX.rs 3:1] + connect my_reg, u8_in @[module-XXXXXXXXXX.rs 5:1] + connect u8_out, my_reg @[module-XXXXXXXXXX.rs 7:1] + connect reset_out, cd.rst @[module-XXXXXXXXXX.rs 11:1] + match enum_in: @[module-XXXXXXXXXX.rs 12:1] + A(_match_arm_value): + connect enum_out, {|A: AsyncReset, B: AsyncReset, C: UInt<1>|}(A, cd.rst) @[module-XXXXXXXXXX.rs 13:1] + connect reset_out, _match_arm_value @[module-XXXXXXXXXX.rs 14:1] + B(_match_arm_value_1): + connect enum_out, {|A: AsyncReset, B: AsyncReset, C: UInt<1>|}(B, _match_arm_value_1) @[module-XXXXXXXXXX.rs 15:1] + C(_match_arm_value_2): + connect enum_out, {|A: AsyncReset, B: AsyncReset, C: UInt<1>|}(C, _match_arm_value_2) @[module-XXXXXXXXXX.rs 16:1] +", + }; +}