educe/trait_handlers/ord/
ord_enum.rs

1use std::{collections::BTreeMap, fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Generics, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct OrdEnumHandler;
14
15impl TraitHandler for OrdEnumHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag:  true,
24            enable_bound: true,
25            rank:         0,
26            enable_rank:  false,
27        }
28        .from_ord_meta(meta);
29
30        let enum_name = ast.ident.to_string();
31
32        let bound = type_attribute
33            .bound
34            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
35
36        let mut comparer_tokens = TokenStream::new();
37
38        let mut match_tokens = String::from("match self {");
39
40        let mut has_non_unit_or_custom_value = false;
41
42        if let Data::Enum(data) = &ast.data {
43            let mut variant_values = Vec::new();
44            let mut variant_idents = Vec::new();
45            let mut variants = Vec::new();
46
47            let mut variant_to_integer =
48                String::from("let variant_to_integer = |other: &Self| match other {");
49            let mut unit_to_integer =
50                String::from("let unit_to_integer = |other: &Self| match other {");
51
52            for (index, variant) in data.variants.iter().enumerate() {
53                let variant_attribute = TypeAttributeBuilder {
54                    enable_flag:  false,
55                    enable_bound: false,
56                    rank:         isize::MIN + index as isize,
57                    enable_rank:  true,
58                }
59                .from_attributes(&variant.attrs, traits);
60
61                let value = variant_attribute.rank;
62
63                if variant_values.contains(&value) {
64                    panic::reuse_a_value(value);
65                }
66
67                if value >= 0 {
68                    has_non_unit_or_custom_value = true;
69                }
70
71                let variant_ident = variant.ident.to_string();
72
73                match &variant.fields {
74                    Fields::Unit => {
75                        // TODO Unit
76                        unit_to_integer
77                            .write_fmt(format_args!(
78                                "{enum_name}::{variant_ident} => {enum_name}::{variant_ident} as \
79                                 isize,",
80                                enum_name = enum_name,
81                                variant_ident = variant_ident
82                            ))
83                            .unwrap();
84                        variant_to_integer
85                            .write_fmt(format_args!(
86                                "{enum_name}::{variant_ident} => {value},",
87                                enum_name = enum_name,
88                                variant_ident = variant_ident,
89                                value = value
90                            ))
91                            .unwrap();
92                    },
93                    Fields::Named(_) => {
94                        // TODO Struct
95                        has_non_unit_or_custom_value = true;
96
97                        variant_to_integer
98                            .write_fmt(format_args!(
99                                "{enum_name}::{variant_ident} {{ .. }} => {value},",
100                                enum_name = enum_name,
101                                variant_ident = variant_ident,
102                                value = value
103                            ))
104                            .unwrap();
105                    },
106                    Fields::Unnamed(fields) => {
107                        // TODO Tuple
108                        has_non_unit_or_custom_value = true;
109
110                        let mut pattern_tokens = String::new();
111
112                        for _ in fields.unnamed.iter() {
113                            pattern_tokens.push_str("_,");
114                        }
115
116                        variant_to_integer
117                            .write_fmt(format_args!(
118                                "{enum_name}::{variant_ident}( {pattern_tokens} ) => {value},",
119                                enum_name = enum_name,
120                                variant_ident = variant_ident,
121                                pattern_tokens = pattern_tokens,
122                                value = value
123                            ))
124                            .unwrap();
125                    },
126                }
127
128                variant_values.push(value);
129                variant_idents.push(variant_ident);
130                variants.push(variant);
131            }
132
133            if has_non_unit_or_custom_value {
134                variant_to_integer.push_str("};");
135
136                comparer_tokens.extend(TokenStream::from_str(&variant_to_integer).unwrap());
137
138                for (index, variant) in variants.into_iter().enumerate() {
139                    let variant_value = variant_values[index];
140                    let variant_ident = &variant_idents[index];
141
142                    match &variant.fields {
143                        Fields::Unit => {
144                            // TODO Unit
145                            match_tokens
146                                .write_fmt(format_args!(
147                                    "{enum_name}::{variant_ident} => {{ let other_value = \
148                                     variant_to_integer(other); return \
149                                     core::cmp::Ord::cmp(&{variant_value}, &other_value); }}",
150                                    enum_name = enum_name,
151                                    variant_ident = variant_ident,
152                                    variant_value = variant_value
153                                ))
154                                .unwrap();
155                        },
156                        Fields::Named(fields) => {
157                            // TODO Struct
158                            let mut pattern_tokens = String::new();
159                            let mut pattern_2_tokens = String::new();
160                            let mut block_tokens = String::new();
161
162                            let mut field_attributes = BTreeMap::new();
163                            let mut field_names = BTreeMap::new();
164
165                            for (index, field) in fields.named.iter().enumerate() {
166                                let field_attribute = FieldAttributeBuilder {
167                                    enable_ignore: true,
168                                    enable_impl:   true,
169                                    rank:          isize::MIN + index as isize,
170                                    enable_rank:   true,
171                                }
172                                .from_attributes(&field.attrs, traits);
173
174                                let field_name = field.ident.as_ref().unwrap().to_string();
175
176                                if field_attribute.ignore {
177                                    pattern_tokens
178                                        .write_fmt(format_args!(
179                                            "{field_name}: _,",
180                                            field_name = field_name
181                                        ))
182                                        .unwrap();
183                                    pattern_2_tokens
184                                        .write_fmt(format_args!(
185                                            "{field_name}: _,",
186                                            field_name = field_name
187                                        ))
188                                        .unwrap();
189                                    continue;
190                                }
191
192                                let rank = field_attribute.rank;
193
194                                if field_attributes.contains_key(&rank) {
195                                    panic::reuse_a_rank(rank);
196                                }
197
198                                pattern_tokens
199                                    .write_fmt(format_args!(
200                                        "{field_name},",
201                                        field_name = field_name
202                                    ))
203                                    .unwrap();
204                                pattern_2_tokens
205                                    .write_fmt(format_args!(
206                                        "{field_name}: ___{field_name},",
207                                        field_name = field_name
208                                    ))
209                                    .unwrap();
210
211                                field_attributes.insert(rank, field_attribute);
212                                field_names.insert(rank, field_name);
213                            }
214
215                            for (index, field_attribute) in field_attributes {
216                                let field_name = field_names.get(&index).unwrap();
217
218                                let compare_trait = field_attribute.compare_trait;
219                                let compare_method = field_attribute.compare_method;
220
221                                match compare_trait {
222                                    Some(compare_trait) => {
223                                        let compare_method = compare_method.unwrap();
224
225                                        block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}({field_name}, ___{field_name}) {{ core::cmp::Ordering::Equal => (), core::cmp::Ordering::Greater => {{ return core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => {{ return core::cmp::Ordering::Less; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
226                                    },
227                                    None => match compare_method {
228                                        Some(compare_method) => {
229                                            block_tokens
230                                                .write_fmt(format_args!(
231                                                    "match {compare_method}({field_name}, \
232                                                     ___{field_name}) {{ \
233                                                     core::cmp::Ordering::Equal => (), \
234                                                     core::cmp::Ordering::Greater => {{ return \
235                                                     core::cmp::Ordering::Greater; }}, \
236                                                     core::cmp::Ordering::Less => {{ return \
237                                                     core::cmp::Ordering::Less; }} }}",
238                                                    compare_method = compare_method,
239                                                    field_name = field_name
240                                                ))
241                                                .unwrap();
242                                        },
243                                        None => {
244                                            block_tokens
245                                                .write_fmt(format_args!(
246                                                    "match core::cmp::Ord::cmp({field_name}, \
247                                                     ___{field_name}) {{ \
248                                                     core::cmp::Ordering::Equal => (), \
249                                                     core::cmp::Ordering::Greater => {{ return \
250                                                     core::cmp::Ordering::Greater; }}, \
251                                                     core::cmp::Ordering::Less => {{ return \
252                                                     core::cmp::Ordering::Less; }} }}",
253                                                    field_name = field_name
254                                                ))
255                                                .unwrap();
256                                        },
257                                    },
258                                }
259                            }
260
261                            match_tokens
262                                .write_fmt(format_args!(
263                                    "{enum_name}::{variant_ident}{{ {pattern_tokens} }} => {{ if \
264                                     let {enum_name}::{variant_ident} {{ {pattern_2_tokens} }} = \
265                                     other {{ {block_tokens} }} else {{ let other_value = \
266                                     variant_to_integer(other); return \
267                                     core::cmp::Ord::cmp(&{variant_value}, &other_value); }} }}",
268                                    enum_name = enum_name,
269                                    variant_ident = variant_ident,
270                                    pattern_tokens = pattern_tokens,
271                                    pattern_2_tokens = pattern_2_tokens,
272                                    block_tokens = block_tokens,
273                                    variant_value = variant_value
274                                ))
275                                .unwrap();
276                        },
277                        Fields::Unnamed(fields) => {
278                            // TODO Tuple
279                            let mut pattern_tokens = String::new();
280                            let mut pattern_2_tokens = String::new();
281                            let mut block_tokens = String::new();
282
283                            let mut field_attributes = BTreeMap::new();
284                            let mut field_names = BTreeMap::new();
285
286                            for (index, field) in fields.unnamed.iter().enumerate() {
287                                let field_attribute = FieldAttributeBuilder {
288                                    enable_ignore: true,
289                                    enable_impl:   true,
290                                    rank:          isize::MIN + index as isize,
291                                    enable_rank:   true,
292                                }
293                                .from_attributes(&field.attrs, traits);
294
295                                let field_name = format!("{}", index);
296
297                                if field_attribute.ignore {
298                                    pattern_tokens.push_str("_,");
299                                    pattern_2_tokens.push_str("_,");
300                                    continue;
301                                }
302
303                                let rank = field_attribute.rank;
304
305                                if field_attributes.contains_key(&rank) {
306                                    panic::reuse_a_rank(rank);
307                                }
308
309                                pattern_tokens
310                                    .write_fmt(format_args!(
311                                        "_{field_name},",
312                                        field_name = field_name
313                                    ))
314                                    .unwrap();
315                                pattern_2_tokens
316                                    .write_fmt(format_args!(
317                                        "__{field_name},",
318                                        field_name = field_name
319                                    ))
320                                    .unwrap();
321
322                                field_attributes.insert(rank, field_attribute);
323                                field_names.insert(rank, field_name);
324                            }
325
326                            for (index, field_attribute) in field_attributes {
327                                let field_name = field_names.get(&index).unwrap();
328
329                                let compare_trait = field_attribute.compare_trait;
330                                let compare_method = field_attribute.compare_method;
331
332                                match compare_trait {
333                                    Some(compare_trait) => {
334                                        let compare_method = compare_method.unwrap();
335
336                                        block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}(_{field_name}, __{field_name}) {{ core::cmp::Ordering::Equal => (), core::cmp::Ordering::Greater => {{ return core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => {{ return core::cmp::Ordering::Less; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
337                                    },
338                                    None => match compare_method {
339                                        Some(compare_method) => {
340                                            block_tokens
341                                                .write_fmt(format_args!(
342                                                    "match {compare_method}(_{field_name}, \
343                                                     __{field_name}) {{ \
344                                                     core::cmp::Ordering::Equal => (), \
345                                                     core::cmp::Ordering::Greater => {{ return \
346                                                     core::cmp::Ordering::Greater; }}, \
347                                                     core::cmp::Ordering::Less => {{ return \
348                                                     core::cmp::Ordering::Less; }} }}",
349                                                    compare_method = compare_method,
350                                                    field_name = field_name
351                                                ))
352                                                .unwrap();
353                                        },
354                                        None => {
355                                            block_tokens
356                                                .write_fmt(format_args!(
357                                                    "match core::cmp::Ord::cmp(_{field_name}, \
358                                                     __{field_name}) {{ \
359                                                     core::cmp::Ordering::Equal => (), \
360                                                     core::cmp::Ordering::Greater => {{ return \
361                                                     core::cmp::Ordering::Greater; }}, \
362                                                     core::cmp::Ordering::Less => {{ return \
363                                                     core::cmp::Ordering::Less; }} }}",
364                                                    field_name = field_name
365                                                ))
366                                                .unwrap();
367                                        },
368                                    },
369                                }
370                            }
371
372                            match_tokens
373                                .write_fmt(format_args!(
374                                    "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ if \
375                                     let {enum_name}::{variant_ident} ( {pattern_2_tokens} ) = \
376                                     other {{ {block_tokens} }} else {{ let other_value = \
377                                     variant_to_integer(other); return \
378                                     core::cmp::Ord::cmp(&{variant_value}, &other_value); }} }}",
379                                    enum_name = enum_name,
380                                    variant_ident = variant_ident,
381                                    pattern_tokens = pattern_tokens,
382                                    pattern_2_tokens = pattern_2_tokens,
383                                    block_tokens = block_tokens,
384                                    variant_value = variant_value
385                                ))
386                                .unwrap();
387                        },
388                    }
389                }
390            } else {
391                unit_to_integer.push_str("};");
392
393                comparer_tokens.extend(TokenStream::from_str(&unit_to_integer).unwrap());
394
395                for (index, _) in variants.into_iter().enumerate() {
396                    let variant_ident = &variant_idents[index];
397
398                    match_tokens
399                        .write_fmt(format_args!(
400                            "{enum_name}::{variant_ident} => {{ let other_value = \
401                             unit_to_integer(other); return \
402                             core::cmp::Ord::cmp(&({enum_name}::{variant_ident} as isize), \
403                             &other_value); }}",
404                            enum_name = enum_name,
405                            variant_ident = variant_ident
406                        ))
407                        .unwrap();
408                }
409            }
410        }
411
412        match_tokens.push('}');
413
414        comparer_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
415
416        if has_non_unit_or_custom_value {
417            comparer_tokens.extend(quote!(core::cmp::Ordering::Equal));
418        }
419
420        let ident = &ast.ident;
421
422        let mut generics_cloned: Generics = ast.generics.clone();
423
424        let where_clause = generics_cloned.make_where_clause();
425
426        for where_predicate in bound {
427            where_clause.predicates.push(where_predicate);
428        }
429
430        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
431
432        let compare_impl = quote! {
433            impl #impl_generics core::cmp::Ord for #ident #ty_generics #where_clause {
434                #[inline]
435                #[allow(unreachable_code, clippy::unneeded_field_pattern)]
436                fn cmp(&self, other: &Self) -> core::cmp::Ordering {
437                    #comparer_tokens
438                }
439            }
440        };
441
442        tokens.extend(compare_impl);
443    }
444}