From 31353862ceba3d255fd6712813a457688b269358 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 1 Jun 2026 23:10:43 -0700 Subject: [PATCH] fayalite/src/module: check that expressions are visible where they are used, e.g. erroring when a wire is inside an `if` but used outside. --- crates/fayalite/src/module.rs | 287 ++++++++++++++++++++++++++++++++-- 1 file changed, 276 insertions(+), 11 deletions(-) diff --git a/crates/fayalite/src/module.rs b/crates/fayalite/src/module.rs index 816a286..86fdd40 100644 --- a/crates/fayalite/src/module.rs +++ b/crates/fayalite/src/module.rs @@ -8,7 +8,7 @@ use crate::{ clock::{Clock, ClockDomain}, enum_::{Enum, EnumMatchVariantsIter, EnumType}, expr::{ - Expr, Flow, ToExpr, ValueType, + Expr, ExprEnum, Flow, ToExpr, ValueType, ops::VariantAccess, target::{ GetTarget, Target, TargetBase, TargetPathArrayElement, TargetPathBundleField, @@ -20,6 +20,7 @@ use crate::{ int::{Bool, DynSize, Size}, intern::{Intern, Interned}, memory::{Mem, MemBuilder, MemBuilderTarget, PortName}, + module::transform::visit::{Visit, Visitor}, platform::PlatformIOBuilder, reg::Reg, reset::{AsyncReset, Reset, ResetType, ResetTypeDispatch, SyncReset}, @@ -1598,9 +1599,54 @@ impl TargetState { } } +struct VisibleExprsStack { + buf: Vec>, + len: usize, +} + +impl VisibleExprsStack { + fn top(&mut self) -> &mut HashSet { + &mut self.buf[self.len - 1] + } + fn slice(&self) -> &[HashSet] { + &self.buf[..self.len] + } + fn contains(&self, v: &ExprEnum) -> bool { + self.slice().iter().any(|i| i.contains(v)) + } + fn push_empty(&mut self) { + #[cold] + fn push_empty_cold(stack: &mut VisibleExprsStack) { + stack.buf.push(HashSet::default()); + assert_eq!(stack.buf.len(), stack.len) + } + self.len += 1; + if self.len > self.buf.len() { + push_empty_cold(self) + } + } + fn pop(&mut self) { + let Some(new_len) = self.len.checked_sub(1) else { + unreachable!("visible exprs stack underflow"); + }; + self.buf[new_len].clear(); + self.len = new_len; + } +} + +impl Default for VisibleExprsStack { + fn default() -> Self { + Self { + buf: Vec::new(), + len: 0, + } + } +} + struct AssertValidityState { module: Module, blocks: Vec, + visible_exprs: VisibleExprsStack, target_states: HashMap, TargetState>, } @@ -1771,6 +1817,7 @@ impl AssertValidityState { } } } + #[track_caller] fn process_conditional_sub_blocks( &mut self, parent_block: usize, @@ -1784,17 +1831,40 @@ impl AssertValidityState { } } #[track_caller] + fn assert_expr_validity(&mut self, expr: Expr, source_location: SourceLocation) { + let mut visitor = AssertExprValidity { state: self }; + match visitor.visit_expr(&expr) { + Ok(()) => {} + Err(e) => match e { + InvalidExpr::ExprIsNotVisible(expr) => { + if let Some(target) = expr.target() { + panic!( + "at {source_location}: expression isn't visible here, it's defined:\n\ + at {}: {expr:?}", + target.base().source_location(), + ); + } else { + panic!("at {source_location}: expression isn't visible here: {expr:?}"); + } + } + }, + } + } + #[track_caller] fn assert_subtree_validity(&mut self, block: usize) { + self.visible_exprs.push_empty(); let module = self.module; if block == 0 { for module_io in &*module.module_io { self.insert_new_base(TargetBase::intern_sized(module_io.module_io.into()), block); + self.visible_exprs.top().insert(module_io.module_io.into()); } } let Block { memories, stmts } = self.blocks[block]; for m in memories { for port in m.ports() { self.insert_new_base(TargetBase::intern_sized(port.into()), block); + self.visible_exprs.top().insert(port.into()); } } for stmt in stmts { @@ -1808,44 +1878,104 @@ impl AssertValidityState { } = connect; self.set_connect_side_written(lhs, source_location, true, block); self.set_connect_side_written(rhs, source_location, false, block); + self.assert_expr_validity(lhs, source_location); + self.assert_expr_validity(rhs, source_location); + } + Stmt::Formal(formal) => { + let StmtFormal { + kind: _, + clk, + pred, + en, + text: _, + source_location, + } = formal; + self.assert_expr_validity(clk, source_location); + self.assert_expr_validity(pred, source_location); + self.assert_expr_validity(en, source_location); } - Stmt::Formal(_) => {} Stmt::If(if_stmt) => { - let sub_blocks = if_stmt.blocks.map(|block| self.make_block_index(block)); + let StmtIf { + cond, + source_location, + blocks: sub_blocks, + } = if_stmt; + self.assert_expr_validity(cond, source_location); + let sub_blocks = sub_blocks.map(|block| self.make_block_index(block)); self.process_conditional_sub_blocks(block, sub_blocks) } Stmt::Match(match_stmt) => { match_stmt.assert_validity(); + let StmtMatch { + expr, + source_location, + blocks: sub_blocks, + } = match_stmt; + self.assert_expr_validity(expr, source_location); let sub_blocks = Vec::from_iter( - match_stmt - .blocks + sub_blocks .into_iter() .map(|block| self.make_block_index(block)), ); - self.process_conditional_sub_blocks(block, sub_blocks.iter().copied()) + self.visible_exprs.push_empty(); + let visible_exprs_top = self.visible_exprs.top(); + for variant_index in 0..expr.ty().variants().len() { + visible_exprs_top + .insert(::new_by_index(expr, variant_index).into()); + } + self.process_conditional_sub_blocks(block, sub_blocks.iter().copied()); + self.visible_exprs.pop(); } Stmt::Declaration(StmtDeclaration::Wire(StmtWire { annotations: _, wire, - })) => self.insert_new_base(TargetBase::intern_sized(wire.into()), block), + })) => { + self.insert_new_base(TargetBase::intern_sized(wire.into()), block); + self.visible_exprs.top().insert(wire.into()); + } Stmt::Declaration(StmtDeclaration::Reg(StmtReg { annotations: _, reg, - })) => self.insert_new_base(TargetBase::intern_sized(reg.into()), block), + })) => { + self.assert_expr_validity(reg.clock_domain(), reg.source_location()); + if let Some(init) = reg.init() { + self.assert_expr_validity(init, reg.source_location()); + } + self.insert_new_base(TargetBase::intern_sized(reg.into()), block); + self.visible_exprs.top().insert(reg.into()); + } Stmt::Declaration(StmtDeclaration::RegSync(StmtReg { annotations: _, reg, - })) => self.insert_new_base(TargetBase::intern_sized(reg.into()), block), + })) => { + self.assert_expr_validity(reg.clock_domain(), reg.source_location()); + if let Some(init) = reg.init() { + self.assert_expr_validity(init, reg.source_location()); + } + self.insert_new_base(TargetBase::intern_sized(reg.into()), block); + self.visible_exprs.top().insert(reg.into()); + } Stmt::Declaration(StmtDeclaration::RegAsync(StmtReg { annotations: _, reg, - })) => self.insert_new_base(TargetBase::intern_sized(reg.into()), block), + })) => { + self.assert_expr_validity(reg.clock_domain(), reg.source_location()); + if let Some(init) = reg.init() { + self.assert_expr_validity(init, reg.source_location()); + } + self.insert_new_base(TargetBase::intern_sized(reg.into()), block); + self.visible_exprs.top().insert(reg.into()); + } Stmt::Declaration(StmtDeclaration::Instance(StmtInstance { annotations: _, instance, - })) => self.insert_new_base(TargetBase::intern_sized(instance.into()), block), + })) => { + self.insert_new_base(TargetBase::intern_sized(instance.into()), block); + self.visible_exprs.top().insert(instance.into()); + } } } + self.visible_exprs.pop(); } #[track_caller] fn assert_validity(&mut self) { @@ -1874,6 +2004,140 @@ impl AssertValidityState { } } +struct AssertExprValidity<'a> { + state: &'a mut AssertValidityState, +} + +enum InvalidExpr { + ExprIsNotVisible(Expr), +} + +impl transform::visit::Visitor for AssertExprValidity<'_> { + type Error = InvalidExpr; + fn visit_expr_enum(&mut self, v: &ExprEnum) -> Result<(), Self::Error> { + match v { + ExprEnum::UIntLiteral(_) + | ExprEnum::SIntLiteral(_) + | ExprEnum::BoolLiteral(_) + | ExprEnum::PhantomConst(_) + | ExprEnum::BundleLiteral(_) + | ExprEnum::ArrayLiteral(_) + | ExprEnum::EnumLiteral(_) + | ExprEnum::Uninit(_) + | ExprEnum::NotU(_) + | ExprEnum::NotS(_) + | ExprEnum::NotB(_) + | ExprEnum::Neg(_) + | ExprEnum::BitAndU(_) + | ExprEnum::BitAndS(_) + | ExprEnum::BitAndB(_) + | ExprEnum::BitOrU(_) + | ExprEnum::BitOrS(_) + | ExprEnum::BitOrB(_) + | ExprEnum::BitXorU(_) + | ExprEnum::BitXorS(_) + | ExprEnum::BitXorB(_) + | ExprEnum::AddU(_) + | ExprEnum::AddS(_) + | ExprEnum::SubU(_) + | ExprEnum::SubS(_) + | ExprEnum::MulU(_) + | ExprEnum::MulS(_) + | ExprEnum::DivU(_) + | ExprEnum::DivS(_) + | ExprEnum::RemU(_) + | ExprEnum::RemS(_) + | ExprEnum::DynShlU(_) + | ExprEnum::DynShlS(_) + | ExprEnum::DynShrU(_) + | ExprEnum::DynShrS(_) + | ExprEnum::FixedShlU(_) + | ExprEnum::FixedShlS(_) + | ExprEnum::FixedShrU(_) + | ExprEnum::FixedShrS(_) + | ExprEnum::CmpLtB(_) + | ExprEnum::CmpLeB(_) + | ExprEnum::CmpGtB(_) + | ExprEnum::CmpGeB(_) + | ExprEnum::CmpEqB(_) + | ExprEnum::CmpNeB(_) + | ExprEnum::CmpLtU(_) + | ExprEnum::CmpLeU(_) + | ExprEnum::CmpGtU(_) + | ExprEnum::CmpGeU(_) + | ExprEnum::CmpEqU(_) + | ExprEnum::CmpNeU(_) + | ExprEnum::CmpLtS(_) + | ExprEnum::CmpLeS(_) + | ExprEnum::CmpGtS(_) + | ExprEnum::CmpGeS(_) + | ExprEnum::CmpEqS(_) + | ExprEnum::CmpNeS(_) + | ExprEnum::CastUIntToUInt(_) + | ExprEnum::CastUIntToSInt(_) + | ExprEnum::CastSIntToUInt(_) + | ExprEnum::CastSIntToSInt(_) + | ExprEnum::CastBoolToUInt(_) + | ExprEnum::CastBoolToSInt(_) + | ExprEnum::CastUIntToBool(_) + | ExprEnum::CastSIntToBool(_) + | ExprEnum::CastBoolToSyncReset(_) + | ExprEnum::CastUIntToSyncReset(_) + | ExprEnum::CastSIntToSyncReset(_) + | ExprEnum::CastBoolToAsyncReset(_) + | ExprEnum::CastUIntToAsyncReset(_) + | ExprEnum::CastSIntToAsyncReset(_) + | ExprEnum::CastSyncResetToBool(_) + | ExprEnum::CastSyncResetToUInt(_) + | ExprEnum::CastSyncResetToSInt(_) + | ExprEnum::CastSyncResetToReset(_) + | ExprEnum::CastAsyncResetToBool(_) + | ExprEnum::CastAsyncResetToUInt(_) + | ExprEnum::CastAsyncResetToSInt(_) + | ExprEnum::CastAsyncResetToReset(_) + | ExprEnum::CastResetToBool(_) + | ExprEnum::CastResetToUInt(_) + | ExprEnum::CastResetToSInt(_) + | ExprEnum::CastBoolToClock(_) + | ExprEnum::CastUIntToClock(_) + | ExprEnum::CastSIntToClock(_) + | ExprEnum::CastClockToBool(_) + | ExprEnum::CastClockToUInt(_) + | ExprEnum::CastClockToSInt(_) + | ExprEnum::FieldAccess(_) + | ExprEnum::ArrayIndex(_) + | ExprEnum::DynArrayIndex(_) + | ExprEnum::ReduceBitAndU(_) + | ExprEnum::ReduceBitAndS(_) + | ExprEnum::ReduceBitOrU(_) + | ExprEnum::ReduceBitOrS(_) + | ExprEnum::ReduceBitXorU(_) + | ExprEnum::ReduceBitXorS(_) + | ExprEnum::SliceUInt(_) + | ExprEnum::SliceSInt(_) + | ExprEnum::CastToBits(_) + | ExprEnum::CastBitsTo(_) + | ExprEnum::ToTraceAsString(_) + | ExprEnum::TraceAsStringAsInner(_) => v.default_visit(self), + ExprEnum::VariantAccess(_) + | ExprEnum::ModuleIO(_) + | ExprEnum::Instance(_) + | ExprEnum::Wire(_) + | ExprEnum::Reg(_) + | ExprEnum::RegSync(_) + | ExprEnum::RegAsync(_) + | ExprEnum::MemPort(_) => { + if self.state.visible_exprs.contains(v) { + // no need to visit inner expressions, we already checked them before adding them to visible_exprs + Ok(()) + } else { + Err(InvalidExpr::ExprIsNotVisible(v.to_expr())) + } + } + } + } +} + impl Module { /// you generally should use the [`#[hdl_module]`][`crate::hdl_module`] proc-macro and [`ModuleBuilder`] instead #[track_caller] @@ -1999,6 +2263,7 @@ impl Module { AssertValidityState { module: self.canonical(), blocks: vec![], + visible_exprs: VisibleExprsStack::default(), target_states: HashMap::with_capacity_and_hasher(64, Default::default()), } .assert_validity();