educe/trait_handlers/ord/
ord_enum.rs1use std::{collections::BTreeMap, fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Generics, Meta};
6
7use super::{
8 super::TraitHandler,
9 models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct OrdEnumHandler;
14
15impl TraitHandler for OrdEnumHandler {
16 fn trait_meta_handler(
17 ast: &DeriveInput,
18 tokens: &mut TokenStream,
19 traits: &[Trait],
20 meta: &Meta,
21 ) {
22 let type_attribute = TypeAttributeBuilder {
23 enable_flag: true,
24 enable_bound: true,
25 rank: 0,
26 enable_rank: false,
27 }
28 .from_ord_meta(meta);
29
30 let enum_name = ast.ident.to_string();
31
32 let bound = type_attribute
33 .bound
34 .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
35
36 let mut comparer_tokens = TokenStream::new();
37
38 let mut match_tokens = String::from("match self {");
39
40 let mut has_non_unit_or_custom_value = false;
41
42 if let Data::Enum(data) = &ast.data {
43 let mut variant_values = Vec::new();
44 let mut variant_idents = Vec::new();
45 let mut variants = Vec::new();
46
47 let mut variant_to_integer =
48 String::from("let variant_to_integer = |other: &Self| match other {");
49 let mut unit_to_integer =
50 String::from("let unit_to_integer = |other: &Self| match other {");
51
52 for (index, variant) in data.variants.iter().enumerate() {
53 let variant_attribute = TypeAttributeBuilder {
54 enable_flag: false,
55 enable_bound: false,
56 rank: isize::MIN + index as isize,
57 enable_rank: true,
58 }
59 .from_attributes(&variant.attrs, traits);
60
61 let value = variant_attribute.rank;
62
63 if variant_values.contains(&value) {
64 panic::reuse_a_value(value);
65 }
66
67 if value >= 0 {
68 has_non_unit_or_custom_value = true;
69 }
70
71 let variant_ident = variant.ident.to_string();
72
73 match &variant.fields {
74 Fields::Unit => {
75 unit_to_integer
77 .write_fmt(format_args!(
78 "{enum_name}::{variant_ident} => {enum_name}::{variant_ident} as \
79 isize,",
80 enum_name = enum_name,
81 variant_ident = variant_ident
82 ))
83 .unwrap();
84 variant_to_integer
85 .write_fmt(format_args!(
86 "{enum_name}::{variant_ident} => {value},",
87 enum_name = enum_name,
88 variant_ident = variant_ident,
89 value = value
90 ))
91 .unwrap();
92 },
93 Fields::Named(_) => {
94 has_non_unit_or_custom_value = true;
96
97 variant_to_integer
98 .write_fmt(format_args!(
99 "{enum_name}::{variant_ident} {{ .. }} => {value},",
100 enum_name = enum_name,
101 variant_ident = variant_ident,
102 value = value
103 ))
104 .unwrap();
105 },
106 Fields::Unnamed(fields) => {
107 has_non_unit_or_custom_value = true;
109
110 let mut pattern_tokens = String::new();
111
112 for _ in fields.unnamed.iter() {
113 pattern_tokens.push_str("_,");
114 }
115
116 variant_to_integer
117 .write_fmt(format_args!(
118 "{enum_name}::{variant_ident}( {pattern_tokens} ) => {value},",
119 enum_name = enum_name,
120 variant_ident = variant_ident,
121 pattern_tokens = pattern_tokens,
122 value = value
123 ))
124 .unwrap();
125 },
126 }
127
128 variant_values.push(value);
129 variant_idents.push(variant_ident);
130 variants.push(variant);
131 }
132
133 if has_non_unit_or_custom_value {
134 variant_to_integer.push_str("};");
135
136 comparer_tokens.extend(TokenStream::from_str(&variant_to_integer).unwrap());
137
138 for (index, variant) in variants.into_iter().enumerate() {
139 let variant_value = variant_values[index];
140 let variant_ident = &variant_idents[index];
141
142 match &variant.fields {
143 Fields::Unit => {
144 match_tokens
146 .write_fmt(format_args!(
147 "{enum_name}::{variant_ident} => {{ let other_value = \
148 variant_to_integer(other); return \
149 core::cmp::Ord::cmp(&{variant_value}, &other_value); }}",
150 enum_name = enum_name,
151 variant_ident = variant_ident,
152 variant_value = variant_value
153 ))
154 .unwrap();
155 },
156 Fields::Named(fields) => {
157 let mut pattern_tokens = String::new();
159 let mut pattern_2_tokens = String::new();
160 let mut block_tokens = String::new();
161
162 let mut field_attributes = BTreeMap::new();
163 let mut field_names = BTreeMap::new();
164
165 for (index, field) in fields.named.iter().enumerate() {
166 let field_attribute = FieldAttributeBuilder {
167 enable_ignore: true,
168 enable_impl: true,
169 rank: isize::MIN + index as isize,
170 enable_rank: true,
171 }
172 .from_attributes(&field.attrs, traits);
173
174 let field_name = field.ident.as_ref().unwrap().to_string();
175
176 if field_attribute.ignore {
177 pattern_tokens
178 .write_fmt(format_args!(
179 "{field_name}: _,",
180 field_name = field_name
181 ))
182 .unwrap();
183 pattern_2_tokens
184 .write_fmt(format_args!(
185 "{field_name}: _,",
186 field_name = field_name
187 ))
188 .unwrap();
189 continue;
190 }
191
192 let rank = field_attribute.rank;
193
194 if field_attributes.contains_key(&rank) {
195 panic::reuse_a_rank(rank);
196 }
197
198 pattern_tokens
199 .write_fmt(format_args!(
200 "{field_name},",
201 field_name = field_name
202 ))
203 .unwrap();
204 pattern_2_tokens
205 .write_fmt(format_args!(
206 "{field_name}: ___{field_name},",
207 field_name = field_name
208 ))
209 .unwrap();
210
211 field_attributes.insert(rank, field_attribute);
212 field_names.insert(rank, field_name);
213 }
214
215 for (index, field_attribute) in field_attributes {
216 let field_name = field_names.get(&index).unwrap();
217
218 let compare_trait = field_attribute.compare_trait;
219 let compare_method = field_attribute.compare_method;
220
221 match compare_trait {
222 Some(compare_trait) => {
223 let compare_method = compare_method.unwrap();
224
225 block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}({field_name}, ___{field_name}) {{ core::cmp::Ordering::Equal => (), core::cmp::Ordering::Greater => {{ return core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => {{ return core::cmp::Ordering::Less; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
226 },
227 None => match compare_method {
228 Some(compare_method) => {
229 block_tokens
230 .write_fmt(format_args!(
231 "match {compare_method}({field_name}, \
232 ___{field_name}) {{ \
233 core::cmp::Ordering::Equal => (), \
234 core::cmp::Ordering::Greater => {{ return \
235 core::cmp::Ordering::Greater; }}, \
236 core::cmp::Ordering::Less => {{ return \
237 core::cmp::Ordering::Less; }} }}",
238 compare_method = compare_method,
239 field_name = field_name
240 ))
241 .unwrap();
242 },
243 None => {
244 block_tokens
245 .write_fmt(format_args!(
246 "match core::cmp::Ord::cmp({field_name}, \
247 ___{field_name}) {{ \
248 core::cmp::Ordering::Equal => (), \
249 core::cmp::Ordering::Greater => {{ return \
250 core::cmp::Ordering::Greater; }}, \
251 core::cmp::Ordering::Less => {{ return \
252 core::cmp::Ordering::Less; }} }}",
253 field_name = field_name
254 ))
255 .unwrap();
256 },
257 },
258 }
259 }
260
261 match_tokens
262 .write_fmt(format_args!(
263 "{enum_name}::{variant_ident}{{ {pattern_tokens} }} => {{ if \
264 let {enum_name}::{variant_ident} {{ {pattern_2_tokens} }} = \
265 other {{ {block_tokens} }} else {{ let other_value = \
266 variant_to_integer(other); return \
267 core::cmp::Ord::cmp(&{variant_value}, &other_value); }} }}",
268 enum_name = enum_name,
269 variant_ident = variant_ident,
270 pattern_tokens = pattern_tokens,
271 pattern_2_tokens = pattern_2_tokens,
272 block_tokens = block_tokens,
273 variant_value = variant_value
274 ))
275 .unwrap();
276 },
277 Fields::Unnamed(fields) => {
278 let mut pattern_tokens = String::new();
280 let mut pattern_2_tokens = String::new();
281 let mut block_tokens = String::new();
282
283 let mut field_attributes = BTreeMap::new();
284 let mut field_names = BTreeMap::new();
285
286 for (index, field) in fields.unnamed.iter().enumerate() {
287 let field_attribute = FieldAttributeBuilder {
288 enable_ignore: true,
289 enable_impl: true,
290 rank: isize::MIN + index as isize,
291 enable_rank: true,
292 }
293 .from_attributes(&field.attrs, traits);
294
295 let field_name = format!("{}", index);
296
297 if field_attribute.ignore {
298 pattern_tokens.push_str("_,");
299 pattern_2_tokens.push_str("_,");
300 continue;
301 }
302
303 let rank = field_attribute.rank;
304
305 if field_attributes.contains_key(&rank) {
306 panic::reuse_a_rank(rank);
307 }
308
309 pattern_tokens
310 .write_fmt(format_args!(
311 "_{field_name},",
312 field_name = field_name
313 ))
314 .unwrap();
315 pattern_2_tokens
316 .write_fmt(format_args!(
317 "__{field_name},",
318 field_name = field_name
319 ))
320 .unwrap();
321
322 field_attributes.insert(rank, field_attribute);
323 field_names.insert(rank, field_name);
324 }
325
326 for (index, field_attribute) in field_attributes {
327 let field_name = field_names.get(&index).unwrap();
328
329 let compare_trait = field_attribute.compare_trait;
330 let compare_method = field_attribute.compare_method;
331
332 match compare_trait {
333 Some(compare_trait) => {
334 let compare_method = compare_method.unwrap();
335
336 block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}(_{field_name}, __{field_name}) {{ core::cmp::Ordering::Equal => (), core::cmp::Ordering::Greater => {{ return core::cmp::Ordering::Greater; }}, core::cmp::Ordering::Less => {{ return core::cmp::Ordering::Less; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
337 },
338 None => match compare_method {
339 Some(compare_method) => {
340 block_tokens
341 .write_fmt(format_args!(
342 "match {compare_method}(_{field_name}, \
343 __{field_name}) {{ \
344 core::cmp::Ordering::Equal => (), \
345 core::cmp::Ordering::Greater => {{ return \
346 core::cmp::Ordering::Greater; }}, \
347 core::cmp::Ordering::Less => {{ return \
348 core::cmp::Ordering::Less; }} }}",
349 compare_method = compare_method,
350 field_name = field_name
351 ))
352 .unwrap();
353 },
354 None => {
355 block_tokens
356 .write_fmt(format_args!(
357 "match core::cmp::Ord::cmp(_{field_name}, \
358 __{field_name}) {{ \
359 core::cmp::Ordering::Equal => (), \
360 core::cmp::Ordering::Greater => {{ return \
361 core::cmp::Ordering::Greater; }}, \
362 core::cmp::Ordering::Less => {{ return \
363 core::cmp::Ordering::Less; }} }}",
364 field_name = field_name
365 ))
366 .unwrap();
367 },
368 },
369 }
370 }
371
372 match_tokens
373 .write_fmt(format_args!(
374 "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ if \
375 let {enum_name}::{variant_ident} ( {pattern_2_tokens} ) = \
376 other {{ {block_tokens} }} else {{ let other_value = \
377 variant_to_integer(other); return \
378 core::cmp::Ord::cmp(&{variant_value}, &other_value); }} }}",
379 enum_name = enum_name,
380 variant_ident = variant_ident,
381 pattern_tokens = pattern_tokens,
382 pattern_2_tokens = pattern_2_tokens,
383 block_tokens = block_tokens,
384 variant_value = variant_value
385 ))
386 .unwrap();
387 },
388 }
389 }
390 } else {
391 unit_to_integer.push_str("};");
392
393 comparer_tokens.extend(TokenStream::from_str(&unit_to_integer).unwrap());
394
395 for (index, _) in variants.into_iter().enumerate() {
396 let variant_ident = &variant_idents[index];
397
398 match_tokens
399 .write_fmt(format_args!(
400 "{enum_name}::{variant_ident} => {{ let other_value = \
401 unit_to_integer(other); return \
402 core::cmp::Ord::cmp(&({enum_name}::{variant_ident} as isize), \
403 &other_value); }}",
404 enum_name = enum_name,
405 variant_ident = variant_ident
406 ))
407 .unwrap();
408 }
409 }
410 }
411
412 match_tokens.push('}');
413
414 comparer_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
415
416 if has_non_unit_or_custom_value {
417 comparer_tokens.extend(quote!(core::cmp::Ordering::Equal));
418 }
419
420 let ident = &ast.ident;
421
422 let mut generics_cloned: Generics = ast.generics.clone();
423
424 let where_clause = generics_cloned.make_where_clause();
425
426 for where_predicate in bound {
427 where_clause.predicates.push(where_predicate);
428 }
429
430 let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
431
432 let compare_impl = quote! {
433 impl #impl_generics core::cmp::Ord for #ident #ty_generics #where_clause {
434 #[inline]
435 #[allow(unreachable_code, clippy::unneeded_field_pattern)]
436 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
437 #comparer_tokens
438 }
439 }
440 };
441
442 tokens.extend(compare_impl);
443 }
444}