forked from libre-chip/fayalite
		
	
		
			
				
	
	
		
			447 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			447 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| // SPDX-License-Identifier: LGPL-3.0-or-later
 | |
| // See Notices.txt for copyright information
 | |
| use proc_macro2::{Span, TokenStream};
 | |
| use quote::{ToTokens, format_ident, quote};
 | |
| use std::{collections::BTreeMap, fs};
 | |
| use syn::{fold::Fold, parse_quote};
 | |
| 
 | |
| pub mod ast;
 | |
| 
 | |
| fn map_camel_case_to_snake_case(s: &str) -> String {
 | |
|     #[derive(Clone, Copy, PartialEq, Eq)]
 | |
|     enum State {
 | |
|         Start,
 | |
|         Lowercase,
 | |
|         PushedUpper(char),
 | |
|     }
 | |
|     let mut state = State::Start;
 | |
|     let mut retval = String::new();
 | |
|     for ch in s.chars() {
 | |
|         state = match ch {
 | |
|             'A'..='Z' => {
 | |
|                 match state {
 | |
|                     State::Start => {}
 | |
|                     State::Lowercase => retval.push('_'),
 | |
|                     State::PushedUpper(upper) => retval.push(upper.to_ascii_lowercase()),
 | |
|                 }
 | |
|                 State::PushedUpper(ch)
 | |
|             }
 | |
|             _ => {
 | |
|                 match state {
 | |
|                     State::PushedUpper(upper) => {
 | |
|                         retval.push(upper.to_ascii_lowercase());
 | |
|                     }
 | |
|                     State::Start | State::Lowercase => {}
 | |
|                 }
 | |
|                 retval.push(ch);
 | |
|                 State::Lowercase
 | |
|             }
 | |
|         };
 | |
|     }
 | |
|     match state {
 | |
|         State::Lowercase | State::Start => {}
 | |
|         State::PushedUpper(upper) => retval.push(upper.to_ascii_lowercase()),
 | |
|     }
 | |
|     retval
 | |
| }
 | |
| 
 | |
| #[derive(Clone)]
 | |
| struct DefinitionState {
 | |
|     fn_name_suffix: syn::Ident,
 | |
|     generics: syn::Generics,
 | |
|     fold_generics: syn::Generics,
 | |
|     folder_generics: syn::Generics,
 | |
|     visit_generics: syn::Generics,
 | |
|     visitor_generics: syn::Generics,
 | |
| }
 | |
| 
 | |
