// SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information use crate::{ hdl_type_common::{ParsedGenerics, SplitForImpl}, module::transform_body::{HdlLet, HdlLetKindIO}, options, Errors, HdlAttr, PairsIterExt, }; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use std::collections::HashSet; use syn::{ parse::{Parse, ParseStream}, parse_quote, visit::{visit_pat, Visit}, Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct, LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate, }; mod transform_body; options! { #[options = ConfigOptions] #[no_ident_fragment] pub(crate) enum ConfigOption { OutlineGenerated(outline_generated), Extern(extern_), } } options! { pub(crate) enum ModuleIOKind { Input(input), Output(output), } } pub(crate) fn check_name_conflicts_with_module_builder(name: &Ident) -> syn::Result<()> { if name == "m" { Err(Error::new_spanned( name, "name conflicts with implicit `m: &mut ModuleBuilder<_>`", )) } else { Ok(()) } } pub(crate) struct CheckNameConflictsWithModuleBuilderVisitor<'a> { pub(crate) errors: &'a mut Errors, } impl Visit<'_> for CheckNameConflictsWithModuleBuilderVisitor<'_> { // TODO: change this to only check for identifiers defining new variables fn visit_ident(&mut self, node: &Ident) { self.errors .push_result(check_name_conflicts_with_module_builder(node)); } } pub(crate) type ModuleIO = HdlLet; pub(crate) struct ModuleFn { attrs: Vec, config_options: HdlAttr, module_kind: ModuleKind, vis: Visibility, sig: Signature, block: Box, io: Vec, struct_generics: ParsedGenerics, } #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] pub(crate) enum ModuleKind { Extern, Normal, } struct ContainsSkippedIdent<'a> { skipped_idents: &'a HashSet, contains_skipped_ident: bool, } impl Visit<'_> for ContainsSkippedIdent<'_> { fn visit_ident(&mut self, ident: &'_ Ident) { if self.skipped_idents.contains(ident) { self.contains_skipped_ident = true; } } } impl Parse for ModuleFn { fn parse(input: ParseStream) -> syn::Result { let ItemFn { mut attrs, vis, mut sig, block, } = input.parse()?; let Signature { ref constness, ref asyncness, ref unsafety, ref abi, fn_token: _, ident: _, ref mut generics, paren_token: _, ref mut inputs, ref variadic, ref output, } = sig; let mut errors = Errors::new(); let config_options = errors .unwrap_or_default(HdlAttr::parse_and_take_attr(&mut attrs)) .unwrap_or_default(); let ConfigOptions { outline_generated: _, extern_, } = config_options.body; let module_kind = match extern_ { Some(_) => ModuleKind::Extern, None => ModuleKind::Normal, }; for fn_arg in inputs { match fn_arg { FnArg::Receiver(_) => { errors.push(syn::Error::new_spanned(fn_arg, "self not allowed here")); } FnArg::Typed(fn_arg) => { visit_pat( &mut CheckNameConflictsWithModuleBuilderVisitor { errors: &mut errors, }, &fn_arg.pat, ); } } } if let Some(constness) = constness { errors.push(syn::Error::new_spanned(constness, "const not allowed here")); } if let Some(asyncness) = asyncness { errors.push(syn::Error::new_spanned(asyncness, "async not allowed here")); } if let Some(unsafety) = unsafety { errors.push(syn::Error::new_spanned(unsafety, "unsafe not allowed here")); } if let Some(abi) = abi { errors.push(syn::Error::new_spanned(abi, "extern not allowed here")); } let mut skipped_idents = HashSet::new(); let struct_generic_params = generics .params .pairs_mut() .filter_map_pair_value_mut(|v| match v { GenericParam::Lifetime(LifetimeParam { attrs, .. }) => { errors .unwrap_or_default(HdlAttr::::parse_and_take_attr(attrs)); None } GenericParam::Type(TypeParam { attrs, ident, .. }) | GenericParam::Const(ConstParam { attrs, ident, .. }) => { if errors .unwrap_or_default(HdlAttr::::parse_and_take_attr(attrs)) .is_some() { skipped_idents.insert(ident.clone()); None } else { Some(v.clone()) } } }) .collect(); let struct_where_clause = generics .where_clause .as_mut() .map(|where_clause| WhereClause { where_token: where_clause.where_token, predicates: where_clause .predicates .pairs_mut() .filter_map_pair_value_mut(|v| match v { WherePredicate::Lifetime(_) => None, _ => { let mut contains_skipped_ident = ContainsSkippedIdent { skipped_idents: &skipped_idents, contains_skipped_ident: false, }; contains_skipped_ident.visit_where_predicate(v); if contains_skipped_ident.contains_skipped_ident { None } else { Some(v.clone()) } } }) .collect(), }); let struct_generics = Generics { lt_token: generics.lt_token, params: struct_generic_params, gt_token: generics.gt_token, where_clause: struct_where_clause, }; if let Some(variadic) = variadic { errors.push(syn::Error::new_spanned(variadic, "... not allowed here")); } if !matches!(output, ReturnType::Default) { errors.push(syn::Error::new_spanned( output, "return type not allowed here", )); } let struct_generics = errors.ok(ParsedGenerics::parse(&mut { struct_generics })); let body_results = struct_generics.as_ref().and_then(|struct_generics| { errors.ok(transform_body::transform_body( module_kind, block, struct_generics, )) }); errors.finish()?; let struct_generics = struct_generics.unwrap(); let (block, io) = body_results.unwrap(); Ok(Self { attrs, config_options, module_kind, vis, sig, block, io, struct_generics, }) } } impl ModuleFn { pub(crate) fn generate(self) -> TokenStream { let Self { attrs, config_options, module_kind, vis, sig, block, io, struct_generics, } = self; let ConfigOptions { outline_generated: _, extern_: _, } = config_options.body; let mut outer_sig = sig.clone(); let mut body_sig = sig; let param_names = Vec::from_iter(outer_sig.inputs.iter_mut().enumerate().map(|(index, arg)| { let FnArg::Typed(arg) = arg else { unreachable!("already checked"); }; let name = if let syn::Pat::Ident(pat) = &*arg.pat { pat.ident.clone() } else { format_ident!("__param{}", index) }; *arg.pat = syn::Pat::Ident(syn::PatIdent { attrs: vec![], by_ref: None, mutability: None, ident: name.clone(), subpat: None, }); name })); let module_kind_value = match module_kind { ModuleKind::Extern => quote! { ::fayalite::module::ModuleKind::Extern }, ModuleKind::Normal => quote! { ::fayalite::module::ModuleKind::Normal }, }; let fn_name = &outer_sig.ident; let (_struct_impl_generics, struct_type_generics, struct_where_clause) = struct_generics.split_for_impl(); let struct_ty = quote! {#fn_name #struct_type_generics}; body_sig.ident = parse_quote! {__body}; body_sig .inputs .insert(0, parse_quote! { m: &::fayalite::module::ModuleBuilder }); let body_fn = ItemFn { attrs: vec![], vis: Visibility::Inherited, sig: body_sig, block, }; outer_sig.output = parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>}; let io_flips = io .iter() .map(|io| match io.kind.kind { ModuleIOKind::Input((input,)) => quote_spanned! {input.span=> #[hdl(flip)] }, ModuleIOKind::Output(_) => quote! {}, }) .collect::>(); let io_types = io.iter().map(|io| &io.kind.ty).collect::>(); let io_names = io.iter().map(|io| &io.name).collect::>(); let fn_name_str = fn_name.to_string(); let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl(); let body_turbofish_type_generics = body_type_generics.as_turbofish(); let block = parse_quote! {{ #body_fn ::fayalite::module::ModuleBuilder::run( #fn_name_str, #module_kind_value, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*), ) }}; let the_struct: ItemStruct = parse_quote! { #[allow(non_camel_case_types)] #[hdl(no_runtime_generics, no_static)] #vis struct #fn_name #struct_generics #struct_where_clause { #( #io_flips #vis #io_names: #io_types,)* } }; let outer_fn = ItemFn { attrs, vis, sig: outer_sig, block, }; let mut retval = outer_fn.into_token_stream(); retval.extend(crate::hdl_bundle::hdl_bundle(the_struct).unwrap()); retval } }