// SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information use crate::{ hdl_type_common::{ParsedGenerics, SplitForImpl}, kw, module::transform_body::{HdlLet, HdlLetKindIO}, options, Errors, HdlAttr, PairsIterExt, }; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use std::collections::HashSet; use syn::{ parse_quote, visit::{visit_pat, Visit}, Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct, LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate, }; mod transform_body; options! { #[options = ConfigOptions] #[no_ident_fragment] pub(crate) enum ConfigOption { OutlineGenerated(outline_generated), Extern(extern_), } } options! { pub(crate) enum ModuleIOKind { Input(input), Output(output), } } pub(crate) fn check_name_conflicts_with_module_builder(name: &Ident) -> syn::Result<()> { if name == "m" { Err(Error::new_spanned( name, "name conflicts with implicit `m: &mut ModuleBuilder<_>`", )) } else { Ok(()) } } pub(crate) struct CheckNameConflictsWithModuleBuilderVisitor<'a> { pub(crate) errors: &'a mut Errors, } impl Visit<'_> for CheckNameConflictsWithModuleBuilderVisitor<'_> { // TODO: change this to only check for identifiers defining new variables fn visit_ident(&mut self, node: &Ident) { self.errors .push_result(check_name_conflicts_with_module_builder(node)); } } pub(crate) type ModuleIO = HdlLet; struct ModuleFnModule { attrs: Vec, config_options: HdlAttr, module_kind: ModuleKind, vis: Visibility, sig: Signature, block: Box, struct_generics: ParsedGenerics, the_struct: TokenStream, } enum ModuleFnImpl { Module(ModuleFnModule), Fn { attrs: Vec, config_options: HdlAttr, vis: Visibility, sig: Signature, block: Box, }, } options! { pub(crate) enum HdlOrHdlModule { Hdl(hdl), HdlModule(hdl_module), } } pub(crate) struct ModuleFn(ModuleFnImpl); #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] pub(crate) enum ModuleKind { Extern, Normal, } struct ContainsSkippedIdent<'a> { skipped_idents: &'a HashSet, contains_skipped_ident: bool, } impl Visit<'_> for ContainsSkippedIdent<'_> { fn visit_ident(&mut self, ident: &'_ Ident) { if self.skipped_idents.contains(ident) { self.contains_skipped_ident = true; } } } impl ModuleFn { pub(crate) fn config_options(&self) -> ConfigOptions { let (ModuleFnImpl::Module(ModuleFnModule { config_options: HdlAttr { body, .. }, .. }) | ModuleFnImpl::Fn { config_options: HdlAttr { body, .. }, .. }) = &self.0; body.clone() } pub(crate) fn parse_from_fn(item: ItemFn) -> syn::Result { let ItemFn { mut attrs, vis, mut sig, block, } = item; let Signature { ref constness, ref asyncness, ref unsafety, ref abi, fn_token: _, ident: _, ref mut generics, paren_token: _, ref mut inputs, ref variadic, ref output, } = sig; let mut errors = Errors::new(); let Some(mut config_options) = errors.unwrap_or_default( HdlAttr::::parse_and_take_attr(&mut attrs), ) else { errors.error(sig.ident, "missing #[hdl] or #[hdl_module] attribute"); errors.finish()?; unreachable!(); }; let ConfigOptions { outline_generated: _, extern_, } = config_options.body; let module_kind = match (config_options.kw, extern_) { (HdlOrHdlModule::Hdl(_), None) => None, (HdlOrHdlModule::Hdl(_), Some(extern2)) => { config_options.body.extern_ = None; errors.error( extern2.0, "extern can only be used as #[hdl_module(extern)]", ); None } (HdlOrHdlModule::HdlModule(_), None) => Some(ModuleKind::Normal), (HdlOrHdlModule::HdlModule(_), Some(_)) => Some(ModuleKind::Extern), }; if let HdlOrHdlModule::HdlModule(_) = config_options.kw { for fn_arg in inputs { match fn_arg { FnArg::Receiver(_) => { errors.push(syn::Error::new_spanned(fn_arg, "self not allowed here")); } FnArg::Typed(fn_arg) => { visit_pat( &mut CheckNameConflictsWithModuleBuilderVisitor { errors: &mut errors, }, &fn_arg.pat, ); } } } if let Some(constness) = constness { errors.push(syn::Error::new_spanned(constness, "const not allowed here")); } if let Some(asyncness) = asyncness { errors.push(syn::Error::new_spanned(asyncness, "async not allowed here")); } if let Some(unsafety) = unsafety { errors.push(syn::Error::new_spanned(unsafety, "unsafe not allowed here")); } if let Some(abi) = abi { errors.push(syn::Error::new_spanned(abi, "extern not allowed here")); } } let mut skipped_idents = HashSet::new(); let struct_generic_params = generics .params .pairs_mut() .filter_map_pair_value_mut(|v| match v { GenericParam::Lifetime(LifetimeParam { attrs, .. }) => { errors.unwrap_or_default( HdlAttr::::parse_and_take_attr(attrs), ); None } GenericParam::Type(TypeParam { attrs, ident, .. }) | GenericParam::Const(ConstParam { attrs, ident, .. }) => { if errors .unwrap_or_default( HdlAttr::::parse_and_take_attr(attrs), ) .is_some() { skipped_idents.insert(ident.clone()); None } else { Some(v.clone()) } } }) .collect(); let struct_where_clause = generics .where_clause .as_mut() .filter(|_| matches!(config_options.kw, HdlOrHdlModule::HdlModule(_))) .map(|where_clause| WhereClause { where_token: where_clause.where_token, predicates: where_clause .predicates .pairs_mut() .filter_map_pair_value_mut(|v| match v { WherePredicate::Lifetime(_) => None, _ => { let mut contains_skipped_ident = ContainsSkippedIdent { skipped_idents: &skipped_idents, contains_skipped_ident: false, }; contains_skipped_ident.visit_where_predicate(v); if contains_skipped_ident.contains_skipped_ident { None } else { Some(v.clone()) } } }) .collect(), }); let struct_generics = if let HdlOrHdlModule::HdlModule(_) = config_options.kw { let mut struct_generics = Generics { lt_token: generics.lt_token, params: struct_generic_params, gt_token: generics.gt_token, where_clause: struct_where_clause, }; if let Some(variadic) = variadic { errors.push(syn::Error::new_spanned(variadic, "... not allowed here")); } if !matches!(output, ReturnType::Default) { errors.push(syn::Error::new_spanned( output, "return type not allowed here", )); } errors.ok(ParsedGenerics::parse(&mut struct_generics)) } else { Some(ParsedGenerics::default()) }; let body_results = struct_generics.as_ref().and_then(|struct_generics| { errors.ok(transform_body::transform_body( module_kind, block, struct_generics, )) }); errors.finish()?; let struct_generics = struct_generics.unwrap(); let (block, io) = body_results.unwrap(); let config_options = match config_options { HdlAttr { pound_token, style, bracket_token, kw: HdlOrHdlModule::Hdl((kw,)), paren_token, body, } => { debug_assert!(io.is_empty()); return Ok(Self(ModuleFnImpl::Fn { attrs, config_options: HdlAttr { pound_token, style, bracket_token, kw, paren_token, body, }, vis, sig, block, })); } HdlAttr { pound_token, style, bracket_token, kw: HdlOrHdlModule::HdlModule((kw,)), paren_token, body, } => HdlAttr { pound_token, style, bracket_token, kw, paren_token, body, }, }; let (_struct_impl_generics, _struct_type_generics, struct_where_clause) = struct_generics.split_for_impl(); let struct_where_clause: Option = parse_quote! { #struct_where_clause }; if let Some(struct_where_clause) = &struct_where_clause { sig.generics .where_clause .get_or_insert_with(|| WhereClause { where_token: struct_where_clause.where_token, predicates: Default::default(), }) .predicates .extend(struct_where_clause.predicates.clone()); } let fn_name = &sig.ident; let io_flips = io .iter() .map(|io| match io.kind.kind { ModuleIOKind::Input((input,)) => quote_spanned! {input.span=> #[hdl(flip)] }, ModuleIOKind::Output(_) => quote! {}, }) .collect::>(); let io_types = io.iter().map(|io| &io.kind.ty).collect::>(); let io_names = io.iter().map(|io| &io.name).collect::>(); let the_struct: ItemStruct = parse_quote! { #[allow(non_camel_case_types)] #[hdl(no_runtime_generics, no_static)] #vis struct #fn_name #struct_generics #struct_where_clause { #( #io_flips #vis #io_names: #io_types,)* } }; let the_struct = crate::hdl_bundle::hdl_bundle(the_struct)?; Ok(Self(ModuleFnImpl::Module(ModuleFnModule { attrs, config_options, module_kind: module_kind.unwrap(), vis, sig, block, struct_generics, the_struct, }))) } } impl ModuleFn { pub(crate) fn generate(self) -> TokenStream { let ModuleFnModule { attrs, config_options, module_kind, vis, sig, block, struct_generics, the_struct, } = match self.0 { ModuleFnImpl::Module(v) => v, ModuleFnImpl::Fn { attrs, config_options, vis, sig, block, } => { let ConfigOptions { outline_generated: _, extern_: _, } = config_options.body; return ItemFn { attrs, vis, sig, block, } .into_token_stream(); } }; let ConfigOptions { outline_generated: _, extern_: _, } = config_options.body; let mut outer_sig = sig.clone(); let mut body_sig = sig; let param_names = Vec::from_iter(outer_sig.inputs.iter_mut().enumerate().map(|(index, arg)| { let FnArg::Typed(arg) = arg else { unreachable!("already checked"); }; let name = if let syn::Pat::Ident(pat) = &*arg.pat { pat.ident.clone() } else { format_ident!("__param{}", index) }; *arg.pat = syn::Pat::Ident(syn::PatIdent { attrs: vec![], by_ref: None, mutability: None, ident: name.clone(), subpat: None, }); name })); let module_kind_value = match module_kind { ModuleKind::Extern => quote! { ::fayalite::module::ModuleKind::Extern }, ModuleKind::Normal => quote! { ::fayalite::module::ModuleKind::Normal }, }; let fn_name = &outer_sig.ident; let (_struct_impl_generics, struct_type_generics, _struct_where_clause) = struct_generics.split_for_impl(); let struct_ty = quote! {#fn_name #struct_type_generics}; body_sig.ident = parse_quote! {__body}; body_sig .inputs .insert(0, parse_quote! { m: &::fayalite::module::ModuleBuilder }); let body_fn = ItemFn { attrs: vec![], vis: Visibility::Inherited, sig: body_sig, block, }; outer_sig.output = parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>}; let fn_name_str = fn_name.to_string(); let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl(); let body_turbofish_type_generics = body_type_generics.as_turbofish(); let block = parse_quote! {{ #body_fn ::fayalite::module::ModuleBuilder::run( #fn_name_str, #module_kind_value, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*), ) }}; let outer_fn = ItemFn { attrs, vis, sig: outer_sig, block, }; let mut retval = outer_fn.into_token_stream(); retval.extend(the_struct); retval } }