educe/trait_handlers/hash/
hash_enum.rs

1use std::{fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Generics, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::Trait;
12
13pub struct HashEnumHandler;
14
15impl TraitHandler for HashEnumHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag: true, enable_bound: true
24        }
25        .from_hash_meta(meta);
26
27        let enum_name = ast.ident.to_string();
28
29        let bound = type_attribute
30            .bound
31            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
32
33        let mut hasher_tokens = TokenStream::new();
34
35        let mut match_tokens = String::from("match self {");
36
37        if let Data::Enum(data) = &ast.data {
38            let has_non_unit = {
39                let mut non_unit = false;
40
41                for variant in data.variants.iter() {
42                    let _ = TypeAttributeBuilder {
43                        enable_flag: false, enable_bound: false
44                    }
45                    .from_attributes(&variant.attrs, traits);
46
47                    match &variant.fields {
48                        Fields::Named(_) | Fields::Unnamed(_) => {
49                            non_unit = true;
50
51                            break;
52                        },
53                        _ => (),
54                    }
55                }
56
57                non_unit
58            };
59
60            if has_non_unit {
61                for (index, variant) in data.variants.iter().enumerate() {
62                    let variant_ident = variant.ident.to_string();
63
64                    match &variant.fields {
65                        Fields::Unit => {
66                            // TODO Unit
67                            match_tokens
68                                .write_fmt(format_args!(
69                                    "{enum_name}::{variant_ident} => {{ \
70                                     core::hash::Hash::hash(&{index}, state); }}",
71                                    enum_name = enum_name,
72                                    variant_ident = variant_ident,
73                                    index = index
74                                ))
75                                .unwrap();
76                        },
77                        Fields::Named(fields) => {
78                            // TODO Struct
79                            let mut pattern_tokens = String::new();
80                            let mut block_tokens = String::new();
81
82                            block_tokens
83                                .write_fmt(format_args!(
84                                    "core::hash::Hash::hash(&{index}, state);",
85                                    index = index
86                                ))
87                                .unwrap();
88
89                            for field in fields.named.iter() {
90                                let field_attribute = FieldAttributeBuilder {
91                                    enable_ignore: true,
92                                    enable_impl:   true,
93                                }
94                                .from_attributes(&field.attrs, traits);
95
96                                let field_name = field.ident.as_ref().unwrap().to_string();
97
98                                if field_attribute.ignore {
99                                    pattern_tokens
100                                        .write_fmt(format_args!(
101                                            "{field_name}: _,",
102                                            field_name = field_name
103                                        ))
104                                        .unwrap();
105                                    continue;
106                                }
107
108                                let hash_trait = field_attribute.hash_trait;
109                                let hash_method = field_attribute.hash_method;
110
111                                pattern_tokens
112                                    .write_fmt(format_args!(
113                                        "{field_name},",
114                                        field_name = field_name
115                                    ))
116                                    .unwrap();
117
118                                match hash_trait {
119                                    Some(hash_trait) => {
120                                        let hash_method = hash_method.unwrap();
121
122                                        block_tokens
123                                            .write_fmt(format_args!(
124                                                "{hash_trait}::{hash_method}({field_name}, state);",
125                                                hash_trait = hash_trait,
126                                                hash_method = hash_method,
127                                                field_name = field_name
128                                            ))
129                                            .unwrap();
130                                    },
131                                    None => match hash_method {
132                                        Some(hash_method) => {
133                                            block_tokens
134                                                .write_fmt(format_args!(
135                                                    "{hash_method}({field_name}, state);",
136                                                    hash_method = hash_method,
137                                                    field_name = field_name
138                                                ))
139                                                .unwrap();
140                                        },
141                                        None => {
142                                            block_tokens
143                                                .write_fmt(format_args!(
144                                                    "core::hash::Hash::hash({field_name}, state);",
145                                                    field_name = field_name
146                                                ))
147                                                .unwrap();
148                                        },
149                                    },
150                                }
151                            }
152
153                            match_tokens
154                                .write_fmt(format_args!(
155                                    "{enum_name}::{variant_ident} {{ {pattern_tokens} }} => {{ \
156                                     {block_tokens} }}",
157                                    enum_name = enum_name,
158                                    variant_ident = variant_ident,
159                                    pattern_tokens = pattern_tokens,
160                                    block_tokens = block_tokens
161                                ))
162                                .unwrap();
163                        },
164                        Fields::Unnamed(fields) => {
165                            // TODO Tuple
166                            let mut pattern_tokens = String::new();
167                            let mut block_tokens = String::new();
168
169                            block_tokens
170                                .write_fmt(format_args!(
171                                    "core::hash::Hash::hash(&{index}, state);",
172                                    index = index
173                                ))
174                                .unwrap();
175
176                            for (index, field) in fields.unnamed.iter().enumerate() {
177                                let field_attribute = FieldAttributeBuilder {
178                                    enable_ignore: true,
179                                    enable_impl:   true,
180                                }
181                                .from_attributes(&field.attrs, traits);
182
183                                if field_attribute.ignore {
184                                    pattern_tokens.push_str("_,");
185                                    continue;
186                                }
187
188                                let hash_trait = field_attribute.hash_trait;
189                                let hash_method = field_attribute.hash_method;
190
191                                let field_name = format!("{}", index);
192
193                                pattern_tokens
194                                    .write_fmt(format_args!(
195                                        "_{field_name},",
196                                        field_name = field_name
197                                    ))
198                                    .unwrap();
199
200                                match hash_trait {
201                                    Some(hash_trait) => {
202                                        let hash_method = hash_method.unwrap();
203
204                                        block_tokens
205                                            .write_fmt(format_args!(
206                                                "{hash_trait}::{hash_method}(_{field_name}, \
207                                                 state);",
208                                                hash_trait = hash_trait,
209                                                hash_method = hash_method,
210                                                field_name = field_name
211                                            ))
212                                            .unwrap();
213                                    },
214                                    None => match hash_method {
215                                        Some(hash_method) => {
216                                            block_tokens
217                                                .write_fmt(format_args!(
218                                                    "{hash_method}(_{field_name}, state);",
219                                                    hash_method = hash_method,
220                                                    field_name = field_name
221                                                ))
222                                                .unwrap();
223                                        },
224                                        None => {
225                                            block_tokens
226                                                .write_fmt(format_args!(
227                                                    "core::hash::Hash::hash(_{field_name}, state);",
228                                                    field_name = field_name
229                                                ))
230                                                .unwrap();
231                                        },
232                                    },
233                                }
234                            }
235
236                            match_tokens
237                                .write_fmt(format_args!(
238                                    "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ \
239                                     {block_tokens} }}",
240                                    enum_name = enum_name,
241                                    variant_ident = variant_ident,
242                                    pattern_tokens = pattern_tokens,
243                                    block_tokens = block_tokens
244                                ))
245                                .unwrap();
246                        },
247                    }
248                }
249            } else {
250                for variant in data.variants.iter() {
251                    let variant_ident = variant.ident.to_string();
252
253                    match_tokens
254                        .write_fmt(format_args!(
255                            "{enum_name}::{variant_ident} => {{ \
256                             core::hash::Hash::hash(&({enum_name}::{variant_ident} as isize), \
257                             state); }}",
258                            enum_name = enum_name,
259                            variant_ident = variant_ident
260                        ))
261                        .unwrap();
262                }
263            }
264        }
265
266        match_tokens.push('}');
267
268        hasher_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
269
270        let ident = &ast.ident;
271
272        let mut generics_cloned: Generics = ast.generics.clone();
273
274        let where_clause = generics_cloned.make_where_clause();
275
276        for where_predicate in bound {
277            where_clause.predicates.push(where_predicate);
278        }
279
280        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
281
282        let hash_impl = quote! {
283            impl #impl_generics core::hash::Hash for #ident #ty_generics #where_clause {
284                #[inline]
285                #[allow(clippy::unneeded_field_pattern)]
286                fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
287                    #hasher_tokens
288                }
289            }
290        };
291
292        tokens.extend(hash_impl);
293    }
294}