// SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information use crate::{ bundle::{BundleField, BundleType}, enum_::{EnumType, EnumVariant}, expr::{ ExprEnum, ops::{self, EnumLiteral, StructuralEq, StructuralEqFlags}, }, intern::{Intern, InternSlice, Interned, Memoize}, memory::{DynPortType, MemPort}, module::{ Block, Id, NameId, ScopedNameId, Stmt, StmtConnect, StmtIf, StmtMatch, StmtWire, transform::visit::{Fold, Folder}, }, prelude::*, util::HashMap, }; use serde::{Deserialize, Serialize}; use std::fmt; #[derive(Debug)] pub enum SimplifyEnumsError { EnumIsNotCastableFromBits { enum_type: Enum }, } impl fmt::Display for SimplifyEnumsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { SimplifyEnumsError::EnumIsNotCastableFromBits { enum_type } => write!( f, "simplify_enums failed: enum type is not castable from bits: {enum_type:?}" ), } } } impl std::error::Error for SimplifyEnumsError {} impl From for std::io::Error { fn from(value: SimplifyEnumsError) -> Self { std::io::Error::new(std::io::ErrorKind::Other, value) } } fn contains_any_enum_types(ty: CanonicalType) -> bool { #[derive(Copy, Clone, PartialEq, Eq, Hash)] struct TheMemoize; impl Memoize for TheMemoize { type Input = CanonicalType; type InputOwned = CanonicalType; type Output = bool; fn inner(self, ty: &Self::Input) -> Self::Output { match *ty { CanonicalType::Array(array_type) => contains_any_enum_types(array_type.element()), CanonicalType::Enum(_) => true, CanonicalType::Bundle(bundle) => bundle .fields() .iter() .any(|field| contains_any_enum_types(field.ty)), CanonicalType::TraceAsString(ty) => contains_any_enum_types(ty.inner_ty()), CanonicalType::UInt(_) | CanonicalType::SInt(_) | CanonicalType::Bool(_) | CanonicalType::AsyncReset(_) | CanonicalType::SyncReset(_) | CanonicalType::Reset(_) | CanonicalType::Clock(_) | CanonicalType::PhantomConst(_) | CanonicalType::DynSimOnly(_) => false, } } } TheMemoize.get_owned(ty) } #[hdl] struct TagAndBody { tag: Tag, body: Body, } #[derive(Clone, Debug)] enum EnumTypeState { TagEnumAndBody(TagAndBody), TagUIntAndBody(TagAndBody), UInt(UInt), Unchanged, } struct ModuleState { module_name: NameId, expr_cache: HashMap, source_location: SourceLocation, } impl ModuleState { fn gen_name(&mut self, name: &str) -> ScopedNameId { ScopedNameId(self.module_name.into(), NameId(name.intern(), Id::new())) } } struct State { enum_types: HashMap, replacement_mem_ports: HashMap, Wire>, kind: SimplifyEnumsKind, module_state_stack: Vec, new_prefix_stmts_for_block: Vec, new_suffix_stmts_for_block: Vec, } struct BlockScope<'a> { state: &'a mut State, parent_new_prefix_stmts_for_block: Vec, parent_new_suffix_stmts_for_block: Vec, } impl<'a> BlockScope<'a> { fn new( state: &'a mut State, new_prefix_stmts_for_block: Vec, new_suffix_stmts_for_block: Vec, ) -> Self { let parent_new_prefix_stmts_for_block = std::mem::replace( &mut state.new_prefix_stmts_for_block, new_prefix_stmts_for_block, ); let parent_new_suffix_stmts_for_block = std::mem::replace( &mut state.new_suffix_stmts_for_block, new_suffix_stmts_for_block, ); Self { state, parent_new_prefix_stmts_for_block, parent_new_suffix_stmts_for_block, } } } impl Drop for BlockScope<'_> { fn drop(&mut self) { self.state.new_prefix_stmts_for_block = std::mem::take(&mut self.parent_new_prefix_stmts_for_block); self.state.new_suffix_stmts_for_block = std::mem::take(&mut self.parent_new_suffix_stmts_for_block); } } impl State { fn get_or_make_enum_type_state( &mut self, enum_type: Enum, ) -> Result { if let Some(retval) = self.enum_types.get(&enum_type) { return Ok(retval.clone()); } if !enum_type.type_properties().is_castable_from_bits { return Err(SimplifyEnumsError::EnumIsNotCastableFromBits { enum_type }); } let has_body = enum_type .variants() .iter() .any(|variant| variant.ty.is_some()); let retval = match (self.kind, has_body) { (SimplifyEnumsKind::SimplifyToEnumsWithNoBody, true) => { EnumTypeState::TagEnumAndBody(TagAndBody { tag: Enum::new(Interned::from_iter(enum_type.variants().iter().map(|v| { EnumVariant { name: v.name, ty: None, } }))), body: UInt::new_dyn( enum_type.type_properties().bit_width - enum_type.discriminant_bit_width(), ), }) } (SimplifyEnumsKind::SimplifyToEnumsWithNoBody, false) => EnumTypeState::Unchanged, (SimplifyEnumsKind::ReplaceWithBundleOfUInts, _) => { EnumTypeState::TagUIntAndBody(TagAndBody { tag: UInt::new_dyn(enum_type.discriminant_bit_width()), body: UInt::new_dyn( enum_type.type_properties().bit_width - enum_type.discriminant_bit_width(), ), }) } (SimplifyEnumsKind::ReplaceWithUInt, _) => { EnumTypeState::UInt(UInt::new_dyn(enum_type.type_properties().bit_width)) } }; self.enum_types.insert(enum_type, retval.clone()); Ok(retval) } #[hdl] fn handle_enum_literal( &mut self, unfolded_enum_type: Enum, variant_index: usize, folded_variant_value: Option>, ) -> Result, SimplifyEnumsError> { Ok( match self.get_or_make_enum_type_state(unfolded_enum_type)? { EnumTypeState::TagEnumAndBody(TagAndBody { tag, body }) => Expr::canonical( #[hdl] TagAndBody { tag: EnumLiteral::new_by_index(tag, variant_index, None), body: match folded_variant_value { Some(variant_value) => variant_value.cast_to_bits().cast_to(body), None => body.zero().to_expr(), }, }, ), EnumTypeState::TagUIntAndBody(TagAndBody { tag, body }) => Expr::canonical( #[hdl] TagAndBody { tag: tag.from_int_wrapping(variant_index), body: match folded_variant_value { Some(folded_variant_value) => { folded_variant_value.cast_to_bits().cast_to(body) } None => body.zero().to_expr(), }, }, ), EnumTypeState::UInt(_) => { let tag = UInt[unfolded_enum_type.discriminant_bit_width()]; let body = UInt[unfolded_enum_type.type_properties().bit_width - tag.width()]; Expr::canonical( (#[hdl] TagAndBody { tag: tag.from_int_wrapping(variant_index), body: match folded_variant_value { Some(folded_variant_value) => { folded_variant_value.cast_to_bits().cast_to(body) } None => body.zero().to_expr(), }, }) .cast_to_bits(), ) } EnumTypeState::Unchanged => Expr::canonical( ops::EnumLiteral::new_by_index( unfolded_enum_type, variant_index, folded_variant_value, ) .to_expr(), ), }, ) } fn handle_variant_access( &mut self, unfolded_enum_type: Enum, folded_base_expr: Expr, variant_index: usize, ) -> Result, SimplifyEnumsError> { let unfolded_variant_type = unfolded_enum_type.variants()[variant_index].ty; Ok( match self.get_or_make_enum_type_state(unfolded_enum_type)? { EnumTypeState::TagEnumAndBody(_) | EnumTypeState::TagUIntAndBody(_) => { match unfolded_variant_type { Some(variant_type) => Expr::canonical( Expr::>::from_canonical( folded_base_expr, ) .body[..variant_type.bit_width()] .cast_bits_to(variant_type.fold(self)?), ), None => Expr::canonical(().to_expr()), } } EnumTypeState::UInt(_) => match unfolded_variant_type { Some(variant_type) => { let base_int = Expr::::from_canonical(folded_base_expr); let variant_type_bit_width = variant_type.bit_width(); Expr::canonical( base_int[unfolded_enum_type.discriminant_bit_width()..] [..variant_type_bit_width] .cast_bits_to(variant_type.fold(self)?), ) } None => Expr::canonical(().to_expr()), }, EnumTypeState::Unchanged => match unfolded_variant_type { Some(_) => ops::VariantAccess::new_by_index( Expr::from_canonical(folded_base_expr), variant_index, ) .to_expr(), None => Expr::canonical(().to_expr()), }, }, ) } fn handle_match( &mut self, unfolded_enum_type: Enum, folded_expr: Expr, source_location: SourceLocation, folded_blocks: &[Block], ) -> Result { match self.get_or_make_enum_type_state(unfolded_enum_type)? { EnumTypeState::TagEnumAndBody(_) => Ok(StmtMatch { expr: Expr::>::from_canonical(folded_expr).tag, source_location, blocks: folded_blocks.intern(), } .into()), EnumTypeState::TagUIntAndBody(_) => { let int_tag_expr = Expr::>::from_canonical(folded_expr).tag; Ok(match_int_tag(int_tag_expr, source_location, folded_blocks).into()) } EnumTypeState::UInt(_) => { let int_tag_expr = Expr::::from_canonical(folded_expr) [..unfolded_enum_type.discriminant_bit_width()]; Ok(match_int_tag(int_tag_expr, source_location, folded_blocks).into()) } EnumTypeState::Unchanged => Ok(StmtMatch { expr: Expr::from_canonical(folded_expr), source_location, blocks: folded_blocks.intern(), } .into()), } } fn handle_stmt_connect_array( &mut self, unfolded_lhs_ty: Array, unfolded_rhs_ty: Array, folded_lhs: Expr, folded_rhs: Expr, source_location: SourceLocation, output_stmts: &mut Vec, ) -> Result<(), SimplifyEnumsError> { assert_eq!(unfolded_lhs_ty.len(), unfolded_rhs_ty.len()); let unfolded_lhs_element_ty = unfolded_lhs_ty.element(); let unfolded_rhs_element_ty = unfolded_rhs_ty.element(); for array_index in 0..unfolded_lhs_ty.len() { self.handle_stmt_connect( unfolded_lhs_element_ty, unfolded_rhs_element_ty, folded_lhs[array_index], folded_rhs[array_index], source_location, output_stmts, )?; } Ok(()) } fn handle_stmt_connect_trace_as_string( &mut self, unfolded_lhs_ty: TraceAsString, unfolded_rhs_ty: TraceAsString, folded_lhs: Expr, folded_rhs: Expr, source_location: SourceLocation, output_stmts: &mut Vec, ) -> Result<(), SimplifyEnumsError> { self.handle_stmt_connect( unfolded_lhs_ty.inner_ty(), unfolded_rhs_ty.inner_ty(), ops::TraceAsStringAsInner::new(folded_lhs).to_expr(), ops::TraceAsStringAsInner::new(folded_rhs).to_expr(), source_location, output_stmts, ) } fn handle_stmt_connect_bundle( &mut self, unfolded_lhs_ty: Bundle, unfolded_rhs_ty: Bundle, folded_lhs: Expr, folded_rhs: Expr, source_location: SourceLocation, output_stmts: &mut Vec, ) -> Result<(), SimplifyEnumsError> { let unfolded_lhs_fields = unfolded_lhs_ty.fields(); let unfolded_rhs_fields = unfolded_rhs_ty.fields(); assert_eq!(unfolded_lhs_fields.len(), unfolded_rhs_fields.len()); for ( field_index, ( &BundleField { name, flipped, ty: unfolded_lhs_field_ty, }, unfolded_rhs_field, ), ) in unfolded_lhs_fields .iter() .zip(&unfolded_rhs_fields) .enumerate() { assert_eq!(name, unfolded_rhs_field.name); assert_eq!(flipped, unfolded_rhs_field.flipped); let folded_lhs_field = ops::FieldAccess::new_by_index(folded_lhs, field_index).to_expr(); let folded_rhs_field = ops::FieldAccess::new_by_index(folded_rhs, field_index).to_expr(); if flipped { // swap lhs/rhs self.handle_stmt_connect( unfolded_rhs_field.ty, unfolded_lhs_field_ty, folded_rhs_field, folded_lhs_field, source_location, output_stmts, )?; } else { self.handle_stmt_connect( unfolded_lhs_field_ty, unfolded_rhs_field.ty, folded_lhs_field, folded_rhs_field, source_location, output_stmts, )?; } } Ok(()) } fn handle_stmt_connect_enum( &mut self, unfolded_lhs_ty: Enum, unfolded_rhs_ty: Enum, folded_lhs: Expr, folded_rhs: Expr, source_location: SourceLocation, output_stmts: &mut Vec, ) -> Result<(), SimplifyEnumsError> { let unfolded_lhs_variants = unfolded_lhs_ty.variants(); let unfolded_rhs_variants = unfolded_rhs_ty.variants(); assert_eq!(unfolded_lhs_variants.len(), unfolded_rhs_variants.len()); let mut folded_blocks = vec![]; for ( variant_index, ( &EnumVariant { name, ty: unfolded_lhs_variant_ty, }, unfolded_rhs_variant, ), ) in unfolded_lhs_variants .iter() .zip(&unfolded_rhs_variants) .enumerate() { let mut output_stmts = vec![]; assert_eq!(name, unfolded_rhs_variant.name); assert_eq!( unfolded_lhs_variant_ty.is_some(), unfolded_rhs_variant.ty.is_some() ); let folded_variant_value = if let (Some(unfolded_lhs_variant_ty), Some(unfolded_rhs_variant_ty)) = (unfolded_lhs_variant_ty, unfolded_rhs_variant.ty) { let lhs_wire = Wire::new_unchecked( self.module_state_stack .last_mut() .unwrap() .gen_name("__connect_variant_body"), source_location, unfolded_lhs_variant_ty.fold(self)?, ); output_stmts.push( StmtWire { annotations: Interned::default(), wire: lhs_wire, } .into(), ); let lhs_wire = lhs_wire.to_expr(); let folded_rhs_variant = self.handle_variant_access(unfolded_rhs_ty, folded_rhs, variant_index)?; self.handle_stmt_connect( unfolded_lhs_variant_ty, unfolded_rhs_variant_ty, lhs_wire, folded_rhs_variant, source_location, &mut output_stmts, )?; Some(lhs_wire) } else { None }; output_stmts.push( StmtConnect { lhs: folded_lhs, rhs: self.handle_enum_literal( unfolded_lhs_ty, variant_index, folded_variant_value, )?, source_location, } .into(), ); folded_blocks.push(Block { memories: Interned::default(), stmts: Intern::intern_owned(output_stmts), }); } output_stmts.push(self.handle_match( unfolded_rhs_ty, folded_rhs, source_location, &folded_blocks, )?); Ok(()) } fn handle_stmt_connect( &mut self, unfolded_lhs_ty: CanonicalType, unfolded_rhs_ty: CanonicalType, folded_lhs: Expr, folded_rhs: Expr, source_location: SourceLocation, output_stmts: &mut Vec, ) -> Result<(), SimplifyEnumsError> { let needs_expansion = unfolded_lhs_ty != unfolded_rhs_ty && (contains_any_enum_types(unfolded_lhs_ty) || contains_any_enum_types(unfolded_rhs_ty)); if !needs_expansion { output_stmts.push( StmtConnect { lhs: folded_lhs, rhs: folded_rhs, source_location, } .into(), ); return Ok(()); } match unfolded_lhs_ty { CanonicalType::Array(unfolded_lhs_ty) => self.handle_stmt_connect_array( unfolded_lhs_ty, Array::from_canonical(unfolded_rhs_ty), Expr::from_canonical(folded_lhs), Expr::from_canonical(folded_rhs), source_location, output_stmts, ), CanonicalType::Enum(unfolded_lhs_ty) => self.handle_stmt_connect_enum( unfolded_lhs_ty, Enum::from_canonical(unfolded_rhs_ty), folded_lhs, folded_rhs, source_location, output_stmts, ), CanonicalType::Bundle(unfolded_lhs_ty) => self.handle_stmt_connect_bundle( unfolded_lhs_ty, Bundle::from_canonical(unfolded_rhs_ty), Expr::from_canonical(folded_lhs), Expr::from_canonical(folded_rhs), source_location, output_stmts, ), CanonicalType::TraceAsString(unfolded_lhs_ty) => self .handle_stmt_connect_trace_as_string( unfolded_lhs_ty, TraceAsString::from_canonical(unfolded_rhs_ty), Expr::from_canonical(folded_lhs), Expr::from_canonical(folded_rhs), source_location, output_stmts, ), CanonicalType::UInt(_) | CanonicalType::SInt(_) | CanonicalType::Bool(_) | CanonicalType::AsyncReset(_) | CanonicalType::SyncReset(_) | CanonicalType::Reset(_) | CanonicalType::Clock(_) | CanonicalType::PhantomConst(_) | CanonicalType::DynSimOnly(_) => unreachable!(), } } fn handle_enum_structural_eq( &mut self, unfolded_ty: Enum, folded_lhs: Expr, folded_rhs: Expr, flags: StructuralEqFlags, ) -> Result, SimplifyEnumsError> { if flags.assume_padding_is_zeroed { return Ok(StructuralEq::with_flags(folded_lhs, folded_rhs, flags).to_expr()); } let enum_type_state = self.get_or_make_enum_type_state(unfolded_ty)?; if let EnumTypeState::Unchanged = enum_type_state { return Ok(StructuralEq::with_flags(folded_lhs, folded_rhs, flags).to_expr()); } let module_state = self.module_state_stack.last_mut().unwrap(); let source_location = module_state.source_location; let output_wire = Wire::new_unchecked( module_state.gen_name("__enum_structural_eq"), source_location, Bool, ); self.new_prefix_stmts_for_block.push( StmtWire { annotations: Interned::default(), wire: output_wire.canonical(), } .into(), ); let output_wire = output_wire.to_expr(); self.new_suffix_stmts_for_block.push( StmtConnect { lhs: Expr::canonical(output_wire), rhs: Expr::canonical(false.to_expr()), source_location, } .into(), ); let tags_eq = match enum_type_state { EnumTypeState::TagEnumAndBody(_) => StructuralEq::with_flags( Expr::canonical(Expr::>::from_canonical(folded_lhs).tag), Expr::canonical(Expr::>::from_canonical(folded_rhs).tag), StructuralEqFlags { assume_padding_is_zeroed: true, }, ) .to_expr(), EnumTypeState::TagUIntAndBody(_) => { let lhs = Expr::>::from_canonical(folded_lhs).tag; let rhs = Expr::>::from_canonical(folded_rhs).tag; lhs.cmp_eq(rhs) } EnumTypeState::UInt(_) => { let lhs_int_tag_expr = Expr::::from_canonical(folded_lhs) [..unfolded_ty.discriminant_bit_width()]; let rhs_int_tag_expr = Expr::::from_canonical(folded_rhs) [..unfolded_ty.discriminant_bit_width()]; lhs_int_tag_expr.cmp_eq(rhs_int_tag_expr) } EnumTypeState::Unchanged => unreachable!(), }; let mut match_arms = Vec::with_capacity(unfolded_ty.variants().len()); for (variant_index, variant) in unfolded_ty.variants().iter().enumerate() { let block_scope = BlockScope::new(self, vec![], vec![]); let this = &mut *block_scope.state; let eq = if let Some(variant_ty) = variant.ty { let folded_lhs = this.handle_variant_access(unfolded_ty, folded_lhs, variant_index)?; let folded_rhs = this.handle_variant_access(unfolded_ty, folded_rhs, variant_index)?; this.handle_structural_eq(variant_ty, folded_lhs, folded_rhs, flags)? } else { true.to_expr() }; match_arms.push(Block { memories: [].intern_slice(), stmts: this .new_prefix_stmts_for_block .drain(..) .chain([StmtConnect { lhs: Expr::canonical(output_wire), rhs: Expr::canonical(eq), source_location, } .into()]) .chain(this.new_suffix_stmts_for_block.drain(..)) .collect(), }); } let match_stmt = self.handle_match(unfolded_ty, folded_lhs, source_location, &match_arms)?; self.new_suffix_stmts_for_block.push( StmtIf { cond: tags_eq, source_location, blocks: [ Block { memories: [].intern_slice(), stmts: [match_stmt].intern_slice(), }, Block { memories: [].intern_slice(), stmts: [].intern_slice(), }, ], } .into(), ); Ok(output_wire) } fn handle_structural_eq( &mut self, unfolded_ty: CanonicalType, folded_lhs: Expr, folded_rhs: Expr, flags: StructuralEqFlags, ) -> Result, SimplifyEnumsError> { if !contains_any_enum_types(unfolded_ty) { return Ok(StructuralEq::with_flags(folded_lhs, folded_rhs, flags).to_expr()); } match unfolded_ty { CanonicalType::Array(unfolded_ty) => { let unfolded_element_ty = unfolded_ty.element(); let mut retval = None; for i in 0..unfolded_ty.len() { let element_eq = self.handle_structural_eq( unfolded_element_ty, ops::ArrayIndex::new(Expr::from_canonical(folded_lhs), i).to_expr(), ops::ArrayIndex::new(Expr::from_canonical(folded_rhs), i).to_expr(), flags, )?; retval = Some(match retval { Some(old_eq) => old_eq & element_eq, None => element_eq, }); } Ok(retval.unwrap_or_else(|| { StructuralEq::with_flags(folded_lhs, folded_rhs, flags).to_expr() })) } CanonicalType::Enum(unfolded_ty) => { self.handle_enum_structural_eq(unfolded_ty, folded_lhs, folded_rhs, flags) } CanonicalType::Bundle(unfolded_ty) => { let mut retval = None; for (i, field) in unfolded_ty.fields().iter().enumerate() { let field_eq = self.handle_structural_eq( field.ty, ops::FieldAccess::new_by_index(Expr::from_canonical(folded_lhs), i) .to_expr(), ops::FieldAccess::new_by_index(Expr::from_canonical(folded_rhs), i) .to_expr(), flags, )?; retval = Some(match retval { Some(old_eq) => old_eq & field_eq, None => field_eq, }); } Ok(retval.unwrap_or_else(|| { StructuralEq::with_flags(folded_lhs, folded_rhs, flags).to_expr() })) } CanonicalType::TraceAsString(unfolded_ty) => self.handle_structural_eq( unfolded_ty.inner_ty(), *Expr::::from_canonical(folded_lhs), *Expr::::from_canonical(folded_rhs), flags, ), CanonicalType::UInt(_) | CanonicalType::SInt(_) | CanonicalType::Bool(_) | CanonicalType::AsyncReset(_) | CanonicalType::SyncReset(_) | CanonicalType::Reset(_) | CanonicalType::Clock(_) | CanonicalType::PhantomConst(_) | CanonicalType::DynSimOnly(_) => unreachable!("doesn't contain any enum types"), } } } fn connect_port( stmts: &mut Vec, lhs: Expr, rhs: Expr, source_location: SourceLocation, ) { let lhs = Expr::unwrap_transparent_types(lhs); let rhs = Expr::unwrap_transparent_types(rhs); if lhs.ty() == rhs.ty() { stmts.push( StmtConnect { lhs, rhs, source_location, } .into(), ); return; } match (lhs.ty(), rhs.ty()) { (CanonicalType::Bundle(lhs_type), CanonicalType::UInt(_) | CanonicalType::Bool(_)) => { let lhs = Expr::::from_canonical(lhs); for field in lhs_type.fields() { assert!(!field.flipped); connect_port(stmts, Expr::field(lhs, &field.name), rhs, source_location); } } (CanonicalType::UInt(_) | CanonicalType::Bool(_), CanonicalType::Bundle(rhs_type)) => { let rhs = Expr::::from_canonical(rhs); for field in rhs_type.fields() { assert!(!field.flipped); connect_port(stmts, lhs, Expr::field(rhs, &field.name), source_location); } } (CanonicalType::Bundle(lhs_type), CanonicalType::Bundle(_)) => { let lhs = Expr::::from_canonical(lhs); let rhs = Expr::::from_canonical(rhs); for field in lhs_type.fields() { let (lhs_field, rhs_field) = if field.flipped { (Expr::field(rhs, &field.name), Expr::field(lhs, &field.name)) } else { (Expr::field(lhs, &field.name), Expr::field(rhs, &field.name)) }; connect_port(stmts, lhs_field, rhs_field, source_location); } } (CanonicalType::Array(lhs_type), CanonicalType::Array(_)) => { let lhs = Expr::::from_canonical(lhs); let rhs = Expr::::from_canonical(rhs); for index in 0..lhs_type.len() { connect_port(stmts, lhs[index], rhs[index], source_location); } } (CanonicalType::TraceAsString(_), CanonicalType::TraceAsString(_)) => { unreachable!("handled by unwrap_transparent_types") } (CanonicalType::Bundle(_), _) | (CanonicalType::Enum(_), _) | (CanonicalType::Array(_), _) | (CanonicalType::UInt(_), _) | (CanonicalType::SInt(_), _) | (CanonicalType::Bool(_), _) | (CanonicalType::Clock(_), _) | (CanonicalType::AsyncReset(_), _) | (CanonicalType::SyncReset(_), _) | (CanonicalType::Reset(_), _) | (CanonicalType::PhantomConst(_), _) | (CanonicalType::DynSimOnly(_), _) | (CanonicalType::TraceAsString(_), _) => unreachable!( "trying to connect memory ports:\n{:?}\n{:?}", lhs.ty(), rhs.ty(), ), } } fn match_int_tag( int_tag_expr: Expr, source_location: SourceLocation, folded_blocks: &[Block], ) -> StmtIf { let mut blocks_iter = folded_blocks.iter().copied().enumerate(); let (_, last_block) = blocks_iter.next_back().unwrap_or_default(); let Some((next_to_last_variant_index, next_to_last_block)) = blocks_iter.next_back() else { return StmtIf { cond: true.to_expr(), source_location, blocks: [last_block, Block::default()], }; }; let mut retval = StmtIf { cond: int_tag_expr.cmp_eq( int_tag_expr .ty() .from_int_wrapping(next_to_last_variant_index), ), source_location, blocks: [next_to_last_block, last_block], }; for (variant_index, block) in blocks_iter.rev() { retval = StmtIf { cond: int_tag_expr.cmp_eq(int_tag_expr.ty().from_int_wrapping(variant_index)), source_location, blocks: [ block, Block { memories: Default::default(), stmts: [Stmt::from(retval)].intern_slice(), }, ], }; } retval } impl Folder for State { type Error = SimplifyEnumsError; fn fold_enum(&mut self, _v: Enum) -> Result { unreachable!() } fn fold_module(&mut self, v: Module) -> Result, Self::Error> { self.module_state_stack.push(ModuleState { module_name: v.name_id(), expr_cache: HashMap::default(), source_location: v.source_location(), }); let retval = Fold::default_fold(v, self); self.module_state_stack.pop(); retval } fn fold_expr_enum(&mut self, op: ExprEnum) -> Result { if let Some(folded_op) = self .module_state_stack .last() .expect("known to be in module") .expr_cache .get(&op) { return Ok(*folded_op); } let folded_op = match op { ExprEnum::EnumLiteral(op) => { let folded_variant_value = op.variant_value().map(|v| v.fold(self)).transpose()?; *Expr::expr_enum(self.handle_enum_literal( op.ty(), op.variant_index(), folded_variant_value, )?) } ExprEnum::VariantAccess(op) => { let folded_base_expr = Expr::canonical(op.base()).fold(self)?; *Expr::expr_enum(self.handle_variant_access( op.base().ty(), folded_base_expr, op.variant_index(), )?) } ExprEnum::StructuralEq(op) => { let ty = op.lhs().ty(); assert_eq!(ty, op.rhs().ty()); let folded_lhs = Expr::canonical(op.lhs()).fold(self)?; let folded_rhs = Expr::canonical(op.rhs()).fold(self)?; *Expr::expr_enum(self.handle_structural_eq( ty, folded_lhs, folded_rhs, op.flags(), )?) } ExprEnum::MemPort(mem_port) => { if let Some(&wire) = self.replacement_mem_ports.get(&mem_port) { ExprEnum::Wire(wire) } else { ExprEnum::MemPort(mem_port.fold(self)?) } } ExprEnum::UIntLiteral(_) | ExprEnum::SIntLiteral(_) | ExprEnum::BoolLiteral(_) | ExprEnum::PhantomConst(_) | ExprEnum::BundleLiteral(_) | ExprEnum::ArrayLiteral(_) | ExprEnum::Uninit(_) | ExprEnum::NotU(_) | ExprEnum::NotS(_) | ExprEnum::NotB(_) | ExprEnum::Neg(_) | ExprEnum::BitAndU(_) | ExprEnum::BitAndS(_) | ExprEnum::BitAndB(_) | ExprEnum::BitOrU(_) | ExprEnum::BitOrS(_) | ExprEnum::BitOrB(_) | ExprEnum::BitXorU(_) | ExprEnum::BitXorS(_) | ExprEnum::BitXorB(_) | ExprEnum::AddU(_) | ExprEnum::AddS(_) | ExprEnum::SubU(_) | ExprEnum::SubS(_) | ExprEnum::MulU(_) | ExprEnum::MulS(_) | ExprEnum::DivU(_) | ExprEnum::DivS(_) | ExprEnum::RemU(_) | ExprEnum::RemS(_) | ExprEnum::DynShlU(_) | ExprEnum::DynShlS(_) | ExprEnum::DynShrU(_) | ExprEnum::DynShrS(_) | ExprEnum::FixedShlU(_) | ExprEnum::FixedShlS(_) | ExprEnum::FixedShrU(_) | ExprEnum::FixedShrS(_) | ExprEnum::CmpLtB(_) | ExprEnum::CmpLeB(_) | ExprEnum::CmpGtB(_) | ExprEnum::CmpGeB(_) | ExprEnum::CmpEqB(_) | ExprEnum::CmpNeB(_) | ExprEnum::CmpLtU(_) | ExprEnum::CmpLeU(_) | ExprEnum::CmpGtU(_) | ExprEnum::CmpGeU(_) | ExprEnum::CmpEqU(_) | ExprEnum::CmpNeU(_) | ExprEnum::CmpLtS(_) | ExprEnum::CmpLeS(_) | ExprEnum::CmpGtS(_) | ExprEnum::CmpGeS(_) | ExprEnum::CmpEqS(_) | ExprEnum::CmpNeS(_) | ExprEnum::CastUIntToUInt(_) | ExprEnum::CastUIntToSInt(_) | ExprEnum::CastSIntToUInt(_) | ExprEnum::CastSIntToSInt(_) | ExprEnum::CastBoolToUInt(_) | ExprEnum::CastBoolToSInt(_) | ExprEnum::CastUIntToBool(_) | ExprEnum::CastSIntToBool(_) | ExprEnum::CastBoolToSyncReset(_) | ExprEnum::CastUIntToSyncReset(_) | ExprEnum::CastSIntToSyncReset(_) | ExprEnum::CastBoolToAsyncReset(_) | ExprEnum::CastUIntToAsyncReset(_) | ExprEnum::CastSIntToAsyncReset(_) | ExprEnum::CastSyncResetToBool(_) | ExprEnum::CastSyncResetToUInt(_) | ExprEnum::CastSyncResetToSInt(_) | ExprEnum::CastSyncResetToReset(_) | ExprEnum::CastAsyncResetToBool(_) | ExprEnum::CastAsyncResetToUInt(_) | ExprEnum::CastAsyncResetToSInt(_) | ExprEnum::CastAsyncResetToReset(_) | ExprEnum::CastResetToBool(_) | ExprEnum::CastResetToUInt(_) | ExprEnum::CastResetToSInt(_) | ExprEnum::CastBoolToClock(_) | ExprEnum::CastUIntToClock(_) | ExprEnum::CastSIntToClock(_) | ExprEnum::CastClockToBool(_) | ExprEnum::CastClockToUInt(_) | ExprEnum::CastClockToSInt(_) | ExprEnum::FieldAccess(_) | ExprEnum::ArrayIndex(_) | ExprEnum::DynArrayIndex(_) | ExprEnum::ReduceBitAndU(_) | ExprEnum::ReduceBitAndS(_) | ExprEnum::ReduceBitOrU(_) | ExprEnum::ReduceBitOrS(_) | ExprEnum::ReduceBitXorU(_) | ExprEnum::ReduceBitXorS(_) | ExprEnum::SliceUInt(_) | ExprEnum::SliceSInt(_) | ExprEnum::CastToBits(_) | ExprEnum::CastBitsTo(_) | ExprEnum::TraceAsStringAsInner(_) | ExprEnum::ToTraceAsString(_) | ExprEnum::ModuleIO(_) | ExprEnum::Instance(_) | ExprEnum::Wire(_) | ExprEnum::Reg(_) | ExprEnum::RegSync(_) | ExprEnum::RegAsync(_) | ExprEnum::FormalInput(_) | ExprEnum::SimIoForGlobal(_) => op.default_fold(self)?, }; self.module_state_stack .last_mut() .expect("known to be in module") .expr_cache .insert(op, folded_op); Ok(folded_op) } fn fold_block(&mut self, block: Block) -> Result { let block_scope = BlockScope::new(self, vec![], vec![]); let this = &mut *block_scope.state; let mut memories = vec![]; let mut stmts = vec![]; for memory in block.memories { let old_element_ty = memory.array_type().element(); let new_element_ty = memory.array_type().element().fold(this)?; if new_element_ty != old_element_ty { let mut new_ports = vec![]; for port in memory.ports() { let new_port = MemPort::::new_unchecked( port.mem_name(), port.source_location(), port.port_name(), port.addr_type(), new_element_ty, ); new_ports.push(new_port); let new_port_ty = new_port.ty(); let mut wire_ty_fields = Vec::from_iter(new_port_ty.fields()); if let Some(wmask_name) = new_port.port_kind().wmask_name() { let index = *new_port_ty .name_indexes() .get(&wmask_name.intern()) .unwrap(); wire_ty_fields[index].ty = port.ty().fields()[index].ty; } let wire_ty = Bundle::new(Intern::intern_owned(wire_ty_fields)); if wire_ty == new_port_ty { continue; } let wire = Wire::new_unchecked( this.module_state_stack .last_mut() .unwrap() .gen_name(&format!( "{}_{}", memory.scoped_name().1.0, port.port_name() )), port.source_location(), wire_ty, ); stmts.push( StmtWire { annotations: Default::default(), wire: wire.canonical(), } .into(), ); connect_port( &mut stmts, Expr::canonical(new_port.to_expr()), Expr::canonical(wire.to_expr()), port.source_location(), ); this.replacement_mem_ports.insert(port, wire.canonical()); } memories.push(Mem::new_unchecked( memory.scoped_name(), memory.source_location(), ArrayType::new_dyn(new_element_ty, memory.array_type().len()), memory.initial_value(), Intern::intern_owned(new_ports), memory.read_latency(), memory.write_latency(), memory.read_under_write(), memory.port_annotations(), memory.mem_annotations(), )); } else { memories.push(memory.fold(this)?); } } stmts.extend_from_slice(&block.stmts.fold(this)?); stmts.splice(0..0, this.new_prefix_stmts_for_block.drain(..)); stmts.extend_from_slice(&this.new_suffix_stmts_for_block); Ok(Block { memories: Intern::intern_owned(memories), stmts: Intern::intern_owned(stmts), }) } fn fold_stmt(&mut self, stmt: Stmt) -> Result { match stmt { Stmt::Match(StmtMatch { expr, source_location, blocks, }) => { let folded_expr = Expr::canonical(expr).fold(self)?; let folded_blocks = blocks.fold(self)?; self.handle_match(expr.ty(), folded_expr, source_location, &folded_blocks) } Stmt::Connect(StmtConnect { lhs, rhs, source_location, }) => { let folded_lhs = lhs.fold(self)?; let folded_rhs = rhs.fold(self)?; let mut output_stmts = vec![]; self.handle_stmt_connect( lhs.ty(), rhs.ty(), folded_lhs, folded_rhs, source_location, &mut output_stmts, )?; if output_stmts.len() == 1 { Ok(output_stmts.pop().unwrap()) } else { Ok(StmtIf { cond: true.to_expr(), source_location, blocks: [ Block { memories: Interned::default(), stmts: Intern::intern_owned(output_stmts), }, Block::default(), ], } .into()) } } Stmt::Formal(_) | Stmt::If(_) | Stmt::Declaration(_) => stmt.default_fold(self), } } fn fold_stmt_match(&mut self, _v: StmtMatch) -> Result { unreachable!() } fn fold_canonical_type( &mut self, canonical_type: CanonicalType, ) -> Result { match canonical_type { CanonicalType::Enum(enum_type) => { Ok(match self.get_or_make_enum_type_state(enum_type)? { EnumTypeState::TagEnumAndBody(ty) => ty.canonical(), EnumTypeState::TagUIntAndBody(ty) => ty.canonical(), EnumTypeState::UInt(ty) => ty.canonical(), EnumTypeState::Unchanged => enum_type.canonical(), }) } CanonicalType::Bundle(_) | CanonicalType::Array(_) | CanonicalType::UInt(_) | CanonicalType::SInt(_) | CanonicalType::Bool(_) | CanonicalType::Clock(_) | CanonicalType::AsyncReset(_) | CanonicalType::SyncReset(_) | CanonicalType::Reset(_) | CanonicalType::PhantomConst(_) | CanonicalType::DynSimOnly(_) | CanonicalType::TraceAsString(_) => canonical_type.default_fold(self), } } fn fold_enum_variant(&mut self, _v: EnumVariant) -> Result { unreachable!() } fn fold_enum_literal>( &mut self, _v: ops::EnumLiteral, ) -> Result, Self::Error> { unreachable!() } fn fold_variant_access( &mut self, _v: ops::VariantAccess, ) -> Result, Self::Error> { unreachable!() } } #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, clap::ValueEnum, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub enum SimplifyEnumsKind { SimplifyToEnumsWithNoBody, #[clap(name = "replace-with-bundle-of-uints")] #[serde(rename = "replace-with-bundle-of-uints")] ReplaceWithBundleOfUInts, #[clap(name = "replace-with-uint")] #[serde(rename = "replace-with-uint")] ReplaceWithUInt, } pub fn simplify_enums( module: Interned>, kind: SimplifyEnumsKind, ) -> Result>, SimplifyEnumsError> { // TODO: deduce StructuralEq's assume_padding_is_zeroed module.fold(&mut State { enum_types: HashMap::default(), replacement_mem_ports: HashMap::default(), kind, module_state_stack: vec![], new_prefix_stmts_for_block: vec![], new_suffix_stmts_for_block: vec![], }) }