educe/trait_handlers/default/
default_union.rs

1use std::{fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{Data, DeriveInput, Generics, Lit, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct DefaultUnionHandler;
14
15impl TraitHandler for DefaultUnionHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag:       true,
24            enable_new:        true,
25            enable_expression: true,
26            enable_bound:      true,
27        }
28        .from_default_meta(meta);
29
30        let bound = type_attribute
31            .bound
32            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
33
34        let mut builder_tokens = TokenStream::new();
35
36        if let Data::Union(data) = &ast.data {
37            match type_attribute.expression {
38                Some(expression) => {
39                    for field in data.fields.named.iter() {
40                        let _ = FieldAttributeBuilder {
41                            enable_flag:       false,
42                            enable_literal:    false,
43                            enable_expression: false,
44                        }
45                        .from_attributes(&field.attrs, traits);
46                    }
47
48                    builder_tokens.extend(quote!(#expression));
49                },
50                None => {
51                    let ident = ast.ident.to_string();
52
53                    let (field_name, field_attribute, typ) = {
54                        let fields = &data.fields.named;
55
56                        if fields.len() == 1 {
57                            let field = &fields[0];
58
59                            let field_attribute = FieldAttributeBuilder {
60                                enable_flag:       true,
61                                enable_literal:    true,
62                                enable_expression: true,
63                            }
64                            .from_attributes(&field.attrs, traits);
65
66                            let field_name = field.ident.as_ref().unwrap().to_string();
67
68                            (
69                                field_name,
70                                field_attribute,
71                                field.ty.clone().into_token_stream().to_string(),
72                            )
73                        } else {
74                            let mut fields_iter = fields.iter();
75
76                            loop {
77                                let field = fields_iter.next();
78
79                                match field {
80                                    Some(field) => {
81                                        let field_attribute = FieldAttributeBuilder {
82                                            enable_flag:       true,
83                                            enable_literal:    true,
84                                            enable_expression: true,
85                                        }
86                                        .from_attributes(&field.attrs, traits);
87
88                                        if field_attribute.flag
89                                            || field_attribute.literal.is_some()
90                                            || field_attribute.expression.is_some()
91                                        {
92                                            let field_name =
93                                                field.ident.as_ref().unwrap().to_string();
94
95                                            loop {
96                                                let field = fields_iter.next();
97
98                                                match field {
99                                                    Some(field) => {
100                                                        let field_attribute =
101                                                            FieldAttributeBuilder {
102                                                                enable_flag:       true,
103                                                                enable_literal:    true,
104                                                                enable_expression: true,
105                                                            }
106                                                            .from_attributes(&field.attrs, traits);
107
108                                                        if field_attribute.flag
109                                                            || field_attribute.literal.is_some()
110                                                            || field_attribute.expression.is_some()
111                                                        {
112                                                            panic::multiple_default_fields();
113                                                        }
114                                                    },
115                                                    None => break,
116                                                }
117                                            }
118
119                                            break (
120                                                field_name,
121                                                field_attribute,
122                                                field.ty.clone().into_token_stream().to_string(),
123                                            );
124                                        }
125                                    },
126                                    None => panic::no_default_field(),
127                                }
128                            }
129                        }
130                    };
131
132                    let mut union_tokens = format!(
133                        "{ident} {{ {field_name}: ",
134                        ident = ident,
135                        field_name = field_name
136                    );
137
138                    match field_attribute.literal {
139                        Some(value) => match &value {
140                            Lit::Str(s) => {
141                                union_tokens
142                                    .write_fmt(format_args!(
143                                        "core::convert::Into::into({s})",
144                                        s = s.into_token_stream()
145                                    ))
146                                    .unwrap();
147                            },
148                            _ => {
149                                union_tokens.push_str(&value.into_token_stream().to_string());
150                            },
151                        },
152                        None => match field_attribute.expression {
153                            Some(expression) => {
154                                union_tokens.push_str(&expression);
155                            },
156                            None => {
157                                union_tokens
158                                    .write_fmt(format_args!(
159                                        "<{typ} as core::default::Default>::default()",
160                                        typ = typ
161                                    ))
162                                    .unwrap();
163                            },
164                        },
165                    }
166
167                    union_tokens.push('}');
168
169                    builder_tokens.extend(TokenStream::from_str(&union_tokens).unwrap());
170                },
171            }
172        }
173
174        let ident = &ast.ident;
175
176        let mut generics_cloned: Generics = ast.generics.clone();
177
178        let where_clause = generics_cloned.make_where_clause();
179
180        for where_predicate in bound {
181            where_clause.predicates.push(where_predicate);
182        }
183
184        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
185
186        let default_impl = quote! {
187            impl #impl_generics core::default::Default for #ident #ty_generics #where_clause {
188                #[inline]
189                fn default() -> Self {
190                    #builder_tokens
191                }
192            }
193        };
194
195        tokens.extend(default_impl);
196
197        if type_attribute.new {
198            let new_impl = quote! {
199                impl #impl_generics #ident #ty_generics #where_clause {
200                    /// Returns the "default value" for a type.
201                    #[inline]
202                    pub fn new() -> Self {
203                        <Self as core::default::Default>::default()
204                    }
205                }
206            };
207
208            tokens.extend(new_impl);
209        }
210    }
211}