make #[hdl_module] support functions with generic parameters
Some checks failed
/ test (push) Failing after 4m38s
Some checks failed
/ test (push) Failing after 4m38s
This commit is contained in:
parent
ef4b3b4081
commit
c5901cd217
|
@ -3,15 +3,17 @@
|
|||
use crate::{
|
||||
is_hdl_attr,
|
||||
module::transform_body::{HdlLet, HdlLetKindIO},
|
||||
options, Errors, HdlAttr,
|
||||
options, Errors, HdlAttr, PairsIterExt,
|
||||
};
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::{format_ident, quote, quote_spanned, ToTokens};
|
||||
use std::collections::HashSet;
|
||||
use syn::{
|
||||
parse::{Parse, ParseStream},
|
||||
parse_quote,
|
||||
visit::{visit_pat, Visit},
|
||||
Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility,
|
||||
Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct,
|
||||
LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate,
|
||||
};
|
||||
|
||||
mod transform_body;
|
||||
|
@ -72,6 +74,7 @@ pub(crate) struct ModuleFn {
|
|||
sig: Signature,
|
||||
block: Box<Block>,
|
||||
io: Vec<ModuleIO>,
|
||||
struct_generics: Generics,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
|
||||
|
@ -80,6 +83,19 @@ pub(crate) enum ModuleKind {
|
|||
Normal,
|
||||
}
|
||||
|
||||
struct ContainsSkippedIdent<'a> {
|
||||
skipped_idents: &'a HashSet<Ident>,
|
||||
contains_skipped_ident: bool,
|
||||
}
|
||||
|
||||
impl Visit<'_> for ContainsSkippedIdent<'_> {
|
||||
fn visit_ident(&mut self, ident: &'_ Ident) {
|
||||
if self.skipped_idents.contains(ident) {
|
||||
self.contains_skipped_ident = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for ModuleFn {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
let ItemFn {
|
||||
|
@ -95,7 +111,7 @@ impl Parse for ModuleFn {
|
|||
ref abi,
|
||||
fn_token: _,
|
||||
ident: _,
|
||||
ref generics,
|
||||
ref mut generics,
|
||||
paren_token: _,
|
||||
ref mut inputs,
|
||||
ref variadic,
|
||||
|
@ -140,12 +156,61 @@ impl Parse for ModuleFn {
|
|||
if let Some(abi) = abi {
|
||||
errors.push(syn::Error::new_spanned(abi, "extern not allowed here"));
|
||||
}
|
||||
if !generics.params.is_empty() {
|
||||
errors.push(syn::Error::new_spanned(
|
||||
&generics.params,
|
||||
"generics are not supported yet",
|
||||
));
|
||||
}
|
||||
let mut skipped_idents = HashSet::new();
|
||||
let struct_generic_params = generics
|
||||
.params
|
||||
.pairs_mut()
|
||||
.filter_map_pair_value_mut(|v| match v {
|
||||
GenericParam::Lifetime(LifetimeParam { attrs, .. }) => {
|
||||
errors
|
||||
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs));
|
||||
None
|
||||
}
|
||||
GenericParam::Type(TypeParam { attrs, ident, .. })
|
||||
| GenericParam::Const(ConstParam { attrs, ident, .. }) => {
|
||||
if errors
|
||||
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs))
|
||||
.is_some()
|
||||
{
|
||||
skipped_idents.insert(ident.clone());
|
||||
None
|
||||
} else {
|
||||
Some(v.clone())
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let struct_where_clause = generics
|
||||
.where_clause
|
||||
.as_mut()
|
||||
.map(|where_clause| WhereClause {
|
||||
where_token: where_clause.where_token,
|
||||
predicates: where_clause
|
||||
.predicates
|
||||
.pairs_mut()
|
||||
.filter_map_pair_value_mut(|v| match v {
|
||||
WherePredicate::Lifetime(_) => None,
|
||||
_ => {
|
||||
let mut contains_skipped_ident = ContainsSkippedIdent {
|
||||
skipped_idents: &skipped_idents,
|
||||
contains_skipped_ident: false,
|
||||
};
|
||||
contains_skipped_ident.visit_where_predicate(v);
|
||||
if contains_skipped_ident.contains_skipped_ident {
|
||||
None
|
||||
} else {
|
||||
Some(v.clone())
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
let struct_generics = Generics {
|
||||
lt_token: generics.lt_token,
|
||||
params: struct_generic_params,
|
||||
gt_token: generics.gt_token,
|
||||
where_clause: struct_where_clause,
|
||||
};
|
||||
if let Some(variadic) = variadic {
|
||||
errors.push(syn::Error::new_spanned(variadic, "... not allowed here"));
|
||||
}
|
||||
|
@ -166,6 +231,7 @@ impl Parse for ModuleFn {
|
|||
sig,
|
||||
block,
|
||||
io,
|
||||
struct_generics,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -180,6 +246,7 @@ impl ModuleFn {
|
|||
sig,
|
||||
block,
|
||||
io,
|
||||
struct_generics,
|
||||
} = self;
|
||||
let ConfigOptions {
|
||||
outline_generated: _,
|
||||
|
@ -211,10 +278,13 @@ impl ModuleFn {
|
|||
ModuleKind::Normal => quote! { ::fayalite::module::NormalModule },
|
||||
};
|
||||
let fn_name = &outer_sig.ident;
|
||||
let (_struct_impl_generics, struct_type_generics, struct_where_clause) =
|
||||
struct_generics.split_for_impl();
|
||||
let struct_ty = quote! {#fn_name #struct_type_generics};
|
||||
body_sig.ident = parse_quote! {__body};
|
||||
body_sig.inputs.insert(
|
||||
0,
|
||||
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>},
|
||||
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>},
|
||||
);
|
||||
let body_fn = ItemFn {
|
||||
attrs: vec![],
|
||||
|
@ -223,7 +293,7 @@ impl ModuleFn {
|
|||
block,
|
||||
};
|
||||
outer_sig.output =
|
||||
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>};
|
||||
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>};
|
||||
let io_flips = io
|
||||
.iter()
|
||||
.map(|io| match io.kind.kind {
|
||||
|
@ -236,9 +306,11 @@ impl ModuleFn {
|
|||
let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>();
|
||||
let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>();
|
||||
let fn_name_str = fn_name.to_string();
|
||||
let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl();
|
||||
let body_turbofish_type_generics = body_type_generics.as_turbofish();
|
||||
let block = parse_quote! {{
|
||||
#body_fn
|
||||
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*))
|
||||
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*))
|
||||
}};
|
||||
let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none());
|
||||
let struct_options = if fixed_type {
|
||||
|
@ -254,7 +326,7 @@ impl ModuleFn {
|
|||
::fayalite::__std::fmt::Debug)]
|
||||
#[allow(non_camel_case_types)]
|
||||
#struct_options
|
||||
#vis struct #fn_name {
|
||||
#vis struct #fn_name #struct_generics #struct_where_clause {
|
||||
#(
|
||||
#io_flips
|
||||
#vis #io_names: #io_types,)*
|
||||
|
|
|
@ -5,6 +5,7 @@ use fayalite::{
|
|||
array::Array,
|
||||
assert_export_firrtl,
|
||||
clock::{Clock, ClockDomain},
|
||||
expr::ToExpr,
|
||||
hdl_module,
|
||||
int::{DynUInt, DynUIntType, IntCmp, SInt, UInt},
|
||||
intern::Intern,
|
||||
|
@ -12,7 +13,7 @@ use fayalite::{
|
|||
module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind},
|
||||
reset::{SyncReset, ToReset},
|
||||
source_location::SourceLocation,
|
||||
ty::Value,
|
||||
ty::{FixedType, Value},
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
|
@ -139,22 +140,22 @@ circuit my_module:
|
|||
}
|
||||
|
||||
#[hdl_module(outline_generated)]
|
||||
pub fn check_array_repeat() {
|
||||
pub fn check_array_repeat<const N: usize>() {
|
||||
#[hdl]
|
||||
let i: UInt<8> = m.input();
|
||||
#[hdl]
|
||||
let o: Array<[UInt<8>; 3]> = m.output();
|
||||
let o: Array<[UInt<8>; N]> = m.output();
|
||||
m.connect(
|
||||
o,
|
||||
#[hdl]
|
||||
[i; 3],
|
||||
[i; N],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_array_repeat() {
|
||||
let _n = SourceLocation::normalize_files_for_tests();
|
||||
let m = check_array_repeat();
|
||||
let m = check_array_repeat::<3>();
|
||||
dbg!(m);
|
||||
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
|
||||
assert_export_firrtl! {
|
||||
|
@ -169,6 +170,84 @@ circuit check_array_repeat:
|
|||
connect _array_literal_expr[1], i
|
||||
connect _array_literal_expr[2], i
|
||||
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
|
||||
",
|
||||
};
|
||||
let m = check_array_repeat::<4>();
|
||||
dbg!(m);
|
||||
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
|
||||
assert_export_firrtl! {
|
||||
m =>
|
||||
"/test/check_array_repeat_1.fir": r"FIRRTL version 3.2.0
|
||||
circuit check_array_repeat_1:
|
||||
module check_array_repeat_1: @[module-XXXXXXXXXX.rs 1:1]
|
||||
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
|
||||
output o: UInt<8>[4] @[module-XXXXXXXXXX.rs 3:1]
|
||||
wire _array_literal_expr: UInt<8>[4]
|
||||
connect _array_literal_expr[0], i
|
||||
connect _array_literal_expr[1], i
|
||||
connect _array_literal_expr[2], i
|
||||
connect _array_literal_expr[3], i
|
||||
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
|
||||
",
|
||||
};
|
||||
}
|
||||
|
||||
#[hdl_module(outline_generated)]
|
||||
pub fn check_skipped_generics<T, #[hdl(skip)] U, const N: usize, #[hdl(skip)] const M: usize>(v: U)
|
||||
where
|
||||
T: Value<Type: FixedType<Value = T>>,
|
||||
U: std::fmt::Display,
|
||||
{
|
||||
dbg!(M);
|
||||
#[hdl]
|
||||
let i: T = m.input();
|
||||
#[hdl]
|
||||
let o: Array<[T; N]> = m.output();
|
||||
let bytes = v.to_string().as_bytes().to_expr();
|
||||
#[hdl]
|
||||
let o2: Array<[UInt<8>]> = m.output(bytes.ty());
|
||||
m.connect(
|
||||
o,
|
||||
#[hdl]
|
||||
[i; N],
|
||||
);
|
||||
m.connect(o2[i], bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skipped_generics() {
|
||||
let _n = SourceLocation::normalize_files_for_tests();
|
||||
let m = check_skipped_generics::<UInt<8>, _, 3, 4>("Hello World!\n");
|
||||
dbg!(m);
|
||||
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
|
||||
assert_export_firrtl! {
|
||||
m =>
|
||||
"/test/check_skipped_generics.fir": r"FIRRTL version 3.2.0
|
||||
circuit check_skipped_generics:
|
||||
module check_skipped_generics: @[module-XXXXXXXXXX.rs 1:1]
|
||||
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
|
||||
output o: UInt<8>[3] @[module-XXXXXXXXXX.rs 3:1]
|
||||
output o2: UInt<8>[13] @[module-XXXXXXXXXX.rs 4:1]
|
||||
wire _array_literal_expr: UInt<8>[3]
|
||||
connect _array_literal_expr[0], i
|
||||
connect _array_literal_expr[1], i
|
||||
connect _array_literal_expr[2], i
|
||||
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 5:1]
|
||||
wire _array_literal_expr_1: UInt<8>[13]
|
||||
connect _array_literal_expr_1[0], UInt<8>(0h48)
|
||||
connect _array_literal_expr_1[1], UInt<8>(0h65)
|
||||
connect _array_literal_expr_1[2], UInt<8>(0h6C)
|
||||
connect _array_literal_expr_1[3], UInt<8>(0h6C)
|
||||
connect _array_literal_expr_1[4], UInt<8>(0h6F)
|
||||
connect _array_literal_expr_1[5], UInt<8>(0h20)
|
||||
connect _array_literal_expr_1[6], UInt<8>(0h57)
|
||||
connect _array_literal_expr_1[7], UInt<8>(0h6F)
|
||||
connect _array_literal_expr_1[8], UInt<8>(0h72)
|
||||
connect _array_literal_expr_1[9], UInt<8>(0h6C)
|
||||
connect _array_literal_expr_1[10], UInt<8>(0h64)
|
||||
connect _array_literal_expr_1[11], UInt<8>(0h21)
|
||||
connect _array_literal_expr_1[12], UInt<8>(0hA)
|
||||
connect o2, _array_literal_expr_1 @[module-XXXXXXXXXX.rs 6:1]
|
||||
",
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue