cpu/crates/cpu/src/util/tree_reduce.rs
Jacob Lifshay cb5855589f
All checks were successful
/ test (push) Successful in 46m24s
WIP adding register allocator
2024-10-14 21:20:42 -07:00

153 lines
4.4 KiB
Rust

// SPDX-License-Identifier: LGPL-3.0-or-later
// See Notices.txt for copyright information
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum TreeReduceOp {
Input,
Reduce,
}
#[derive(Copy, Clone, Debug)]
struct Entry {
start: usize,
depth: u32,
}
#[derive(Clone, Debug)]
pub struct TreeReduceOps {
len: usize,
stack: Vec<Entry>,
}
impl TreeReduceOps {
pub fn new(len: usize) -> Self {
TreeReduceOps {
len,
stack: Vec::new(),
}
}
}
impl Iterator for TreeReduceOps {
type Item = TreeReduceOp;
fn next(&mut self) -> Option<Self::Item> {
match *self.stack {
[] if self.len != 0 => {
self.stack.push(Entry { start: 0, depth: 0 });
Some(TreeReduceOp::Input)
}
[.., ref mut second_last, last] if second_last.depth == last.depth => {
second_last.depth += 1;
self.stack.pop();
Some(TreeReduceOp::Reduce)
}
[.., last] if self.len - last.start > 1 << last.depth => {
let start = last.start + (1 << last.depth);
self.stack.push(Entry { start, depth: 0 });
Some(TreeReduceOp::Input)
}
[.., ref mut second_last, _] => {
second_last.depth += 1;
self.stack.pop();
Some(TreeReduceOp::Reduce)
}
_ => None,
}
}
}
#[track_caller]
pub fn tree_reduce_with_state<S, I, R>(
iter: impl IntoIterator<IntoIter: ExactSizeIterator, Item = I>,
state: &mut S,
mut input: impl FnMut(&mut S, I) -> R,
mut reduce: impl FnMut(&mut S, R, R) -> R,
) -> Option<R> {
let mut stack = Vec::new();
let mut iter = iter.into_iter();
for op in TreeReduceOps::new(iter.len()) {
match op {
TreeReduceOp::Input => stack.push(input(
state,
iter.next().expect("inconsistent iterator len() and next()"),
)),
TreeReduceOp::Reduce => {
let Some(r) = stack.pop() else {
unreachable!();
};
let Some(l) = stack.pop() else {
unreachable!();
};
stack.push(reduce(state, l, r));
}
}
}
stack.pop()
}
pub fn tree_reduce<T>(
iter: impl IntoIterator<Item = T, IntoIter: ExactSizeIterator>,
mut reduce: impl FnMut(T, T) -> T,
) -> Option<T> {
tree_reduce_with_state(iter, &mut (), |_, v| v, move |_, l, r| reduce(l, r))
}
#[cfg(test)]
mod tests {
use super::*;
use std::ops::Range;
fn recursive_tree_reduce(range: Range<usize>, ops: &mut Vec<TreeReduceOp>) {
if range.len() == 1 {
ops.push(TreeReduceOp::Input);
return;
}
if range.is_empty() {
return;
}
let pow2_len = range.len().next_power_of_two();
let split = range.start + pow2_len / 2;
recursive_tree_reduce(range.start..split, ops);
recursive_tree_reduce(split..range.end, ops);
ops.push(TreeReduceOp::Reduce);
}
#[test]
fn test_tree_reduce() {
const EXPECTED: &'static [&'static [TreeReduceOp]] = {
use TreeReduceOp::{Input as I, Reduce as R};
&[
&[],
&[I],
&[I, I, R],
&[I, I, R, I, R],
&[I, I, R, I, I, R, R],
&[I, I, R, I, I, R, R, I, R],
&[I, I, R, I, I, R, R, I, I, R, R],
&[I, I, R, I, I, R, R, I, I, R, I, R, R],
&[I, I, R, I, I, R, R, I, I, R, I, I, R, R, R],
]
};
for len in 0..64 {
let mut expected = vec![];
recursive_tree_reduce(0..len, &mut expected);
if let Some(&expected2) = EXPECTED.get(len) {
assert_eq!(*expected, *expected2, "len={len}");
}
assert_eq!(
TreeReduceOps::new(len).collect::<Vec<_>>(),
expected,
"len={len}"
);
let seq: Vec<_> = (0..len).collect();
assert_eq!(
seq,
tree_reduce(seq.iter().map(|&v| vec![v]), |mut l, r| {
l.extend_from_slice(&r);
l
})
.unwrap_or_default()
);
}
}
}