427 lines
15 KiB
Rust
427 lines
15 KiB
Rust
// SPDX-License-Identifier: LGPL-3.0-or-later
|
|
// See Notices.txt for copyright information
|
|
use proc_macro2::{Span, TokenStream};
|
|
use quote::{format_ident, quote, ToTokens};
|
|
use std::{collections::BTreeMap, fs};
|
|
use syn::{fold::Fold, parse_quote};
|
|
|
|
pub mod ast;
|
|
|
|
fn map_camel_case_to_snake_case(s: &str) -> String {
|
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
|
enum State {
|
|
Start,
|
|
Lowercase,
|
|
PushedUpper(char),
|
|
}
|
|
let mut state = State::Start;
|
|
let mut retval = String::new();
|
|
for ch in s.chars() {
|
|
state = match ch {
|
|
'A'..='Z' => {
|
|
match state {
|
|
State::Start => {}
|
|
State::Lowercase => retval.push('_'),
|
|
State::PushedUpper(upper) => retval.push(upper.to_ascii_lowercase()),
|
|
}
|
|
State::PushedUpper(ch)
|
|
}
|
|
_ => {
|
|
match state {
|
|
State::PushedUpper(upper) => {
|
|
retval.push(upper.to_ascii_lowercase());
|
|
}
|
|
State::Start | State::Lowercase => {}
|
|
}
|
|
retval.push(ch);
|
|
State::Lowercase
|
|
}
|
|
};
|
|
}
|
|
match state {
|
|
State::Lowercase | State::Start => {}
|
|
State::PushedUpper(upper) => retval.push(upper.to_ascii_lowercase()),
|
|
}
|
|
retval
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct DefinitionState {
|
|
fn_name_suffix: syn::Ident,
|
|
generics: syn::Generics,
|
|
fold_generics: syn::Generics,
|
|
folder_generics: syn::Generics,
|
|
visit_generics: syn::Generics,
|
|
visitor_generics: syn::Generics,
|
|
}
|
|
|
|
impl DefinitionState {
|
|
fn folder_fn_name(&self) -> syn::Ident {
|
|
format_ident!("fold_{}", self.fn_name_suffix)
|
|
}
|
|
fn visitor_fn_name(&self) -> syn::Ident {
|
|
format_ident!("visit_{}", self.fn_name_suffix)
|
|
}
|
|
fn folder_fn(&self, path: &ast::Path) -> TokenStream {
|
|
let folder_fn_name = self.folder_fn_name();
|
|
let (impl_generics, type_generics, where_clause) = self.folder_generics.split_for_impl();
|
|
quote! {
|
|
fn #folder_fn_name #impl_generics(&mut self, v: #path #type_generics) -> Result<#path #type_generics, Self::Error> #where_clause {
|
|
Fold::default_fold(v, self)
|
|
}
|
|
}
|
|
}
|
|
fn visitor_fn(&self, path: &ast::Path) -> TokenStream {
|
|
let visitor_fn_name = self.visitor_fn_name();
|
|
let (impl_generics, type_generics, where_clause) = self.visitor_generics.split_for_impl();
|
|
quote! {
|
|
fn #visitor_fn_name #impl_generics(&mut self, v: &#path #type_generics) -> Result<(), Self::Error> #where_clause {
|
|
Visit::default_visit(v, self)
|
|
}
|
|
}
|
|
}
|
|
fn fold_impl(&self, path: &ast::Path, body: impl ToTokens) -> TokenStream {
|
|
let folder_fn_name = self.folder_fn_name();
|
|
let (_, self_type_generics, _) = self.generics.split_for_impl();
|
|
let (trait_impl_generics, _, trait_where_clause) = self.fold_generics.split_for_impl();
|
|
quote! {
|
|
#[automatically_derived]
|
|
#[allow(clippy::init_numbered_fields)]
|
|
impl #trait_impl_generics Fold<State> for #path #self_type_generics #trait_where_clause {
|
|
fn fold(self, state: &mut State) -> Result<Self, State::Error> {
|
|
state.#folder_fn_name(self)
|
|
}
|
|
fn default_fold(self, state: &mut State) -> Result<Self, State::Error> {
|
|
#body
|
|
}
|
|
}
|
|
}
|
|
}
|
|
fn visit_impl(&self, path: &ast::Path, body: impl ToTokens) -> TokenStream {
|
|
let visitor_fn_name = self.visitor_fn_name();
|
|
let (_, self_type_generics, _) = self.generics.split_for_impl();
|
|
let (trait_impl_generics, _, trait_where_clause) = self.visit_generics.split_for_impl();
|
|
quote! {
|
|
#[automatically_derived]
|
|
impl #trait_impl_generics Visit<State> for #path #self_type_generics #trait_where_clause {
|
|
fn visit(&self, state: &mut State) -> Result<(), State::Error> {
|
|
state.#visitor_fn_name(self)
|
|
}
|
|
fn default_visit(&self, state: &mut State) -> Result<(), State::Error> {
|
|
#body
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct GenerateState<'a> {
|
|
def_states: BTreeMap<&'a ast::Path, DefinitionState>,
|
|
definitions: &'a ast::Definitions,
|
|
}
|
|
|
|
struct MapStateToSelf;
|
|
|
|
impl syn::fold::Fold for MapStateToSelf {
|
|
fn fold_ident(&mut self, i: syn::Ident) -> syn::Ident {
|
|
if i == "State" {
|
|
syn::Ident::new("Self", i.span())
|
|
} else {
|
|
i
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a> GenerateState<'a> {
|
|
fn make_definition_state(&mut self, path: &'a ast::Path) -> syn::Result<()> {
|
|
let ast::Definition {
|
|
fn_name_suffix,
|
|
generics,
|
|
fold_where,
|
|
visit_where,
|
|
data: _,
|
|
} = self.definitions.types.get(path).ok_or_else(|| {
|
|
syn::Error::new(
|
|
Span::call_site(),
|
|
format!("can't find named type: {path:?}"),
|
|
)
|
|
})?;
|
|
let fn_name_suffix = fn_name_suffix
|
|
.as_ref()
|
|
.map(syn::Ident::from)
|
|
.unwrap_or_else(|| format_ident!("{}", map_camel_case_to_snake_case(&path.last().0)));
|
|
let generics = generics.clone().map(|v| v.0).unwrap_or_default();
|
|
let mut fold_generics = generics.clone();
|
|
let mut folder_generics = generics.clone();
|
|
fold_generics
|
|
.params
|
|
.insert(0, parse_quote! {State: ?Sized + Folder});
|
|
if let Some(fold_where) = fold_where {
|
|
fold_generics
|
|
.make_where_clause()
|
|
.predicates
|
|
.extend(fold_where.0.iter().cloned());
|
|
folder_generics.make_where_clause().predicates.extend(
|
|
fold_where
|
|
.0
|
|
.iter()
|
|
.cloned()
|
|
.map(|v| MapStateToSelf.fold_where_predicate(v)),
|
|
);
|
|
}
|
|
let mut visit_generics = generics.clone();
|
|
let mut visitor_generics = generics.clone();
|
|
visit_generics
|
|
.params
|
|
.insert(0, parse_quote! {State: ?Sized + Visitor});
|
|
if let Some(visit_where) = visit_where {
|
|
visit_generics
|
|
.make_where_clause()
|
|
.predicates
|
|
.extend(visit_where.0.iter().cloned());
|
|
visitor_generics.make_where_clause().predicates.extend(
|
|
visit_where
|
|
.0
|
|
.iter()
|
|
.cloned()
|
|
.map(|v| MapStateToSelf.fold_where_predicate(v)),
|
|
);
|
|
}
|
|
self.def_states.insert(
|
|
path,
|
|
DefinitionState {
|
|
fn_name_suffix,
|
|
generics,
|
|
fold_generics,
|
|
folder_generics,
|
|
visit_generics,
|
|
visitor_generics,
|
|
},
|
|
);
|
|
Ok(())
|
|
}
|
|
fn new(ast: &'a ast::Definitions) -> syn::Result<Self> {
|
|
let mut retval = GenerateState {
|
|
def_states: BTreeMap::new(),
|
|
definitions: ast,
|
|
};
|
|
let ast::Definitions { types } = ast;
|
|
for path in types.keys() {
|
|
retval.make_definition_state(path)?;
|
|
}
|
|
Ok(retval)
|
|
}
|
|
}
|
|
|
|
pub fn generate(ast: &ast::Definitions) -> syn::Result<String> {
|
|
let state = GenerateState::new(ast)?;
|
|
let mut visitor_fns = vec![];
|
|
let mut visit_impls = vec![];
|
|
let mut folder_fns = vec![];
|
|
let mut fold_impls = vec![];
|
|
for (&def_path, def_state) in state.def_states.iter() {
|
|
folder_fns.push(def_state.folder_fn(def_path));
|
|
visitor_fns.push(def_state.visitor_fn(def_path));
|
|
let fold_body;
|
|
let visit_body;
|
|
let ast::Definition {
|
|
fn_name_suffix: _,
|
|
generics: _,
|
|
fold_where: _,
|
|
visit_where: _,
|
|
data,
|
|
} = ast.types.get(def_path).unwrap();
|
|
match data {
|
|
ast::Data::ManualImpl => {
|
|
continue;
|
|
}
|
|
ast::Data::Opaque => {
|
|
fold_body = quote! {
|
|
let _ = state;
|
|
Ok(self)
|
|
};
|
|
visit_body = quote! {
|
|
let _ = state;
|
|
Ok(())
|
|
};
|
|
}
|
|
ast::Data::Struct(ast::Fields {
|
|
constructor,
|
|
fields,
|
|
}) => {
|
|
let mut visit_members = vec![];
|
|
let mut fold_members = vec![];
|
|
for (field_name, field) in fields {
|
|
let fold_member_name = if constructor.is_some() {
|
|
None
|
|
} else {
|
|
let member = field_name.to_member();
|
|
if member.is_none() {
|
|
return Err(syn::Error::new(Span::call_site(), format!("struct must have `$constructor` since it contains a non-plain field: {def_path:?} {field_name:?}")));
|
|
}
|
|
member
|
|
};
|
|
let fold_member_name = fold_member_name.as_slice();
|
|
let fold_member = match field {
|
|
ast::Field::Opaque => {
|
|
quote! {
|
|
#(#fold_member_name:)* self.#field_name
|
|
}
|
|
}
|
|
ast::Field::Visible => {
|
|
visit_members.push(quote! {
|
|
Visit::visit(&self.#field_name, state)?;
|
|
});
|
|
quote! {
|
|
#(#fold_member_name:)* Fold::fold(self.#field_name, state)?
|
|
}
|
|
}
|
|
ast::Field::RefVisible => {
|
|
visit_members.push(quote! {
|
|
Visit::visit(self.#field_name, state)?;
|
|
});
|
|
quote! {
|
|
#(#fold_member_name:)* Fold::fold(self.#field_name.clone(), state)?
|
|
}
|
|
}
|
|
};
|
|
fold_members.push(fold_member);
|
|
}
|
|
let match_members = constructor
|
|
.is_none()
|
|
.then(|| {
|
|
fields
|
|
.keys()
|
|
.map(|k| k.to_member())
|
|
.collect::<Option<Vec<_>>>()
|
|
.map(|members| {
|
|
if members.is_empty() {
|
|
quote! {
|
|
let _ = state;
|
|
let Self {} = self;
|
|
}
|
|
} else {
|
|
quote! {
|
|
let Self {
|
|
#(#members: _,)*
|
|
} = self;
|
|
}
|
|
}
|
|
})
|
|
})
|
|
.flatten();
|
|
visit_body = quote! {
|
|
#match_members
|
|
#(#visit_members)*
|
|
Ok(())
|
|
};
|
|
let fold_body_tail = if let Some(constructor) = constructor {
|
|
quote! {
|
|
Ok(#constructor(#(#fold_members),*))
|
|
}
|
|
} else {
|
|
quote! {
|
|
Ok(Self {
|
|
#(#fold_members,)*
|
|
})
|
|
}
|
|
};
|
|
fold_body = quote! {
|
|
#match_members
|
|
#fold_body_tail
|
|
};
|
|
}
|
|
ast::Data::Enum(ast::Variants { variants }) => {
|
|
let mut fold_arms = vec![];
|
|
let mut visit_arms = vec![];
|
|
let mut state_unused = true;
|
|
for (variant_name, variant_field) in variants {
|
|
let fold_arm;
|
|
let visit_arm;
|
|
match variant_field {
|
|
Some(ast::Field::Visible) => {
|
|
state_unused = false;
|
|
fold_arm = quote! {
|
|
Self::#variant_name(v) => Fold::fold(v, state).map(Self::#variant_name),
|
|
};
|
|
visit_arm = quote! {
|
|
Self::#variant_name(v) => Visit::visit(v, state),
|
|
};
|
|
}
|
|
Some(ast::Field::RefVisible) => {
|
|
return Err(syn::Error::new(
|
|
Span::call_site(),
|
|
"enum variant field must not be RefVisible",
|
|
));
|
|
}
|
|
Some(ast::Field::Opaque) => {
|
|
fold_arm = quote! {
|
|
Self::#variant_name(_) => Ok(self),
|
|
};
|
|
visit_arm = quote! {
|
|
Self::#variant_name(_) => Ok(()),
|
|
};
|
|
}
|
|
None => {
|
|
fold_arm = quote! {
|
|
Self::#variant_name => Ok(self),
|
|
};
|
|
visit_arm = quote! {
|
|
Self::#variant_name => Ok(()),
|
|
};
|
|
}
|
|
}
|
|
fold_arms.push(fold_arm);
|
|
visit_arms.push(visit_arm);
|
|
}
|
|
let ignore_state = state_unused.then(|| {
|
|
quote! {
|
|
let _ = state;
|
|
}
|
|
});
|
|
visit_body = quote! {
|
|
#ignore_state
|
|
match self {
|
|
#(#visit_arms)*
|
|
}
|
|
};
|
|
fold_body = quote! {
|
|
#ignore_state
|
|
match self {
|
|
#(#fold_arms)*
|
|
}
|
|
};
|
|
}
|
|
}
|
|
fold_impls.push(def_state.fold_impl(def_path, fold_body));
|
|
visit_impls.push(def_state.visit_impl(def_path, visit_body));
|
|
}
|
|
Ok(prettyplease::unparse(&parse_quote! {
|
|
pub trait Visitor {
|
|
type Error;
|
|
|
|
#(#visitor_fns)*
|
|
}
|
|
|
|
#(#visit_impls)*
|
|
|
|
pub trait Folder {
|
|
type Error;
|
|
|
|
#(#folder_fns)*
|
|
}
|
|
|
|
#(#fold_impls)*
|
|
}))
|
|
}
|
|
|
|
pub fn error_at_call_site<T: std::fmt::Display>(e: T) -> syn::Error {
|
|
syn::Error::new(Span::call_site(), e)
|
|
}
|
|
|
|
pub fn parse_and_generate(path: impl AsRef<std::path::Path>) -> syn::Result<String> {
|
|
let input = fs::read_to_string(path).map_err(error_at_call_site)?;
|
|
let ast: ast::Definitions = serde_json::from_str(&input).map_err(error_at_call_site)?;
|
|
generate(&ast)
|
|
}
|