fayalite/crates/fayalite-proc-macros-impl/src/module.rs

355 lines
12 KiB
Rust
Raw Normal View History

2024-06-11 06:09:13 +00:00
// SPDX-License-Identifier: LGPL-3.0-or-later
// See Notices.txt for copyright information
use crate::{
is_hdl_attr,
module::transform_body::{HdlLet, HdlLetKindIO},
options, Errors, HdlAttr, PairsIterExt,
2024-06-11 06:09:13 +00:00
};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::HashSet;
2024-06-11 06:09:13 +00:00
use syn::{
parse::{Parse, ParseStream},
parse_quote,
visit::{visit_pat, Visit},
Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct,
LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate,
2024-06-11 06:09:13 +00:00
};
mod transform_body;
options! {
#[options = ConfigOptions]
#[no_ident_fragment]
pub(crate) enum ConfigOption {
OutlineGenerated(outline_generated),
Extern(extern_),
}
}
options! {
pub(crate) enum ModuleIOKind {
Input(input),
Output(output),
}
}
pub(crate) fn check_name_conflicts_with_module_builder(name: &Ident) -> syn::Result<()> {
if name == "m" {
Err(Error::new_spanned(
name,
"name conflicts with implicit `m: &mut ModuleBuilder<_>`",
))
} else {
Ok(())
}
}
pub(crate) struct CheckNameConflictsWithModuleBuilderVisitor<'a> {
pub(crate) errors: &'a mut Errors,
}
impl Visit<'_> for CheckNameConflictsWithModuleBuilderVisitor<'_> {
// TODO: change this to only check for identifiers defining new variables
fn visit_ident(&mut self, node: &Ident) {
self.errors
.push_result(check_name_conflicts_with_module_builder(node));
}
}
fn retain_struct_attrs<F: FnMut(&Attribute) -> bool>(item: &mut ItemStruct, mut f: F) {
item.attrs.retain(&mut f);
for field in item.fields.iter_mut() {
field.attrs.retain(&mut f);
}
}
pub(crate) type ModuleIO = HdlLet<HdlLetKindIO>;
pub(crate) struct ModuleFn {
attrs: Vec<Attribute>,
config_options: HdlAttr<ConfigOptions>,
module_kind: ModuleKind,
vis: Visibility,
sig: Signature,
block: Box<Block>,
io: Vec<ModuleIO>,
struct_generics: Generics,
2024-06-11 06:09:13 +00:00
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub(crate) enum ModuleKind {
Extern,
Normal,
}
struct ContainsSkippedIdent<'a> {
skipped_idents: &'a HashSet<Ident>,
contains_skipped_ident: bool,
}
impl Visit<'_> for ContainsSkippedIdent<'_> {
fn visit_ident(&mut self, ident: &'_ Ident) {
if self.skipped_idents.contains(ident) {
self.contains_skipped_ident = true;
}
}
}
2024-06-11 06:09:13 +00:00
impl Parse for ModuleFn {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ItemFn {
mut attrs,
vis,
mut sig,
block,
} = input.parse()?;
let Signature {
ref constness,
ref asyncness,
ref unsafety,
ref abi,
fn_token: _,
ident: _,
ref mut generics,
2024-06-11 06:09:13 +00:00
paren_token: _,
ref mut inputs,
ref variadic,
ref output,
} = sig;
let mut errors = Errors::new();
let config_options = errors
.unwrap_or_default(HdlAttr::parse_and_take_attr(&mut attrs))
.unwrap_or_default();
let ConfigOptions {
outline_generated: _,
extern_,
} = config_options.body;
let module_kind = match extern_ {
Some(_) => ModuleKind::Extern,
None => ModuleKind::Normal,
};
for fn_arg in inputs {
match fn_arg {
FnArg::Receiver(_) => {
errors.push(syn::Error::new_spanned(fn_arg, "self not allowed here"));
}
FnArg::Typed(fn_arg) => {
visit_pat(
&mut CheckNameConflictsWithModuleBuilderVisitor {
errors: &mut errors,
},
&fn_arg.pat,
);
}
}
}
if let Some(constness) = constness {
errors.push(syn::Error::new_spanned(constness, "const not allowed here"));
}
if let Some(asyncness) = asyncness {
errors.push(syn::Error::new_spanned(asyncness, "async not allowed here"));
}
if let Some(unsafety) = unsafety {
errors.push(syn::Error::new_spanned(unsafety, "unsafe not allowed here"));
}
if let Some(abi) = abi {
errors.push(syn::Error::new_spanned(abi, "extern not allowed here"));
}
let mut skipped_idents = HashSet::new();
let struct_generic_params = generics
.params
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
GenericParam::Lifetime(LifetimeParam { attrs, .. }) => {
errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs));
None
}
GenericParam::Type(TypeParam { attrs, ident, .. })
| GenericParam::Const(ConstParam { attrs, ident, .. }) => {
if errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs))
.is_some()
{
skipped_idents.insert(ident.clone());
None
} else {
Some(v.clone())
}
}
})
.collect();
let struct_where_clause = generics
.where_clause
.as_mut()
.map(|where_clause| WhereClause {
where_token: where_clause.where_token,
predicates: where_clause
.predicates
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
WherePredicate::Lifetime(_) => None,
_ => {
let mut contains_skipped_ident = ContainsSkippedIdent {
skipped_idents: &skipped_idents,
contains_skipped_ident: false,
};
contains_skipped_ident.visit_where_predicate(v);
if contains_skipped_ident.contains_skipped_ident {
None
} else {
Some(v.clone())
}
}
})
.collect(),
});
let struct_generics = Generics {
lt_token: generics.lt_token,
params: struct_generic_params,
gt_token: generics.gt_token,
where_clause: struct_where_clause,
};
2024-06-11 06:09:13 +00:00
if let Some(variadic) = variadic {
errors.push(syn::Error::new_spanned(variadic, "... not allowed here"));
}
if !matches!(output, ReturnType::Default) {
errors.push(syn::Error::new_spanned(
output,
"return type not allowed here",
));
}
let body_results = errors.ok(transform_body::transform_body(module_kind, block));
errors.finish()?;
let (block, io) = body_results.unwrap();
Ok(Self {
attrs,
config_options,
module_kind,
vis,
sig,
block,
io,
struct_generics,
2024-06-11 06:09:13 +00:00
})
}
}
impl ModuleFn {
pub(crate) fn generate(self) -> TokenStream {
let Self {
attrs,
config_options,
module_kind,
vis,
sig,
block,
io,
struct_generics,
2024-06-11 06:09:13 +00:00
} = self;
let ConfigOptions {
outline_generated: _,
extern_: _,
} = config_options.body;
let mut outer_sig = sig.clone();
let mut body_sig = sig;
let param_names =
Vec::from_iter(outer_sig.inputs.iter_mut().enumerate().map(|(index, arg)| {
let FnArg::Typed(arg) = arg else {
unreachable!("already checked");
};
let name = if let syn::Pat::Ident(pat) = &*arg.pat {
pat.ident.clone()
} else {
format_ident!("__param{}", index)
};
*arg.pat = syn::Pat::Ident(syn::PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: name.clone(),
subpat: None,
});
name
}));
let module_kind_ty = match module_kind {
ModuleKind::Extern => quote! { ::fayalite::module::ExternModule },
ModuleKind::Normal => quote! { ::fayalite::module::NormalModule },
};
let fn_name = &outer_sig.ident;
let (_struct_impl_generics, struct_type_generics, struct_where_clause) =
struct_generics.split_for_impl();
let struct_ty = quote! {#fn_name #struct_type_generics};
2024-06-11 06:09:13 +00:00
body_sig.ident = parse_quote! {__body};
body_sig.inputs.insert(
0,
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>},
2024-06-11 06:09:13 +00:00
);
let body_fn = ItemFn {
attrs: vec![],
vis: Visibility::Inherited,
sig: body_sig,
block,
};
outer_sig.output =
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>};
2024-06-11 06:09:13 +00:00
let io_flips = io
.iter()
.map(|io| match io.kind.kind {
ModuleIOKind::Input((input,)) => quote_spanned! {input.span=>
#[hdl(flip)]
},
ModuleIOKind::Output(_) => quote! {},
})
.collect::<Vec<_>>();
let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>();
let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>();
let fn_name_str = fn_name.to_string();
let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl();
let body_turbofish_type_generics = body_type_generics.as_turbofish();
2024-06-11 06:09:13 +00:00
let block = parse_quote! {{
#body_fn
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*))
2024-06-11 06:09:13 +00:00
}};
let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none());
let struct_options = if fixed_type {
quote! { #[hdl(fixed_type)] }
} else {
quote! {}
};
let the_struct: ItemStruct = parse_quote! {
#[derive(::fayalite::__std::clone::Clone,
::fayalite::__std::hash::Hash,
::fayalite::__std::cmp::PartialEq,
::fayalite::__std::cmp::Eq,
::fayalite::__std::fmt::Debug)]
#[allow(non_camel_case_types)]
#struct_options
#vis struct #fn_name #struct_generics #struct_where_clause {
2024-06-11 06:09:13 +00:00
#(
#io_flips
#vis #io_names: #io_types,)*
}
};
let mut struct_without_hdl_attrs = the_struct.clone();
let mut struct_without_derives = the_struct;
retain_struct_attrs(&mut struct_without_hdl_attrs, |attr| !is_hdl_attr(attr));
retain_struct_attrs(&mut struct_without_derives, |attr| {
!attr.path().is_ident("derive")
});
let outer_fn = ItemFn {
attrs,
vis,
sig: outer_sig,
block,
};
let mut retval = outer_fn.into_token_stream();
struct_without_hdl_attrs.to_tokens(&mut retval);
retval.extend(
crate::value_derive_struct::value_derive_struct(struct_without_derives).unwrap(),
);
retval
}
}