educe/trait_handlers/ord/
ord_struct.rs1use std::{collections::BTreeMap, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Generics, Meta};
6
7use super::{
8 super::TraitHandler,
9 models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct OrdStructHandler;
14
15impl TraitHandler for OrdStructHandler {
16 fn trait_meta_handler(
17 ast: &DeriveInput,
18 tokens: &mut TokenStream,
19 traits: &[Trait],
20 meta: &Meta,
21 ) {
22 let type_attribute = TypeAttributeBuilder {
23 enable_flag: true,
24 enable_bound: true,
25 rank: 0,
26 enable_rank: false,
27 }
28 .from_ord_meta(meta);
29
30 let bound = type_attribute
31 .bound
32 .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
33
34 let mut comparer_tokens = TokenStream::new();
35
36 if let Data::Struct(data) = &ast.data {
37 let mut field_attributes = BTreeMap::new();
38 let mut field_names = BTreeMap::new();
39
40 for (index, field) in data.fields.iter().enumerate() {
41 let field_attribute = FieldAttributeBuilder {
42 enable_ignore: true,
43 enable_impl: true,
44 rank: isize::MIN + index as isize,
45 enable_rank: true,
46 }
47 .from_attributes(&field.attrs, traits);
48
49 if field_attribute.ignore {
50 continue;
51 }
52
53 let rank = field_attribute.rank;
54
55 if field_attributes.contains_key(&rank) {
56 panic::reuse_a_rank(rank);
57 }
58
59 let field_name = if let Some(ident) = field.ident.as_ref() {
60 ident.to_string()
61 } else {
62 format!("{}", index)
63 };
64
65 field_attributes.insert(rank, field_attribute);
66 field_names.insert(rank, field_name);
67 }
68
69 for (index, field_attribute) in field_attributes {
70 let field_name = field_names.get(&index).unwrap();
71
72 let compare_trait = field_attribute.compare_trait;
73 let compare_method = field_attribute.compare_method;
74
75 match compare_trait {
76 Some(compare_trait) => {
77 let compare_method = compare_method.unwrap();
78
79 let statement = format!(
80 "match {compare_trait}::{compare_method}(&self.{field_name}, \
81 &other.{field_name}) {{ core::cmp::Ordering::Equal => (), \
82 core::cmp::Ordering::Greater => {{ return \
83 core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => {{ \
84 return core::cmp::Ordering::Less; }} }}",
85 compare_trait = compare_trait,
86 compare_method = compare_method,
87 field_name = field_name
88 );
89
90 comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
91 },
92 None => match compare_method {
93 Some(compare_method) => {
94 let statement = format!(
95 "match {compare_method}(&self.{field_name}, &other.{field_name}) \
96 {{ core::cmp::Ordering::Equal => (), \
97 core::cmp::Ordering::Greater => {{ return \
98 core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => \
99 {{ return core::cmp::Ordering::Less; }} }}",
100 compare_method = compare_method,
101 field_name = field_name
102 );
103
104 comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
105 },
106 None => {
107 let statement = format!(
108 "match core::cmp::Ord::cmp(&self.{field_name}, \
109 &other.{field_name}) {{ core::cmp::Ordering::Equal => (), \
110 core::cmp::Ordering::Greater => {{ return \
111 core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => \
112 {{ return core::cmp::Ordering::Less; }} }}",
113 field_name = field_name
114 );
115
116 comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
117 },
118 },
119 }
120 }
121 }
122
123 let ident = &ast.ident;
124
125 let mut generics_cloned: Generics = ast.generics.clone();
126
127 let where_clause = generics_cloned.make_where_clause();
128
129 for where_predicate in bound {
130 where_clause.predicates.push(where_predicate);
131 }
132
133 let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
134
135 let compare_impl = quote! {
136 impl #impl_generics core::cmp::Ord for #ident #ty_generics #where_clause {
137 #[inline]
138 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
139 #comparer_tokens
140
141 core::cmp::Ordering::Equal
142 }
143 }
144 };
145
146 tokens.extend(compare_impl);
147 }
148}