1use syn::parse_quote;
2use syn::spanned::Spanned;
3
4use crate::utils::deref_expr;
5use crate::utils::generics_declaration_to_generics;
6use crate::utils::signature_to_method_call;
7use crate::utils::trait_to_generic_ident;
8
9pub fn derive(trait_: &syn::ItemTrait) -> syn::Result<syn::ItemImpl> {
10 let trait_ident = &trait_.ident;
12 let generic_type = trait_to_generic_ident(&trait_);
13
14 let trait_generics = &trait_.generics;
19 let where_clause = &trait_.generics.where_clause;
20 let mut impl_generics = trait_generics.clone();
21
22 let mut trait_generic_names = trait_generics.clone();
24 trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?;
25
26 impl_generics.params.push(syn::GenericParam::Type(
27 parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized),
28 ));
29
30 let mut methods: Vec<syn::ImplItemFn> = Vec::new();
32 let mut assoc_types: Vec<syn::ImplItemType> = Vec::new();
33 for item in trait_.items.iter() {
34 if let syn::TraitItem::Fn(ref m) = item {
35 if let Some(r) = m.sig.receiver() {
36 let err = if r.colon_token.is_some() {
37 Some("cannot derive `Arc` for a trait declaring methods with arbitrary receiver types")
38 } else if r.mutability.is_some() {
39 Some("cannot derive `Arc` for a trait declaring `&mut self` methods")
40 } else if r.reference.is_none() {
41 Some("cannot derive `Arc` for a trait declaring `self` methods")
42 } else {
43 None
44 };
45 if let Some(msg) = err {
46 return Err(syn::Error::new(r.span(), msg));
47 }
48 }
49
50 let mut call = signature_to_method_call(&m.sig)?;
51 call.receiver = Box::new(deref_expr(deref_expr(*call.receiver)));
52
53 let signature = &m.sig;
54 let item = parse_quote!(#[inline] #signature { #call });
55 methods.push(item)
56 }
57
58 if let syn::TraitItem::Type(t) = item {
59 let t_ident = &t.ident;
60 let attrs = &t.attrs;
61
62 let t_generics = &t.generics;
63 let where_clause = &t.generics.where_clause;
64 let mut t_generic_names = t_generics.clone();
65 t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?;
66
67 let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; );
68 assoc_types.push(item);
69 }
70 }
71
72 Ok(parse_quote!(
73 #[automatically_derived]
74 impl #impl_generics #trait_ident #trait_generic_names for std::sync::Arc<#generic_type> #where_clause {
75 #(#assoc_types)*
76 #(#methods)*
77 }
78 ))
79}
80
81#[cfg(test)]
82mod tests {
83 mod derive {
84
85 use syn::parse_quote;
86
87 #[test]
88 fn empty() {
89 let trait_ = parse_quote!(
90 trait Trait {}
91 );
92 assert_eq!(
93 super::super::derive(&trait_).unwrap(),
94 parse_quote!(
95 #[automatically_derived]
96 impl<T: Trait + ?Sized> Trait for std::sync::Arc<T> {}
97 )
98 );
99 }
100
101 #[test]
102 fn receiver_ref() {
103 let trait_ = parse_quote!(
104 trait Trait {
105 fn my_method(&self);
106 }
107 );
108 assert_eq!(
109 super::super::derive(&trait_).unwrap(),
110 parse_quote!(
111 #[automatically_derived]
112 impl<T: Trait + ?Sized> Trait for std::sync::Arc<T> {
113 #[inline]
114 fn my_method(&self) {
115 (*(*self)).my_method()
116 }
117 }
118 )
119 );
120 }
121
122 #[test]
123 fn receiver_mut() {
124 let trait_ = parse_quote!(
125 trait Trait {
126 fn my_method(&mut self);
127 }
128 );
129 assert!(super::super::derive(&trait_).is_err());
130 }
131
132 #[test]
133 fn receiver_self() {
134 let trait_ = parse_quote!(
135 trait Trait {
136 fn my_method(self);
137 }
138 );
139 assert!(super::super::derive(&trait_).is_err());
140 }
141
142 #[test]
143 fn receiver_arbitrary() {
144 let trait_ = parse_quote!(
145 trait Trait {
146 fn my_method(self: Box<Self>);
147 }
148 );
149 assert!(super::super::derive(&trait_).is_err());
150 }
151
152 #[test]
153 fn generics() {
154 let trait_ = parse_quote!(
155 trait MyTrait<T> {}
156 );
157 let derived = super::super::derive(&trait_).unwrap();
158
159 assert_eq!(
160 derived,
161 parse_quote!(
162 #[automatically_derived]
163 impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::sync::Arc<MT> {}
164 )
165 );
166 }
167
168 #[test]
169 fn generics_bounded() {
170 let trait_ = parse_quote!(
171 trait MyTrait<T: 'static + Send> {}
172 );
173 let derived = super::super::derive(&trait_).unwrap();
174
175 assert_eq!(
176 derived,
177 parse_quote!(
178 #[automatically_derived]
179 impl<T: 'static + Send, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::sync::Arc<MT> {}
180 )
181 );
182 }
183
184 #[test]
185 fn generics_lifetime() {
186 let trait_ = parse_quote!(
187 trait MyTrait<'a, 'b: 'a, T: 'static + Send> {}
188 );
189 let derived = super::super::derive(&trait_).unwrap();
190
191 assert_eq!(
192 derived,
193 parse_quote!(
194 #[automatically_derived]
195 impl<'a, 'b: 'a, T: 'static + Send, MT: MyTrait<'a, 'b, T> + ?Sized>
196 MyTrait<'a, 'b, T> for std::sync::Arc<MT>
197 {
198 }
199 )
200 );
201 }
202
203 #[test]
204 fn associated_types() {
205 let trait_ = parse_quote!(
206 trait MyTrait {
207 type Return;
208 }
209 );
210 let derived = super::super::derive(&trait_).unwrap();
211
212 assert_eq!(
213 derived,
214 parse_quote!(
215 #[automatically_derived]
216 impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
217 type Return = <MT as MyTrait>::Return;
218 }
219 )
220 );
221 }
222
223 #[test]
224 fn associated_types_bound() {
225 let trait_ = parse_quote!(
226 trait MyTrait {
227 type Return: Clone;
228 }
229 );
230 let derived = super::super::derive(&trait_).unwrap();
231
232 assert_eq!(
233 derived,
234 parse_quote!(
235 #[automatically_derived]
236 impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
237 type Return = <MT as MyTrait>::Return;
238 }
239 )
240 );
241 }
242
243 #[test]
244 fn associated_types_dodgy_name() {
245 let trait_ = parse_quote!(
246 trait MyTrait {
247 type r#type;
248 }
249 );
250 let derived = super::super::derive(&trait_).unwrap();
251
252 assert_eq!(
253 derived,
254 parse_quote!(
255 #[automatically_derived]
256 impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
257 type r#type = <MT as MyTrait>::r#type;
258 }
259 )
260 );
261 }
262
263 #[test]
264 fn associated_types_attrs() {
265 let trait_ = parse_quote!(
266 trait MyTrait {
267 #[cfg(target_arch = "wasm32")]
268 type Return;
269 #[cfg(not(target_arch = "wasm32"))]
270 type Return: Send;
271 }
272 );
273 let derived = super::super::derive(&trait_).unwrap();
274
275 assert_eq!(
276 derived,
277 parse_quote!(
278 #[automatically_derived]
279 impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
280 #[cfg(target_arch = "wasm32")]
281 type Return = <MT as MyTrait>::Return;
282 #[cfg(not(target_arch = "wasm32"))]
283 type Return = <MT as MyTrait>::Return;
284 }
285 )
286 );
287 }
288
289 #[test]
290 fn associated_types_and_generics() {
291 let trait_ = parse_quote!(
292 trait MyTrait<T> {
293 type Return;
294 }
295 );
296 let derived = super::super::derive(&trait_).unwrap();
297
298 assert_eq!(
299 derived,
300 parse_quote!(
301 #[automatically_derived]
302 impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for std::sync::Arc<MT> {
303 type Return = <MT as MyTrait<T>>::Return;
304 }
305 )
306 );
307 }
308
309 #[test]
310 fn associated_type_generics() {
311 let trait_ = parse_quote!(
312 trait MyTrait {
313 type Return<T>;
314 }
315 );
316 let derived = super::super::derive(&trait_).unwrap();
317
318 assert_eq!(
319 derived,
320 parse_quote!(
321 #[automatically_derived]
322 impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
323 type Return<T> = <MT as MyTrait>::Return<T>;
324 }
325 )
326 );
327 }
328
329 #[test]
330 fn associated_type_generics_bounded() {
331 let trait_ = parse_quote!(
332 trait MyTrait {
333 type Return<T: 'static + Send>;
334 }
335 );
336 let derived = super::super::derive(&trait_).unwrap();
337
338 assert_eq!(
339 derived,
340 parse_quote!(
341 #[automatically_derived]
342 impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
343 type Return<T: 'static + Send> = <MT as MyTrait>::Return<T>;
344 }
345 )
346 );
347 }
348
349 #[test]
350 fn associated_type_generics_lifetimes() {
351 let trait_ = parse_quote!(
352 trait MyTrait {
353 type Return<'a>
354 where
355 Self: 'a;
356 }
357 );
358 let derived = super::super::derive(&trait_).unwrap();
359
360 assert_eq!(
361 derived,
362 parse_quote!(
363 #[automatically_derived]
364 impl<MT: MyTrait + ?Sized> MyTrait for std::sync::Arc<MT> {
365 type Return<'a> = <MT as MyTrait>::Return<'a>
366 where
367 Self: 'a;
368 }
369 )
370 );
371 }
372 }
373}