educe/trait_handlers/partial_eq/
partial_eq_enum.rs1use std::{fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Generics, Meta};
6
7use super::{
8 super::TraitHandler,
9 models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::Trait;
12
13pub struct PartialEqEnumHandler;
14
15impl TraitHandler for PartialEqEnumHandler {
16 fn trait_meta_handler(
17 ast: &DeriveInput,
18 tokens: &mut TokenStream,
19 traits: &[Trait],
20 meta: &Meta,
21 ) {
22 let type_attribute = TypeAttributeBuilder {
23 enable_flag: true, enable_bound: true
24 }
25 .from_partial_eq_meta(meta);
26
27 let enum_name = ast.ident.to_string();
28
29 let bound = type_attribute
30 .bound
31 .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
32
33 let mut comparer_tokens = TokenStream::new();
34
35 let mut match_tokens = String::from("match self {");
36
37 if let Data::Enum(data) = &ast.data {
38 for variant in data.variants.iter() {
39 let _ = TypeAttributeBuilder {
40 enable_flag: false, enable_bound: false
41 }
42 .from_attributes(&variant.attrs, traits);
43
44 let variant_ident = variant.ident.to_string();
45
46 match &variant.fields {
47 Fields::Unit => {
48 match_tokens
50 .write_fmt(format_args!(
51 "{enum_name}::{variant_ident} => {{ if let \
52 {enum_name}::{variant_ident} = other {{ }} else {{ return false; \
53 }} }}",
54 enum_name = enum_name,
55 variant_ident = variant_ident
56 ))
57 .unwrap();
58 },
59 Fields::Named(fields) => {
60 let mut pattern_tokens = String::new();
62 let mut pattern_2_tokens = String::new();
63 let mut block_tokens = String::new();
64
65 let mut field_attributes = Vec::new();
66 let mut field_names = Vec::new();
67
68 for field in fields.named.iter() {
69 let field_attribute = FieldAttributeBuilder {
70 enable_ignore: true,
71 enable_impl: true,
72 }
73 .from_attributes(&field.attrs, traits);
74
75 let field_name = field.ident.as_ref().unwrap().to_string();
76
77 if field_attribute.ignore {
78 pattern_tokens
79 .write_fmt(format_args!(
80 "{field_name}: _,",
81 field_name = field_name
82 ))
83 .unwrap();
84 pattern_2_tokens
85 .write_fmt(format_args!(
86 "{field_name}: _,",
87 field_name = field_name
88 ))
89 .unwrap();
90 continue;
91 }
92
93 pattern_tokens
94 .write_fmt(format_args!("{field_name},", field_name = field_name))
95 .unwrap();
96 pattern_2_tokens
97 .write_fmt(format_args!(
98 "{field_name}: ___{field_name},",
99 field_name = field_name
100 ))
101 .unwrap();
102
103 field_attributes.push(field_attribute);
104 field_names.push(field_name);
105 }
106
107 for (index, field_attribute) in field_attributes.into_iter().enumerate() {
108 let field_name = &field_names[index];
109
110 let compare_trait = field_attribute.compare_trait;
111 let compare_method = field_attribute.compare_method;
112
113 match compare_trait {
114 Some(compare_trait) => {
115 let compare_method = compare_method.unwrap();
116
117 block_tokens
118 .write_fmt(format_args!(
119 "if !{compare_trait}::{compare_method}({field_name}, \
120 ___{field_name}) {{ return false; }}",
121 compare_trait = compare_trait,
122 compare_method = compare_method,
123 field_name = field_name
124 ))
125 .unwrap();
126 },
127 None => match compare_method {
128 Some(compare_method) => {
129 block_tokens
130 .write_fmt(format_args!(
131 "if !{compare_method}({field_name}, \
132 ___{field_name}) {{ return false; }}",
133 compare_method = compare_method,
134 field_name = field_name
135 ))
136 .unwrap();
137 },
138 None => {
139 block_tokens
140 .write_fmt(format_args!(
141 "if core::cmp::PartialEq::ne({field_name}, \
142 ___{field_name}) {{ return false; }}",
143 field_name = field_name
144 ))
145 .unwrap();
146 },
147 },
148 }
149 }
150
151 match_tokens
152 .write_fmt(format_args!(
153 "{enum_name}::{variant_ident}{{ {pattern_tokens} }} => {{ if let \
154 {enum_name}::{variant_ident} {{ {pattern_2_tokens} }} = other {{ \
155 {block_tokens} }} else {{ return false; }} }}",
156 enum_name = enum_name,
157 variant_ident = variant_ident,
158 pattern_tokens = pattern_tokens,
159 pattern_2_tokens = pattern_2_tokens,
160 block_tokens = block_tokens
161 ))
162 .unwrap();
163 },
164 Fields::Unnamed(fields) => {
165 let mut pattern_tokens = String::new();
167 let mut pattern_2_tokens = String::new();
168 let mut block_tokens = String::new();
169
170 let mut field_attributes = Vec::new();
171 let mut field_names = Vec::new();
172
173 for (index, field) in fields.unnamed.iter().enumerate() {
174 let field_attribute = FieldAttributeBuilder {
175 enable_ignore: true,
176 enable_impl: true,
177 }
178 .from_attributes(&field.attrs, traits);
179
180 let field_name = format!("{}", index);
181
182 if field_attribute.ignore {
183 pattern_tokens.push_str("_,");
184 pattern_2_tokens.push_str("_,");
185 continue;
186 }
187
188 pattern_tokens
189 .write_fmt(format_args!("_{field_name},", field_name = field_name))
190 .unwrap();
191 pattern_2_tokens
192 .write_fmt(format_args!("__{field_name},", field_name = field_name))
193 .unwrap();
194
195 field_attributes.push(field_attribute);
196 field_names.push(field_name);
197 }
198
199 for (index, field_attribute) in field_attributes.into_iter().enumerate() {
200 let field_name = &field_names[index];
201
202 let compare_trait = field_attribute.compare_trait;
203 let compare_method = field_attribute.compare_method;
204
205 match compare_trait {
206 Some(compare_trait) => {
207 let compare_method = compare_method.unwrap();
208
209 block_tokens
210 .write_fmt(format_args!(
211 "if !{compare_trait}::{compare_method}(_{field_name}, \
212 __{field_name}) {{ return false; }}",
213 compare_trait = compare_trait,
214 compare_method = compare_method,
215 field_name = field_name
216 ))
217 .unwrap();
218 },
219 None => match compare_method {
220 Some(compare_method) => {
221 block_tokens
222 .write_fmt(format_args!(
223 "if !{compare_method}(_{field_name}, \
224 __{field_name}) {{ return false; }}",
225 compare_method = compare_method,
226 field_name = field_name
227 ))
228 .unwrap();
229 },
230 None => {
231 block_tokens
232 .write_fmt(format_args!(
233 "if core::cmp::PartialEq::ne(_{field_name}, \
234 __{field_name}) {{ return false; }}",
235 field_name = field_name
236 ))
237 .unwrap();
238 },
239 },
240 }
241 }
242
243 match_tokens
244 .write_fmt(format_args!(
245 "{enum_name}::{variant_ident}( {pattern_tokens} ) => {{ if let \
246 {enum_name}::{variant_ident} ( {pattern_2_tokens} ) = other {{ \
247 {block_tokens} }} else {{ return false; }} }}",
248 enum_name = enum_name,
249 variant_ident = variant_ident,
250 pattern_tokens = pattern_tokens,
251 pattern_2_tokens = pattern_2_tokens,
252 block_tokens = block_tokens
253 ))
254 .unwrap();
255 },
256 }
257 }
258 }
259
260 match_tokens.push('}');
261
262 comparer_tokens.extend(TokenStream::from_str(&match_tokens).unwrap());
263
264 let ident = &ast.ident;
265
266 let mut generics_cloned: Generics = ast.generics.clone();
267
268 let where_clause = generics_cloned.make_where_clause();
269
270 for where_predicate in bound {
271 where_clause.predicates.push(where_predicate);
272 }
273
274 let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
275
276 let compare_impl = quote! {
277 impl #impl_generics core::cmp::PartialEq for #ident #ty_generics #where_clause {
278 #[inline]
279 #[allow(clippy::unneeded_field_pattern)]
280 fn eq(&self, other: &Self) -> bool {
281 #comparer_tokens
282
283 true
284 }
285 }
286 };
287
288 tokens.extend(compare_impl);
289 }
290}