deduce_structural_eq_flags: use expressions' literal_bits to improve deduction around cast_bits_to
All checks were successful
/ test (pull_request) Successful in 4m16s

This commit is contained in:
Jacob Lifshay 2026-06-11 20:46:28 -07:00
parent e2ca80af97
commit 1b16118ce5
Signed by: programmerjake
SSH key fingerprint: SHA256:HnFTLGpSm4Q4Fj502oCFisjZSoakwEuTsJJMSke63RQ
2 changed files with 3097 additions and 800 deletions

View file

@ -2,28 +2,29 @@
// See Notices.txt for copyright information
use crate::{
bundle::BundleType,
enum_::EnumType,
bundle::{BundleField, BundleType},
enum_::{EnumType, EnumVariant},
expr::{
ExprEnum,
ExprEnum, ToLiteralBits,
ops::{
ArrayIndex, FieldAccess, StructuralEq, StructuralEqFlags, TraceAsStringAsInner,
VariantAccess,
},
target::TargetBase,
},
intern::{Intern, InternSlice, Interned, Memoize},
intern::{Intern, InternSlice, Interned, MemoizeGeneric},
module::{
ModuleBody, Stmt, StmtConnect, StmtDeclaration, StmtInstance, StmtReg, StmtWire,
transform::visit::{Fold, Folder, Visit, Visitor},
},
prelude::*,
util::{
HashMap,
BitSliceWriteWithBase, HashMap,
indented_print::{PushIndent, indented_println},
union_find_map::{Entry, UnionFindMap},
},
};
use bitvec::{order::Lsb0, view::BitView};
use std::{convert::Infallible, fmt};
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
@ -77,6 +78,181 @@ impl fmt::Debug for FlagsTree {
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
enum FlagsTreeSourceValue<'a> {
LiteralBits { bits: &'a BitSlice },
Flags { assume_padding_is_zeroed: bool },
}
impl<'a> fmt::Debug for FlagsTreeSourceValue<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::LiteralBits { bits } => f
.debug_struct("LiteralBits")
.field(
"bits",
&fmt::from_fn(|f| write!(f, "{:#b}", BitSliceWriteWithBase(bits))),
)
.finish(),
Self::Flags {
assume_padding_is_zeroed,
} => f
.debug_struct("Flags")
.field("assume_padding_is_zeroed", assume_padding_is_zeroed)
.finish(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
enum FlagsTreeSourceValueOwned {
LiteralBits { bits: Interned<BitSlice> },
Flags { assume_padding_is_zeroed: bool },
}
impl FlagsTreeSourceValueOwned {
fn as_ref(&self) -> FlagsTreeSourceValue<'_> {
match *self {
Self::LiteralBits { ref bits } => FlagsTreeSourceValue::LiteralBits { bits },
Self::Flags {
assume_padding_is_zeroed,
} => FlagsTreeSourceValue::Flags {
assume_padding_is_zeroed,
},
}
}
}
impl<'a> FlagsTreeSourceValue<'a> {
fn to_owned(self) -> FlagsTreeSourceValueOwned {
match self {
Self::LiteralBits { bits } => FlagsTreeSourceValueOwned::LiteralBits {
bits: bits.intern(),
},
Self::Flags {
assume_padding_is_zeroed,
} => FlagsTreeSourceValueOwned::Flags {
assume_padding_is_zeroed,
},
}
}
fn visit_array_body<T, R: FromIterator<T>>(
self,
element_ty: CanonicalType,
f: impl FnMut(Self) -> T,
) -> R {
match self {
Self::LiteralBits { bits } => {
let bit_width = element_ty.bit_width();
if bit_width == 0 {
R::from_iter([])
} else {
bits.chunks(bit_width)
.map(|bits| Self::LiteralBits { bits })
.map(f)
.collect()
}
}
Self::Flags { .. } => [self].into_iter().map(f).collect(),
}
}
fn visit_bundle_body<T, R: FromIterator<T>>(
self,
bundle: Bundle,
mut f: impl FnMut(usize, &BundleField, Self) -> T,
) -> R {
match self {
Self::LiteralBits { bits } => bundle
.fields()
.iter()
.enumerate()
.scan(bits, |bundle_bits, (field_index, field)| {
let field_bit_width = field.ty.bit_width();
let bits;
(bits, *bundle_bits) = bundle_bits.split_at(field_bit_width);
Some(f(field_index, field, Self::LiteralBits { bits }))
})
.collect(),
Self::Flags { .. } => bundle
.fields()
.iter()
.enumerate()
.map(move |(field_index, field)| f(field_index, field, self))
.collect(),
}
}
fn visit_enum_body<T, R: FromIterator<T>>(
self,
enum_: Enum,
debug_trace: bool,
mut f: impl FnMut(&EnumVariant, Self) -> T,
) -> (R, bool) {
let collected;
let padding_is_zeroed;
match self {
Self::LiteralBits { bits } => {
let mut discriminant = 0usize;
let discriminant_bit_width = enum_.discriminant_bit_width();
let (discriminant_bits, bits) = bits.split_at(discriminant_bit_width);
if debug_trace {
indented_println!(
"discriminant_bits={:#b}",
BitSliceWriteWithBase(discriminant_bits)
);
}
discriminant.view_bits_mut::<Lsb0>()[..discriminant_bit_width]
.copy_from_bitslice(discriminant_bits);
let (bits, padding) = bits.split_at(
enum_
.variants()
.get(discriminant)
.and_then(|variant| variant.ty)
.map_or(0, |variant_ty| variant_ty.bit_width()),
);
if debug_trace {
indented_println!(
"bits={:#b}\npadding={:#b}",
BitSliceWriteWithBase(bits),
BitSliceWriteWithBase(padding),
);
}
collected = enum_
.variants()
.iter()
.enumerate()
.map(|(variant_index, variant)| {
f(
variant,
if variant_index == discriminant {
Self::LiteralBits { bits }
} else {
Self::Flags {
assume_padding_is_zeroed: true,
}
},
)
})
.collect();
padding_is_zeroed = !padding.any();
if debug_trace {
indented_println!("padding_is_zeroed={padding_is_zeroed:?}");
}
}
Self::Flags {
assume_padding_is_zeroed,
} => {
collected = enum_
.variants()
.iter()
.map(|variant| f(variant, self))
.collect();
padding_is_zeroed = assume_padding_is_zeroed;
}
}
(collected, padding_is_zeroed)
}
}
impl FlagsTree {
fn contains_padding(&self) -> bool {
match self {
@ -97,8 +273,19 @@ impl FlagsTree {
Self::NoPadding => true,
}
}
fn new_inner(ty: CanonicalType, assume_padding_is_zeroed: bool) -> Interned<Self> {
match ty {
fn new_inner<'a>(
ty: &CanonicalType,
source_value: FlagsTreeSourceValue<'a>,
debug_trace: bool,
) -> Interned<Self> {
let _push_indent;
if debug_trace {
indented_println!("FlagsTree::new_inner()");
_push_indent = PushIndent::new();
indented_println!("ty: {ty:?}");
indented_println!("source_value: {source_value:?}");
}
let retval = match ty {
CanonicalType::UInt(_)
| CanonicalType::SInt(_)
| CanonicalType::Bool(_)
@ -109,33 +296,43 @@ impl FlagsTree {
| CanonicalType::PhantomConst(_)
| CanonicalType::DynSimOnly(_) => Self::NoPadding.intern_sized(),
CanonicalType::Array(ty) => {
if ty.is_empty() {
Self::NoPadding.intern_sized()
} else {
Self::new(ty.element(), assume_padding_is_zeroed)
}
let mut retval = None;
let element_ty = ty.element();
let () =
source_value.visit_array_body(element_ty, |source_value| match &mut retval {
Some(retval) => {
*retval = FlagsTree::new(element_ty, source_value, debug_trace)
.merged(*retval)
}
None => retval = Some(Self::new(element_ty, source_value, debug_trace)),
});
retval.unwrap_or_else(|| Self::NoPadding.intern_sized())
}
CanonicalType::Enum(ty) => {
let mut expected_bit_width = None;
let mut variants = Vec::with_capacity(ty.variants().len());
let mut contains_padding = false;
for variant in ty.variants() {
let variant_flags_tree =
variant.ty.map(|ty| Self::new(ty, assume_padding_is_zeroed));
variants.push(variant_flags_tree);
contains_padding |= variant_flags_tree.is_some_and(|v| v.contains_padding());
let bit_width = if let Some(ty) = variant.ty {
ty.bit_width()
} else {
0
};
if expected_bit_width
.replace(bit_width)
.is_some_and(|v| v != bit_width)
{
contains_padding = true;
}
}
let mut assume_padding_is_zeroed = true;
let (variants, padding_is_zeroed): (Vec<_>, _) =
source_value.visit_enum_body(*ty, debug_trace, |variant, source_value| {
let (variant_flags_tree, bit_width) = if let Some(variant_ty) = variant.ty {
let variant_flags_tree =
Self::new(variant_ty, source_value, debug_trace);
contains_padding |= variant_flags_tree.contains_padding();
assume_padding_is_zeroed &=
variant_flags_tree.assume_padding_is_zeroed();
(Some(variant_flags_tree), variant_ty.bit_width())
} else {
(None, 0)
};
if expected_bit_width
.replace(bit_width)
.is_some_and(|v| v != bit_width)
{
contains_padding = true;
}
variant_flags_tree
});
assume_padding_is_zeroed &= padding_is_zeroed;
if contains_padding {
Self::Enum {
variants: variants.intern_slice(),
@ -148,11 +345,14 @@ impl FlagsTree {
}
CanonicalType::Bundle(ty) => {
let mut contains_padding = false;
let fields = Vec::from_iter(ty.fields().iter().map(|field| {
let flags_tree = Self::new(field.ty, assume_padding_is_zeroed);
contains_padding |= flags_tree.contains_padding();
flags_tree
}));
let mut assume_padding_is_zeroed = true;
let fields: Vec<_> =
source_value.visit_bundle_body(*ty, |_field_index, field, source_value| {
let flags_tree = Self::new(field.ty, source_value, debug_trace);
contains_padding |= flags_tree.contains_padding();
assume_padding_is_zeroed &= flags_tree.assume_padding_is_zeroed();
flags_tree
});
if contains_padding {
Self::Bundle {
fields: fields.intern_slice(),
@ -163,31 +363,77 @@ impl FlagsTree {
Self::NoPadding.intern_sized()
}
}
CanonicalType::TraceAsString(ty) => Self::new(ty.inner_ty(), assume_padding_is_zeroed),
CanonicalType::TraceAsString(ty) => Self::new(ty.inner_ty(), source_value, debug_trace),
};
if debug_trace {
indented_println!("return: {retval:#?}");
}
retval
}
fn new(ty: CanonicalType, assume_padding_is_zeroed: bool) -> Interned<Self> {
fn new(
ty: CanonicalType,
source_value: FlagsTreeSourceValue<'_>,
debug_trace: bool,
) -> Interned<Self> {
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
struct MyMemoize {
assume_padding_is_zeroed: bool,
debug_trace: bool,
}
impl Memoize for MyMemoize {
type Input = CanonicalType;
type InputOwned = CanonicalType;
enum InputCow<'a> {
Borrowed {
ty: &'a CanonicalType,
source_value: FlagsTreeSourceValue<'a>,
},
Owned {
ty: CanonicalType,
source_value: FlagsTreeSourceValueOwned,
},
}
impl MemoizeGeneric for MyMemoize {
type InputRef<'a> = (&'a CanonicalType, FlagsTreeSourceValue<'a>);
type InputOwned = (CanonicalType, FlagsTreeSourceValueOwned);
type InputCow<'a> = InputCow<'a>;
type Output = Interned<FlagsTree>;
fn inner(self, input: &Self::Input) -> Self::Output {
let Self {
assume_padding_is_zeroed,
} = self;
let retval = FlagsTree::new_inner(*input, assume_padding_is_zeroed);
retval
fn input_eq(a: Self::InputRef<'_>, b: Self::InputRef<'_>) -> bool {
a == b
}
fn input_borrow(input: &Self::InputOwned) -> Self::InputRef<'_> {
let (ty, source_value) = input;
(ty, source_value.as_ref())
}
fn input_cow_into_owned(input: Self::InputCow<'_>) -> Self::InputOwned {
match input {
InputCow::Borrowed { ty, source_value } => (*ty, source_value.to_owned()),
InputCow::Owned { ty, source_value } => (ty, source_value),
}
}
fn input_cow_borrow<'a>(input: &'a Self::InputCow<'_>) -> Self::InputRef<'a> {
match input {
&InputCow::Borrowed { ty, source_value } => (ty, source_value),
InputCow::Owned { ty, source_value } => (ty, source_value.as_ref()),
}
}
fn input_cow_from_owned<'a>(input: Self::InputOwned) -> Self::InputCow<'a> {
let (ty, source_value) = input;
InputCow::Owned { ty, source_value }
}
fn input_cow_from_ref(input: Self::InputRef<'_>) -> Self::InputCow<'_> {
let (ty, source_value) = input;
InputCow::Borrowed { ty, source_value }
}
fn inner(self, input: Self::InputRef<'_>) -> Self::Output {
let (ty, source_value) = input;
FlagsTree::new_inner(ty, source_value, self.debug_trace)
}
}
MyMemoize {
assume_padding_is_zeroed,
}
.get_owned(ty)
MyMemoize { debug_trace }.get((&ty, source_value))
}
#[must_use]
fn merged(self, other: Interned<FlagsTree>) -> Interned<FlagsTree> {
@ -596,12 +842,6 @@ impl<T: Type> ExprOrUnknown<T> {
Self::Unknown(ty) => ExprOrUnknown::Unknown(map_ty(ty)),
}
}
fn map_unwrap<U>(self, map_ty: impl FnOnce(T) -> U, map_expr: impl FnOnce(Expr<T>) -> U) -> U {
match self {
Self::Expr(expr) => map_expr(expr),
Self::Unknown(ty) => map_ty(ty),
}
}
fn from_canonical(v: ExprOrUnknown<CanonicalType>) -> Self {
match v {
ExprOrUnknown::Expr(expr) => Self::Expr(Expr::from_canonical(expr)),
@ -735,7 +975,13 @@ impl State {
fn visit_expr_or_unknown(&mut self, expr: ExprOrUnknown<CanonicalType>) -> Interned<FlagsTree> {
match expr {
ExprOrUnknown::Expr(expr) => self.visit_canonical_expr(expr),
ExprOrUnknown::Unknown(ty) => FlagsTree::new(ty, false),
ExprOrUnknown::Unknown(ty) => FlagsTree::new(
ty,
FlagsTreeSourceValue::Flags {
assume_padding_is_zeroed: false,
},
self.debug_trace,
),
}
}
fn connect(
@ -880,7 +1126,13 @@ impl State {
let visited = self.exprs_visited.entry(expr_enum).or_insert(false);
let flags = *self.expr_flags.entry(expr_enum).or_insert_with(|| {
self.any_changes = true;
FlagsTree::new(ty, true)
FlagsTree::new(
ty,
FlagsTreeSourceValue::Flags {
assume_padding_is_zeroed: true,
},
self.debug_trace,
)
});
if std::mem::replace(visited, true) {
return flags;
@ -898,7 +1150,18 @@ impl State {
let (flags, _) = this.connect(ExprOrUnknown::Expr(expr), init);
*flags
};
let literal_bits = expr_enum.to_literal_bits();
let flags = match *expr_enum {
_ if literal_bits.is_ok() => {
let Ok(bits) = &literal_bits else {
unreachable!();
};
*FlagsTree::new(
ty,
FlagsTreeSourceValue::LiteralBits { bits },
self.debug_trace,
)
}
ExprEnum::UIntLiteral(_)
| ExprEnum::SIntLiteral(_)
| ExprEnum::BoolLiteral(_)
@ -948,7 +1211,13 @@ impl State {
}
*flags
}
ExprEnum::Uninit(_) => *FlagsTree::new(ty, false),
ExprEnum::Uninit(_) => *FlagsTree::new(
ty,
FlagsTreeSourceValue::Flags {
assume_padding_is_zeroed: false,
},
self.debug_trace,
),
ExprEnum::NotU(_)
| ExprEnum::NotS(_)
| ExprEnum::NotB(_)
@ -1113,7 +1382,13 @@ impl State {
| ExprEnum::SliceUInt(_)
| ExprEnum::SliceSInt(_)
| ExprEnum::CastToBits(_) => *flags,
ExprEnum::CastBitsTo(_) => *FlagsTree::new(ty, false),
ExprEnum::CastBitsTo(_) => *FlagsTree::new(
ty,
FlagsTreeSourceValue::Flags {
assume_padding_is_zeroed: false,
},
self.debug_trace,
),
ExprEnum::ToTraceAsString(expr) => {
self.visit_canonical_expr(Expr::canonical(expr.inner()));
// FlagsTree treats TraceAsString transparently, so just union them together.
@ -1143,8 +1418,20 @@ impl State {
ExprEnum::Reg(reg) => handle_reg(self, reg.init()),
ExprEnum::RegSync(reg) => handle_reg(self, reg.init()),
ExprEnum::RegAsync(reg) => handle_reg(self, reg.init()),
ExprEnum::MemPort(_) => *FlagsTree::new(ty, false),
ExprEnum::FormalInput(_) => *FlagsTree::new(ty, false),
ExprEnum::MemPort(_) => *FlagsTree::new(
ty,
FlagsTreeSourceValue::Flags {
assume_padding_is_zeroed: false,
},
self.debug_trace,
),
ExprEnum::FormalInput(_) => *FlagsTree::new(
ty,
FlagsTreeSourceValue::Flags {
assume_padding_is_zeroed: false,
},
self.debug_trace,
),
ExprEnum::SimIoForGlobal(_) => {
unreachable!("Module is known to not contain SimIoForGlobal from validation")
}

File diff suppressed because it is too large Load diff