educe/trait_handlers/partial_ord/
partial_ord_enum.rs1use std::{collections::BTreeMap, fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Generics, Meta};
6
7use super::{
8 super::TraitHandler,
9 models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct PartialOrdEnumHandler;
14
15impl TraitHandler for PartialOrdEnumHandler {
16 fn trait_meta_handler(
17 ast: &DeriveInput,
18 tokens: &mut TokenStream,
19 traits: &[Trait],
20 meta: &Meta,
21 ) {
22 let type_attribute = TypeAttributeBuilder {
23 enable_flag: true,
24 enable_bound: true,
25 rank: 0,
26 enable_rank: false,
27 }
28 .from_partial_ord_meta(meta);
29
30 let enum_name = ast.ident.to_string();
31
32 let bound = type_attribute
33 .bound
34 .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
35
36 let mut comparer_tokens = TokenStream::new();
37
38 let mut match_tokens = String::from("match self {");
39
40 let mut has_non_unit_or_custom_value = false;
41
42 if let Data::Enum(data) = &ast.data {
43 let mut variant_values = Vec::new();
44 let mut variant_idents = Vec::new();
45 let mut variants = Vec::new();
46
47 let mut variant_to_integer =
48 String::from("let variant_to_integer = |other: &Self| match other {");
49 let mut unit_to_integer =
50 String::from("let unit_to_integer = |other: &Self| match other {");
51
52 for (index, variant) in data.variants.iter().enumerate() {
53 let variant_attribute = TypeAttributeBuilder {
54 enable_flag: false,
55 enable_bound: false,
56 rank: isize::MIN + index as isize,
57 enable_rank: true,
58 }
59 .from_attributes(&variant.attrs, traits);
60
61 let value = variant_attribute.rank;
62
63 if variant_values.contains(&value) {
64 panic::reuse_a_value(value);
65 }
66
67 if value >= 0 {
68 has_non_unit_or_custom_value = true;
69 }
70
71 let variant_ident = variant.ident.to_string();
72
73 match &variant.fields {
74 Fields::Unit => {
75 unit_to_integer
77 .write_fmt(format_args!(
78 "{enum_name}::{variant_ident} => {enum_name}::{variant_ident} as \
79 isize,",
80 enum_name = enum_name,
81 variant_ident = variant_ident
82 ))
83 .unwrap();
84 variant_to_integer
85 .write_fmt(format_args!(
86 "{enum_name}::{variant_ident} => {value},",
87 enum_name = enum_name,
88 variant_ident = variant_ident,
89 value = value
90 ))
91 .unwrap();
92 },
93 Fields::Named(_) => {
94 has_non_unit_or_custom_value = true;
96
97 variant_to_integer
98 .write_fmt(format_args!(
99 "{enum_name}::{variant_ident} {{ .. }} => {value},",
100 enum_name = enum_name,
101 variant_ident = variant_ident,
102 value = value
103 ))
104 .unwrap();
105 },
106 Fields::Unnamed(fields) => {
107 has_non_unit_or_custom_value = true;
109
110 let mut pattern_tokens = String::new();
111
112 for _ in fields.unnamed.iter() {
113 pattern_tokens.push_str("_,");
114 }
115
116 variant_to_integer
117 .write_fmt(format_args!(
118 "{enum_name}::{variant_ident}( {pattern_tokens} ) => {value},",
119 enum_name = enum_name,
120 variant_ident = variant_ident,
121 pattern_tokens = pattern_tokens,
122 value = value
123 ))
124 .unwrap();
125 },
126 }
127
128 variant_values.push(value);
129 variant_idents.push(variant_ident);
130 variants.push(variant);
131 }
132
133 if has_non_unit_or_custom_value {
134 variant_to_integer.push_str("};");
135
136 comparer_tokens.extend(TokenStream::from_str(&variant_to_integer).unwrap());
137
138 for (index, variant) in variants.into_iter().enumerate() {
139 let variant_value = variant_values[index];
140 let variant_ident = &variant_idents[index];
141
142 match &variant.fields {
143 Fields::Unit => {
144 match_tokens
146 .write_fmt(format_args!(
147 "{enum_name}::{variant_ident} => {{ let other_value = \
148 variant_to_integer(other); return \
149 core::cmp::PartialOrd::partial_cmp(&{variant_value}, \
150 &other_value); }}",
151 enum_name = enum_name,
152 variant_ident = variant_ident,
153 variant_value = variant_value
154 ))
155 .unwrap();
156 },
157 Fields::Named(fields) => {
158 let mut pattern_tokens = String::new();
160 let mut pattern_2_tokens = String::new();
161 let mut block_tokens = String::new();
162
163 let mut field_attributes = BTreeMap::new();
164 let mut field_names = BTreeMap::new();
165
166 for (index, field) in fields.named.iter().enumerate() {
167 let field_attribute = FieldAttributeBuilder {
168 enable_ignore: true,
169 enable_impl: true,
170 rank: isize::MIN + index as isize,
171 enable_rank: true,
172 }
173 .from_attributes(&field.attrs, traits);
174
175 let field_name = field.ident.as_ref().unwrap().to_string();
176
177 if field_attribute.ignore {
178 pattern_tokens
179 .write_fmt(format_args!(
180 "{field_name}: _,",
181 field_name = field_name
182 ))
183 .unwrap();
184 pattern_2_tokens
185 .write_fmt(format_args!(
186 "{field_name}: _,",
187 field_name = field_name
188 ))
189 .unwrap();
190 continue;
191 }
192
193 let rank = field_attribute.rank;
194
195 if field_attributes.contains_key(&rank) {
196 panic::reuse_a_rank(rank);
197 }
198
199 pattern_tokens
200 .write_fmt(format_args!(
201 "{field_name},",
202 field_name = field_name
203 ))
204 .unwrap();
205 pattern_2_tokens
206 .write_fmt(format_args!(
207 "{field_name}: ___{field_name},",
208 field_name = field_name
209 ))
210 .unwrap();
211
212 field_attributes.insert(rank, field_attribute);
213 field_names.insert(rank, field_name);
214 }
215
216 for (index, field_attribute) in field_attributes {
217 let field_name = field_names.get(&index).unwrap();
218
219 let compare_trait = field_attribute.compare_trait;
220 let compare_method = field_attribute.compare_method;
221
222 match compare_trait {
223 Some(compare_trait) => {
224 let compare_method = compare_method.unwrap();
225
226 block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}({field_name}, ___{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
227 },
228 None => {
229 match compare_method {
230 Some(compare_method) => {
231 block_tokens.write_fmt(format_args!("match {compare_method}({field_name}, ___{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_method = compare_method, field_name = field_name)).unwrap();
232 },
233 None => {
234 block_tokens.write_fmt(format_args!("match core::cmp::PartialOrd::partial_cmp({field_name}, ___{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", field_name = field_name)).unwrap();
235 },
236 }
237 },
238 }
239 }
240
241 match_tokens
242 .write_fmt(format_args!(
243 "{enum_name}::{variant_ident}{{ {pattern_tokens} }} => {{ if \
244 let {enum_name}::{variant_ident} {{ {pattern_2_tokens} }} = \
245 other {{ {block_tokens} }} else {{ let other_value = \
246 variant_to_integer(other); return \
247 core::cmp::PartialOrd::partial_cmp(&{variant_value}, \
248 &other_value); }} }}",
249 enum_name = enum_name,
250 variant_ident = variant_ident,
251 pattern_tokens = pattern_tokens,
252 pattern_2_tokens = pattern_2_tokens,
253 block_tokens = block_tokens,
254 variant_value = variant_value
255 ))
256 .unwrap();
257 },
258 Fields::Unnamed(fields) => {
259 let mut pattern_tokens = String::new();
261 let mut pattern_2_tokens = String::new();
262 let mut block_tokens = String::new();
263
264 let mut field_attributes = BTreeMap::new();
265 let mut field_names = BTreeMap::new();
266
267 for (index, field) in fields.unnamed.iter().enumerate() {
268 let field_attribute = FieldAttributeBuilder {
269 enable_ignore: true,
270 enable_impl: true,
271 rank: isize::MIN + index as isize,
272 enable_rank: true,
273 }
274 .from_attributes(&field.attrs, traits);
275
276 let field_name = format!("{}", index);
277
278 if field_attribute.ignore {
279 pattern_tokens.push_str("_,");
280 pattern_2_tokens.push_str("_,");
281 continue;
282 }
283
284 let rank = field_attribute.rank;
285
286 if field_attributes.contains_key(&rank) {
287 panic::reuse_a_rank(rank);
288 }
289
290 pattern_tokens
291 .write_fmt(format_args!(
292 "_{field_name},",
293 field_name = field_name
294 ))
295 .unwrap();
296 pattern_2_tokens
297 .write_fmt(format_args!(
298 "__{field_name},",
299 field_name = field_name
300 ))
301 .unwrap();
302
303 field_attributes.insert(rank, field_attribute);
304 field_names.insert(rank, field_name);
305 }
306
307 for (index, field_attribute) in field_attributes {
308 let field_name = field_names.get(&index).unwrap();
309
310 let compare_trait = field_attribute.compare_trait;
311 let compare_method = field_attribute.compare_method;
312
313 match compare_trait {
314 Some(compare_trait) => {
315 let compare_method = compare_method.unwrap();
316
317 block_tokens.write_fmt(format_args!("match {compare_trait}::{compare_method}(_{field_name}, __{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_trait = compare_trait, compare_method = compare_method, field_name = field_name)).unwrap();
318 },
319 None => {
320 match compare_method {
321 Some(compare_method) => {
322 block_tokens.write_fmt(format_args!("match {compare_method}(_{field_name}, __{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", compare_method = compare_method, field_name = field_name)).unwrap();
323 },
324 None => {
325 block_tokens.write_fmt(format_args!("match core::cmp::PartialOrd::partial_cmp(_{field_name}, __{field_name}) {{ Some(core::cmp::Ordering::Equal) => (), Some(core::cmp::Ordering::Greater) => {{ return Some(core::cmp::Ordering::Greater); }}, Some(core::cmp::Ordering::Less) => {{ return Some(core::cmp::Ordering::Less); }}, None => {{ return None; }} }}", field_name = field_name)).unwrap();
326 },
327 }
328 },
329 }
330 }
331
332 match_tokens
333 .write_fmt(format_args!(
334 "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ if \
335 let {enum_name}::{variant_ident} ( {pattern_2_tokens} ) = \
336 other {{ {block_tokens} }} else {{ let other_value = \
337 variant_to_integer(other); return \
338 core::cmp::PartialOrd::partial_cmp(&{variant_value}, \
339 &other_value); }} }}",
340 enum_name = enum_name,
341 variant_ident = variant_ident,
342 pattern_tokens = pattern_tokens,
343 pattern_2_tokens = pattern_2_tokens,
344 block_tokens = block_tokens,
345 variant_value = variant_value
346 ))
347 .unwrap();
348 },
349 }
350 }
351 } else {
352 unit_to_integer.push_str("};");
353
354 comparer_tokens.extend(TokenStream::from_str(&unit_to_integer).unwrap());
355
356 for (index, _) in variants.into_iter().enumerate() {
357 let variant_ident = &variant_idents[index];
358
359 match_tokens
360 .write_fmt(format_args!(
361 "{enum_name}::{variant_ident} => {{ let other_value = \
362 unit_to_integer(other); return \
363 core::cmp::PartialOrd::partial_cmp(&({enum_name}::{variant_ident} as \
364 isize), &other_value); }}",
365 enum_name = enum_name,
366 variant_ident = variant_ident
367 ))
368 .unwrap();
369 }
370 }
371 }
372
373 match_tokens.push('}');
374
375 comparer_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
376
377 if has_non_unit_or_custom_value {
378 comparer_tokens.extend(quote!(Some(core::cmp::Ordering::Equal)));
379 }
380
381 let ident = &ast.ident;
382
383 let mut generics_cloned: Generics = ast.generics.clone();
384
385 let where_clause = generics_cloned.make_where_clause();
386
387 for where_predicate in bound {
388 where_clause.predicates.push(where_predicate);
389 }
390
391 let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
392
393 let compare_impl = quote! {
394 impl #impl_generics core::cmp::PartialOrd for #ident #ty_generics #where_clause {
395 #[inline]
396 #[allow(unreachable_code, clippy::unneeded_field_pattern)]
397 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
398 #comparer_tokens
399 }
400 }
401 };
402
403 tokens.extend(compare_impl);
404 }
405}