blanket/derive/
arc.rs

1use syn::parse_quote;
2use syn::spanned::Spanned;
3
4use crate::utils::deref_expr;
5use crate::utils::generics_declaration_to_generics;
6use crate::utils::signature_to_method_call;
7use crate::utils::trait_to_generic_ident;
8
9pub fn derive(trait_: &syn::ItemTrait) -> syn::Result<syn::ItemImpl> {
10    // build an identifier for the generic type used for the implementation
11    let trait_ident = &trait_.ident;
12    let generic_type = trait_to_generic_ident(&trait_);
13
14    // build the generics for the impl block:
15    // we use the same generics as the trait itself, plus
16    // a generic type that implements the trait for which we provide the
17    // blanket implementation
18    let trait_generics = &trait_.generics;
19    let where_clause = &trait_.generics.where_clause;
20    let mut impl_generics = trait_generics.clone();
21
22    // we must however remove the generic type bounds, to avoid repeating them
23    let mut trait_generic_names = trait_generics.clone();
24    trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?;
25
26    impl_generics.params.push(syn::GenericParam::Type(
27        parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized),
28    ));
29
30    // build the methods & associated types.
31    let mut methods: Vec<syn::ImplItemFn> = Vec::new();
32    let mut assoc_types: Vec<syn::ImplItemType> = Vec::new();
33    for item in trait_.items.iter() {
34        if let syn::TraitItem::Fn(ref m) = item {
35            if let Some(r) = m.sig.receiver() {
36                let err = if r.colon_token.is_some() {
37                    Some("cannot derive `Arc` for a trait declaring methods with arbitrary receiver types")
38                } else if r.mutability.is_some() {
39                    Some("cannot derive `Arc` for a trait declaring `&mut self` methods")
40                } else if r.reference.is_none() {
41                    Some("cannot derive `Arc` for a trait declaring `self` methods")
42                } else {
43                    None
44                };
45                if let Some(msg) = err {
46                    return Err(syn::Error::new(r.span(), msg));
47                }
48            }
49
50            let mut call = signature_to_method_call(&m.sig)?;
51            call.receiver = Box::new(deref_expr(deref_expr(*call.receiver)));
52
53            let signature = &m.sig;
54            let item = parse_quote!(#[inline] #signature { #call });
55            methods.push(item)
56        }
57
58        if let syn::TraitItem::Type(t) = item {
59            let t_ident = &t.ident;
60            let attrs = &t.attrs;
61
62            let t_generics = &t.generics;
63            let where_clause = &t.generics.where_clause;
64            let mut t_generic_names = t_generics.clone();
65            t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?;
66
67            let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; );
68            assoc_types.push(item);
69        }
70    }
71
72    Ok(parse_quote!(
73        #[automatically_derived]
74        impl #impl_generics #trait_ident #trait_generic_names for std::sync::Arc<#generic_type> #where_clause {
75            #(#assoc_types)*
76            #(#methods)*
77        }
78    ))
79}
80
81#[cfg(test)]
82mod tests {
83    mod derive {
84
85        use syn::parse_quote;
86
87        #[test]
88        fn empty() {
89            let trait_ = parse_quote!(
90                trait Trait {}
91            );
92            assert_eq!(
93                super::super::derive(&trait_).unwrap(),
94                parse_quote!(
95                    #[automatically_derived]
96                    impl<T: Trait + ?Sized> Trait for std::sync::Arc<T> {}
97                )
98            );
99        }
100
101        #[test]
102        fn receiver_ref() {
103            let trait_ = parse_quote!(
104                trait Trait {
105                    fn my_method(&self);
106                }
107            );
108            assert_eq!(
109                super::super::derive(&trait_).unwrap(),
110                parse_quote!(
111                    #[automatically_derived]
112                    impl<T: Trait + ?Sized> Trait for std::sync::Arc<T> {
113                        #[inline]
114                        fn my_method(&self) {
115                            (*(*self)).my_method()
116                        }
117                    }
118                )
119            );
120        }
121
122        #[test]
123        fn receiver_mut() {
124            let trait_ = parse_quote!(
125                trait Trait {
126                    fn my_method(&mut self);
127                }
128            );
129            assert!(super::super::derive(&trait_).is_err());
130        }
131
132        #[test]
133        fn receiver_self() {
134            let trait_ = parse_quote!(
135                trait Trait {
136                    fn my_method(self);
137                }
138            );
139            assert!(super::super::derive(&trait_).is_err());
140        }
141
142        #[test]
143        fn receiver_arbitrary() {
144            let trait_ = parse_quote!(
145                trait Trait {
146                    fn my_method(self: Box<Self>);
147                }
148            );
149            assert!(super::super::derive(&trait_).is_err());
150        }
151
152        #[test]
153        fn generics() {
154            let trait_ = parse_quote!(
155                trait MyTrait<T> {}
156            );
157            let derived = super::super::derive(&trait_).unwrap();
158
159            assert_eq!(
160                derived,
161                parse_quote!(
162                    #[automatically_derived]
163                    impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::sync::Arc<MT> {}
164                )
165            );
166        }
167
168        #[test]
169        fn generics_bounded() {
170            let trait_ = parse_quote!(
171                trait MyTrait<T: 'static + Send> {}
172            );
173            let derived = super::super::derive(&trait_).unwrap();
174
175            assert_eq!(
176                derived,
177                parse_quote!(
178                    #[automatically_derived]
179                    impl<T: 'static + Send, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::sync::Arc<MT> {}
180                )
181            );
182        }
183
184        #[test]
185        fn generics_lifetime() {
186            let trait_ = parse_quote!(
187                trait MyTrait<'a, 'b: 'a, T: 'static + Send> {}
188            );
189            let derived = super::super::derive(&trait_).unwrap();
190
191            assert_eq!(
192                derived,
193                parse_quote!(
194                    #[automatically_derived]
195                    impl<'a, 'b: 'a, T: 'static + Send, MT: MyTrait<'a, 'b, T> + ?Sized>
196                        MyTrait<'a, 'b, T> for std::sync::Arc<MT>
197                    {
198                    }
199                )
200            );
201        }
202
203        #[test]
204        fn associated_types() {
205            let trait_ = parse_quote!(
206                trait MyTrait {
207                    type Return;
208                }
209            );
210            let derived = super::super::derive(&trait_).unwrap();
211
212            assert_eq!(
213                derived,
214                parse_quote!(
215                    #[automatically_derived]
216                    impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
217                        type Return = <MT as MyTrait>::Return;
218                    }
219                )
220            );
221        }
222
223        #[test]
224        fn associated_types_bound() {
225            let trait_ = parse_quote!(
226                trait MyTrait {
227                    type Return: Clone;
228                }
229            );
230            let derived = super::super::derive(&trait_).unwrap();
231
232            assert_eq!(
233                derived,
234                parse_quote!(
235                    #[automatically_derived]
236                    impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
237                        type Return = <MT as MyTrait>::Return;
238                    }
239                )
240            );
241        }
242
243        #[test]
244        fn associated_types_dodgy_name() {
245            let trait_ = parse_quote!(
246                trait MyTrait {
247                    type r#type;
248                }
249            );
250            let derived = super::super::derive(&trait_).unwrap();
251
252            assert_eq!(
253                derived,
254                parse_quote!(
255                    #[automatically_derived]
256                    impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
257                        type r#type = <MT as MyTrait>::r#type;
258                    }
259                )
260            );
261        }
262
263        #[test]
264        fn associated_types_attrs() {
265            let trait_ = parse_quote!(
266                trait MyTrait {
267                    #[cfg(target_arch = "wasm32")]
268                    type Return;
269                    #[cfg(not(target_arch = "wasm32"))]
270                    type Return: Send;
271                }
272            );
273            let derived = super::super::derive(&trait_).unwrap();
274
275            assert_eq!(
276                derived,
277                parse_quote!(
278                    #[automatically_derived]
279                    impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
280                        #[cfg(target_arch = "wasm32")]
281                        type Return = <MT as MyTrait>::Return;
282                        #[cfg(not(target_arch = "wasm32"))]
283                        type Return = <MT as MyTrait>::Return;
284                    }
285                )
286            );
287        }
288
289        #[test]
290        fn associated_types_and_generics() {
291            let trait_ = parse_quote!(
292                trait MyTrait<T> {
293                    type Return;
294                }
295            );
296            let derived = super::super::derive(&trait_).unwrap();
297
298            assert_eq!(
299                derived,
300                parse_quote!(
301                    #[automatically_derived]
302                    impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::sync::Arc<MT> {
303                        type Return = <MT as MyTrait<T>>::Return;
304                    }
305                )
306            );
307        }
308
309        #[test]
310        fn associated_type_generics() {
311            let trait_ = parse_quote!(
312                trait MyTrait {
313                    type Return<T>;
314                }
315            );
316            let derived = super::super::derive(&trait_).unwrap();
317
318            assert_eq!(
319                derived,
320                parse_quote!(
321                    #[automatically_derived]
322                    impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
323                        type Return<T> = <MT as MyTrait>::Return<T>;
324                    }
325                )
326            );
327        }
328
329        #[test]
330        fn associated_type_generics_bounded() {
331            let trait_ = parse_quote!(
332                trait MyTrait {
333                    type Return<T: 'static + Send>;
334                }
335            );
336            let derived = super::super::derive(&trait_).unwrap();
337
338            assert_eq!(
339                derived,
340                parse_quote!(
341                    #[automatically_derived]
342                    impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
343                        type Return<T: 'static + Send> = <MT as MyTrait>::Return<T>;
344                    }
345                )
346            );
347        }
348
349        #[test]
350        fn associated_type_generics_lifetimes() {
351            let trait_ = parse_quote!(
352                trait MyTrait {
353                    type Return<'a>
354                    where
355                        Self: 'a;
356                }
357            );
358            let derived = super::super::derive(&trait_).unwrap();
359
360            assert_eq!(
361                derived,
362                parse_quote!(
363                    #[automatically_derived]
364                    impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
365                        type Return<'a> = <MT as MyTrait>::Return<'a>
366                        where
367                            Self: 'a;
368                    }
369                )
370            );
371        }
372    }
373}