strum_macros/macros/
enum_discriminants.rs

1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{quote, ToTokens};
3use syn::parse_quote;
4use syn::{Data, DeriveInput, Fields};
5
6use crate::helpers::{non_enum_error, strum_discriminants_passthrough_error, HasTypeProperties};
7
8/// Attributes to copy from the main enum's variants to the discriminant enum's variants.
9///
10/// Attributes not in this list may be for other `proc_macro`s on the main enum, and may cause
11/// compilation problems when copied across.
12const ATTRIBUTES_TO_COPY: &[&str] = &["doc", "cfg", "allow", "deny", "strum_discriminants"];
13
14pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
15    let name = &ast.ident;
16    let vis = &ast.vis;
17
18    let variants = match &ast.data {
19        Data::Enum(v) => &v.variants,
20        _ => return Err(non_enum_error()),
21    };
22
23    // Derives for the generated enum
24    let type_properties = ast.get_type_properties()?;
25    let strum_module_path = type_properties.crate_module_path();
26
27    let derives = type_properties.discriminant_derives;
28
29    let derives = quote! {
30        #[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
31    };
32
33    // Create #[doc] attrs for new generated type.
34    let docs = type_properties.discriminant_docs;
35
36    let docs = quote! {
37        #(#[doc = #docs])*
38    };
39
40    // Work out the name
41    let default_name = syn::Ident::new(&format!("{}Discriminants", name), Span::call_site());
42
43    let discriminants_name = type_properties.discriminant_name.unwrap_or(default_name);
44    let discriminants_vis = type_properties
45        .discriminant_vis
46        .as_ref()
47        .unwrap_or_else(|| &vis);
48
49    // Pass through all other attributes
50    let pass_though_attributes = type_properties.discriminant_others;
51
52    let repr = type_properties.enum_repr.map(|repr| quote!(#[repr(#repr)]));
53
54    // Add the variants without fields, but exclude the `strum` meta item
55    let mut discriminants = Vec::new();
56    for variant in variants {
57        let ident = &variant.ident;
58        let discriminant = variant
59            .discriminant
60            .as_ref()
61            .map(|(_, expr)| quote!( = #expr));
62
63        // Don't copy across the "strum" meta attribute. Only passthrough the whitelisted
64        // attributes and proxy `#[strum_discriminants(...)]` attributes
65        let attrs = variant
66            .attrs
67            .iter()
68            .filter(|attr| {
69                ATTRIBUTES_TO_COPY
70                    .iter()
71                    .any(|attr_whitelisted| attr.path().is_ident(attr_whitelisted))
72            })
73            .map(|attr| {
74                if attr.path().is_ident("strum_discriminants") {
75                    let mut ts = attr.meta.require_list()?.to_token_stream().into_iter();
76
77                    // Discard strum_discriminants(...)
78                    let _ = ts.next();
79
80                    let passthrough_group = ts
81                        .next()
82                        .ok_or_else(|| strum_discriminants_passthrough_error(attr))?;
83                    let passthrough_attribute = match passthrough_group {
84                        TokenTree::Group(ref group) => group.stream(),
85                        _ => {
86                            return Err(strum_discriminants_passthrough_error(&passthrough_group));
87                        }
88                    };
89                    if passthrough_attribute.is_empty() {
90                        return Err(strum_discriminants_passthrough_error(&passthrough_group));
91                    }
92                    Ok(quote! { #[#passthrough_attribute] })
93                } else {
94                    Ok(attr.to_token_stream())
95                }
96            })
97            .collect::<Result<Vec<_>, _>>()?;
98
99        discriminants.push(quote! { #(#attrs)* #ident #discriminant});
100    }
101
102    // Ideally:
103    //
104    // * For `Copy` types, we `impl From<TheEnum> for TheEnumDiscriminants`
105    // * For `!Copy` types, we `impl<'enum> From<&'enum TheEnum> for TheEnumDiscriminants`
106    //
107    // That way we ensure users are not able to pass a `Copy` type by reference. However, the
108    // `#[derive(..)]` attributes are not in the parsed tokens, so we are not able to check if a
109    // type is `Copy`, so we just implement both.
110    //
111    // See <https://github.com/dtolnay/syn/issues/433>
112    // ---
113    // let is_copy = unique_meta_list(type_meta.iter(), "derive")
114    //     .map(extract_list_metas)
115    //     .map(|metas| {
116    //         metas
117    //             .filter_map(get_meta_ident)
118    //             .any(|derive| derive.to_string() == "Copy")
119    //     }).unwrap_or(false);
120
121    let arms = variants
122        .iter()
123        .map(|variant| {
124            let ident = &variant.ident;
125            let params = match &variant.fields {
126                Fields::Unit => quote! {},
127                Fields::Unnamed(_fields) => {
128                    quote! { (..) }
129                }
130                Fields::Named(_fields) => {
131                    quote! { { .. } }
132                }
133            };
134
135            quote! { #name::#ident #params => #discriminants_name::#ident }
136        })
137        .collect::<Vec<_>>();
138
139    let from_fn_body = if variants.is_empty() {
140        //this method on empty enum is impossible to be called. it is therefor left empty
141        quote! { unreachable!()}
142    } else {
143        quote! { match val { #(#arms),* } }
144    };
145
146    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
147    let impl_from = quote! {
148        #[automatically_derived]
149        impl #impl_generics ::core::convert::From< #name #ty_generics > for #discriminants_name #where_clause {
150            #[inline]
151            fn from(val: #name #ty_generics) -> #discriminants_name {
152                #from_fn_body
153            }
154        }
155    };
156    let impl_from_ref = {
157        let mut generics = ast.generics.clone();
158
159        let lifetime = parse_quote!('_enum);
160        let enum_life = quote! { & #lifetime };
161        generics.params.push(lifetime);
162
163        // Shadows the earlier `impl_generics`
164        let (impl_generics, _, _) = generics.split_for_impl();
165
166        quote! {
167            #[automatically_derived]
168            impl #impl_generics ::core::convert::From< #enum_life #name #ty_generics > for #discriminants_name #where_clause {
169                #[inline]
170                fn from(val: #enum_life #name #ty_generics) -> #discriminants_name {
171                    #from_fn_body
172                }
173            }
174        }
175    };
176
177    // For now, only implement IntoDiscriminant if the user has not overriden the visibility.
178    let impl_into_discriminant = match type_properties.discriminant_vis {
179        // If the visibilty is unspecified or `pub` then we implement IntoDiscriminant
180        None | Some(syn::Visibility::Public(..)) => quote! {
181            #[automatically_derived]
182            impl #impl_generics #strum_module_path::IntoDiscriminant for #name #ty_generics #where_clause {
183                type Discriminant = #discriminants_name;
184
185                #[inline]
186                fn discriminant(&self) -> Self::Discriminant {
187                    <Self::Discriminant as ::core::convert::From<&Self>>::from(self)
188                }
189            }
190        },
191        // If it's something restricted such as `pub(super)` then we skip implementing the
192        // trait for now. There are certainly scenarios where they could be equivalent, but
193        // as a heuristic, if someone is overriding the visibility, it's because they want
194        // the discriminant type to be less visible than the original type.
195        _ => quote! {},
196    };
197
198    Ok(quote! {
199        #docs
200        #derives
201        #repr
202        #(#[ #pass_though_attributes ])*
203        #discriminants_vis enum #discriminants_name {
204            #(#discriminants),*
205        }
206
207        #impl_into_discriminant
208        #impl_from
209        #impl_from_ref
210    })
211}