cpu/crates/cpu/src/reg_alloc.rs

626 lines
24 KiB
Rust

// SPDX-License-Identifier: LGPL-3.0-or-later
// See Notices.txt for copyright information
use crate::{
config::CpuConfig,
instruction::{
COMMON_MOP_SRC_LEN, MOp, MOpDestReg, MOpRegNum, MOpTrait, MoveRegMOp, PRegNum,
RenameTableName, UnitOutRegNum,
},
unit::{
GlobalState, TrapData, UnitMOp, UnitOutput, UnitOutputWrite, UnitResult,
UnitResultCompleted, UnitTrait,
unit_base::{UnitForwardingInfo, UnitInput},
},
util::tree_reduce::tree_reduce_with_state,
};
use fayalite::{
memory::{WriteStruct, splat_mask},
module::{instance_with_loc, memory_with_loc, wire_with_loc},
prelude::*,
util::ready_valid::ReadyValid,
};
use std::{
collections::{BTreeMap, VecDeque},
num::NonZeroUsize,
};
pub mod unit_free_regs_tracker;
#[hdl]
pub struct FetchedDecodedMOp {
pub mop: MOp,
/// true if pc doesn't have to be related to the previous instruction.
/// (enable to stop detecting when the current instruction isn't
/// supposed to be run next, e.g. on branch mis-prediction)
pub is_unrelated_pc: Bool,
pub pc: UInt<64>,
}
#[hdl]
pub enum FetchDecodeSpecialOp {
Trap(TrapData),
ICacheFlush,
}
#[hdl]
pub struct FetchDecodeInterface<FetchWidth: Size> {
pub decoded_insns: ArrayType<ReadyValid<FetchedDecodedMOp>, FetchWidth>,
#[hdl(flip)]
pub fetch_decode_special_op: ReadyValid<FetchDecodeSpecialOp>,
}
#[hdl]
struct ROBRenamedInsn<UnitNumWidth: Size, OutRegNumWidth: Size> {
mop_dest: MOpDestReg,
p_dest: PRegNum<UnitNumWidth, OutRegNumWidth>,
}
#[hdl]
struct ROBEntry<UnitNumWidth: Size, OutRegNumWidth: Size> {
renamed_insn: ROBRenamedInsn<UnitNumWidth, OutRegNumWidth>,
dest_written: Bool,
}
#[hdl_module]
fn rob(config: &CpuConfig) {
#[hdl]
let cd: ClockDomain = m.input();
#[hdl]
let renamed_insns_in: Array<ReadyValid<ROBRenamedInsn<DynSize, DynSize>>> = m.input(
Array[ReadyValid[ROBRenamedInsn[config.unit_num_width()][config.out_reg_num_width]]]
[config.fetch_width.get()],
);
#[hdl]
let unit_forwarding_info: UnitForwardingInfo<DynSize, DynSize, DynSize> =
m.input(config.unit_forwarding_info());
let rob_entry_ty = ROBEntry[config.unit_num_width()][config.out_reg_num_width];
#[hdl]
let rob = reg_builder()
.clock_domain(cd)
.no_reset(Array[rob_entry_ty][config.rob_size.get()]);
#[hdl]
let rob_valid_start = reg_builder()
.clock_domain(cd)
.reset(UInt::range(0..config.rob_size.get()).zero());
#[hdl]
let rob_valid_end = reg_builder()
.clock_domain(cd)
.reset(UInt::range(0..config.rob_size.get()).zero());
#[hdl]
let free_space = wire(UInt::range_inclusive(0..=config.rob_size.get()));
#[hdl]
if rob_valid_end.cmp_lt(rob_valid_start) {
// rob_valid_end wrapped around but start didn't
connect_any(
free_space,
rob_valid_end + config.rob_size.get() - rob_valid_start,
);
} else {
connect_any(free_space, rob_valid_end - rob_valid_start);
}
struct IndexAndRange {
index: Expr<UInt>,
range: std::ops::Range<usize>,
}
let mut next_write_index = IndexAndRange {
index: rob_valid_end,
range: 0..config.rob_size.get(),
};
for fetch_index in 0..config.fetch_width.get() {
let write_index = next_write_index;
let next_write_index_range = write_index.range.start..write_index.range.end + 1;
next_write_index = IndexAndRange {
index: wire_with_loc(
&format!("next_write_index_{fetch_index}"),
SourceLocation::caller(),
UInt::range(next_write_index_range.clone()),
),
range: next_write_index_range,
};
connect(
renamed_insns_in[fetch_index].ready,
fetch_index.cmp_lt(free_space),
);
#[hdl]
if let HdlSome(renamed_insn) = ReadyValid::firing_data(renamed_insns_in[fetch_index]) {
for i in write_index.range.clone() {
#[hdl]
if write_index.index.cmp_eq(i) {
connect(
rob[i % config.rob_size.get()],
#[hdl]
ROBEntry {
renamed_insn,
dest_written: false,
},
);
}
}
}
// TODO: optimize write_index chain better
connect_any(
next_write_index.index,
write_index.index
+ ReadyValid::firing(renamed_insns_in[fetch_index]).cast_to_static::<UInt<1>>(),
);
}
assert!(
config.rob_size >= config.fetch_width,
"rob_size ({}) is too small for fetch_width = {} -- next_write_index would overflow",
config.rob_size,
config.fetch_width,
);
#[hdl]
if next_write_index.index.cmp_lt(config.rob_size.get()) {
connect_any(rob_valid_end, next_write_index.index);
} else {
connect_any(
rob_valid_end,
next_write_index.index - config.rob_size.get(),
);
}
// TODO: optimize better, O(rob_size * unit_count) is too big here
for rob_index in 0..config.rob_size.get() {
for unit_index in 0..config.non_const_unit_nums().len() {
#[hdl]
if let HdlSome(unit_output_write) = unit_forwarding_info.unit_output_writes[unit_index]
{
#[hdl]
let UnitOutputWrite::<_> {
which: unit_out_reg,
value: _,
} = unit_output_write;
let p_reg_num = #[hdl]
PRegNum::<_, _> {
unit_num: config.unit_num().from_index(unit_index),
unit_out_reg,
};
#[hdl]
if rob[rob_index].renamed_insn.p_dest.cmp_eq(p_reg_num) {
connect(rob[rob_index].dest_written, true);
}
}
}
}
}
#[hdl_module]
/// combination register allocator, register renaming, unit selection, and retire handling
pub fn reg_alloc(config: &CpuConfig) {
#[hdl]
let cd: ClockDomain = m.input();
#[hdl]
let fetch_decode_interface: FetchDecodeInterface<DynSize> =
m.input(FetchDecodeInterface[config.fetch_width.get()]);
#[hdl]
let global_state: GlobalState = m.input();
// TODO: propagate traps, branch mis-predictions, and special ops
connect(
fetch_decode_interface.fetch_decode_special_op.data,
HdlNone(),
);
// TODO: finish
#[hdl]
let rob = instance(rob(config));
connect(rob.cd, cd);
let mut rename_table_mems = BTreeMap::<RenameTableName, MemBuilder<_>>::new();
for reg_kind in MOpDestReg::REG_KINDS {
for &rename_table_name in reg_kind.rename_table_names() {
rename_table_mems
.entry(rename_table_name)
.or_insert_with(|| {
let mut mem = memory_with_loc(
&format!("{}_mem", rename_table_name.as_str()),
config.p_reg_num(),
SourceLocation::caller(),
);
mem.depth(rename_table_name.reg_range().len());
mem
});
}
}
#[hdl]
let available_units = wire(Array[Array[Bool][config.units.len()]][config.fetch_width.get()]);
#[hdl]
let selected_unit_indexes =
wire(Array[HdlOption[UInt[config.unit_num_width()]]][config.fetch_width.get()]);
#[hdl]
let renamed_mops =
wire(Array[HdlOption[UnitInput[config.renamed_mop_in_unit()]]][config.fetch_width.get()]);
#[hdl]
let renamed_mops_out_reg = wire(Array[HdlOption[config.p_reg_num()]][config.fetch_width.get()]);
for fetch_index in 0..config.fetch_width.get() {
// TODO: finish
connect(
rob.renamed_insns_in[fetch_index].data,
rob.ty().renamed_insns_in.element().data.HdlNone(),
);
// TODO: finish
connect(
fetch_decode_interface.decoded_insns[fetch_index].ready,
true,
);
for prev_fetch_index in 0..fetch_index {
#[hdl]
if !fetch_decode_interface.decoded_insns[prev_fetch_index].ready {
connect(
fetch_decode_interface.decoded_insns[fetch_index].ready,
false,
);
}
}
connect(
available_units[fetch_index],
repeat(false, config.units.len()),
);
connect(
renamed_mops[fetch_index],
renamed_mops.ty().element().HdlNone(),
);
#[hdl]
struct RenameTableReadPort<T> {
addr: MOpRegNum,
#[hdl(flip)]
data: T,
}
let rename_table_read_ports: [_; COMMON_MOP_SRC_LEN] = std::array::from_fn(|src_index| {
let wire = wire_with_loc(
&format!("rename_{fetch_index}_src_{src_index}"),
SourceLocation::caller(),
RenameTableReadPort[config.p_reg_num()],
);
connect(wire.addr, MOpRegNum::const_zero());
connect(wire.data, config.p_reg_num().const_zero());
for (&rename_table_name, mem) in &mut rename_table_mems {
let read_port = mem.new_read_port();
connect(read_port.clk, cd.clk);
connect_any(read_port.addr, 0u8);
connect(read_port.en, false);
let reg_range = rename_table_name.reg_range();
#[hdl]
if wire.addr.value.cmp_ge(reg_range.start) & wire.addr.value.cmp_lt(reg_range.end) {
connect_any(read_port.addr, wire.addr.value - reg_range.start);
connect(read_port.en, true);
connect(wire.data, read_port.data);
for prev_fetch_index in 0..fetch_index {
#[hdl]
if let HdlSome(decoded_insn) =
fetch_decode_interface.decoded_insns[prev_fetch_index].data
{
#[hdl]
if let HdlSome(renamed_mop_out_reg) =
renamed_mops_out_reg[prev_fetch_index]
{
let dest_reg = MOpTrait::dest_reg(decoded_insn.mop);
for (dest_reg, reg_kind) in MOpDestReg::regs(dest_reg)
.into_iter()
.zip(MOpDestReg::REG_KINDS)
{
if reg_kind.rename_table_names().contains(&rename_table_name) {
#[hdl]
if dest_reg.value.cmp_eq(wire.addr.value) {
connect(wire.data, renamed_mop_out_reg);
}
}
}
}
}
}
}
}
wire
});
let mut rename_table_write_ports = BTreeMap::<RenameTableName, VecDeque<_>>::new();
for reg_kind in MOpDestReg::REG_KINDS {
for &rename_table_name in reg_kind.rename_table_names() {
let mem = rename_table_mems
.get_mut(&rename_table_name)
.expect("already added all RenameTableName values");
let write_ports = rename_table_write_ports
.entry(rename_table_name)
.or_default();
let write_port_ = mem.new_write_port();
let table_name = rename_table_name.as_str();
let write_port = wire_with_loc(
&format!("{table_name}_{fetch_index}_{}", reg_kind.reg_name()),
SourceLocation::caller(),
write_port_.ty(),
);
connect(write_port_, write_port);
write_ports.push_back(write_port);
connect_any(
write_port,
#[hdl]
WriteStruct::<_, _> {
addr: 0_hdl_u0,
en: false,
clk: cd.clk,
data: write_port.data.ty().uninit(),
mask: splat_mask(config.p_reg_num(), true.to_expr()),
},
);
}
}
#[hdl]
if let HdlSome(decoded_insn) = fetch_decode_interface.decoded_insns[fetch_index].data {
connect(
available_units[fetch_index],
config.available_units_for_kind(MOp::kind(decoded_insn.mop)),
);
#[hdl]
if let HdlSome(renamed_mop_out_reg) = renamed_mops_out_reg[fetch_index] {
let dest_reg = MOpTrait::dest_reg(decoded_insn.mop);
let renamed_mop = UnitMOp::try_with_transformed_move_op(
MOpTrait::map_regs(
decoded_insn.mop,
renamed_mop_out_reg.unit_out_reg,
config.p_reg_num_width(),
&mut |src_reg, src_index| {
connect(
rename_table_read_ports[src_index].addr,
#[hdl]
MOpRegNum { value: src_reg },
);
rename_table_read_ports[src_index].data.cast_to_bits()
},
),
config.renamed_mop_in_unit().TransformedMove,
|renamed_mop, renamed_move_op: Expr<MoveRegMOp<_, _>>| {
// TODO: finish handling MoveRegMOp
connect(renamed_mop, renamed_mop.ty().HdlNone());
},
);
connect(
renamed_mops[fetch_index],
HdlOption::map(renamed_mop, |mop| {
#[hdl]
UnitInput::<_> {
mop,
pc: decoded_insn.pc,
}
}),
);
for (reg, reg_kind) in MOpDestReg::regs(dest_reg)
.into_iter()
.zip(MOpDestReg::REG_KINDS)
{
for &rename_table_name in reg_kind.rename_table_names() {
let Some(write_ports) =
rename_table_write_ports.get_mut(&rename_table_name)
else {
unreachable!();
};
let Some(write_port) = write_ports.pop_front() else {
unreachable!();
};
let reg_range = rename_table_name.reg_range();
#[hdl]
if reg.value.cmp_ge(reg_range.start) & reg.value.cmp_lt(reg_range.end) {
connect(write_port.data, renamed_mop_out_reg);
if let Some(fixed_reg_num) = reg_kind.fixed_reg_num() {
connect_any(write_port.addr, fixed_reg_num - reg_range.start);
} else {
connect_any(write_port.addr, reg.value - reg_range.start);
}
connect(write_port.en, true);
}
}
}
}
}
connect(
selected_unit_indexes[fetch_index],
tree_reduce_with_state(
0..config.units.len(),
&mut 0usize,
|_state, unit_index| {
let selected_unit_index_leaf = wire_with_loc(
&format!("selected_unit_index_leaf_{fetch_index}_{unit_index}"),
SourceLocation::caller(),
HdlOption[UInt[config.unit_num_width()]],
);
connect(
selected_unit_index_leaf,
selected_unit_index_leaf.ty().HdlNone(),
);
let unit_index_wire = wire_with_loc(
&format!("unit_index_{fetch_index}_{unit_index}"),
SourceLocation::caller(),
UInt[config.unit_num_width()],
);
connect_any(unit_index_wire, unit_index);
#[hdl]
if available_units[fetch_index][unit_index] {
connect(selected_unit_index_leaf, HdlSome(unit_index_wire))
}
selected_unit_index_leaf
},
|state, l, r| {
let selected_unit_index_node = wire_with_loc(
&format!("selected_unit_index_node_{fetch_index}_{state}"),
SourceLocation::caller(),
l.ty(),
);
*state += 1;
connect(selected_unit_index_node, l);
#[hdl]
if let HdlNone = l {
connect(selected_unit_index_node, r);
}
selected_unit_index_node
},
)
.expect("expected at least one unit"),
);
}
// must come after to override connects in loop above
for fetch_index in 0..config.fetch_width.get() {
// TODO: handle assigning multiple instructions to a unit at a time
for later_fetch_index in fetch_index + 1..config.fetch_width.get() {
#[hdl]
if let HdlSome(selected_unit_index) = selected_unit_indexes[fetch_index] {
connect(
available_units[later_fetch_index][selected_unit_index],
false,
);
}
}
}
connect(
renamed_mops_out_reg,
repeat(
HdlOption[config.p_reg_num()].HdlNone(),
config.fetch_width.get(),
),
);
#[hdl]
let unit_forwarding_info = wire(config.unit_forwarding_info());
connect(rob.unit_forwarding_info, unit_forwarding_info);
for (unit_index, unit_config) in config.units.iter().enumerate() {
let dyn_unit = unit_config.kind.unit(config, unit_index);
let unit = instance_with_loc(
&format!("unit_{unit_index}"),
dyn_unit.module(),
SourceLocation::caller(),
);
connect(dyn_unit.cd(unit), cd);
connect(dyn_unit.global_state(unit), global_state);
let unit_to_reg_alloc = dyn_unit.unit_to_reg_alloc(unit);
// TODO: handle assigning multiple instructions to a unit at a time
let assign_to_unit_at_once = NonZeroUsize::new(1).unwrap();
// TODO: handle retiring multiple instructions from a unit at a time
let retire_from_unit_at_once = NonZeroUsize::new(1).unwrap();
let unit_free_regs_tracker = instance_with_loc(
&format!("unit_{unit_index}_free_regs_tracker"),
unit_free_regs_tracker::unit_free_regs_tracker(
retire_from_unit_at_once,
assign_to_unit_at_once,
config.out_reg_num_width,
),
SourceLocation::caller(),
);
connect(unit_free_regs_tracker.cd, cd);
// TODO: finish
connect(
unit_free_regs_tracker.free_in[0].data,
HdlOption[UInt[config.out_reg_num_width]].uninit(), // FIXME: just for debugging
);
connect(unit_free_regs_tracker.alloc_out[0].ready, false);
connect(
unit_to_reg_alloc.input.data,
unit_to_reg_alloc.input.ty().data.HdlNone(),
);
for fetch_index in 0..config.fetch_width.get() {
#[hdl]
if let HdlNone = unit_free_regs_tracker.alloc_out[0].data {
// must come after to override connects in loops above
connect(available_units[fetch_index][unit_index], false);
}
#[hdl]
if !unit_to_reg_alloc.input.ready {
// must come after to override connects in loops above
connect(available_units[fetch_index][unit_index], false);
}
#[hdl]
if let HdlSome(selected_unit_index) = selected_unit_indexes[fetch_index] {
#[hdl]
if selected_unit_index.cmp_eq(unit_index) {
connect(unit_free_regs_tracker.alloc_out[0].ready, true);
#[hdl]
if let HdlSome(renamed_mop) =
HdlOption::and_then(renamed_mops[fetch_index], |v| {
#[hdl]
let UnitInput::<_> { mop, pc } = v;
let mop = dyn_unit.extract_mop(mop);
HdlOption::map(mop, |mop| {
#[hdl]
UnitInput::<_> { mop, pc }
})
})
{
connect(unit_to_reg_alloc.input.data, HdlSome(renamed_mop));
} else {
connect(
unit_to_reg_alloc.input.data,
HdlSome(unit_to_reg_alloc.input.ty().data.HdlSome.uninit()),
);
// FIXME: add hdl_assert(cd.clk, false.to_expr(), "");
}
#[hdl]
if let HdlSome(unit_out_reg) = unit_free_regs_tracker.alloc_out[0].data {
let unit_num = config.unit_num().from_index(unit_index);
let unit_out_reg = #[hdl]
UnitOutRegNum {
value: unit_out_reg,
};
connect(
renamed_mops_out_reg[fetch_index],
HdlSome(
#[hdl]
PRegNum {
unit_num,
unit_out_reg,
},
),
);
}
}
}
}
connect(unit_to_reg_alloc.unit_forwarding_info, unit_forwarding_info);
connect(
unit_forwarding_info.unit_output_writes[unit_index],
unit_forwarding_info
.ty()
.unit_output_writes
.element()
.HdlNone(),
);
connect(
unit_forwarding_info.unit_reg_frees[unit_index],
HdlOption::map(
ReadyValid::firing_data(unit_free_regs_tracker.free_in[0]),
|value| {
#[hdl]
UnitOutRegNum::<_> { value }
},
),
);
#[hdl]
if let HdlSome(output) = unit_to_reg_alloc.output {
#[hdl]
let UnitOutput::<_, _> { which, result } = output;
#[hdl]
match result {
UnitResult::<_>::Completed(completed) => {
#[hdl]
let UnitResultCompleted::<_> { value, extra_out } = completed;
connect(
unit_forwarding_info.unit_output_writes[unit_index],
HdlSome(
#[hdl]
UnitOutputWrite::<_> { which, value },
),
);
// TODO: handle extra_out
}
UnitResult::<_>::Trap(trap_data) => {
// TODO: handle traps
}
}
}
// TODO: handle cancellation
connect(
unit_to_reg_alloc.cancel_input,
HdlOption[config.unit_cancel_input()].HdlNone(),
);
}
}