derive_more_impl/
add_like.rs

1use crate::add_helpers::{struct_exprs, tuple_exprs};
2use crate::utils::{
3    add_extra_type_param_bound_op_output, field_idents, named_to_vec, numbered_vars,
4    unnamed_to_vec,
5};
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, ToTokens};
8use std::iter;
9use syn::{Data, DataEnum, DeriveInput, Field, Fields, Ident};
10
11pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream {
12    let trait_name = trait_name.trim_end_matches("Self");
13    let trait_ident = format_ident!("{trait_name}");
14    let method_name = trait_name.to_lowercase();
15    let method_ident = format_ident!("{method_name}");
16    let input_type = &input.ident;
17
18    let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident);
19    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
20
21    let (output_type, block) = match input.data {
22        Data::Struct(ref data_struct) => match data_struct.fields {
23            Fields::Unnamed(ref fields) => (
24                quote! { #input_type #ty_generics },
25                tuple_content(input_type, &unnamed_to_vec(fields), &method_ident),
26            ),
27            Fields::Named(ref fields) => (
28                quote! { #input_type #ty_generics },
29                struct_content(input_type, &named_to_vec(fields), &method_ident),
30            ),
31            _ => panic!("Unit structs cannot use derive({trait_name})"),
32        },
33        Data::Enum(ref data_enum) => (
34            quote! {
35                derive_more::core::result::Result<#input_type #ty_generics, derive_more::BinaryError>
36            },
37            enum_content(input_type, data_enum, &method_ident),
38        ),
39
40        _ => panic!("Only structs and enums can use derive({trait_name})"),
41    };
42
43    quote! {
44        #[automatically_derived]
45        impl #impl_generics derive_more::core::ops::#trait_ident
46         for #input_type #ty_generics #where_clause {
47            type Output = #output_type;
48
49            #[inline]
50            #[track_caller]
51            fn #method_ident(self, rhs: #input_type #ty_generics) -> #output_type {
52                #block
53            }
54        }
55    }
56}
57
58fn tuple_content<T: ToTokens>(
59    input_type: &T,
60    fields: &[&Field],
61    method_ident: &Ident,
62) -> TokenStream {
63    let exprs = tuple_exprs(fields, method_ident);
64    quote! { #input_type(#(#exprs),*) }
65}
66
67fn struct_content(
68    input_type: &Ident,
69    fields: &[&Field],
70    method_ident: &Ident,
71) -> TokenStream {
72    // It's safe to unwrap because struct fields always have an identifier
73    let exprs = struct_exprs(fields, method_ident);
74    let field_names = field_idents(fields);
75
76    quote! { #input_type{#(#field_names: #exprs),*} }
77}
78
79#[allow(clippy::cognitive_complexity)]
80fn enum_content(
81    input_type: &Ident,
82    data_enum: &DataEnum,
83    method_ident: &Ident,
84) -> TokenStream {
85    let mut matches = vec![];
86    let mut method_iter = iter::repeat(method_ident);
87
88    for variant in &data_enum.variants {
89        let subtype = &variant.ident;
90        let subtype = quote! { #input_type::#subtype };
91
92        match variant.fields {
93            Fields::Unnamed(ref fields) => {
94                // The pattern that is outputted should look like this:
95                // (Subtype(left_vars), TypePath(right_vars)) => Ok(TypePath(exprs))
96                let size = unnamed_to_vec(fields).len();
97                let l_vars = &numbered_vars(size, "l_");
98                let r_vars = &numbered_vars(size, "r_");
99                let method_iter = method_iter.by_ref();
100                let matcher = quote! {
101                    (#subtype(#(#l_vars),*),
102                     #subtype(#(#r_vars),*)) => {
103                        derive_more::core::result::Result::Ok(
104                            #subtype(#(#l_vars.#method_iter(#r_vars)),*)
105                        )
106                    }
107                };
108                matches.push(matcher);
109            }
110            Fields::Named(ref fields) => {
111                // The pattern that is outputted should look like this:
112                // (Subtype{a: __l_a, ...}, Subtype{a: __r_a, ...} => {
113                //     Ok(Subtype{a: __l_a.add(__r_a), ...})
114                // }
115                let field_vec = named_to_vec(fields);
116                let size = field_vec.len();
117                let field_names = &field_idents(&field_vec);
118                let l_vars = &numbered_vars(size, "l_");
119                let r_vars = &numbered_vars(size, "r_");
120                let method_iter = method_iter.by_ref();
121                let matcher = quote! {
122                    (#subtype{#(#field_names: #l_vars),*},
123                     #subtype{#(#field_names: #r_vars),*}) => {
124                        derive_more::core::result::Result::Ok(#subtype{
125                            #(#field_names: #l_vars.#method_iter(#r_vars)),*
126                        })
127                    }
128                };
129                matches.push(matcher);
130            }
131            Fields::Unit => {
132                let operation_name = method_ident.to_string();
133                matches.push(quote! {
134                    (#subtype, #subtype) => derive_more::core::result::Result::Err(
135                        derive_more::BinaryError::Unit(
136                            derive_more::UnitError::new(#operation_name)
137                        )
138                    )
139                });
140            }
141        }
142    }
143
144    if data_enum.variants.len() > 1 {
145        // In the strange case where there's only one enum variant this is would be an unreachable
146        // match.
147        let operation_name = method_ident.to_string();
148        matches.push(quote! {
149            _ => derive_more::core::result::Result::Err(derive_more::BinaryError::Mismatch(
150                derive_more::WrongVariantError::new(#operation_name)
151            ))
152        });
153    }
154    quote! {
155        match (self, rhs) {
156            #(#matches),*
157        }
158    }
159}