From 9092e45447136e332a346c82d4e2bb6bb16dac0f Mon Sep 17 00:00:00 2001
From: Jacob Lifshay <programmerjake@gmail.com>
Date: Sun, 30 Mar 2025 01:25:07 -0700
Subject: [PATCH] fix #[hdl(sim)] match on enums

---
 .../src/module/transform_body/expand_match.rs | 46 +++++++++++++++++--
 crates/fayalite/tests/sim.rs                  |  6 +++
 2 files changed, 47 insertions(+), 5 deletions(-)

diff --git a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs
index 57e919a..a2e0375 100644
--- a/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs
+++ b/crates/fayalite-proc-macros-impl/src/module/transform_body/expand_match.rs
@@ -83,7 +83,14 @@ visit_trait! {
         }
     }
     fn visit_match_pat_enum_variant(state: _, v: &MatchPatEnumVariant) {
-        let MatchPatEnumVariant {match_span:_, variant_path: _, enum_path: _, variant_name: _, field } = v;
+        let MatchPatEnumVariant {
+            match_span:_,
+            sim:_,
+            variant_path: _,
+            enum_path: _,
+            variant_name: _,
+            field,
+        } = v;
         if let Some((_, v)) = field {
             state.visit_match_pat_simple(v);
         }
@@ -293,6 +300,7 @@ impl ToTokens for MatchPatTuple {
 with_debug_clone_and_fold! {
     struct MatchPatEnumVariant<> {
         match_span: Span,
+        sim: Option<(kw::sim,)>,
         variant_path: Path,
         enum_path: Path,
         variant_name: Ident,
@@ -304,6 +312,7 @@ impl ToTokens for MatchPatEnumVariant {
     fn to_tokens(&self, tokens: &mut TokenStream) {
         let Self {
             match_span,
+            sim,
             variant_path: _,
             enum_path,
             variant_name,
@@ -313,7 +322,28 @@ impl ToTokens for MatchPatEnumVariant {
             __MatchTy::<#enum_path>::#variant_name
         }
         .to_tokens(tokens);
-        if let Some((paren_token, field)) = field {
+        if sim.is_some() {
+            if let Some((paren_token, field)) = field {
+                paren_token.surround(tokens, |tokens| {
+                    field.to_tokens(tokens);
+                    match field {
+                        MatchPatSimple::Paren(_)
+                        | MatchPatSimple::Or(_)
+                        | MatchPatSimple::Binding(_)
+                        | MatchPatSimple::Wild(_) => quote_spanned! {*match_span=>
+                            , _
+                        }
+                        .to_tokens(tokens),
+                        MatchPatSimple::Rest(_) => {}
+                    }
+                });
+            } else {
+                quote_spanned! {*match_span=>
+                    (_)
+                }
+                .to_tokens(tokens);
+            }
+        } else if let Some((paren_token, field)) = field {
             paren_token.surround(tokens, |tokens| field.to_tokens(tokens));
         }
     }
@@ -448,6 +478,7 @@ trait ParseMatchPat: Sized {
                         state,
                         MatchPatEnumVariant {
                             match_span: state.match_span,
+                            sim: state.sim,
                             variant_path,
                             enum_path,
                             variant_name,
@@ -494,6 +525,7 @@ trait ParseMatchPat: Sized {
                     state,
                     MatchPatEnumVariant {
                         match_span: state.match_span,
+                        sim: state.sim,
                         variant_path,
                         enum_path,
                         variant_name,
@@ -578,6 +610,7 @@ trait ParseMatchPat: Sized {
                     state,
                     MatchPatEnumVariant {
                         match_span: state.match_span,
+                        sim: state.sim,
                         variant_path,
                         enum_path,
                         variant_name,
@@ -940,6 +973,7 @@ impl Fold for RewriteAsCheckMatch {
 }
 
 struct HdlMatchParseState<'a> {
+    sim: Option<(kw::sim,)>,
     match_span: Span,
     errors: &'a mut Errors,
 }
@@ -986,6 +1020,7 @@ impl Visitor<'_> {
         mut let_stmt: Local,
     ) -> Local {
         let span = let_stmt.let_token.span();
+        let ExprOptions { sim } = hdl_attr.body;
         if let Pat::Type(pat) = &mut let_stmt.pat {
             *pat.ty = wrap_ty_with_expr((*pat.ty).clone());
         }
@@ -1015,6 +1050,7 @@ impl Visitor<'_> {
         }
         let Ok(pat) = MatchPat::parse(
             &mut HdlMatchParseState {
+                sim,
                 match_span: span,
                 errors: &mut self.errors,
             },
@@ -1031,7 +1067,6 @@ impl Visitor<'_> {
             errors: _,
             bindings,
         } = state;
-        let ExprOptions { sim } = hdl_attr.body;
         let retval = if sim.is_some() {
             parse_quote_spanned! {span=>
                 let (#(#bindings,)*) = {
@@ -1093,7 +1128,9 @@ impl Visitor<'_> {
             brace_token: _,
             arms,
         } = expr_match;
+        let ExprOptions { sim } = hdl_attr.body;
         let mut state = HdlMatchParseState {
+            sim,
             match_span: span,
             errors: &mut self.errors,
         };
@@ -1101,13 +1138,12 @@ impl Visitor<'_> {
             arms.into_iter()
                 .filter_map(|arm| MatchArm::parse(&mut state, arm).ok()),
         );
-        let ExprOptions { sim } = hdl_attr.body;
         let expr = if sim.is_some() {
             quote_spanned! {span=>
                 {
                     type __MatchTy<T> = <T as ::fayalite::ty::Type>::SimValue;
                     let __match_expr = ::fayalite::sim::value::ToSimValue::to_sim_value(&(#expr));
-                    #match_token *__match_expr {
+                    #match_token ::fayalite::sim::value::SimValue::into_value(__match_expr) {
                         #(#arms)*
                     }
                 }
diff --git a/crates/fayalite/tests/sim.rs b/crates/fayalite/tests/sim.rs
index 398fe18..71e53ea 100644
--- a/crates/fayalite/tests/sim.rs
+++ b/crates/fayalite/tests/sim.rs
@@ -497,6 +497,12 @@ fn test_enums() {
             "vcd:\n{}\ncycle: {cycle}",
             String::from_utf8(writer.take()).unwrap(),
         );
+        // make sure matching on SimValue<SomeEnum> works
+        #[hdl(sim)]
+        match io.b_out {
+            HdlNone => println!("io.b_out is HdlNone"),
+            HdlSome(v) => println!("io.b_out is HdlSome(({:?}, {:?}))", *v.0, *v.1),
+        }
         sim.write_clock(sim.io().cd.clk, false);
         sim.advance_time(SimDuration::from_micros(1));
         sim.write_clock(sim.io().cd.clk, true);