make #[hdl_module] support functions with generic parameters

This commit is contained in:
Jacob Lifshay 2024-07-25 22:10:33 -07:00
parent ef4b3b4081
commit b33566841d
Signed by: programmerjake
SSH key fingerprint: SHA256:B1iRVvUJkvd7upMIiMqn6OyxvD2SgJkAH3ZnUOj6z+c
2 changed files with 169 additions and 18 deletions

View file

@ -3,15 +3,17 @@
use crate::{
is_hdl_attr,
module::transform_body::{HdlLet, HdlLetKindIO},
options, Errors, HdlAttr,
options, Errors, HdlAttr, PairsIterExt,
};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::HashSet;
use syn::{
parse::{Parse, ParseStream},
parse_quote,
visit::{visit_pat, Visit},
Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility,
Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct,
LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate,
};
mod transform_body;
@ -72,6 +74,7 @@ pub(crate) struct ModuleFn {
sig: Signature,
block: Box<Block>,
io: Vec<ModuleIO>,
struct_generics: Generics,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
@ -80,6 +83,19 @@ pub(crate) enum ModuleKind {
Normal,
}
struct ContainsSkippedIdent<'a> {
skipped_idents: &'a HashSet<Ident>,
contains_skipped_ident: bool,
}
impl Visit<'_> for ContainsSkippedIdent<'_> {
fn visit_ident(&mut self, ident: &'_ Ident) {
if self.skipped_idents.contains(ident) {
self.contains_skipped_ident = true;
}
}
}
impl Parse for ModuleFn {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ItemFn {
@ -95,7 +111,7 @@ impl Parse for ModuleFn {
ref abi,
fn_token: _,
ident: _,
ref generics,
ref mut generics,
paren_token: _,
ref mut inputs,
ref variadic,
@ -140,12 +156,61 @@ impl Parse for ModuleFn {
if let Some(abi) = abi {
errors.push(syn::Error::new_spanned(abi, "extern not allowed here"));
}
if !generics.params.is_empty() {
errors.push(syn::Error::new_spanned(
&generics.params,
"generics are not supported yet",
));
}
let mut skipped_idents = HashSet::new();
let struct_generic_params = generics
.params
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
GenericParam::Lifetime(LifetimeParam { attrs, .. }) => {
errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs));
None
}
GenericParam::Type(TypeParam { attrs, ident, .. })
| GenericParam::Const(ConstParam { attrs, ident, .. }) => {
if errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs))
.is_some()
{
skipped_idents.insert(ident.clone());
None
} else {
Some(v.clone())
}
}
})
.collect();
let struct_where_clause = generics
.where_clause
.as_mut()
.map(|where_clause| WhereClause {
where_token: where_clause.where_token,
predicates: where_clause
.predicates
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
WherePredicate::Lifetime(_) => None,
_ => {
let mut contains_skipped_ident = ContainsSkippedIdent {
skipped_idents: &skipped_idents,
contains_skipped_ident: false,
};
contains_skipped_ident.visit_where_predicate(v);
if contains_skipped_ident.contains_skipped_ident {
None
} else {
Some(v.clone())
}
}
})
.collect(),
});
let struct_generics = Generics {
lt_token: generics.lt_token,
params: struct_generic_params,
gt_token: generics.gt_token,
where_clause: struct_where_clause,
};
if let Some(variadic) = variadic {
errors.push(syn::Error::new_spanned(variadic, "... not allowed here"));
}
@ -166,6 +231,7 @@ impl Parse for ModuleFn {
sig,
block,
io,
struct_generics,
})
}
}
@ -180,6 +246,7 @@ impl ModuleFn {
sig,
block,
io,
struct_generics,
} = self;
let ConfigOptions {
outline_generated: _,
@ -211,10 +278,13 @@ impl ModuleFn {
ModuleKind::Normal => quote! { ::fayalite::module::NormalModule },
};
let fn_name = &outer_sig.ident;
let (_struct_impl_generics, struct_type_generics, struct_where_clause) =
struct_generics.split_for_impl();
let struct_ty = quote! {#fn_name #struct_type_generics};
body_sig.ident = parse_quote! {__body};
body_sig.inputs.insert(
0,
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>},
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>},
);
let body_fn = ItemFn {
attrs: vec![],
@ -223,7 +293,7 @@ impl ModuleFn {
block,
};
outer_sig.output =
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>};
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>};
let io_flips = io
.iter()
.map(|io| match io.kind.kind {
@ -236,9 +306,11 @@ impl ModuleFn {
let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>();
let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>();
let fn_name_str = fn_name.to_string();
let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl();
let body_turbofish_type_generics = body_type_generics.as_turbofish();
let block = parse_quote! {{
#body_fn
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*))
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*))
}};
let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none());
let struct_options = if fixed_type {
@ -254,7 +326,7 @@ impl ModuleFn {
::fayalite::__std::fmt::Debug)]
#[allow(non_camel_case_types)]
#struct_options
#vis struct #fn_name {
#vis struct #fn_name #struct_generics #struct_where_clause {
#(
#io_flips
#vis #io_names: #io_types,)*