1use syn::{parse_quote, spanned::Spanned};
2
3use crate::utils::{
4 deref_expr, generics_declaration_to_generics, signature_to_method_call, trait_to_generic_ident,
5};
6
7pub fn derive(trait_: &syn::ItemTrait) -> syn::Result<syn::ItemImpl> {
8 let trait_ident = &trait_.ident;
10 let generic_type = trait_to_generic_ident(&trait_);
11
12 let trait_generics = &trait_.generics;
17 let where_clause = &trait_.generics.where_clause;
18 let mut impl_generics = trait_generics.clone();
19
20 let mut trait_generic_names = trait_generics.clone();
22 trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?;
23
24 impl_generics.params.push(syn::GenericParam::Type(
25 parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized),
26 ));
27
28 let mut methods: Vec<syn::ImplItemFn> = Vec::new();
30 let mut assoc_types: Vec<syn::ImplItemType> = Vec::new();
31 for item in trait_.items.iter() {
32 if let syn::TraitItem::Fn(ref m) = item {
33 if let Some(r) = m.sig.receiver() {
34 let err = if r.colon_token.is_some() {
35 Some("cannot derive `Ref` for a trait declaring methods with arbitrary receiver types")
36 } else if r.mutability.is_some() {
37 Some("cannot derive `Ref` for a trait declaring `&mut self` methods")
38 } else if r.reference.is_none() {
39 Some("cannot derive `Ref` for a trait declaring `self` methods")
40 } else {
41 None
42 };
43 if let Some(msg) = err {
44 return Err(syn::Error::new(r.span(), msg));
45 }
46 }
47
48 let mut call = signature_to_method_call(&m.sig)?;
49 call.receiver = Box::new(deref_expr(deref_expr(*call.receiver)));
50
51 let signature = &m.sig;
52 let item = parse_quote!(#[inline] #signature { #call });
53 methods.push(item)
54 }
55
56 if let syn::TraitItem::Type(t) = item {
57 let t_ident = &t.ident;
58 let attrs = &t.attrs;
59
60 let t_generics = &t.generics;
61 let where_clause = &t.generics.where_clause;
62 let mut t_generic_names = t_generics.clone();
63 t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?;
64
65 let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; );
66 assoc_types.push(item);
67 }
68 }
69
70 Ok(parse_quote!(
71 #[automatically_derived]
72 impl #impl_generics #trait_ident #trait_generic_names for &#generic_type #where_clause {
73 #(#assoc_types)*
74 #(#methods)*
75 }
76 ))
77}
78
79#[cfg(test)]
80mod tests {
81 mod derive {
82
83 use syn::parse_quote;
84
85 #[test]
86 fn empty() {
87 let trait_ = parse_quote!(
88 trait Trait {}
89 );
90 assert_eq!(
91 super::super::derive(&trait_).unwrap(),
92 parse_quote!(
93 #[automatically_derived]
94 impl<T: Trait + ?Sized> Trait for &T {}
95 )
96 );
97 }
98
99 #[test]
100 fn receiver_ref() {
101 let trait_ = parse_quote!(
102 trait Trait {
103 fn my_method(&self);
104 }
105 );
106 assert_eq!(
107 super::super::derive(&trait_).unwrap(),
108 parse_quote!(
109 #[automatically_derived]
110 impl<T: Trait + ?Sized> Trait for &T {
111 #[inline]
112 fn my_method(&self) {
113 (*(*self)).my_method()
114 }
115 }
116 )
117 );
118 }
119
120 #[test]
121 fn receiver_mut() {
122 let trait_ = parse_quote!(
123 trait Trait {
124 fn my_method(&mut self);
125 }
126 );
127 assert!(super::super::derive(&trait_).is_err());
128 }
129
130 #[test]
131 fn receiver_self() {
132 let trait_ = parse_quote!(
133 trait Trait {
134 fn my_method(self);
135 }
136 );
137 assert!(super::super::derive(&trait_).is_err());
138 }
139
140 #[test]
141 fn receiver_arbitrary() {
142 let trait_ = parse_quote!(
143 trait Trait {
144 fn my_method(self: Box<Self>);
145 }
146 );
147 assert!(super::super::derive(&trait_).is_err());
148 }
149
150 #[test]
151 fn generics() {
152 let trait_ = parse_quote!(
153 trait MyTrait<T> {}
154 );
155 let derived = super::super::derive(&trait_).unwrap();
156
157 assert_eq!(
158 derived,
159 parse_quote!(
160 #[automatically_derived]
161 impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for &MT {}
162 )
163 );
164 }
165
166 #[test]
167 fn generics_bounded() {
168 let trait_ = parse_quote!(
169 trait MyTrait<T: 'static + Send> {}
170 );
171 let derived = super::super::derive(&trait_).unwrap();
172
173 assert_eq!(
174 derived,
175 parse_quote!(
176 #[automatically_derived]
177 impl<T: 'static + Send, MT: MyTrait<T> + ?Sized> MyTrait<T> for &MT {}
178 )
179 );
180 }
181
182 #[test]
183 fn generics_lifetime() {
184 let trait_ = parse_quote!(
185 trait MyTrait<'a, 'b: 'a, T: 'static + Send> {}
186 );
187 let derived = super::super::derive(&trait_).unwrap();
188
189 assert_eq!(
190 derived,
191 parse_quote!(
192 #[automatically_derived]
193 impl<'a, 'b: 'a, T: 'static + Send, MT: MyTrait<'a, 'b, T> + ?Sized>
194 MyTrait<'a, 'b, T> for &MT
195 {
196 }
197 )
198 );
199 }
200
201 #[test]
202 fn associated_types() {
203 let trait_ = parse_quote!(
204 trait MyTrait {
205 type Return;
206 }
207 );
208 let derived = super::super::derive(&trait_).unwrap();
209
210 assert_eq!(
211 derived,
212 parse_quote!(
213 #[automatically_derived]
214 impl<MT: MyTrait + ?Sized> MyTrait for &MT {
215 type Return = <MT as MyTrait>::Return;
216 }
217 )
218 );
219 }
220
221 #[test]
222 fn associated_types_bound() {
223 let trait_ = parse_quote!(
224 trait MyTrait {
225 type Return: Clone;
226 }
227 );
228 let derived = super::super::derive(&trait_).unwrap();
229
230 assert_eq!(
231 derived,
232 parse_quote!(
233 #[automatically_derived]
234 impl<MT: MyTrait + ?Sized> MyTrait for &MT {
235 type Return = <MT as MyTrait>::Return;
236 }
237 )
238 );
239 }
240
241 #[test]
242 fn associated_types_dodgy_name() {
243 let trait_ = parse_quote!(
244 trait MyTrait {
245 type r#type;
246 }
247 );
248 let derived = super::super::derive(&trait_).unwrap();
249
250 assert_eq!(
251 derived,
252 parse_quote!(
253 #[automatically_derived]
254 impl<MT: MyTrait + ?Sized> MyTrait for &MT {
255 type r#type = <MT as MyTrait>::r#type;
256 }
257 )
258 );
259 }
260
261 #[test]
262 fn associated_types_attrs() {
263 let trait_ = parse_quote!(
264 trait MyTrait {
265 #[cfg(target_arch = "wasm32")]
266 type Return;
267 #[cfg(not(target_arch = "wasm32"))]
268 type Return: Send;
269 }
270 );
271 let derived = super::super::derive(&trait_).unwrap();
272
273 assert_eq!(
274 derived,
275 parse_quote!(
276 #[automatically_derived]
277 impl<MT: MyTrait + ?Sized> MyTrait for &MT {
278 #[cfg(target_arch = "wasm32")]
279 type Return = <MT as MyTrait>::Return;
280 #[cfg(not(target_arch = "wasm32"))]
281 type Return = <MT as MyTrait>::Return;
282 }
283 )
284 );
285 }
286
287 #[test]
288 fn associated_types_and_generics() {
289 let trait_ = parse_quote!(
290 trait MyTrait<T> {
291 type Return;
292 }
293 );
294 let derived = super::super::derive(&trait_).unwrap();
295
296 assert_eq!(
297 derived,
298 parse_quote!(
299 #[automatically_derived]
300 impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for &MT {
301 type Return = <MT as MyTrait<T>>::Return;
302 }
303 )
304 );
305 }
306
307 #[test]
308 fn associated_type_generics() {
309 let trait_ = parse_quote!(
310 trait MyTrait {
311 type Return<T>;
312 }
313 );
314 let derived = super::super::derive(&trait_).unwrap();
315
316 assert_eq!(
317 derived,
318 parse_quote!(
319 #[automatically_derived]
320 impl<MT: MyTrait + ?Sized> MyTrait for &MT {
321 type Return<T> = <MT as MyTrait>::Return<T>;
322 }
323 )
324 );
325 }
326
327 #[test]
328 fn associated_type_generics_bounded() {
329 let trait_ = parse_quote!(
330 trait MyTrait {
331 type Return<T: 'static + Send>;
332 }
333 );
334 let derived = super::super::derive(&trait_).unwrap();
335
336 assert_eq!(
337 derived,
338 parse_quote!(
339 #[automatically_derived]
340 impl<MT: MyTrait + ?Sized> MyTrait for &MT {
341 type Return<T: 'static + Send> = <MT as MyTrait>::Return<T>;
342 }
343 )
344 );
345 }
346
347 #[test]
348 fn associated_type_generics_lifetimes() {
349 let trait_ = parse_quote!(
350 trait MyTrait {
351 type Return<'a>
352 where
353 Self: 'a;
354 }
355 );
356 let derived = super::super::derive(&trait_).unwrap();
357
358 assert_eq!(
359 derived,
360 parse_quote!(
361 #[automatically_derived]
362 impl<MT: MyTrait + ?Sized> MyTrait for &MT {
363 type Return<'a> = <MT as MyTrait>::Return<'a>
364 where
365 Self: 'a;
366 }
367 )
368 );
369 }
370 }
371}