fayalite/src/module: check that expressions are visible where they are used, e.g. erroring when a wire is inside an if but used outside.
All checks were successful
/ test (pull_request) Successful in 4m21s
/ test (push) Successful in 4m54s

This commit is contained in:
Jacob Lifshay 2026-06-01 23:10:43 -07:00
parent b1116c4a1a
commit 31353862ce
Signed by: programmerjake
SSH key fingerprint: SHA256:HnFTLGpSm4Q4Fj502oCFisjZSoakwEuTsJJMSke63RQ

View file

@ -8,7 +8,7 @@ use crate::{
clock::{Clock, ClockDomain},
enum_::{Enum, EnumMatchVariantsIter, EnumType},
expr::{
Expr, Flow, ToExpr, ValueType,
Expr, ExprEnum, Flow, ToExpr, ValueType,
ops::VariantAccess,
target::{
GetTarget, Target, TargetBase, TargetPathArrayElement, TargetPathBundleField,
@ -20,6 +20,7 @@ use crate::{
int::{Bool, DynSize, Size},
intern::{Intern, Interned},
memory::{Mem, MemBuilder, MemBuilderTarget, PortName},
module::transform::visit::{Visit, Visitor},
platform::PlatformIOBuilder,
reg::Reg,
reset::{AsyncReset, Reset, ResetType, ResetTypeDispatch, SyncReset},
@ -1598,9 +1599,54 @@ impl TargetState {
}
}
struct VisibleExprsStack {
buf: Vec<HashSet<ExprEnum>>,
len: usize,
}
impl VisibleExprsStack {
fn top(&mut self) -> &mut HashSet<ExprEnum> {
&mut self.buf[self.len - 1]
}
fn slice(&self) -> &[HashSet<ExprEnum>] {
&self.buf[..self.len]
}
fn contains(&self, v: &ExprEnum) -> bool {
self.slice().iter().any(|i| i.contains(v))
}
fn push_empty(&mut self) {
#[cold]
fn push_empty_cold(stack: &mut VisibleExprsStack) {
stack.buf.push(HashSet::default());
assert_eq!(stack.buf.len(), stack.len)
}
self.len += 1;
if self.len > self.buf.len() {
push_empty_cold(self)
}
}
fn pop(&mut self) {
let Some(new_len) = self.len.checked_sub(1) else {
unreachable!("visible exprs stack underflow");
};
self.buf[new_len].clear();
self.len = new_len;
}
}
impl Default for VisibleExprsStack {
fn default() -> Self {
Self {
buf: Vec::new(),
len: 0,
}
}
}
struct AssertValidityState {
module: Module<Bundle>,
blocks: Vec<Block>,
visible_exprs: VisibleExprsStack,
target_states: HashMap<Interned<TargetBase>, TargetState>,
}
@ -1771,6 +1817,7 @@ impl AssertValidityState {
}
}
}
#[track_caller]
fn process_conditional_sub_blocks(
&mut self,
parent_block: usize,
@ -1784,17 +1831,40 @@ impl AssertValidityState {
}
}
#[track_caller]
fn assert_expr_validity<T: Type>(&mut self, expr: Expr<T>, source_location: SourceLocation) {
let mut visitor = AssertExprValidity { state: self };
match visitor.visit_expr(&expr) {
Ok(()) => {}
Err(e) => match e {
InvalidExpr::ExprIsNotVisible(expr) => {
if let Some(target) = expr.target() {
panic!(
"at {source_location}: expression isn't visible here, it's defined:\n\
at {}: {expr:?}",
target.base().source_location(),
);
} else {
panic!("at {source_location}: expression isn't visible here: {expr:?}");
}
}
},
}
}
#[track_caller]
fn assert_subtree_validity(&mut self, block: usize) {
self.visible_exprs.push_empty();
let module = self.module;
if block == 0 {
for module_io in &*module.module_io {
self.insert_new_base(TargetBase::intern_sized(module_io.module_io.into()), block);
self.visible_exprs.top().insert(module_io.module_io.into());
}
}
let Block { memories, stmts } = self.blocks[block];
for m in memories {
for port in m.ports() {
self.insert_new_base(TargetBase::intern_sized(port.into()), block);
self.visible_exprs.top().insert(port.into());
}
}
for stmt in stmts {
@ -1808,44 +1878,104 @@ impl AssertValidityState {
} = connect;
self.set_connect_side_written(lhs, source_location, true, block);
self.set_connect_side_written(rhs, source_location, false, block);
self.assert_expr_validity(lhs, source_location);
self.assert_expr_validity(rhs, source_location);
}
Stmt::Formal(formal) => {
let StmtFormal {
kind: _,
clk,
pred,
en,
text: _,
source_location,
} = formal;
self.assert_expr_validity(clk, source_location);
self.assert_expr_validity(pred, source_location);
self.assert_expr_validity(en, source_location);
}
Stmt::Formal(_) => {}
Stmt::If(if_stmt) => {
let sub_blocks = if_stmt.blocks.map(|block| self.make_block_index(block));
let StmtIf {
cond,
source_location,
blocks: sub_blocks,
} = if_stmt;
self.assert_expr_validity(cond, source_location);
let sub_blocks = sub_blocks.map(|block| self.make_block_index(block));
self.process_conditional_sub_blocks(block, sub_blocks)
}
Stmt::Match(match_stmt) => {
match_stmt.assert_validity();
let StmtMatch {
expr,
source_location,
blocks: sub_blocks,
} = match_stmt;
self.assert_expr_validity(expr, source_location);
let sub_blocks = Vec::from_iter(
match_stmt
.blocks
sub_blocks
.into_iter()
.map(|block| self.make_block_index(block)),
);
self.process_conditional_sub_blocks(block, sub_blocks.iter().copied())
self.visible_exprs.push_empty();
let visible_exprs_top = self.visible_exprs.top();
for variant_index in 0..expr.ty().variants().len() {
visible_exprs_top
.insert(<VariantAccess>::new_by_index(expr, variant_index).into());
}
self.process_conditional_sub_blocks(block, sub_blocks.iter().copied());
self.visible_exprs.pop();
}
Stmt::Declaration(StmtDeclaration::Wire(StmtWire {
annotations: _,
wire,
})) => self.insert_new_base(TargetBase::intern_sized(wire.into()), block),
})) => {
self.insert_new_base(TargetBase::intern_sized(wire.into()), block);
self.visible_exprs.top().insert(wire.into());
}
Stmt::Declaration(StmtDeclaration::Reg(StmtReg {
annotations: _,
reg,
})) => self.insert_new_base(TargetBase::intern_sized(reg.into()), block),
})) => {
self.assert_expr_validity(reg.clock_domain(), reg.source_location());
if let Some(init) = reg.init() {
self.assert_expr_validity(init, reg.source_location());
}
self.insert_new_base(TargetBase::intern_sized(reg.into()), block);
self.visible_exprs.top().insert(reg.into());
}
Stmt::Declaration(StmtDeclaration::RegSync(StmtReg {
annotations: _,
reg,
})) => self.insert_new_base(TargetBase::intern_sized(reg.into()), block),
})) => {
self.assert_expr_validity(reg.clock_domain(), reg.source_location());
if let Some(init) = reg.init() {
self.assert_expr_validity(init, reg.source_location());
}
self.insert_new_base(TargetBase::intern_sized(reg.into()), block);
self.visible_exprs.top().insert(reg.into());
}
Stmt::Declaration(StmtDeclaration::RegAsync(StmtReg {
annotations: _,
reg,
})) => self.insert_new_base(TargetBase::intern_sized(reg.into()), block),
})) => {
self.assert_expr_validity(reg.clock_domain(), reg.source_location());
if let Some(init) = reg.init() {
self.assert_expr_validity(init, reg.source_location());
}
self.insert_new_base(TargetBase::intern_sized(reg.into()), block);
self.visible_exprs.top().insert(reg.into());
}
Stmt::Declaration(StmtDeclaration::Instance(StmtInstance {
annotations: _,
instance,
})) => self.insert_new_base(TargetBase::intern_sized(instance.into()), block),
})) => {
self.insert_new_base(TargetBase::intern_sized(instance.into()), block);
self.visible_exprs.top().insert(instance.into());
}
}
}
self.visible_exprs.pop();
}
#[track_caller]
fn assert_validity(&mut self) {
@ -1874,6 +2004,140 @@ impl AssertValidityState {
}
}
struct AssertExprValidity<'a> {
state: &'a mut AssertValidityState,
}
enum InvalidExpr {
ExprIsNotVisible(Expr<CanonicalType>),
}
impl transform::visit::Visitor for AssertExprValidity<'_> {
type Error = InvalidExpr;
fn visit_expr_enum(&mut self, v: &ExprEnum) -> Result<(), Self::Error> {
match v {
ExprEnum::UIntLiteral(_)
| ExprEnum::SIntLiteral(_)
| ExprEnum::BoolLiteral(_)
| ExprEnum::PhantomConst(_)
| ExprEnum::BundleLiteral(_)
| ExprEnum::ArrayLiteral(_)
| ExprEnum::EnumLiteral(_)
| ExprEnum::Uninit(_)
| ExprEnum::NotU(_)
| ExprEnum::NotS(_)
| ExprEnum::NotB(_)
| ExprEnum::Neg(_)
| ExprEnum::BitAndU(_)
| ExprEnum::BitAndS(_)
| ExprEnum::BitAndB(_)
| ExprEnum::BitOrU(_)
| ExprEnum::BitOrS(_)
| ExprEnum::BitOrB(_)
| ExprEnum::BitXorU(_)
| ExprEnum::BitXorS(_)
| ExprEnum::BitXorB(_)
| ExprEnum::AddU(_)
| ExprEnum::AddS(_)
| ExprEnum::SubU(_)
| ExprEnum::SubS(_)
| ExprEnum::MulU(_)
| ExprEnum::MulS(_)
| ExprEnum::DivU(_)
| ExprEnum::DivS(_)
| ExprEnum::RemU(_)
| ExprEnum::RemS(_)
| ExprEnum::DynShlU(_)
| ExprEnum::DynShlS(_)
| ExprEnum::DynShrU(_)
| ExprEnum::DynShrS(_)
| ExprEnum::FixedShlU(_)
| ExprEnum::FixedShlS(_)
| ExprEnum::FixedShrU(_)
| ExprEnum::FixedShrS(_)
| ExprEnum::CmpLtB(_)
| ExprEnum::CmpLeB(_)
| ExprEnum::CmpGtB(_)
| ExprEnum::CmpGeB(_)
| ExprEnum::CmpEqB(_)
| ExprEnum::CmpNeB(_)
| ExprEnum::CmpLtU(_)
| ExprEnum::CmpLeU(_)
| ExprEnum::CmpGtU(_)
| ExprEnum::CmpGeU(_)
| ExprEnum::CmpEqU(_)
| ExprEnum::CmpNeU(_)
| ExprEnum::CmpLtS(_)
| ExprEnum::CmpLeS(_)
| ExprEnum::CmpGtS(_)
| ExprEnum::CmpGeS(_)
| ExprEnum::CmpEqS(_)
| ExprEnum::CmpNeS(_)
| ExprEnum::CastUIntToUInt(_)
| ExprEnum::CastUIntToSInt(_)
| ExprEnum::CastSIntToUInt(_)
| ExprEnum::CastSIntToSInt(_)
| ExprEnum::CastBoolToUInt(_)
| ExprEnum::CastBoolToSInt(_)
| ExprEnum::CastUIntToBool(_)
| ExprEnum::CastSIntToBool(_)
| ExprEnum::CastBoolToSyncReset(_)
| ExprEnum::CastUIntToSyncReset(_)
| ExprEnum::CastSIntToSyncReset(_)
| ExprEnum::CastBoolToAsyncReset(_)
| ExprEnum::CastUIntToAsyncReset(_)
| ExprEnum::CastSIntToAsyncReset(_)
| ExprEnum::CastSyncResetToBool(_)
| ExprEnum::CastSyncResetToUInt(_)
| ExprEnum::CastSyncResetToSInt(_)
| ExprEnum::CastSyncResetToReset(_)
| ExprEnum::CastAsyncResetToBool(_)
| ExprEnum::CastAsyncResetToUInt(_)
| ExprEnum::CastAsyncResetToSInt(_)
| ExprEnum::CastAsyncResetToReset(_)
| ExprEnum::CastResetToBool(_)
| ExprEnum::CastResetToUInt(_)
| ExprEnum::CastResetToSInt(_)
| ExprEnum::CastBoolToClock(_)
| ExprEnum::CastUIntToClock(_)
| ExprEnum::CastSIntToClock(_)
| ExprEnum::CastClockToBool(_)
| ExprEnum::CastClockToUInt(_)
| ExprEnum::CastClockToSInt(_)
| ExprEnum::FieldAccess(_)
| ExprEnum::ArrayIndex(_)
| ExprEnum::DynArrayIndex(_)
| ExprEnum::ReduceBitAndU(_)
| ExprEnum::ReduceBitAndS(_)
| ExprEnum::ReduceBitOrU(_)
| ExprEnum::ReduceBitOrS(_)
| ExprEnum::ReduceBitXorU(_)
| ExprEnum::ReduceBitXorS(_)
| ExprEnum::SliceUInt(_)
| ExprEnum::SliceSInt(_)
| ExprEnum::CastToBits(_)
| ExprEnum::CastBitsTo(_)
| ExprEnum::ToTraceAsString(_)
| ExprEnum::TraceAsStringAsInner(_) => v.default_visit(self),
ExprEnum::VariantAccess(_)
| ExprEnum::ModuleIO(_)
| ExprEnum::Instance(_)
| ExprEnum::Wire(_)
| ExprEnum::Reg(_)
| ExprEnum::RegSync(_)
| ExprEnum::RegAsync(_)
| ExprEnum::MemPort(_) => {
if self.state.visible_exprs.contains(v) {
// no need to visit inner expressions, we already checked them before adding them to visible_exprs
Ok(())
} else {
Err(InvalidExpr::ExprIsNotVisible(v.to_expr()))
}
}
}
}
}
impl<T: BundleType> Module<T> {
/// you generally should use the [`#[hdl_module]`][`crate::hdl_module`] proc-macro and [`ModuleBuilder`] instead
#[track_caller]
@ -1999,6 +2263,7 @@ impl<T: BundleType> Module<T> {
AssertValidityState {
module: self.canonical(),
blocks: vec![],
visible_exprs: VisibleExprsStack::default(),
target_states: HashMap::with_capacity_and_hasher(64, Default::default()),
}
.assert_validity();