make #[hdl_module] support functions with generic parameters

This commit is contained in:
Jacob Lifshay 2024-07-25 22:10:33 -07:00
parent ef4b3b4081
commit b33566841d
Signed by: programmerjake
SSH key fingerprint: SHA256:B1iRVvUJkvd7upMIiMqn6OyxvD2SgJkAH3ZnUOj6z+c
2 changed files with 169 additions and 18 deletions

View file

@ -3,15 +3,17 @@
use crate::{
is_hdl_attr,
module::transform_body::{HdlLet, HdlLetKindIO},
options, Errors, HdlAttr,
options, Errors, HdlAttr, PairsIterExt,
};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::HashSet;
use syn::{
parse::{Parse, ParseStream},
parse_quote,
visit::{visit_pat, Visit},
Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility,
Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct,
LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate,
};
mod transform_body;
@ -72,6 +74,7 @@ pub(crate) struct ModuleFn {
sig: Signature,
block: Box<Block>,
io: Vec<ModuleIO>,
struct_generics: Generics,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
@ -80,6 +83,19 @@ pub(crate) enum ModuleKind {
Normal,
}
struct ContainsSkippedIdent<'a> {
skipped_idents: &'a HashSet<Ident>,
contains_skipped_ident: bool,
}
impl Visit<'_> for ContainsSkippedIdent<'_> {
fn visit_ident(&mut self, ident: &'_ Ident) {
if self.skipped_idents.contains(ident) {
self.contains_skipped_ident = true;
}
}
}
impl Parse for ModuleFn {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ItemFn {
@ -95,7 +111,7 @@ impl Parse for ModuleFn {
ref abi,
fn_token: _,
ident: _,
ref generics,
ref mut generics,
paren_token: _,
ref mut inputs,
ref variadic,
@ -140,12 +156,61 @@ impl Parse for ModuleFn {
if let Some(abi) = abi {
errors.push(syn::Error::new_spanned(abi, "extern not allowed here"));
}
if !generics.params.is_empty() {
errors.push(syn::Error::new_spanned(
&generics.params,
"generics are not supported yet",
));
}
let mut skipped_idents = HashSet::new();
let struct_generic_params = generics
.params
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
GenericParam::Lifetime(LifetimeParam { attrs, .. }) => {
errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs));
None
}
GenericParam::Type(TypeParam { attrs, ident, .. })
| GenericParam::Const(ConstParam { attrs, ident, .. }) => {
if errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs))
.is_some()
{
skipped_idents.insert(ident.clone());
None
} else {
Some(v.clone())
}
}
})
.collect();
let struct_where_clause = generics
.where_clause
.as_mut()
.map(|where_clause| WhereClause {
where_token: where_clause.where_token,
predicates: where_clause
.predicates
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
WherePredicate::Lifetime(_) => None,
_ => {
let mut contains_skipped_ident = ContainsSkippedIdent {
skipped_idents: &skipped_idents,
contains_skipped_ident: false,
};
contains_skipped_ident.visit_where_predicate(v);
if contains_skipped_ident.contains_skipped_ident {
None
} else {
Some(v.clone())
}
}
})
.collect(),
});
let struct_generics = Generics {
lt_token: generics.lt_token,
params: struct_generic_params,
gt_token: generics.gt_token,
where_clause: struct_where_clause,
};
if let Some(variadic) = variadic {
errors.push(syn::Error::new_spanned(variadic, "... not allowed here"));
}
@ -166,6 +231,7 @@ impl Parse for ModuleFn {
sig,
block,
io,
struct_generics,
})
}
}
@ -180,6 +246,7 @@ impl ModuleFn {
sig,
block,
io,
struct_generics,
} = self;
let ConfigOptions {
outline_generated: _,
@ -211,10 +278,13 @@ impl ModuleFn {
ModuleKind::Normal => quote! { ::fayalite::module::NormalModule },
};
let fn_name = &outer_sig.ident;
let (_struct_impl_generics, struct_type_generics, struct_where_clause) =
struct_generics.split_for_impl();
let struct_ty = quote! {#fn_name #struct_type_generics};
body_sig.ident = parse_quote! {__body};
body_sig.inputs.insert(
0,
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>},
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>},
);
let body_fn = ItemFn {
attrs: vec![],
@ -223,7 +293,7 @@ impl ModuleFn {
block,
};
outer_sig.output =
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>};
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>};
let io_flips = io
.iter()
.map(|io| match io.kind.kind {
@ -236,9 +306,11 @@ impl ModuleFn {
let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>();
let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>();
let fn_name_str = fn_name.to_string();
let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl();
let body_turbofish_type_generics = body_type_generics.as_turbofish();
let block = parse_quote! {{
#body_fn
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*))
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*))
}};
let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none());
let struct_options = if fixed_type {
@ -254,7 +326,7 @@ impl ModuleFn {
::fayalite::__std::fmt::Debug)]
#[allow(non_camel_case_types)]
#struct_options
#vis struct #fn_name {
#vis struct #fn_name #struct_generics #struct_where_clause {
#(
#io_flips
#vis #io_names: #io_types,)*

