futures_macro/
join.rs

1//! The futures-rs `join!` macro implementation.
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{format_ident, quote};
6use syn::parse::{Parse, ParseStream};
7use syn::{Expr, Ident, Token};
8
9#[derive(Default)]
10struct Join {
11    fut_exprs: Vec<Expr>,
12}
13
14impl Parse for Join {
15    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
16        let mut join = Self::default();
17
18        while !input.is_empty() {
19            join.fut_exprs.push(input.parse::<Expr>()?);
20
21            if !input.is_empty() {
22                input.parse::<Token![,]>()?;
23            }
24        }
25
26        Ok(join)
27    }
28}
29
30fn bind_futures(fut_exprs: Vec<Expr>, span: Span) -> (Vec<TokenStream2>, Vec<Ident>) {
31    let mut future_let_bindings = Vec::with_capacity(fut_exprs.len());
32    let future_names: Vec<_> = fut_exprs
33        .into_iter()
34        .enumerate()
35        .map(|(i, expr)| {
36            let name = format_ident!("_fut{}", i, span = span);
37            future_let_bindings.push(quote! {
38                // Move future into a local so that it is pinned in one place and
39                // is no longer accessible by the end user.
40                let mut #name = __futures_crate::future::maybe_done(#expr);
41                let mut #name = unsafe { __futures_crate::Pin::new_unchecked(&mut #name) };
42            });
43            name
44        })
45        .collect();
46
47    (future_let_bindings, future_names)
48}
49
50/// The `join!` macro.
51pub(crate) fn join(input: TokenStream) -> TokenStream {
52    let parsed = syn::parse_macro_input!(input as Join);
53
54    // should be def_site, but that's unstable
55    let span = Span::call_site();
56
57    let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
58
59    let poll_futures = future_names.iter().map(|fut| {
60        quote! {
61            __all_done &= __futures_crate::future::Future::poll(
62                #fut.as_mut(), __cx).is_ready();
63        }
64    });
65    let take_outputs = future_names.iter().map(|fut| {
66        quote! {
67            #fut.as_mut().take_output().unwrap(),
68        }
69    });
70
71    TokenStream::from(quote! { {
72        #( #future_let_bindings )*
73
74        __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
75            let mut __all_done = true;
76            #( #poll_futures )*
77            if __all_done {
78                __futures_crate::task::Poll::Ready((
79                    #( #take_outputs )*
80                ))
81            } else {
82                __futures_crate::task::Poll::Pending
83            }
84        }).await
85    } })
86}
87
88/// The `try_join!` macro.
89pub(crate) fn try_join(input: TokenStream) -> TokenStream {
90    let parsed = syn::parse_macro_input!(input as Join);
91
92    // should be def_site, but that's unstable
93    let span = Span::call_site();
94
95    let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
96
97    let poll_futures = future_names.iter().map(|fut| {
98        quote! {
99            if __futures_crate::future::Future::poll(
100                #fut.as_mut(), __cx).is_pending()
101            {
102                __all_done = false;
103            } else if #fut.as_mut().output_mut().unwrap().is_err() {
104                // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce
105                // a `T: Debug` bound.
106                // Also, for an error type of ! any code after `err().unwrap()` is unreachable.
107                #[allow(unreachable_code)]
108                return __futures_crate::task::Poll::Ready(
109                    __futures_crate::Err(
110                        #fut.as_mut().take_output().unwrap().err().unwrap()
111                    )
112                );
113            }
114        }
115    });
116    let take_outputs = future_names.iter().map(|fut| {
117        quote! {
118            // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce
119            // an `E: Debug` bound.
120            // Also, for an ok type of ! any code after `ok().unwrap()` is unreachable.
121            #[allow(unreachable_code)]
122            #fut.as_mut().take_output().unwrap().ok().unwrap(),
123        }
124    });
125
126    TokenStream::from(quote! { {
127        #( #future_let_bindings )*
128
129        #[allow(clippy::diverging_sub_expression)]
130        __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
131            let mut __all_done = true;
132            #( #poll_futures )*
133            if __all_done {
134                __futures_crate::task::Poll::Ready(
135                    __futures_crate::Ok((
136                        #( #take_outputs )*
137                    ))
138                )
139            } else {
140                __futures_crate::task::Poll::Pending
141            }
142        }).await
143    } })
144}