// SPDX-License-Identifier: LGPL-3.0-or-later // See Notices.txt for copyright information use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use std::{collections::BTreeMap, fs}; use syn::{fold::Fold, parse_quote}; pub mod ast; fn map_camel_case_to_snake_case(s: &str) -> String { #[derive(Clone, Copy, PartialEq, Eq)] enum State { Start, Lowercase, PushedUpper(char), } let mut state = State::Start; let mut retval = String::new(); for ch in s.chars() { state = match ch { 'A'..='Z' => { match state { State::Start => {} State::Lowercase => retval.push('_'), State::PushedUpper(upper) => retval.push(upper.to_ascii_lowercase()), } State::PushedUpper(ch) } _ => { match state { State::PushedUpper(upper) => { retval.push(upper.to_ascii_lowercase()); } State::Start | State::Lowercase => {} } retval.push(ch); State::Lowercase } }; } match state { State::Lowercase | State::Start => {} State::PushedUpper(upper) => retval.push(upper.to_ascii_lowercase()), } retval } #[derive(Clone)] struct DefinitionState { fn_name_suffix: syn::Ident, generics: syn::Generics, fold_generics: syn::Generics, folder_generics: syn::Generics, visit_generics: syn::Generics, visitor_generics: syn::Generics, } impl DefinitionState { fn folder_fn_name(&self) -> syn::Ident { format_ident!("fold_{}", self.fn_name_suffix) } fn visitor_fn_name(&self) -> syn::Ident { format_ident!("visit_{}", self.fn_name_suffix) } fn folder_fn(&self, path: &ast::Path) -> TokenStream { let folder_fn_name = self.folder_fn_name(); let (impl_generics, type_generics, where_clause) = self.folder_generics.split_for_impl(); quote! { fn #folder_fn_name #impl_generics( &mut self, v: #path #type_generics, ) -> Result<#path #type_generics, Self::Error> #where_clause { Fold::default_fold(v, self) } } } fn visitor_fn(&self, path: &ast::Path) -> TokenStream { let visitor_fn_name = self.visitor_fn_name(); let (impl_generics, type_generics, where_clause) = self.visitor_generics.split_for_impl(); quote! { fn #visitor_fn_name #impl_generics( &mut self, v: &#path #type_generics, ) -> Result<(), Self::Error> #where_clause { Visit::default_visit(v, self) } } } fn fold_impl(&self, path: &ast::Path, body: impl ToTokens) -> TokenStream { let folder_fn_name = self.folder_fn_name(); let (_, self_type_generics, _) = self.generics.split_for_impl(); let (trait_impl_generics, _, trait_where_clause) = self.fold_generics.split_for_impl(); quote! { #[automatically_derived] #[allow(clippy::init_numbered_fields)] impl #trait_impl_generics Fold for #path #self_type_generics #trait_where_clause { fn fold(self, state: &mut State) -> Result { state.#folder_fn_name(self) } fn default_fold(self, state: &mut State) -> Result { #body } } } } fn visit_impl(&self, path: &ast::Path, body: impl ToTokens) -> TokenStream { let visitor_fn_name = self.visitor_fn_name(); let (_, self_type_generics, _) = self.generics.split_for_impl(); let (trait_impl_generics, _, trait_where_clause) = self.visit_generics.split_for_impl(); quote! { #[automatically_derived] impl #trait_impl_generics Visit for #path #self_type_generics #trait_where_clause { fn visit(&self, state: &mut State) -> Result<(), State::Error> { state.#visitor_fn_name(self) } fn default_visit(&self, state: &mut State) -> Result<(), State::Error> { #body } } } } } struct GenerateState<'a> { def_states: BTreeMap<&'a ast::Path, DefinitionState>, definitions: &'a ast::Definitions, } struct MapStateToSelf; impl syn::fold::Fold for MapStateToSelf { fn fold_ident(&mut self, i: syn::Ident) -> syn::Ident { if i == "State" { syn::Ident::new("Self", i.span()) } else { i } } } impl<'a> GenerateState<'a> { fn make_definition_state(&mut self, path: &'a ast::Path) -> syn::Result<()> { let ast::Definition { fn_name_suffix, generics, fold_where, visit_where, data: _, } = self.definitions.types.get(path).ok_or_else(|| { syn::Error::new( Span::call_site(), format!("can't find named type: {path:?}"), ) })?; let fn_name_suffix = fn_name_suffix .as_ref() .map(syn::Ident::from) .unwrap_or_else(|| format_ident!("{}", map_camel_case_to_snake_case(&path.last().0))); let generics = generics.clone().map(|v| v.0).unwrap_or_default(); let mut fold_generics = generics.clone(); let mut folder_generics = generics.clone(); fold_generics .params .insert(0, parse_quote! {State: ?Sized + Folder}); if let Some(fold_where) = fold_where { fold_generics .make_where_clause() .predicates .extend(fold_where.0.iter().cloned()); folder_generics.make_where_clause().predicates.extend( fold_where .0 .iter() .cloned() .map(|v| MapStateToSelf.fold_where_predicate(v)), ); } let mut visit_generics = generics.clone(); let mut visitor_generics = generics.clone(); visit_generics .params .insert(0, parse_quote! {State: ?Sized + Visitor}); if let Some(visit_where) = visit_where { visit_generics .make_where_clause() .predicates .extend(visit_where.0.iter().cloned()); visitor_generics.make_where_clause().predicates.extend( visit_where .0 .iter() .cloned() .map(|v| MapStateToSelf.fold_where_predicate(v)), ); } self.def_states.insert( path, DefinitionState { fn_name_suffix, generics, fold_generics, folder_generics, visit_generics, visitor_generics, }, ); Ok(()) } fn new(ast: &'a ast::Definitions) -> syn::Result { let mut retval = GenerateState { def_states: BTreeMap::new(), definitions: ast, }; let ast::Definitions { types } = ast; for path in types.keys() { retval.make_definition_state(path)?; } Ok(retval) } } pub fn generate(ast: &ast::Definitions) -> syn::Result { let state = GenerateState::new(ast)?; let mut visitor_fns = vec![]; let mut visit_impls = vec![]; let mut folder_fns = vec![]; let mut fold_impls = vec![]; for (&def_path, def_state) in state.def_states.iter() { folder_fns.push(def_state.folder_fn(def_path)); visitor_fns.push(def_state.visitor_fn(def_path)); let fold_body; let visit_body; let ast::Definition { fn_name_suffix: _, generics: _, fold_where: _, visit_where: _, data, } = ast.types.get(def_path).unwrap(); match data { ast::Data::ManualImpl => { continue; } ast::Data::Opaque => { fold_body = quote! { let _ = state; Ok(self) }; visit_body = quote! { let _ = state; Ok(()) }; } ast::Data::Struct(ast::Fields { constructor, fields, }) => { let mut visit_members = vec![]; let mut fold_members = vec![]; for (field_name, field) in fields { let fold_member_name = if constructor.is_some() { None } else { let member = field_name.to_member(); if member.is_none() { return Err(syn::Error::new( Span::call_site(), format!( "struct must have `$constructor` since it contains a \ non-plain field: {def_path:?} {field_name:?}" ), )); } member }; let fold_member_name = fold_member_name.as_slice(); let fold_member = match field { ast::Field::Opaque => { quote! { #(#fold_member_name:)* self.#field_name } } ast::Field::Visible => { visit_members.push(quote! { Visit::visit(&self.#field_name, state)?; }); quote! { #(#fold_member_name:)* Fold::fold(self.#field_name, state)? } } ast::Field::RefVisible => { visit_members.push(quote! { Visit::visit(self.#field_name, state)?; }); quote! { #(#fold_member_name:)* Fold::fold(self.#field_name.clone(), state)? } } }; fold_members.push(fold_member); } let match_members = constructor .is_none() .then(|| { fields .keys() .map(|k| k.to_member()) .collect::>>() .map(|members| { if members.is_empty() { quote! { let _ = state; let Self {} = self; } } else { quote! { let Self { #(#members: _,)* } = self; } } }) }) .flatten(); visit_body = quote! { #match_members #(#visit_members)* Ok(()) }; let fold_body_tail = if let Some(constructor) = constructor { quote! { Ok(#constructor(#(#fold_members),*)) } } else { quote! { Ok(Self { #(#fold_members,)* }) } }; fold_body = quote! { #match_members #fold_body_tail }; } ast::Data::Enum(ast::Variants { variants }) => { let mut fold_arms = vec![]; let mut visit_arms = vec![]; let mut state_unused = true; for (variant_name, variant_field) in variants { let fold_arm; let visit_arm; match variant_field { Some(ast::Field::Visible) => { state_unused = false; fold_arm = quote! { Self::#variant_name(v) => Fold::fold(v, state) .map(Self::#variant_name), }; visit_arm = quote! { Self::#variant_name(v) => Visit::visit(v, state), }; } Some(ast::Field::RefVisible) => { return Err(syn::Error::new( Span::call_site(), "enum variant field must not be RefVisible", )); } Some(ast::Field::Opaque) => { fold_arm = quote! { Self::#variant_name(_) => Ok(self), }; visit_arm = quote! { Self::#variant_name(_) => Ok(()), }; } None => { fold_arm = quote! { Self::#variant_name => Ok(self), }; visit_arm = quote! { Self::#variant_name => Ok(()), }; } } fold_arms.push(fold_arm); visit_arms.push(visit_arm); } let ignore_state = state_unused.then(|| { quote! { let _ = state; } }); visit_body = quote! { #ignore_state match self { #(#visit_arms)* } }; fold_body = quote! { #ignore_state match self { #(#fold_arms)* } }; } } fold_impls.push(def_state.fold_impl(def_path, fold_body)); visit_impls.push(def_state.visit_impl(def_path, visit_body)); } Ok(prettyplease::unparse(&parse_quote! { pub trait Visitor { type Error; #(#visitor_fns)* } #(#visit_impls)* pub trait Folder { type Error; #(#folder_fns)* } #(#fold_impls)* })) } pub fn error_at_call_site(e: T) -> syn::Error { syn::Error::new(Span::call_site(), e) } pub fn parse_and_generate(path: impl AsRef) -> syn::Result { let input = fs::read_to_string(path).map_err(error_at_call_site)?; let ast: ast::Definitions = serde_json::from_str(&input).map_err(error_at_call_site)?; generate(&ast) }