tokio/macros/
try_join.rs

1macro_rules! doc {
2    ($try_join:item) => {
3        /// Waits on multiple concurrent branches, returning when **all** branches
4        /// complete with `Ok(_)` or on the first `Err(_)`.
5        ///
6        /// The `try_join!` macro must be used inside of async functions, closures, and
7        /// blocks.
8        ///
9        /// Similar to [`join!`], the `try_join!` macro takes a list of async
10        /// expressions and evaluates them concurrently on the same task. Each async
11        /// expression evaluates to a future and the futures from each expression are
12        /// multiplexed on the current task. The `try_join!` macro returns when **all**
13        /// branches return with `Ok` or when the **first** branch returns with `Err`.
14        ///
15        /// [`join!`]: macro@join
16        ///
17        /// # Notes
18        ///
19        /// The supplied futures are stored inline and do not require allocating a
20        /// `Vec`.
21        ///
22        /// ### Runtime characteristics
23        ///
24        /// By running all async expressions on the current task, the expressions are
25        /// able to run **concurrently** but not in **parallel**. This means all
26        /// expressions are run on the same thread and if one branch blocks the thread,
27        /// all other expressions will be unable to continue. If parallelism is
28        /// required, spawn each async expression using [`tokio::spawn`] and pass the
29        /// join handle to `try_join!`.
30        ///
31        /// [`tokio::spawn`]: crate::spawn
32        ///
33        /// # Examples
34        ///
35        /// Basic `try_join` with two branches.
36        ///
37        /// ```
38        /// async fn do_stuff_async() -> Result<(), &'static str> {
39        ///     // async work
40        /// # Ok(())
41        /// }
42        ///
43        /// async fn more_async_work() -> Result<(), &'static str> {
44        ///     // more here
45        /// # Ok(())
46        /// }
47        ///
48        /// #[tokio::main]
49        /// async fn main() {
50        ///     let res = tokio::try_join!(
51        ///         do_stuff_async(),
52        ///         more_async_work());
53        ///
54        ///     match res {
55        ///          Ok((first, second)) => {
56        ///              // do something with the values
57        ///          }
58        ///          Err(err) => {
59        ///             println!("processing failed; error = {}", err);
60        ///          }
61        ///     }
62        /// }
63        /// ```
64        ///
65        /// Using `try_join!` with spawned tasks.
66        ///
67        /// ```
68        /// use tokio::task::JoinHandle;
69        ///
70        /// async fn do_stuff_async() -> Result<(), &'static str> {
71        ///     // async work
72        /// # Err("failed")
73        /// }
74        ///
75        /// async fn more_async_work() -> Result<(), &'static str> {
76        ///     // more here
77        /// # Ok(())
78        /// }
79        ///
80        /// async fn flatten<T>(handle: JoinHandle<Result<T, &'static str>>) -> Result<T, &'static str> {
81        ///     match handle.await {
82        ///         Ok(Ok(result)) => Ok(result),
83        ///         Ok(Err(err)) => Err(err),
84        ///         Err(err) => Err("handling failed"),
85        ///     }
86        /// }
87        ///
88        /// #[tokio::main]
89        /// async fn main() {
90        ///     let handle1 = tokio::spawn(do_stuff_async());
91        ///     let handle2 = tokio::spawn(more_async_work());
92        ///     match tokio::try_join!(flatten(handle1), flatten(handle2)) {
93        ///         Ok(val) => {
94        ///             // do something with the values
95        ///         }
96        ///         Err(err) => {
97        ///             println!("Failed with {}.", err);
98        ///             # assert_eq!(err, "failed");
99        ///         }
100        ///     }
101        /// }
102        /// ```
103        #[macro_export]
104        #[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
105        $try_join
106    };
107}
108
109#[cfg(doc)]
110doc! {macro_rules! try_join {
111    ($($future:expr),*) => { unimplemented!() }
112}}
113
114#[cfg(not(doc))]
115doc! {macro_rules! try_join {
116    (@ {
117        // One `_` for each branch in the `try_join!` macro. This is not used once
118        // normalization is complete.
119        ( $($count:tt)* )
120
121        // The expression `0+1+1+ ... +1` equal to the number of branches.
122        ( $($total:tt)* )
123
124        // Normalized try_join! branches
125        $( ( $($skip:tt)* ) $e:expr, )*
126
127    }) => {{
128        use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
129        use $crate::macros::support::Poll::{Ready, Pending};
130
131        // Safety: nothing must be moved out of `futures`. This is to satisfy
132        // the requirement of `Pin::new_unchecked` called below.
133        //
134        // We can't use the `pin!` macro for this because `futures` is a tuple
135        // and the standard library provides no way to pin-project to the fields
136        // of a tuple.
137        let mut futures = ( $( maybe_done($e), )* );
138
139        // This assignment makes sure that the `poll_fn` closure only has a
140        // reference to the futures, instead of taking ownership of them. This
141        // mitigates the issue described in
142        // <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
143        let mut futures = &mut futures;
144
145        // Each time the future created by poll_fn is polled, a different future will be polled first
146        // to ensure every future passed to join! gets a chance to make progress even if
147        // one of the futures consumes the whole budget.
148        //
149        // This is number of futures that will be skipped in the first loop
150        // iteration the next time.
151        let mut skip_next_time: u32 = 0;
152
153        poll_fn(move |cx| {
154            const COUNT: u32 = $($total)*;
155
156            let mut is_pending = false;
157
158            let mut to_run = COUNT;
159
160            // The number of futures that will be skipped in the first loop iteration
161            let mut skip = skip_next_time;
162
163            skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };
164
165            // This loop runs twice and the first `skip` futures
166            // are not polled in the first iteration.
167            loop {
168            $(
169                if skip == 0 {
170                    if to_run == 0 {
171                        // Every future has been polled
172                        break;
173                    }
174                    to_run -= 1;
175
176                    // Extract the future for this branch from the tuple.
177                    let ( $($skip,)* fut, .. ) = &mut *futures;
178
179                    // Safety: future is stored on the stack above
180                    // and never moved.
181                    let mut fut = unsafe { Pin::new_unchecked(fut) };
182
183                    // Try polling
184                    if fut.as_mut().poll(cx).is_pending() {
185                        is_pending = true;
186                    } else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
187                        return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
188                    }
189                } else {
190                    // Future skipped, one less future to skip in the next iteration
191                    skip -= 1;
192                }
193            )*
194            }
195
196            if is_pending {
197                Pending
198            } else {
199                Ready(Ok(($({
200                    // Extract the future for this branch from the tuple.
201                    let ( $($skip,)* fut, .. ) = &mut futures;
202
203                    // Safety: future is stored on the stack above
204                    // and never moved.
205                    let mut fut = unsafe { Pin::new_unchecked(fut) };
206
207                    fut
208                        .take_output()
209                        .expect("expected completed future")
210                        .ok()
211                        .expect("expected Ok(_)")
212                },)*)))
213            }
214        }).await
215    }};
216
217    // ===== Normalize =====
218
219    (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
220      $crate::try_join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
221    };
222
223    // ===== Entry point =====
224
225    ( $($e:expr),+ $(,)?) => {
226        $crate::try_join!(@{ () (0) } $($e,)*)
227    };
228
229    () => { async { Ok(()) }.await }
230}}