View file

@ -5,6 +5,7 @@ use fayalite::{
array::Array,
assert_export_firrtl,
clock::{Clock, ClockDomain},
expr::ToExpr,
hdl_module,
int::{DynUInt, DynUIntType, IntCmp, SInt, UInt},
intern::Intern,
@ -12,7 +13,7 @@ use fayalite::{
module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind},
reset::{SyncReset, ToReset},
source_location::SourceLocation,
ty::Value,
ty::{FixedType, Value},
};
use serde_json::json;
@ -139,22 +140,22 @@ circuit my_module:
}
#[hdl_module(outline_generated)]
pub fn check_array_repeat() {
pub fn check_array_repeat<const N: usize>() {
#[hdl]
let i: UInt<8> = m.input();
#[hdl]
let o: Array<[UInt<8>; 3]> = m.output();
let o: Array<[UInt<8>; N]> = m.output();
m.connect(
o,
#[hdl]
[i; 3],
[i; N],
);
}
#[test]
fn test_array_repeat() {
let _n = SourceLocation::normalize_files_for_tests();
let m = check_array_repeat();
let m = check_array_repeat::<3>();
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
@ -169,6 +170,84 @@ circuit check_array_repeat:
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
",
};
let m = check_array_repeat::<4>();
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
m =>
"/test/check_array_repeat_1.fir": r"FIRRTL version 3.2.0
circuit check_array_repeat_1:
module check_array_repeat_1: @[module-XXXXXXXXXX.rs 1:1]
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
output o: UInt<8>[4] @[module-XXXXXXXXXX.rs 3:1]
wire _array_literal_expr: UInt<8>[4]
connect _array_literal_expr[0], i
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect _array_literal_expr[3], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
",
};
}
#[hdl_module(outline_generated)]
pub fn check_skipped_generics<T, #[hdl(skip)] U, const N: usize, #[hdl(skip)] const M: usize>(v: U)
where
T: Value<Type: FixedType<Value = T>>,
U: std::fmt::Display,
{
dbg!(M);
#[hdl]
let i: T = m.input();
#[hdl]
let o: Array<[T; N]> = m.output();
let bytes = v.to_string().as_bytes().to_expr();
#[hdl]
let o2: Array<[UInt<8>]> = m.output(bytes.ty());
m.connect(
o,
#[hdl]
[i; N],
);
m.connect(o2, bytes);
}
#[test]
fn test_skipped_generics() {
let _n = SourceLocation::normalize_files_for_tests();
let m = check_skipped_generics::<UInt<8>, _, 3, 4>("Hello World!\n");
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
m =>
"/test/check_skipped_generics.fir": r"FIRRTL version 3.2.0
circuit check_skipped_generics:
module check_skipped_generics: @[module-XXXXXXXXXX.rs 1:1]
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
output o: UInt<8>[3] @[module-XXXXXXXXXX.rs 3:1]
output o2: UInt<8>[13] @[module-XXXXXXXXXX.rs 4:1]
wire _array_literal_expr: UInt<8>[3]
connect _array_literal_expr[0], i
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 5:1]
wire _array_literal_expr_1: UInt<8>[13]
connect _array_literal_expr_1[0], UInt<8>(0h48)
connect _array_literal_expr_1[1], UInt<8>(0h65)
connect _array_literal_expr_1[2], UInt<8>(0h6C)
connect _array_literal_expr_1[3], UInt<8>(0h6C)
connect _array_literal_expr_1[4], UInt<8>(0h6F)
connect _array_literal_expr_1[5], UInt<8>(0h20)
connect _array_literal_expr_1[6], UInt<8>(0h57)
connect _array_literal_expr_1[7], UInt<8>(0h6F)
connect _array_literal_expr_1[8], UInt<8>(0h72)
connect _array_literal_expr_1[9], UInt<8>(0h6C)
connect _array_literal_expr_1[10], UInt<8>(0h64)
connect _array_literal_expr_1[11], UInt<8>(0h21)
connect _array_literal_expr_1[12], UInt<8>(0hA)
connect o2, _array_literal_expr_1 @[module-XXXXXXXXXX.rs 6:1]
",
};
}