educe/trait_handlers/partial_eq/
partial_eq_struct.rs1use std::str::FromStr;
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Generics, Meta};
6
7use super::{
8 super::TraitHandler,
9 models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::Trait;
12
13pub struct PartialEqStructHandler;
14
15impl TraitHandler for PartialEqStructHandler {
16 fn trait_meta_handler(
17 ast: &DeriveInput,
18 tokens: &mut TokenStream,
19 traits: &[Trait],
20 meta: &Meta,
21 ) {
22 let type_attribute = TypeAttributeBuilder {
23 enable_flag: true, enable_bound: true
24 }
25 .from_partial_eq_meta(meta);
26
27 let bound = type_attribute
28 .bound
29 .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
30
31 let mut comparer_tokens = TokenStream::new();
32
33 if let Data::Struct(data) = &ast.data {
34 for (index, field) in data.fields.iter().enumerate() {
35 let field_attribute = FieldAttributeBuilder {
36 enable_ignore: true,
37 enable_impl: true,
38 }
39 .from_attributes(&field.attrs, traits);
40
41 if field_attribute.ignore {
42 continue;
43 }
44
45 let compare_trait = field_attribute.compare_trait;
46 let compare_method = field_attribute.compare_method;
47
48 let field_name = if let Some(ident) = field.ident.as_ref() {
49 ident.to_string()
50 } else {
51 format!("{}", index)
52 };
53
54 match compare_trait {
55 Some(compare_trait) => {
56 let compare_method = compare_method.unwrap();
57
58 let statement = format!(
59 "if !{compare_trait}::{compare_method}(&self.{field_name}, \
60 &other.{field_name}) {{ return false }}",
61 compare_trait = compare_trait,
62 compare_method = compare_method,
63 field_name = field_name
64 );
65
66 comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
67 },
68 None => match compare_method {
69 Some(compare_method) => {
70 let statement = format!(
71 "if !{compare_method}(&self.{field_name}, &other.{field_name}) {{ \
72 return false; }}",
73 compare_method = compare_method,
74 field_name = field_name
75 );
76
77 comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
78 },
79 None => {
80 let statement = format!(
81 "if core::cmp::PartialEq::ne(&self.{field_name}, \
82 &other.{field_name}) {{ return false; }}",
83 field_name = field_name
84 );
85
86 comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
87 },
88 },
89 }
90 }
91 }
92
93 let ident = &ast.ident;
94
95 let mut generics_cloned: Generics = ast.generics.clone();
96
97 let where_clause = generics_cloned.make_where_clause();
98
99 for where_predicate in bound {
100 where_clause.predicates.push(where_predicate);
101 }
102
103 let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
104
105 let compare_impl = quote! {
106 impl #impl_generics core::cmp::PartialEq for #ident #ty_generics #where_clause {
107 #[inline]
108 fn eq(&self, other: &Self) -> bool {
109 #comparer_tokens
110
111 true
112 }
113 }
114 };
115
116 tokens.extend(compare_impl);
117 }
118}