educe/trait_handlers/deref/
deref_enum.rs

1use std::{fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{Data, DeriveInput, Fields, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct DerefEnumHandler;
14
15impl TraitHandler for DerefEnumHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let _ = TypeAttributeBuilder {
23            enable_flag: true
24        }
25        .from_deref_meta(meta);
26
27        let enum_name = ast.ident.to_string();
28
29        let mut ty_all = TokenStream::new();
30        let mut deref_tokens = TokenStream::new();
31
32        let mut match_tokens = String::from("match self {");
33
34        if let Data::Enum(data) = &ast.data {
35            for variant in data.variants.iter() {
36                let _ = TypeAttributeBuilder {
37                    enable_flag: false
38                }
39                .from_attributes(&variant.attrs, traits);
40
41                let variant_ident = variant.ident.to_string();
42
43                match &variant.fields {
44                    Fields::Unit => {
45                        // TODO Unit
46                        panic::deref_cannot_support_unit_variant();
47                    },
48                    Fields::Named(fields) => {
49                        // TODO Struct
50                        let mut pattern_tokens = String::new();
51                        let mut block_tokens = String::new();
52
53                        let mut ty = TokenStream::new();
54
55                        let mut counter = 0;
56
57                        for field in fields.named.iter() {
58                            let field_attribute = FieldAttributeBuilder {
59                                enable_flag: true
60                            }
61                            .from_attributes(&field.attrs, traits);
62
63                            if field_attribute.flag {
64                                if !ty.is_empty() {
65                                    panic::multiple_deref_fields_of_variant(&variant_ident);
66                                }
67
68                                let field_name = field.ident.as_ref().unwrap().to_string();
69
70                                ty.extend(field.ty.clone().into_token_stream());
71
72                                if ty_all.is_empty() {
73                                    ty_all.extend(field.ty.clone().into_token_stream());
74                                }
75
76                                block_tokens
77                                    .write_fmt(format_args!(
78                                        "return {field_name};",
79                                        field_name = field_name
80                                    ))
81                                    .unwrap();
82                                pattern_tokens
83                                    .write_fmt(format_args!(
84                                        "{field_name}, ..",
85                                        field_name = field_name
86                                    ))
87                                    .unwrap();
88                            }
89
90                            counter += 1;
91                        }
92
93                        if ty.is_empty() {
94                            if counter == 1 {
95                                let field = fields.named.iter().next().unwrap();
96
97                                let field_name = field.ident.as_ref().unwrap().to_string();
98
99                                ty.extend(field.ty.clone().into_token_stream());
100
101                                if ty_all.is_empty() {
102                                    ty_all.extend(field.ty.clone().into_token_stream());
103                                }
104
105                                block_tokens
106                                    .write_fmt(format_args!(
107                                        "return {field_name};",
108                                        field_name = field_name
109                                    ))
110                                    .unwrap();
111                                pattern_tokens
112                                    .write_fmt(format_args!(
113                                        "{field_name}, ..",
114                                        field_name = field_name
115                                    ))
116                                    .unwrap();
117                            } else {
118                                panic::no_deref_field_of_variant(&variant_ident);
119                            }
120                        }
121
122                        match_tokens
123                            .write_fmt(format_args!(
124                                "{enum_name}::{variant_ident} {{ {pattern_tokens} }} => {{ \
125                                 {block_tokens} }}",
126                                enum_name = enum_name,
127                                variant_ident = variant_ident,
128                                pattern_tokens = pattern_tokens,
129                                block_tokens = block_tokens
130                            ))
131                            .unwrap();
132                    },
133                    Fields::Unnamed(fields) => {
134                        // TODO Tuple
135                        let mut pattern_tokens = String::new();
136                        let mut block_tokens = String::new();
137
138                        let mut ty = TokenStream::new();
139
140                        let mut counter = 0;
141
142                        for (index, field) in fields.unnamed.iter().enumerate() {
143                            let field_attribute = FieldAttributeBuilder {
144                                enable_flag: true
145                            }
146                            .from_attributes(&field.attrs, traits);
147
148                            if field_attribute.flag {
149                                if !ty.is_empty() {
150                                    panic::multiple_deref_fields_of_variant(&variant_ident);
151                                }
152
153                                let field_name = format!("{}", index);
154
155                                ty.extend(field.ty.clone().into_token_stream());
156
157                                if ty_all.is_empty() {
158                                    ty_all.extend(field.ty.clone().into_token_stream());
159                                }
160
161                                block_tokens
162                                    .write_fmt(format_args!(
163                                        "return _{field_name};",
164                                        field_name = field_name
165                                    ))
166                                    .unwrap();
167                                pattern_tokens
168                                    .write_fmt(format_args!(
169                                        "_{field_name},",
170                                        field_name = field_name
171                                    ))
172                                    .unwrap();
173                            } else {
174                                pattern_tokens.push_str("_,");
175                            }
176
177                            counter += 1;
178                        }
179
180                        if ty.is_empty() {
181                            if counter == 1 {
182                                let field = fields.unnamed.iter().next().unwrap();
183
184                                let field_name = String::from("0");
185
186                                ty.extend(field.ty.clone().into_token_stream());
187
188                                if ty_all.is_empty() {
189                                    ty_all.extend(field.ty.clone().into_token_stream());
190                                }
191
192                                block_tokens
193                                    .write_fmt(format_args!(
194                                        "return _{field_name};",
195                                        field_name = field_name
196                                    ))
197                                    .unwrap();
198
199                                pattern_tokens.clear();
200                                pattern_tokens
201                                    .write_fmt(format_args!(
202                                        "_{field_name}",
203                                        field_name = field_name
204                                    ))
205                                    .unwrap();
206                            } else {
207                                panic::no_deref_field_of_variant(&variant_ident);
208                            }
209                        }
210
211                        match_tokens
212                            .write_fmt(format_args!(
213                                "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ \
214                                 {block_tokens} }}",
215                                enum_name = enum_name,
216                                variant_ident = variant_ident,
217                                pattern_tokens = pattern_tokens,
218                                block_tokens = block_tokens
219                            ))
220                            .unwrap();
221                    },
222                }
223            }
224        }
225
226        match_tokens.push('}');
227
228        deref_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
229
230        let ident = &ast.ident;
231
232        let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
233
234        let deref_impl = quote! {
235            impl #impl_generics core::ops::Deref for #ident #ty_generics #where_clause {
236                type Target = #ty_all;
237
238                #[inline]
239                fn deref(&self) -> &Self::Target {
240                    #deref_tokens
241                }
242            }
243        };
244
245        tokens.extend(deref_impl);
246    }
247}