make #[hdl_module] support functions with generic parameters

This commit is contained in:
Jacob Lifshay 2024-07-25 22:10:33 -07:00
parent ef4b3b4081
commit b33566841d
Signed by: programmerjake
SSH key fingerprint: SHA256:B1iRVvUJkvd7upMIiMqn6OyxvD2SgJkAH3ZnUOj6z+c
2 changed files with 169 additions and 18 deletions

View file

@ -3,15 +3,17 @@
use crate::{ use crate::{
is_hdl_attr, is_hdl_attr,
module::transform_body::{HdlLet, HdlLetKindIO}, module::transform_body::{HdlLet, HdlLetKindIO},
options, Errors, HdlAttr, options, Errors, HdlAttr, PairsIterExt,
}; };
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens}; use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::HashSet;
use syn::{ use syn::{
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
parse_quote, parse_quote,
visit::{visit_pat, Visit}, visit::{visit_pat, Visit},
Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility, Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct,
LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate,
}; };
mod transform_body; mod transform_body;
@ -72,6 +74,7 @@ pub(crate) struct ModuleFn {
sig: Signature, sig: Signature,
block: Box<Block>, block: Box<Block>,
io: Vec<ModuleIO>, io: Vec<ModuleIO>,
struct_generics: Generics,
} }
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
@ -80,6 +83,19 @@ pub(crate) enum ModuleKind {
Normal, Normal,
} }
struct ContainsSkippedIdent<'a> {
skipped_idents: &'a HashSet<Ident>,
contains_skipped_ident: bool,
}
impl Visit<'_> for ContainsSkippedIdent<'_> {
fn visit_ident(&mut self, ident: &'_ Ident) {
if self.skipped_idents.contains(ident) {
self.contains_skipped_ident = true;
}
}
}
impl Parse for ModuleFn { impl Parse for ModuleFn {
fn parse(input: ParseStream) -> syn::Result<Self> { fn parse(input: ParseStream) -> syn::Result<Self> {
let ItemFn { let ItemFn {
@ -95,7 +111,7 @@ impl Parse for ModuleFn {
ref abi, ref abi,
fn_token: _, fn_token: _,
ident: _, ident: _,
ref generics, ref mut generics,
paren_token: _, paren_token: _,
ref mut inputs, ref mut inputs,
ref variadic, ref variadic,
@ -140,12 +156,61 @@ impl Parse for ModuleFn {
if let Some(abi) = abi { if let Some(abi) = abi {
errors.push(syn::Error::new_spanned(abi, "extern not allowed here")); errors.push(syn::Error::new_spanned(abi, "extern not allowed here"));
} }
if !generics.params.is_empty() { let mut skipped_idents = HashSet::new();
errors.push(syn::Error::new_spanned( let struct_generic_params = generics
&generics.params, .params
"generics are not supported yet", .pairs_mut()
)); .filter_map_pair_value_mut(|v| match v {
GenericParam::Lifetime(LifetimeParam { attrs, .. }) => {
errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs));
None
} }
GenericParam::Type(TypeParam { attrs, ident, .. })
| GenericParam::Const(ConstParam { attrs, ident, .. }) => {
if errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs))
.is_some()
{
skipped_idents.insert(ident.clone());
None
} else {
Some(v.clone())
}
}
})
.collect();
let struct_where_clause = generics
.where_clause
.as_mut()
.map(|where_clause| WhereClause {
where_token: where_clause.where_token,
predicates: where_clause
.predicates
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
WherePredicate::Lifetime(_) => None,
_ => {
let mut contains_skipped_ident = ContainsSkippedIdent {
skipped_idents: &skipped_idents,
contains_skipped_ident: false,
};
contains_skipped_ident.visit_where_predicate(v);
if contains_skipped_ident.contains_skipped_ident {
None
} else {
Some(v.clone())
}
}
})
.collect(),
});
let struct_generics = Generics {
lt_token: generics.lt_token,
params: struct_generic_params,
gt_token: generics.gt_token,
where_clause: struct_where_clause,
};
if let Some(variadic) = variadic { if let Some(variadic) = variadic {
errors.push(syn::Error::new_spanned(variadic, "... not allowed here")); errors.push(syn::Error::new_spanned(variadic, "... not allowed here"));
} }
@ -166,6 +231,7 @@ impl Parse for ModuleFn {
sig, sig,
block, block,
io, io,
struct_generics,
}) })
} }
} }
@ -180,6 +246,7 @@ impl ModuleFn {
sig, sig,
block, block,
io, io,
struct_generics,
} = self; } = self;
let ConfigOptions { let ConfigOptions {
outline_generated: _, outline_generated: _,
@ -211,10 +278,13 @@ impl ModuleFn {
ModuleKind::Normal => quote! { ::fayalite::module::NormalModule }, ModuleKind::Normal => quote! { ::fayalite::module::NormalModule },
}; };
let fn_name = &outer_sig.ident; let fn_name = &outer_sig.ident;
let (_struct_impl_generics, struct_type_generics, struct_where_clause) =
struct_generics.split_for_impl();
let struct_ty = quote! {#fn_name #struct_type_generics};
body_sig.ident = parse_quote! {__body}; body_sig.ident = parse_quote! {__body};
body_sig.inputs.insert( body_sig.inputs.insert(
0, 0,
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>}, parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>},
); );
let body_fn = ItemFn { let body_fn = ItemFn {
attrs: vec![], attrs: vec![],
@ -223,7 +293,7 @@ impl ModuleFn {
block, block,
}; };
outer_sig.output = outer_sig.output =
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>}; parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>};
let io_flips = io let io_flips = io
.iter() .iter()
.map(|io| match io.kind.kind { .map(|io| match io.kind.kind {
@ -236,9 +306,11 @@ impl ModuleFn {
let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>(); let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>();
let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>(); let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>();
let fn_name_str = fn_name.to_string(); let fn_name_str = fn_name.to_string();
let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl();
let body_turbofish_type_generics = body_type_generics.as_turbofish();
let block = parse_quote! {{ let block = parse_quote! {{
#body_fn #body_fn
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*)) ::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*))
}}; }};
let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none()); let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none());
let struct_options = if fixed_type { let struct_options = if fixed_type {
@ -254,7 +326,7 @@ impl ModuleFn {
::fayalite::__std::fmt::Debug)] ::fayalite::__std::fmt::Debug)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#struct_options #struct_options
#vis struct #fn_name { #vis struct #fn_name #struct_generics #struct_where_clause {
#( #(
#io_flips #io_flips
#vis #io_names: #io_types,)* #vis #io_names: #io_types,)*

