blanket/derive/
box.rs

1use syn::{parse_quote, spanned::Spanned};
2
3use crate::utils::{
4    deref_expr, generics_declaration_to_generics, signature_to_method_call, trait_to_generic_ident,
5};
6
7pub fn derive(trait_: &syn::ItemTrait) -> syn::Result<syn::ItemImpl> {
8    // build an identifier for the generic type used for the implementation
9    let trait_ident = &trait_.ident;
10    let generic_type = trait_to_generic_ident(&trait_);
11
12    // build the generics for the impl block:
13    // we use the same generics as the trait itself, plus
14    // a generic type that implements the trait for which we provide the
15    // blanket implementation
16    let trait_generics = &trait_.generics;
17    let where_clause = &trait_.generics.where_clause;
18    let mut impl_generics = trait_generics.clone();
19
20    // we must however remove the generic type bounds, to avoid repeating them
21    let mut trait_generic_names = trait_generics.clone();
22    trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?;
23
24    impl_generics.params.push(syn::GenericParam::Type(
25        parse_quote!(#generic_type: #trait_ident #trait_generic_names),
26    ));
27
28    // build the methods
29    let mut methods: Vec<syn::ImplItemFn> = Vec::new();
30    let mut assoc_types: Vec<syn::ImplItemType> = Vec::new();
31    for item in trait_.items.iter() {
32        if let syn::TraitItem::Fn(ref m) = item {
33            let signature = &m.sig;
34            let mut call = signature_to_method_call(signature)?;
35
36            if let Some(r) = m.sig.receiver() {
37                let err = if r.colon_token.is_some() {
38                    Some("cannot derive `Box` for a trait declaring methods with arbitrary receiver types")
39                } else if r.reference.is_some() {
40                    call.receiver = Box::new(deref_expr(deref_expr(*call.receiver)));
41                    None
42                } else {
43                    call.receiver = Box::new(deref_expr(*call.receiver));
44                    None
45                };
46                if let Some(msg) = err {
47                    return Err(syn::Error::new(r.span(), msg));
48                }
49            } else {
50                unimplemented!()
51            }
52
53            let item = parse_quote!(#[inline] #signature { #call });
54            methods.push(item)
55        }
56
57        if let syn::TraitItem::Type(t) = item {
58            let t_ident = &t.ident;
59            let attrs = &t.attrs;
60
61            let t_generics = &t.generics;
62            let where_clause = &t.generics.where_clause;
63            let mut t_generic_names = t_generics.clone();
64            t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?;
65
66            let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; );
67            assoc_types.push(item);
68        }
69    }
70
71    // generate the impl block
72    Ok(parse_quote!(
73        #[automatically_derived]
74        impl #impl_generics #trait_ident #trait_generic_names for Box<#generic_type> #where_clause {
75            #(#assoc_types)*
76            #(#methods)*
77        }
78    ))
79}
80
81#[cfg(test)]
82mod tests {
83    mod derive {
84
85        use syn::parse_quote;
86
87        #[test]
88        fn empty() {
89            let trait_ = parse_quote!(
90                trait MyTrait {}
91            );
92            let derived = super::super::derive(&trait_).unwrap();
93            assert_eq!(
94                derived,
95                parse_quote!(
96                    #[automatically_derived]
97                    impl<MT: MyTrait> MyTrait for Box<MT> {}
98                )
99            );
100        }
101
102        #[test]
103        fn receiver_ref() {
104            let trait_ = parse_quote!(
105                trait MyTrait {
106                    fn my_method(&self);
107                }
108            );
109            assert_eq!(
110                super::super::derive(&trait_).unwrap(),
111                parse_quote!(
112                    #[automatically_derived]
113                    impl<MT: MyTrait> MyTrait for Box<MT> {
114                        #[inline]
115                        fn my_method(&self) {
116                            (*(*self)).my_method()
117                        }
118                    }
119                )
120            );
121        }
122
123        #[test]
124        fn receiver_mut() {
125            let trait_ = parse_quote!(
126                trait MyTrait {
127                    fn my_method(&mut self);
128                }
129            );
130            assert_eq!(
131                super::super::derive(&trait_).unwrap(),
132                parse_quote!(
133                    #[automatically_derived]
134                    impl<MT: MyTrait> MyTrait for Box<MT> {
135                        #[inline]
136                        fn my_method(&mut self) {
137                            (*(*self)).my_method()
138                        }
139                    }
140                )
141            );
142        }
143
144        #[test]
145        fn receiver_self() {
146            let trait_ = parse_quote!(
147                trait MyTrait {
148                    fn my_method(self);
149                }
150            );
151            assert_eq!(
152                super::super::derive(&trait_).unwrap(),
153                parse_quote!(
154                    #[automatically_derived]
155                    impl<MT: MyTrait> MyTrait for Box<MT> {
156                        #[inline]
157                        fn my_method(self) {
158                            (*self).my_method()
159                        }
160                    }
161                )
162            );
163        }
164
165        #[test]
166        fn receiver_arbitrary() {
167            let trait_ = parse_quote!(
168                trait MyTrait {
169                    fn my_method(self: Box<Self>);
170                }
171            );
172            assert!(super::super::derive(&trait_).is_err());
173        }
174
175        #[test]
176        fn generics() {
177            let trait_ = parse_quote!(
178                trait MyTrait<T> {}
179            );
180            let derived = super::super::derive(&trait_).unwrap();
181
182            assert_eq!(
183                derived,
184                parse_quote!(
185                    #[automatically_derived]
186                    impl<T, MT: MyTrait<T>> MyTrait<T> for Box<MT> {}
187                )
188            );
189        }
190
191        #[test]
192        fn generics_bounded() {
193            let trait_ = parse_quote!(
194                trait MyTrait<T: 'static + Send> {}
195            );
196            let derived = super::super::derive(&trait_).unwrap();
197
198            assert_eq!(
199                derived,
200                parse_quote!(
201                    #[automatically_derived]
202                    impl<T: 'static + Send, MT: MyTrait<T>> MyTrait<T> for Box<MT> {}
203                )
204            );
205        }
206
207        #[test]
208        fn generics_lifetime() {
209            let trait_ = parse_quote!(
210                trait MyTrait<'a, 'b: 'a, T: 'static + Send> {}
211            );
212            let derived = super::super::derive(&trait_).unwrap();
213
214            assert_eq!(
215                derived,
216                parse_quote!(
217                    #[automatically_derived]
218                    impl<'a, 'b: 'a, T: 'static + Send, MT: MyTrait<'a, 'b, T>> MyTrait<'a, 'b, T> for Box<MT> {}
219                )
220            );
221        }
222
223        #[test]
224        fn associated_types() {
225            let trait_ = parse_quote!(
226                trait MyTrait {
227                    type Return;
228                }
229            );
230            let derived = super::super::derive(&trait_).unwrap();
231
232            assert_eq!(
233                derived,
234                parse_quote!(
235                    #[automatically_derived]
236                    impl<MT: MyTrait> MyTrait for Box<MT> {
237                        type Return = <MT as MyTrait>::Return;
238                    }
239                )
240            );
241        }
242
243        #[test]
244        fn associated_types_bound() {
245            let trait_ = parse_quote!(
246                trait MyTrait {
247                    type Return: Clone;
248                }
249            );
250            let derived = super::super::derive(&trait_).unwrap();
251
252            assert_eq!(
253                derived,
254                parse_quote!(
255                    #[automatically_derived]
256                    impl<MT: MyTrait> MyTrait for Box<MT> {
257                        type Return = <MT as MyTrait>::Return;
258                    }
259                )
260            );
261        }
262
263        #[test]
264        fn associated_types_dodgy_name() {
265            let trait_ = parse_quote!(
266                trait MyTrait {
267                    type r#type;
268                }
269            );
270            let derived = super::super::derive(&trait_).unwrap();
271
272            assert_eq!(
273                derived,
274                parse_quote!(
275                    #[automatically_derived]
276                    impl<MT: MyTrait> MyTrait for Box<MT> {
277                        type r#type = <MT as MyTrait>::r#type;
278                    }
279                )
280            );
281        }
282
283        #[test]
284        fn associated_types_attrs() {
285            let trait_ = parse_quote!(
286                trait MyTrait {
287                    #[cfg(target_arch = "wasm32")]
288                    type Return;
289                    #[cfg(not(target_arch = "wasm32"))]
290                    type Return: Send;
291                }
292            );
293            let derived = super::super::derive(&trait_).unwrap();
294
295            assert_eq!(
296                derived,
297                parse_quote!(
298                    #[automatically_derived]
299                    impl<MT: MyTrait> MyTrait for Box<MT> {
300                        #[cfg(target_arch = "wasm32")]
301                        type Return = <MT as MyTrait>::Return;
302                        #[cfg(not(target_arch = "wasm32"))]
303                        type Return = <MT as MyTrait>::Return;
304                    }
305                )
306            );
307        }
308
309        #[test]
310        fn associated_types_and_generics() {
311            let trait_ = parse_quote!(
312                trait MyTrait<T> {
313                    type Return;
314                }
315            );
316            let derived = super::super::derive(&trait_).unwrap();
317
318            assert_eq!(
319                derived,
320                parse_quote!(
321                    #[automatically_derived]
322                    impl<T, MT: MyTrait<T>> MyTrait<T> for Box<MT> {
323                        type Return = <MT as MyTrait<T>>::Return;
324                    }
325                )
326            );
327        }
328
329        #[test]
330        fn associated_type_generics() {
331            let trait_ = parse_quote!(
332                trait MyTrait {
333                    type Return<T>;
334                }
335            );
336            let derived = super::super::derive(&trait_).unwrap();
337
338            assert_eq!(
339                derived,
340                parse_quote!(
341                    #[automatically_derived]
342                    impl<MT: MyTrait> MyTrait for Box<MT> {
343                        type Return<T> = <MT as MyTrait>::Return<T>;
344                    }
345                )
346            );
347        }
348
349        #[test]
350        fn associated_type_generics_bounded() {
351            let trait_ = parse_quote!(
352                trait MyTrait {
353                    type Return<T: 'static + Send>;
354                }
355            );
356            let derived = super::super::derive(&trait_).unwrap();
357
358            assert_eq!(
359                derived,
360                parse_quote!(
361                    #[automatically_derived]
362                    impl<MT: MyTrait> MyTrait for Box<MT> {
363                        type Return<T: 'static + Send> = <MT as MyTrait>::Return<T>;
364                    }
365                )
366            );
367        }
368
369        #[test]
370        fn associated_type_generics_lifetimes() {
371            let trait_ = parse_quote!(
372                trait MyTrait {
373                    type Return<'a>
374                    where
375                        Self: 'a;
376                }
377            );
378            let derived = super::super::derive(&trait_).unwrap();
379
380            assert_eq!(
381                derived,
382                parse_quote!(
383                    #[automatically_derived]
384                    impl<MT: MyTrait> MyTrait for Box<MT> {
385                        type Return<'a> = <MT as MyTrait>::Return<'a>
386                        where
387                            Self: 'a;
388                    }
389                )
390            );
391        }
392    }
393}