educe/trait_handlers/ord/
ord_struct.rs

1use std::{collections::BTreeMap, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Generics, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct OrdStructHandler;
14
15impl TraitHandler for OrdStructHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag:  true,
24            enable_bound: true,
25            rank:         0,
26            enable_rank:  false,
27        }
28        .from_ord_meta(meta);
29
30        let bound = type_attribute
31            .bound
32            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
33
34        let mut comparer_tokens = TokenStream::new();
35
36        if let Data::Struct(data) = &ast.data {
37            let mut field_attributes = BTreeMap::new();
38            let mut field_names = BTreeMap::new();
39
40            for (index, field) in data.fields.iter().enumerate() {
41                let field_attribute = FieldAttributeBuilder {
42                    enable_ignore: true,
43                    enable_impl:   true,
44                    rank:          isize::MIN + index as isize,
45                    enable_rank:   true,
46                }
47                .from_attributes(&field.attrs, traits);
48
49                if field_attribute.ignore {
50                    continue;
51                }
52
53                let rank = field_attribute.rank;
54
55                if field_attributes.contains_key(&rank) {
56                    panic::reuse_a_rank(rank);
57                }
58
59                let field_name = if let Some(ident) = field.ident.as_ref() {
60                    ident.to_string()
61                } else {
62                    format!("{}", index)
63                };
64
65                field_attributes.insert(rank, field_attribute);
66                field_names.insert(rank, field_name);
67            }
68
69            for (index, field_attribute) in field_attributes {
70                let field_name = field_names.get(&index).unwrap();
71
72                let compare_trait = field_attribute.compare_trait;
73                let compare_method = field_attribute.compare_method;
74
75                match compare_trait {
76                    Some(compare_trait) => {
77                        let compare_method = compare_method.unwrap();
78
79                        let statement = format!(
80                            "match {compare_trait}::{compare_method}(&self.{field_name}, \
81                             &other.{field_name}) {{ core::cmp::Ordering::Equal => (), \
82                             core::cmp::Ordering::Greater => {{ return \
83                             core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => {{ \
84                             return core::cmp::Ordering::Less; }} }}",
85                            compare_trait = compare_trait,
86                            compare_method = compare_method,
87                            field_name = field_name
88                        );
89
90                        comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
91                    },
92                    None => match compare_method {
93                        Some(compare_method) => {
94                            let statement = format!(
95                                "match {compare_method}(&self.{field_name}, &other.{field_name}) \
96                                 {{ core::cmp::Ordering::Equal => (), \
97                                 core::cmp::Ordering::Greater => {{ return \
98                                 core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => \
99                                 {{ return core::cmp::Ordering::Less; }} }}",
100                                compare_method = compare_method,
101                                field_name = field_name
102                            );
103
104                            comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
105                        },
106                        None => {
107                            let statement = format!(
108                                "match core::cmp::Ord::cmp(&self.{field_name}, \
109                                 &other.{field_name}) {{ core::cmp::Ordering::Equal => (), \
110                                 core::cmp::Ordering::Greater => {{ return \
111                                 core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => \
112                                 {{ return core::cmp::Ordering::Less; }} }}",
113                                field_name = field_name
114                            );
115
116                            comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
117                        },
118                    },
119                }
120            }
121        }
122
123        let ident = &ast.ident;
124
125        let mut generics_cloned: Generics = ast.generics.clone();
126
127        let where_clause = generics_cloned.make_where_clause();
128
129        for where_predicate in bound {
130            where_clause.predicates.push(where_predicate);
131        }
132
133        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
134
135        let compare_impl = quote! {
136            impl #impl_generics core::cmp::Ord for #ident #ty_generics #where_clause {
137                #[inline]
138                fn cmp(&self, other: &Self) -> core::cmp::Ordering {
139                    #comparer_tokens
140
141                    core::cmp::Ordering::Equal
142                }
143            }
144        };
145
146        tokens.extend(compare_impl);
147    }
148}