1use syn::{parse_quote, spanned::Spanned};
2
3use crate::utils::{
4 deref_expr, generics_declaration_to_generics, signature_to_method_call, trait_to_generic_ident,
5};
6
7pub fn derive(trait_: &syn::ItemTrait) -> syn::Result<syn::ItemImpl> {
8 let trait_ident = &trait_.ident;
10 let generic_type = trait_to_generic_ident(&trait_);
11
12 let trait_generics = &trait_.generics;
17 let where_clause = &trait_.generics.where_clause;
18 let mut impl_generics = trait_generics.clone();
19
20 let mut trait_generic_names = trait_generics.clone();
22 trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?;
23
24 impl_generics.params.push(syn::GenericParam::Type(
25 parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized),
26 ));
27
28 let mut methods: Vec<syn::ImplItemFn> = Vec::new();
30 let mut assoc_types: Vec<syn::ImplItemType> = Vec::new();
31 for item in trait_.items.iter() {
32 if let syn::TraitItem::Fn(ref m) = item {
33 if let Some(r) = m.sig.receiver() {
34 let err = if r.colon_token.is_some() {
35 Some("cannot derive `Mut` for a trait declaring methods with arbitrary receiver types")
36 } else if r.reference.is_none() {
37 Some("cannot derive `Mut` for a trait declaring `self` methods")
38 } else {
39 None
40 };
41 if let Some(msg) = err {
42 return Err(syn::Error::new(r.span(), msg));
43 }
44 }
45
46 let mut call = signature_to_method_call(&m.sig)?;
47 call.receiver = Box::new(deref_expr(deref_expr(*call.receiver)));
48
49 let signature = &m.sig;
50 let item = parse_quote!(#[inline] #signature { #call });
51 methods.push(item)
52 }
53
54 if let syn::TraitItem::Type(t) = item {
55 let t_ident = &t.ident;
56 let attrs = &t.attrs;
57
58 let t_generics = &t.generics;
59 let where_clause = &t.generics.where_clause;
60 let mut t_generic_names = t_generics.clone();
61 t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?;
62
63 let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; );
64 assoc_types.push(item);
65 }
66 }
67
68 Ok(parse_quote!(
69 #[automatically_derived]
70 impl #impl_generics #trait_ident #trait_generic_names for &mut #generic_type #where_clause {
71 #(#assoc_types)*
72 #(#methods)*
73 }
74 ))
75}
76
77#[cfg(test)]
78mod tests {
79 mod derive {
80
81 use syn::parse_quote;
82
83 #[test]
84 fn empty() {
85 let trait_ = parse_quote!(
86 trait MyTrait {}
87 );
88 let derived = super::super::derive(&trait_).unwrap();
89 assert_eq!(
90 derived,
91 parse_quote!(
92 #[automatically_derived]
93 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {}
94 )
95 );
96 }
97
98 #[test]
99 fn receiver_mut() {
100 let trait_ = parse_quote!(
101 trait MyTrait {
102 fn my_method(&mut self);
103 }
104 );
105 assert_eq!(
106 super::super::derive(&trait_).unwrap(),
107 parse_quote!(
108 #[automatically_derived]
109 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
110 #[inline]
111 fn my_method(&mut self) {
112 (*(*self)).my_method()
113 }
114 }
115 )
116 );
117 }
118
119 #[test]
120 fn receiver_self() {
121 let trait_ = parse_quote!(
122 trait MyTrait {
123 fn my_method(self);
124 }
125 );
126 assert!(super::super::derive(&trait_).is_err());
127 }
128
129 #[test]
130 fn receiver_arbitrary() {
131 let trait_ = parse_quote!(
132 trait MyTrait {
133 fn my_method(self: Box<Self>);
134 }
135 );
136 assert!(super::super::derive(&trait_).is_err());
137 }
138
139 #[test]
140 fn generics() {
141 let trait_ = parse_quote!(
142 trait Trait<T> {}
143 );
144 let derived = super::super::derive(&trait_).unwrap();
145
146 assert_eq!(
147 derived,
148 parse_quote!(
149 #[automatically_derived]
150 impl<T, T_: Trait<T> + ?Sized> Trait<T> for &mut T_ {}
151 )
152 );
153 }
154
155 #[test]
156 fn generics_bounded() {
157 let trait_ = parse_quote!(
158 trait Trait<T: 'static + Send> {}
159 );
160 let derived = super::super::derive(&trait_).unwrap();
161
162 assert_eq!(
163 derived,
164 parse_quote!(
165 #[automatically_derived]
166 impl<T: 'static + Send, T_: Trait<T> + ?Sized> Trait<T> for &mut T_ {}
167 )
168 );
169 }
170
171 #[test]
172 fn generics_lifetime() {
173 let trait_ = parse_quote!(
174 trait Trait<'a, 'b: 'a, T: 'static + Send> {}
175 );
176 let derived = super::super::derive(&trait_).unwrap();
177
178 assert_eq!(
179 derived,
180 parse_quote!(
181 #[automatically_derived]
182 impl<'a, 'b: 'a, T: 'static + Send, T_: Trait<'a, 'b, T> + ?Sized>
183 Trait<'a, 'b, T> for &mut T_
184 {
185 }
186 )
187 );
188 }
189
190 #[test]
191 fn associated_types() {
192 let trait_ = parse_quote!(
193 trait MyTrait {
194 type Return;
195 }
196 );
197 let derived = super::super::derive(&trait_).unwrap();
198
199 assert_eq!(
200 derived,
201 parse_quote!(
202 #[automatically_derived]
203 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
204 type Return = <MT as MyTrait>::Return;
205 }
206 )
207 );
208 }
209
210 #[test]
211 fn associated_types_bound() {
212 let trait_ = parse_quote!(
213 trait MyTrait {
214 type Return: Clone;
215 }
216 );
217 let derived = super::super::derive(&trait_).unwrap();
218
219 assert_eq!(
220 derived,
221 parse_quote!(
222 #[automatically_derived]
223 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
224 type Return = <MT as MyTrait>::Return;
225 }
226 )
227 );
228 }
229
230 #[test]
231 fn associated_types_dodgy_name() {
232 let trait_ = parse_quote!(
233 trait MyTrait {
234 type r#type;
235 }
236 );
237 let derived = super::super::derive(&trait_).unwrap();
238
239 assert_eq!(
240 derived,
241 parse_quote!(
242 #[automatically_derived]
243 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
244 type r#type = <MT as MyTrait>::r#type;
245 }
246 )
247 );
248 }
249
250 #[test]
251 fn associated_types_attrs() {
252 let trait_ = parse_quote!(
253 trait MyTrait {
254 #[cfg(target_arch = "wasm32")]
255 type Return;
256 #[cfg(not(target_arch = "wasm32"))]
257 type Return: Send;
258 }
259 );
260 let derived = super::super::derive(&trait_).unwrap();
261
262 assert_eq!(
263 derived,
264 parse_quote!(
265 #[automatically_derived]
266 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
267 #[cfg(target_arch = "wasm32")]
268 type Return = <MT as MyTrait>::Return;
269 #[cfg(not(target_arch = "wasm32"))]
270 type Return = <MT as MyTrait>::Return;
271 }
272 )
273 );
274 }
275
276 #[test]
277 fn associated_types_and_generics() {
278 let trait_ = parse_quote!(
279 trait MyTrait<T> {
280 type Return;
281 }
282 );
283 let derived = super::super::derive(&trait_).unwrap();
284
285 assert_eq!(
286 derived,
287 parse_quote!(
288 #[automatically_derived]
289 impl<T, MT: MyTrait<T> + ?Sized> MyTrait<T> for &mut MT {
290 type Return = <MT as MyTrait<T>>::Return;
291 }
292 )
293 );
294 }
295
296 #[test]
297 fn associated_type_generics() {
298 let trait_ = parse_quote!(
299 trait MyTrait {
300 type Return<T>;
301 }
302 );
303 let derived = super::super::derive(&trait_).unwrap();
304
305 assert_eq!(
306 derived,
307 parse_quote!(
308 #[automatically_derived]
309 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
310 type Return<T> = <MT as MyTrait>::Return<T>;
311 }
312 )
313 );
314 }
315
316 #[test]
317 fn associated_type_generics_bounded() {
318 let trait_ = parse_quote!(
319 trait MyTrait {
320 type Return<T: 'static + Send>;
321 }
322 );
323 let derived = super::super::derive(&trait_).unwrap();
324
325 assert_eq!(
326 derived,
327 parse_quote!(
328 #[automatically_derived]
329 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
330 type Return<T: 'static + Send> = <MT as MyTrait>::Return<T>;
331 }
332 )
333 );
334 }
335
336 #[test]
337 fn associated_type_generics_lifetimes() {
338 let trait_ = parse_quote!(
339 trait MyTrait {
340 type Return<'a>
341 where
342 Self: 'a;
343 }
344 );
345 let derived = super::super::derive(&trait_).unwrap();
346
347 assert_eq!(
348 derived,
349 parse_quote!(
350 #[automatically_derived]
351 impl<MT: MyTrait + ?Sized> MyTrait for &mut MT {
352 type Return<'a> = <MT as MyTrait>::Return<'a>
353 where
354 Self: 'a;
355 }
356 )
357 );
358 }
359 }
360}