From af0f1f608acd1a6e83032d9c7ea2453bb6453c7a Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 10 Jun 2026 01:24:37 -0700 Subject: [PATCH] WIP add deduce_structural_eq_flags transform --- crates/fayalite/src/module/transform.rs | 1 + .../transform/deduce_structural_eq_flags.rs | 577 ++++++++++++++++++ crates/fayalite/src/util.rs | 6 +- crates/fayalite/src/util/map_trait.rs | 463 ++++++++++++++ crates/fayalite/src/util/union_find_map.rs | 355 +++++++++++ 5 files changed, 1400 insertions(+), 2 deletions(-) create mode 100644 crates/fayalite/src/module/transform/deduce_structural_eq_flags.rs create mode 100644 crates/fayalite/src/util/map_trait.rs create mode 100644 crates/fayalite/src/util/union_find_map.rs diff --git a/crates/fayalite/src/module/transform.rs b/crates/fayalite/src/module/transform.rs index 063a1a3..7eec4cf 100644 --- a/crates/fayalite/src/module/transform.rs +++ b/crates/fayalite/src/module/transform.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information pub mod deduce_resets; +pub mod deduce_structural_eq_flags; pub mod simplify_enums; pub mod simplify_memories; pub mod visit; diff --git a/crates/fayalite/src/module/transform/deduce_structural_eq_flags.rs b/crates/fayalite/src/module/transform/deduce_structural_eq_flags.rs new file mode 100644 index 0000000..c76ebba --- /dev/null +++ b/crates/fayalite/src/module/transform/deduce_structural_eq_flags.rs @@ -0,0 +1,577 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// See Notices.txt for copyright information + +use crate::{ + bundle::BundleType, + enum_::EnumType, + expr::{ + ExprEnum, + ops::{ArrayIndex, FieldAccess, StructuralEqFlags, TraceAsStringAsInner, VariantAccess}, + }, + intern::{Intern, InternSlice, Interned, Memoize}, + module::{ + Block, ModuleBody, + transform::visit::{Fold, Folder, Visit, Visitor}, + }, + prelude::*, + util::{ + HashMap, + union_find_map::{Entry, UnionFindMap}, + }, +}; +use std::convert::Infallible; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +enum FlagsTree { + Enum { + variants: Interned<[Option>]>, + /// invariant -- if this is true all children must also have [`FlagsTree::assume_padding_is_zeroed()`] return true + assume_padding_is_zeroed: bool, + }, + Bundle { + fields: Interned<[Interned]>, + /// invariant -- if this is true all children must also have [`FlagsTree::assume_padding_is_zeroed()`] return true + assume_padding_is_zeroed: bool, + }, + NoPadding, +} + +impl FlagsTree { + fn contains_padding(&self) -> bool { + matches!(self, Self::NoPadding) + } + fn assume_padding_is_zeroed(&self) -> bool { + match *self { + Self::Enum { + assume_padding_is_zeroed, + .. + } + | Self::Bundle { + assume_padding_is_zeroed, + .. + } => assume_padding_is_zeroed, + Self::NoPadding => true, + } + } + fn new(ty: CanonicalType, assume_padding_is_zeroed: bool) -> Interned { + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + struct MyMemoize { + assume_padding_is_zeroed: bool, + } + impl Memoize for MyMemoize { + type Input = CanonicalType; + type InputOwned = CanonicalType; + type Output = Interned; + + fn inner(self, input: &Self::Input) -> Self::Output { + match input { + CanonicalType::UInt(_) + | CanonicalType::SInt(_) + | CanonicalType::Bool(_) + | CanonicalType::AsyncReset(_) + | CanonicalType::SyncReset(_) + | CanonicalType::Reset(_) + | CanonicalType::Clock(_) + | CanonicalType::PhantomConst(_) + | CanonicalType::DynSimOnly(_) => FlagsTree::NoPadding, + CanonicalType::Array(ty) => { + if ty.is_empty() { + FlagsTree::NoPadding + } else { + return FlagsTree::new(ty.element(), self.assume_padding_is_zeroed); + } + } + CanonicalType::Enum(ty) => { + let mut expected_bit_width = None; + let mut variants = Vec::with_capacity(ty.variants().len()); + let mut contains_padding = false; + for variant in ty.variants() { + let variant_flags_tree = variant + .ty + .map(|ty| FlagsTree::new(ty, self.assume_padding_is_zeroed)); + variants.push(variant_flags_tree); + contains_padding |= + variant_flags_tree.is_some_and(|v| v.contains_padding()); + let bit_width = if let Some(ty) = variant.ty { + ty.bit_width() + } else { + 0 + }; + if expected_bit_width + .replace(bit_width) + .is_some_and(|v| v != bit_width) + { + contains_padding = true; + } + } + if contains_padding { + FlagsTree::Enum { + variants: variants.intern_slice(), + assume_padding_is_zeroed: self.assume_padding_is_zeroed, + } + } else { + FlagsTree::NoPadding + } + } + CanonicalType::Bundle(ty) => { + let mut contains_padding = false; + let fields = Vec::from_iter(ty.fields().iter().map(|field| { + let flags_tree = + FlagsTree::new(field.ty, self.assume_padding_is_zeroed); + contains_padding |= flags_tree.contains_padding(); + flags_tree + })); + if contains_padding { + FlagsTree::Bundle { + fields: fields.intern_slice(), + assume_padding_is_zeroed: self.assume_padding_is_zeroed, + } + } else { + FlagsTree::NoPadding + } + } + CanonicalType::TraceAsString(ty) => { + return FlagsTree::new(ty.inner_ty(), self.assume_padding_is_zeroed); + } + } + .intern_sized() + } + } + MyMemoize { + assume_padding_is_zeroed, + } + .get_owned(ty) + } + #[must_use] + fn merged(self, other: Interned) -> Interned { + if self == *other { + return other; + } + match (self, *other) { + ( + Self::Enum { + variants: l_variants, + assume_padding_is_zeroed: l_assume_padding_is_zeroed, + }, + Self::Enum { + variants: r_variants, + assume_padding_is_zeroed: r_assume_padding_is_zeroed, + }, + ) => { + let variants = Interned::from_iter(l_variants.iter().zip(&r_variants).map( + |(&l_variant, &r_variant)| { + l_variant + .zip(r_variant) + .map(|(l_variant, r_variant)| l_variant.merged(r_variant)) + }, + )); + let assume_padding_is_zeroed = + l_assume_padding_is_zeroed & r_assume_padding_is_zeroed; + Self::Enum { + variants, + assume_padding_is_zeroed, + } + .intern_sized() + } + (Self::Enum { .. }, _) => unreachable!("mismatched types"), + ( + Self::Bundle { + fields: l_fields, + assume_padding_is_zeroed: l_assume_padding_is_zeroed, + }, + Self::Bundle { + fields: r_fields, + assume_padding_is_zeroed: r_assume_padding_is_zeroed, + }, + ) => { + let fields = Interned::from_iter( + l_fields + .iter() + .zip(&r_fields) + .map(|(&l_field, &r_field)| l_field.merged(r_field)), + ); + let assume_padding_is_zeroed = + l_assume_padding_is_zeroed & r_assume_padding_is_zeroed; + Self::Bundle { + fields, + assume_padding_is_zeroed, + } + .intern_sized() + } + (Self::Bundle { .. }, _) => unreachable!("mismatched types"), + (Self::NoPadding, _) => { + unreachable!("NoPadding is always caught by the early return above") + } + } + } +} + +struct State { + root_module: Interned>, + any_changes: bool, + expr_flags: UnionFindMap, Interned>, + expr_visited: HashMap, bool>, +} + +const OPTIMISTIC_FLAGS: StructuralEqFlags = StructuralEqFlags { + assume_padding_is_zeroed: true, +}; + +impl State { + fn new(root_module: Interned>) -> Self { + Self { + root_module, + any_changes: false, + expr_flags: UnionFindMap::default(), + expr_visited: HashMap::default(), + } + } + fn merge_expr_flags( + &mut self, + expr: Interned, + new_flags: FlagsTree, + ) -> Interned { + match self.expr_flags.entry(expr) { + Entry::Vacant(entry) => { + let new_flags = new_flags.intern_sized(); + entry.insert(new_flags); + self.any_changes = true; + new_flags + } + Entry::Occupied(mut entry) => { + let new_flags = new_flags.merged(*entry.get()); + self.any_changes |= new_flags != *entry.get(); + entry.insert(new_flags); + new_flags + } + } + } + fn union_exprs( + &mut self, + expr1: Interned, + expr2: Interned, + ) -> Interned { + let (unioned, value) = self + .expr_flags + .union(&expr1, &expr2, |_, v1, _, v2| v1.merged(v2)); + self.any_changes |= unioned; + *value + } + fn connect( + &mut self, + lhs: Expr, + rhs: Expr, + ) -> (Interned, Interned) { + let lhs_flags = self.visit_canonical_expr(lhs); + let rhs_flags = self.visit_canonical_expr(rhs); + if lhs_flags == rhs_flags { + return (lhs_flags, rhs_flags); + } + match lhs.ty() { + CanonicalType::UInt(_) + | CanonicalType::SInt(_) + | CanonicalType::Bool(_) + | CanonicalType::AsyncReset(_) + | CanonicalType::SyncReset(_) + | CanonicalType::Reset(_) + | CanonicalType::Clock(_) + | CanonicalType::PhantomConst(_) + | CanonicalType::DynSimOnly(_) => { + unreachable!("flags are always FlagsTree::NoPadding") + } + CanonicalType::Array(lhs_ty) => { + let lhs = Expr::from_canonical(lhs); + let rhs = Expr::::from_canonical(rhs); + assert_eq!(lhs_ty.len(), rhs.ty().len()); + assert!(lhs_ty.len() > 0); + // FlagsTree treats arrays transparently, so the returned flags don't need to be adjusted. + // All array indexing operations are unioned together so just arbitrarily use index 0 + self.connect( + ArrayIndex::new(lhs, 0).to_expr(), + ArrayIndex::new(rhs, 0).to_expr(), + ) + } + CanonicalType::Enum(lhs_ty) => { + let lhs = Expr::from_canonical(lhs); + let rhs = Expr::::from_canonical(rhs); + assert_eq!(lhs_ty.variants().len(), rhs.ty().variants().len()); + let mut lhs_assume_padding_is_zeroed = lhs_flags.assume_padding_is_zeroed(); + let mut lhs_variants = Vec::with_capacity(lhs_ty.variants().len()); + for (variant_index, (lhs_variant, rhs_variant)) in lhs_ty + .variants() + .into_iter() + .zip(rhs.ty().variants()) + .enumerate() + { + assert_eq!(lhs_variant.ty.is_some(), rhs_variant.ty.is_some()); + if lhs_variant.ty.is_some() { + let (lhs_field_flags, _rhs_field_flags) = self.connect( + VariantAccess::new_by_index(lhs, variant_index).to_expr(), + VariantAccess::new_by_index(rhs, variant_index).to_expr(), + ); + lhs_variants.push(Some(lhs_field_flags)); + lhs_assume_padding_is_zeroed &= lhs_field_flags.assume_padding_is_zeroed(); + } else { + lhs_variants.push(None); + } + } + let lhs_flags = self.merge_expr_flags( + Expr::expr_enum(lhs), + FlagsTree::Enum { + variants: lhs_variants.intern_slice(), + assume_padding_is_zeroed: lhs_assume_padding_is_zeroed, + }, + ); + (lhs_flags, rhs_flags) + } + CanonicalType::Bundle(lhs_ty) => { + let lhs = Expr::from_canonical(lhs); + let rhs = Expr::::from_canonical(rhs); + assert_eq!(lhs_ty.fields().len(), rhs.ty().fields().len()); + let mut lhs_assume_padding_is_zeroed = lhs_flags.assume_padding_is_zeroed(); + let mut rhs_assume_padding_is_zeroed = rhs_flags.assume_padding_is_zeroed(); + let mut lhs_fields = Vec::with_capacity(lhs_ty.fields().len()); + let mut rhs_fields = Vec::with_capacity(lhs_ty.fields().len()); + for (field_index, (lhs_field, rhs_field)) in lhs_ty + .fields() + .into_iter() + .zip(rhs.ty().fields()) + .enumerate() + { + assert_eq!(lhs_field.flipped, rhs_field.flipped); + let lhs_field_flags; + let rhs_field_flags; + if lhs_field.flipped { + // flipped, so exchange lhs/rhs when recursively calling connect + (rhs_field_flags, lhs_field_flags) = self.connect( + FieldAccess::new_by_index(rhs, field_index).to_expr(), + FieldAccess::new_by_index(lhs, field_index).to_expr(), + ); + } else { + (lhs_field_flags, rhs_field_flags) = self.connect( + FieldAccess::new_by_index(lhs, field_index).to_expr(), + FieldAccess::new_by_index(rhs, field_index).to_expr(), + ); + } + lhs_fields.push(lhs_field_flags); + rhs_fields.push(rhs_field_flags); + lhs_assume_padding_is_zeroed &= lhs_field_flags.assume_padding_is_zeroed(); + rhs_assume_padding_is_zeroed &= rhs_field_flags.assume_padding_is_zeroed(); + } + let lhs_flags = self.merge_expr_flags( + Expr::expr_enum(lhs), + FlagsTree::Bundle { + fields: lhs_fields.intern_slice(), + assume_padding_is_zeroed: lhs_assume_padding_is_zeroed, + }, + ); + let rhs_flags = self.merge_expr_flags( + Expr::expr_enum(rhs), + FlagsTree::Bundle { + fields: rhs_fields.intern_slice(), + assume_padding_is_zeroed: rhs_assume_padding_is_zeroed, + }, + ); + (lhs_flags, rhs_flags) + } + CanonicalType::TraceAsString(_) => { + let lhs = Expr::from_canonical(lhs); + let rhs = Expr::from_canonical(rhs); + // FlagsTree treats TraceAsString transparently, so the returned flags don't need to be adjusted. + // this expression and the inner expression are unioned together + self.connect( + TraceAsStringAsInner::new(lhs).to_expr(), + TraceAsStringAsInner::new(rhs).to_expr(), + ) + } + } + } + fn visit_canonical_expr(&mut self, expr: Expr) -> Interned { + let expr_enum = Expr::expr_enum(expr); + let ty = expr.ty(); + let visited = self.expr_visited.entry(expr_enum).or_insert(false); + let flags = *self.expr_flags.entry(expr_enum).or_insert_with(|| { + self.any_changes = true; + FlagsTree::new(ty, true) + }); + if std::mem::replace(visited, true) { + return flags; + } + let mut merge_and_default_visit = |flags| { + self.merge_expr_flags(expr_enum, flags); + let Ok(()) = expr_enum.default_visit(self); + *self.expr_flags.find_mut(&expr_enum) + }; + match *expr_enum { + ExprEnum::UIntLiteral(_) + | ExprEnum::SIntLiteral(_) + | ExprEnum::BoolLiteral(_) + | ExprEnum::PhantomConst(_) => merge_and_default_visit(*flags), + ExprEnum::BundleLiteral(bundle_literal) => todo!(), + ExprEnum::ArrayLiteral(array_literal) => todo!(), + ExprEnum::EnumLiteral(enum_literal) => todo!(), + ExprEnum::Uninit(_) => merge_and_default_visit(*FlagsTree::new(ty, false)), + ExprEnum::NotU(_) + | ExprEnum::NotS(_) + | ExprEnum::NotB(_) + | ExprEnum::Neg(_) + | ExprEnum::BitAndU(_) + | ExprEnum::BitAndS(_) + | ExprEnum::BitAndB(_) + | ExprEnum::BitOrU(_) + | ExprEnum::BitOrS(_) + | ExprEnum::BitOrB(_) + | ExprEnum::BitXorU(_) + | ExprEnum::BitXorS(_) + | ExprEnum::BitXorB(_) + | ExprEnum::AddU(_) + | ExprEnum::AddS(_) + | ExprEnum::SubU(_) + | ExprEnum::SubS(_) + | ExprEnum::MulU(_) + | ExprEnum::MulS(_) + | ExprEnum::DivU(_) + | ExprEnum::DivS(_) + | ExprEnum::RemU(_) + | ExprEnum::RemS(_) + | ExprEnum::DynShlU(_) + | ExprEnum::DynShlS(_) + | ExprEnum::DynShrU(_) + | ExprEnum::DynShrS(_) + | ExprEnum::FixedShlU(_) + | ExprEnum::FixedShlS(_) + | ExprEnum::FixedShrU(_) + | ExprEnum::FixedShrS(_) + | ExprEnum::CmpLtB(_) + | ExprEnum::CmpLeB(_) + | ExprEnum::CmpGtB(_) + | ExprEnum::CmpGeB(_) + | ExprEnum::CmpEqB(_) + | ExprEnum::CmpNeB(_) + | ExprEnum::CmpLtU(_) + | ExprEnum::CmpLeU(_) + | ExprEnum::CmpGtU(_) + | ExprEnum::CmpGeU(_) + | ExprEnum::CmpEqU(_) + | ExprEnum::CmpNeU(_) + | ExprEnum::CmpLtS(_) + | ExprEnum::CmpLeS(_) + | ExprEnum::CmpGtS(_) + | ExprEnum::CmpGeS(_) + | ExprEnum::CmpEqS(_) + | ExprEnum::CmpNeS(_) + | ExprEnum::CastUIntToUInt(_) + | ExprEnum::CastUIntToSInt(_) + | ExprEnum::CastSIntToUInt(_) + | ExprEnum::CastSIntToSInt(_) + | ExprEnum::CastBoolToUInt(_) + | ExprEnum::CastBoolToSInt(_) + | ExprEnum::CastUIntToBool(_) + | ExprEnum::CastSIntToBool(_) + | ExprEnum::CastBoolToSyncReset(_) + | ExprEnum::CastUIntToSyncReset(_) + | ExprEnum::CastSIntToSyncReset(_) + | ExprEnum::CastBoolToAsyncReset(_) + | ExprEnum::CastUIntToAsyncReset(_) + | ExprEnum::CastSIntToAsyncReset(_) + | ExprEnum::CastSyncResetToBool(_) + | ExprEnum::CastSyncResetToUInt(_) + | ExprEnum::CastSyncResetToSInt(_) + | ExprEnum::CastSyncResetToReset(_) + | ExprEnum::CastAsyncResetToBool(_) + | ExprEnum::CastAsyncResetToUInt(_) + | ExprEnum::CastAsyncResetToSInt(_) + | ExprEnum::CastAsyncResetToReset(_) + | ExprEnum::CastResetToBool(_) + | ExprEnum::CastResetToUInt(_) + | ExprEnum::CastResetToSInt(_) + | ExprEnum::CastBoolToClock(_) + | ExprEnum::CastUIntToClock(_) + | ExprEnum::CastSIntToClock(_) + | ExprEnum::CastClockToBool(_) + | ExprEnum::CastClockToUInt(_) + | ExprEnum::CastClockToSInt(_) => merge_and_default_visit(*flags), + ExprEnum::FieldAccess(field_access) => todo!(), + ExprEnum::VariantAccess(variant_access) => todo!(), + ExprEnum::ArrayIndex(array_index) => todo!(), + ExprEnum::DynArrayIndex(dyn_array_index) => todo!(), + ExprEnum::ReduceBitAndU(_) + | ExprEnum::ReduceBitAndS(_) + | ExprEnum::ReduceBitOrU(_) + | ExprEnum::ReduceBitOrS(_) + | ExprEnum::ReduceBitXorU(_) + | ExprEnum::ReduceBitXorS(_) + | ExprEnum::SliceUInt(_) + | ExprEnum::SliceSInt(_) + | ExprEnum::CastToBits(_) => merge_and_default_visit(*flags), + ExprEnum::CastBitsTo(_) => merge_and_default_visit(*FlagsTree::new(ty, false)), + ExprEnum::ToTraceAsString(to_trace_as_string) => todo!(), + ExprEnum::TraceAsStringAsInner(trace_as_string_as_inner) => todo!(), + ExprEnum::StructuralEq(structural_eq) => merge_and_default_visit(*flags), + ExprEnum::ModuleIO(module_io) => merge_and_default_visit(*flags), + ExprEnum::Instance(instance) => todo!(), + ExprEnum::Wire(wire) => merge_and_default_visit(*flags), + ExprEnum::Reg(reg) => todo!(), + ExprEnum::RegSync(reg) => todo!(), + ExprEnum::RegAsync(reg) => todo!(), + ExprEnum::MemPort(_) => merge_and_default_visit(*FlagsTree::new(ty, false)), + ExprEnum::FormalInput(_) => merge_and_default_visit(*FlagsTree::new(ty, false)), + ExprEnum::SimIoForGlobal(_) => { + unreachable!("Module is known to not contain SimIoForGlobal from validation") + } + } + } +} + +impl Visitor for State { + type Error = Infallible; + + fn visit_expr_enum(&mut self, expr_enum: &ExprEnum) -> Result<(), Self::Error> { + self.visit_canonical_expr(expr_enum.to_expr()); + Ok(()) + } + + fn visit_block(&mut self, v: &Block) -> Result<(), Self::Error> { + todo!() + } + + fn visit_module(&mut self, v: &Module) -> Result<(), Self::Error> { + let module = v.canonical(); + let external = match module.body() { + ModuleBody::Normal(_) => false, + ModuleBody::Extern(_) => true, + }; + if external || *self.root_module == module { + for module_io in module.module_io() { + let module_io = module_io.module_io.to_expr(); + // all main/external module I/O can have non-zeroed padding + self.merge_expr_flags( + Expr::expr_enum(module_io), + *FlagsTree::new(module_io.ty(), false), + ); + } + } + module.default_visit(self) + } +} + +impl Folder for State { + type Error = Infallible; +} + +pub fn deduce_structural_eq_flags(module: Interned>) -> Interned> { + // the algorithm proceeds in two stages: + // 1. Visitor for State: a fixed-point data-flow algorithm to determine what flags should be + // 2. Folder for State: transforming the StructuralEq operations to have the deduced flags + let mut state = State::new(module); + loop { + state.expr_visited.values_mut().for_each(|v| *v = false); + state.any_changes = false; + let Ok(()) = module.visit(&mut state); + if !state.any_changes { + break; + } + } + let Ok(retval) = module.fold(&mut state); + retval +} diff --git a/crates/fayalite/src/util.rs b/crates/fayalite/src/util.rs index 6845d3c..e804643 100644 --- a/crates/fayalite/src/util.rs +++ b/crates/fayalite/src/util.rs @@ -16,8 +16,8 @@ pub type DefaultBuildHasher = test_hasher::DefaultBuildHasher; #[cfg(not(feature = "unstable-test-hasher"))] pub(crate) type DefaultBuildHasher = hashbrown::DefaultHashBuilder; -pub(crate) type HashMap = hashbrown::HashMap; -pub(crate) type HashSet = hashbrown::HashSet; +pub(crate) type HashMap = hashbrown::HashMap; +pub(crate) type HashSet = hashbrown::HashSet; #[doc(inline)] pub use const_bool::{ConstBool, ConstBoolDispatch, ConstBoolDispatchTag, GenericConstBool}; @@ -44,6 +44,8 @@ pub use misc::{ pub(crate) use misc::{InternedStrCompareAsStr, chain, copy_le_bytes_to_bitslice}; pub mod job_server; +pub mod map_trait; pub mod prefix_sum; pub mod ready_valid; pub(crate) mod serde_by_id; +pub mod union_find_map; diff --git a/crates/fayalite/src/util/map_trait.rs b/crates/fayalite/src/util/map_trait.rs new file mode 100644 index 0000000..6fc065c --- /dev/null +++ b/crates/fayalite/src/util/map_trait.rs @@ -0,0 +1,463 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// See Notices.txt for copyright information + +use std::fmt; + +pub enum Entry<'a, M: Map + 'a> { + Vacant(M::VacantEntry<'a>), + Occupied(M::OccupiedEntry<'a>), +} + +impl<'a, M: Map + 'a> Entry<'a, M> { + pub fn and_modify(mut self, f: F) -> Self { + if let Self::Occupied(entry) = &mut self { + f(entry.get_mut()); + } + self + } + pub fn insert_entry(self, v: M::Value) -> M::OccupiedEntry<'a> { + match self { + Self::Vacant(entry) => entry.insert_entry(v), + Self::Occupied(mut entry) => { + entry.insert(v); + entry + } + } + } + pub fn key(&self) -> &M::Key { + match self { + Self::Vacant(entry) => entry.key(), + Self::Occupied(entry) => entry.key(), + } + } + pub fn or_default(self) -> &'a mut M::Value + where + M::Value: Default, + { + self.or_insert_with(Default::default) + } + pub fn or_insert(self, v: M::Value) -> &'a mut M::Value { + match self { + Self::Vacant(entry) => entry.insert(v), + Self::Occupied(entry) => entry.into_mut(), + } + } + pub fn or_insert_with M::Value>(self, f: F) -> &'a mut M::Value { + match self { + Self::Vacant(entry) => entry.insert(f()), + Self::Occupied(entry) => entry.into_mut(), + } + } + pub fn or_insert_with_key M::Value>(self, f: F) -> &'a mut M::Value { + match self { + Self::Vacant(entry) => { + let v = f(entry.key()); + entry.insert(v) + } + Self::Occupied(entry) => entry.into_mut(), + } + } +} + +impl<'a, M: Map: fmt::Debug, VacantEntry<'a>: fmt::Debug> + 'a> fmt::Debug + for Entry<'a, M> +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Vacant(v) => f.debug_tuple("Vacant").field(v).finish(), + Self::Occupied(v) => f.debug_tuple("Occupied").field(v).finish(), + } + } +} + +pub trait VacantEntry<'a>: Sized { + type Map: Map = Self> + 'a; + fn insert(self, v: ::Value) -> &'a mut ::Value; + fn insert_entry(self, v: ::Value) -> ::OccupiedEntry<'a>; + fn into_key(self) -> ::Key; + fn key(&self) -> &::Key; +} + +pub trait OccupiedEntry<'a>: Sized { + type Map: Map = Self> + 'a; + fn get(&self) -> &::Value; + fn get_mut(&mut self) -> &mut ::Value; + fn insert(&mut self, v: ::Value) -> ::Value; + fn into_mut(self) -> &'a mut ::Value; + fn key(&self) -> &::Key; + fn remove(self) -> ::Value; + fn remove_entry(self) -> (::Key, ::Value); +} + +pub trait Map: + Sized + + IntoIterator::Key, ::Value)> + + Extend<(::Key, ::Value)> + + FromIterator<(::Key, ::Value)> +{ + type Key; + type Value; + type IntoKeys: Iterator; + type IntoValues: Iterator; + type Iter<'a>: Iterator + where + Self: 'a, + Self::Key: 'a, + Self::Value: 'a; + type IterMut<'a>: Iterator + where + Self: 'a, + Self::Key: 'a, + Self::Value: 'a; + type Keys<'a>: Iterator + where + Self: 'a, + Self::Key: 'a; + type Values<'a>: Iterator + where + Self: 'a, + Self::Value: 'a; + type ValuesMut<'a>: Iterator + where + Self: 'a, + Self::Value: 'a; + type OccupiedEntry<'a>: OccupiedEntry<'a, Map = Self> + where + Self: 'a; + type VacantEntry<'a>: VacantEntry<'a, Map = Self> + where + Self: 'a; + fn clear(&mut self); + fn entry(&mut self, k: Self::Key) -> Entry<'_, Self>; + fn insert(&mut self, k: Self::Key, v: Self::Value) -> Option; + fn into_keys(self) -> Self::IntoKeys; + fn into_values(self) -> Self::IntoValues; + fn is_empty(&self) -> bool; + fn iter(&self) -> Self::Iter<'_>; + fn iter_mut(&mut self) -> Self::IterMut<'_>; + fn keys(&self) -> Self::Keys<'_>; + fn len(&self) -> usize; + fn retain bool>(&mut self, f: F); + fn values(&self) -> Self::Values<'_>; + fn values_mut(&mut self) -> Self::ValuesMut<'_>; +} + +pub trait MapGet: Map { + fn contains_key(&self, k: &Q) -> bool; + fn get(&self, k: &Q) -> Option<&Self::Value>; + fn get_mut(&mut self, k: &Q) -> Option<&mut Self::Value>; + fn remove(&mut self, k: &Q) -> Option; + fn remove_entry(&mut self, k: &Q) -> Option<(Self::Key, Self::Value)>; +} + +mod hash_map { + use super::*; + use crate::util::HashMap; + use hashbrown::{Equivalent, hash_map}; + use std::hash::{BuildHasher, Hash}; + + impl Map for HashMap { + type Key = K; + type Value = V; + type IntoKeys = hash_map::IntoKeys; + type IntoValues = hash_map::IntoValues; + type Iter<'a> + = hash_map::Iter<'a, K, V> + where + Self: 'a, + Self::Key: 'a, + Self::Value: 'a; + type IterMut<'a> + = hash_map::IterMut<'a, K, V> + where + Self: 'a, + Self::Key: 'a, + Self::Value: 'a; + type Keys<'a> + = hash_map::Keys<'a, K, V> + where + Self: 'a, + Self::Key: 'a; + type Values<'a> + = hash_map::Values<'a, K, V> + where + Self: 'a, + Self::Value: 'a; + type ValuesMut<'a> + = hash_map::ValuesMut<'a, K, V> + where + Self: 'a, + Self::Value: 'a; + type OccupiedEntry<'a> + = hash_map::OccupiedEntry<'a, K, V, H> + where + Self: 'a; + type VacantEntry<'a> + = hash_map::VacantEntry<'a, K, V, H> + where + Self: 'a; + fn clear(&mut self) { + self.clear(); + } + fn entry(&mut self, k: Self::Key) -> Entry<'_, Self> { + use hash_map::Entry::*; + match self.entry(k) { + Occupied(entry) => Entry::Occupied(entry), + Vacant(entry) => Entry::Vacant(entry), + } + } + fn insert(&mut self, k: Self::Key, v: Self::Value) -> Option { + self.insert(k, v) + } + fn into_keys(self) -> Self::IntoKeys { + self.into_keys() + } + fn into_values(self) -> Self::IntoValues { + self.into_values() + } + fn is_empty(&self) -> bool { + self.is_empty() + } + fn iter(&self) -> Self::Iter<'_> { + self.iter() + } + fn iter_mut(&mut self) -> Self::IterMut<'_> { + self.iter_mut() + } + fn keys(&self) -> Self::Keys<'_> { + self.keys() + } + fn len(&self) -> usize { + self.len() + } + fn retain bool>(&mut self, f: F) { + self.retain(f); + } + fn values(&self) -> Self::Values<'_> { + self.values() + } + fn values_mut(&mut self) -> Self::ValuesMut<'_> { + self.values_mut() + } + } + + impl> MapGet + for HashMap + { + fn contains_key(&self, k: &Q) -> bool { + self.contains_key(k) + } + fn get(&self, k: &Q) -> Option<&Self::Value> { + self.get(k) + } + fn get_mut(&mut self, k: &Q) -> Option<&mut Self::Value> { + self.get_mut(k) + } + fn remove(&mut self, k: &Q) -> Option { + self.remove(k) + } + fn remove_entry(&mut self, k: &Q) -> Option<(Self::Key, Self::Value)> { + self.remove_entry(k) + } + } + + impl<'a, K: Eq + Hash, V, H: BuildHasher + Default> VacantEntry<'a> + for hash_map::VacantEntry<'a, K, V, H> + { + type Map = HashMap; + fn insert(self, v: ::Value) -> &'a mut ::Value { + self.insert(v) + } + fn insert_entry( + self, + v: ::Value, + ) -> ::OccupiedEntry<'a> { + self.insert_entry(v) + } + fn into_key(self) -> ::Key { + self.into_key() + } + fn key(&self) -> &::Key { + self.key() + } + } + + impl<'a, K: Eq + Hash, V, H: BuildHasher + Default> OccupiedEntry<'a> + for hash_map::OccupiedEntry<'a, K, V, H> + { + type Map = HashMap; + fn get(&self) -> &::Value { + self.get() + } + fn get_mut(&mut self) -> &mut ::Value { + self.get_mut() + } + fn insert(&mut self, v: ::Value) -> ::Value { + self.insert(v) + } + fn into_mut(self) -> &'a mut ::Value { + self.into_mut() + } + fn key(&self) -> &::Key { + self.key() + } + fn remove(self) -> ::Value { + self.remove() + } + fn remove_entry(self) -> (::Key, ::Value) { + self.remove_entry() + } + } +} + +mod btree_map { + use super::*; + use std::collections::{BTreeMap, btree_map}; + + impl Map for BTreeMap { + type Key = K; + type Value = V; + type IntoKeys = btree_map::IntoKeys; + type IntoValues = btree_map::IntoValues; + type Iter<'a> + = btree_map::Iter<'a, K, V> + where + Self: 'a, + Self::Key: 'a, + Self::Value: 'a; + type IterMut<'a> + = btree_map::IterMut<'a, K, V> + where + Self: 'a, + Self::Key: 'a, + Self::Value: 'a; + type Keys<'a> + = btree_map::Keys<'a, K, V> + where + Self: 'a, + Self::Key: 'a; + type Values<'a> + = btree_map::Values<'a, K, V> + where + Self: 'a, + Self::Value: 'a; + type ValuesMut<'a> + = btree_map::ValuesMut<'a, K, V> + where + Self: 'a, + Self::Value: 'a; + type OccupiedEntry<'a> + = btree_map::OccupiedEntry<'a, K, V> + where + Self: 'a; + type VacantEntry<'a> + = btree_map::VacantEntry<'a, K, V> + where + Self: 'a; + fn clear(&mut self) { + self.clear(); + } + fn entry(&mut self, k: Self::Key) -> Entry<'_, Self> { + use btree_map::Entry::*; + match self.entry(k) { + Occupied(entry) => Entry::Occupied(entry), + Vacant(entry) => Entry::Vacant(entry), + } + } + fn insert(&mut self, k: Self::Key, v: Self::Value) -> Option { + self.insert(k, v) + } + fn into_keys(self) -> Self::IntoKeys { + self.into_keys() + } + fn into_values(self) -> Self::IntoValues { + self.into_values() + } + fn is_empty(&self) -> bool { + self.is_empty() + } + fn iter(&self) -> Self::Iter<'_> { + self.iter() + } + fn iter_mut(&mut self) -> Self::IterMut<'_> { + self.iter_mut() + } + fn keys(&self) -> Self::Keys<'_> { + self.keys() + } + fn len(&self) -> usize { + self.len() + } + fn retain bool>(&mut self, f: F) { + self.retain(f); + } + fn values(&self) -> Self::Values<'_> { + self.values() + } + fn values_mut(&mut self) -> Self::ValuesMut<'_> { + self.values_mut() + } + } + + impl, V, Q: ?Sized + Ord> MapGet for BTreeMap { + fn contains_key(&self, k: &Q) -> bool { + self.contains_key(k) + } + fn get(&self, k: &Q) -> Option<&Self::Value> { + self.get(k) + } + fn get_mut(&mut self, k: &Q) -> Option<&mut Self::Value> { + self.get_mut(k) + } + fn remove(&mut self, k: &Q) -> Option { + self.remove(k) + } + fn remove_entry(&mut self, k: &Q) -> Option<(Self::Key, Self::Value)> { + self.remove_entry(k) + } + } + + impl<'a, K: Ord, V> VacantEntry<'a> for btree_map::VacantEntry<'a, K, V> { + type Map = BTreeMap; + fn insert(self, v: ::Value) -> &'a mut ::Value { + self.insert(v) + } + fn insert_entry( + self, + v: ::Value, + ) -> ::OccupiedEntry<'a> { + self.insert_entry(v) + } + fn into_key(self) -> ::Key { + self.into_key() + } + fn key(&self) -> &::Key { + self.key() + } + } + + impl<'a, K: Ord, V> OccupiedEntry<'a> for btree_map::OccupiedEntry<'a, K, V> { + type Map = BTreeMap; + fn get(&self) -> &::Value { + self.get() + } + fn get_mut(&mut self) -> &mut ::Value { + self.get_mut() + } + fn insert(&mut self, v: ::Value) -> ::Value { + self.insert(v) + } + fn into_mut(self) -> &'a mut ::Value { + self.into_mut() + } + fn key(&self) -> &::Key { + self.key() + } + fn remove(self) -> ::Value { + self.remove() + } + fn remove_entry(self) -> (::Key, ::Value) { + self.remove_entry() + } + } +} diff --git a/crates/fayalite/src/util/union_find_map.rs b/crates/fayalite/src/util/union_find_map.rs new file mode 100644 index 0000000..6591036 --- /dev/null +++ b/crates/fayalite/src/util/union_find_map.rs @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// See Notices.txt for copyright information + +use crate::util::{ + HashMap, + map_trait::{self, Map, MapGet, OccupiedEntry as _, VacantEntry as _}, +}; +use petgraph::unionfind::UnionFind; +use std::{collections::BTreeMap, fmt, marker::PhantomData}; + +pub struct UnionFindMap> { + uf: UnionFind, + keys_to_indexes: M, + values: Vec>, + _phantom: PhantomData, +} + +impl> fmt::Debug + for UnionFindMap +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut indexes_to_keys = vec![None; self.len()]; + for (k, &index) in self.keys_to_indexes.iter() { + indexes_to_keys[index] = Some(k); + } + let mut debug_map = f.debug_map(); + for (index, key) in indexes_to_keys.into_iter().enumerate() { + if let Some(key) = key { + debug_map.key(key); + } else { + debug_map.key(&fmt::from_fn(|f| { + f.write_str("<>") + })); + } + let set_index = self.uf.find(index); + debug_map.value(&fmt::from_fn(|f| { + write!(f, "@{set_index} ")?; + if set_index == index { + let Some(value) = &self.values[index] else { + unreachable!(); + }; + value.fmt(f) + } else { + Ok(()) + } + })); + } + debug_map.finish() + } +} + +impl> UnionFindMap { + /// returns the number of keys, not the number of sets/values + pub fn len(&self) -> usize { + self.values.len() + } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn capacity(&self) -> usize { + self.values.capacity() + } + #[track_caller] + pub fn equiv(&self, k1: &K1, k2: &K2) -> bool + where + M: MapGet + MapGet, + { + self.try_equiv(k1, k2).expect("key not found") + } + pub fn try_equiv(&self, k1: &K1, k2: &K2) -> Option + where + M: MapGet + MapGet, + { + let &index1 = self.keys_to_indexes.get(k1)?; + let &index2 = self.keys_to_indexes.get(k2)?; + Some(self.uf.equiv(index1, index2)) + } + #[track_caller] + pub fn find(&self, k: &Q) -> &V + where + M: MapGet, + { + self.try_find(k).expect("key not found") + } + pub fn try_find(&self, k: &Q) -> Option<&V> + where + M: MapGet, + { + let &index = self.keys_to_indexes.get(k)?; + self.values[self.uf.find(index)].as_ref() + } + #[track_caller] + pub fn find_mut(&mut self, k: &Q) -> &mut V + where + M: MapGet, + { + self.try_find_mut(k).expect("key not found") + } + pub fn try_find_mut(&mut self, k: &Q) -> Option<&mut V> + where + M: MapGet, + { + let &index = self.keys_to_indexes.get(k)?; + self.values[self.uf.find_mut(index)].as_mut() + } + /// inserts a new key as a new set, otherwise replaces the value for the set containing the passed-in key + pub fn insert(&mut self, k: K, v: V) -> Option { + match self.entry(k) { + Entry::Vacant(entry) => { + entry.insert(v); + None + } + Entry::Occupied(mut entry) => Some(entry.insert(v)), + } + } + pub fn entry(&mut self, k: K) -> Entry<'_, K, V, M> { + match self.keys_to_indexes.entry(k) { + map_trait::Entry::Vacant(keys_to_indexes_entry) => Entry::Vacant(VacantEntry { + keys_to_indexes_entry, + uf: &mut self.uf, + values: &mut self.values, + }), + map_trait::Entry::Occupied(keys_to_indexes_entry) => { + let set_index = self.uf.find_mut(*keys_to_indexes_entry.get()); + Entry::Occupied(OccupiedEntry { + keys_to_indexes_entry, + set_index, + uf: &mut self.uf, + values: &mut self.values, + }) + } + } + } + /// Unify the two sets containing `k1` and `k2`. + /// If the sets were the same, returns `Some((false, value))`, + /// otherwise calling `merge` to merge their values and returning `Some((true, value))`. + /// Returns `None` if either of the keys weren't found. + pub fn try_union( + &mut self, + k1: &K1, + k2: &K2, + merge: F, + ) -> Option<(bool, &mut V)> + where + M: MapGet + MapGet, + F: FnOnce(&K1, V, &K2, V) -> V, + { + let &index1 = self.keys_to_indexes.get(k1)?; + let &index2 = self.keys_to_indexes.get(k2)?; + let index1 = self.uf.find_mut(index1); + let index2 = self.uf.find_mut(index2); + if index1 == index2 { + return Some((false, self.values[index1].as_mut()?)); + } + assert!(self.uf.union(index1, index2)); + let v1 = self.values[index1].take().expect("known to be Some"); + let v2 = self.values[index2].take().expect("known to be Some"); + let dest = &mut self.values[self.uf.find_mut(index1)]; + let dest = dest.insert(merge(k1, v1, k2, v2)); + Some((true, dest)) + } + /// Unify the two sets containing `k1` and `k2`. + /// If the sets were the same, returns `(false, value)`, + /// otherwise calling `merge` to merge their values and returning `(true, value)`. + /// panics if either of the keys weren't found. + #[track_caller] + pub fn union(&mut self, k1: &K1, k2: &K2, merge: F) -> (bool, &mut V) + where + M: MapGet + MapGet, + F: FnOnce(&K1, V, &K2, V) -> V, + { + self.try_union(k1, k2, merge).expect("key not found") + } +} + +impl UnionFindMap { + pub fn new() -> Self { + Self::with_hasher(Default::default()) + } + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacity_and_hasher(capacity, Default::default()) + } +} + +impl UnionFindMap> { + pub const fn new_btree() -> Self { + Self { + uf: UnionFind::new_empty(), + keys_to_indexes: BTreeMap::new(), + values: Vec::new(), + _phantom: PhantomData, + } + } +} + +impl UnionFindMap> { + pub const fn with_hasher(hash_builder: H) -> Self { + Self { + uf: UnionFind::new_empty(), + keys_to_indexes: HashMap::with_hasher(hash_builder), + values: Vec::new(), + _phantom: PhantomData, + } + } + pub fn with_capacity_and_hasher(capacity: usize, hash_builder: H) -> Self { + Self { + uf: UnionFind::with_capacity(capacity), + keys_to_indexes: HashMap::with_capacity_and_hasher(capacity, hash_builder), + values: Vec::with_capacity(capacity), + _phantom: PhantomData, + } + } +} + +impl Default for UnionFindMap { + fn default() -> Self { + Self { + uf: UnionFind::new_empty(), + keys_to_indexes: M::default(), + values: Vec::new(), + _phantom: PhantomData, + } + } +} + +pub struct OccupiedEntry<'a, K, V, M: Map + 'a> { + keys_to_indexes_entry: M::OccupiedEntry<'a>, + set_index: usize, + uf: &'a mut UnionFind, + values: &'a mut [Option], +} + +impl<'a, K, V, M: Map + 'a> OccupiedEntry<'a, K, V, M> { + pub fn get(&self) -> &V { + let Some(v) = &self.values[self.set_index] else { + unreachable!() + }; + v + } + pub fn get_mut(&mut self) -> &mut V { + let Some(v) = &mut self.values[self.set_index] else { + unreachable!() + }; + v + } + /// replaces the value for this set + pub fn insert(&mut self, v: V) -> V { + std::mem::replace(self.get_mut(), v) + } + pub fn into_mut(self) -> &'a mut V { + let Some(v) = &mut self.values[self.set_index] else { + unreachable!() + }; + v + } + pub fn key(&self) -> &K { + self.keys_to_indexes_entry.key() + } +} + +pub struct VacantEntry<'a, K, V, M: Map + 'a> { + keys_to_indexes_entry: M::VacantEntry<'a>, + uf: &'a mut UnionFind, + values: &'a mut Vec>, +} + +impl<'a, K, V, M: Map + 'a> VacantEntry<'a, K, V, M> { + /// inserts a new key as a new set + pub fn insert(self, v: V) -> &'a mut V { + self.insert_entry(v).into_mut() + } + /// inserts a new key as a new set + pub fn insert_entry(self, v: V) -> OccupiedEntry<'a, K, V, M> { + let Self { + keys_to_indexes_entry, + uf, + values, + } = self; + let set_index = uf.new_set(); + values.push(Some(v)); + OccupiedEntry { + keys_to_indexes_entry: keys_to_indexes_entry.insert_entry(set_index), + set_index, + uf, + values, + } + } + pub fn into_key(self) -> K { + self.keys_to_indexes_entry.into_key() + } + pub fn key(&self) -> &K { + self.keys_to_indexes_entry.key() + } +} + +pub enum Entry<'a, K, V, M: Map + 'a> { + Vacant(VacantEntry<'a, K, V, M>), + Occupied(OccupiedEntry<'a, K, V, M>), +} + +impl<'a, K, V, M: Map + 'a> Entry<'a, K, V, M> { + pub fn and_modify(mut self, f: F) -> Self { + if let Self::Occupied(entry) = &mut self { + f(entry.get_mut()); + } + self + } + /// inserts a new key as a new set, otherwise replaces the value for the set containing the passed-in key + pub fn insert_entry(self, v: V) -> OccupiedEntry<'a, K, V, M> { + match self { + Self::Vacant(entry) => entry.insert_entry(v), + Self::Occupied(mut entry) => { + entry.insert(v); + entry + } + } + } + pub fn key(&self) -> &K { + match self { + Self::Vacant(entry) => entry.key(), + Self::Occupied(entry) => entry.key(), + } + } + /// inserts a new key as a new set + pub fn or_default(self) -> &'a mut V + where + V: Default, + { + self.or_insert_with(V::default) + } + /// inserts a new key as a new set + pub fn or_insert(self, v: V) -> &'a mut V { + match self { + Self::Vacant(entry) => entry.insert(v), + Self::Occupied(entry) => entry.into_mut(), + } + } + /// inserts a new key as a new set + pub fn or_insert_with V>(self, f: F) -> &'a mut V { + match self { + Self::Vacant(entry) => entry.insert(f()), + Self::Occupied(entry) => entry.into_mut(), + } + } + /// inserts a new key as a new set + pub fn or_insert_with_key V>(self, f: F) -> &'a mut V { + match self { + Self::Vacant(entry) => { + let v = f(entry.key()); + entry.insert(v) + } + Self::Occupied(entry) => entry.into_mut(), + } + } +}