diff --git a/crates/fayalite-proc-macros-impl/src/fold.rs b/crates/fayalite-proc-macros-impl/src/fold.rs index 7f2e580..49cc8c1 100644 --- a/crates/fayalite-proc-macros-impl/src/fold.rs +++ b/crates/fayalite-proc-macros-impl/src/fold.rs @@ -224,6 +224,7 @@ forward_fold!(syn::ExprPath => fold_expr_path); forward_fold!(syn::ExprRepeat => fold_expr_repeat); forward_fold!(syn::ExprStruct => fold_expr_struct); forward_fold!(syn::ExprTuple => fold_expr_tuple); +forward_fold!(syn::FieldPat => fold_field_pat); forward_fold!(syn::Ident => fold_ident); forward_fold!(syn::Member => fold_member); forward_fold!(syn::Path => fold_path); diff --git a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs index a9153c6..ae21a73 100644 --- a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs +++ b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs @@ -1,12 +1,12 @@ // SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information use crate::{ - fold::impl_fold, + fold::{impl_fold, DoFold}, module::transform_body::{with_debug_clone_and_fold, Visitor}, Errors, HdlAttr, PairsIterExt, }; use proc_macro2::{Span, TokenStream}; -use quote::{format_ident, ToTokens, TokenStreamExt}; +use quote::{format_ident, quote_spanned, ToTokens, TokenStreamExt}; use syn::{ fold::{fold_arm, fold_expr_match, fold_pat, Fold}, parse::Nothing, @@ -130,6 +130,7 @@ impl MatchPatStructField { with_debug_clone_and_fold! { struct MatchPatStruct<> { + match_span: Span, path: Path, brace_token: Brace, fields: Punctuated, @@ -140,12 +141,16 @@ with_debug_clone_and_fold! { impl ToTokens for MatchPatStruct { fn to_tokens(&self, tokens: &mut TokenStream) { let Self { + match_span, path, brace_token, fields, rest, } = self; - path.to_tokens(tokens); + quote_spanned! {*match_span=> + __MatchTy::<#path> + } + .to_tokens(tokens); brace_token.surround(tokens, |tokens| { fields.to_tokens(tokens); rest.to_tokens(tokens); @@ -155,6 +160,7 @@ impl ToTokens for MatchPatStruct { with_debug_clone_and_fold! { struct MatchPatEnumVariant<> { + match_span: Span, variant_path: Path, enum_path: Path, variant_name: Ident, @@ -165,12 +171,16 @@ with_debug_clone_and_fold! { impl ToTokens for MatchPatEnumVariant { fn to_tokens(&self, tokens: &mut TokenStream) { let Self { - variant_path, - enum_path: _, - variant_name: _, + match_span, + variant_path: _, + enum_path, + variant_name, field, } = self; - variant_path.to_tokens(tokens); + quote_spanned! {*match_span=> + __MatchTy::<#enum_path>::#variant_name + } + .to_tokens(tokens); if let Some((paren_token, field)) = field { paren_token.surround(tokens, |tokens| field.to_tokens(tokens)); } @@ -301,6 +311,7 @@ trait ParseMatchPat: Sized { }) => Self::enum_variant( state, MatchPatEnumVariant { + match_span: state.match_span, variant_path, enum_path, variant_name, @@ -346,6 +357,7 @@ trait ParseMatchPat: Sized { Self::enum_variant( state, MatchPatEnumVariant { + match_span: state.match_span, variant_path, enum_path, variant_name, @@ -376,6 +388,7 @@ trait ParseMatchPat: Sized { Self::struct_( state, MatchPatStruct { + match_span: state.match_span, path, brace_token, fields, @@ -428,6 +441,7 @@ trait ParseMatchPat: Sized { Self::enum_variant( state, MatchPatEnumVariant { + match_span: state.match_span, variant_path, enum_path, variant_name, @@ -626,6 +640,44 @@ impl Fold for RewriteAsCheckMatch { Pat::Ident(self.fold_pat_ident(pat_ident)) } }, + Pat::Path(PatPath { + attrs: _, + qself, + path, + }) => match parse_enum_path(TypePath { qself, path }) { + Ok(EnumPath { + variant_path: _, + enum_path, + variant_name, + }) => parse_quote_spanned! {self.span=> + __MatchTy::<#enum_path>::#variant_name {} + }, + Err(type_path) => parse_quote_spanned! {self.span=> + __MatchTy::<#type_path> {} + }, + }, + Pat::Struct(PatStruct { + attrs: _, + qself, + path, + brace_token, + fields, + rest, + }) => { + let type_path = TypePath { qself, path }; + let path = parse_quote_spanned! {self.span=> + __MatchTy::<#type_path> + }; + let fields = fields.do_fold(self); + Pat::Struct(PatStruct { + attrs: vec![], + qself: None, + path, + brace_token, + fields, + rest, + }) + } Pat::TupleStruct(PatTupleStruct { attrs, qself, @@ -690,6 +742,7 @@ impl Fold for RewriteAsCheckMatch { } struct HdlMatchParseState<'a> { + match_span: Span, errors: &'a mut Errors, } @@ -710,13 +763,14 @@ impl Visitor<'_> { } = expr_match; self.require_normal_module(match_token); let mut state = HdlMatchParseState { + match_span: span, errors: &mut self.errors, }; let arms = Vec::from_iter( arms.into_iter() .filter_map(|arm| MatchArm::parse(&mut state, arm).ok()), ); - let expr = quote::quote_spanned! {span=> + let expr = quote_spanned! {span=> { type __MatchTy = ::MatchVariant; let __match_expr = ::fayalite::expr::ToExpr::to_expr(&(#expr)); @@ -735,7 +789,6 @@ impl Visitor<'_> { } } }; - println!("{}", expr); syn::parse2(expr).unwrap() } } diff --git a/crates/fayalite/src/expr.rs b/crates/fayalite/src/expr.rs index e51303a..607ff1e 100644 --- a/crates/fayalite/src/expr.rs +++ b/crates/fayalite/src/expr.rs @@ -21,7 +21,7 @@ use crate::{ wire::Wire, }; use bitvec::slice::BitSlice; -use std::{fmt, ops::Deref}; +use std::{convert::Infallible, fmt, ops::Deref}; pub mod ops; pub mod target; @@ -690,3 +690,10 @@ pub trait CastTo: ToExpr { } impl CastTo for T {} + +#[doc(hidden)] +pub fn check_match_expr( + _expr: Expr, + _check_fn: impl FnOnce(T::MatchVariant, Infallible), +) { +} diff --git a/crates/fayalite/tests/module.rs b/crates/fayalite/tests/module.rs index f039750..23cd61f 100644 --- a/crates/fayalite/tests/module.rs +++ b/crates/fayalite/tests/module.rs @@ -662,7 +662,6 @@ circuit check_enum_literals: }; } -#[cfg(todo)] #[hdl_module(outline_generated)] pub fn check_struct_enum_match() { #[hdl] @@ -688,7 +687,7 @@ pub fn check_struct_enum_match() { match i2 { TestEnum::A => connect(o[2], 0_hdl_u8), TestEnum::B(v) => connect(o[2], v), - TestEnum::C(v) => connect_any(o[2], v[1]), + TestEnum::C(v) => connect_any(o[2], v[1].cast_to(UInt[1])), } #[hdl] match i2 { @@ -699,12 +698,11 @@ pub fn check_struct_enum_match() { #[hdl] match i2 { TestEnum::B(_) => connect(o[4], 1_hdl_u8), - TestEnum::C(v) => connect_any(o[4], v[2]), + TestEnum::C(v) => connect_any(o[4], v[2].cast_to(UInt[1])), _ => connect(o[4], 0_hdl_u8), } } -#[cfg(todo)] #[test] fn test_struct_enum_match() { let _n = SourceLocation::normalize_files_for_tests(); @@ -715,14 +713,14 @@ fn test_struct_enum_match() { m => "/test/check_struct_enum_match.fir": r"FIRRTL version 3.2.0 circuit check_struct_enum_match: - type Ty0 = {|None, Some: UInt<8>|} + type Ty0 = {|HdlNone, HdlSome: UInt<8>|} type Ty1 = {|A, B: UInt<8>, C: UInt<1>[3]|} module check_struct_enum_match: @[module-XXXXXXXXXX.rs 1:1] input i1: Ty0 @[module-XXXXXXXXXX.rs 2:1] input i2: Ty1 @[module-XXXXXXXXXX.rs 3:1] output o: UInt<8>[5] @[module-XXXXXXXXXX.rs 4:1] match i1: @[module-XXXXXXXXXX.rs 5:1] - None: + HdlNone: match i2: @[module-XXXXXXXXXX.rs 6:1] A: connect o[0], UInt<8>(0h17) @[module-XXXXXXXXXX.rs 7:1] @@ -730,12 +728,12 @@ circuit check_struct_enum_match: connect o[0], add(_match_arm_value, UInt<8>(0h2)) @[module-XXXXXXXXXX.rs 8:1] C(_match_arm_value_1): connect o[0], UInt<8>(0h17) @[module-XXXXXXXXXX.rs 7:1] - Some(_match_arm_value_2): + HdlSome(_match_arm_value_2): connect o[0], _match_arm_value_2 @[module-XXXXXXXXXX.rs 9:1] match i1: @[module-XXXXXXXXXX.rs 10:1] - None: + HdlNone: connect o[1], UInt<8>(0h0) @[module-XXXXXXXXXX.rs 11:1] - Some(_match_arm_value_3): + HdlSome(_match_arm_value_3): connect o[1], UInt<8>(0h1) @[module-XXXXXXXXXX.rs 12:1] match i2: @[module-XXXXXXXXXX.rs 13:1] A: @@ -768,7 +766,7 @@ circuit check_struct_enum_match: m => "/test/check_struct_enum_match.fir": r"FIRRTL version 3.2.0 circuit check_struct_enum_match: - type Ty0 = {|None, Some|} + type Ty0 = {|HdlNone, HdlSome|} type Ty1 = {tag: Ty0, body: UInt<8>} type Ty2 = {|A, B, C|} type Ty3 = {tag: Ty2, body: UInt<8>} @@ -777,7 +775,7 @@ circuit check_struct_enum_match: input i2: Ty3 @[module-XXXXXXXXXX.rs 3:1] output o: UInt<8>[5] @[module-XXXXXXXXXX.rs 4:1] match i1.tag: @[module-XXXXXXXXXX.rs 5:1] - None: + HdlNone: match i2.tag: @[module-XXXXXXXXXX.rs 6:1] A: connect o[0], UInt<8>(0h17) @[module-XXXXXXXXXX.rs 7:1] @@ -785,12 +783,12 @@ circuit check_struct_enum_match: connect o[0], add(bits(i2.body, 7, 0), UInt<8>(0h2)) @[module-XXXXXXXXXX.rs 8:1] C: connect o[0], UInt<8>(0h17) @[module-XXXXXXXXXX.rs 7:1] - Some: + HdlSome: connect o[0], bits(i1.body, 7, 0) @[module-XXXXXXXXXX.rs 9:1] match i1.tag: @[module-XXXXXXXXXX.rs 10:1] - None: + HdlNone: connect o[1], UInt<8>(0h0) @[module-XXXXXXXXXX.rs 11:1] - Some: + HdlSome: connect o[1], UInt<8>(0h1) @[module-XXXXXXXXXX.rs 12:1] match i2.tag: @[module-XXXXXXXXXX.rs 13:1] A: