fayalite/crates/fayalite/src/module/transform/simplify_memories.rs

951 lines
36 KiB
Rust

// SPDX-License-Identifier: LGPL-3.0-or-later
// See Notices.txt for copyright information
use crate::{
annotations::TargetedAnnotation,
array::Array,
bundle::{Bundle, BundleType},
expr::{CastBitsTo, CastToBits, Expr, ExprEnum, ToExpr},
int::{Bool, SInt, Size, UInt},
intern::{Intern, Interned},
memory::{Mem, MemPort, PortType},
module::{
transform::visit::{Fold, Folder},
Block, Id, Module, NameId, ScopedNameId, Stmt, StmtConnect, StmtWire,
},
source_location::SourceLocation,
ty::{CanonicalType, Type},
util::MakeMutSlice,
wire::Wire,
};
use bitvec::{slice::BitSlice, vec::BitVec};
use hashbrown::HashMap;
use std::{
convert::Infallible,
fmt::Write,
ops::{Deref, DerefMut},
rc::Rc,
};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
enum SingleType {
UInt(UInt),
SInt(SInt),
Bool(Bool),
UIntArray(Array<UInt>),
SIntArray(Array<SInt>),
BoolArray(Array<Bool>),
}
impl SingleType {
fn is_array_type(self, array_type: Array) -> bool {
match self {
SingleType::UInt(_) | SingleType::SInt(_) | SingleType::Bool(_) => false,
SingleType::UIntArray(ty) => ty.as_dyn_array() == array_type,
SingleType::SIntArray(ty) => ty.as_dyn_array() == array_type,
SingleType::BoolArray(ty) => ty.as_dyn_array() == array_type,
}
}
fn array_len(self) -> usize {
match self {
SingleType::UInt(_ty) => 1,
SingleType::SInt(_ty) => 1,
SingleType::Bool(_ty) => 1,
SingleType::UIntArray(ty) => ty.len(),
SingleType::SIntArray(ty) => ty.len(),
SingleType::BoolArray(ty) => ty.len(),
}
}
}
#[derive(Clone, Debug)]
enum MemSplit {
Bundle {
fields: Rc<[MemSplit]>,
},
Single {
output_mem: Option<Mem>,
element_type: SingleType,
unchanged_element_type: bool,
},
Array {
elements: Rc<[MemSplit]>,
},
}
impl MemSplit {
fn mark_changed_element_type(self) -> Self {
match self {
MemSplit::Bundle { fields: _ } => self,
MemSplit::Single {
output_mem,
element_type,
unchanged_element_type: _,
} => MemSplit::Single {
output_mem,
element_type,
unchanged_element_type: false,
},
MemSplit::Array { elements: _ } => self,
}
}
fn new(element_type: CanonicalType) -> Self {
match element_type {
CanonicalType::Bundle(bundle_ty) => MemSplit::Bundle {
fields: bundle_ty
.fields()
.into_iter()
.map(|field| Self::new(field.ty).mark_changed_element_type())
.collect(),
},
CanonicalType::Array(ty) => {
let element = MemSplit::new(ty.element());
if let Self::Single {
output_mem: _,
element_type,
unchanged_element_type,
} = element
{
match element_type {
SingleType::UInt(element_type) => Self::Single {
output_mem: None,
element_type: SingleType::UIntArray(Array::new_dyn(
element_type,
ty.len(),
)),
unchanged_element_type,
},
SingleType::SInt(element_type) => Self::Single {
output_mem: None,
element_type: SingleType::SIntArray(Array::new_dyn(
element_type,
ty.len(),
)),
unchanged_element_type,
},
SingleType::Bool(element_type) => Self::Single {
output_mem: None,
element_type: SingleType::BoolArray(Array::new_dyn(
element_type,
ty.len(),
)),
unchanged_element_type,
},
SingleType::UIntArray(element_type) => Self::Single {
output_mem: None,
element_type: SingleType::UIntArray(Array::new_dyn(
element_type.element(),
ty.len()
.checked_mul(element_type.len())
.expect("memory element type can't be too big"),
)),
unchanged_element_type: false,
},
SingleType::SIntArray(element_type) => Self::Single {
output_mem: None,
element_type: SingleType::SIntArray(Array::new_dyn(
element_type.element(),
ty.len()
.checked_mul(element_type.len())
.expect("memory element type can't be too big"),
)),
unchanged_element_type: false,
},
SingleType::BoolArray(element_type) => Self::Single {
output_mem: None,
element_type: SingleType::BoolArray(Array::new_dyn(
element_type.element(),
ty.len()
.checked_mul(element_type.len())
.expect("memory element type can't be too big"),
)),
unchanged_element_type: false,
},
}
} else {
let element = element.mark_changed_element_type();
Self::Array {
elements: (0..ty.len()).map(|_| element.clone()).collect(),
}
}
}
CanonicalType::UInt(ty) => Self::Single {
output_mem: None,
element_type: SingleType::UInt(ty),
unchanged_element_type: true,
},
CanonicalType::SInt(ty) => Self::Single {
output_mem: None,
element_type: SingleType::SInt(ty),
unchanged_element_type: true,
},
CanonicalType::Bool(ty) => Self::Single {
output_mem: None,
element_type: SingleType::Bool(ty),
unchanged_element_type: true,
},
CanonicalType::Enum(ty) => Self::Single {
output_mem: None,
element_type: SingleType::UInt(UInt::new_dyn(ty.type_properties().bit_width)),
unchanged_element_type: false,
},
CanonicalType::Clock(_)
| CanonicalType::AsyncReset(_)
| CanonicalType::SyncReset(_)
| CanonicalType::Reset(_) => unreachable!("memory element type is a storable type"),
}
}
}
struct MemState {
replacement_ports: Box<[ExprEnum]>,
}
struct SplitState<'a> {
wire_rdata: Box<[Option<Expr<CanonicalType>>]>,
wire_wdata: Box<[Option<Expr<CanonicalType>>]>,
wire_wmask: Box<[Option<Expr<CanonicalType>>]>,
initial_value: Option<Box<[&'a BitSlice]>>,
}
impl<'a> SplitState<'a> {
fn placeholder() -> Self {
Self {
wire_rdata: Box::new([]),
wire_wdata: Box::new([]),
wire_wmask: Box::new([]),
initial_value: None,
}
}
fn new_empty(ports_len: usize) -> Self {
Self {
wire_rdata: (0..ports_len).map(|_| None).collect(),
wire_wdata: (0..ports_len).map(|_| None).collect(),
wire_wmask: (0..ports_len).map(|_| None).collect(),
initial_value: None,
}
}
}
struct SplitStateStack<'a> {
ports_len: usize,
values: Vec<SplitState<'a>>,
top_index: usize,
}
impl<'a> SplitStateStack<'a> {
fn new(ports_len: usize, value: SplitState<'a>) -> Self {
Self {
ports_len,
values: vec![value],
top_index: 0,
}
}
fn top(&mut self) -> &mut SplitState<'a> {
&mut self.values[self.top_index]
}
fn pop(&mut self) {
self.top_index = self
.top_index
.checked_sub(1)
.expect("there's always at least one entry in the stack");
}
fn push_map(
&mut self,
mut wire_map: impl FnMut(Expr<CanonicalType>) -> Expr<CanonicalType>,
mut initial_value_element_map: impl FnMut(&BitSlice) -> &BitSlice,
) {
let top_index = self.top_index + 1;
let mut top = match self.values.get_mut(top_index) {
Some(top) => std::mem::replace(top, SplitState::placeholder()),
None => SplitState::new_empty(self.ports_len),
};
for (l, &r) in top.wire_rdata.iter_mut().zip(self.top().wire_rdata.iter()) {
*l = r.map(&mut wire_map);
}
for (l, &r) in top.wire_wdata.iter_mut().zip(self.top().wire_wdata.iter()) {
*l = r.map(&mut wire_map);
}
for (l, &r) in top.wire_wmask.iter_mut().zip(self.top().wire_wmask.iter()) {
*l = r.map(&mut wire_map);
}
if let Some(initial_value) = &self.top().initial_value {
let new_initial_value = top.initial_value.get_or_insert_with(|| {
Box::from_iter((0..initial_value.len()).map(|_| Default::default()))
});
for (l, &r) in new_initial_value.iter_mut().zip(initial_value.iter()) {
*l = initial_value_element_map(r);
}
}
self.top_index = top_index;
if let Some(v) = self.values.get_mut(top_index) {
*v = top;
} else {
assert_eq!(top_index, self.values.len());
self.values.push(top);
}
}
}
struct SplitMemState<'a, 'b> {
module_state: &'a mut ModuleState,
input_mem: Mem,
output_mems: &'a mut Vec<Mem>,
output_stmts: &'a mut Vec<Stmt>,
element_type: CanonicalType,
split: &'a mut MemSplit,
mem_name_path: &'a mut String,
split_state_stack: &'a mut SplitStateStack<'b>,
mem_state: &'a MemState,
}
impl SplitMemState<'_, '_> {
fn split_mem(self) {
let outer_mem_name_path_len = self.mem_name_path.len();
match self.split {
MemSplit::Bundle { fields } => {
let CanonicalType::Bundle(bundle_type) = self.element_type else {
unreachable!();
};
for ((field, field_offset), split) in bundle_type
.fields()
.into_iter()
.zip(bundle_type.field_offsets())
.zip(fields.make_mut_slice())
{
self.mem_name_path.truncate(outer_mem_name_path_len);
self.mem_name_path.push('_');
self.mem_name_path.push_str(&field.name);
let field_ty_bit_width = field.ty.bit_width();
self.split_state_stack.push_map(
|e: Expr<CanonicalType>| {
Expr::field(Expr::<Bundle>::from_canonical(e), &field.name)
},
|initial_value_element| {
&initial_value_element[field_offset..][..field_ty_bit_width]
},
);
SplitMemState {
module_state: self.module_state,
input_mem: self.input_mem,
output_mems: self.output_mems,
output_stmts: self.output_stmts,
element_type: field.ty,
split,
mem_name_path: self.mem_name_path,
split_state_stack: self.split_state_stack,
mem_state: self.mem_state,
}
.split_mem();
self.split_state_stack.pop();
}
}
MemSplit::Single {
output_mem,
element_type: single_type,
unchanged_element_type: _,
} => {
let new_mem = self.module_state.create_split_mem(
self.input_mem,
self.output_stmts,
self.element_type,
*single_type,
self.mem_name_path,
self.split_state_stack.top(),
);
for (port, wire) in new_mem
.ports()
.into_iter()
.zip(self.mem_state.replacement_ports.iter())
{
let port_expr = port.to_expr();
let wire_expr = Expr::<Bundle>::from_canonical(wire.to_expr());
for name in [
Some("addr"),
Some("clk"),
Some("en"),
port.port_kind().wmode_name(),
] {
let Some(name) = name else {
continue;
};
self.output_stmts.push(
StmtConnect {
lhs: Expr::field(port_expr, name),
rhs: Expr::field(wire_expr, name),
source_location: port.source_location(),
}
.into(),
);
}
}
*output_mem = Some(new_mem);
self.output_mems.push(new_mem);
}
MemSplit::Array { elements } => {
let CanonicalType::Array(array_type) = self.element_type else {
unreachable!();
};
let element_type = array_type.element();
let element_bit_width = element_type.bit_width();
for (index, split) in elements.make_mut_slice().iter_mut().enumerate() {
self.mem_name_path.truncate(outer_mem_name_path_len);
write!(self.mem_name_path, "_{index}").unwrap();
self.split_state_stack.push_map(
|e| Expr::<Array>::from_canonical(e)[index],
|initial_value_element| {
&initial_value_element[index * element_bit_width..][..element_bit_width]
},
);
SplitMemState {
module_state: self.module_state,
input_mem: self.input_mem,
output_mems: self.output_mems,
output_stmts: self.output_stmts,
element_type,
split,
mem_name_path: self.mem_name_path,
split_state_stack: self.split_state_stack,
mem_state: self.mem_state,
}
.split_mem();
self.split_state_stack.pop();
}
}
}
}
}
struct ModuleState {
output_module: Option<Interned<Module<Bundle>>>,
memories: HashMap<ScopedNameId, MemState>,
}
impl ModuleState {
#[allow(clippy::too_many_arguments)]
fn connect_split_mem_port_arrays(
output_stmts: &mut Vec<Stmt>,
input_array_types: &[Array],
memory_element_array_range_start: usize,
memory_element_array_range_len: usize,
wire_rdata: Option<Expr<CanonicalType>>,
wire_wdata: Option<Expr<CanonicalType>>,
wire_wmask: Option<Expr<CanonicalType>>,
port_rdata: Option<Expr<Array>>,
port_wdata: Option<Expr<Array>>,
port_wmask: Option<Expr<Array>>,
connect_rdata: impl Copy + Fn(&mut Vec<Stmt>, Expr<CanonicalType>, Expr<CanonicalType>),
connect_wdata: impl Copy + Fn(&mut Vec<Stmt>, Expr<CanonicalType>, Expr<CanonicalType>),
connect_wmask: impl Copy + Fn(&mut Vec<Stmt>, Expr<CanonicalType>, Expr<CanonicalType>),
) {
let Some((input_array_type, input_array_types_rest)) = input_array_types.split_first()
else {
assert_eq!(memory_element_array_range_len, 1);
if let (Some(wire), Some(port)) = (wire_rdata, port_rdata) {
connect_rdata(output_stmts, wire, port[memory_element_array_range_start]);
}
if let (Some(wire), Some(port)) = (wire_wdata, port_wdata) {
connect_wdata(output_stmts, wire, port[memory_element_array_range_start]);
}
if let (Some(wire), Some(port)) = (wire_wmask, port_wmask) {
connect_wmask(output_stmts, wire, port[memory_element_array_range_start]);
}
return;
};
if input_array_type.is_empty() {
return; // no need to connect zero-length arrays, also avoids division by zero
}
assert_eq!(memory_element_array_range_len % input_array_type.len(), 0);
let chunk_size = memory_element_array_range_len / input_array_type.len();
for index in 0..input_array_type.len() {
let map = |e| Expr::<Array>::from_canonical(e)[index];
let wire_rdata = wire_rdata.map(map);
let wire_wdata = wire_wdata.map(map);
let wire_wmask = wire_wmask.map(map);
Self::connect_split_mem_port_arrays(
output_stmts,
input_array_types_rest,
memory_element_array_range_start + chunk_size * index,
chunk_size,
wire_rdata,
wire_wdata,
wire_wmask,
port_rdata,
port_wdata,
port_wmask,
connect_rdata,
connect_wdata,
connect_wmask,
);
}
}
#[allow(clippy::too_many_arguments)]
fn connect_split_mem_port(
&mut self,
output_stmts: &mut Vec<Stmt>,
mut input_element_type: CanonicalType,
single_type: SingleType,
source_location: SourceLocation,
wire_rdata: Option<Expr<CanonicalType>>,
wire_wdata: Option<Expr<CanonicalType>>,
wire_wmask: Option<Expr<CanonicalType>>,
port_rdata: Option<Expr<CanonicalType>>,
port_wdata: Option<Expr<CanonicalType>>,
port_wmask: Option<Expr<CanonicalType>>,
) {
let mut input_array_types = vec![];
let connect_read = |output_stmts: &mut Vec<Stmt>,
wire_read: Expr<CanonicalType>,
port_read: Expr<CanonicalType>| {
output_stmts.push(
StmtConnect {
lhs: wire_read,
rhs: port_read,
source_location,
}
.into(),
);
};
let connect_write = |output_stmts: &mut Vec<Stmt>,
wire_write: Expr<CanonicalType>,
port_write: Expr<CanonicalType>| {
output_stmts.push(
StmtConnect {
lhs: port_write,
rhs: wire_write,
source_location,
}
.into(),
);
};
let connect_read_enum = |output_stmts: &mut Vec<Stmt>,
wire_read: Expr<CanonicalType>,
port_read: Expr<CanonicalType>| {
connect_read(
output_stmts,
wire_read,
Expr::<UInt>::from_canonical(port_read).cast_bits_to(Expr::ty(wire_read)),
);
};
let connect_write_enum =
|output_stmts: &mut Vec<Stmt>,
wire_write: Expr<CanonicalType>,
port_write: Expr<CanonicalType>| {
connect_write(
output_stmts,
Expr::canonical(wire_write.cast_to_bits()),
port_write,
);
};
loop {
match input_element_type {
CanonicalType::Bundle(_) => unreachable!("bundle types are always split"),
CanonicalType::Enum(_)
if input_array_types
.first()
.map(|&v| single_type.is_array_type(v))
.unwrap_or(true) =>
{
if let (Some(wire_rdata), Some(port_rdata)) = (wire_rdata, port_rdata) {
connect_read_enum(output_stmts, wire_rdata, port_rdata);
}
if let (Some(wire_wdata), Some(port_wdata)) = (wire_wdata, port_wdata) {
connect_write_enum(output_stmts, wire_wdata, port_wdata);
}
if let (Some(wire_wmask), Some(port_wmask)) = (wire_wmask, port_wmask) {
connect_write(output_stmts, wire_wmask, port_wmask);
}
}
CanonicalType::Enum(_) => Self::connect_split_mem_port_arrays(
output_stmts,
&input_array_types,
0,
single_type.array_len(),
wire_rdata,
wire_wdata,
wire_wmask,
port_rdata.map(Expr::from_canonical),
port_wdata.map(Expr::from_canonical),
port_wmask.map(Expr::from_canonical),
connect_read_enum,
connect_write_enum,
connect_write,
),
CanonicalType::Array(array_type) => {
input_array_types.push(array_type);
input_element_type = array_type.element();
continue;
}
CanonicalType::UInt(_) | CanonicalType::SInt(_) | CanonicalType::Bool(_)
if input_array_types
.first()
.map(|&v| single_type.is_array_type(v))
.unwrap_or(true) =>
{
if let (Some(wire_rdata), Some(port_rdata)) = (wire_rdata, port_rdata) {
connect_read(output_stmts, wire_rdata, port_rdata);
}
if let (Some(wire_wdata), Some(port_wdata)) = (wire_wdata, port_wdata) {
connect_write(output_stmts, wire_wdata, port_wdata);
}
if let (Some(wire_wmask), Some(port_wmask)) = (wire_wmask, port_wmask) {
connect_write(output_stmts, wire_wmask, port_wmask);
}
}
CanonicalType::UInt(_) | CanonicalType::SInt(_) | CanonicalType::Bool(_) => {
Self::connect_split_mem_port_arrays(
output_stmts,
&input_array_types,
0,
single_type.array_len(),
wire_rdata,
wire_wdata,
wire_wmask,
port_rdata.map(Expr::from_canonical),
port_wdata.map(Expr::from_canonical),
port_wmask.map(Expr::from_canonical),
connect_read,
connect_write,
connect_write,
)
}
CanonicalType::Clock(_)
| CanonicalType::AsyncReset(_)
| CanonicalType::SyncReset(_)
| CanonicalType::Reset(_) => unreachable!("memory element type is a storable type"),
}
break;
}
}
fn create_split_mem(
&mut self,
input_mem: Mem,
output_stmts: &mut Vec<Stmt>,
input_element_type: CanonicalType,
single_type: SingleType,
mem_name_path: &str,
split_state: &SplitState<'_>,
) -> Mem {
let mem_name = NameId(
Intern::intern_owned(format!("{}{mem_name_path}", input_mem.scoped_name().1 .0)),
Id::new(),
);
let mem_name = ScopedNameId(input_mem.scoped_name().0, mem_name);
let output_element_type = match single_type {
SingleType::UInt(ty) => ty.canonical(),
SingleType::SInt(ty) => ty.canonical(),
SingleType::Bool(ty) => ty.canonical(),
SingleType::UIntArray(ty) => ty.canonical(),
SingleType::SIntArray(ty) => ty.canonical(),
SingleType::BoolArray(ty) => ty.canonical(),
};
let output_array_type = Array::new_dyn(output_element_type, input_mem.array_type().len());
let initial_value = split_state.initial_value.as_ref().map(|initial_value| {
let mut bits = BitVec::with_capacity(output_array_type.type_properties().bit_width);
for element in initial_value.iter() {
bits.extend_from_bitslice(element);
}
Intern::intern_owned(bits)
});
let ports = input_mem
.ports()
.into_iter()
.map(|port| {
MemPort::new_unchecked(
mem_name,
port.source_location(),
port.port_name(),
port.addr_type(),
output_element_type,
)
})
.collect();
let output_mem = Mem::new_unchecked(
mem_name,
input_mem.source_location(),
output_array_type,
initial_value,
ports,
input_mem.read_latency(),
input_mem.write_latency(),
input_mem.read_under_write(),
input_mem
.port_annotations()
.iter()
.flat_map(|_v| -> Option<TargetedAnnotation> {
// TODO: map annotation target for memory port
None
})
.collect(),
input_mem.mem_annotations(),
);
for (index, port) in ports.into_iter().enumerate() {
let SplitState {
wire_rdata,
wire_wdata,
wire_wmask,
initial_value: _,
} = split_state;
let port_expr = port.to_expr();
let port_rdata = port
.port_kind()
.rdata_name()
.map(|name| Expr::field(port_expr, name));
let port_wdata = port
.port_kind()
.wdata_name()
.map(|name| Expr::field(port_expr, name));
let port_wmask = port
.port_kind()
.wmask_name()
.map(|name| Expr::field(port_expr, name));
self.connect_split_mem_port(
output_stmts,
input_element_type,
single_type,
port.source_location(),
wire_rdata[index],
wire_wdata[index],
wire_wmask[index],
port_rdata,
port_wdata,
port_wmask,
);
}
output_mem
}
fn process_mem(
&mut self,
input_mem: Mem,
output_mems: &mut Vec<Mem>,
output_stmts: &mut Vec<Stmt>,
) {
let element_type = input_mem.array_type().element();
let mut split = MemSplit::new(element_type);
let mem_state = match split {
MemSplit::Single {
ref mut output_mem,
element_type: _,
unchanged_element_type: true,
} => {
// no change necessary
*output_mem = Some(input_mem);
output_mems.push(input_mem);
MemState {
replacement_ports: input_mem
.ports()
.into_iter()
.map(ExprEnum::MemPort)
.collect(),
}
}
MemSplit::Single {
unchanged_element_type: false,
..
}
| MemSplit::Bundle { .. }
| MemSplit::Array { .. } => {
let mut replacement_ports = Vec::with_capacity(input_mem.ports().len());
let mut wire_port_rdata = Vec::with_capacity(input_mem.ports().len());
let mut wire_port_wdata = Vec::with_capacity(input_mem.ports().len());
let mut wire_port_wmask = Vec::with_capacity(input_mem.ports().len());
for port in input_mem.ports() {
let port_ty = port.ty();
let NameId(mem_name, _) = input_mem.scoped_name().1;
let port_name = port.port_name();
let wire_name = NameId(
Intern::intern_owned(format!("{mem_name}_{port_name}")),
Id::new(),
);
let wire = Wire::new_unchecked(
ScopedNameId(input_mem.scoped_name().0, wire_name),
port.source_location(),
port_ty,
);
let wire_expr = wire.to_expr();
let canonical_wire = wire.canonical();
output_stmts.push(
StmtWire {
annotations: Default::default(),
wire: canonical_wire,
}
.into(),
);
replacement_ports.push(ExprEnum::Wire(canonical_wire));
wire_port_rdata.push(
port.port_kind()
.rdata_name()
.map(|name| Expr::field(wire_expr, name)),
);
wire_port_wdata.push(
port.port_kind()
.wdata_name()
.map(|name| Expr::field(wire_expr, name)),
);
wire_port_wmask.push(
port.port_kind()
.wmask_name()
.map(|name| Expr::field(wire_expr, name)),
);
}
let mem_state = MemState {
replacement_ports: replacement_ports.into_boxed_slice(),
};
SplitMemState {
module_state: self,
input_mem,
output_mems,
output_stmts,
element_type,
split: &mut split,
mem_name_path: &mut String::with_capacity(32),
split_state_stack: &mut SplitStateStack::new(
input_mem.ports().len(),
SplitState {
wire_rdata: wire_port_rdata.into_boxed_slice(),
wire_wdata: wire_port_wdata.into_boxed_slice(),
wire_wmask: wire_port_wmask.into_boxed_slice(),
initial_value: input_mem.initial_value().as_ref().map(
|initial_value| {
if initial_value.len() == 0 {
Box::new([])
} else {
Box::from_iter(initial_value.chunks(
initial_value.len() / input_mem.array_type().len(),
))
}
},
),
},
),
mem_state: &mem_state,
}
.split_mem();
mem_state
}
};
self.memories.insert(input_mem.scoped_name(), mem_state);
}
}
#[derive(Default)]
struct State {
modules: HashMap<Interned<Module<Bundle>>, ModuleState>,
current_module: Option<Interned<Module<Bundle>>>,
}
impl State {
fn module_state(&mut self) -> &mut ModuleState {
let current_module = self.current_module.unwrap();
self.modules.get_mut(&current_module).unwrap()
}
}
struct PushedState<'a> {
state: &'a mut State,
old_module: Option<Interned<Module<Bundle>>>,
}
impl<'a> PushedState<'a> {
fn push_module(state: &'a mut State, module: Interned<Module<Bundle>>) -> Self {
let old_module = state.current_module.replace(module);
Self { state, old_module }
}
}
impl Drop for PushedState<'_> {
fn drop(&mut self) {
self.state.current_module = self.old_module.take();
}
}
impl Deref for PushedState<'_> {
type Target = State;
fn deref(&self) -> &Self::Target {
self.state
}
}
impl DerefMut for PushedState<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.state
}
}
impl Folder for State {
type Error = Infallible;
fn fold_module<T: BundleType>(&mut self, v: Module<T>) -> Result<Module<T>, Self::Error> {
let module: Interned<_> = v.canonical().intern_sized();
if let Some(module_state) = self.modules.get(&module) {
return Ok(Module::from_canonical(
*module_state
.output_module
.expect("modules can't be mutually recursive"),
));
}
self.modules.insert(
module,
ModuleState {
output_module: None,
memories: HashMap::new(),
},
);
let mut this = PushedState::push_module(self, module);
let module = module.default_fold(&mut *this)?;
this.module_state().output_module = Some(module);
Ok(Module::from_canonical(*module))
}
fn fold_mem<Element: Type, Len: Size>(
&mut self,
_v: Mem<Element, Len>,
) -> Result<Mem<Element, Len>, Self::Error> {
unreachable!()
}
fn fold_block(&mut self, v: Block) -> Result<Block, Self::Error> {
let Block {
memories: input_mems,
stmts: input_stmts,
} = v;
let mut output_mems = vec![];
let mut output_stmts = vec![];
let module_state = self.module_state();
for input_mem in input_mems {
module_state.process_mem(input_mem, &mut output_mems, &mut output_stmts);
}
output_stmts.extend(
input_stmts
.into_iter()
.map(|stmt| stmt.fold(self).unwrap_or_else(|v| match v {})),
);
Ok(Block {
memories: Intern::intern_owned(output_mems),
stmts: Intern::intern_owned(output_stmts),
})
}
fn fold_expr_enum(&mut self, v: ExprEnum) -> Result<ExprEnum, Self::Error> {
if let ExprEnum::MemPort(mem_port) = v {
Ok(self
.module_state()
.memories
.get(&mem_port.mem_name())
.expect("all uses of a memory must come after the memory is declared")
.replacement_ports[mem_port.port_index()])
} else {
v.default_fold(self)
}
}
fn fold_mem_port<T: PortType>(&mut self, _v: MemPort<T>) -> Result<MemPort<T>, Self::Error> {
unreachable!()
}
}
pub fn simplify_memories(module: Interned<Module<Bundle>>) -> Interned<Module<Bundle>> {
module
.fold(&mut State::default())
.unwrap_or_else(|v| match v {})
}