derive_more_impl/
not_like.rs

1use crate::utils::{
2    add_extra_type_param_bound_op_output, named_to_vec, unnamed_to_vec,
3};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, ToTokens};
6use std::iter;
7use syn::{Data, DataEnum, DeriveInput, Field, Fields, Ident, Index};
8
9pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream {
10    let trait_ident = format_ident!("{trait_name}");
11    let method_name = trait_name.to_lowercase();
12    let method_ident = format_ident!("{method_name}");
13    let input_type = &input.ident;
14
15    let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident);
16    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
17
18    let (output_type, block) = match input.data {
19        Data::Struct(ref data_struct) => match data_struct.fields {
20            Fields::Unnamed(ref fields) => (
21                quote! { #input_type #ty_generics },
22                tuple_content(input_type, &unnamed_to_vec(fields), &method_ident),
23            ),
24            Fields::Named(ref fields) => (
25                quote! { #input_type #ty_generics },
26                struct_content(input_type, &named_to_vec(fields), &method_ident),
27            ),
28            _ => panic!("Unit structs cannot use derive({trait_name})"),
29        },
30        Data::Enum(ref data_enum) => {
31            enum_output_type_and_content(input, data_enum, &method_ident)
32        }
33
34        _ => panic!("Only structs and enums can use derive({trait_name})"),
35    };
36
37    quote! {
38        #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
39        #[automatically_derived]
40        impl #impl_generics derive_more::core::ops::#trait_ident
41         for #input_type #ty_generics #where_clause {
42            type Output = #output_type;
43
44            #[inline]
45            fn #method_ident(self) -> #output_type {
46                #block
47            }
48        }
49    }
50}
51
52fn tuple_content<T: ToTokens>(
53    input_type: &T,
54    fields: &[&Field],
55    method_ident: &Ident,
56) -> TokenStream {
57    let mut exprs = vec![];
58
59    for i in 0..fields.len() {
60        let i = Index::from(i);
61        // generates `self.0.add()`
62        let expr = quote! { self.#i.#method_ident() };
63        exprs.push(expr);
64    }
65
66    quote! { #input_type(#(#exprs),*) }
67}
68
69fn struct_content(
70    input_type: &Ident,
71    fields: &[&Field],
72    method_ident: &Ident,
73) -> TokenStream {
74    let mut exprs = vec![];
75
76    for field in fields {
77        // It's safe to unwrap because struct fields always have an identifier
78        let field_id = field.ident.as_ref();
79        // generates `x: self.x.not()`
80        let expr = quote! { #field_id: self.#field_id.#method_ident() };
81        exprs.push(expr)
82    }
83
84    quote! { #input_type{#(#exprs),*} }
85}
86
87fn enum_output_type_and_content(
88    input: &DeriveInput,
89    data_enum: &DataEnum,
90    method_ident: &Ident,
91) -> (TokenStream, TokenStream) {
92    let input_type = &input.ident;
93    let (_, ty_generics, _) = input.generics.split_for_impl();
94    let mut matches = vec![];
95    let mut method_iter = iter::repeat(method_ident);
96    // If the enum contains unit types that means it can error.
97    let has_unit_type = data_enum.variants.iter().any(|v| v.fields == Fields::Unit);
98
99    for variant in &data_enum.variants {
100        let subtype = &variant.ident;
101        let subtype = quote! { #input_type::#subtype };
102
103        match variant.fields {
104            Fields::Unnamed(ref fields) => {
105                // The pattern that is outputted should look like this:
106                // (Subtype(vars)) => Ok(TypePath(exprs))
107                let size = unnamed_to_vec(fields).len();
108                let vars: &Vec<_> =
109                    &(0..size).map(|i| format_ident!("__{i}")).collect();
110                let method_iter = method_iter.by_ref();
111                let mut body = quote! { #subtype(#(#vars.#method_iter()),*) };
112                if has_unit_type {
113                    body = quote! { derive_more::core::result::Result::Ok(#body) }
114                }
115                let matcher = quote! {
116                    #subtype(#(#vars),*) => {
117                        #body
118                    }
119                };
120                matches.push(matcher);
121            }
122            Fields::Named(ref fields) => {
123                // The pattern that is outputted should look like this:
124                // (Subtype{a: __l_a, ...} => {
125                //     Ok(Subtype{a: __l_a.neg(__r_a), ...})
126                // }
127                let field_vec = named_to_vec(fields);
128                let size = field_vec.len();
129                let field_names: &Vec<_> = &field_vec
130                    .iter()
131                    .map(|f| f.ident.as_ref().unwrap())
132                    .collect();
133                let vars: &Vec<_> =
134                    &(0..size).map(|i| format_ident!("__{i}")).collect();
135                let method_iter = method_iter.by_ref();
136                let mut body = quote! {
137                    #subtype{#(#field_names: #vars.#method_iter()),*}
138                };
139                if has_unit_type {
140                    body = quote! { derive_more::core::result::Result::Ok(#body) }
141                }
142                let matcher = quote! {
143                    #subtype{#(#field_names: #vars),*} => {
144                        #body
145                    }
146                };
147                matches.push(matcher);
148            }
149            Fields::Unit => {
150                let operation_name = method_ident.to_string();
151                matches.push(quote! {
152                    #subtype => derive_more::core::result::Result::Err(
153                        derive_more::UnitError::new(#operation_name)
154                    )
155                });
156            }
157        }
158    }
159
160    let body = quote! {
161        match self {
162            #(#matches),*
163        }
164    };
165
166    let output_type = if has_unit_type {
167        quote! {
168            derive_more::core::result::Result<#input_type #ty_generics, derive_more::UnitError>
169        }
170    } else {
171        quote! { #input_type #ty_generics }
172    };
173
174    (output_type, body)
175}