| impl DefinitionState {
 | |
|     fn folder_fn_name(&self) -> syn::Ident {
 | |
|         format_ident!("fold_{}", self.fn_name_suffix)
 | |
|     }
 | |
|     fn visitor_fn_name(&self) -> syn::Ident {
 | |
|         format_ident!("visit_{}", self.fn_name_suffix)
 | |
|     }
 | |
|     fn folder_fn(&self, path: &ast::Path) -> TokenStream {
 | |
|         let folder_fn_name = self.folder_fn_name();
 | |
|         let (impl_generics, type_generics, where_clause) = self.folder_generics.split_for_impl();
 | |
|         quote! {
 | |
|             fn #folder_fn_name #impl_generics(
 | |
|                 &mut self,
 | |
|                 v: #path #type_generics,
 | |
|             ) -> Result<#path #type_generics, Self::Error>
 | |
|             #where_clause
 | |
|             {
 | |
|                 Fold::default_fold(v, self)
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     fn visitor_fn(&self, path: &ast::Path) -> TokenStream {
 | |
|         let visitor_fn_name = self.visitor_fn_name();
 | |
|         let (impl_generics, type_generics, where_clause) = self.visitor_generics.split_for_impl();
 | |
|         quote! {
 | |
|             fn #visitor_fn_name #impl_generics(
 | |
|                 &mut self,
 | |
|                 v: &#path #type_generics,
 | |
|             ) -> Result<(), Self::Error>
 | |
|             #where_clause
 | |
|             {
 | |
|                 Visit::default_visit(v, self)
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     fn fold_impl(&self, path: &ast::Path, body: impl ToTokens) -> TokenStream {
 | |
|         let folder_fn_name = self.folder_fn_name();
 | |
|         let (_, self_type_generics, _) = self.generics.split_for_impl();
 | |
|         let (trait_impl_generics, _, trait_where_clause) = self.fold_generics.split_for_impl();
 | |
|         quote! {
 | |
|             #[automatically_derived]
 | |
|             #[allow(clippy::init_numbered_fields)]
 | |
|             impl #trait_impl_generics Fold<State> for #path #self_type_generics
 | |
|             #trait_where_clause
 | |
|             {
 | |
|                 fn fold(self, state: &mut State) -> Result<Self, State::Error> {
 | |
|                     state.#folder_fn_name(self)
 | |
|                 }
 | |
|                 fn default_fold(self, state: &mut State) -> Result<Self, State::Error> {
 | |
|                     #body
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     fn visit_impl(&self, path: &ast::Path, body: impl ToTokens) -> TokenStream {
 | |
|         let visitor_fn_name = self.visitor_fn_name();
 | |
|         let (_, self_type_generics, _) = self.generics.split_for_impl();
 | |
|         let (trait_impl_generics, _, trait_where_clause) = self.visit_generics.split_for_impl();
 | |
|         quote! {
 | |
|             #[automatically_derived]
 | |
|             impl #trait_impl_generics Visit<State> for #path #self_type_generics
 | |
|             #trait_where_clause
 | |
|             {
 | |
|                 fn visit(&self, state: &mut State) -> Result<(), State::Error> {
 | |
|                     state.#visitor_fn_name(self)
 | |
|                 }
 | |
|                 fn default_visit(&self, state: &mut State) -> Result<(), State::Error> {
 | |
|                     #body
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| struct GenerateState<'a> {
 | |
|     def_states: BTreeMap<&'a ast::Path, DefinitionState>,
 | |
|     definitions: &'a ast::Definitions,
 | |
| }
 | |
| 
 | |
| struct MapStateToSelf;
 | |
| 
 | |
| impl syn::fold::Fold for MapStateToSelf {
 | |
|     fn fold_ident(&mut self, i: syn::Ident) -> syn::Ident {
 | |
|         if i == "State" {
 | |
|             syn::Ident::new("Self", i.span())
 | |
|         } else {
 | |
|             i
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl<'a> GenerateState<'a> {
 | |
|     fn make_definition_state(&mut self, path: &'a ast::Path) -> syn::Result<()> {
 | |
|         let ast::Definition {
 | |
|             fn_name_suffix,
 | |
|             generics,
 | |
|             fold_where,
 | |
|             visit_where,
 | |
|             data: _,
 | |
|         } = self.definitions.types.get(path).ok_or_else(|| {
 | |
|             syn::Error::new(
 | |
|                 Span::call_site(),
 | |
|                 format!("can't find named type: {path:?}"),
 | |
|             )
 | |
|         })?;
 | |
|         let fn_name_suffix = fn_name_suffix
 | |
|             .as_ref()
 | |
|             .map(syn::Ident::from)
 | |
|             .unwrap_or_else(|| format_ident!("{}", map_camel_case_to_snake_case(&path.last().0)));
 | |
|         let generics = generics.clone().map(|v| v.0).unwrap_or_default();
 | |
|         let mut fold_generics = generics.clone();
 | |
|         let mut folder_generics = generics.clone();
 | |
|         fold_generics
 | |
|             .params
 | |
|             .insert(0, parse_quote! {State: ?Sized + Folder});
 | |
|         if let Some(fold_where) = fold_where {
 | |
|             fold_generics
 | |
|                 .make_where_clause()
 | |
|                 .predicates
 | |
|                 .extend(fold_where.0.iter().cloned());
 | |
|             folder_generics.make_where_clause().predicates.extend(
 | |
|                 fold_where
 | |
|                     .0
 | |
|                     .iter()
 | |
|                     .cloned()
 | |
|                     .map(|v| MapStateToSelf.fold_where_predicate(v)),
 | |
|             );
 | |
|         }
 | |
|         let mut visit_generics = generics.clone();
 | |
|         let mut visitor_generics = generics.clone();
 | |
|         visit_generics
 | |
|             .params
 | |
|             .insert(0, parse_quote! {State: ?Sized + Visitor});
 | |
|         if let Some(visit_where) = visit_where {
 | |
|             visit_generics
 | |
|                 .make_where_clause()
 | |
|                 .predicates
 | |
|                 .extend(visit_where.0.iter().cloned());
 | |
|             visitor_generics.make_where_clause().predicates.extend(
 | |
|                 visit_where
 | |
|                     .0
 | |
|                     .iter()
 | |
|                     .cloned()
 | |
|                     .map(|v| MapStateToSelf.fold_where_predicate(v)),
 | |
|             );
 | |
|         }
 | |
|         self.def_states.insert(
 | |
|             path,
 | |
|             DefinitionState {
 | |
|                 fn_name_suffix,
 | |
|                 generics,
 | |
|                 fold_generics,
 | |
|                 folder_generics,
 | |
|                 visit_generics,
 | |
|                 visitor_generics,
 | |
|             },
 | |
|         );
 | |
|         Ok(())
 | |
|     }
 | |
|     fn new(ast: &'a ast::Definitions) -> syn::Result<Self> {
 | |
|         let mut retval = GenerateState {
 | |
|             def_states: BTreeMap::new(),
 | |
|             definitions: ast,
 | |
|         };
 | |
|         let ast::Definitions { types } = ast;
 | |
|         for path in types.keys() {
 | |
|             retval.make_definition_state(path)?;
 | |
|         }
 | |
|         Ok(retval)
 | |
|     }
 | |
| }
 | |
| 
 | |
| pub fn generate(ast: &ast::Definitions) -> syn::Result<String> {
 | |
|     let state = GenerateState::new(ast)?;
 | |
|     let mut visitor_fns = vec![];
 | |
|     let mut visit_impls = vec![];
 | |
|     let mut folder_fns = vec![];
 | |
|     let mut fold_impls = vec![];
 | |
|     for (&def_path, def_state) in state.def_states.iter() {
 | |
|         folder_fns.push(def_state.folder_fn(def_path));
 | |
|         visitor_fns.push(def_state.visitor_fn(def_path));
 | |
|         let fold_body;
 | |
|         let visit_body;
 | |
|         let ast::Definition {
 | |
|             fn_name_suffix: _,
 | |
|             generics: _,
 | |
|             fold_where: _,
 | |
|             visit_where: _,
 | |
|             data,
 | |
|         } = ast.types.get(def_path).unwrap();
 | |
|         match data {
 | |
|             ast::Data::ManualImpl => {
 | |
|                 continue;
 | |
|             }
 | |
|             ast::Data::Opaque => {
 | |
|                 fold_body = quote! {
 | |
|                     let _ = state;
 | |
|                     Ok(self)
 | |
|                 };
 | |
|                 visit_body = quote! {
 | |
|                     let _ = state;
 | |
|                     Ok(())
 | |
|                 };
 | |
|             }
 | |
|             ast::Data::Struct(ast::Fields {
 | |
|                 constructor,
 | |
|                 fields,
 | |
|             }) => {
 | |
|                 let mut visit_members = vec![];
 | |
|                 let mut fold_members = vec![];
 | |
|                 for (field_name, field) in fields {
 | |
|                     let fold_member_name = if constructor.is_some() {
 | |
|                         None
 | |
|                     } else {
 | |
|                         let member = field_name.to_member();
 | |
|                         if member.is_none() {
 | |
|                             return Err(syn::Error::new(
 | |
|                                 Span::call_site(),
 | |
|                                 format!(
 | |
|                                     "struct must have `$constructor` since it contains a \
 | |
|                                     non-plain field: {def_path:?} {field_name:?}"
 | |
|                                 ),
 | |
|                             ));
 | |
|                         }
 | |
|                         member
 | |
|                     };
 | |
|                     let fold_member_name = fold_member_name.as_slice();
 | |
|                     let fold_member = match field {
 | |
|                         ast::Field::Opaque => {
 | |
|                             quote! {
 | |
|                                 #(#fold_member_name:)* self.#field_name
 | |
|                             }
 | |
|                         }
 | |
|                         ast::Field::Visible => {
 | |
|                             visit_members.push(quote! {
 | |
|                                 Visit::visit(&self.#field_name, state)?;
 | |
|                             });
 | |
|                             quote! {
 | |
|                                 #(#fold_member_name:)* Fold::fold(self.#field_name, state)?
 | |
|                             }
 | |
|                         }
 | |
|                         ast::Field::RefVisible => {
 | |
|                             visit_members.push(quote! {
 | |
|                                 Visit::visit(self.#field_name, state)?;
 | |
|                             });
 | |
|                             quote! {
 | |
|                                 #(#fold_member_name:)* Fold::fold(self.#field_name.clone(), state)?
 | |
|                             }
 | |
|                         }
 | |
|                     };
 | |
|                     fold_members.push(fold_member);
 | |
|                 }
 | |
|                 let match_members = constructor
 | |
|                     .is_none()
 | |
|                     .then(|| {
 | |
|                         fields
 | |
|                             .keys()
 | |
|                             .map(|k| k.to_member())
 | |
|                             .collect::<Option<Vec<_>>>()
 | |
|                             .map(|members| {
 | |
|                                 if members.is_empty() {
 | |
|                                     quote! {
 | |
|                                         let _ = state;
 | |
|                                         let Self {} = self;
 | |
|                                     }
 | |
|                                 } else {
 | |
|                                     quote! {
 | |
|                                         let Self {
 | |
|                                             #(#members: _,)*
 | |
|                                         } = self;
 | |
|                                     }
 | |
|                                 }
 | |
|                             })
 | |
|                     })
 | |
|                     .flatten();
 | |
|                 visit_body = quote! {
 | |
|                     #match_members
 | |
|                     #(#visit_members)*
 | |
|                     Ok(())
 | |
|                 };
 | |
|                 let fold_body_tail = if let Some(constructor) = constructor {
 | |
|                     quote! {
 | |
|                         Ok(#constructor(#(#fold_members),*))
 | |
|                     }
 | |
|                 } else {
 | |
|                     quote! {
 | |
|                         Ok(Self {
 | |
|                             #(#fold_members,)*
 | |
|                         })
 | |
|                     }
 | |
|                 };
 | |
|                 fold_body = quote! {
 | |
|                     #match_members
 | |
|                     #fold_body_tail
 | |
|                 };
 | |
|             }
 | |
|             ast::Data::Enum(ast::Variants { variants }) => {
 | |
|                 let mut fold_arms = vec![];
 | |
|                 let mut visit_arms = vec![];
 | |
|                 let mut state_unused = true;
 | |
|                 for (variant_name, variant_field) in variants {
 | |
|                     let fold_arm;
 | |
|                     let visit_arm;
 | |
|                     match variant_field {
 | |
|                         Some(ast::Field::Visible) => {
 | |
|                             state_unused = false;
 | |
|                             fold_arm = quote! {
 | |
|                                 Self::#variant_name(v) => Fold::fold(v, state)
 | |
|                                                             .map(Self::#variant_name),
 | |
|                             };
 | |
|                             visit_arm = quote! {
 | |
|                                 Self::#variant_name(v) => Visit::visit(v, state),
 | |
|                             };
 | |
|                         }
 | |
|                         Some(ast::Field::RefVisible) => {
 | |
|                             return Err(syn::Error::new(
 | |
|                                 Span::call_site(),
 | |
|                                 "enum variant field must not be RefVisible",
 | |
|                             ));
 | |
|                         }
 | |
|                         Some(ast::Field::Opaque) => {
 | |
|                             fold_arm = quote! {
 | |
|                                 Self::#variant_name(_) => Ok(self),
 | |
|                             };
 | |
|                             visit_arm = quote! {
 | |
|                                 Self::#variant_name(_) => Ok(()),
 | |
|                             };
 | |
|                         }
 | |
|                         None => {
 | |
|                             fold_arm = quote! {
 | |
|                                 Self::#variant_name => Ok(self),
 | |
|                             };
 | |
|                             visit_arm = quote! {
 | |
|                                 Self::#variant_name => Ok(()),
 | |
|                             };
 | |
|                         }
 | |
|                     }
 | |
|                     fold_arms.push(fold_arm);
 | |
|                     visit_arms.push(visit_arm);
 | |
|                 }
 | |
|                 let ignore_state = state_unused.then(|| {
 | |
|                     quote! {
 | |
|                         let _ = state;
 | |
|                     }
 | |
|                 });
 | |
|                 visit_body = quote! {
 | |
|                     #ignore_state
 | |
|                     match self {
 | |
|                         #(#visit_arms)*
 | |
|                     }
 | |
|                 };
 | |
|                 fold_body = quote! {
 | |
|                     #ignore_state
 | |
|                     match self {
 | |
|                         #(#fold_arms)*
 | |
|                     }
 | |
|                 };
 | |
|             }
 | |
|         }
 | |
|         fold_impls.push(def_state.fold_impl(def_path, fold_body));
 | |
|         visit_impls.push(def_state.visit_impl(def_path, visit_body));
 | |
|     }
 | |
|     Ok(prettyplease::unparse(&parse_quote! {
 | |
|         pub trait Visitor {
 | |
|             type Error;
 | |
| 
 | |
|             #(#visitor_fns)*
 | |
|         }
 | |
| 
 | |
|         #(#visit_impls)*
 | |
| 
 | |
|         pub trait Folder {
 | |
|             type Error;
 | |
| 
 | |
|             #(#folder_fns)*
 | |
|         }
 | |
| 
 | |
|         #(#fold_impls)*
 | |
|     }))
 | |
| }
 | |
| 
 | |
| pub fn error_at_call_site<T: std::fmt::Display>(e: T) -> syn::Error {
 | |
|     syn::Error::new(Span::call_site(), e)
 | |
| }
 | |
| 
 | |
| pub fn parse_and_generate(path: impl AsRef<std::path::Path>) -> syn::Result<String> {
 | |
|     let input = fs::read_to_string(path).map_err(error_at_call_site)?;
 | |
|     let ast: ast::Definitions = serde_json::from_str(&input).map_err(error_at_call_site)?;
 | |
|     generate(&ast)
 | |
| }
 |