educe/trait_handlers/partial_ord/
partial_ord_enum.rs

1use std::{collections::BTreeMap, fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Generics, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct PartialOrdEnumHandler;
14
15impl TraitHandler for PartialOrdEnumHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag:  true,
24            enable_bound: true,
25            rank:         0,
26            enable_rank:  false,
27        }
28        .from_partial_ord_meta(meta);
29
30        let enum_name = ast.ident.to_string();
31
32        let bound = type_attribute
33            .bound
34            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
35
36        let mut comparer_tokens = TokenStream::new();
37
38        let mut match_tokens = String::from("match self {");
39
40        let mut has_non_unit_or_custom_value = false;
41
42        if let Data::Enum(data) = &ast.data {
43            let mut variant_values = Vec::new();
44            let mut variant_idents = Vec::new();
45            let mut variants = Vec::new();
46
47            let mut variant_to_integer =
48                String::from("let variant_to_integer = |other: &Self| match other {");
49            let mut unit_to_integer =
50                String::from("let unit_to_integer = |other: &Self| match other {");
51
52            for (index, variant) in data.variants.iter().enumerate() {
53                let variant_attribute = TypeAttributeBuilder {
54                    enable_flag:  false,
55                    enable_bound: false,
56                    rank:         isize::MIN + index as isize,
57                    enable_rank:  true,
58                }
59                .from_attributes(&variant.attrs, traits);
60
61                let value = variant_attribute.rank;
62
63                if variant_values.contains(&value) {
64                    panic::reuse_a_value(value);
65                }
66
67                if value >= 0 {
68                    has_non_unit_or_custom_value = true;
69                }
70
71                let variant_ident = variant.ident.to_string();
72
73                match &variant.fields {
74                    Fields::Unit => {
75                        // TODO Unit
76                        unit_to_integer
77                            .write_fmt(format_args!(
78                                "{enum_name}::{variant_ident} => {enum_name}::{variant_ident} as \
79                                 isize,",
80                                enum_name = enum_name,
81                                variant_ident = variant_ident
82                            ))
83                            .unwrap();
84                        variant_to_integer
85                            .write_fmt(format_args!(
86                                "{enum_name}::{variant_ident} => {value},",
87                                enum_name = enum_name,
88                                variant_ident = variant_ident,
89                                value = value
90                            ))
91                            .unwrap();
92                    },
93                    Fields::Named(_) => {
94                        // TODO Struct
95                        has_non_unit_or_custom_value = true;
96
97                        variant_to_integer
98                            .write_fmt(format_args!(
99                                "{enum_name}::{variant_ident} {{ .. }} => {value},",
100                                enum_name = enum_name,
101                                variant_ident = variant_ident,
102                                value = value
103                            ))
104                            .unwrap();
105                    },
106                    Fields::Unnamed(fields) => {
107                        // TODO Tuple
108                        has_non_unit_or_custom_value = true;
109
110                        let mut pattern_tokens = String::new();
111
112                        for _ in fields.unnamed.iter() {
113                            pattern_tokens.push_str("_,");
114                        }
115
116                        variant_to_integer
117                            .write_fmt(format_args!(
118                                "{enum_name}::{variant_ident}( {pattern_tokens} ) => {value},",
119                                enum_name = enum_name,
120                                variant_ident = variant_ident,
121                                pattern_tokens = pattern_tokens,
122                                value = value
123                            ))
124                            .unwrap();
125                    },
126                }
127
128                variant_values.push(value);
129                variant_idents.push(variant_ident);
130                variants.push(variant);
131            }
132
133            if has_non_unit_or_custom_value {
134                variant_to_integer.push_str("};");
135
136                comparer_tokens.extend(TokenStream::from_str(&variant_to_integer).unwrap());
137
138                for (index, variant) in variants.into_iter().enumerate() {
139                    let variant_value = variant_values[index];
140                    let variant_ident = &variant_idents[index];
141
142                    match &variant.fields {
143                        Fields::Unit => {
144                            // TODO Unit
145                            match_tokens
146                                .write_fmt(format_args!(
147                                    "{enum_name}::{variant_ident} => {{ let other_value = \
148                                     variant_to_integer(other); return \
149                                     core::cmp::PartialOrd::partial_cmp(&{variant_value}, \
150                                     &other_value); }}",
151                                    enum_name = enum_name,
152                                    variant_ident = variant_ident,
153                                    variant_value = variant_value
154                                ))
155                                .unwrap();
156                        },
157                        Fields::Named(fields) => {
158                            // TODO Struct
159                            let mut pattern_tokens = String::new();
160                            let mut pattern_2_tokens = String::new();
161                            let mut block_tokens = String::new();
162
163                            let mut field_attributes = BTreeMap::new();
164                            let mut field_names = BTreeMap::new();
165
166                            for (index, field) in fields.named.iter().enumerate() {
167                                let field_attribute = FieldAttributeBuilder {
168                                    enable_ignore: true,
169                                    enable_impl:   true,
170                                    rank:          isize::MIN + index as isize,
171                                    enable_rank:   true,
172                                }
173                                .from_attributes(&field.attrs, traits);
174
175                                let field_name = field.ident.as_ref().unwrap().to_string();
176
177                                if field_attribute.ignore {
178                                    pattern_tokens
179                                        .write_fmt(format_args!(
180                                            "{field_name}: _,",
181                                            field_name = field_name
182                                        ))
183                                        .unwrap();
184                                    pattern_2_tokens
185                                        .write_fmt(format_args!(
186                                            "{field_name}: _,",
187                                            field_name = field_name
188                                        ))
189                                        .unwrap();
190                                    continue;
191                                }
192
193                                let rank = field_attribute.rank;
194
195                                if field_attributes.contains_key(&rank) {
196                                    panic::reuse_a_rank(rank);
197                                }
198
199                                pattern_tokens
200                                    .write_fmt(format_args!(
201                                        "{field_name},",
202                                        field_name = field_name
203                                    ))
204                                    .unwrap();
205                                pattern_2_tokens
206                                    .write_fmt(format_args!(
207                                        "{field_name}: ___{field_name},",
208                                        field_name = field_name
209                                    ))
210                                    .unwrap();
211
212                                field_attributes.insert(rank, field_attribute);
213                                field_names.insert(rank, field_name);
214                            }
215
216                            for (index, field_attribute) in field_attributes {
217                                let field_name = field_names.get(&index).unwrap();
218
219                                let compare_trait = field_attribute.compare_trait;
220                                let compare_method = field_attribute.compare_method;
221
222                                match compare_trait {
223                                    Some(compare_trait) => {
224                                        let compare_method = compare_method.unwrap();
225
226                                        block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}({field_name}, ___{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
227                                    },
228                                    None => {
229                                        match compare_method {
230                                            Some(compare_method) => {
231                                                block_tokens.write_fmt(format_args!("match {compare_method}({field_name}, ___{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_method = compare_method, field_name = field_name)).unwrap();
232                                            },
233                                            None => {
234                                                block_tokens.write_fmt(format_args!("match core::cmp::PartialOrd::partial_cmp({field_name}, ___{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", field_name = field_name)).unwrap();
235                                            },
236                                        }
237                                    },
238                                }
239                            }
240
241                            match_tokens
242                                .write_fmt(format_args!(
243                                    "{enum_name}::{variant_ident}{{ {pattern_tokens} }} => {{ if \
244                                     let {enum_name}::{variant_ident} {{ {pattern_2_tokens} }} = \
245                                     other {{ {block_tokens} }} else {{ let other_value = \
246                                     variant_to_integer(other); return \
247                                     core::cmp::PartialOrd::partial_cmp(&{variant_value}, \
248                                     &other_value); }} }}",
249                                    enum_name = enum_name,
250                                    variant_ident = variant_ident,
251                                    pattern_tokens = pattern_tokens,
252                                    pattern_2_tokens = pattern_2_tokens,
253                                    block_tokens = block_tokens,
254                                    variant_value = variant_value
255                                ))
256                                .unwrap();
257                        },
258                        Fields::Unnamed(fields) => {
259                            // TODO Tuple
260                            let mut pattern_tokens = String::new();
261                            let mut pattern_2_tokens = String::new();
262                            let mut block_tokens = String::new();
263
264                            let mut field_attributes = BTreeMap::new();
265                            let mut field_names = BTreeMap::new();
266
267                            for (index, field) in fields.unnamed.iter().enumerate() {
268                                let field_attribute = FieldAttributeBuilder {
269                                    enable_ignore: true,
270                                    enable_impl:   true,
271                                    rank:          isize::MIN + index as isize,
272                                    enable_rank:   true,
273                                }
274                                .from_attributes(&field.attrs, traits);
275
276                                let field_name = format!("{}", index);
277
278                                if field_attribute.ignore {
279                                    pattern_tokens.push_str("_,");
280                                    pattern_2_tokens.push_str("_,");
281                                    continue;
282                                }
283
284                                let rank = field_attribute.rank;
285
286                                if field_attributes.contains_key(&rank) {
287                                    panic::reuse_a_rank(rank);
288                                }
289
290                                pattern_tokens
291                                    .write_fmt(format_args!(
292                                        "_{field_name},",
293                                        field_name = field_name
294                                    ))
295                                    .unwrap();
296                                pattern_2_tokens
297                                    .write_fmt(format_args!(
298                                        "__{field_name},",
299                                        field_name = field_name
300                                    ))
301                                    .unwrap();
302
303                                field_attributes.insert(rank, field_attribute);
304                                field_names.insert(rank, field_name);
305                            }
306
307                            for (index, field_attribute) in field_attributes {
308                                let field_name = field_names.get(&index).unwrap();
309
310                                let compare_trait = field_attribute.compare_trait;
311                                let compare_method = field_attribute.compare_method;
312
313                                match compare_trait {
314                                    Some(compare_trait) => {
315                                        let compare_method = compare_method.unwrap();
316
317                                        block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}(_{field_name}, __{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
318                                    },
319                                    None => {
320                                        match compare_method {
321                                            Some(compare_method) => {
322                                                block_tokens.write_fmt(format_args!("match {compare_method}(_{field_name}, __{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_method = compare_method, field_name = field_name)).unwrap();
323                                            },
324                                            None => {
325                                                block_tokens.write_fmt(format_args!("match core::cmp::PartialOrd::partial_cmp(_{field_name}, __{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", field_name = field_name)).unwrap();
326                                            },
327                                        }
328                                    },
329                                }
330                            }
331
332                            match_tokens
333                                .write_fmt(format_args!(
334                                    "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ if \
335                                     let {enum_name}::{variant_ident} ( {pattern_2_tokens} ) = \
336                                     other {{ {block_tokens} }} else {{ let other_value = \
337                                     variant_to_integer(other); return \
338                                     core::cmp::PartialOrd::partial_cmp(&{variant_value}, \
339                                     &other_value); }} }}",
340                                    enum_name = enum_name,
341                                    variant_ident = variant_ident,
342                                    pattern_tokens = pattern_tokens,
343                                    pattern_2_tokens = pattern_2_tokens,
344                                    block_tokens = block_tokens,
345                                    variant_value = variant_value
346                                ))
347                                .unwrap();
348                        },
349                    }
350                }
351            } else {
352                unit_to_integer.push_str("};");
353
354                comparer_tokens.extend(TokenStream::from_str(&unit_to_integer).unwrap());
355
356                for (index, _) in variants.into_iter().enumerate() {
357                    let variant_ident = &variant_idents[index];
358
359                    match_tokens
360                        .write_fmt(format_args!(
361                            "{enum_name}::{variant_ident} => {{ let other_value = \
362                             unit_to_integer(other); return \
363                             core::cmp::PartialOrd::partial_cmp(&({enum_name}::{variant_ident} as \
364                             isize), &other_value); }}",
365                            enum_name = enum_name,
366                            variant_ident = variant_ident
367                        ))
368                        .unwrap();
369                }
370            }
371        }
372
373        match_tokens.push('}');
374
375        comparer_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
376
377        if has_non_unit_or_custom_value {
378            comparer_tokens.extend(quote!(Some(core::cmp::Ordering::Equal)));
379        }
380
381        let ident = &ast.ident;
382
383        let mut generics_cloned: Generics = ast.generics.clone();
384
385        let where_clause = generics_cloned.make_where_clause();
386
387        for where_predicate in bound {
388            where_clause.predicates.push(where_predicate);
389        }
390
391        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
392
393        let compare_impl = quote! {
394            impl #impl_generics core::cmp::PartialOrd for #ident #ty_generics #where_clause {
395                #[inline]
396                #[allow(unreachable_code, clippy::unneeded_field_pattern)]
397                fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
398                    #comparer_tokens
399                }
400            }
401        };
402
403        tokens.extend(compare_impl);
404    }
405}