blanket/derive/
mut.rs

1use syn::{parse_quote, spanned::Spanned};
2
3use crate::utils::{
4    deref_expr, generics_declaration_to_generics, signature_to_method_call, trait_to_generic_ident,
5};
6
7pub fn derive(trait_: &syn::ItemTrait) -> syn::Result<syn::ItemImpl> {
8    // build an identifier for the generic type used for the implementation
9    let trait_ident = &trait_.ident;
10    let generic_type = trait_to_generic_ident(&trait_);
11
12    // build the generics for the impl block:
13    // we use the same generics as the trait itself, plus
14    // a generic type that implements the trait for which we provide the
15    // blanket implementation
16    let trait_generics = &trait_.generics;
17    let where_clause = &trait_.generics.where_clause;
18    let mut impl_generics = trait_generics.clone();
19
20    // we must however remove the generic type bounds, to avoid repeating them
21    let mut trait_generic_names = trait_generics.clone();
22    trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?;
23
24    impl_generics.params.push(syn::GenericParam::Type(
25        parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized),
26    ));
27
28    // build the methods
29    let mut methods: Vec<syn::ImplItemFn> = Vec::new();
30    let mut assoc_types: Vec<syn::ImplItemType> = Vec::new();
31    for item in trait_.items.iter() {
32        if let syn::TraitItem::Fn(ref m) = item {
33            if let Some(r) = m.sig.receiver() {
34                let err = if r.colon_token.is_some() {
35                    Some("cannot derive `Mut` for a trait declaring methods with arbitrary receiver types")
36                } else if r.reference.is_none() {
37                    Some("cannot derive `Mut` for a trait declaring `self` methods")
38                } else {
39                    None
40                };
41                if let Some(msg) = err {
42                    return Err(syn::Error::new(r.span(), msg));
43                }
44            }
45
46            let mut call = signature_to_method_call(&m.sig)?;
47            call.receiver = Box::new(deref_expr(deref_expr(*call.receiver)));
48
49            let signature = &m.sig;
50            let item = parse_quote!(#[inline] #signature { #call });
51            methods.push(item)
52        }
53
54        if let syn::TraitItem::Type(t) = item {
55            let t_ident = &t.ident;
56            let attrs = &t.attrs;
57
58            let t_generics = &t.generics;
59            let where_clause = &t.generics.where_clause;
60            let mut t_generic_names = t_generics.clone();
61            t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?;
62
63            let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; );
64            assoc_types.push(item);
65        }
66    }
67
68    Ok(parse_quote!(
69        #[automatically_derived]
70        impl #impl_generics #trait_ident #trait_generic_names for &mut #generic_type #where_clause {
71            #(#assoc_types)*
72            #(#methods)*
73        }
74    ))
75}
76
77#[cfg(test)]
78mod tests {
79    mod derive {
80
81        use syn::parse_quote;
82
83        #[test]
84        fn empty() {
85            let trait_ = parse_quote!(
86                trait MyTrait {}
87            );
88            let derived = super::super::derive(&trait_).unwrap();
89            assert_eq!(
90                derived,
91                parse_quote!(
92                    #[automatically_derived]
93                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {}
94                )
95            );
96        }
97
98        #[test]
99        fn receiver_mut() {
100            let trait_ = parse_quote!(
101                trait MyTrait {
102                    fn my_method(&mut self);
103                }
104            );
105            assert_eq!(
106                super::super::derive(&trait_).unwrap(),
107                parse_quote!(
108                    #[automatically_derived]
109                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
110                        #[inline]
111                        fn my_method(&mut self) {
112                            (*(*self)).my_method()
113                        }
114                    }
115                )
116            );
117        }
118
119        #[test]
120        fn receiver_self() {
121            let trait_ = parse_quote!(
122                trait MyTrait {
123                    fn my_method(self);
124                }
125            );
126            assert!(super::super::derive(&trait_).is_err());
127        }
128
129        #[test]
130        fn receiver_arbitrary() {
131            let trait_ = parse_quote!(
132                trait MyTrait {
133                    fn my_method(self: Box<Self>);
134                }
135            );
136            assert!(super::super::derive(&trait_).is_err());
137        }
138
139        #[test]
140        fn generics() {
141            let trait_ = parse_quote!(
142                trait Trait<T> {}
143            );
144            let derived = super::super::derive(&trait_).unwrap();
145
146            assert_eq!(
147                derived,
148                parse_quote!(
149                    #[automatically_derived]
150                    impl<T, T_: Trait<T> + ?Sized> Trait<T> for &mut T_ {}
151                )
152            );
153        }
154
155        #[test]
156        fn generics_bounded() {
157            let trait_ = parse_quote!(
158                trait Trait<T: 'static + Send> {}
159            );
160            let derived = super::super::derive(&trait_).unwrap();
161
162            assert_eq!(
163                derived,
164                parse_quote!(
165                    #[automatically_derived]
166                    impl<T: 'static + Send, T_: Trait<T> + ?Sized> Trait<T> for &mut T_ {}
167                )
168            );
169        }
170
171        #[test]
172        fn generics_lifetime() {
173            let trait_ = parse_quote!(
174                trait Trait<'a, 'b: 'a, T: 'static + Send> {}
175            );
176            let derived = super::super::derive(&trait_).unwrap();
177
178            assert_eq!(
179                derived,
180                parse_quote!(
181                    #[automatically_derived]
182                    impl<'a, 'b: 'a, T: 'static + Send, T_: Trait<'a, 'b, T> + ?Sized>
183                        Trait<'a, 'b, T> for &mut T_
184                    {
185                    }
186                )
187            );
188        }
189
190        #[test]
191        fn associated_types() {
192            let trait_ = parse_quote!(
193                trait MyTrait {
194                    type Return;
195                }
196            );
197            let derived = super::super::derive(&trait_).unwrap();
198
199            assert_eq!(
200                derived,
201                parse_quote!(
202                    #[automatically_derived]
203                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
204                        type Return = <MT as MyTrait>::Return;
205                    }
206                )
207            );
208        }
209
210        #[test]
211        fn associated_types_bound() {
212            let trait_ = parse_quote!(
213                trait MyTrait {
214                    type Return: Clone;
215                }
216            );
217            let derived = super::super::derive(&trait_).unwrap();
218
219            assert_eq!(
220                derived,
221                parse_quote!(
222                    #[automatically_derived]
223                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
224                        type Return = <MT as MyTrait>::Return;
225                    }
226                )
227            );
228        }
229
230        #[test]
231        fn associated_types_dodgy_name() {
232            let trait_ = parse_quote!(
233                trait MyTrait {
234                    type r#type;
235                }
236            );
237            let derived = super::super::derive(&trait_).unwrap();
238
239            assert_eq!(
240                derived,
241                parse_quote!(
242                    #[automatically_derived]
243                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
244                        type r#type = <MT as MyTrait>::r#type;
245                    }
246                )
247            );
248        }
249
250        #[test]
251        fn associated_types_attrs() {
252            let trait_ = parse_quote!(
253                trait MyTrait {
254                    #[cfg(target_arch = "wasm32")]
255                    type Return;
256                    #[cfg(not(target_arch = "wasm32"))]
257                    type Return: Send;
258                }
259            );
260            let derived = super::super::derive(&trait_).unwrap();
261
262            assert_eq!(
263                derived,
264                parse_quote!(
265                    #[automatically_derived]
266                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
267                        #[cfg(target_arch = "wasm32")]
268                        type Return = <MT as MyTrait>::Return;
269                        #[cfg(not(target_arch = "wasm32"))]
270                        type Return = <MT as MyTrait>::Return;
271                    }
272                )
273            );
274        }
275
276        #[test]
277        fn associated_types_and_generics() {
278            let trait_ = parse_quote!(
279                trait MyTrait<T> {
280                    type Return;
281                }
282            );
283            let derived = super::super::derive(&trait_).unwrap();
284
285            assert_eq!(
286                derived,
287                parse_quote!(
288                    #[automatically_derived]
289                    impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for &mut MT {
290                        type Return = <MT as MyTrait<T>>::Return;
291                    }
292                )
293            );
294        }
295
296        #[test]
297        fn associated_type_generics() {
298            let trait_ = parse_quote!(
299                trait MyTrait {
300                    type Return<T>;
301                }
302            );
303            let derived = super::super::derive(&trait_).unwrap();
304
305            assert_eq!(
306                derived,
307                parse_quote!(
308                    #[automatically_derived]
309                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
310                        type Return<T> = <MT as MyTrait>::Return<T>;
311                    }
312                )
313            );
314        }
315
316        #[test]
317        fn associated_type_generics_bounded() {
318            let trait_ = parse_quote!(
319                trait MyTrait {
320                    type Return<T: 'static + Send>;
321                }
322            );
323            let derived = super::super::derive(&trait_).unwrap();
324
325            assert_eq!(
326                derived,
327                parse_quote!(
328                    #[automatically_derived]
329                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
330                        type Return<T: 'static + Send> = <MT as MyTrait>::Return<T>;
331                    }
332                )
333            );
334        }
335
336        #[test]
337        fn associated_type_generics_lifetimes() {
338            let trait_ = parse_quote!(
339                trait MyTrait {
340                    type Return<'a>
341                    where
342                        Self: 'a;
343                }
344            );
345            let derived = super::super::derive(&trait_).unwrap();
346
347            assert_eq!(
348                derived,
349                parse_quote!(
350                    #[automatically_derived]
351                    impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
352                        type Return<'a> = <MT as MyTrait>::Return<'a>
353                        where
354                            Self: 'a;
355                    }
356                )
357            );
358        }
359    }
360}