diff --git a/crates/fayalite/src/enum_.rs b/crates/fayalite/src/enum_.rs index 384414c..13724ef 100644 --- a/crates/fayalite/src/enum_.rs +++ b/crates/fayalite/src/enum_.rs @@ -7,14 +7,14 @@ use crate::{ int::Bool, intern::{Intern, Interned}, module::{ - enum_match_variants_helper, EnumMatchVariantAndInactiveScopeImpl, - EnumMatchVariantsIterImpl, Scope, + connect, enum_match_variants_helper, incomplete_wire, wire, + EnumMatchVariantAndInactiveScopeImpl, EnumMatchVariantsIterImpl, Scope, }, source_location::SourceLocation, ty::{CanonicalType, MatchVariantAndInactiveScope, StaticType, Type, TypeProperties}, }; use hashbrown::HashMap; -use std::{fmt, iter::FusedIterator}; +use std::{convert::Infallible, fmt, iter::FusedIterator}; #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub struct EnumVariant { @@ -364,3 +364,308 @@ pub fn HdlSome(value: impl ToExpr) -> Expr> { let value = value.to_expr(); HdlOption[Expr::ty(value)].HdlSome(value) } + +impl HdlOption { + #[track_caller] + pub fn try_map( + expr: Expr, + f: impl FnOnce(Expr) -> Result, E>, + ) -> Result>, E> { + Self::try_and_then(expr, |v| Ok(HdlSome(f(v)?))) + } + #[track_caller] + pub fn map( + expr: Expr, + f: impl FnOnce(Expr) -> Expr, + ) -> Expr> { + Self::and_then(expr, |v| HdlSome(f(v))) + } + #[hdl] + #[track_caller] + pub fn try_and_then( + expr: Expr, + f: impl FnOnce(Expr) -> Result>, E>, + ) -> Result>, E> { + // manually run match steps so we can extract the return type to construct HdlNone + type Wrap = T; + #[hdl] + let mut and_then_out = incomplete_wire(); + let mut iter = Self::match_variants(expr, SourceLocation::caller()); + let none = iter.next().unwrap(); + let some = iter.next().unwrap(); + assert!(iter.next().is_none()); + let (Wrap::<::MatchVariant>::HdlSome(value), some_scope) = + Self::match_activate_scope(some) + else { + unreachable!(); + }; + let value = f(value).map_err(|e| { + and_then_out.complete(()); // avoid error + e + })?; + let and_then_out = and_then_out.complete(Expr::ty(value)); + connect(and_then_out, value); + drop(some_scope); + let (Wrap::<::MatchVariant>::HdlNone, none_scope) = + Self::match_activate_scope(none) + else { + unreachable!(); + }; + connect(and_then_out, Expr::ty(and_then_out).HdlNone()); + drop(none_scope); + Ok(and_then_out) + } + #[track_caller] + pub fn and_then( + expr: Expr, + f: impl FnOnce(Expr) -> Expr>, + ) -> Expr> { + match Self::try_and_then(expr, |v| Ok::<_, Infallible>(f(v))) { + Ok(v) => v, + Err(e) => match e {}, + } + } + #[hdl] + #[track_caller] + pub fn and(expr: Expr, opt_b: Expr>) -> Expr> { + #[hdl] + let and_out = wire(Expr::ty(opt_b)); + connect(and_out, Expr::ty(opt_b).HdlNone()); + #[hdl] + if let HdlSome(_) = expr { + connect(and_out, opt_b); + } + and_out + } + #[hdl] + #[track_caller] + pub fn try_filter( + expr: Expr, + f: impl FnOnce(Expr) -> Result, E>, + ) -> Result, E> { + #[hdl] + let filtered = wire(Expr::ty(expr)); + connect(filtered, Expr::ty(expr).HdlNone()); + let mut f = Some(f); + #[hdl] + if let HdlSome(v) = expr { + #[hdl] + if f.take().unwrap()(v)? { + connect(filtered, HdlSome(v)); + } + } + Ok(filtered) + } + #[hdl] + #[track_caller] + pub fn filter(expr: Expr, f: impl FnOnce(Expr) -> Expr) -> Expr { + match Self::try_filter(expr, |v| Ok::<_, Infallible>(f(v))) { + Ok(v) => v, + Err(e) => match e {}, + } + } + #[hdl] + #[track_caller] + pub fn try_inspect( + expr: Expr, + f: impl FnOnce(Expr) -> Result<(), E>, + ) -> Result, E> { + let mut f = Some(f); + #[hdl] + if let HdlSome(v) = expr { + f.take().unwrap()(v)?; + } + Ok(expr) + } + #[hdl] + #[track_caller] + pub fn inspect(expr: Expr, f: impl FnOnce(Expr)) -> Expr { + let mut f = Some(f); + #[hdl] + if let HdlSome(v) = expr { + f.take().unwrap()(v); + } + expr + } + #[hdl] + #[track_caller] + pub fn is_none(expr: Expr) -> Expr { + #[hdl] + let is_none_out: Bool = wire(); + connect(is_none_out, false); + #[hdl] + if let HdlNone = expr { + connect(is_none_out, true); + } + is_none_out + } + #[hdl] + #[track_caller] + pub fn is_some(expr: Expr) -> Expr { + #[hdl] + let is_some_out: Bool = wire(); + connect(is_some_out, false); + #[hdl] + if let HdlSome(_) = expr { + connect(is_some_out, true); + } + is_some_out + } + #[hdl] + #[track_caller] + pub fn map_or( + expr: Expr, + default: Expr, + f: impl FnOnce(Expr) -> Expr, + ) -> Expr { + #[hdl] + let mapped = wire(Expr::ty(default)); + let mut f = Some(f); + #[hdl] + match expr { + HdlSome(v) => connect(mapped, f.take().unwrap()(v)), + HdlNone => connect(mapped, default), + } + mapped + } + #[hdl] + #[track_caller] + pub fn map_or_else( + expr: Expr, + default: impl FnOnce() -> Expr, + f: impl FnOnce(Expr) -> Expr, + ) -> Expr { + #[hdl] + let mut mapped = incomplete_wire(); + let mut default = Some(default); + let mut f = Some(f); + let mut retval = None; + #[hdl] + match expr { + HdlSome(v) => { + let v = f.take().unwrap()(v); + let mapped = *retval.get_or_insert_with(|| mapped.complete(Expr::ty(v))); + connect(mapped, v); + } + HdlNone => { + let v = default.take().unwrap()(); + let mapped = *retval.get_or_insert_with(|| mapped.complete(Expr::ty(v))); + connect(mapped, v); + } + } + retval.unwrap() + } + #[hdl] + #[track_caller] + pub fn or(expr: Expr, opt_b: Expr) -> Expr { + #[hdl] + let or_out = wire(Expr::ty(expr)); + connect(or_out, opt_b); + #[hdl] + if let HdlSome(_) = expr { + connect(or_out, expr); + } + or_out + } + #[hdl] + #[track_caller] + pub fn or_else(expr: Expr, f: impl FnOnce() -> Expr) -> Expr { + #[hdl] + let or_else_out = wire(Expr::ty(expr)); + connect(or_else_out, f()); + #[hdl] + if let HdlSome(_) = expr { + connect(or_else_out, expr); + } + or_else_out + } + #[hdl] + #[track_caller] + pub fn unwrap_or(expr: Expr, default: Expr) -> Expr { + #[hdl] + let unwrap_or_else_out = wire(Expr::ty(default)); + connect(unwrap_or_else_out, default); + #[hdl] + if let HdlSome(v) = expr { + connect(unwrap_or_else_out, v); + } + unwrap_or_else_out + } + #[hdl] + #[track_caller] + pub fn unwrap_or_else(expr: Expr, f: impl FnOnce() -> Expr) -> Expr { + #[hdl] + let unwrap_or_else_out = wire(Expr::ty(expr).HdlSome); + connect(unwrap_or_else_out, f()); + #[hdl] + if let HdlSome(v) = expr { + connect(unwrap_or_else_out, v); + } + unwrap_or_else_out + } + #[hdl] + #[track_caller] + pub fn xor(expr: Expr, opt_b: Expr) -> Expr { + #[hdl] + let xor_out = wire(Expr::ty(expr)); + #[hdl] + if let HdlSome(_) = expr { + #[hdl] + if let HdlNone = opt_b { + connect(xor_out, expr); + } else { + connect(xor_out, Expr::ty(expr).HdlNone()); + } + } else { + connect(xor_out, opt_b); + } + xor_out + } + #[hdl] + #[track_caller] + pub fn zip(expr: Expr, other: Expr>) -> Expr> { + #[hdl] + let zip_out = wire(HdlOption[(Expr::ty(expr).HdlSome, Expr::ty(other).HdlSome)]); + connect(zip_out, Expr::ty(zip_out).HdlNone()); + #[hdl] + if let HdlSome(l) = expr { + #[hdl] + if let HdlSome(r) = other { + connect(zip_out, HdlSome((l, r))); + } + } + zip_out + } +} + +impl HdlOption> { + #[hdl] + #[track_caller] + pub fn flatten(expr: Expr) -> Expr> { + #[hdl] + let flattened = wire(Expr::ty(expr).HdlSome); + #[hdl] + match expr { + HdlSome(v) => connect(flattened, v), + HdlNone => connect(flattened, Expr::ty(expr).HdlSome.HdlNone()), + } + flattened + } +} + +impl HdlOption<(T, U)> { + #[hdl] + #[track_caller] + pub fn unzip(expr: Expr) -> Expr<(HdlOption, HdlOption)> { + let (t, u) = Expr::ty(expr).HdlSome; + #[hdl] + let unzipped = wire((HdlOption[t], HdlOption[u])); + connect(unzipped, (HdlOption[t].HdlNone(), HdlOption[u].HdlNone())); + #[hdl] + if let HdlSome(v) = expr { + connect(unzipped.0, HdlSome(v.0)); + connect(unzipped.1, HdlSome(v.1)); + } + unzipped + } +}