From 2c1afd1cd63f8da01300bf866c510f5db313da2b Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 17 Sep 2024 15:39:23 -0700 Subject: [PATCH] const generics on hdl_module work! --- .../src/hdl_type_common.rs | 359 +++++++++++++++--- .../fayalite-proc-macros-impl/src/module.rs | 65 ++-- crates/fayalite/tests/module.rs | 17 +- 3 files changed, 355 insertions(+), 86 deletions(-) diff --git a/crates/fayalite-proc-macros-impl/src/hdl_type_common.rs b/crates/fayalite-proc-macros-impl/src/hdl_type_common.rs index 0ae6480..bc0b074 100644 --- a/crates/fayalite-proc-macros-impl/src/hdl_type_common.rs +++ b/crates/fayalite-proc-macros-impl/src/hdl_type_common.rs @@ -11,8 +11,8 @@ use syn::{ AngleBracketedGenericArguments, Attribute, ConstParam, Expr, ExprIndex, ExprPath, ExprTuple, Field, FieldMutability, Fields, FieldsNamed, FieldsUnnamed, GenericArgument, GenericParam, Generics, Ident, ImplGenerics, Index, ItemStruct, Path, PathArguments, PathSegment, - PredicateType, Token, Turbofish, Type, TypeGenerics, TypeGroup, TypeParam, TypeParen, TypePath, - TypeTuple, Visibility, WhereClause, WherePredicate, + PredicateType, QSelf, Token, Turbofish, Type, TypeGenerics, TypeGroup, TypeParam, TypeParen, + TypePath, TypeTuple, Visibility, WhereClause, WherePredicate, }; crate::options! { @@ -1236,7 +1236,6 @@ impl ParseTypes for ParsedType { parser .errors .error(ident, "constant provided when a type was expected"); - todo!(); return Err(ParseFailed); } }, @@ -1524,6 +1523,7 @@ pub(crate) enum UnparsedGenericParam { ident: Ident, colon_token: Token![:], bounds: ParsedBounds, + mask_type_bounds: ParsedTypeBounds, }, Const { attrs: Vec, @@ -1532,6 +1532,7 @@ pub(crate) enum UnparsedGenericParam { ident: Ident, colon_token: Token![:], ty: ParsedConstGenericType, + bounds: Option, }, } @@ -1595,6 +1596,20 @@ pub(crate) mod known_items { Err(path) } } + #[allow(dead_code)] + pub(crate) fn parse_path_with_arguments(mut path: Path) -> Result<(Self, PathArguments), Path> { + let Some(last_segment) = path.segments.last_mut() else { + return Err(path); + }; + let arguments = std::mem::replace(&mut last_segment.arguments, PathArguments::None); + match Self::parse_path(path) { + Ok(retval) => Ok((retval, arguments)), + Err(mut path) => { + path.segments.last_mut().unwrap().arguments = arguments; + Err(path) + } + } + } } impl Parse for $known_item { @@ -1664,6 +1679,7 @@ pub(crate) mod known_items { impl_known_item!(::fayalite::ty::CanonicalType); impl_known_item!(::fayalite::ty::StaticType); impl_known_item!(::fayalite::ty::Type); + impl_known_item!(::fayalite::ty::Type::MaskType); impl_known_item!(::fayalite::util::ConstUsize); impl_known_item!(::fayalite::__std::primitive::usize); } @@ -1822,6 +1838,19 @@ macro_rules! impl_bounds { Ok(retval) } } + + impl $struct_type { + #[allow(dead_code)] + $vis fn add_implied_bounds(&mut self) { + let orig_bounds = self.clone(); + self.extend( + self.clone() + .into_iter() + .map($enum_type::implied_bounds), + ); + self.extend([orig_bounds]); // keep spans of explicitly provided bounds + } + } }; } @@ -1849,6 +1878,64 @@ impl_bounds! { } } +impl From for ParsedBound { + fn from(value: ParsedTypeBound) -> Self { + match value { + ParsedTypeBound::BundleType(v) => ParsedBound::BundleType(v), + ParsedTypeBound::EnumType(v) => ParsedBound::EnumType(v), + ParsedTypeBound::IntType(v) => ParsedBound::IntType(v), + ParsedTypeBound::StaticType(v) => ParsedBound::StaticType(v), + ParsedTypeBound::Type(v) => ParsedBound::Type(v), + } + } +} + +impl From for ParsedBounds { + fn from(value: ParsedTypeBounds) -> Self { + let ParsedTypeBounds { + BundleType, + EnumType, + IntType, + StaticType, + Type, + } = value; + Self { + BundleType, + EnumType, + IntType, + KnownSize: None, + Size: None, + StaticType, + Type, + } + } +} + +impl ParsedTypeBound { + fn implied_bounds(self) -> ParsedTypeBounds { + let span = self.span(); + match self { + Self::BundleType(v) => ParsedTypeBounds::from_iter([ + ParsedTypeBound::from(v), + ParsedTypeBound::Type(known_items::Type(span)), + ]), + Self::EnumType(v) => ParsedTypeBounds::from_iter([ + ParsedTypeBound::from(v), + ParsedTypeBound::Type(known_items::Type(span)), + ]), + Self::IntType(v) => ParsedTypeBounds::from_iter([ + ParsedTypeBound::from(v), + ParsedTypeBound::Type(known_items::Type(span)), + ]), + Self::StaticType(v) => ParsedTypeBounds::from_iter([ + ParsedTypeBound::from(v), + ParsedTypeBound::Type(known_items::Type(span)), + ]), + Self::Type(v) => ParsedTypeBounds::from_iter([ParsedTypeBound::from(v)]), + } + } +} + impl_bounds! { #[struct = ParsedSizeTypeBounds] pub(crate) enum ParsedSizeTypeBound { @@ -1857,6 +1944,43 @@ impl_bounds! { } } +impl From for ParsedBound { + fn from(value: ParsedSizeTypeBound) -> Self { + match value { + ParsedSizeTypeBound::KnownSize(v) => ParsedBound::KnownSize(v), + ParsedSizeTypeBound::Size(v) => ParsedBound::Size(v), + } + } +} + +impl From for ParsedBounds { + fn from(value: ParsedSizeTypeBounds) -> Self { + let ParsedSizeTypeBounds { KnownSize, Size } = value; + Self { + BundleType: None, + EnumType: None, + IntType: None, + KnownSize, + Size, + StaticType: None, + Type: None, + } + } +} + +impl ParsedSizeTypeBound { + fn implied_bounds(self) -> ParsedSizeTypeBounds { + let span = self.span(); + match self { + Self::KnownSize(v) => ParsedSizeTypeBounds::from_iter([ + ParsedSizeTypeBound::from(v), + ParsedSizeTypeBound::Size(known_items::Size(span)), + ]), + Self::Size(v) => ParsedSizeTypeBounds::from_iter([ParsedSizeTypeBound::from(v)]), + } + } +} + #[derive(Clone, Debug)] pub(crate) enum ParsedBoundsCategory { Type(ParsedTypeBounds), @@ -1918,42 +2042,13 @@ impl ParsedBound { } } fn implied_bounds(self) -> ParsedBounds { - let span = self.span(); - match self { - Self::BundleType(v) => ParsedBounds::from_iter([ - ParsedBound::from(v), - ParsedBound::Type(known_items::Type(span)), - ]), - Self::EnumType(v) => ParsedBounds::from_iter([ - ParsedBound::from(v), - ParsedBound::Type(known_items::Type(span)), - ]), - Self::IntType(v) => ParsedBounds::from_iter([ - ParsedBound::from(v), - ParsedBound::Type(known_items::Type(span)), - ]), - Self::KnownSize(v) => ParsedBounds::from_iter([ - ParsedBound::from(v), - ParsedBound::Size(known_items::Size(span)), - ]), - Self::Size(v) => ParsedBounds::from_iter([ParsedBound::from(v)]), - Self::StaticType(v) => ParsedBounds::from_iter([ - ParsedBound::from(v), - ParsedBound::Type(known_items::Type(span)), - ]), - Self::Type(v) => ParsedBounds::from_iter([ParsedBound::from(v)]), + match self.categorize() { + ParsedBoundCategory::Type(v) => v.implied_bounds().into(), + ParsedBoundCategory::SizeType(v) => v.implied_bounds().into(), } } } -impl ParsedBounds { - fn add_implied_bounds(&mut self) { - let orig_bounds = self.clone(); - self.extend(self.clone().into_iter().map(ParsedBound::implied_bounds)); - self.extend([orig_bounds]); // keep spans of explicitly provided bounds - } -} - #[derive(Debug, Clone)] pub(crate) struct ParsedTypeParam { pub(crate) attrs: Vec, @@ -2116,6 +2211,7 @@ impl ParsedGenerics { } ParsedGenericParam::Const(ParsedConstParam { ident, bounds, .. }) => { bounds + .bounds .KnownSize .get_or_insert_with(|| known_items::KnownSize(ident.span())); } @@ -2527,6 +2623,7 @@ impl ParsedGenerics { ident: ident.clone(), colon_token, bounds: ParsedBounds::default(), + mask_type_bounds: ParsedTypeBounds::default(), }, LateParsedParam { default: default @@ -2563,6 +2660,7 @@ impl ParsedGenerics { ident: ident.clone(), colon_token: *colon_token, ty: ParsedConstGenericType::Usize(known_items::usize(ident.span())), + bounds: None, }, LateParsedParam { default: None, @@ -2591,20 +2689,139 @@ impl ParsedGenerics { lifetimes: None, bounded_ty: Type::Path(TypePath { - qself: None, + qself, path: bounded_ty, }), - colon_token: _, + colon_token, bounds: unparsed_bounds, }) = predicate else { errors.error(predicate, "unsupported where predicate kind"); continue; }; - ParsedTypeNamed { - path: todo!(), - args: todo!(), - }; + if let Some(qself) = &qself { + if let QSelf { + lt_token: _, + ty: base_ty, + position: 3, + as_token: Some(_), + gt_token: _, + } = qself + { + if bounded_ty.segments.len() == 4 && unparsed_bounds.len() == 1 { + if let ( + Ok(_), + Type::Path(TypePath { + qself: None, + path: base_ty, + }), + ) = ( + known_items::MaskType::parse_path(bounded_ty.clone()), + &**base_ty, + ) { + let Some(&index) = base_ty + .get_ident() + .and_then(|base_ty| param_name_to_index_map.get(base_ty)) + else { + errors.error( + TypePath { + qself: Some(qself.clone()), + path: bounded_ty, + }, + "unsupported where predicate kind", + ); + continue; + }; + let parsed_bounds = match &mut unparsed_params[index] { + UnparsedGenericParam::Type { + mask_type_bounds, .. + } => mask_type_bounds, + UnparsedGenericParam::Const { ident, .. } => { + errors.error( + bounded_ty, + format_args!( + "expected type, found const parameter `{ident}`" + ), + ); + continue; + } + }; + parsed_bounds.extend(errors.ok(syn::parse2::( + unparsed_bounds.to_token_stream(), + ))); + continue; + } + } + } + errors.error( + TypePath { + qself: Some(qself.clone()), + path: bounded_ty, + }, + "unsupported where predicate kind", + ); + continue; + } + if let Ok(( + const_usize, + PathArguments::AngleBracketed(AngleBracketedGenericArguments { + colon2_token: _, + lt_token, + args, + gt_token, + }), + )) = known_items::ConstUsize::parse_path_with_arguments(bounded_ty.clone()) + { + if args.len() != 1 { + errors.error(const_usize, "ConstUsize must have one argument"); + continue; + } + let GenericArgument::Type(Type::Path(TypePath { + qself: None, + path: arg, + })) = &args[0] + else { + errors.error( + const_usize, + "the only supported ConstUsize argument is a const generic parameter", + ); + continue; + }; + let arg = arg.get_ident(); + let Some((arg, &index)) = + arg.and_then(|arg| Some((arg, param_name_to_index_map.get(arg)?))) + else { + errors.error( + const_usize, + "the only supported ConstUsize argument is a const generic parameter", + ); + continue; + }; + let parsed_bounds = match &mut unparsed_params[index] { + UnparsedGenericParam::Const { bounds, .. } => bounds, + UnparsedGenericParam::Type { ident, .. } => { + errors.error( + bounded_ty, + format_args!("expected const generic parameter, found type `{ident}`"), + ); + continue; + } + }; + parsed_bounds + .get_or_insert_with(|| ParsedConstParamWhereBounds { + const_usize, + lt_token, + ident: arg.clone(), + gt_token, + colon_token, + bounds: ParsedSizeTypeBounds::default(), + }) + .bounds + .extend(errors.ok(syn::parse2::( + unparsed_bounds.to_token_stream(), + ))); + continue; + } let Some(&index) = bounded_ty .get_ident() .and_then(|bounded_ty| param_name_to_index_map.get(bounded_ty)) @@ -2639,7 +2856,26 @@ impl ParsedGenerics { ident, colon_token, mut bounds, + mask_type_bounds, } => { + for bound in mask_type_bounds { + bounds + .Type + .get_or_insert_with(|| known_items::Type(bound.span())); + match bound { + ParsedTypeBound::BundleType(_) + | ParsedTypeBound::EnumType(_) + | ParsedTypeBound::IntType(_) => { + errors.error(bound, "bound on mask type not implemented"); + } + ParsedTypeBound::StaticType(bound) => { + if bounds.StaticType.is_none() { + errors.error(bound, "StaticType bound on mask type without corresponding StaticType bound on original type is not implemented"); + } + }, + ParsedTypeBound::Type(_) => {} + } + } bounds.add_implied_bounds(); match bounds.categorize(&mut errors, ident.span()) { ParsedBoundsCategory::Type(bounds) => { @@ -2671,18 +2907,35 @@ impl ParsedGenerics { ident, colon_token, ty, - } => ParsedGenericParam::Const(ParsedConstParam { - bounds: ParsedSizeTypeBounds { - KnownSize: None, - Size: Some(known_items::Size(ident.span())), - }, - attrs, - options, - const_token, - ident, - colon_token, - ty, - }), + bounds, + } => { + let span = ident.span(); + let mut bounds = bounds.unwrap_or_else(|| ParsedConstParamWhereBounds { + const_usize: known_items::ConstUsize(span), + lt_token: Token![<](span), + ident: ident.clone(), + gt_token: Token![>](span), + colon_token: Token![:](span), + bounds: ParsedSizeTypeBounds { + KnownSize: None, + Size: Some(known_items::Size(span)), + }, + }); + bounds + .bounds + .Size + .get_or_insert_with(|| known_items::Size(span)); + bounds.bounds.add_implied_bounds(); + ParsedGenericParam::Const(ParsedConstParam { + bounds, + attrs, + options, + const_token, + ident, + colon_token, + ty, + }) + } }, )); let mut retval = Self { diff --git a/crates/fayalite-proc-macros-impl/src/module.rs b/crates/fayalite-proc-macros-impl/src/module.rs index 7d816d3..0945abb 100644 --- a/crates/fayalite-proc-macros-impl/src/module.rs +++ b/crates/fayalite-proc-macros-impl/src/module.rs @@ -66,8 +66,8 @@ pub(crate) struct ModuleFn { vis: Visibility, sig: Signature, block: Box, - io: Vec, struct_generics: ParsedGenerics, + the_struct: TokenStream, } #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] @@ -224,6 +224,41 @@ impl Parse for ModuleFn { errors.finish()?; let struct_generics = struct_generics.unwrap(); let (block, io) = body_results.unwrap(); + let (_struct_impl_generics, _struct_type_generics, struct_where_clause) = + struct_generics.split_for_impl(); + let struct_where_clause: Option = parse_quote! { #struct_where_clause }; + if let Some(struct_where_clause) = &struct_where_clause { + sig.generics + .where_clause + .get_or_insert_with(|| WhereClause { + where_token: struct_where_clause.where_token, + predicates: Default::default(), + }) + .predicates + .extend(struct_where_clause.predicates.clone()); + } + let fn_name = &sig.ident; + let io_flips = io + .iter() + .map(|io| match io.kind.kind { + ModuleIOKind::Input((input,)) => quote_spanned! {input.span=> + #[hdl(flip)] + }, + ModuleIOKind::Output(_) => quote! {}, + }) + .collect::>(); + let io_types = io.iter().map(|io| &io.kind.ty).collect::>(); + let io_names = io.iter().map(|io| &io.name).collect::>(); + let the_struct: ItemStruct = parse_quote! { + #[allow(non_camel_case_types)] + #[hdl(no_runtime_generics, no_static)] + #vis struct #fn_name #struct_generics #struct_where_clause { + #( + #io_flips + #vis #io_names: #io_types,)* + } + }; + let the_struct = crate::hdl_bundle::hdl_bundle(the_struct)?; Ok(Self { attrs, config_options, @@ -231,8 +266,8 @@ impl Parse for ModuleFn { vis, sig, block, - io, struct_generics, + the_struct, }) } } @@ -246,8 +281,8 @@ impl ModuleFn { vis, sig, block, - io, struct_generics, + the_struct, } = self; let ConfigOptions { outline_generated: _, @@ -279,7 +314,7 @@ impl ModuleFn { ModuleKind::Normal => quote! { ::fayalite::module::ModuleKind::Normal }, }; let fn_name = &outer_sig.ident; - let (_struct_impl_generics, struct_type_generics, struct_where_clause) = + let (_struct_impl_generics, struct_type_generics, _struct_where_clause) = struct_generics.split_for_impl(); let struct_ty = quote! {#fn_name #struct_type_generics}; body_sig.ident = parse_quote! {__body}; @@ -294,17 +329,6 @@ impl ModuleFn { }; outer_sig.output = parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>}; - let io_flips = io - .iter() - .map(|io| match io.kind.kind { - ModuleIOKind::Input((input,)) => quote_spanned! {input.span=> - #[hdl(flip)] - }, - ModuleIOKind::Output(_) => quote! {}, - }) - .collect::>(); - let io_types = io.iter().map(|io| &io.kind.ty).collect::>(); - let io_names = io.iter().map(|io| &io.name).collect::>(); let fn_name_str = fn_name.to_string(); let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl(); let body_turbofish_type_generics = body_type_generics.as_turbofish(); @@ -316,15 +340,6 @@ impl ModuleFn { |m| __body #body_turbofish_type_generics(m, #(#param_names,)*), ) }}; - let the_struct: ItemStruct = parse_quote! { - #[allow(non_camel_case_types)] - #[hdl(no_runtime_generics, no_static)] - #vis struct #fn_name #struct_generics #struct_where_clause { - #( - #io_flips - #vis #io_names: #io_types,)* - } - }; let outer_fn = ItemFn { attrs, vis, @@ -332,7 +347,7 @@ impl ModuleFn { block, }; let mut retval = outer_fn.into_token_stream(); - retval.extend(crate::hdl_bundle::hdl_bundle(the_struct).unwrap()); + retval.extend(the_struct); retval } } diff --git a/crates/fayalite/tests/module.rs b/crates/fayalite/tests/module.rs index 8df3771..f039750 100644 --- a/crates/fayalite/tests/module.rs +++ b/crates/fayalite/tests/module.rs @@ -6,6 +6,7 @@ use fayalite::{ intern::Intern, module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind}, prelude::*, + ty::StaticType, }; use serde_json::json; @@ -133,9 +134,11 @@ circuit my_module: }; } -#[cfg(todo)] #[hdl_module(outline_generated)] -pub fn check_array_repeat() { +pub fn check_array_repeat() +where + ConstUsize: KnownSize, +{ #[hdl] let i: UInt<8> = m.input(); #[hdl] @@ -147,7 +150,6 @@ pub fn check_array_repeat() { ); } -#[cfg(todo)] #[test] fn test_array_repeat() { let _n = SourceLocation::normalize_files_for_tests(); @@ -188,21 +190,21 @@ circuit check_array_repeat_1: }; } -#[cfg(todo)] #[hdl_module(outline_generated)] pub fn check_skipped_generics(v: U) where - T: StaticValue, + T: StaticType, + ConstUsize: KnownSize, U: std::fmt::Display, { dbg!(M); #[hdl] let i: T = m.input(); #[hdl] - let o: Array<[T; N]> = m.output(); + let o: Array = m.output(); let bytes = v.to_string().as_bytes().to_expr(); #[hdl] - let o2: Array<[UInt<8>]> = m.output(bytes.ty()); + let o2: Array> = m.output(Expr::ty(bytes)); connect( o, #[hdl] @@ -211,7 +213,6 @@ where connect(o2, bytes); } -#[cfg(todo)] #[test] fn test_skipped_generics() { let _n = SourceLocation::normalize_files_for_tests();