diff --git a/crates/fayalite-proc-macros-impl/src/module/transform_body.rs b/crates/fayalite-proc-macros-impl/src/module/transform_body.rs index 6e99e87..c67f8dc 100644 --- a/crates/fayalite-proc-macros-impl/src/module/transform_body.rs +++ b/crates/fayalite-proc-macros-impl/src/module/transform_body.rs @@ -1109,7 +1109,7 @@ fn parse_quote_let_pat>( } } -fn wrap_ty_with_expr(ty: impl ToTokens) -> Type { +pub(crate) fn wrap_ty_with_expr(ty: impl ToTokens) -> Type { parse_quote_spanned! {ty.span()=> ::fayalite::expr::Expr<#ty> } @@ -1586,7 +1586,7 @@ impl Visitor<'_> { } } -fn empty_let() -> Local { +pub(crate) fn empty_let() -> Local { Local { attrs: vec![], let_token: Default::default(), @@ -1672,7 +1672,7 @@ impl Fold for Visitor<'_> { } } - fn fold_local(&mut self, let_stmt: Local) -> Local { + fn fold_local(&mut self, mut let_stmt: Local) -> Local { match self .errors .ok(HdlAttr::::parse_and_leave_attr( @@ -1682,6 +1682,25 @@ impl Fold for Visitor<'_> { Some(None) => return fold_local(self, let_stmt), Some(Some(HdlAttr { .. })) => {} }; + let mut pat = &let_stmt.pat; + if let Pat::Type(pat_type) = pat { + pat = &pat_type.pat; + } + let Pat::Ident(syn::PatIdent { + attrs: _, + by_ref: None, + mutability: _, + ident: _, + subpat: None, + }) = pat + else { + let hdl_attr = HdlAttr::::parse_and_take_attr(&mut let_stmt.attrs) + .ok() + .flatten() + .expect("already checked above"); + let let_stmt = fold_local(self, let_stmt); + return self.process_hdl_let_pat(hdl_attr, let_stmt); + }; let hdl_let = syn::parse2::>>(let_stmt.into_token_stream()); let Some(hdl_let) = self.errors.ok(hdl_let) else { return empty_let(); diff --git a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs index 1d53104..f1ff2c2 100644 --- a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs +++ b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs @@ -3,22 +3,111 @@ use crate::{ fold::{impl_fold, DoFold}, kw, - module::transform_body::{with_debug_clone_and_fold, Visitor}, + module::transform_body::{empty_let, with_debug_clone_and_fold, wrap_ty_with_expr, Visitor}, Errors, HdlAttr, PairsIterExt, }; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote_spanned, ToTokens, TokenStreamExt}; +use std::collections::BTreeSet; use syn::{ - fold::{fold_arm, fold_expr_match, fold_pat, Fold}, + fold::{fold_arm, fold_expr_match, fold_local, fold_pat, Fold}, parse::Nothing, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, token::{Brace, Paren}, - Arm, Attribute, Expr, ExprMatch, FieldPat, Ident, Member, Pat, PatIdent, PatOr, PatParen, - PatPath, PatRest, PatStruct, PatTupleStruct, PatWild, Path, PathSegment, Token, TypePath, + Arm, Attribute, Expr, ExprMatch, FieldPat, Ident, Local, Member, Pat, PatIdent, PatOr, + PatParen, PatPath, PatRest, PatStruct, PatTuple, PatTupleStruct, PatWild, Path, PathSegment, + Token, TypePath, }; +macro_rules! visit_trait { + ( + $($vis:vis fn $fn:ident($state:ident: _, $value:ident: &$Value:ty) $block:block)* + ) => { + trait VisitMatchPat<'a> { + $(fn $fn(&mut self, $value: &'a $Value) { + $fn(self, $value); + })* + } + + $($vis fn $fn<'a>($state: &mut (impl ?Sized + VisitMatchPat<'a>), $value: &'a $Value) $block)* + }; +} + +visit_trait! { + fn visit_match_pat_binding(_state: _, v: &MatchPatBinding) { + let MatchPatBinding { ident: _ } = v; + } + fn visit_match_pat_wild(_state: _, v: &MatchPatWild) { + let MatchPatWild { underscore_token: _ } = v; + } + fn visit_match_pat_rest(_state: _, v: &MatchPatRest) { + let MatchPatRest { dot2_token: _ } = v; + } + fn visit_match_pat_paren(state: _, v: &MatchPatParen) { + let MatchPatParen { paren_token: _, pat } = v; + state.visit_match_pat(pat); + } + fn visit_match_pat_paren_simple(state: _, v: &MatchPatParen) { + let MatchPatParen { paren_token: _, pat } = v; + state.visit_match_pat_simple(pat); + } + fn visit_match_pat_or(state: _, v: &MatchPatOr) { + let MatchPatOr { leading_vert: _, cases } = v; + for v in cases { + state.visit_match_pat(v); + } + } + fn visit_match_pat_or_simple(state: _, v: &MatchPatOr) { + let MatchPatOr { leading_vert: _, cases } = v; + for v in cases { + state.visit_match_pat_simple(v); + } + } + fn visit_match_pat_struct_field(state: _, v: &MatchPatStructField) { + let MatchPatStructField { field_name: _, colon_token: _, pat } = v; + state.visit_match_pat_simple(pat); + } + fn visit_match_pat_struct(state: _, v: &MatchPatStruct) { + let MatchPatStruct { match_span: _, path: _, brace_token: _, fields, rest: _ } = v; + for v in fields { + state.visit_match_pat_struct_field(v); + } + } + fn visit_match_pat_tuple(state: _, v: &MatchPatTuple) { + let MatchPatTuple { paren_token: _, fields } = v; + for v in fields { + state.visit_match_pat_simple(v); + } + } + fn visit_match_pat_enum_variant(state: _, v: &MatchPatEnumVariant) { + let MatchPatEnumVariant {match_span:_, variant_path: _, enum_path: _, variant_name: _, field } = v; + if let Some((_, v)) = field { + state.visit_match_pat_simple(v); + } + } + fn visit_match_pat_simple(state: _, v: &MatchPatSimple) { + match v { + MatchPatSimple::Paren(v) => state.visit_match_pat_paren_simple(v), + MatchPatSimple::Or(v) => state.visit_match_pat_or_simple(v), + MatchPatSimple::Binding(v) => state.visit_match_pat_binding(v), + MatchPatSimple::Wild(v) => state.visit_match_pat_wild(v), + MatchPatSimple::Rest(v) => state.visit_match_pat_rest(v), + } + } + fn visit_match_pat(state: _, v: &MatchPat) { + match v { + MatchPat::Simple(v) => state.visit_match_pat_simple(v), + MatchPat::Or(v) => state.visit_match_pat_or(v), + MatchPat::Paren(v) => state.visit_match_pat_paren(v), + MatchPat::Struct(v) => state.visit_match_pat_struct(v), + MatchPat::Tuple(v) => state.visit_match_pat_tuple(v), + MatchPat::EnumVariant(v) => state.visit_match_pat_enum_variant(v), + } + } +} + with_debug_clone_and_fold! { struct MatchPatBinding<> { ident: Ident, @@ -53,6 +142,15 @@ with_debug_clone_and_fold! { } } +impl

MatchPatOr

{ + /// returns the first `|` between two patterns + fn first_inner_vert(&self) -> Option { + let mut pairs = self.cases.pairs(); + pairs.next_back(); + pairs.next().and_then(|v| v.into_tuple().1.copied()) + } +} + impl ToTokens for MatchPatOr

{ fn to_tokens(&self, tokens: &mut TokenStream) { let Self { @@ -77,6 +175,19 @@ impl ToTokens for MatchPatWild { } } +with_debug_clone_and_fold! { + struct MatchPatRest<> { + dot2_token: Token![..], + } +} + +impl ToTokens for MatchPatRest { + fn to_tokens(&self, tokens: &mut TokenStream) { + let Self { dot2_token } = self; + dot2_token.to_tokens(tokens); + } +} + with_debug_clone_and_fold! { struct MatchPatStructField<> { field_name: Ident, @@ -159,6 +270,25 @@ impl ToTokens for MatchPatStruct { } } +with_debug_clone_and_fold! { + struct MatchPatTuple<> { + paren_token: Paren, + fields: Punctuated, + } +} + +impl ToTokens for MatchPatTuple { + fn to_tokens(&self, tokens: &mut TokenStream) { + let Self { + paren_token, + fields, + } = self; + paren_token.surround(tokens, |tokens| { + fields.to_tokens(tokens); + }) + } +} + with_debug_clone_and_fold! { struct MatchPatEnumVariant<> { match_span: Span, @@ -194,6 +324,7 @@ enum MatchPatSimple { Or(MatchPatOr), Binding(MatchPatBinding), Wild(MatchPatWild), + Rest(MatchPatRest), } impl_fold! { @@ -202,6 +333,7 @@ impl_fold! { Or(MatchPatOr), Binding(MatchPatBinding), Wild(MatchPatWild), + Rest(MatchPatRest), } } @@ -212,6 +344,7 @@ impl ToTokens for MatchPatSimple { Self::Paren(v) => v.to_tokens(tokens), Self::Binding(v) => v.to_tokens(tokens), Self::Wild(v) => v.to_tokens(tokens), + Self::Rest(v) => v.to_tokens(tokens), } } } @@ -278,6 +411,7 @@ trait ParseMatchPat: Sized { fn or(v: MatchPatOr) -> Self; fn paren(v: MatchPatParen) -> Self; fn struct_(state: &mut HdlMatchParseState<'_>, v: MatchPatStruct) -> Result; + fn tuple(state: &mut HdlMatchParseState<'_>, v: MatchPatTuple) -> Result; fn enum_variant(state: &mut HdlMatchParseState<'_>, v: MatchPatEnumVariant) -> Result; fn parse(state: &mut HdlMatchParseState<'_>, pat: Pat) -> Result { @@ -462,7 +596,34 @@ trait ParseMatchPat: Sized { }) => Ok(Self::simple(MatchPatSimple::Wild(MatchPatWild { underscore_token, }))), - Pat::Tuple(_) | Pat::Slice(_) | Pat::Const(_) | Pat::Lit(_) | Pat::Range(_) => { + Pat::Tuple(PatTuple { + attrs: _, + paren_token, + elems, + }) => { + let fields = elems + .into_pairs() + .filter_map_pair_value(|field_pat| { + if let Pat::Rest(PatRest { + attrs: _, + dot2_token, + }) = field_pat + { + Some(MatchPatSimple::Rest(MatchPatRest { dot2_token })) + } else { + MatchPatSimple::parse(state, field_pat).ok() + } + }) + .collect(); + Self::tuple( + state, + MatchPatTuple { + paren_token, + fields, + }, + ) + } + Pat::Slice(_) | Pat::Const(_) | Pat::Lit(_) | Pat::Range(_) => { state .errors .error(pat, "not yet implemented in #[hdl] patterns"); @@ -497,6 +658,14 @@ impl ParseMatchPat for MatchPatSimple { Err(()) } + fn tuple(state: &mut HdlMatchParseState<'_>, v: MatchPatTuple) -> Result { + state.errors.push(syn::Error::new( + v.paren_token.span.open(), + "matching tuples is not yet implemented inside structs/enums in #[hdl] patterns", + )); + Err(()) + } + fn enum_variant( state: &mut HdlMatchParseState<'_>, v: MatchPatEnumVariant, @@ -515,6 +684,7 @@ enum MatchPat { Or(MatchPatOr), Paren(MatchPatParen), Struct(MatchPatStruct), + Tuple(MatchPatTuple), EnumVariant(MatchPatEnumVariant), } @@ -524,6 +694,7 @@ impl_fold! { Or(MatchPatOr), Paren(MatchPatParen), Struct(MatchPatStruct), + Tuple(MatchPatTuple), EnumVariant(MatchPatEnumVariant), } } @@ -545,6 +716,10 @@ impl ParseMatchPat for MatchPat { Ok(Self::Struct(v)) } + fn tuple(_state: &mut HdlMatchParseState<'_>, v: MatchPatTuple) -> Result { + Ok(Self::Tuple(v)) + } + fn enum_variant( _state: &mut HdlMatchParseState<'_>, v: MatchPatEnumVariant, @@ -560,6 +735,7 @@ impl ToTokens for MatchPat { Self::Or(v) => v.to_tokens(tokens), Self::Paren(v) => v.to_tokens(tokens), Self::Struct(v) => v.to_tokens(tokens), + Self::Tuple(v) => v.to_tokens(tokens), Self::EnumVariant(v) => v.to_tokens(tokens), } } @@ -622,10 +798,6 @@ struct RewriteAsCheckMatch { } impl Fold for RewriteAsCheckMatch { - fn fold_field_pat(&mut self, mut i: FieldPat) -> FieldPat { - i.colon_token = Some(Token![:](i.member.span())); - i - } fn fold_pat(&mut self, pat: Pat) -> Pat { match pat { Pat::Ident(mut pat_ident) => match parse_enum_ident(pat_ident.ident) { @@ -740,6 +912,30 @@ impl Fold for RewriteAsCheckMatch { // don't recurse into expressions i } + fn fold_local(&mut self, mut let_stmt: Local) -> Local { + if let Some(syn::LocalInit { + eq_token, + expr: _, + diverge, + }) = let_stmt.init.take() + { + let_stmt.init = Some(syn::LocalInit { + eq_token, + expr: parse_quote_spanned! {self.span=> + __match_value + }, + diverge: diverge.map(|(else_, _expr)| { + ( + else_, + parse_quote_spanned! {self.span=> + match __infallible {} + }, + ) + }), + }); + } + fold_local(self, let_stmt) + } } struct HdlMatchParseState<'a> { @@ -747,7 +943,123 @@ struct HdlMatchParseState<'a> { errors: &'a mut Errors, } +struct HdlLetPatVisitState<'a> { + errors: &'a mut Errors, + bindings: BTreeSet<&'a Ident>, +} + +impl<'a> VisitMatchPat<'a> for HdlLetPatVisitState<'a> { + fn visit_match_pat_binding(&mut self, v: &'a MatchPatBinding) { + self.bindings.insert(&v.ident); + } + + fn visit_match_pat_or(&mut self, v: &'a MatchPatOr) { + if let Some(first_inner_vert) = v.first_inner_vert() { + self.errors.error( + first_inner_vert, + "or-patterns are not supported in let statements", + ); + } + visit_match_pat_or(self, v); + } + + fn visit_match_pat_or_simple(&mut self, v: &'a MatchPatOr) { + if let Some(first_inner_vert) = v.first_inner_vert() { + self.errors.error( + first_inner_vert, + "or-patterns are not supported in let statements", + ); + } + visit_match_pat_or_simple(self, v); + } + + fn visit_match_pat_enum_variant(&mut self, v: &'a MatchPatEnumVariant) { + self.errors.error(v, "refutable pattern in let statement"); + } +} + impl Visitor<'_> { + pub(crate) fn process_hdl_let_pat( + &mut self, + _hdl_attr: HdlAttr, + mut let_stmt: Local, + ) -> Local { + let span = let_stmt.let_token.span(); + if let Pat::Type(pat) = &mut let_stmt.pat { + *pat.ty = wrap_ty_with_expr((*pat.ty).clone()); + } + let check_let_stmt = RewriteAsCheckMatch { span }.fold_local(let_stmt.clone()); + let Local { + attrs: _, + let_token, + pat, + init, + semi_token, + } = let_stmt; + self.require_normal_module_or_fn(let_token); + let Some(syn::LocalInit { + eq_token, + expr, + diverge, + }) = init + else { + self.errors + .error(let_token, "#[hdl] let must be assigned a value"); + return empty_let(); + }; + if let Some((else_, _)) = diverge { + // TODO: implement let-else + self.errors + .error(else_, "#[hdl] let ... else { ... } is not implemented"); + return empty_let(); + } + let Ok(pat) = MatchPat::parse( + &mut HdlMatchParseState { + match_span: span, + errors: &mut self.errors, + }, + pat, + ) else { + return empty_let(); + }; + let mut state = HdlLetPatVisitState { + errors: &mut self.errors, + bindings: BTreeSet::new(), + }; + state.visit_match_pat(&pat); + let HdlLetPatVisitState { + errors: _, + bindings, + } = state; + let retval = parse_quote_spanned! {span=> + let (#(#bindings,)* __scope,) = { + type __MatchTy = ::MatchVariant; + let __match_expr = ::fayalite::expr::ToExpr::to_expr(&(#expr)); + ::fayalite::expr::check_match_expr(__match_expr, |__match_value, __infallible| { + #[allow(unused_variables)] + #check_let_stmt + match __infallible {} + }); + let mut __match_iter = ::fayalite::module::match_(__match_expr); + let ::fayalite::__std::option::Option::Some(__match_variant) = ::fayalite::__std::iter::Iterator::next(&mut __match_iter) else { + ::fayalite::__std::unreachable!("#[hdl] let with uninhabited type"); + }; + let ::fayalite::__std::option::Option::None = ::fayalite::__std::iter::Iterator::next(&mut __match_iter) else { + ::fayalite::__std::unreachable!("#[hdl] let with refutable pattern"); + }; + let (__match_variant, __scope) = + ::fayalite::ty::MatchVariantAndInactiveScope::match_activate_scope( + __match_variant, + ); + #let_token #pat #eq_token __match_variant #semi_token + (#(#bindings,)* __scope,) + }; + }; + match retval { + syn::Stmt::Local(retval) => retval, + _ => unreachable!(), + } + } pub(crate) fn process_hdl_match( &mut self, _hdl_attr: HdlAttr, diff --git a/crates/fayalite/src/_docs/modules/module_bodies/hdl_let_statements.rs b/crates/fayalite/src/_docs/modules/module_bodies/hdl_let_statements.rs index 61d29b5..229871b 100644 --- a/crates/fayalite/src/_docs/modules/module_bodies/hdl_let_statements.rs +++ b/crates/fayalite/src/_docs/modules/module_bodies/hdl_let_statements.rs @@ -2,6 +2,7 @@ // See Notices.txt for copyright information //! ## `#[hdl] let` statements +pub mod destructuring; pub mod inputs_outputs; pub mod instances; pub mod memories; diff --git a/crates/fayalite/src/_docs/modules/module_bodies/hdl_let_statements/destructuring.rs b/crates/fayalite/src/_docs/modules/module_bodies/hdl_let_statements/destructuring.rs new file mode 100644 index 0000000..1fc4705 --- /dev/null +++ b/crates/fayalite/src/_docs/modules/module_bodies/hdl_let_statements/destructuring.rs @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// See Notices.txt for copyright information +//! ### Destructuring Let +//! +//! You can use `#[hdl] let` to destructure types, similarly to Rust `let` statements with non-trivial patterns. +//! +//! `#[hdl] let` statements can only match one level of struct/tuple pattern for now, +//! e.g. you can match with the pattern `MyStruct { a, b }`, but not `MyStruct { a, b: Struct2 { v } }`. +//! +//! ``` +//! # use fayalite::prelude::*; +//! #[hdl] +//! struct MyStruct { +//! a: UInt<8>, +//! b: Bool, +//! } +//! +//! #[hdl_module] +//! fn my_module() { +//! #[hdl] +//! let my_input: MyStruct = m.input(); +//! #[hdl] +//! let my_output: UInt<8> = m.input(); +//! #[hdl] +//! let MyStruct { a, b } = my_input; +//! #[hdl] +//! if b { +//! connect(my_output, a); +//! } else { +//! connect(my_output, 0_hdl_u8); +//! } +//! } +//! ``` diff --git a/crates/fayalite/src/_docs/modules/module_bodies/hdl_match_statements.rs b/crates/fayalite/src/_docs/modules/module_bodies/hdl_match_statements.rs index 9e6c511..6df70f1 100644 --- a/crates/fayalite/src/_docs/modules/module_bodies/hdl_match_statements.rs +++ b/crates/fayalite/src/_docs/modules/module_bodies/hdl_match_statements.rs @@ -7,5 +7,5 @@ //! //! `#[hdl] match` statements' bodies must evaluate to type `()` for now. //! -//! `#[hdl] match` statements can only match one level of struct/enum pattern for now, +//! `#[hdl] match` statements can only match one level of struct/tuple/enum pattern for now, //! e.g. you can match with the pattern `HdlSome(v)`, but not `HdlSome(HdlSome(_))`. diff --git a/crates/fayalite/tests/module.rs b/crates/fayalite/tests/module.rs index 4e56df4..49f5689 100644 --- a/crates/fayalite/tests/module.rs +++ b/crates/fayalite/tests/module.rs @@ -4345,3 +4345,79 @@ circuit check_cfgs: ", }; } + +#[hdl_module(outline_generated)] +pub fn check_let_patterns() { + #[hdl] + let tuple_in: (UInt<1>, SInt<1>, Bool) = m.input(); + #[hdl] + let (tuple_0, tuple_1, tuple_2) = tuple_in; + #[hdl] + let tuple_0_out: UInt<1> = m.output(); + connect(tuple_0_out, tuple_0); + #[hdl] + let tuple_1_out: SInt<1> = m.output(); + connect(tuple_1_out, tuple_1); + #[hdl] + let tuple_2_out: Bool = m.output(); + connect(tuple_2_out, tuple_2); + + #[hdl] + let test_struct_in: TestStruct> = m.input(); + #[hdl] + let TestStruct::<_> { a, b } = test_struct_in; + #[hdl] + let test_struct_a_out: SInt<8> = m.output(); + connect(test_struct_a_out, a); + #[hdl] + let test_struct_b_out: UInt<8> = m.output(); + connect(test_struct_b_out, b); + + #[hdl] + let test_struct_2_in: TestStruct2 = m.input(); + #[hdl] + let TestStruct2 { v } = test_struct_2_in; + #[hdl] + let test_struct_2_v_out: UInt<8> = m.output(); + connect(test_struct_2_v_out, v); + + #[hdl] + let test_struct_3_in: TestStruct3 = m.input(); + #[hdl] + let TestStruct3 {} = test_struct_3_in; +} + +#[test] +fn test_let_patterns() { + let _n = SourceLocation::normalize_files_for_tests(); + let m = check_let_patterns(); + dbg!(m); + #[rustfmt::skip] // work around https://github.com/rust-lang/rustfmt/issues/6161 + assert_export_firrtl! { + m => + "/test/check_let_patterns.fir": r"FIRRTL version 3.2.0 +circuit check_let_patterns: + type Ty0 = {`0`: UInt<1>, `1`: SInt<1>, `2`: UInt<1>} + type Ty1 = {a: SInt<8>, b: UInt<8>} + type Ty2 = {v: UInt<8>} + type Ty3 = {} + module check_let_patterns: @[module-XXXXXXXXXX.rs 1:1] + input tuple_in: Ty0 @[module-XXXXXXXXXX.rs 2:1] + output tuple_0_out: UInt<1> @[module-XXXXXXXXXX.rs 4:1] + output tuple_1_out: SInt<1> @[module-XXXXXXXXXX.rs 6:1] + output tuple_2_out: UInt<1> @[module-XXXXXXXXXX.rs 8:1] + input test_struct_in: Ty1 @[module-XXXXXXXXXX.rs 10:1] + output test_struct_a_out: SInt<8> @[module-XXXXXXXXXX.rs 12:1] + output test_struct_b_out: UInt<8> @[module-XXXXXXXXXX.rs 14:1] + input test_struct_2_in: Ty2 @[module-XXXXXXXXXX.rs 16:1] + output test_struct_2_v_out: UInt<8> @[module-XXXXXXXXXX.rs 18:1] + input test_struct_3_in: Ty3 @[module-XXXXXXXXXX.rs 20:1] + connect tuple_0_out, tuple_in.`0` @[module-XXXXXXXXXX.rs 5:1] + connect tuple_1_out, tuple_in.`1` @[module-XXXXXXXXXX.rs 7:1] + connect tuple_2_out, tuple_in.`2` @[module-XXXXXXXXXX.rs 9:1] + connect test_struct_a_out, test_struct_in.a @[module-XXXXXXXXXX.rs 13:1] + connect test_struct_b_out, test_struct_in.b @[module-XXXXXXXXXX.rs 15:1] + connect test_struct_2_v_out, test_struct_2_in.v @[module-XXXXXXXXXX.rs 19:1] +", + }; +}