futures_macro/
select.rs

1//! The futures-rs `select!` macro implementation.
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::{format_ident, quote};
6use syn::parse::{Parse, ParseStream};
7use syn::{parse_quote, Expr, Ident, Pat, Token};
8
9mod kw {
10    syn::custom_keyword!(complete);
11}
12
13struct Select {
14    // span of `complete`, then expression after `=> ...`
15    complete: Option<Expr>,
16    default: Option<Expr>,
17    normal_fut_exprs: Vec<Expr>,
18    normal_fut_handlers: Vec<(Pat, Expr)>,
19}
20
21#[allow(clippy::large_enum_variant)]
22enum CaseKind {
23    Complete,
24    Default,
25    Normal(Pat, Expr),
26}
27
28impl Parse for Select {
29    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30        let mut select = Self {
31            complete: None,
32            default: None,
33            normal_fut_exprs: vec![],
34            normal_fut_handlers: vec![],
35        };
36
37        while !input.is_empty() {
38            let case_kind = if input.peek(kw::complete) {
39                // `complete`
40                if select.complete.is_some() {
41                    return Err(input.error("multiple `complete` cases found, only one allowed"));
42                }
43                input.parse::<kw::complete>()?;
44                CaseKind::Complete
45            } else if input.peek(Token![default]) {
46                // `default`
47                if select.default.is_some() {
48                    return Err(input.error("multiple `default` cases found, only one allowed"));
49                }
50                input.parse::<Ident>()?;
51                CaseKind::Default
52            } else {
53                // `<pat> = <expr>`
54                let pat = Pat::parse_multi_with_leading_vert(input)?;
55                input.parse::<Token![=]>()?;
56                let expr = input.parse()?;
57                CaseKind::Normal(pat, expr)
58            };
59
60            // `=> <expr>`
61            input.parse::<Token![=>]>()?;
62            let expr = Expr::parse_with_earlier_boundary_rule(input)?;
63
64            // Commas after the expression are only optional if it's a `Block`
65            // or it is the last branch in the `match`.
66            let is_block = match expr {
67                Expr::Block(_) => true,
68                _ => false,
69            };
70            if is_block || input.is_empty() {
71                input.parse::<Option<Token![,]>>()?;
72            } else {
73                input.parse::<Token![,]>()?;
74            }
75
76            match case_kind {
77                CaseKind::Complete => select.complete = Some(expr),
78                CaseKind::Default => select.default = Some(expr),
79                CaseKind::Normal(pat, fut_expr) => {
80                    select.normal_fut_exprs.push(fut_expr);
81                    select.normal_fut_handlers.push((pat, expr));
82                }
83            }
84        }
85
86        Ok(select)
87    }
88}
89
90// Enum over all the cases in which the `select!` waiting has completed and the result
91// can be processed.
92//
93// `enum __PrivResult<_1, _2, ...> { _1(_1), _2(_2), ..., Complete }`
94fn declare_result_enum(
95    result_ident: Ident,
96    variants: usize,
97    complete: bool,
98    span: Span,
99) -> (Vec<Ident>, syn::ItemEnum) {
100    // "_0", "_1", "_2"
101    let variant_names: Vec<Ident> =
102        (0..variants).map(|num| format_ident!("_{}", num, span = span)).collect();
103
104    let type_parameters = &variant_names;
105    let variants = &variant_names;
106
107    let complete_variant = if complete { Some(quote!(Complete)) } else { None };
108
109    let enum_item = parse_quote! {
110        enum #result_ident<#(#type_parameters,)*> {
111            #(
112                #variants(#type_parameters),
113            )*
114            #complete_variant
115        }
116    };
117
118    (variant_names, enum_item)
119}
120
121/// The `select!` macro.
122pub(crate) fn select(input: TokenStream) -> TokenStream {
123    select_inner(input, true)
124}
125
126/// The `select_biased!` macro.
127pub(crate) fn select_biased(input: TokenStream) -> TokenStream {
128    select_inner(input, false)
129}
130
131fn select_inner(input: TokenStream, random: bool) -> TokenStream {
132    let parsed = syn::parse_macro_input!(input as Select);
133
134    // should be def_site, but that's unstable
135    let span = Span::call_site();
136
137    let enum_ident = Ident::new("__PrivResult", span);
138
139    let (variant_names, enum_item) = declare_result_enum(
140        enum_ident.clone(),
141        parsed.normal_fut_exprs.len(),
142        parsed.complete.is_some(),
143        span,
144    );
145
146    // bind non-`Ident` future exprs w/ `let`
147    let mut future_let_bindings = Vec::with_capacity(parsed.normal_fut_exprs.len());
148    let bound_future_names: Vec<_> = parsed
149        .normal_fut_exprs
150        .into_iter()
151        .zip(variant_names.iter())
152        .map(|(expr, variant_name)| {
153            match expr {
154                syn::Expr::Path(path) => {
155                    // Don't bind futures that are already a path.
156                    // This prevents creating redundant stack space
157                    // for them.
158                    // Passing Futures by path requires those Futures to implement Unpin.
159                    // We check for this condition here in order to be able to
160                    // safely use Pin::new_unchecked(&mut #path) later on.
161                    future_let_bindings.push(quote! {
162                        __futures_crate::async_await::assert_fused_future(&#path);
163                        __futures_crate::async_await::assert_unpin(&#path);
164                    });
165                    path
166                }
167                _ => {
168                    // Bind and pin the resulting Future on the stack. This is
169                    // necessary to support direct select! calls on !Unpin
170                    // Futures. The Future is not explicitly pinned here with
171                    // a Pin call, but assumed as pinned. The actual Pin is
172                    // created inside the poll() function below to defer the
173                    // creation of the temporary pointer, which would otherwise
174                    // increase the size of the generated Future.
175                    // Safety: This is safe since the lifetime of the Future
176                    // is totally constraint to the lifetime of the select!
177                    // expression, and the Future can't get moved inside it
178                    // (it is shadowed).
179                    future_let_bindings.push(quote! {
180                        let mut #variant_name = #expr;
181                    });
182                    parse_quote! { #variant_name }
183                }
184            }
185        })
186        .collect();
187
188    // For each future, make an `&mut dyn FnMut(&mut Context<'_>) -> Option<Poll<__PrivResult<...>>`
189    // to use for polling that individual future. These will then be put in an array.
190    let poll_functions = bound_future_names.iter().zip(variant_names.iter()).map(
191        |(bound_future_name, variant_name)| {
192            // Below we lazily create the Pin on the Future below.
193            // This is done in order to avoid allocating memory in the generator
194            // for the Pin variable.
195            // Safety: This is safe because one of the following condition applies:
196            // 1. The Future is passed by the caller by name, and we assert that
197            //    it implements Unpin.
198            // 2. The Future is created in scope of the select! function and will
199            //    not be moved for the duration of it. It is thereby stack-pinned
200            quote! {
201                let mut #variant_name = |__cx: &mut __futures_crate::task::Context<'_>| {
202                    let mut #bound_future_name = unsafe {
203                        __futures_crate::Pin::new_unchecked(&mut #bound_future_name)
204                    };
205                    if __futures_crate::future::FusedFuture::is_terminated(&#bound_future_name) {
206                        __futures_crate::None
207                    } else {
208                        __futures_crate::Some(__futures_crate::future::FutureExt::poll_unpin(
209                            &mut #bound_future_name,
210                            __cx,
211                        ).map(#enum_ident::#variant_name))
212                    }
213                };
214                let #variant_name: &mut dyn FnMut(
215                    &mut __futures_crate::task::Context<'_>
216                ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = &mut #variant_name;
217            }
218        },
219    );
220
221    let none_polled = if parsed.complete.is_some() {
222        quote! {
223            __futures_crate::task::Poll::Ready(#enum_ident::Complete)
224        }
225    } else {
226        quote! {
227            panic!("all futures in select! were completed,\
228                    but no `complete =>` handler was provided")
229        }
230    };
231
232    let branches = parsed.normal_fut_handlers.into_iter().zip(variant_names.iter()).map(
233        |((pat, expr), variant_name)| {
234            quote! {
235                #enum_ident::#variant_name(#pat) => #expr,
236            }
237        },
238    );
239    let branches = quote! { #( #branches )* };
240
241    let complete_branch = parsed.complete.map(|complete_expr| {
242        quote! {
243            #enum_ident::Complete => { #complete_expr },
244        }
245    });
246
247    let branches = quote! {
248        #branches
249        #complete_branch
250    };
251
252    let await_select_fut = if parsed.default.is_some() {
253        // For select! with default this returns the Poll result
254        quote! {
255            __poll_fn(&mut __futures_crate::task::Context::from_waker(
256                __futures_crate::task::noop_waker_ref()
257            ))
258        }
259    } else {
260        quote! {
261            __futures_crate::future::poll_fn(__poll_fn).await
262        }
263    };
264
265    let execute_result_expr = if let Some(default_expr) = &parsed.default {
266        // For select! with default __select_result is a Poll, otherwise not
267        quote! {
268            match __select_result {
269                __futures_crate::task::Poll::Ready(result) => match result {
270                    #branches
271                },
272                _ => #default_expr
273            }
274        }
275    } else {
276        quote! {
277            match __select_result {
278                #branches
279            }
280        }
281    };
282
283    let shuffle = if random {
284        quote! {
285            __futures_crate::async_await::shuffle(&mut __select_arr);
286        }
287    } else {
288        quote!()
289    };
290
291    TokenStream::from(quote! { {
292        #enum_item
293
294        let __select_result = {
295            #( #future_let_bindings )*
296
297            let mut __poll_fn = |__cx: &mut __futures_crate::task::Context<'_>| {
298                let mut __any_polled = false;
299
300                #( #poll_functions )*
301
302                let mut __select_arr = [#( #variant_names ),*];
303                #shuffle
304                for poller in &mut __select_arr {
305                    let poller: &mut &mut dyn FnMut(
306                        &mut __futures_crate::task::Context<'_>
307                    ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = poller;
308                    match poller(__cx) {
309                        __futures_crate::Some(x @ __futures_crate::task::Poll::Ready(_)) =>
310                            return x,
311                        __futures_crate::Some(__futures_crate::task::Poll::Pending) => {
312                            __any_polled = true;
313                        }
314                        __futures_crate::None => {}
315                    }
316                }
317
318                if !__any_polled {
319                    #none_polled
320                } else {
321                    __futures_crate::task::Poll::Pending
322                }
323            };
324
325            #await_select_fut
326        };
327
328        #execute_result_expr
329    } })
330}