fayalite/crates/fayalite-visit-gen/src/ast.rs

614 lines
16 KiB
Rust

// SPDX-License-Identifier: LGPL-3.0-or-later
// See Notices.txt for copyright information
use indexmap::IndexMap;
use proc_macro2::{Span, TokenStream};
use quote::{IdentFragment, ToTokens, TokenStreamExt};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Write},
iter::FusedIterator,
str::FromStr,
};
use thiserror::Error;
macro_rules! impl_try_from_str {
($ty:ty) => {
impl TryFrom<&'_ str> for $ty {
type Error = <Self as FromStr>::Err;
fn try_from(v: &str) -> Result<Self, Self::Error> {
v.parse()
}
}
impl TryFrom<String> for $ty {
type Error = <Self as FromStr>::Err;
fn try_from(v: String) -> Result<Self, Self::Error> {
v.parse()
}
}
};
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
#[serde(into = "String", try_from = "String")]
pub struct Ident(pub String);
impl ToTokens for Ident {
fn to_tokens(&self, tokens: &mut TokenStream) {
syn::Ident::from(self).to_tokens(tokens);
}
}
impl IdentFragment for Ident {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(&self.0)
}
}
impl Ident {
pub fn is_start_char(ch: char) -> bool {
ch == '_' || ch.is_ascii_alphabetic()
}
pub fn is_continue_char(ch: char) -> bool {
ch == '_' || ch.is_ascii_alphanumeric()
}
pub fn is_ident(v: &str) -> bool {
!v.is_empty()
&& v.starts_with(Self::is_start_char)
&& v.trim_start_matches(Self::is_continue_char).is_empty()
}
}
impl From<Ident> for Path {
fn from(value: Ident) -> Self {
Path(value.0)
}
}
impl From<Ident> for String {
fn from(value: Ident) -> Self {
value.0
}
}
impl From<Ident> for syn::Ident {
fn from(value: Ident) -> Self {
From::from(&value)
}
}
impl From<&'_ Ident> for syn::Ident {
fn from(value: &Ident) -> Self {
syn::Ident::new(&value.0, Span::call_site())
}
}
#[derive(Clone, Debug, Error)]
#[error("invalid identifier")]
pub struct IdentParseError;
impl_try_from_str!(Ident);
impl FromStr for Ident {
type Err = IdentParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if Self::is_ident(s) {
Ok(Self(s.into()))
} else {
Err(IdentParseError)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
#[serde(into = "String", try_from = "String")]
pub struct Path(String);
impl Path {
pub fn iter(&self) -> PathIter<'_> {
PathIter(&self.0)
}
pub fn last(&self) -> Ident {
self.iter().next_back().unwrap()
}
pub fn is_path(s: &str) -> bool {
if s.is_empty() {
false
} else {
s.split("::").all(Ident::is_ident)
}
}
}
#[derive(Debug, Clone)]
pub struct PathIter<'a>(&'a str);
impl Iterator for PathIter<'_> {
type Item = Ident;
fn next(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
None
} else if let Some((first, rest)) = self.0.split_once("::") {
self.0 = rest;
Some(Ident(first.into()))
} else {
let retval = self.0;
self.0 = &self.0[..0];
Some(Ident(retval.into()))
}
}
fn last(mut self) -> Option<Self::Item> {
self.next_back()
}
}
impl FusedIterator for PathIter<'_> {}
impl DoubleEndedIterator for PathIter<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
None
} else if let Some((rest, last)) = self.0.rsplit_once("::") {
self.0 = rest;
Some(Ident(last.into()))
} else {
let retval = self.0;
self.0 = &self.0[..0];
Some(Ident(retval.into()))
}
}
}
impl ToTokens for Path {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.append_separated(self.iter(), <syn::Token![::]>::default());
}
}
#[derive(Clone, Debug, Error)]
#[error("invalid path")]
pub struct PathParseError;
impl From<Path> for String {
fn from(value: Path) -> Self {
value.0
}
}
impl_try_from_str!(Path);
impl FromStr for Path {
type Err = PathParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
if value.is_empty() {
Err(PathParseError)
} else if value.split("::").all(Ident::is_ident) {
Ok(Self(value.into()))
} else {
Err(PathParseError)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Definitions {
pub types: std::collections::BTreeMap<Path, Definition>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Definition {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fn_name_suffix: Option<Ident>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generics: Option<Generics>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fold_where: Option<WherePredicates>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub visit_where: Option<WherePredicates>,
pub data: Data,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "$kind")]
pub enum Data {
ManualImpl,
Opaque,
Enum(Variants),
Struct(Fields),
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
#[serde(into = "String", try_from = "String")]
pub struct FieldNameIdent {
pub ident: Ident,
pub is_getter: bool,
}
impl FieldNameIdent {
pub fn to_member(&self) -> Option<syn::Member> {
let Self {
ref ident,
is_getter,
} = *self;
if is_getter {
None
} else {
Some(syn::Ident::from(ident).into())
}
}
}
impl ToTokens for FieldNameIdent {
fn to_tokens(&self, tokens: &mut TokenStream) {
let Self {
ref ident,
is_getter,
} = *self;
ident.to_tokens(tokens);
if is_getter {
syn::token::Paren::default().surround(tokens, |_| {});
}
}
}
impl From<FieldNameIdent> for String {
fn from(value: FieldNameIdent) -> Self {
let mut retval = value.ident.0;
if value.is_getter {
retval.push_str("()");
}
retval
}
}
#[derive(Clone, Debug, Error)]
#[error("invalid field name")]
pub struct FieldNameParseError;
impl_try_from_str!(FieldNameIdent);
impl FromStr for FieldNameIdent {
type Err = FieldNameParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let ident = value.strip_suffix("()");
let is_getter = ident.is_some();
let ident = ident.unwrap_or(value);
if let Ok(ident) = ident.parse() {
Ok(Self { ident, is_getter })
} else {
Err(FieldNameParseError)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
#[serde(into = "String", try_from = "String")]
pub struct WherePredicates(pub syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>);
#[derive(Debug, Error)]
#[error("invalid `where` predicates")]
pub struct WherePredicatesParseError;
impl_try_from_str!(WherePredicates);
impl FromStr for WherePredicates {
type Err = WherePredicatesParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
Ok(Self(
syn::parse::Parser::parse_str(syn::punctuated::Punctuated::parse_terminated, value)
.map_err(|_| WherePredicatesParseError)?,
))
}
}
impl From<WherePredicates> for String {
fn from(value: WherePredicates) -> Self {
value.0.into_token_stream().to_string()
}
}
impl From<WherePredicates> for syn::WhereClause {
fn from(value: WherePredicates) -> Self {
syn::WhereClause {
where_token: Default::default(),
predicates: value.0,
}
}
}
impl From<syn::WhereClause> for WherePredicates {
fn from(value: syn::WhereClause) -> Self {
Self(value.predicates)
}
}
impl ToTokens for WherePredicates {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.0.to_tokens(tokens);
}
}
#[derive(Serialize, Deserialize)]
#[serde(untagged)]
enum SerializedGenerics {
Where {
generics: String,
#[serde(rename = "where")]
where_predicates: WherePredicates,
},
NoWhere(String),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
#[serde(into = "SerializedGenerics", try_from = "SerializedGenerics")]
pub struct Generics(pub syn::Generics);
impl ToTokens for Generics {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.0.to_tokens(tokens);
}
}
impl From<Generics> for SerializedGenerics {
fn from(mut value: Generics) -> Self {
match value.0.where_clause.take() {
Some(where_clause) => Self::Where {
generics: value.0.into_token_stream().to_string(),
where_predicates: where_clause.into(),
},
None => Self::NoWhere(value.0.into_token_stream().to_string()),
}
}
}
#[derive(Debug, Error)]
#[error("invalid generics")]
pub struct GenericsParseError;
impl TryFrom<SerializedGenerics> for Generics {
type Error = GenericsParseError;
fn try_from(value: SerializedGenerics) -> Result<Self, Self::Error> {
let (generics, where_clause) = match value {
SerializedGenerics::Where {
generics,
where_predicates,
} => (generics, Some(where_predicates.into())),
SerializedGenerics::NoWhere(generics) => (generics, None),
};
let Ok(mut generics) = syn::parse_str::<syn::Generics>(&generics) else {
return Err(GenericsParseError);
};
generics.where_clause = where_clause;
Ok(Self(generics))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[serde(into = "String", try_from = "String")]
pub struct PathWithGenerics {
pub path: Path,
pub generics: Option<syn::AngleBracketedGenericArguments>,
}
impl ToTokens for PathWithGenerics {
fn to_tokens(&self, tokens: &mut TokenStream) {
let Self { path, generics } = self;
path.to_tokens(tokens);
if let Some(generics) = generics {
<syn::Token![::]>::default().to_tokens(tokens);
generics.to_tokens(tokens);
}
}
}
impl From<PathWithGenerics> for String {
fn from(value: PathWithGenerics) -> Self {
let PathWithGenerics { path, generics } = value;
let mut retval = String::from(path);
if let Some(generics) = generics {
write!(retval, "{}", generics.to_token_stream()).unwrap();
}
retval
}
}
#[derive(Clone, Debug, Error)]
#[error("invalid path with optional generics")]
pub struct PathWithGenericsParseError;
impl_try_from_str!(PathWithGenerics);
impl FromStr for PathWithGenerics {
type Err = PathWithGenericsParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let (path, generics) = if let Some(lt_pos) = value.find('<') {
let (path, generics) = value.split_at(lt_pos);
let path = path.strip_suffix("::").unwrap_or(path);
match syn::parse_str(generics) {
Ok(generics) => (path, Some(generics)),
Err(_) => return Err(PathWithGenericsParseError),
}
} else {
(value, None)
};
if let Ok(path) = path.parse() {
Ok(Self { path, generics })
} else {
Err(PathWithGenericsParseError)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
#[serde(into = "String", try_from = "String")]
pub enum FieldName {
Index(usize),
Ident(FieldNameIdent),
}
impl FieldName {
pub fn to_member(&self) -> Option<syn::Member> {
match self {
&FieldName::Index(index) => Some(index.into()),
FieldName::Ident(ident) => ident.to_member(),
}
}
}
impl ToTokens for FieldName {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
&FieldName::Index(index) => syn::Index::from(index).to_tokens(tokens),
FieldName::Ident(ident) => ident.to_tokens(tokens),
}
}
}
impl From<FieldName> for String {
fn from(value: FieldName) -> Self {
match value {
FieldName::Index(index) => index.to_string(),
FieldName::Ident(ident) => ident.into(),
}
}
}
impl_try_from_str!(FieldName);
impl FromStr for FieldName {
type Err = FieldNameParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
if !value.is_empty()
&& value
.trim_start_matches(|ch: char| ch.is_ascii_digit())
.is_empty()
{
if let Ok(index) = value.parse() {
Ok(Self::Index(index))
} else {
Err(FieldNameParseError)
}
} else {
value.parse().map(Self::Ident)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Fields {
#[serde(
default,
rename = "$constructor",
skip_serializing_if = "Option::is_none"
)]
pub constructor: Option<PathWithGenerics>,
#[serde(flatten)]
pub fields: IndexMap<FieldName, Field>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Variants {
pub variants: IndexMap<Ident, Option<Field>>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum Field {
Opaque,
Visible,
RefVisible,
}
#[cfg(test)]
mod tests {
use crate::ast;
#[test]
fn test_serialize() {
let definitions = ast::Definitions {
types: FromIterator::from_iter([
(
ast::Path("Module".into()),
ast::Definition {
fn_name_suffix: None,
generics: Some(
ast::SerializedGenerics::Where {
generics: "<T: BundleValue>".into(),
where_predicates: "T::Type: BundleType<Value = T>,"
.parse()
.unwrap(),
}
.try_into()
.unwrap(),
),
fold_where: None,
visit_where: None,
data: ast::Data::Struct(ast::Fields {
constructor: Some("Module::new_unchecked".parse().unwrap()),
fields: FromIterator::from_iter([(
"name_id()".parse().unwrap(),
ast::Field::Visible,
)]),
}),
},
),
(
ast::Path("NameId".into()),
ast::Definition {
fn_name_suffix: None,
generics: None,
fold_where: None,
visit_where: None,
data: ast::Data::Struct(ast::Fields {
constructor: None,
fields: FromIterator::from_iter([
("0".try_into().unwrap(), ast::Field::Opaque),
("1".try_into().unwrap(), ast::Field::Opaque),
]),
}),
},
),
]),
};
let definitions_str = serde_json::to_string_pretty(&definitions).unwrap();
println!("{definitions_str}");
assert_eq!(
definitions_str,
r#"{
"types": {
"Module": {
"generics": {
"generics": "< T : BundleValue >",
"where": "T :: Type : BundleType < Value = T > ,"
},
"data": {
"$kind": "Struct",
"$constructor": "Module::new_unchecked",
"name_id()": "Visible"
}
},
"NameId": {
"data": {
"$kind": "Struct",
"0": "Opaque",
"1": "Opaque"
}
}
}
}"#
);
}
}