add validation of connects and matches when validating module

this is useful for catching errors in transformation passes
This commit is contained in:
Jacob Lifshay 2024-09-30 21:20:35 -07:00
parent d2ba313f0f
commit 1e2831da47
Signed by: programmerjake
SSH key fingerprint: SHA256:B1iRVvUJkvd7upMIiMqn6OyxvD2SgJkAH3ZnUOj6z+c

View file

@ -185,6 +185,40 @@ pub struct StmtConnect {
pub source_location: SourceLocation, pub source_location: SourceLocation,
} }
impl StmtConnect {
#[track_caller]
fn assert_validity_with_original_types(&self, lhs_orig_ty: impl Type, rhs_orig_ty: impl Type) {
let Self {
lhs,
rhs,
source_location,
} = *self;
assert!(
Expr::ty(lhs).can_connect(Expr::ty(rhs)),
"can't connect types that are not equivalent:\nlhs type:\n{lhs_orig_ty:?}\nrhs type:\n{rhs_orig_ty:?}\nat: {source_location}",
);
assert!(
matches!(Expr::flow(lhs), Flow::Sink | Flow::Duplex),
"can't connect to source, connect lhs must have sink or duplex flow\nat: {source_location}"
);
assert!(
lhs.target().is_some(),
"can't connect to non-target\nat: {source_location}"
);
match Expr::flow(rhs) {
Flow::Source | Flow::Duplex => {}
Flow::Sink => assert!(
Expr::ty(rhs).is_passive(),
"can't connect from sink with non-passive type\nat: {source_location}"
),
}
}
#[track_caller]
fn assert_validity(&self) {
self.assert_validity_with_original_types(Expr::ty(self.lhs), Expr::ty(self.rhs));
}
}
impl fmt::Debug for StmtConnect { impl fmt::Debug for StmtConnect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { let Self {
@ -283,6 +317,13 @@ pub struct StmtMatch<S: ModuleBuildingStatus = ModuleBuilt> {
pub blocks: Interned<[S::Block]>, pub blocks: Interned<[S::Block]>,
} }
impl StmtMatch {
#[track_caller]
fn assert_validity(&self) {
assert_eq!(Expr::ty(self.expr).variants().len(), self.blocks.len());
}
}
impl<S: ModuleBuildingStatus> fmt::Debug for StmtMatch<S> { impl<S: ModuleBuildingStatus> fmt::Debug for StmtMatch<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { let Self {
@ -1657,11 +1698,13 @@ impl AssertValidityState {
} }
for stmt in stmts { for stmt in stmts {
match stmt { match stmt {
Stmt::Connect(StmtConnect { Stmt::Connect(connect) => {
lhs, connect.assert_validity();
rhs, let StmtConnect {
source_location, lhs,
}) => { rhs,
source_location,
} = connect;
self.set_connect_side_written(lhs, source_location, true, block); self.set_connect_side_written(lhs, source_location, true, block);
self.set_connect_side_written(rhs, source_location, false, block); self.set_connect_side_written(rhs, source_location, false, block);
} }
@ -1671,6 +1714,7 @@ impl AssertValidityState {
self.process_conditional_sub_blocks(block, sub_blocks) self.process_conditional_sub_blocks(block, sub_blocks)
} }
Stmt::Match(match_stmt) => { Stmt::Match(match_stmt) => {
match_stmt.assert_validity();
let sub_blocks = Vec::from_iter( let sub_blocks = Vec::from_iter(
match_stmt match_stmt
.blocks .blocks
@ -2517,24 +2561,12 @@ pub fn connect_any_with_loc<Lhs: ToExpr, Rhs: ToExpr>(
let rhs_orig = rhs.to_expr(); let rhs_orig = rhs.to_expr();
let lhs = Expr::canonical(lhs_orig); let lhs = Expr::canonical(lhs_orig);
let rhs = Expr::canonical(rhs_orig); let rhs = Expr::canonical(rhs_orig);
assert!( let connect = StmtConnect {
Expr::ty(lhs).can_connect(Expr::ty(rhs)), lhs,
"can't connect types that are not equivalent:\nlhs type:\n{:?}\nrhs type:\n{:?}", rhs,
Expr::ty(lhs_orig), source_location,
Expr::ty(rhs_orig) };
); connect.assert_validity_with_original_types(Expr::ty(lhs_orig), Expr::ty(rhs_orig));
assert!(
matches!(Expr::flow(lhs), Flow::Sink | Flow::Duplex),
"can't connect to source, connect lhs must have sink or duplex flow"
);
assert!(lhs.target().is_some(), "can't connect to non-target");
match Expr::flow(rhs) {
Flow::Source | Flow::Duplex => {}
Flow::Sink => assert!(
Expr::ty(rhs).is_passive(),
"can't connect from sink with non-passive type"
),
}
ModuleBuilder::with(|m| { ModuleBuilder::with(|m| {
m.impl_ m.impl_
.borrow_mut() .borrow_mut()
@ -2542,14 +2574,7 @@ pub fn connect_any_with_loc<Lhs: ToExpr, Rhs: ToExpr>(
.builder_normal_body() .builder_normal_body()
.block(m.block_stack.top()) .block(m.block_stack.top())
.stmts .stmts
.push( .push(connect.into());
StmtConnect {
lhs,
rhs,
source_location,
}
.into(),
);
}); });
} }