blanket/derive/
rc.rs

1use syn::{parse_quote, spanned::Spanned};
2
3use crate::utils::{
4    deref_expr, generics_declaration_to_generics, signature_to_method_call, trait_to_generic_ident,
5};
6
7pub fn derive(trait_: &syn::ItemTrait) -> syn::Result<syn::ItemImpl> {
8    // build an identifier for the generic type used for the implementation
9    let trait_ident = &trait_.ident;
10    let generic_type = trait_to_generic_ident(&trait_);
11
12    // build the generics for the impl block:
13    // we use the same generics as the trait itself, plus
14    // a generic type that implements the trait for which we provide the
15    // blanket implementation
16    let trait_generics = &trait_.generics;
17    let where_clause = &trait_.generics.where_clause;
18    let mut impl_generics = trait_generics.clone();
19
20    // we must however remove the generic type bounds, to avoid repeating them
21    let mut trait_generic_names = trait_generics.clone();
22    trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?;
23
24    impl_generics.params.push(syn::GenericParam::Type(
25        parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized),
26    ));
27
28    // build the methods
29    let mut methods: Vec<syn::ImplItemFn> = Vec::new();
30    let mut assoc_types: Vec<syn::ImplItemType> = Vec::new();
31    for item in trait_.items.iter() {
32        if let syn::TraitItem::Fn(ref m) = item {
33            if let Some(r) = m.sig.receiver() {
34                let err = if r.colon_token.is_some() {
35                    Some("cannot derive `Rc` for a trait declaring methods with arbitrary receiver types")
36                } else if r.mutability.is_some() {
37                    Some("cannot derive `Rc` for a trait declaring `&mut self` methods")
38                } else if r.reference.is_none() {
39                    Some("cannot derive `Rc` for a trait declaring `self` methods")
40                } else {
41                    None
42                };
43                if let Some(msg) = err {
44                    return Err(syn::Error::new(r.span(), msg));
45                }
46            }
47
48            let mut call = signature_to_method_call(&m.sig)?;
49            call.receiver = Box::new(deref_expr(deref_expr(*call.receiver)));
50
51            let signature = &m.sig;
52            let item = parse_quote!(#[inline] #signature { #call });
53            methods.push(item)
54        }
55
56        if let syn::TraitItem::Type(t) = item {
57            let t_ident = &t.ident;
58            let attrs = &t.attrs;
59
60            let t_generics = &t.generics;
61            let where_clause = &t.generics.where_clause;
62            let mut t_generic_names = t_generics.clone();
63            t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?;
64
65            let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; );
66            assoc_types.push(item);
67        }
68    }
69
70    Ok(parse_quote!(
71        #[automatically_derived]
72        impl #impl_generics #trait_ident #trait_generic_names for std::rc::Rc<#generic_type> #where_clause {
73            #(#assoc_types)*
74            #(#methods)*
75        }
76    ))
77}
78
79#[cfg(test)]
80mod tests {
81    mod derive {
82
83        use syn::parse_quote;
84
85        #[test]
86        fn empty() {
87            let trait_ = parse_quote!(
88                trait Trait {}
89            );
90            assert_eq!(
91                super::super::derive(&trait_).unwrap(),
92                parse_quote!(
93                    #[automatically_derived]
94                    impl<T: Trait + ?Sized> Trait for std::rc::Rc<T> {}
95                )
96            );
97        }
98
99        #[test]
100        fn receiver_ref() {
101            let trait_ = parse_quote!(
102                trait Trait {
103                    fn my_method(&self);
104                }
105            );
106            assert_eq!(
107                super::super::derive(&trait_).unwrap(),
108                parse_quote!(
109                    #[automatically_derived]
110                    impl<T: Trait + ?Sized> Trait for std::rc::Rc<T> {
111                        #[inline]
112                        fn my_method(&self) {
113                            (*(*self)).my_method()
114                        }
115                    }
116                )
117            );
118        }
119
120        #[test]
121        fn receiver_mut() {
122            let trait_ = parse_quote!(
123                trait Trait {
124                    fn my_method(&mut self);
125                }
126            );
127            assert!(super::super::derive(&trait_).is_err());
128        }
129
130        #[test]
131        fn receiver_self() {
132            let trait_ = parse_quote!(
133                trait Trait {
134                    fn my_method(self);
135                }
136            );
137            assert!(super::super::derive(&trait_).is_err());
138        }
139
140        #[test]
141        fn receiver_arbitrary() {
142            let trait_ = parse_quote!(
143                trait Trait {
144                    fn my_method(self: Box<Self>);
145                }
146            );
147            assert!(super::super::derive(&trait_).is_err());
148        }
149
150        #[test]
151        fn generics() {
152            let trait_ = parse_quote!(
153                trait MyTrait<T> {}
154            );
155            let derived = super::super::derive(&trait_).unwrap();
156
157            assert_eq!(
158                derived,
159                parse_quote!(
160                    #[automatically_derived]
161                    impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::rc::Rc<MT> {}
162                )
163            );
164        }
165
166        #[test]
167        fn generics_bounded() {
168            let trait_ = parse_quote!(
169                trait MyTrait<T: 'static + Send> {}
170            );
171            let derived = super::super::derive(&trait_).unwrap();
172
173            assert_eq!(
174                derived,
175                parse_quote!(
176                    #[automatically_derived]
177                    impl<T: 'static + Send, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::rc::Rc<MT> {}
178                )
179            );
180        }
181
182        #[test]
183        fn generics_lifetime() {
184            let trait_ = parse_quote!(
185                trait MyTrait<'a, 'b: 'a, T: 'static + Send> {}
186            );
187            let derived = super::super::derive(&trait_).unwrap();
188
189            assert_eq!(
190                derived,
191                parse_quote!(
192                    #[automatically_derived]
193                    impl<'a, 'b: 'a, T: 'static + Send, MT: MyTrait<'a, 'b, T> + ?Sized>
194                        MyTrait<'a, 'b, T> for std::rc::Rc<MT>
195                    {
196                    }
197                )
198            );
199        }
200
201        #[test]
202        fn associated_types() {
203            let trait_ = parse_quote!(
204                trait MyTrait {
205                    type Return;
206                }
207            );
208            let derived = super::super::derive(&trait_).unwrap();
209
210            assert_eq!(
211                derived,
212                parse_quote!(
213                    #[automatically_derived]
214                    impl<MT: MyTrait + ?Sized> MyTrait for std::rc::Rc<MT> {
215                        type Return = <MT as MyTrait>::Return;
216                    }
217                )
218            );
219        }
220
221        #[test]
222        fn associated_types_bound() {
223            let trait_ = parse_quote!(
224                trait MyTrait {
225                    type Return: Clone;
226                }
227            );
228            let derived = super::super::derive(&trait_).unwrap();
229
230            assert_eq!(
231                derived,
232                parse_quote!(
233                    #[automatically_derived]
234                    impl<MT: MyTrait + ?Sized> MyTrait for std::rc::Rc<MT> {
235                        type Return = <MT as MyTrait>::Return;
236                    }
237                )
238            );
239        }
240
241        #[test]
242        fn associated_types_dodgy_name() {
243            let trait_ = parse_quote!(
244                trait MyTrait {
245                    type r#type;
246                }
247            );
248            let derived = super::super::derive(&trait_).unwrap();
249
250            assert_eq!(
251                derived,
252                parse_quote!(
253                    #[automatically_derived]
254                    impl<MT: MyTrait + ?Sized> MyTrait for std::rc::Rc<MT> {
255                        type r#type = <MT as MyTrait>::r#type;
256                    }
257                )
258            );
259        }
260
261        #[test]
262        fn associated_types_attrs() {
263            let trait_ = parse_quote!(
264                trait MyTrait {
265                    #[cfg(target_arch = "wasm32")]
266                    type Return;
267                    #[cfg(not(target_arch = "wasm32"))]
268                    type Return: Send;
269                }
270            );
271            let derived = super::super::derive(&trait_).unwrap();
272
273            assert_eq!(
274                derived,
275                parse_quote!(
276                    #[automatically_derived]
277                    impl<MT: MyTrait + ?Sized> MyTrait for std::rc::Rc<MT> {
278                        #[cfg(target_arch = "wasm32")]
279                        type Return = <MT as MyTrait>::Return;
280                        #[cfg(not(target_arch = "wasm32"))]
281                        type Return = <MT as MyTrait>::Return;
282                    }
283                )
284            );
285        }
286
287        #[test]
288        fn associated_types_and_generics() {
289            let trait_ = parse_quote!(
290                trait MyTrait<T> {
291                    type Return;
292                }
293            );
294            let derived = super::super::derive(&trait_).unwrap();
295
296            assert_eq!(
297                derived,
298                parse_quote!(
299                    #[automatically_derived]
300                    impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::rc::Rc<MT> {
301                        type Return = <MT as MyTrait<T>>::Return;
302                    }
303                )
304            );
305        }
306
307        #[test]
308        fn associated_type_generics() {
309            let trait_ = parse_quote!(
310                trait MyTrait {
311                    type Return<T>;
312                }
313            );
314            let derived = super::super::derive(&trait_).unwrap();
315
316            assert_eq!(
317                derived,
318                parse_quote!(
319                    #[automatically_derived]
320                    impl<MT: MyTrait + ?Sized> MyTrait for std::rc::Rc<MT> {
321                        type Return<T> = <MT as MyTrait>::Return<T>;
322                    }
323                )
324            );
325        }
326
327        #[test]
328        fn associated_type_generics_bounded() {
329            let trait_ = parse_quote!(
330                trait MyTrait {
331                    type Return<T: 'static + Send>;
332                }
333            );
334            let derived = super::super::derive(&trait_).unwrap();
335
336            assert_eq!(
337                derived,
338                parse_quote!(
339                    #[automatically_derived]
340                    impl<MT: MyTrait + ?Sized> MyTrait for std::rc::Rc<MT> {
341                        type Return<T: 'static + Send> = <MT as MyTrait>::Return<T>;
342                    }
343                )
344            );
345        }
346
347        #[test]
348        fn associated_type_generics_lifetimes() {
349            let trait_ = parse_quote!(
350                trait MyTrait {
351                    type Return<'a>
352                    where
353                        Self: 'a;
354                }
355            );
356            let derived = super::super::derive(&trait_).unwrap();
357
358            assert_eq!(
359                derived,
360                parse_quote!(
361                    #[automatically_derived]
362                    impl<MT: MyTrait + ?Sized> MyTrait for std::rc::Rc<MT> {
363                        type Return<'a> = <MT as MyTrait>::Return<'a>
364                        where
365                            Self: 'a;
366                    }
367                )
368            );
369        }
370    }
371}