educe/trait_handlers/default/
default_enum.rs

1use std::{fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{Data, DeriveInput, Fields, Generics, Lit, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct DefaultEnumHandler;
14
15impl TraitHandler for DefaultEnumHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag:       true,
24            enable_new:        true,
25            enable_expression: true,
26            enable_bound:      true,
27        }
28        .from_default_meta(meta);
29
30        let bound = type_attribute
31            .bound
32            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
33
34        let mut builder_tokens = TokenStream::new();
35
36        if let Data::Enum(data) = &ast.data {
37            match type_attribute.expression {
38                Some(expression) => {
39                    for variant in data.variants.iter() {
40                        let _ = TypeAttributeBuilder {
41                            enable_flag:       false,
42                            enable_new:        false,
43                            enable_expression: false,
44                            enable_bound:      false,
45                        }
46                        .from_attributes(&variant.attrs, traits);
47
48                        ensure_fields_no_attribute(&variant.fields, traits);
49                    }
50
51                    builder_tokens.extend(quote!(#expression));
52                },
53                None => {
54                    let variant = {
55                        let variants = &data.variants;
56
57                        if variants.len() == 1 {
58                            let variant = &variants[0];
59
60                            let _ = TypeAttributeBuilder {
61                                enable_flag:       true,
62                                enable_new:        false,
63                                enable_expression: false,
64                                enable_bound:      false,
65                            }
66                            .from_attributes(&variant.attrs, traits);
67
68                            variant
69                        } else {
70                            let mut variants_iter = variants.iter();
71
72                            loop {
73                                let variant = variants_iter.next();
74
75                                match variant {
76                                    Some(variant) => {
77                                        let variant_attribute = TypeAttributeBuilder {
78                                            enable_flag:       true,
79                                            enable_new:        false,
80                                            enable_expression: false,
81                                            enable_bound:      false,
82                                        }
83                                        .from_attributes(&variant.attrs, traits);
84
85                                        if variant_attribute.flag {
86                                            loop {
87                                                let variant = variants_iter.next();
88
89                                                match variant {
90                                                    Some(variant) => {
91                                                        let variant_attribute = TypeAttributeBuilder {
92                                                            enable_flag: true,
93                                                            enable_new: false,
94                                                            enable_expression: false,
95                                                            enable_bound: false,
96                                                        }.from_attributes(&variant.attrs, traits);
97
98                                                        if variant_attribute.flag {
99                                                            panic::multiple_default_variants();
100                                                        } else {
101                                                            ensure_fields_no_attribute(
102                                                                &variant.fields,
103                                                                traits,
104                                                            );
105                                                        }
106                                                    },
107                                                    None => break,
108                                                }
109                                            }
110
111                                            break variant;
112                                        } else {
113                                            ensure_fields_no_attribute(&variant.fields, traits);
114                                        }
115                                    },
116                                    None => panic::no_default_variant(),
117                                }
118                            }
119                        }
120                    };
121
122                    let enum_name = ast.ident.to_string();
123                    let variant_name = variant.ident.to_string();
124
125                    let mut enum_tokens = format!(
126                        "{enum_name}::{variant_name}",
127                        enum_name = enum_name,
128                        variant_name = variant_name
129                    );
130
131                    match &variant.fields {
132                        Fields::Unit => (), // TODO Unit
133                        Fields::Named(fields) => {
134                            // TODO Struct
135                            enum_tokens.push('{');
136
137                            for field in fields.named.iter() {
138                                let field_attribute = FieldAttributeBuilder {
139                                    enable_flag:       false,
140                                    enable_literal:    true,
141                                    enable_expression: true,
142                                }
143                                .from_attributes(&field.attrs, traits);
144
145                                let field_name = field.ident.as_ref().unwrap().to_string();
146
147                                enum_tokens
148                                    .write_fmt(format_args!(
149                                        "{field_name}: ",
150                                        field_name = field_name
151                                    ))
152                                    .unwrap();
153
154                                match field_attribute.literal {
155                                    Some(value) => match &value {
156                                        Lit::Str(s) => {
157                                            enum_tokens
158                                                .write_fmt(format_args!(
159                                                    "core::convert::Into::into({s})",
160                                                    s = s.into_token_stream()
161                                                ))
162                                                .unwrap();
163                                        },
164                                        _ => {
165                                            enum_tokens
166                                                .push_str(&value.into_token_stream().to_string());
167                                        },
168                                    },
169                                    None => {
170                                        match field_attribute.expression {
171                                            Some(expression) => {
172                                                enum_tokens.push_str(&expression);
173                                            },
174                                            None => {
175                                                let typ = field
176                                                    .ty
177                                                    .clone()
178                                                    .into_token_stream()
179                                                    .to_string();
180
181                                                enum_tokens.write_fmt(format_args!("<{typ} as core::default::Default>::default()", typ = typ)).unwrap();
182                                            },
183                                        }
184                                    },
185                                }
186
187                                enum_tokens.push(',');
188                            }
189
190                            enum_tokens.push('}');
191                        },
192                        Fields::Unnamed(fields) => {
193                            // TODO Tuple
194                            enum_tokens.push('(');
195
196                            for field in fields.unnamed.iter() {
197                                let field_attribute = FieldAttributeBuilder {
198                                    enable_flag:       false,
199                                    enable_literal:    true,
200                                    enable_expression: true,
201                                }
202                                .from_attributes(&field.attrs, traits);
203
204                                match field_attribute.literal {
205                                    Some(value) => match &value {
206                                        Lit::Str(s) => {
207                                            enum_tokens
208                                                .write_fmt(format_args!(
209                                                    "core::convert::Into::into({s})",
210                                                    s = s.into_token_stream()
211                                                ))
212                                                .unwrap();
213                                        },
214                                        _ => {
215                                            enum_tokens
216                                                .push_str(&value.into_token_stream().to_string());
217                                        },
218                                    },
219                                    None => {
220                                        match field_attribute.expression {
221                                            Some(expression) => {
222                                                enum_tokens.push_str(&expression);
223                                            },
224                                            None => {
225                                                let typ = field
226                                                    .ty
227                                                    .clone()
228                                                    .into_token_stream()
229                                                    .to_string();
230
231                                                enum_tokens.write_fmt(format_args!("<{typ} as core::default::Default>::default()", typ = typ)).unwrap();
232                                            },
233                                        }
234                                    },
235                                }
236
237                                enum_tokens.push(',');
238                            }
239
240                            enum_tokens.push(')');
241                        },
242                    }
243
244                    builder_tokens.extend(TokenStream::from_str(&enum_tokens).unwrap());
245                },
246            }
247        }
248
249        let ident = &ast.ident;
250
251        let mut generics_cloned: Generics = ast.generics.clone();
252
253        let where_clause = generics_cloned.make_where_clause();
254
255        for where_predicate in bound {
256            where_clause.predicates.push(where_predicate);
257        }
258
259        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
260
261        let default_impl = quote! {
262            impl #impl_generics core::default::Default for #ident #ty_generics #where_clause {
263                #[inline]
264                fn default() -> Self {
265                    #builder_tokens
266                }
267            }
268        };
269
270        tokens.extend(default_impl);
271
272        if type_attribute.new {
273            let new_impl = quote! {
274                impl #impl_generics #ident #ty_generics #where_clause {
275                    /// Returns the "default value" for a type.
276                    #[inline]
277                    pub fn new() -> Self {
278                        <Self as core::default::Default>::default()
279                    }
280                }
281            };
282
283            tokens.extend(new_impl);
284        }
285    }
286}
287
288fn ensure_fields_no_attribute(fields: &Fields, traits: &[Trait]) {
289    match fields {
290        Fields::Unit => (),
291        Fields::Named(fields) => {
292            for field in fields.named.iter() {
293                let _ = FieldAttributeBuilder {
294                    enable_flag:       false,
295                    enable_literal:    false,
296                    enable_expression: false,
297                }
298                .from_attributes(&field.attrs, traits);
299            }
300        },
301        Fields::Unnamed(fields) => {
302            for field in fields.unnamed.iter() {
303                let _ = FieldAttributeBuilder {
304                    enable_flag:       false,
305                    enable_literal:    false,
306                    enable_expression: false,
307                }
308                .from_attributes(&field.attrs, traits);
309            }
310        },
311    }
312}