educe/trait_handlers/partial_eq/
partial_eq_struct.rs

1use std::str::FromStr;
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Generics, Meta};
6
7use super::{
8    super::TraitHandler,
9    models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::Trait;
12
13pub struct PartialEqStructHandler;
14
15impl TraitHandler for PartialEqStructHandler {
16    fn trait_meta_handler(
17        ast: &DeriveInput,
18        tokens: &mut TokenStream,
19        traits: &[Trait],
20        meta: &Meta,
21    ) {
22        let type_attribute = TypeAttributeBuilder {
23            enable_flag: true, enable_bound: true
24        }
25        .from_partial_eq_meta(meta);
26
27        let bound = type_attribute
28            .bound
29            .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
30
31        let mut comparer_tokens = TokenStream::new();
32
33        if let Data::Struct(data) = &ast.data {
34            for (index, field) in data.fields.iter().enumerate() {
35                let field_attribute = FieldAttributeBuilder {
36                    enable_ignore: true,
37                    enable_impl:   true,
38                }
39                .from_attributes(&field.attrs, traits);
40
41                if field_attribute.ignore {
42                    continue;
43                }
44
45                let compare_trait = field_attribute.compare_trait;
46                let compare_method = field_attribute.compare_method;
47
48                let field_name = if let Some(ident) = field.ident.as_ref() {
49                    ident.to_string()
50                } else {
51                    format!("{}", index)
52                };
53
54                match compare_trait {
55                    Some(compare_trait) => {
56                        let compare_method = compare_method.unwrap();
57
58                        let statement = format!(
59                            "if !{compare_trait}::{compare_method}(&self.{field_name}, \
60                             &other.{field_name}) {{ return false }}",
61                            compare_trait = compare_trait,
62                            compare_method = compare_method,
63                            field_name = field_name
64                        );
65
66                        comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
67                    },
68                    None => match compare_method {
69                        Some(compare_method) => {
70                            let statement = format!(
71                                "if !{compare_method}(&self.{field_name}, &other.{field_name}) {{ \
72                                 return false; }}",
73                                compare_method = compare_method,
74                                field_name = field_name
75                            );
76
77                            comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
78                        },
79                        None => {
80                            let statement = format!(
81                                "if core::cmp::PartialEq::ne(&self.{field_name}, \
82                                 &other.{field_name}) {{ return false; }}",
83                                field_name = field_name
84                            );
85
86                            comparer_tokens.extend(TokenStream::from_str(&statement).unwrap());
87                        },
88                    },
89                }
90            }
91        }
92
93        let ident = &ast.ident;
94
95        let mut generics_cloned: Generics = ast.generics.clone();
96
97        let where_clause = generics_cloned.make_where_clause();
98
99        for where_predicate in bound {
100            where_clause.predicates.push(where_predicate);
101        }
102
103        let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
104
105        let compare_impl = quote! {
106            impl #impl_generics core::cmp::PartialEq for #ident #ty_generics #where_clause {
107                #[inline]
108                fn eq(&self, other: &Self) -> bool {
109                    #comparer_tokens
110
111                    true
112                }
113            }
114        };
115
116        tokens.extend(compare_impl);
117    }
118}