async_trait/
expand.rs

1use crate::bound::{has_bound, InferredBound, Supertraits};
2use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3use crate::parse::Item;
4use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5use crate::verbatim::VerbatimFn;
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use std::mem;
10use syn::punctuated::Punctuated;
11use syn::visit_mut::{self, VisitMut};
12use syn::{
13    parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14    Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15    ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause,
16};
17
18impl ToTokens for Item {
19    fn to_tokens(&self, tokens: &mut TokenStream) {
20        match self {
21            Item::Trait(item) => item.to_tokens(tokens),
22            Item::Impl(item) => item.to_tokens(tokens),
23        }
24    }
25}
26
27#[derive(Clone, Copy)]
28enum Context<'a> {
29    Trait {
30        generics: &'a Generics,
31        supertraits: &'a Supertraits,
32    },
33    Impl {
34        impl_generics: &'a Generics,
35        associated_type_impl_traits: &'a Set<Ident>,
36    },
37}
38
39impl Context<'_> {
40    fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41        let generics = match self {
42            Context::Trait { generics, .. } => generics,
43            Context::Impl { impl_generics, .. } => impl_generics,
44        };
45        generics.params.iter().filter_map(move |param| {
46            if let GenericParam::Lifetime(param) = param {
47                if used.contains(&param.lifetime) {
48                    return Some(param);
49                }
50            }
51            None
52        })
53    }
54}
55
56pub fn expand(input: &mut Item, is_local: bool) {
57    match input {
58        Item::Trait(input) => {
59            let context = Context::Trait {
60                generics: &input.generics,
61                supertraits: &input.supertraits,
62            };
63            for inner in &mut input.items {
64                if let TraitItem::Fn(method) = inner {
65                    let sig = &mut method.sig;
66                    if sig.asyncness.is_some() {
67                        let block = &mut method.default;
68                        let mut has_self = has_self_in_sig(sig);
69                        method.attrs.push(parse_quote!(#[must_use]));
70                        if let Some(block) = block {
71                            has_self |= has_self_in_block(block);
72                            transform_block(context, sig, block);
73                            method.attrs.push(lint_suppress_with_body());
74                        } else {
75                            method.attrs.push(lint_suppress_without_body());
76                        }
77                        let has_default = method.default.is_some();
78                        transform_sig(context, sig, has_self, has_default, is_local);
79                    }
80                }
81            }
82        }
83        Item::Impl(input) => {
84            let mut associated_type_impl_traits = Set::new();
85            for inner in &input.items {
86                if let ImplItem::Type(assoc) = inner {
87                    if let Type::ImplTrait(_) = assoc.ty {
88                        associated_type_impl_traits.insert(assoc.ident.clone());
89                    }
90                }
91            }
92
93            let context = Context::Impl {
94                impl_generics: &input.generics,
95                associated_type_impl_traits: &associated_type_impl_traits,
96            };
97            for inner in &mut input.items {
98                match inner {
99                    ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100                        let sig = &mut method.sig;
101                        let block = &mut method.block;
102                        let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103                        transform_block(context, sig, block);
104                        transform_sig(context, sig, has_self, false, is_local);
105                        method.attrs.push(lint_suppress_with_body());
106                    }
107                    ImplItem::Verbatim(tokens) => {
108                        let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109                            Ok(method) if method.sig.asyncness.is_some() => method,
110                            _ => continue,
111                        };
112                        let sig = &mut method.sig;
113                        let has_self = has_self_in_sig(sig);
114                        transform_sig(context, sig, has_self, false, is_local);
115                        method.attrs.push(lint_suppress_with_body());
116                        *tokens = quote!(#method);
117                    }
118                    _ => {}
119                }
120            }
121        }
122    }
123}
124
125fn lint_suppress_with_body() -> Attribute {
126    parse_quote! {
127        #[allow(
128            elided_named_lifetimes,
129            clippy::async_yields_async,
130            clippy::diverging_sub_expression,
131            clippy::let_unit_value,
132            clippy::needless_arbitrary_self_type,
133            clippy::no_effect_underscore_binding,
134            clippy::shadow_same,
135            clippy::type_complexity,
136            clippy::type_repetition_in_bounds,
137            clippy::used_underscore_binding
138        )]
139    }
140}
141
142fn lint_suppress_without_body() -> Attribute {
143    parse_quote! {
144        #[allow(
145            elided_named_lifetimes,
146            clippy::type_complexity,
147            clippy::type_repetition_in_bounds
148        )]
149    }
150}
151
152// Input:
153//     async fn f<T>(&self, x: &T) -> Ret;
154//
155// Output:
156//     fn f<'life0, 'life1, 'async_trait, T>(
157//         &'life0 self,
158//         x: &'life1 T,
159//     ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
160//     where
161//         'life0: 'async_trait,
162//         'life1: 'async_trait,
163//         T: 'async_trait,
164//         Self: Sync + 'async_trait;
165fn transform_sig(
166    context: Context,
167    sig: &mut Signature,
168    has_self: bool,
169    has_default: bool,
170    is_local: bool,
171) {
172    sig.fn_token.span = sig.asyncness.take().unwrap().span;
173
174    let (ret_arrow, ret) = match &sig.output {
175        ReturnType::Default => (quote!(->), quote!(())),
176        ReturnType::Type(arrow, ret) => (quote!(#arrow), quote!(#ret)),
177    };
178
179    let mut lifetimes = CollectLifetimes::new();
180    for arg in &mut sig.inputs {
181        match arg {
182            FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183            FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184        }
185    }
186
187    for param in &mut sig.generics.params {
188        match param {
189            GenericParam::Type(param) => {
190                let param_name = &param.ident;
191                let span = match param.colon_token.take() {
192                    Some(colon_token) => colon_token.span,
193                    None => param_name.span(),
194                };
195                let bounds = mem::replace(&mut param.bounds, Punctuated::new());
196                where_clause_or_default(&mut sig.generics.where_clause)
197                    .predicates
198                    .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
199            }
200            GenericParam::Lifetime(param) => {
201                let param_name = &param.lifetime;
202                let span = match param.colon_token.take() {
203                    Some(colon_token) => colon_token.span,
204                    None => param_name.span(),
205                };
206                let bounds = mem::replace(&mut param.bounds, Punctuated::new());
207                where_clause_or_default(&mut sig.generics.where_clause)
208                    .predicates
209                    .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
210            }
211            GenericParam::Const(_) => {}
212        }
213    }
214
215    for param in context.lifetimes(&lifetimes.explicit) {
216        let param = &param.lifetime;
217        let span = param.span();
218        where_clause_or_default(&mut sig.generics.where_clause)
219            .predicates
220            .push(parse_quote_spanned!(span=> #param: 'async_trait));
221    }
222
223    if sig.generics.lt_token.is_none() {
224        sig.generics.lt_token = Some(Token![<](sig.ident.span()));
225    }
226    if sig.generics.gt_token.is_none() {
227        sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join()));
228    }
229
230    for elided in lifetimes.elided {
231        sig.generics.params.push(parse_quote!(#elided));
232        where_clause_or_default(&mut sig.generics.where_clause)
233            .predicates
234            .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
235    }
236
237    sig.generics.params.push(parse_quote!('async_trait));
238
239    if has_self {
240        let bounds: &[InferredBound] = if is_local {
241            &[]
242        } else if let Some(receiver) = sig.receiver() {
243            match receiver.ty.as_ref() {
244                // self: &Self
245                Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
246                // self: Arc<Self>
247                Type::Path(ty)
248                    if {
249                        let segment = ty.path.segments.last().unwrap();
250                        segment.ident == "Arc"
251                            && match &segment.arguments {
252                                PathArguments::AngleBracketed(arguments) => {
253                                    arguments.args.len() == 1
254                                        && match &arguments.args[0] {
255                                            GenericArgument::Type(Type::Path(arg)) => {
256                                                arg.path.is_ident("Self")
257                                            }
258                                            _ => false,
259                                        }
260                                }
261                                _ => false,
262                            }
263                    } =>
264                {
265                    &[InferredBound::Sync, InferredBound::Send]
266                }
267                _ => &[InferredBound::Send],
268            }
269        } else {
270            &[InferredBound::Send]
271        };
272
273        let bounds = bounds.iter().filter(|bound| match context {
274            Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound),
275            Context::Impl { .. } => false,
276        });
277
278        where_clause_or_default(&mut sig.generics.where_clause)
279            .predicates
280            .push(parse_quote! {
281                Self: #(#bounds +)* 'async_trait
282            });
283    }
284
285    for (i, arg) in sig.inputs.iter_mut().enumerate() {
286        match arg {
287            FnArg::Receiver(receiver) => {
288                if receiver.reference.is_none() {
289                    receiver.mutability = None;
290                }
291            }
292            FnArg::Typed(arg) => {
293                if match *arg.ty {
294                    Type::Reference(_) => false,
295                    _ => true,
296                } {
297                    if let Pat::Ident(pat) = &mut *arg.pat {
298                        pat.by_ref = None;
299                        pat.mutability = None;
300                    } else {
301                        let positional = positional_arg(i, &arg.pat);
302                        let m = mut_pat(&mut arg.pat);
303                        arg.pat = parse_quote!(#m #positional);
304                    }
305                }
306                AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
307            }
308        }
309    }
310
311    let bounds = if is_local {
312        quote!('async_trait)
313    } else {
314        quote!(::core::marker::Send + 'async_trait)
315    };
316    sig.output = parse_quote! {
317        #ret_arrow ::core::pin::Pin<Box<
318            dyn ::core::future::Future<Output = #ret> + #bounds
319        >>
320    };
321}
322
323// Input:
324//     async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
325//         self + x + a + b
326//     }
327//
328// Output:
329//     Box::pin(async move {
330//         let ___ret: Ret = {
331//             let __self = self;
332//             let x = x;
333//             let (a, b) = __arg1;
334//
335//             __self + x + a + b
336//         };
337//
338//         ___ret
339//     })
340fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
341    let mut replace_self = false;
342    let decls = sig
343        .inputs
344        .iter()
345        .enumerate()
346        .map(|(i, arg)| match arg {
347            FnArg::Receiver(Receiver {
348                self_token,
349                mutability,
350                ..
351            }) => {
352                replace_self = true;
353                let ident = Ident::new("__self", self_token.span);
354                quote!(let #mutability #ident = #self_token;)
355            }
356            FnArg::Typed(arg) => {
357                // If there is a #[cfg(...)] attribute that selectively enables
358                // the parameter, forward it to the variable.
359                //
360                // This is currently not applied to the `self` parameter.
361                let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
362
363                if let Type::Reference(_) = *arg.ty {
364                    quote!()
365                } else if let Pat::Ident(PatIdent {
366                    ident, mutability, ..
367                }) = &*arg.pat
368                {
369                    quote! {
370                        #(#attrs)*
371                        let #mutability #ident = #ident;
372                    }
373                } else {
374                    let pat = &arg.pat;
375                    let ident = positional_arg(i, pat);
376                    if let Pat::Wild(_) = **pat {
377                        quote! {
378                            #(#attrs)*
379                            let #ident = #ident;
380                        }
381                    } else {
382                        quote! {
383                            #(#attrs)*
384                            let #pat = {
385                                let #ident = #ident;
386                                #ident
387                            };
388                        }
389                    }
390                }
391            }
392        })
393        .collect::<Vec<_>>();
394
395    if replace_self {
396        ReplaceSelf.visit_block_mut(block);
397    }
398
399    let stmts = &block.stmts;
400    let let_ret = match &mut sig.output {
401        ReturnType::Default => quote_spanned! {block.brace_token.span=>
402            #(#decls)*
403            let () = { #(#stmts)* };
404        },
405        ReturnType::Type(_, ret) => {
406            if contains_associated_type_impl_trait(context, ret) {
407                if decls.is_empty() {
408                    quote!(#(#stmts)*)
409                } else {
410                    quote!(#(#decls)* { #(#stmts)* })
411                }
412            } else {
413                quote! {
414                    if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
415                        #[allow(unreachable_code)]
416                        return __ret;
417                    }
418                    #(#decls)*
419                    let __ret: #ret = { #(#stmts)* };
420                    #[allow(unreachable_code)]
421                    __ret
422                }
423            }
424        }
425    };
426    let box_pin = quote_spanned!(block.brace_token.span=>
427        Box::pin(async move { #let_ret })
428    );
429    block.stmts = parse_quote!(#box_pin);
430}
431
432fn positional_arg(i: usize, pat: &Pat) -> Ident {
433    let span = syn::spanned::Spanned::span(pat).resolved_at(Span::mixed_site());
434    format_ident!("__arg{}", i, span = span)
435}
436
437fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
438    struct AssociatedTypeImplTraits<'a> {
439        set: &'a Set<Ident>,
440        contains: bool,
441    }
442
443    impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
444        fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
445            if ty.qself.is_none()
446                && ty.path.segments.len() == 2
447                && ty.path.segments[0].ident == "Self"
448                && self.set.contains(&ty.path.segments[1].ident)
449            {
450                self.contains = true;
451            }
452            visit_mut::visit_type_path_mut(self, ty);
453        }
454    }
455
456    match context {
457        Context::Trait { .. } => false,
458        Context::Impl {
459            associated_type_impl_traits,
460            ..
461        } => {
462            let mut visit = AssociatedTypeImplTraits {
463                set: associated_type_impl_traits,
464                contains: false,
465            };
466            visit.visit_type_mut(ret);
467            visit.contains
468        }
469    }
470}
471
472fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
473    clause.get_or_insert_with(|| WhereClause {
474        where_token: Default::default(),
475        predicates: Punctuated::new(),
476    })
477}