make #[hdl_module] support functions with generic parameters
	
		
			
	
		
	
	
		
	
		
			Some checks failed
		
		
	
	
		
			
				
	
				/ test (push) Failing after 4m38s
				
			
		
		
	
	
				
					
				
			
		
			Some checks failed
		
		
	
	/ test (push) Failing after 4m38s
				
			This commit is contained in:
		
							parent
							
								
									ef4b3b4081
								
							
						
					
					
						commit
						c5901cd217
					
				
					 2 changed files with 169 additions and 18 deletions
				
			
		| 
						 | 
				
			
			@ -3,15 +3,17 @@
 | 
			
		|||
use crate::{
 | 
			
		||||
    is_hdl_attr,
 | 
			
		||||
    module::transform_body::{HdlLet, HdlLetKindIO},
 | 
			
		||||
    options, Errors, HdlAttr,
 | 
			
		||||
    options, Errors, HdlAttr, PairsIterExt,
 | 
			
		||||
};
 | 
			
		||||
use proc_macro2::TokenStream;
 | 
			
		||||
use quote::{format_ident, quote, quote_spanned, ToTokens};
 | 
			
		||||
use std::collections::HashSet;
 | 
			
		||||
use syn::{
 | 
			
		||||
    parse::{Parse, ParseStream},
 | 
			
		||||
    parse_quote,
 | 
			
		||||
    visit::{visit_pat, Visit},
 | 
			
		||||
    Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility,
 | 
			
		||||
    Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct,
 | 
			
		||||
    LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
mod transform_body;
 | 
			
		||||
| 
						 | 
				
			
			@ -72,6 +74,7 @@ pub(crate) struct ModuleFn {
 | 
			
		|||
    sig: Signature,
 | 
			
		||||
    block: Box<Block>,
 | 
			
		||||
    io: Vec<ModuleIO>,
 | 
			
		||||
    struct_generics: Generics,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
 | 
			
		||||
| 
						 | 
				
			
			@ -80,6 +83,19 @@ pub(crate) enum ModuleKind {
 | 
			
		|||
    Normal,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct ContainsSkippedIdent<'a> {
 | 
			
		||||
    skipped_idents: &'a HashSet<Ident>,
 | 
			
		||||
    contains_skipped_ident: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Visit<'_> for ContainsSkippedIdent<'_> {
 | 
			
		||||
    fn visit_ident(&mut self, ident: &'_ Ident) {
 | 
			
		||||
        if self.skipped_idents.contains(ident) {
 | 
			
		||||
            self.contains_skipped_ident = true;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Parse for ModuleFn {
 | 
			
		||||
    fn parse(input: ParseStream) -> syn::Result<Self> {
 | 
			
		||||
        let ItemFn {
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +111,7 @@ impl Parse for ModuleFn {
 | 
			
		|||
            ref abi,
 | 
			
		||||
            fn_token: _,
 | 
			
		||||
            ident: _,
 | 
			
		||||
            ref generics,
 | 
			
		||||
            ref mut generics,
 | 
			
		||||
            paren_token: _,
 | 
			
		||||
            ref mut inputs,
 | 
			
		||||
            ref variadic,
 | 
			
		||||
| 
						 | 
				
			
			@ -140,12 +156,61 @@ impl Parse for ModuleFn {
 | 
			
		|||
        if let Some(abi) = abi {
 | 
			
		||||
            errors.push(syn::Error::new_spanned(abi, "extern not allowed here"));
 | 
			
		||||
        }
 | 
			
		||||
        if !generics.params.is_empty() {
 | 
			
		||||
            errors.push(syn::Error::new_spanned(
 | 
			
		||||
                &generics.params,
 | 
			
		||||
                "generics are not supported yet",
 | 
			
		||||
            ));
 | 
			
		||||
        let mut skipped_idents = HashSet::new();
 | 
			
		||||
        let struct_generic_params = generics
 | 
			
		||||
            .params
 | 
			
		||||
            .pairs_mut()
 | 
			
		||||
            .filter_map_pair_value_mut(|v| match v {
 | 
			
		||||
                GenericParam::Lifetime(LifetimeParam { attrs, .. }) => {
 | 
			
		||||
                    errors
 | 
			
		||||
                        .unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs));
 | 
			
		||||
                    None
 | 
			
		||||
                }
 | 
			
		||||
                GenericParam::Type(TypeParam { attrs, ident, .. })
 | 
			
		||||
                | GenericParam::Const(ConstParam { attrs, ident, .. }) => {
 | 
			
		||||
                    if errors
 | 
			
		||||
                        .unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs))
 | 
			
		||||
                        .is_some()
 | 
			
		||||
                    {
 | 
			
		||||
                        skipped_idents.insert(ident.clone());
 | 
			
		||||
                        None
 | 
			
		||||
                    } else {
 | 
			
		||||
                        Some(v.clone())
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            })
 | 
			
		||||
            .collect();
 | 
			
		||||
        let struct_where_clause = generics
 | 
			
		||||
            .where_clause
 | 
			
		||||
            .as_mut()
 | 
			
		||||
            .map(|where_clause| WhereClause {
 | 
			
		||||
                where_token: where_clause.where_token,
 | 
			
		||||
                predicates: where_clause
 | 
			
		||||
                    .predicates
 | 
			
		||||
                    .pairs_mut()
 | 
			
		||||
                    .filter_map_pair_value_mut(|v| match v {
 | 
			
		||||
                        WherePredicate::Lifetime(_) => None,
 | 
			
		||||
                        _ => {
 | 
			
		||||
                            let mut contains_skipped_ident = ContainsSkippedIdent {
 | 
			
		||||
                                skipped_idents: &skipped_idents,
 | 
			
		||||
                                contains_skipped_ident: false,
 | 
			
		||||
                            };
 | 
			
		||||
                            contains_skipped_ident.visit_where_predicate(v);
 | 
			
		||||
                            if contains_skipped_ident.contains_skipped_ident {
 | 
			
		||||
                                None
 | 
			
		||||
                            } else {
 | 
			
		||||
                                Some(v.clone())
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    })
 | 
			
		||||
                    .collect(),
 | 
			
		||||
            });
 | 
			
		||||
        let struct_generics = Generics {
 | 
			
		||||
            lt_token: generics.lt_token,
 | 
			
		||||
            params: struct_generic_params,
 | 
			
		||||
            gt_token: generics.gt_token,
 | 
			
		||||
            where_clause: struct_where_clause,
 | 
			
		||||
        };
 | 
			
		||||
        if let Some(variadic) = variadic {
 | 
			
		||||
            errors.push(syn::Error::new_spanned(variadic, "... not allowed here"));
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -166,6 +231,7 @@ impl Parse for ModuleFn {
 | 
			
		|||
            sig,
 | 
			
		||||
            block,
 | 
			
		||||
            io,
 | 
			
		||||
            struct_generics,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -180,6 +246,7 @@ impl ModuleFn {
 | 
			
		|||
            sig,
 | 
			
		||||
            block,
 | 
			
		||||
            io,
 | 
			
		||||
            struct_generics,
 | 
			
		||||
        } = self;
 | 
			
		||||
        let ConfigOptions {
 | 
			
		||||
            outline_generated: _,
 | 
			
		||||
| 
						 | 
				
			
			@ -211,10 +278,13 @@ impl ModuleFn {
 | 
			
		|||
            ModuleKind::Normal => quote! { ::fayalite::module::NormalModule },
 | 
			
		||||
        };
 | 
			
		||||
        let fn_name = &outer_sig.ident;
 | 
			
		||||
        let (_struct_impl_generics, struct_type_generics, struct_where_clause) =
 | 
			
		||||
            struct_generics.split_for_impl();
 | 
			
		||||
        let struct_ty = quote! {#fn_name #struct_type_generics};
 | 
			
		||||
        body_sig.ident = parse_quote! {__body};
 | 
			
		||||
        body_sig.inputs.insert(
 | 
			
		||||
            0,
 | 
			
		||||
            parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>},
 | 
			
		||||
            parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>},
 | 
			
		||||
        );
 | 
			
		||||
        let body_fn = ItemFn {
 | 
			
		||||
            attrs: vec![],
 | 
			
		||||
| 
						 | 
				
			
			@ -223,7 +293,7 @@ impl ModuleFn {
 | 
			
		|||
            block,
 | 
			
		||||
        };
 | 
			
		||||
        outer_sig.output =
 | 
			
		||||
            parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>};
 | 
			
		||||
            parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>};
 | 
			
		||||
        let io_flips = io
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|io| match io.kind.kind {
 | 
			
		||||
| 
						 | 
				
			
			@ -236,9 +306,11 @@ impl ModuleFn {
 | 
			
		|||
        let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>();
 | 
			
		||||
        let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>();
 | 
			
		||||
        let fn_name_str = fn_name.to_string();
 | 
			
		||||
        let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl();
 | 
			
		||||
        let body_turbofish_type_generics = body_type_generics.as_turbofish();
 | 
			
		||||
        let block = parse_quote! {{
 | 
			
		||||
            #body_fn
 | 
			
		||||
            ::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*))
 | 
			
		||||
            ::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*))
 | 
			
		||||
        }};
 | 
			
		||||
        let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none());
 | 
			
		||||
        let struct_options = if fixed_type {
 | 
			
		||||
| 
						 | 
				
			
			@ -254,7 +326,7 @@ impl ModuleFn {
 | 
			
		|||
                ::fayalite::__std::fmt::Debug)]
 | 
			
		||||
            #[allow(non_camel_case_types)]
 | 
			
		||||
            #struct_options
 | 
			
		||||
            #vis struct #fn_name {
 | 
			
		||||
            #vis struct #fn_name #struct_generics #struct_where_clause {
 | 
			
		||||
                #(
 | 
			
		||||
                #io_flips
 | 
			
		||||
                #vis #io_names: #io_types,)*
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,6 +5,7 @@ use fayalite::{
 | 
			
		|||
    array::Array,
 | 
			
		||||
    assert_export_firrtl,
 | 
			
		||||
    clock::{Clock, ClockDomain},
 | 
			
		||||
    expr::ToExpr,
 | 
			
		||||
    hdl_module,
 | 
			
		||||
    int::{DynUInt, DynUIntType, IntCmp, SInt, UInt},
 | 
			
		||||
    intern::Intern,
 | 
			
		||||
| 
						 | 
				
			
			@ -12,7 +13,7 @@ use fayalite::{
 | 
			
		|||
    module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind},
 | 
			
		||||
    reset::{SyncReset, ToReset},
 | 
			
		||||
    source_location::SourceLocation,
 | 
			
		||||
    ty::Value,
 | 
			
		||||
    ty::{FixedType, Value},
 | 
			
		||||
};
 | 
			
		||||
use serde_json::json;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -139,22 +140,22 @@ circuit my_module:
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
#[hdl_module(outline_generated)]
 | 
			
		||||
pub fn check_array_repeat() {
 | 
			
		||||
pub fn check_array_repeat<const N: usize>() {
 | 
			
		||||
    #[hdl]
 | 
			
		||||
    let i: UInt<8> = m.input();
 | 
			
		||||
    #[hdl]
 | 
			
		||||
    let o: Array<[UInt<8>; 3]> = m.output();
 | 
			
		||||
    let o: Array<[UInt<8>; N]> = m.output();
 | 
			
		||||
    m.connect(
 | 
			
		||||
        o,
 | 
			
		||||
        #[hdl]
 | 
			
		||||
        [i; 3],
 | 
			
		||||
        [i; N],
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[test]
 | 
			
		||||
fn test_array_repeat() {
 | 
			
		||||
    let _n = SourceLocation::normalize_files_for_tests();
 | 
			
		||||
    let m = check_array_repeat();
 | 
			
		||||
    let m = check_array_repeat::<3>();
 | 
			
		||||
    dbg!(m);
 | 
			
		||||
    #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
 | 
			
		||||
    assert_export_firrtl! {
 | 
			
		||||
| 
						 | 
				
			
			@ -169,6 +170,84 @@ circuit check_array_repeat:
 | 
			
		|||
        connect _array_literal_expr[1], i
 | 
			
		||||
        connect _array_literal_expr[2], i
 | 
			
		||||
        connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
 | 
			
		||||
",
 | 
			
		||||
    };
 | 
			
		||||
    let m = check_array_repeat::<4>();
 | 
			
		||||
    dbg!(m);
 | 
			
		||||
    #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
 | 
			
		||||
    assert_export_firrtl! {
 | 
			
		||||
        m =>
 | 
			
		||||
        "/test/check_array_repeat_1.fir": r"FIRRTL version 3.2.0
 | 
			
		||||
circuit check_array_repeat_1:
 | 
			
		||||
    module check_array_repeat_1: @[module-XXXXXXXXXX.rs 1:1]
 | 
			
		||||
        input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
 | 
			
		||||
        output o: UInt<8>[4] @[module-XXXXXXXXXX.rs 3:1]
 | 
			
		||||
        wire _array_literal_expr: UInt<8>[4]
 | 
			
		||||
        connect _array_literal_expr[0], i
 | 
			
		||||
        connect _array_literal_expr[1], i
 | 
			
		||||
        connect _array_literal_expr[2], i
 | 
			
		||||
        connect _array_literal_expr[3], i
 | 
			
		||||
        connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
 | 
			
		||||
",
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[hdl_module(outline_generated)]
 | 
			
		||||
pub fn check_skipped_generics<T, #[hdl(skip)] U, const N: usize, #[hdl(skip)] const M: usize>(v: U)
 | 
			
		||||
where
 | 
			
		||||
    T: Value<Type: FixedType<Value = T>>,
 | 
			
		||||
    U: std::fmt::Display,
 | 
			
		||||
{
 | 
			
		||||
    dbg!(M);
 | 
			
		||||
    #[hdl]
 | 
			
		||||
    let i: T = m.input();
 | 
			
		||||
    #[hdl]
 | 
			
		||||
    let o: Array<[T; N]> = m.output();
 | 
			
		||||
    let bytes = v.to_string().as_bytes().to_expr();
 | 
			
		||||
    #[hdl]
 | 
			
		||||
    let o2: Array<[UInt<8>]> = m.output(bytes.ty());
 | 
			
		||||
    m.connect(
 | 
			
		||||
        o,
 | 
			
		||||
        #[hdl]
 | 
			
		||||
        [i; N],
 | 
			
		||||
    );
 | 
			
		||||
    m.connect(o2[i], bytes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[test]
 | 
			
		||||
fn test_skipped_generics() {
 | 
			
		||||
    let _n = SourceLocation::normalize_files_for_tests();
 | 
			
		||||
    let m = check_skipped_generics::<UInt<8>, _, 3, 4>("Hello World!\n");
 | 
			
		||||
    dbg!(m);
 | 
			
		||||
    #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
 | 
			
		||||
    assert_export_firrtl! {
 | 
			
		||||
        m =>
 | 
			
		||||
        "/test/check_skipped_generics.fir": r"FIRRTL version 3.2.0
 | 
			
		||||
circuit check_skipped_generics:
 | 
			
		||||
    module check_skipped_generics: @[module-XXXXXXXXXX.rs 1:1]
 | 
			
		||||
        input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
 | 
			
		||||
        output o: UInt<8>[3] @[module-XXXXXXXXXX.rs 3:1]
 | 
			
		||||
        output o2: UInt<8>[13] @[module-XXXXXXXXXX.rs 4:1]
 | 
			
		||||
        wire _array_literal_expr: UInt<8>[3]
 | 
			
		||||
        connect _array_literal_expr[0], i
 | 
			
		||||
        connect _array_literal_expr[1], i
 | 
			
		||||
        connect _array_literal_expr[2], i
 | 
			
		||||
        connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 5:1]
 | 
			
		||||
        wire _array_literal_expr_1: UInt<8>[13]
 | 
			
		||||
        connect _array_literal_expr_1[0], UInt<8>(0h48)
 | 
			
		||||
        connect _array_literal_expr_1[1], UInt<8>(0h65)
 | 
			
		||||
        connect _array_literal_expr_1[2], UInt<8>(0h6C)
 | 
			
		||||
        connect _array_literal_expr_1[3], UInt<8>(0h6C)
 | 
			
		||||
        connect _array_literal_expr_1[4], UInt<8>(0h6F)
 | 
			
		||||
        connect _array_literal_expr_1[5], UInt<8>(0h20)
 | 
			
		||||
        connect _array_literal_expr_1[6], UInt<8>(0h57)
 | 
			
		||||
        connect _array_literal_expr_1[7], UInt<8>(0h6F)
 | 
			
		||||
        connect _array_literal_expr_1[8], UInt<8>(0h72)
 | 
			
		||||
        connect _array_literal_expr_1[9], UInt<8>(0h6C)
 | 
			
		||||
        connect _array_literal_expr_1[10], UInt<8>(0h64)
 | 
			
		||||
        connect _array_literal_expr_1[11], UInt<8>(0h21)
 | 
			
		||||
        connect _array_literal_expr_1[12], UInt<8>(0hA)
 | 
			
		||||
        connect o2, _array_literal_expr_1 @[module-XXXXXXXXXX.rs 6:1]
 | 
			
		||||
",
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue