Compare commits

...

3 commits

Author SHA1 Message Date
Jacob Lifshay c5901cd217
make #[hdl_module] support functions with generic parameters
Some checks failed
/ test (push) Failing after 4m38s
2024-07-25 22:10:33 -07:00
Jacob Lifshay ef4b3b4081
make [T; N]: ToExpr for any N instead of a fixed list 2024-07-25 22:08:28 -07:00
Jacob Lifshay 7963f0a5cd
add Iterator<Item = Pair<T, P>> helpers 2024-07-25 22:07:23 -07:00
6 changed files with 315 additions and 118 deletions

View file

@ -7,7 +7,9 @@ use std::io::{ErrorKind, Write};
use syn::{
bracketed, parenthesized,
parse::{Parse, ParseStream, Parser},
parse_quote, AttrStyle, Attribute, Error, Item, Token,
parse_quote,
punctuated::Pair,
AttrStyle, Attribute, Error, Item, Token,
};
mod fold;
@ -318,6 +320,108 @@ impl<T: Parse> HdlAttr<T> {
}
}
#[allow(dead_code)]
pub(crate) trait PairsIterExt: Sized + Iterator {
fn map_pair<T1, T2, P1, P2, ValueFn: FnMut(T1) -> T2, PunctFn: FnMut(P1) -> P2>(
self,
mut value_fn: ValueFn,
mut punct_fn: PunctFn,
) -> impl Iterator<Item = Pair<T2, P2>>
where
Self: Iterator<Item = Pair<T1, P1>>,
{
self.map(move |p| {
let (t, p) = p.into_tuple();
let t = value_fn(t);
let p = p.map(&mut punct_fn);
Pair::new(t, p)
})
}
fn filter_map_pair<T1, T2, P1, P2, ValueFn: FnMut(T1) -> Option<T2>, PunctFn: FnMut(P1) -> P2>(
self,
mut value_fn: ValueFn,
mut punct_fn: PunctFn,
) -> impl Iterator<Item = Pair<T2, P2>>
where
Self: Iterator<Item = Pair<T1, P1>>,
{
self.filter_map(move |p| {
let (t, p) = p.into_tuple();
let t = value_fn(t)?;
let p = p.map(&mut punct_fn);
Some(Pair::new(t, p))
})
}
fn map_pair_value<T1, T2, P, F: FnMut(T1) -> T2>(
self,
f: F,
) -> impl Iterator<Item = Pair<T2, P>>
where
Self: Iterator<Item = Pair<T1, P>>,
{
self.map_pair(f, |v| v)
}
fn filter_map_pair_value<T1, T2, P, F: FnMut(T1) -> Option<T2>>(
self,
f: F,
) -> impl Iterator<Item = Pair<T2, P>>
where
Self: Iterator<Item = Pair<T1, P>>,
{
self.filter_map_pair(f, |v| v)
}
fn map_pair_value_mut<'a, T1: 'a, T2: 'a, P: Clone + 'a, F: FnMut(T1) -> T2 + 'a>(
self,
f: F,
) -> impl Iterator<Item = Pair<T2, P>> + 'a
where
Self: Iterator<Item = Pair<T1, &'a mut P>> + 'a,
{
self.map_pair(f, |v| v.clone())
}
fn filter_map_pair_value_mut<
'a,
T1: 'a,
T2: 'a,
P: Clone + 'a,
F: FnMut(T1) -> Option<T2> + 'a,
>(
self,
f: F,
) -> impl Iterator<Item = Pair<T2, P>> + 'a
where
Self: Iterator<Item = Pair<T1, &'a mut P>> + 'a,
{
self.filter_map_pair(f, |v| v.clone())
}
fn map_pair_value_ref<'a, T1: 'a, T2: 'a, P: Clone + 'a, F: FnMut(T1) -> T2 + 'a>(
self,
f: F,
) -> impl Iterator<Item = Pair<T2, P>> + 'a
where
Self: Iterator<Item = Pair<T1, &'a P>> + 'a,
{
self.map_pair(f, |v| v.clone())
}
fn filter_map_pair_value_ref<
'a,
T1: 'a,
T2: 'a,
P: Clone + 'a,
F: FnMut(T1) -> Option<T2> + 'a,
>(
self,
f: F,
) -> impl Iterator<Item = Pair<T2, P>> + 'a
where
Self: Iterator<Item = Pair<T1, &'a P>> + 'a,
{
self.filter_map_pair(f, |v| v.clone())
}
}
impl<T, P, Iter: Iterator<Item = Pair<T, P>>> PairsIterExt for Iter {}
pub(crate) struct Errors {
error: Option<Error>,
finished: bool,

View file

@ -3,15 +3,17 @@
use crate::{
is_hdl_attr,
module::transform_body::{HdlLet, HdlLetKindIO},
options, Errors, HdlAttr,
options, Errors, HdlAttr, PairsIterExt,
};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::HashSet;
use syn::{
parse::{Parse, ParseStream},
parse_quote,
visit::{visit_pat, Visit},
Attribute, Block, Error, FnArg, Ident, ItemFn, ItemStruct, ReturnType, Signature, Visibility,
Attribute, Block, ConstParam, Error, FnArg, GenericParam, Generics, Ident, ItemFn, ItemStruct,
LifetimeParam, ReturnType, Signature, TypeParam, Visibility, WhereClause, WherePredicate,
};
mod transform_body;
@ -72,6 +74,7 @@ pub(crate) struct ModuleFn {
sig: Signature,
block: Box<Block>,
io: Vec<ModuleIO>,
struct_generics: Generics,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
@ -80,6 +83,19 @@ pub(crate) enum ModuleKind {
Normal,
}
struct ContainsSkippedIdent<'a> {
skipped_idents: &'a HashSet<Ident>,
contains_skipped_ident: bool,
}
impl Visit<'_> for ContainsSkippedIdent<'_> {
fn visit_ident(&mut self, ident: &'_ Ident) {
if self.skipped_idents.contains(ident) {
self.contains_skipped_ident = true;
}
}
}
impl Parse for ModuleFn {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ItemFn {
@ -95,7 +111,7 @@ impl Parse for ModuleFn {
ref abi,
fn_token: _,
ident: _,
ref generics,
ref mut generics,
paren_token: _,
ref mut inputs,
ref variadic,
@ -140,12 +156,61 @@ impl Parse for ModuleFn {
if let Some(abi) = abi {
errors.push(syn::Error::new_spanned(abi, "extern not allowed here"));
}
if !generics.params.is_empty() {
errors.push(syn::Error::new_spanned(
&generics.params,
"generics are not supported yet",
));
}
let mut skipped_idents = HashSet::new();
let struct_generic_params = generics
.params
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
GenericParam::Lifetime(LifetimeParam { attrs, .. }) => {
errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs));
None
}
GenericParam::Type(TypeParam { attrs, ident, .. })
| GenericParam::Const(ConstParam { attrs, ident, .. }) => {
if errors
.unwrap_or_default(HdlAttr::<crate::kw::skip>::parse_and_take_attr(attrs))
.is_some()
{
skipped_idents.insert(ident.clone());
None
} else {
Some(v.clone())
}
}
})
.collect();
let struct_where_clause = generics
.where_clause
.as_mut()
.map(|where_clause| WhereClause {
where_token: where_clause.where_token,
predicates: where_clause
.predicates
.pairs_mut()
.filter_map_pair_value_mut(|v| match v {
WherePredicate::Lifetime(_) => None,
_ => {
let mut contains_skipped_ident = ContainsSkippedIdent {
skipped_idents: &skipped_idents,
contains_skipped_ident: false,
};
contains_skipped_ident.visit_where_predicate(v);
if contains_skipped_ident.contains_skipped_ident {
None
} else {
Some(v.clone())
}
}
})
.collect(),
});
let struct_generics = Generics {
lt_token: generics.lt_token,
params: struct_generic_params,
gt_token: generics.gt_token,
where_clause: struct_where_clause,
};
if let Some(variadic) = variadic {
errors.push(syn::Error::new_spanned(variadic, "... not allowed here"));
}
@ -166,6 +231,7 @@ impl Parse for ModuleFn {
sig,
block,
io,
struct_generics,
})
}
}
@ -180,6 +246,7 @@ impl ModuleFn {
sig,
block,
io,
struct_generics,
} = self;
let ConfigOptions {
outline_generated: _,
@ -211,10 +278,13 @@ impl ModuleFn {
ModuleKind::Normal => quote! { ::fayalite::module::NormalModule },
};
let fn_name = &outer_sig.ident;
let (_struct_impl_generics, struct_type_generics, struct_where_clause) =
struct_generics.split_for_impl();
let struct_ty = quote! {#fn_name #struct_type_generics};
body_sig.ident = parse_quote! {__body};
body_sig.inputs.insert(
0,
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#fn_name, #module_kind_ty>},
parse_quote! {m: &mut ::fayalite::module::ModuleBuilder<#struct_ty, #module_kind_ty>},
);
let body_fn = ItemFn {
attrs: vec![],
@ -223,7 +293,7 @@ impl ModuleFn {
block,
};
outer_sig.output =
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#fn_name>>};
parse_quote! {-> ::fayalite::intern::Interned<::fayalite::module::Module<#struct_ty>>};
let io_flips = io
.iter()
.map(|io| match io.kind.kind {
@ -236,9 +306,11 @@ impl ModuleFn {
let io_types = io.iter().map(|io| &io.kind.ty).collect::<Vec<_>>();
let io_names = io.iter().map(|io| &io.name).collect::<Vec<_>>();
let fn_name_str = fn_name.to_string();
let (_, body_type_generics, _) = body_fn.sig.generics.split_for_impl();
let body_turbofish_type_generics = body_type_generics.as_turbofish();
let block = parse_quote! {{
#body_fn
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body(m, #(#param_names,)*))
::fayalite::module::ModuleBuilder::run(#fn_name_str, |m| __body #body_turbofish_type_generics(m, #(#param_names,)*))
}};
let fixed_type = io.iter().all(|io| io.kind.ty_expr.is_none());
let struct_options = if fixed_type {
@ -254,7 +326,7 @@ impl ModuleFn {
::fayalite::__std::fmt::Debug)]
#[allow(non_camel_case_types)]
#struct_options
#vis struct #fn_name {
#vis struct #fn_name #struct_generics #struct_where_clause {
#(
#io_flips
#vis #io_names: #io_types,)*

View file

@ -1,6 +1,6 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
// See Notices.txt for copyright information
use crate::{module::transform_body::Visitor, options, Errors, HdlAttr};
use crate::{module::transform_body::Visitor, options, Errors, HdlAttr, PairsIterExt};
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote_spanned, ToTokens, TokenStreamExt};
use syn::{
@ -207,7 +207,7 @@ impl StructOrEnumLiteral {
}
pub(crate) fn map_fields(
self,
mut f: impl FnMut(StructOrEnumLiteralField) -> StructOrEnumLiteralField,
f: impl FnMut(StructOrEnumLiteralField) -> StructOrEnumLiteralField,
) -> Self {
let Self {
attrs,
@ -217,10 +217,7 @@ impl StructOrEnumLiteral {
dot2_token,
rest,
} = self;
let fields = Punctuated::from_iter(fields.into_pairs().map(|p| {
let (field, comma) = p.into_tuple();
Pair::new(f(field), comma)
}));
let fields = fields.into_pairs().map_pair_value(f).collect();
Self {
attrs,
path,
@ -247,26 +244,22 @@ impl From<ExprStruct> for StructOrEnumLiteral {
attrs,
path: TypePath { qself, path },
brace_or_paren: BraceOrParen::Brace(brace_token),
fields: Punctuated::from_iter(fields.into_pairs().map(|v| {
let (
FieldValue {
fields: fields
.into_pairs()
.map_pair_value(
|FieldValue {
attrs,
member,
colon_token,
expr,
}| StructOrEnumLiteralField {
attrs,
member,
colon_token,
expr,
},
comma,
) = v.into_tuple();
Pair::new(
StructOrEnumLiteralField {
attrs,
member,
colon_token,
expr,
},
comma,
)
})),
.collect(),
dot2_token,
rest,
}
@ -514,20 +507,24 @@ impl Visitor {
}
};
};
let fields = Punctuated::from_iter(args.into_pairs().enumerate().map(|(index, p)| {
let (expr, comma) = p.into_tuple();
let mut index = Index::from(index);
index.span = hdl_attr.hdl.span;
Pair::new(
StructOrEnumLiteralField {
attrs: vec![],
member: Member::Unnamed(index),
colon_token: None,
expr,
},
comma,
)
}));
let fields = args
.into_pairs()
.enumerate()
.map(|(index, p)| {
let (expr, comma) = p.into_tuple();
let mut index = Index::from(index);
index.span = hdl_attr.hdl.span;
Pair::new(
StructOrEnumLiteralField {
attrs: vec![],
member: Member::Unnamed(index),
colon_token: None,
expr,
},
comma,
)
})
.collect();
self.process_struct_enum(
hdl_attr,
StructOrEnumLiteral {

View file

@ -6,7 +6,7 @@ use crate::{
expand_aggregate_literals::{AggregateLiteralOptions, StructOrEnumPath},
with_debug_clone_and_fold, Visitor,
},
Errors, HdlAttr,
Errors, HdlAttr, PairsIterExt,
};
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, TokenStreamExt};
@ -238,11 +238,7 @@ trait ParseMatchPat: Sized {
leading_vert,
cases: cases
.into_pairs()
.filter_map(|pair| {
let (pat, punct) = pair.into_tuple();
let pat = Self::parse(state, pat).ok()?;
Some(Pair::new(pat, punct))
})
.filter_map_pair_value(|pat| Self::parse(state, pat).ok())
.collect(),
})),
Pat::Paren(PatParen {
@ -282,10 +278,8 @@ trait ParseMatchPat: Sized {
}) => {
let fields = fields
.into_pairs()
.filter_map(|pair| {
let (field_pat, punct) = pair.into_tuple();
let field_pat = MatchPatStructField::parse(state, field_pat).ok()?;
Some(Pair::new(field_pat, punct))
.filter_map_pair_value(|field_pat| {
MatchPatStructField::parse(state, field_pat).ok()
})
.collect();
let path = TypePath { qself, path };

View file

@ -548,70 +548,21 @@ impl<E: ToExpr<Type = T>, T: FixedType> ToExpr for Vec<E> {
}
}
impl<E: ToExpr<Type = T>, T: FixedType> ToExpr for [E; 0] {
type Type = ArrayType<[T::Value; 0]>;
impl<E: ToExpr<Type = T>, T: FixedType, const N: usize> ToExpr for [E; N] {
type Type = ArrayType<[T::Value; N]>;
fn ty(&self) -> Self::Type {
ArrayType::new_with_len_type(FixedType::fixed_type(), ())
}
fn to_expr(&self) -> Expr<<Self::Type as Type>::Value> {
Array::new(FixedType::fixed_type(), Arc::new([])).to_expr()
let elements = Intern::intern_owned(Vec::from_iter(
self.iter().map(|v| v.to_expr().to_canonical_dyn()),
));
ArrayLiteral::new_unchecked(elements, self.ty()).to_expr()
}
}
macro_rules! impl_to_expr_for_non_empty_array {
($N:literal) => {
impl<E: ToExpr<Type = T>, T: Type> ToExpr for [E; $N] {
type Type = ArrayType<[T::Value; $N]>;
fn ty(&self) -> Self::Type {
ArrayType::new_with_len_type(self[0].ty(), ())
}
fn to_expr(&self) -> Expr<<Self::Type as Type>::Value> {
let elements = Intern::intern_owned(Vec::from_iter(
self.iter().map(|v| v.to_expr().to_canonical_dyn()),
));
ArrayLiteral::new_unchecked(elements, self.ty()).to_expr()
}
}
};
}
impl_to_expr_for_non_empty_array!(1);
impl_to_expr_for_non_empty_array!(2);
impl_to_expr_for_non_empty_array!(3);
impl_to_expr_for_non_empty_array!(4);
impl_to_expr_for_non_empty_array!(5);
impl_to_expr_for_non_empty_array!(6);
impl_to_expr_for_non_empty_array!(7);
impl_to_expr_for_non_empty_array!(8);
impl_to_expr_for_non_empty_array!(9);
impl_to_expr_for_non_empty_array!(10);
impl_to_expr_for_non_empty_array!(11);
impl_to_expr_for_non_empty_array!(12);
impl_to_expr_for_non_empty_array!(13);
impl_to_expr_for_non_empty_array!(14);
impl_to_expr_for_non_empty_array!(15);
impl_to_expr_for_non_empty_array!(16);
impl_to_expr_for_non_empty_array!(17);
impl_to_expr_for_non_empty_array!(18);
impl_to_expr_for_non_empty_array!(19);
impl_to_expr_for_non_empty_array!(20);
impl_to_expr_for_non_empty_array!(21);
impl_to_expr_for_non_empty_array!(22);
impl_to_expr_for_non_empty_array!(23);
impl_to_expr_for_non_empty_array!(24);
impl_to_expr_for_non_empty_array!(25);
impl_to_expr_for_non_empty_array!(26);
impl_to_expr_for_non_empty_array!(27);
impl_to_expr_for_non_empty_array!(28);
impl_to_expr_for_non_empty_array!(29);
impl_to_expr_for_non_empty_array!(30);
impl_to_expr_for_non_empty_array!(31);
impl_to_expr_for_non_empty_array!(32);
#[derive(Clone, Debug)]
pub struct ArrayIntoIter<VA: ValueArrayOrSlice> {
array: Arc<VA>,

View file

@ -5,6 +5,7 @@ use fayalite::{
array::Array,
assert_export_firrtl,
clock::{Clock, ClockDomain},
expr::ToExpr,
hdl_module,
int::{DynUInt, DynUIntType, IntCmp, SInt, UInt},
intern::Intern,
@ -12,7 +13,7 @@ use fayalite::{
module::transform::simplify_enums::{simplify_enums, SimplifyEnumsKind},
reset::{SyncReset, ToReset},
source_location::SourceLocation,
ty::Value,
ty::{FixedType, Value},
};
use serde_json::json;
@ -139,22 +140,22 @@ circuit my_module:
}
#[hdl_module(outline_generated)]
pub fn check_array_repeat() {
pub fn check_array_repeat<const N: usize>() {
#[hdl]
let i: UInt<8> = m.input();
#[hdl]
let o: Array<[UInt<8>; 3]> = m.output();
let o: Array<[UInt<8>; N]> = m.output();
m.connect(
o,
#[hdl]
[i; 3],
[i; N],
);
}
#[test]
fn test_array_repeat() {
let _n = SourceLocation::normalize_files_for_tests();
let m = check_array_repeat();
let m = check_array_repeat::<3>();
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
@ -169,6 +170,84 @@ circuit check_array_repeat:
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
",
};
let m = check_array_repeat::<4>();
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
m =>
"/test/check_array_repeat_1.fir": r"FIRRTL version 3.2.0
circuit check_array_repeat_1:
module check_array_repeat_1: @[module-XXXXXXXXXX.rs 1:1]
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
output o: UInt<8>[4] @[module-XXXXXXXXXX.rs 3:1]
wire _array_literal_expr: UInt<8>[4]
connect _array_literal_expr[0], i
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect _array_literal_expr[3], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 4:1]
",
};
}
#[hdl_module(outline_generated)]
pub fn check_skipped_generics<T, #[hdl(skip)] U, const N: usize, #[hdl(skip)] const M: usize>(v: U)
where
T: Value<Type: FixedType<Value = T>>,
U: std::fmt::Display,
{
dbg!(M);
#[hdl]
let i: T = m.input();
#[hdl]
let o: Array<[T; N]> = m.output();
let bytes = v.to_string().as_bytes().to_expr();
#[hdl]
let o2: Array<[UInt<8>]> = m.output(bytes.ty());
m.connect(
o,
#[hdl]
[i; N],
);
m.connect(o2[i], bytes);
}
#[test]
fn test_skipped_generics() {
let _n = SourceLocation::normalize_files_for_tests();
let m = check_skipped_generics::<UInt<8>, _, 3, 4>("Hello World!\n");
dbg!(m);
#[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161
assert_export_firrtl! {
m =>
"/test/check_skipped_generics.fir": r"FIRRTL version 3.2.0
circuit check_skipped_generics:
module check_skipped_generics: @[module-XXXXXXXXXX.rs 1:1]
input i: UInt<8> @[module-XXXXXXXXXX.rs 2:1]
output o: UInt<8>[3] @[module-XXXXXXXXXX.rs 3:1]
output o2: UInt<8>[13] @[module-XXXXXXXXXX.rs 4:1]
wire _array_literal_expr: UInt<8>[3]
connect _array_literal_expr[0], i
connect _array_literal_expr[1], i
connect _array_literal_expr[2], i
connect o, _array_literal_expr @[module-XXXXXXXXXX.rs 5:1]
wire _array_literal_expr_1: UInt<8>[13]
connect _array_literal_expr_1[0], UInt<8>(0h48)
connect _array_literal_expr_1[1], UInt<8>(0h65)
connect _array_literal_expr_1[2], UInt<8>(0h6C)
connect _array_literal_expr_1[3], UInt<8>(0h6C)
connect _array_literal_expr_1[4], UInt<8>(0h6F)
connect _array_literal_expr_1[5], UInt<8>(0h20)
connect _array_literal_expr_1[6], UInt<8>(0h57)
connect _array_literal_expr_1[7], UInt<8>(0h6F)
connect _array_literal_expr_1[8], UInt<8>(0h72)
connect _array_literal_expr_1[9], UInt<8>(0h6C)
connect _array_literal_expr_1[10], UInt<8>(0h64)
connect _array_literal_expr_1[11], UInt<8>(0h21)
connect _array_literal_expr_1[12], UInt<8>(0hA)
connect o2, _array_literal_expr_1 @[module-XXXXXXXXXX.rs 6:1]
",
};
}