1use crate::bound::{has_bound, InferredBound, Supertraits};
2use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3use crate::parse::Item;
4use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5use crate::verbatim::VerbatimFn;
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use std::mem;
10use syn::punctuated::Punctuated;
11use syn::visit_mut::{self, VisitMut};
12use syn::{
13 parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14 Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15 ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause,
16};
17
18impl ToTokens for Item {
19 fn to_tokens(&self, tokens: &mut TokenStream) {
20 match self {
21 Item::Trait(item) => item.to_tokens(tokens),
22 Item::Impl(item) => item.to_tokens(tokens),
23 }
24 }
25}
26
27#[derive(Clone, Copy)]
28enum Context<'a> {
29 Trait {
30 generics: &'a Generics,
31 supertraits: &'a Supertraits,
32 },
33 Impl {
34 impl_generics: &'a Generics,
35 associated_type_impl_traits: &'a Set<Ident>,
36 },
37}
38
39impl Context<'_> {
40 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41 let generics = match self {
42 Context::Trait { generics, .. } => generics,
43 Context::Impl { impl_generics, .. } => impl_generics,
44 };
45 generics.params.iter().filter_map(move |param| {
46 if let GenericParam::Lifetime(param) = param {
47 if used.contains(¶m.lifetime) {
48 return Some(param);
49 }
50 }
51 None
52 })
53 }
54}
55
56pub fn expand(input: &mut Item, is_local: bool) {
57 match input {
58 Item::Trait(input) => {
59 let context = Context::Trait {
60 generics: &input.generics,
61 supertraits: &input.supertraits,
62 };
63 for inner in &mut input.items {
64 if let TraitItem::Fn(method) = inner {
65 let sig = &mut method.sig;
66 if sig.asyncness.is_some() {
67 let block = &mut method.default;
68 let mut has_self = has_self_in_sig(sig);
69 method.attrs.push(parse_quote!(#[must_use]));
70 if let Some(block) = block {
71 has_self |= has_self_in_block(block);
72 transform_block(context, sig, block);
73 method.attrs.push(lint_suppress_with_body());
74 } else {
75 method.attrs.push(lint_suppress_without_body());
76 }
77 let has_default = method.default.is_some();
78 transform_sig(context, sig, has_self, has_default, is_local);
79 }
80 }
81 }
82 }
83 Item::Impl(input) => {
84 let mut associated_type_impl_traits = Set::new();
85 for inner in &input.items {
86 if let ImplItem::Type(assoc) = inner {
87 if let Type::ImplTrait(_) = assoc.ty {
88 associated_type_impl_traits.insert(assoc.ident.clone());
89 }
90 }
91 }
92
93 let context = Context::Impl {
94 impl_generics: &input.generics,
95 associated_type_impl_traits: &associated_type_impl_traits,
96 };
97 for inner in &mut input.items {
98 match inner {
99 ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100 let sig = &mut method.sig;
101 let block = &mut method.block;
102 let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103 transform_block(context, sig, block);
104 transform_sig(context, sig, has_self, false, is_local);
105 method.attrs.push(lint_suppress_with_body());
106 }
107 ImplItem::Verbatim(tokens) => {
108 let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109 Ok(method) if method.sig.asyncness.is_some() => method,
110 _ => continue,
111 };
112 let sig = &mut method.sig;
113 let has_self = has_self_in_sig(sig);
114 transform_sig(context, sig, has_self, false, is_local);
115 method.attrs.push(lint_suppress_with_body());
116 *tokens = quote!(#method);
117 }
118 _ => {}
119 }
120 }
121 }
122 }
123}
124
125fn lint_suppress_with_body() -> Attribute {
126 parse_quote! {
127 #[allow(
128 elided_named_lifetimes,
129 clippy::async_yields_async,
130 clippy::diverging_sub_expression,
131 clippy::let_unit_value,
132 clippy::needless_arbitrary_self_type,
133 clippy::no_effect_underscore_binding,
134 clippy::shadow_same,
135 clippy::type_complexity,
136 clippy::type_repetition_in_bounds,
137 clippy::used_underscore_binding
138 )]
139 }
140}
141
142fn lint_suppress_without_body() -> Attribute {
143 parse_quote! {
144 #[allow(
145 elided_named_lifetimes,
146 clippy::type_complexity,
147 clippy::type_repetition_in_bounds
148 )]
149 }
150}
151
152fn transform_sig(
166 context: Context,
167 sig: &mut Signature,
168 has_self: bool,
169 has_default: bool,
170 is_local: bool,
171) {
172 sig.fn_token.span = sig.asyncness.take().unwrap().span;
173
174 let (ret_arrow, ret) = match &sig.output {
175 ReturnType::Default => (quote!(->), quote!(())),
176 ReturnType::Type(arrow, ret) => (quote!(#arrow), quote!(#ret)),
177 };
178
179 let mut lifetimes = CollectLifetimes::new();
180 for arg in &mut sig.inputs {
181 match arg {
182 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184 }
185 }
186
187 for param in &mut sig.generics.params {
188 match param {
189 GenericParam::Type(param) => {
190 let param_name = ¶m.ident;
191 let span = match param.colon_token.take() {
192 Some(colon_token) => colon_token.span,
193 None => param_name.span(),
194 };
195 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
196 where_clause_or_default(&mut sig.generics.where_clause)
197 .predicates
198 .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
199 }
200 GenericParam::Lifetime(param) => {
201 let param_name = ¶m.lifetime;
202 let span = match param.colon_token.take() {
203 Some(colon_token) => colon_token.span,
204 None => param_name.span(),
205 };
206 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
207 where_clause_or_default(&mut sig.generics.where_clause)
208 .predicates
209 .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
210 }
211 GenericParam::Const(_) => {}
212 }
213 }
214
215 for param in context.lifetimes(&lifetimes.explicit) {
216 let param = ¶m.lifetime;
217 let span = param.span();
218 where_clause_or_default(&mut sig.generics.where_clause)
219 .predicates
220 .push(parse_quote_spanned!(span=> #param: 'async_trait));
221 }
222
223 if sig.generics.lt_token.is_none() {
224 sig.generics.lt_token = Some(Token));
225 }
226 if sig.generics.gt_token.is_none() {
227 sig.generics.gt_token = Some(Token));
228 }
229
230 for elided in lifetimes.elided {
231 sig.generics.params.push(parse_quote!(#elided));
232 where_clause_or_default(&mut sig.generics.where_clause)
233 .predicates
234 .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
235 }
236
237 sig.generics.params.push(parse_quote!('async_trait));
238
239 if has_self {
240 let bounds: &[InferredBound] = if is_local {
241 &[]
242 } else if let Some(receiver) = sig.receiver() {
243 match receiver.ty.as_ref() {
244 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
246 Type::Path(ty)
248 if {
249 let segment = ty.path.segments.last().unwrap();
250 segment.ident == "Arc"
251 && match &segment.arguments {
252 PathArguments::AngleBracketed(arguments) => {
253 arguments.args.len() == 1
254 && match &arguments.args[0] {
255 GenericArgument::Type(Type::Path(arg)) => {
256 arg.path.is_ident("Self")
257 }
258 _ => false,
259 }
260 }
261 _ => false,
262 }
263 } =>
264 {
265 &[InferredBound::Sync, InferredBound::Send]
266 }
267 _ => &[InferredBound::Send],
268 }
269 } else {
270 &[InferredBound::Send]
271 };
272
273 let bounds = bounds.iter().filter(|bound| match context {
274 Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound),
275 Context::Impl { .. } => false,
276 });
277
278 where_clause_or_default(&mut sig.generics.where_clause)
279 .predicates
280 .push(parse_quote! {
281 Self: #(#bounds +)* 'async_trait
282 });
283 }
284
285 for (i, arg) in sig.inputs.iter_mut().enumerate() {
286 match arg {
287 FnArg::Receiver(receiver) => {
288 if receiver.reference.is_none() {
289 receiver.mutability = None;
290 }
291 }
292 FnArg::Typed(arg) => {
293 if match *arg.ty {
294 Type::Reference(_) => false,
295 _ => true,
296 } {
297 if let Pat::Ident(pat) = &mut *arg.pat {
298 pat.by_ref = None;
299 pat.mutability = None;
300 } else {
301 let positional = positional_arg(i, &arg.pat);
302 let m = mut_pat(&mut arg.pat);
303 arg.pat = parse_quote!(#m #positional);
304 }
305 }
306 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
307 }
308 }
309 }
310
311 let bounds = if is_local {
312 quote!('async_trait)
313 } else {
314 quote!(::core::marker::Send + 'async_trait)
315 };
316 sig.output = parse_quote! {
317 #ret_arrow ::core::pin::Pin<Box<
318 dyn ::core::future::Future<Output = #ret> + #bounds
319 >>
320 };
321}
322
323fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
341 let mut replace_self = false;
342 let decls = sig
343 .inputs
344 .iter()
345 .enumerate()
346 .map(|(i, arg)| match arg {
347 FnArg::Receiver(Receiver {
348 self_token,
349 mutability,
350 ..
351 }) => {
352 replace_self = true;
353 let ident = Ident::new("__self", self_token.span);
354 quote!(let #mutability #ident = #self_token;)
355 }
356 FnArg::Typed(arg) => {
357 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
362
363 if let Type::Reference(_) = *arg.ty {
364 quote!()
365 } else if let Pat::Ident(PatIdent {
366 ident, mutability, ..
367 }) = &*arg.pat
368 {
369 quote! {
370 #(#attrs)*
371 let #mutability #ident = #ident;
372 }
373 } else {
374 let pat = &arg.pat;
375 let ident = positional_arg(i, pat);
376 if let Pat::Wild(_) = **pat {
377 quote! {
378 #(#attrs)*
379 let #ident = #ident;
380 }
381 } else {
382 quote! {
383 #(#attrs)*
384 let #pat = {
385 let #ident = #ident;
386 #ident
387 };
388 }
389 }
390 }
391 }
392 })
393 .collect::<Vec<_>>();
394
395 if replace_self {
396 ReplaceSelf.visit_block_mut(block);
397 }
398
399 let stmts = &block.stmts;
400 let let_ret = match &mut sig.output {
401 ReturnType::Default => quote_spanned! {block.brace_token.span=>
402 #(#decls)*
403 let () = { #(#stmts)* };
404 },
405 ReturnType::Type(_, ret) => {
406 if contains_associated_type_impl_trait(context, ret) {
407 if decls.is_empty() {
408 quote!(#(#stmts)*)
409 } else {
410 quote!(#(#decls)* { #(#stmts)* })
411 }
412 } else {
413 quote! {
414 if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
415 #[allow(unreachable_code)]
416 return __ret;
417 }
418 #(#decls)*
419 let __ret: #ret = { #(#stmts)* };
420 #[allow(unreachable_code)]
421 __ret
422 }
423 }
424 }
425 };
426 let box_pin = quote_spanned!(block.brace_token.span=>
427 Box::pin(async move { #let_ret })
428 );
429 block.stmts = parse_quote!(#box_pin);
430}
431
432fn positional_arg(i: usize, pat: &Pat) -> Ident {
433 let span = syn::spanned::Spanned::span(pat).resolved_at(Span::mixed_site());
434 format_ident!("__arg{}", i, span = span)
435}
436
437fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
438 struct AssociatedTypeImplTraits<'a> {
439 set: &'a Set<Ident>,
440 contains: bool,
441 }
442
443 impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
444 fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
445 if ty.qself.is_none()
446 && ty.path.segments.len() == 2
447 && ty.path.segments[0].ident == "Self"
448 && self.set.contains(&ty.path.segments[1].ident)
449 {
450 self.contains = true;
451 }
452 visit_mut::visit_type_path_mut(self, ty);
453 }
454 }
455
456 match context {
457 Context::Trait { .. } => false,
458 Context::Impl {
459 associated_type_impl_traits,
460 ..
461 } => {
462 let mut visit = AssociatedTypeImplTraits {
463 set: associated_type_impl_traits,
464 contains: false,
465 };
466 visit.visit_type_mut(ret);
467 visit.contains
468 }
469 }
470}
471
472fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
473 clause.get_or_insert_with(|| WhereClause {
474 where_token: Default::default(),
475 predicates: Punctuated::new(),
476 })
477}