From b33566841d2adcdc2823b2a62d7d006278a02118 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 25 Jul 2024 22:10:33 -0700 Subject: [PATCH] make #[hdl_module] support functions with generic parameters --- .../fayalite-proc-macros-impl/src/module.rs | 98 ++++++++++++++++--- crates/fayalite/tests/module.rs | 89 ++++++++++++++++- 2 files changed, 169 insertions(+), 18 deletions(-) diff --git a/crates/fayalite-proc-macros-impl/src/module.rs b/crates/fayalite-proc-macros-impl/src/module.rs index a84bc71..b490096 100644 --- a/crates/fayalite-proc-macros-impl/src/module.rs +++ b/crates/fayalite-proc-macros-impl/src/module.rs @@ -3,15 +3,17 @@ use crate::{ is_hdl_attr, module::transform_body::{HdlLet, HdlLetKindIO}, - options, Errors, HdlAttr, + options, Errors, HdlAttr, PairsIterExt, }; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; +use std::collections::HashSet; use syn::{ parse::{Parse, ParseStream}, parse_quote, visit::{visit_pat, Visit}, - Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility, + Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct, + LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate, }; mod transform_body; @@ -72,6 +74,7 @@ pub(crate) struct ModuleFn { sig: Signature, block: Box, io: Vec, + struct_generics: Generics, } #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] @@ -80,6 +83,19 @@ pub(crate) enum ModuleKind { Normal, } +struct ContainsSkippedIdent<'a> { + skipped_idents: &'a HashSet, + contains_skipped_ident: bool, +} + +impl Visit<'_> for ContainsSkippedIdent<'_> { + fn visit_ident(&mut self, ident: &'_ Ident) { + if self.skipped_idents.contains(ident) { + self.contains_skipped_ident = true; + } + } +} + impl Parse for ModuleFn { fn parse(input: ParseStream) -> syn::Result { let ItemFn { @@ -95,7 +111,7 @@ impl Parse for ModuleFn { ref abi, fn_token: _, ident: _, - ref generics, + ref mut generics, paren_token: _, ref mut inputs, ref variadic, @@ -140,12 +156,61 @@ impl Parse for ModuleFn { if let Some(abi) = abi { errors.push(syn::Error::new_spanned(abi, "extern not allowed here")); } - if !generics.params.is_empty() { - errors.push(syn::Error::new_spanned( - &generics.params, - "generics are not supported yet", - )); - } + let mut skipped_idents = HashSet::new(); + let struct_generic_params = generics + .params + .pairs_mut() + .filter_map_pair_value_mut(|v| match v { + GenericParam::Lifetime(LifetimeParam { attrs, .. }) => { + errors + .unwrap_or_default(HdlAttr::::parse_and_take_attr(attrs)); + None + } + GenericParam::Type(TypeParam { attrs, ident, .. }) + | GenericParam::Const(ConstParam { attrs, ident, .. }) => { + if errors + .unwrap_or_default(HdlAttr::::parse_and_take_attr(attrs)) + .is_some() + { + skipped_idents.insert(ident.clone()); + None + } else { + Some(v.clone()) + } + } + }) + .collect(); + let struct_where_clause = generics + .where_clause + .as_mut() + .map(|where_clause| WhereClause { + where_token: where_clause.where_token, + predicates: where_clause + .predicates + .pairs_mut() + .filter_map_pair_value_mut(|v| match v { + WherePredicate::Lifetime(_) => None, + _ => { + let mut contains_skipped_ident = ContainsSkippedIdent { + skipped_idents: &skipped_idents, + contains_skipped_ident: false, + }; + contains_skipped_ident.visit_where_predicate(v); + if contains_skipped_ident.contains_skipped_ident { + None + } else { + Some(v.clone()) + } + } + }) + .collect(), + }); + let struct_generics = Generics { + lt_token: generics.lt_token, + params: struct_generic_params, + gt_token: generics.gt_token, + where_clause: struct_where_clause, + }; if let Some(variadic) = variadic { errors.push(syn::Error::new_spanned(variadic, "... not allowed here")); } @@ -166,6 +231,7 @@ impl Parse for ModuleFn { sig, block, io, + struct_generics, }) } } @@ -180,6 +246,7 @@ impl ModuleFn { sig, block, io, + struct_generics, } = self; let ConfigOptions { outline_generated: _, @@ -211,10 +278,13 @@ impl ModuleFn { ModuleKind::Normal => quote! { ::fayalite::module::NormalModule }, }; let fn_name = &outer_sig.ident; + let (_struct_impl_generics, struct_type_generics, struct_where_clause) = + struct_generics.split_for_impl(); + let struct_ty = quote! {#fn_name #struct_type_generics}; body_sig.ident = parse_quote! {__body}; body_sig.inputs.insert( 0, - parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>}, + parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>}, ); let body_fn = ItemFn { attrs: vec![], @@ -223,7 +293,7 @@ impl ModuleFn { block, }; outer_sig.output = - parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>}; + parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>}; let io_flips = io .iter() .map(|io| match io.kind.kind { @@ -236,9 +306,11 @@ impl ModuleFn { let io_types = io.iter().map(|io| &io.kind.ty).collect::>(); let io_names = io.iter().map(|io| &io.name).collect::>(); let fn_name_str = fn_name.to_string(); + let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl(); + let body_turbofish_type_generics = body_type_generics.as_turbofish(); let block = parse_quote! {{ #body_fn - ::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*)) + ::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*)) }}; let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none()); let struct_options = if fixed_type { @@ -254,7 +326,7 @@ impl ModuleFn { ::fayalite::__std::fmt::Debug)] #[allow(non_camel_case_types)] #struct_options - #vis struct #fn_name { + #vis struct #fn_name #struct_generics #struct_where_clause { #( #io_flips #vis #io_names: #io_types,)* diff --git a/crates/fayalite/tests/module.rs b/crates/fayalite/tests/module.rs index d996a17..9a83c59 100644 --- a/crates/fayalite/tests/module.rs +++ b/crates/fayalite/tests/module.rs @@ -5,6 +5,7 @@ use fayalite::{ array::Array, assert_export_firrtl, clock::{Clock, ClockDomain}, + expr::ToExpr, hdl_module, int::{DynUInt, DynUIntType, IntCmp, SInt, UInt}, intern::Intern, @@ -12,7 +13,7 @@ use fayalite::{ module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind}, reset::{SyncReset, ToReset}, source_location::SourceLocation, - ty::Value, + ty::{FixedType, Value}, }; use serde_json::json; @@ -139,22 +140,22 @@ circuit my_module: } #[hdl_module(outline_generated)] -pub fn check_array_repeat() { +pub fn check_array_repeat() { #[hdl] let i: UInt<8> = m.input(); #[hdl] - let o: Array<[UInt<8>; 3]> = m.output(); + let o: Array<[UInt<8>; N]> = m.output(); m.connect( o, #[hdl] - [i; 3], + [i; N], ); } #[test] fn test_array_repeat() { let _n = SourceLocation::normalize_files_for_tests(); - let m = check_array_repeat(); + let m = check_array_repeat::<3>(); dbg!(m); #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 assert_export_firrtl! { @@ -169,6 +170,84 @@ circuit check_array_repeat: connect _array_literal_expr[1], i connect _array_literal_expr[2], i connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1] +", + }; + let m = check_array_repeat::<4>(); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + "/test/check_array_repeat_1.fir": r"FIRRTL version 3.2.0 +circuit check_array_repeat_1: + module check_array_repeat_1: @[module-XXXXXXXXXX.rs 1:1] + input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1] + output o: UInt<8>[4] @[module-XXXXXXXXXX.rs 3:1] + wire _array_literal_expr: UInt<8>[4] + connect _array_literal_expr[0], i + connect _array_literal_expr[1], i + connect _array_literal_expr[2], i + connect _array_literal_expr[3], i + connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1] +", + }; +} + +#[hdl_module(outline_generated)] +pub fn check_skipped_generics(v: U) +where + T: Value>, + U: std::fmt::Display, +{ + dbg!(M); + #[hdl] + let i: T = m.input(); + #[hdl] + let o: Array<[T; N]> = m.output(); + let bytes = v.to_string().as_bytes().to_expr(); + #[hdl] + let o2: Array<[UInt<8>]> = m.output(bytes.ty()); + m.connect( + o, + #[hdl] + [i; N], + ); + m.connect(o2, bytes); +} + +#[test] +fn test_skipped_generics() { + let _n = SourceLocation::normalize_files_for_tests(); + let m = check_skipped_generics::, _, 3, 4>("Hello World!\n"); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + "/test/check_skipped_generics.fir": r"FIRRTL version 3.2.0 +circuit check_skipped_generics: + module check_skipped_generics: @[module-XXXXXXXXXX.rs 1:1] + input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1] + output o: UInt<8>[3] @[module-XXXXXXXXXX.rs 3:1] + output o2: UInt<8>[13] @[module-XXXXXXXXXX.rs 4:1] + wire _array_literal_expr: UInt<8>[3] + connect _array_literal_expr[0], i + connect _array_literal_expr[1], i + connect _array_literal_expr[2], i + connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 5:1] + wire _array_literal_expr_1: UInt<8>[13] + connect _array_literal_expr_1[0], UInt<8>(0h48) + connect _array_literal_expr_1[1], UInt<8>(0h65) + connect _array_literal_expr_1[2], UInt<8>(0h6C) + connect _array_literal_expr_1[3], UInt<8>(0h6C) + connect _array_literal_expr_1[4], UInt<8>(0h6F) + connect _array_literal_expr_1[5], UInt<8>(0h20) + connect _array_literal_expr_1[6], UInt<8>(0h57) + connect _array_literal_expr_1[7], UInt<8>(0h6F) + connect _array_literal_expr_1[8], UInt<8>(0h72) + connect _array_literal_expr_1[9], UInt<8>(0h6C) + connect _array_literal_expr_1[10], UInt<8>(0h64) + connect _array_literal_expr_1[11], UInt<8>(0h21) + connect _array_literal_expr_1[12], UInt<8>(0hA) + connect o2, _array_literal_expr_1 @[module-XXXXXXXXXX.rs 6:1] ", }; }