strum_macros/macros/
enum_discriminants.rs1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{quote, ToTokens};
3use syn::parse_quote;
4use syn::{Data, DeriveInput, Fields};
5
6use crate::helpers::{non_enum_error, strum_discriminants_passthrough_error, HasTypeProperties};
7
8const ATTRIBUTES_TO_COPY: &[&str] = &["doc", "cfg", "allow", "deny", "strum_discriminants"];
13
14pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
15 let name = &ast.ident;
16 let vis = &ast.vis;
17
18 let variants = match &ast.data {
19 Data::Enum(v) => &v.variants,
20 _ => return Err(non_enum_error()),
21 };
22
23 let type_properties = ast.get_type_properties()?;
25 let strum_module_path = type_properties.crate_module_path();
26
27 let derives = type_properties.discriminant_derives;
28
29 let derives = quote! {
30 #[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
31 };
32
33 let docs = type_properties.discriminant_docs;
35
36 let docs = quote! {
37 #(#[doc = #docs])*
38 };
39
40 let default_name = syn::Ident::new(&format!("{}Discriminants", name), Span::call_site());
42
43 let discriminants_name = type_properties.discriminant_name.unwrap_or(default_name);
44 let discriminants_vis = type_properties
45 .discriminant_vis
46 .as_ref()
47 .unwrap_or_else(|| &vis);
48
49 let pass_though_attributes = type_properties.discriminant_others;
51
52 let repr = type_properties.enum_repr.map(|repr| quote!(#[repr(#repr)]));
53
54 let mut discriminants = Vec::new();
56 for variant in variants {
57 let ident = &variant.ident;
58 let discriminant = variant
59 .discriminant
60 .as_ref()
61 .map(|(_, expr)| quote!( = #expr));
62
63 let attrs = variant
66 .attrs
67 .iter()
68 .filter(|attr| {
69 ATTRIBUTES_TO_COPY
70 .iter()
71 .any(|attr_whitelisted| attr.path().is_ident(attr_whitelisted))
72 })
73 .map(|attr| {
74 if attr.path().is_ident("strum_discriminants") {
75 let mut ts = attr.meta.require_list()?.to_token_stream().into_iter();
76
77 let _ = ts.next();
79
80 let passthrough_group = ts
81 .next()
82 .ok_or_else(|| strum_discriminants_passthrough_error(attr))?;
83 let passthrough_attribute = match passthrough_group {
84 TokenTree::Group(ref group) => group.stream(),
85 _ => {
86 return Err(strum_discriminants_passthrough_error(&passthrough_group));
87 }
88 };
89 if passthrough_attribute.is_empty() {
90 return Err(strum_discriminants_passthrough_error(&passthrough_group));
91 }
92 Ok(quote! { #[#passthrough_attribute] })
93 } else {
94 Ok(attr.to_token_stream())
95 }
96 })
97 .collect::<Result<Vec<_>, _>>()?;
98
99 discriminants.push(quote! { #(#attrs)* #ident #discriminant});
100 }
101
102 let arms = variants
122 .iter()
123 .map(|variant| {
124 let ident = &variant.ident;
125 let params = match &variant.fields {
126 Fields::Unit => quote! {},
127 Fields::Unnamed(_fields) => {
128 quote! { (..) }
129 }
130 Fields::Named(_fields) => {
131 quote! { { .. } }
132 }
133 };
134
135 quote! { #name::#ident #params => #discriminants_name::#ident }
136 })
137 .collect::<Vec<_>>();
138
139 let from_fn_body = if variants.is_empty() {
140 quote! { unreachable!()}
142 } else {
143 quote! { match val { #(#arms),* } }
144 };
145
146 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
147 let impl_from = quote! {
148 #[automatically_derived]
149 impl #impl_generics ::core::convert::From< #name #ty_generics > for #discriminants_name #where_clause {
150 #[inline]
151 fn from(val: #name #ty_generics) -> #discriminants_name {
152 #from_fn_body
153 }
154 }
155 };
156 let impl_from_ref = {
157 let mut generics = ast.generics.clone();
158
159 let lifetime = parse_quote!('_enum);
160 let enum_life = quote! { & #lifetime };
161 generics.params.push(lifetime);
162
163 let (impl_generics, _, _) = generics.split_for_impl();
165
166 quote! {
167 #[automatically_derived]
168 impl #impl_generics ::core::convert::From< #enum_life #name #ty_generics > for #discriminants_name #where_clause {
169 #[inline]
170 fn from(val: #enum_life #name #ty_generics) -> #discriminants_name {
171 #from_fn_body
172 }
173 }
174 }
175 };
176
177 let impl_into_discriminant = match type_properties.discriminant_vis {
179 None | Some(syn::Visibility::Public(..)) => quote! {
181 #[automatically_derived]
182 impl #impl_generics #strum_module_path::IntoDiscriminant for #name #ty_generics #where_clause {
183 type Discriminant = #discriminants_name;
184
185 #[inline]
186 fn discriminant(&self) -> Self::Discriminant {
187 <Self::Discriminant as ::core::convert::From<&Self>>::from(self)
188 }
189 }
190 },
191 _ => quote! {},
196 };
197
198 Ok(quote! {
199 #docs
200 #derives
201 #repr
202 #(#[ #pass_though_attributes ])*
203 #discriminants_vis enum #discriminants_name {
204 #(#discriminants),*
205 }
206
207 #impl_into_discriminant
208 #impl_from
209 #impl_from_ref
210 })
211}