From 7963f0a5cd0b9eb9123aaa8d6523bb600405ac26 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 25 Jul 2024 22:07:23 -0700 Subject: [PATCH 1/3] add Iterator> helpers --- crates/fayalite-proc-macros-impl/src/lib.rs | 106 +++++++++++++++++- .../expand_aggregate_literals.rs | 65 +++++------ .../src/module/transform_body/expand_match.rs | 14 +-- 3 files changed, 140 insertions(+), 45 deletions(-) diff --git a/crates/fayalite-proc-macros-impl/src/lib.rs b/crates/fayalite-proc-macros-impl/src/lib.rs index 7aafe66..3a4a326 100644 --- a/crates/fayalite-proc-macros-impl/src/lib.rs +++ b/crates/fayalite-proc-macros-impl/src/lib.rs @@ -7,7 +7,9 @@ use std::io::{ErrorKind, Write}; use syn::{ bracketed, parenthesized, parse::{Parse, ParseStream, Parser}, - parse_quote, AttrStyle, Attribute, Error, Item, Token, + parse_quote, + punctuated::Pair, + AttrStyle, Attribute, Error, Item, Token, }; mod fold; @@ -318,6 +320,108 @@ impl HdlAttr { } } +#[allow(dead_code)] +pub(crate) trait PairsIterExt: Sized + Iterator { + fn map_pair T2, PunctFn: FnMut(P1) -> P2>( + self, + mut value_fn: ValueFn, + mut punct_fn: PunctFn, + ) -> impl Iterator> + where + Self: Iterator>, + { + self.map(move |p| { + let (t, p) = p.into_tuple(); + let t = value_fn(t); + let p = p.map(&mut punct_fn); + Pair::new(t, p) + }) + } + fn filter_map_pair Option, PunctFn: FnMut(P1) -> P2>( + self, + mut value_fn: ValueFn, + mut punct_fn: PunctFn, + ) -> impl Iterator> + where + Self: Iterator>, + { + self.filter_map(move |p| { + let (t, p) = p.into_tuple(); + let t = value_fn(t)?; + let p = p.map(&mut punct_fn); + Some(Pair::new(t, p)) + }) + } + fn map_pair_value T2>( + self, + f: F, + ) -> impl Iterator> + where + Self: Iterator>, + { + self.map_pair(f, |v| v) + } + fn filter_map_pair_value Option>( + self, + f: F, + ) -> impl Iterator> + where + Self: Iterator>, + { + self.filter_map_pair(f, |v| v) + } + fn map_pair_value_mut<'a, T1: 'a, T2: 'a, P: Clone + 'a, F: FnMut(T1) -> T2 + 'a>( + self, + f: F, + ) -> impl Iterator> + 'a + where + Self: Iterator> + 'a, + { + self.map_pair(f, |v| v.clone()) + } + fn filter_map_pair_value_mut< + 'a, + T1: 'a, + T2: 'a, + P: Clone + 'a, + F: FnMut(T1) -> Option + 'a, + >( + self, + f: F, + ) -> impl Iterator> + 'a + where + Self: Iterator> + 'a, + { + self.filter_map_pair(f, |v| v.clone()) + } + fn map_pair_value_ref<'a, T1: 'a, T2: 'a, P: Clone + 'a, F: FnMut(T1) -> T2 + 'a>( + self, + f: F, + ) -> impl Iterator> + 'a + where + Self: Iterator> + 'a, + { + self.map_pair(f, |v| v.clone()) + } + fn filter_map_pair_value_ref< + 'a, + T1: 'a, + T2: 'a, + P: Clone + 'a, + F: FnMut(T1) -> Option + 'a, + >( + self, + f: F, + ) -> impl Iterator> + 'a + where + Self: Iterator> + 'a, + { + self.filter_map_pair(f, |v| v.clone()) + } +} + +impl>> PairsIterExt for Iter {} + pub(crate) struct Errors { error: Option, finished: bool, diff --git a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_aggregate_literals.rs b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_aggregate_literals.rs index ccc7252..bfa4a51 100644 --- a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_aggregate_literals.rs +++ b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_aggregate_literals.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information -use crate::{module::transform_body::Visitor, options, Errors, HdlAttr}; +use crate::{module::transform_body::Visitor, options, Errors, HdlAttr, PairsIterExt}; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote_spanned, ToTokens, TokenStreamExt}; use syn::{ @@ -207,7 +207,7 @@ impl StructOrEnumLiteral { } pub(crate) fn map_fields( self, - mut f: impl FnMut(StructOrEnumLiteralField) -> StructOrEnumLiteralField, + f: impl FnMut(StructOrEnumLiteralField) -> StructOrEnumLiteralField, ) -> Self { let Self { attrs, @@ -217,10 +217,7 @@ impl StructOrEnumLiteral { dot2_token, rest, } = self; - let fields = Punctuated::from_iter(fields.into_pairs().map(|p| { - let (field, comma) = p.into_tuple(); - Pair::new(f(field), comma) - })); + let fields = fields.into_pairs().map_pair_value(f).collect(); Self { attrs, path, @@ -247,26 +244,22 @@ impl From for StructOrEnumLiteral { attrs, path: TypePath { qself, path }, brace_or_paren: BraceOrParen::Brace(brace_token), - fields: Punctuated::from_iter(fields.into_pairs().map(|v| { - let ( - FieldValue { + fields: fields + .into_pairs() + .map_pair_value( + |FieldValue { + attrs, + member, + colon_token, + expr, + }| StructOrEnumLiteralField { attrs, member, colon_token, expr, }, - comma, - ) = v.into_tuple(); - Pair::new( - StructOrEnumLiteralField { - attrs, - member, - colon_token, - expr, - }, - comma, ) - })), + .collect(), dot2_token, rest, } @@ -514,20 +507,24 @@ impl Visitor { } }; }; - let fields = Punctuated::from_iter(args.into_pairs().enumerate().map(|(index, p)| { - let (expr, comma) = p.into_tuple(); - let mut index = Index::from(index); - index.span = hdl_attr.hdl.span; - Pair::new( - StructOrEnumLiteralField { - attrs: vec![], - member: Member::Unnamed(index), - colon_token: None, - expr, - }, - comma, - ) - })); + let fields = args + .into_pairs() + .enumerate() + .map(|(index, p)| { + let (expr, comma) = p.into_tuple(); + let mut index = Index::from(index); + index.span = hdl_attr.hdl.span; + Pair::new( + StructOrEnumLiteralField { + attrs: vec![], + member: Member::Unnamed(index), + colon_token: None, + expr, + }, + comma, + ) + }) + .collect(); self.process_struct_enum( hdl_attr, StructOrEnumLiteral { diff --git a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs index 58b0bbb..fe1a895 100644 --- a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs +++ b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs @@ -6,7 +6,7 @@ use crate::{ expand_aggregate_literals::{AggregateLiteralOptions, StructOrEnumPath}, with_debug_clone_and_fold, Visitor, }, - Errors, HdlAttr, + Errors, HdlAttr, PairsIterExt, }; use proc_macro2::{Span, TokenStream}; use quote::{ToTokens, TokenStreamExt}; @@ -238,11 +238,7 @@ trait ParseMatchPat: Sized { leading_vert, cases: cases .into_pairs() - .filter_map(|pair| { - let (pat, punct) = pair.into_tuple(); - let pat = Self::parse(state, pat).ok()?; - Some(Pair::new(pat, punct)) - }) + .filter_map_pair_value(|pat| Self::parse(state, pat).ok()) .collect(), })), Pat::Paren(PatParen { @@ -282,10 +278,8 @@ trait ParseMatchPat: Sized { }) => { let fields = fields .into_pairs() - .filter_map(|pair| { - let (field_pat, punct) = pair.into_tuple(); - let field_pat = MatchPatStructField::parse(state, field_pat).ok()?; - Some(Pair::new(field_pat, punct)) + .filter_map_pair_value(|field_pat| { + MatchPatStructField::parse(state, field_pat).ok() }) .collect(); let path = TypePath { qself, path }; From ef4b3b4081bfa06d194f55276cdda36f682c77a7 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 25 Jul 2024 22:08:28 -0700 Subject: [PATCH 2/3] make [T; N]: ToExpr for any N instead of a fixed list --- crates/fayalite/src/array.rs | 61 ++++-------------------------------- 1 file changed, 6 insertions(+), 55 deletions(-) diff --git a/crates/fayalite/src/array.rs b/crates/fayalite/src/array.rs index c0db3c2..d25987c 100644 --- a/crates/fayalite/src/array.rs +++ b/crates/fayalite/src/array.rs @@ -548,70 +548,21 @@ impl, T: FixedType> ToExpr for Vec { } } -impl, T: FixedType> ToExpr for [E; 0] { - type Type = ArrayType<[T::Value; 0]>; +impl, T: FixedType, const N: usize> ToExpr for [E; N] { + type Type = ArrayType<[T::Value; N]>; fn ty(&self) -> Self::Type { ArrayType::new_with_len_type(FixedType::fixed_type(), ()) } fn to_expr(&self) -> Expr<::Value> { - Array::new(FixedType::fixed_type(), Arc::new([])).to_expr() + let elements = Intern::intern_owned(Vec::from_iter( + self.iter().map(|v| v.to_expr().to_canonical_dyn()), + )); + ArrayLiteral::new_unchecked(elements, self.ty()).to_expr() } } -macro_rules! impl_to_expr_for_non_empty_array { - ($N:literal) => { - impl, T: Type> ToExpr for [E; $N] { - type Type = ArrayType<[T::Value; $N]>; - - fn ty(&self) -> Self::Type { - ArrayType::new_with_len_type(self[0].ty(), ()) - } - - fn to_expr(&self) -> Expr<::Value> { - let elements = Intern::intern_owned(Vec::from_iter( - self.iter().map(|v| v.to_expr().to_canonical_dyn()), - )); - ArrayLiteral::new_unchecked(elements, self.ty()).to_expr() - } - } - }; -} - -impl_to_expr_for_non_empty_array!(1); -impl_to_expr_for_non_empty_array!(2); -impl_to_expr_for_non_empty_array!(3); -impl_to_expr_for_non_empty_array!(4); -impl_to_expr_for_non_empty_array!(5); -impl_to_expr_for_non_empty_array!(6); -impl_to_expr_for_non_empty_array!(7); -impl_to_expr_for_non_empty_array!(8); -impl_to_expr_for_non_empty_array!(9); -impl_to_expr_for_non_empty_array!(10); -impl_to_expr_for_non_empty_array!(11); -impl_to_expr_for_non_empty_array!(12); -impl_to_expr_for_non_empty_array!(13); -impl_to_expr_for_non_empty_array!(14); -impl_to_expr_for_non_empty_array!(15); -impl_to_expr_for_non_empty_array!(16); -impl_to_expr_for_non_empty_array!(17); -impl_to_expr_for_non_empty_array!(18); -impl_to_expr_for_non_empty_array!(19); -impl_to_expr_for_non_empty_array!(20); -impl_to_expr_for_non_empty_array!(21); -impl_to_expr_for_non_empty_array!(22); -impl_to_expr_for_non_empty_array!(23); -impl_to_expr_for_non_empty_array!(24); -impl_to_expr_for_non_empty_array!(25); -impl_to_expr_for_non_empty_array!(26); -impl_to_expr_for_non_empty_array!(27); -impl_to_expr_for_non_empty_array!(28); -impl_to_expr_for_non_empty_array!(29); -impl_to_expr_for_non_empty_array!(30); -impl_to_expr_for_non_empty_array!(31); -impl_to_expr_for_non_empty_array!(32); - #[derive(Clone, Debug)] pub struct ArrayIntoIter { array: Arc, From c5901cd2177c38a43da03603367ea87a6561ead3 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 25 Jul 2024 22:10:33 -0700 Subject: [PATCH 3/3] make #[hdl_module] support functions with generic parameters --- .../fayalite-proc-macros-impl/src/module.rs | 98 ++++++++++++++++--- crates/fayalite/tests/module.rs | 89 ++++++++++++++++- 2 files changed, 169 insertions(+), 18 deletions(-) diff --git a/crates/fayalite-proc-macros-impl/src/module.rs b/crates/fayalite-proc-macros-impl/src/module.rs index a84bc71..b490096 100644 --- a/crates/fayalite-proc-macros-impl/src/module.rs +++ b/crates/fayalite-proc-macros-impl/src/module.rs @@ -3,15 +3,17 @@ use crate::{ is_hdl_attr, module::transform_body::{HdlLet, HdlLetKindIO}, - options, Errors, HdlAttr, + options, Errors, HdlAttr, PairsIterExt, }; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; +use std::collections::HashSet; use syn::{ parse::{Parse, ParseStream}, parse_quote, visit::{visit_pat, Visit}, - Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility, + Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct, + LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate, }; mod transform_body; @@ -72,6 +74,7 @@ pub(crate) struct ModuleFn { sig: Signature, block: Box, io: Vec, + struct_generics: Generics, } #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] @@ -80,6 +83,19 @@ pub(crate) enum ModuleKind { Normal, } +struct ContainsSkippedIdent<'a> { + skipped_idents: &'a HashSet, + contains_skipped_ident: bool, +} + +impl Visit<'_> for ContainsSkippedIdent<'_> { + fn visit_ident(&mut self, ident: &'_ Ident) { + if self.skipped_idents.contains(ident) { + self.contains_skipped_ident = true; + } + } +} + impl Parse for ModuleFn { fn parse(input: ParseStream) -> syn::Result { let ItemFn { @@ -95,7 +111,7 @@ impl Parse for ModuleFn { ref abi, fn_token: _, ident: _, - ref generics, + ref mut generics, paren_token: _, ref mut inputs, ref variadic, @@ -140,12 +156,61 @@ impl Parse for ModuleFn { if let Some(abi) = abi { errors.push(syn::Error::new_spanned(abi, "extern not allowed here")); } - if !generics.params.is_empty() { - errors.push(syn::Error::new_spanned( - &generics.params, - "generics are not supported yet", - )); - } + let mut skipped_idents = HashSet::new(); + let struct_generic_params = generics + .params + .pairs_mut() + .filter_map_pair_value_mut(|v| match v { + GenericParam::Lifetime(LifetimeParam { attrs, .. }) => { + errors + .unwrap_or_default(HdlAttr::::parse_and_take_attr(attrs)); + None + } + GenericParam::Type(TypeParam { attrs, ident, .. }) + | GenericParam::Const(ConstParam { attrs, ident, .. }) => { + if errors + .unwrap_or_default(HdlAttr::::parse_and_take_attr(attrs)) + .is_some() + { + skipped_idents.insert(ident.clone()); + None + } else { + Some(v.clone()) + } + } + }) + .collect(); + let struct_where_clause = generics + .where_clause + .as_mut() + .map(|where_clause| WhereClause { + where_token: where_clause.where_token, + predicates: where_clause + .predicates + .pairs_mut() + .filter_map_pair_value_mut(|v| match v { + WherePredicate::Lifetime(_) => None, + _ => { + let mut contains_skipped_ident = ContainsSkippedIdent { + skipped_idents: &skipped_idents, + contains_skipped_ident: false, + }; + contains_skipped_ident.visit_where_predicate(v); + if contains_skipped_ident.contains_skipped_ident { + None + } else { + Some(v.clone()) + } + } + }) + .collect(), + }); + let struct_generics = Generics { + lt_token: generics.lt_token, + params: struct_generic_params, + gt_token: generics.gt_token, + where_clause: struct_where_clause, + }; if let Some(variadic) = variadic { errors.push(syn::Error::new_spanned(variadic, "... not allowed here")); } @@ -166,6 +231,7 @@ impl Parse for ModuleFn { sig, block, io, + struct_generics, }) } } @@ -180,6 +246,7 @@ impl ModuleFn { sig, block, io, + struct_generics, } = self; let ConfigOptions { outline_generated: _, @@ -211,10 +278,13 @@ impl ModuleFn { ModuleKind::Normal => quote! { ::fayalite::module::NormalModule }, }; let fn_name = &outer_sig.ident; + let (_struct_impl_generics, struct_type_generics, struct_where_clause) = + struct_generics.split_for_impl(); + let struct_ty = quote! {#fn_name #struct_type_generics}; body_sig.ident = parse_quote! {__body}; body_sig.inputs.insert( 0, - parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>}, + parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>}, ); let body_fn = ItemFn { attrs: vec![], @@ -223,7 +293,7 @@ impl ModuleFn { block, }; outer_sig.output = - parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>}; + parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>}; let io_flips = io .iter() .map(|io| match io.kind.kind { @@ -236,9 +306,11 @@ impl ModuleFn { let io_types = io.iter().map(|io| &io.kind.ty).collect::>(); let io_names = io.iter().map(|io| &io.name).collect::>(); let fn_name_str = fn_name.to_string(); + let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl(); + let body_turbofish_type_generics = body_type_generics.as_turbofish(); let block = parse_quote! {{ #body_fn - ::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*)) + ::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*)) }}; let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none()); let struct_options = if fixed_type { @@ -254,7 +326,7 @@ impl ModuleFn { ::fayalite::__std::fmt::Debug)] #[allow(non_camel_case_types)] #struct_options - #vis struct #fn_name { + #vis struct #fn_name #struct_generics #struct_where_clause { #( #io_flips #vis #io_names: #io_types,)* diff --git a/crates/fayalite/tests/module.rs b/crates/fayalite/tests/module.rs index d996a17..c0e49f5 100644 --- a/crates/fayalite/tests/module.rs +++ b/crates/fayalite/tests/module.rs @@ -5,6 +5,7 @@ use fayalite::{ array::Array, assert_export_firrtl, clock::{Clock, ClockDomain}, + expr::ToExpr, hdl_module, int::{DynUInt, DynUIntType, IntCmp, SInt, UInt}, intern::Intern, @@ -12,7 +13,7 @@ use fayalite::{ module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind}, reset::{SyncReset, ToReset}, source_location::SourceLocation, - ty::Value, + ty::{FixedType, Value}, }; use serde_json::json; @@ -139,22 +140,22 @@ circuit my_module: } #[hdl_module(outline_generated)] -pub fn check_array_repeat() { +pub fn check_array_repeat() { #[hdl] let i: UInt<8> = m.input(); #[hdl] - let o: Array<[UInt<8>; 3]> = m.output(); + let o: Array<[UInt<8>; N]> = m.output(); m.connect( o, #[hdl] - [i; 3], + [i; N], ); } #[test] fn test_array_repeat() { let _n = SourceLocation::normalize_files_for_tests(); - let m = check_array_repeat(); + let m = check_array_repeat::<3>(); dbg!(m); #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 assert_export_firrtl! { @@ -169,6 +170,84 @@ circuit check_array_repeat: connect _array_literal_expr[1], i connect _array_literal_expr[2], i connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1] +", + }; + let m = check_array_repeat::<4>(); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + "/test/check_array_repeat_1.fir": r"FIRRTL version 3.2.0 +circuit check_array_repeat_1: + module check_array_repeat_1: @[module-XXXXXXXXXX.rs 1:1] + input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1] + output o: UInt<8>[4] @[module-XXXXXXXXXX.rs 3:1] + wire _array_literal_expr: UInt<8>[4] + connect _array_literal_expr[0], i + connect _array_literal_expr[1], i + connect _array_literal_expr[2], i + connect _array_literal_expr[3], i + connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1] +", + }; +} + +#[hdl_module(outline_generated)] +pub fn check_skipped_generics(v: U) +where + T: Value>, + U: std::fmt::Display, +{ + dbg!(M); + #[hdl] + let i: T = m.input(); + #[hdl] + let o: Array<[T; N]> = m.output(); + let bytes = v.to_string().as_bytes().to_expr(); + #[hdl] + let o2: Array<[UInt<8>]> = m.output(bytes.ty()); + m.connect( + o, + #[hdl] + [i; N], + ); + m.connect(o2[i], bytes); +} + +#[test] +fn test_skipped_generics() { + let _n = SourceLocation::normalize_files_for_tests(); + let m = check_skipped_generics::, _, 3, 4>("Hello World!\n"); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + "/test/check_skipped_generics.fir": r"FIRRTL version 3.2.0 +circuit check_skipped_generics: + module check_skipped_generics: @[module-XXXXXXXXXX.rs 1:1] + input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1] + output o: UInt<8>[3] @[module-XXXXXXXXXX.rs 3:1] + output o2: UInt<8>[13] @[module-XXXXXXXXXX.rs 4:1] + wire _array_literal_expr: UInt<8>[3] + connect _array_literal_expr[0], i + connect _array_literal_expr[1], i + connect _array_literal_expr[2], i + connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 5:1] + wire _array_literal_expr_1: UInt<8>[13] + connect _array_literal_expr_1[0], UInt<8>(0h48) + connect _array_literal_expr_1[1], UInt<8>(0h65) + connect _array_literal_expr_1[2], UInt<8>(0h6C) + connect _array_literal_expr_1[3], UInt<8>(0h6C) + connect _array_literal_expr_1[4], UInt<8>(0h6F) + connect _array_literal_expr_1[5], UInt<8>(0h20) + connect _array_literal_expr_1[6], UInt<8>(0h57) + connect _array_literal_expr_1[7], UInt<8>(0h6F) + connect _array_literal_expr_1[8], UInt<8>(0h72) + connect _array_literal_expr_1[9], UInt<8>(0h6C) + connect _array_literal_expr_1[10], UInt<8>(0h64) + connect _array_literal_expr_1[11], UInt<8>(0h21) + connect _array_literal_expr_1[12], UInt<8>(0hA) + connect o2, _array_literal_expr_1 @[module-XXXXXXXXXX.rs 6:1] ", }; }