educe/trait_handlers/partial_eq/
partial_eq_enum.rs

1use std::{fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Generics, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::Trait;
12
13pub struct PartialEqEnumHandler;
14
15impl TraitHandler for PartialEqEnumHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag: true, enable_bound: true
24        }
25        .from_partial_eq_meta(meta);
26
27        let enum_name = ast.ident.to_string();
28
29        let bound = type_attribute
30            .bound
31            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
32
33        let mut comparer_tokens = TokenStream::new();
34
35        let mut match_tokens = String::from("match self {");
36
37        if let Data::Enum(data) = &ast.data {
38            for variant in data.variants.iter() {
39                let _ = TypeAttributeBuilder {
40                    enable_flag: false, enable_bound: false
41                }
42                .from_attributes(&variant.attrs, traits);
43
44                let variant_ident = variant.ident.to_string();
45
46                match &variant.fields {
47                    Fields::Unit => {
48                        // TODO Unit
49                        match_tokens
50                            .write_fmt(format_args!(
51                                "{enum_name}::{variant_ident} => {{ if let \
52                                 {enum_name}::{variant_ident} = other {{ }} else {{ return false; \
53                                 }} }}",
54                                enum_name = enum_name,
55                                variant_ident = variant_ident
56                            ))
57                            .unwrap();
58                    },
59                    Fields::Named(fields) => {
60                        // TODO Struct
61                        let mut pattern_tokens = String::new();
62                        let mut pattern_2_tokens = String::new();
63                        let mut block_tokens = String::new();
64
65                        let mut field_attributes = Vec::new();
66                        let mut field_names = Vec::new();
67
68                        for field in fields.named.iter() {
69                            let field_attribute = FieldAttributeBuilder {
70                                enable_ignore: true,
71                                enable_impl:   true,
72                            }
73                            .from_attributes(&field.attrs, traits);
74
75                            let field_name = field.ident.as_ref().unwrap().to_string();
76
77                            if field_attribute.ignore {
78                                pattern_tokens
79                                    .write_fmt(format_args!(
80                                        "{field_name}: _,",
81                                        field_name = field_name
82                                    ))
83                                    .unwrap();
84                                pattern_2_tokens
85                                    .write_fmt(format_args!(
86                                        "{field_name}: _,",
87                                        field_name = field_name
88                                    ))
89                                    .unwrap();
90                                continue;
91                            }
92
93                            pattern_tokens
94                                .write_fmt(format_args!("{field_name},", field_name = field_name))
95                                .unwrap();
96                            pattern_2_tokens
97                                .write_fmt(format_args!(
98                                    "{field_name}: ___{field_name},",
99                                    field_name = field_name
100                                ))
101                                .unwrap();
102
103                            field_attributes.push(field_attribute);
104                            field_names.push(field_name);
105                        }
106
107                        for (index, field_attribute) in field_attributes.into_iter().enumerate() {
108                            let field_name = &field_names[index];
109
110                            let compare_trait = field_attribute.compare_trait;
111                            let compare_method = field_attribute.compare_method;
112
113                            match compare_trait {
114                                Some(compare_trait) => {
115                                    let compare_method = compare_method.unwrap();
116
117                                    block_tokens
118                                        .write_fmt(format_args!(
119                                            "if !{compare_trait}::{compare_method}({field_name}, \
120                                             ___{field_name}) {{ return false; }}",
121                                            compare_trait = compare_trait,
122                                            compare_method = compare_method,
123                                            field_name = field_name
124                                        ))
125                                        .unwrap();
126                                },
127                                None => match compare_method {
128                                    Some(compare_method) => {
129                                        block_tokens
130                                            .write_fmt(format_args!(
131                                                "if !{compare_method}({field_name}, \
132                                                 ___{field_name}) {{ return false; }}",
133                                                compare_method = compare_method,
134                                                field_name = field_name
135                                            ))
136                                            .unwrap();
137                                    },
138                                    None => {
139                                        block_tokens
140                                            .write_fmt(format_args!(
141                                                "if core::cmp::PartialEq::ne({field_name}, \
142                                                 ___{field_name}) {{ return false; }}",
143                                                field_name = field_name
144                                            ))
145                                            .unwrap();
146                                    },
147                                },
148                            }
149                        }
150
151                        match_tokens
152                            .write_fmt(format_args!(
153                                "{enum_name}::{variant_ident}{{ {pattern_tokens} }} => {{ if let \
154                                 {enum_name}::{variant_ident} {{ {pattern_2_tokens} }} = other {{ \
155                                 {block_tokens} }} else {{ return false; }} }}",
156                                enum_name = enum_name,
157                                variant_ident = variant_ident,
158                                pattern_tokens = pattern_tokens,
159                                pattern_2_tokens = pattern_2_tokens,
160                                block_tokens = block_tokens
161                            ))
162                            .unwrap();
163                    },
164                    Fields::Unnamed(fields) => {
165                        // TODO Tuple
166                        let mut pattern_tokens = String::new();
167                        let mut pattern_2_tokens = String::new();
168                        let mut block_tokens = String::new();
169
170                        let mut field_attributes = Vec::new();
171                        let mut field_names = Vec::new();
172
173                        for (index, field) in fields.unnamed.iter().enumerate() {
174                            let field_attribute = FieldAttributeBuilder {
175                                enable_ignore: true,
176                                enable_impl:   true,
177                            }
178                            .from_attributes(&field.attrs, traits);
179
180                            let field_name = format!("{}", index);
181
182                            if field_attribute.ignore {
183                                pattern_tokens.push_str("_,");
184                                pattern_2_tokens.push_str("_,");
185                                continue;
186                            }
187
188                            pattern_tokens
189                                .write_fmt(format_args!("_{field_name},", field_name = field_name))
190                                .unwrap();
191                            pattern_2_tokens
192                                .write_fmt(format_args!("__{field_name},", field_name = field_name))
193                                .unwrap();
194
195                            field_attributes.push(field_attribute);
196                            field_names.push(field_name);
197                        }
198
199                        for (index, field_attribute) in field_attributes.into_iter().enumerate() {
200                            let field_name = &field_names[index];
201
202                            let compare_trait = field_attribute.compare_trait;
203                            let compare_method = field_attribute.compare_method;
204
205                            match compare_trait {
206                                Some(compare_trait) => {
207                                    let compare_method = compare_method.unwrap();
208
209                                    block_tokens
210                                        .write_fmt(format_args!(
211                                            "if !{compare_trait}::{compare_method}(_{field_name}, \
212                                             __{field_name}) {{ return false; }}",
213                                            compare_trait = compare_trait,
214                                            compare_method = compare_method,
215                                            field_name = field_name
216                                        ))
217                                        .unwrap();
218                                },
219                                None => match compare_method {
220                                    Some(compare_method) => {
221                                        block_tokens
222                                            .write_fmt(format_args!(
223                                                "if !{compare_method}(_{field_name}, \
224                                                 __{field_name}) {{ return false; }}",
225                                                compare_method = compare_method,
226                                                field_name = field_name
227                                            ))
228                                            .unwrap();
229                                    },
230                                    None => {
231                                        block_tokens
232                                            .write_fmt(format_args!(
233                                                "if core::cmp::PartialEq::ne(_{field_name}, \
234                                                 __{field_name}) {{ return false; }}",
235                                                field_name = field_name
236                                            ))
237                                            .unwrap();
238                                    },
239                                },
240                            }
241                        }
242
243                        match_tokens
244                            .write_fmt(format_args!(
245                                "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ if let \
246                                 {enum_name}::{variant_ident} ( {pattern_2_tokens} ) = other {{ \
247                                 {block_tokens} }} else {{ return false; }} }}",
248                                enum_name = enum_name,
249                                variant_ident = variant_ident,
250                                pattern_tokens = pattern_tokens,
251                                pattern_2_tokens = pattern_2_tokens,
252                                block_tokens = block_tokens
253                            ))
254                            .unwrap();
255                    },
256                }
257            }
258        }
259
260        match_tokens.push('}');
261
262        comparer_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
263
264        let ident = &ast.ident;
265
266        let mut generics_cloned: Generics = ast.generics.clone();
267
268        let where_clause = generics_cloned.make_where_clause();
269
270        for where_predicate in bound {
271            where_clause.predicates.push(where_predicate);
272        }
273
274        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
275
276        let compare_impl = quote! {
277            impl #impl_generics core::cmp::PartialEq for #ident #ty_generics #where_clause {
278                #[inline]
279                #[allow(clippy::unneeded_field_pattern)]
280                fn eq(&self, other: &Self) -> bool {
281                    #comparer_tokens
282
283                    true
284                }
285            }
286        };
287
288        tokens.extend(compare_impl);
289    }
290}