View file

@ -5,6 +5,7 @@ use fayalite::{
array::Array, array::Array,
assert_export_firrtl, assert_export_firrtl,
clock::{Clock, ClockDomain}, clock::{Clock, ClockDomain},
expr::ToExpr,
hdl_module, hdl_module,
int::{DynUInt, DynUIntType, IntCmp, SInt, UInt}, int::{DynUInt, DynUIntType, IntCmp, SInt, UInt},
intern::Intern, intern::Intern,
@ -12,7 +13,7 @@ use fayalite::{
module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind}, module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind},
reset::{SyncReset, ToReset}, reset::{SyncReset, ToReset},
source_location::SourceLocation, source_location::SourceLocation,
ty::Value, ty::{FixedType, Value},
}; };
use serde_json::json; use serde_json::json;
@ -139,22 +140,22 @@ circuit my_module:
} }
#[hdl_module(outline_generated)] #[hdl_module(outline_generated)]
pub fn check_array_repeat() { pub fn check_array_repeat<const N: usize>() {
#[hdl] #[hdl]
let i: UInt<8> = m.input(); let i: UInt<8> = m.input();
#[hdl] #[hdl]
let o: Array<[UInt<8>; 3]> = m.output(); let o: Array<[UInt<8>; N]> = m.output();
m.connect( m.connect(
o, o,
#[hdl] #[hdl]
[i; 3], [i; N],
); );
} }
#[test] #[test]
fn test_array_repeat() { fn test_array_repeat() {
let _n = SourceLocation::normalize_files_for_tests(); let _n = SourceLocation::normalize_files_for_tests();
let m = check_array_repeat(); let m = check_array_repeat::<3>();
dbg!(m); dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! { assert_export_firrtl! {
@ -169,6 +170,84 @@ circuit check_array_repeat:
connect _array_literal_expr[1], i connect _array_literal_expr[1], i
connect _array_literal_expr[2], i connect _array_literal_expr[2], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1] connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
",
};
let m = check_array_repeat::<4>();
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
m =>
"/test/check_array_repeat_1.fir": r"FIRRTL version 3.2.0
circuit check_array_repeat_1:
module check_array_repeat_1: @[module-XXXXXXXXXX.rs 1:1]
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
output o: UInt<8>[4] @[module-XXXXXXXXXX.rs 3:1]
wire _array_literal_expr: UInt<8>[4]
connect _array_literal_expr[0], i
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect _array_literal_expr[3], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
",
};
}
#[hdl_module(outline_generated)]
pub fn check_skipped_generics<T, #[hdl(skip)] U, const N: usize, #[hdl(skip)] const M: usize>(v: U)
where
T: Value<Type: FixedType<Value = T>>,
U: std::fmt::Display,
{
dbg!(M);
#[hdl]
let i: T = m.input();
#[hdl]
let o: Array<[T; N]> = m.output();
let bytes = v.to_string().as_bytes().to_expr();
#[hdl]
let o2: Array<[UInt<8>]> = m.output(bytes.ty());
m.connect(
o,
#[hdl]
[i; N],
);
m.connect(o2, bytes);
}
#[test]
fn test_skipped_generics() {
let _n = SourceLocation::normalize_files_for_tests();
let m = check_skipped_generics::<UInt<8>, _, 3, 4>("Hello World!\n");
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
m =>
"/test/check_skipped_generics.fir": r"FIRRTL version 3.2.0
circuit check_skipped_generics:
module check_skipped_generics: @[module-XXXXXXXXXX.rs 1:1]
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
output o: UInt<8>[3] @[module-XXXXXXXXXX.rs 3:1]
output o2: UInt<8>[13] @[module-XXXXXXXXXX.rs 4:1]
wire _array_literal_expr: UInt<8>[3]
connect _array_literal_expr[0], i
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 5:1]
wire _array_literal_expr_1: UInt<8>[13]
connect _array_literal_expr_1[0], UInt<8>(0h48)
connect _array_literal_expr_1[1], UInt<8>(0h65)
connect _array_literal_expr_1[2], UInt<8>(0h6C)
connect _array_literal_expr_1[3], UInt<8>(0h6C)
connect _array_literal_expr_1[4], UInt<8>(0h6F)
connect _array_literal_expr_1[5], UInt<8>(0h20)
connect _array_literal_expr_1[6], UInt<8>(0h57)
connect _array_literal_expr_1[7], UInt<8>(0h6F)
connect _array_literal_expr_1[8], UInt<8>(0h72)
connect _array_literal_expr_1[9], UInt<8>(0h6C)
connect _array_literal_expr_1[10], UInt<8>(0h64)
connect _array_literal_expr_1[11], UInt<8>(0h21)
connect _array_literal_expr_1[12], UInt<8>(0hA)
connect o2, _array_literal_expr_1 @[module-XXXXXXXXXX.rs 6:1]
", ",
}; };
} }