educe/trait_handlers/partial_ord/
partial_ord_struct.rs

1use std::{collections::BTreeMap, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Generics, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct PartialOrdStructHandler;
14
15impl TraitHandler for PartialOrdStructHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag:  true,
24            enable_bound: true,
25            rank:         0,
26            enable_rank:  false,
27        }
28        .from_partial_ord_meta(meta);
29
30        let bound = type_attribute
31            .bound
32            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
33
34        let mut comparer_tokens = TokenStream::new();
35
36        if let Data::Struct(data) = &ast.data {
37            let mut field_attributes = BTreeMap::new();
38            let mut field_names = BTreeMap::new();
39
40            for (index, field) in data.fields.iter().enumerate() {
41                let field_attribute = FieldAttributeBuilder {
42                    enable_ignore: true,
43                    enable_impl:   true,
44                    rank:          isize::MIN + index as isize,
45                    enable_rank:   true,
46                }
47                .from_attributes(&field.attrs, traits);
48
49                if field_attribute.ignore {
50                    continue;
51                }
52
53                let rank = field_attribute.rank;
54
55                if field_attributes.contains_key(&rank) {
56                    panic::reuse_a_rank(rank);
57                }
58
59                let field_name = if let Some(ident) = field.ident.as_ref() {
60                    ident.to_string()
61                } else {
62                    format!("{}", index)
63                };
64
65                field_attributes.insert(rank, field_attribute);
66                field_names.insert(rank, field_name);
67            }
68
69            for (index, field_attribute) in field_attributes {
70                let field_name = field_names.get(&index).unwrap();
71
72                let compare_trait = field_attribute.compare_trait;
73                let compare_method = field_attribute.compare_method;
74
75                match compare_trait {
76                    Some(compare_trait) => {
77                        let compare_method = compare_method.unwrap();
78
79                        let statement = format!(
80                            "match {compare_trait}::{compare_method}(&self.{field_name}, \
81                             &other.{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), \
82                             Some(core::cmp::Ordering::Greater) => {{ return \
83                             Some(core::cmp::Ordering::Greater); }}, \
84                             Some(core::cmp::Ordering::Less) => {{ return \
85                             Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}",
86                            compare_trait = compare_trait,
87                            compare_method = compare_method,
88                            field_name = field_name
89                        );
90
91                        comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
92                    },
93                    None => match compare_method {
94                        Some(compare_method) => {
95                            let statement = format!(
96                                "match {compare_method}(&self.{field_name}, &other.{field_name}) \
97                                 {{ Some(core::cmp::Ordering::Equal) => (), \
98                                 Some(core::cmp::Ordering::Greater) => {{ return \
99                                 Some(core::cmp::Ordering::Greater); }}, \
100                                 Some(core::cmp::Ordering::Less) => {{ return \
101                                 Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} \
102                                 }}",
103                                compare_method = compare_method,
104                                field_name = field_name
105                            );
106
107                            comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
108                        },
109                        None => {
110                            let statement = format!(
111                                "match core::cmp::PartialOrd::partial_cmp(&self.{field_name}, \
112                                 &other.{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), \
113                                 Some(core::cmp::Ordering::Greater) => {{ return \
114                                 Some(core::cmp::Ordering::Greater); }}, \
115                                 Some(core::cmp::Ordering::Less) => {{ return \
116                                 Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} \
117                                 }}",
118                                field_name = field_name
119                            );
120
121                            comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
122                        },
123                    },
124                }
125            }
126        }
127
128        let ident = &ast.ident;
129
130        let mut generics_cloned: Generics = ast.generics.clone();
131
132        let where_clause = generics_cloned.make_where_clause();
133
134        for where_predicate in bound {
135            where_clause.predicates.push(where_predicate);
136        }
137
138        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
139
140        let compare_impl = quote! {
141            impl #impl_generics core::cmp::PartialOrd for #ident #ty_generics #where_clause {
142                #[inline]
143                fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
144                    #comparer_tokens
145
146                    Some(core::cmp::Ordering::Equal)
147                }
148            }
149        };
150
151        tokens.extend(compare_impl);
152    }
153}