tor_proto/util/
keyed_futures_unordered.rs

1//! Provides [`KeyedFuturesUnordered`]
2
3// So that we can declare these things as if they were in their own crate.
4#![allow(unreachable_pub)]
5
6use std::{
7    collections::{hash_map, HashMap},
8    hash::Hash,
9    pin::Pin,
10    sync::Arc,
11    task::Poll,
12};
13
14use futures::future::FutureExt;
15use futures::{
16    channel::mpsc::{UnboundedReceiver, UnboundedSender},
17    Future,
18};
19use pin_project::pin_project;
20
21/// Waker for internal use in [`KeyedFuturesUnordered`]
22///
23/// When woken, it notifies the parent [`KeyedFuturesUnordered`] that the future
24/// for a corresponding key is ready to be polled.
25struct KeyedWaker<K> {
26    /// The key associated with this waker.
27    key: K,
28    /// Sender cloned from the parent [`KeyedFuturesUnordered`].
29    sender: UnboundedSender<K>,
30}
31
32impl<K> std::task::Wake for KeyedWaker<K>
33where
34    K: Clone,
35{
36    fn wake(self: Arc<Self>) {
37        self.sender
38            .unbounded_send(self.key.clone())
39            .unwrap_or_else(|e| {
40                if e.is_disconnected() {
41                    // Other side has disappeared. Can safely ignore.
42                    return;
43                }
44                // Shouldn't happen, but probably no need to `panic`.
45                tracing::error!("Bug: Unexpected send error: {e:?}");
46            });
47    }
48}
49
50/// Efficiently manages a dynamic set of futures as per
51/// [`futures::stream::FuturesUnordered`]. Unlike `FuturesUnordered`, each future
52/// has an associated key. This key is returned along with the future's output,
53/// and can be used to cancel and *remove* a future from the set.
54///
55/// Implements [`futures::Stream`], producing a stream of completed futures and
56/// their associated keys.
57///
58/// # Stream behavior
59///
60/// `Stream::poll_next` returns:
61/// * `Poll::Ready(None)` if there are no futures managed by this object.
62/// * `Poll::Ready(Some((key, output)))` with the key and output of a ready
63///    future when there is one.
64/// * `Poll::Pending` when there are futures managed by this object, but none
65///    are currently ready.
66///
67/// Unlike for a generic `Stream`, it *is* permitted to call `poll_next` again
68/// after having received `Poll::Ready(None)`. It will still behave as above
69/// (i.e. returning `Pending` or `Ready` if futures have since been inserted).
70#[derive(Debug)]
71#[pin_project]
72pub struct KeyedFuturesUnordered<K, F>
73where
74    F: Future,
75{
76    /// Receiver on which we're notified of keys that are ready to be polled.
77    #[pin]
78    notification_receiver: UnboundedReceiver<K>,
79    /// Sender on which to notify `notifications_receiver` that keys are ready
80    /// to be polled.
81    // In particular, keys are sent here:
82    // * When a future is inserted.
83    // * In `KeyedWaker`, which is the `Waker` we register with futures when we
84    //   poll them internally.
85    notification_sender: UnboundedSender<K>,
86    /// Map of pending futures.
87    futures: HashMap<K, F>,
88}
89
90impl<K, F> KeyedFuturesUnordered<K, F>
91where
92    F: Future,
93    K: Eq + Hash + Clone,
94{
95    /// Create an empty [`KeyedFuturesUnordered`].
96    pub fn new() -> Self {
97        let (send, recv) = futures::channel::mpsc::unbounded();
98        Self {
99            notification_sender: send,
100            notification_receiver: recv,
101            futures: Default::default(),
102        }
103    }
104
105    /// Insert a future and associate it with `key`. Return an error if there is already an entry for `key`.
106    pub fn try_insert(&mut self, key: K, fut: F) -> Result<(), KeyAlreadyInsertedError<K, F>> {
107        let hash_map::Entry::Vacant(v) = self.futures.entry(key.clone()) else {
108            // Key is already present.
109            return Err(KeyAlreadyInsertedError { key, fut });
110        };
111        v.insert(fut);
112        // Immediately "notify" ourselves, to enqueue this key to be polled.
113        self.notification_sender
114            .unbounded_send(key)
115            // * Since the sender is unbounded, can't fail due to fullness.
116            // * Since we have our own copy of the receiver, can't be disconnected.
117            .expect("Unbounded send unexpectedly failed");
118        Ok(())
119    }
120
121    /// Remove the entry for `key`, if any, and return the corresponding future.
122    pub fn remove(&mut self, key: &K) -> Option<(K, F)> {
123        self.futures.remove_entry(key)
124    }
125
126    /// Get the future corresponding to `key`, if any.
127    ///
128    /// As for [`Self::get_mut`], removing or replacing its [`std::task::Waker`]
129    /// without waking it (e.g. using internal mutability) results in
130    /// unspecified (but sound) behavior.
131    #[allow(dead_code)]
132    pub fn get<'a>(&'a self, key: &K) -> Option<&'a F> {
133        self.futures.get(key)
134    }
135
136    /// Get the future corresponding to `key`, if any.
137    ///
138    /// The future should not be `poll`d, nor its registered
139    /// [`std::task::Waker`] otherwise removed or replaced (unless it is also
140    /// woken; see below). The result of doing either is unspecified (but
141    /// sound).
142    ///
143    /// This method is useful primarily when the future has other functionality
144    /// or data bundled with it besides its implementation of the `Future`
145    /// trait, though it *is* permitted to mutate the object in a way that
146    /// causes it to become ready (i.e. wakes and discards its registered
147    /// [`std::task::Waker`]`), or become unready (cause its next poll result to
148    /// be `Poll::Pending` when it otherwise would have been `Poll::Ready` and
149    /// may have already woken its registered `Waker`).
150    //
151    // More specifically:
152    // * If the waker is lost without being woken, we'll never
153    //   poll this future again.
154    // * If our waker is woken *and* the caller polls the future to completion,
155    //   we could end up polling it again after completion,
156    //   breaking the `Future` contract.
157    #[allow(dead_code)]
158    pub fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut F> {
159        self.futures.get_mut(key)
160    }
161}
162
163impl<K, F> futures::Stream for KeyedFuturesUnordered<K, F>
164where
165    F: Future + Unpin,
166    K: Clone + Hash + Eq + Send + Sync + 'static,
167{
168    type Item = (K, F::Output);
169
170    fn poll_next(
171        self: Pin<&mut Self>,
172        cx: &mut std::task::Context<'_>,
173    ) -> Poll<Option<Self::Item>> {
174        if self.futures.is_empty() {
175            // Follow precedent of `FuturesUnordered` of returning None in this case.
176            // TODO: Consider breaking this precedent? This behavior is a bit
177            // odd, since the documentation of the Stream trait indicates that a
178            // stream shouldn't be polled again after returning None.
179            return Poll::Ready(None);
180        }
181        let mut self_ = self.project();
182        loop {
183            // Get the next pollable future, registering the caller's waker.
184            let key = match self_.notification_receiver.as_mut().poll_next(cx) {
185                Poll::Ready(key) => key.expect("Unexpected end of stream"),
186                Poll::Pending => {
187                    // No more keys to try.
188                    return Poll::Pending;
189                }
190            };
191            let Some(fut) = self_.futures.get_mut(&key) else {
192                // No future for this key. Presumably because it was removed
193                // from the map. Try the next key.
194                continue;
195            };
196            // Poll the future itself, using our own waker that will notify us
197            // that this key is ready.
198            let waker = std::task::Waker::from(Arc::new(KeyedWaker {
199                key: key.clone(),
200                sender: self_.notification_sender.clone(),
201            }));
202            match fut.poll_unpin(&mut std::task::Context::from_waker(&waker)) {
203                Poll::Ready(o) => {
204                    // Remove and drop the future itself.
205                    // We *could* return it along with the item, but this would
206                    // be a departure from the interface of `FuturesUnordered`,
207                    // and most futures are designed to be discarded after
208                    // completion.
209                    self_.futures.remove(&key);
210
211                    return Poll::Ready(Some((key, o)));
212                }
213                Poll::Pending => {
214                    // This future wasn't actually ready.
215                    //
216                    // This can happen, e.g. because:
217                    // * This is our first time actually polling this future.
218                    // * The futures waker was called spuriously.
219                    // * This was actually a reused key, and we received the notification from
220                    //   a waker for a previous future registered with this key.
221                    //
222                    // Move on to the next key.
223                }
224            }
225        }
226    }
227}
228
229/// Error returned by [`KeyedFuturesUnordered::try_insert`].
230#[derive(Debug, thiserror::Error)]
231#[allow(clippy::exhaustive_structs)]
232pub struct KeyAlreadyInsertedError<K, F> {
233    /// Key that caller tried to insert.
234    #[allow(dead_code)]
235    pub key: K,
236    /// Future that caller tried to insert.
237    #[allow(dead_code)]
238    pub fut: F,
239}
240
241#[cfg(test)]
242mod tests {
243    // @@ begin test lint list maintained by maint/add_warning @@
244    #![allow(clippy::bool_assert_comparison)]
245    #![allow(clippy::clone_on_copy)]
246    #![allow(clippy::dbg_macro)]
247    #![allow(clippy::mixed_attributes_style)]
248    #![allow(clippy::print_stderr)]
249    #![allow(clippy::print_stdout)]
250    #![allow(clippy::single_char_pattern)]
251    #![allow(clippy::unwrap_used)]
252    #![allow(clippy::unchecked_duration_subtraction)]
253    #![allow(clippy::useless_vec)]
254    #![allow(clippy::needless_pass_by_value)]
255    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
256
257    use std::task::Waker;
258
259    use futures::{executor::block_on, future::poll_fn, StreamExt as _};
260    use oneshot_fused_workaround as oneshot;
261    use tor_rtmock::MockRuntime;
262
263    use super::*;
264
265    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
266    struct Key(u64);
267
268    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
269    struct Value(u64);
270
271    /// Simple future for testing. Supports comparison, and can be mutated directly to become ready.
272    #[derive(Debug, Clone)]
273    struct ValueFut<V> {
274        /// Value that will be produced when ready.
275        value: Option<V>,
276        /// Whether this is ready.
277        // We use a distinct flag here instead of a None value so that pending
278        // instances are still unequal if they have different values.
279        ready: bool,
280        // Waker
281        waker: Option<Waker>,
282    }
283
284    impl<V> std::cmp::PartialEq for ValueFut<V>
285    where
286        V: std::cmp::PartialEq,
287    {
288        fn eq(&self, other: &Self) -> bool {
289            // Ignores the waker, which isn't comparable
290            self.value == other.value && self.ready == other.ready
291        }
292    }
293
294    impl<V> std::cmp::Eq for ValueFut<V> where V: std::cmp::Eq {}
295
296    impl<V> ValueFut<V> {
297        fn ready(value: V) -> Self {
298            Self {
299                value: Some(value),
300                ready: true,
301                waker: None,
302            }
303        }
304        fn pending(value: V) -> Self {
305            Self {
306                value: Some(value),
307                ready: false,
308                waker: None,
309            }
310        }
311        fn make_ready(&mut self) {
312            self.ready = true;
313            if let Some(waker) = self.waker.take() {
314                waker.wake();
315            }
316        }
317    }
318
319    impl<V> Future for ValueFut<V>
320    where
321        V: Unpin,
322    {
323        type Output = V;
324
325        fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
326            if !self.ready {
327                self.waker.replace(cx.waker().clone());
328                Poll::Pending
329            } else {
330                Poll::Ready(self.value.take().expect("Polled future after it was ready"))
331            }
332        }
333    }
334
335    #[test]
336    fn test_empty() {
337        block_on(poll_fn(|cx| {
338            let mut kfu = KeyedFuturesUnordered::<Key, ValueFut<Value>>::new();
339
340            // When there are no futures in the set (ready or pending), returns
341            // `Poll::Ready(None)` as for `FuturesUnordered`.
342            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
343
344            // Nothing to get.
345            assert_eq!(kfu.get(&Key(0)), None);
346            assert_eq!(kfu.get_mut(&Key(0)), None);
347
348            Poll::Ready(())
349        }));
350    }
351
352    #[test]
353    fn test_one_pending_future() {
354        block_on(poll_fn(|cx| {
355            let mut kfu = KeyedFuturesUnordered::new();
356
357            kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
358
359            // When there are futures in the set, but none are ready, returns
360            // `Poll::Pending`, as for `FuturesUnordered`
361            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
362
363            // State should be unchanged; same result if we poll again.
364            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
365
366            // We should be able to get the future.
367            assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::pending(Value(0))));
368            assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::pending(Value(0))));
369
370            Poll::Ready(())
371        }));
372    }
373
374    #[test]
375    fn test_one_ready_future() {
376        block_on(poll_fn(|cx| {
377            let mut kfu = KeyedFuturesUnordered::new();
378
379            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
380
381            // Should be able to get the future before it's polled.
382            assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::ready(Value(1))));
383            assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::ready(Value(1))));
384
385            // When there is a ready future, returns it.
386            assert_eq!(
387                kfu.poll_next_unpin(cx),
388                Poll::Ready(Some((Key(0), Value(1))))
389            );
390
391            // After having returned the ready future, should be empty again.
392            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
393            assert_eq!(kfu.get(&Key(0)), None);
394            assert_eq!(kfu.get_mut(&Key(0)), None);
395
396            Poll::Ready(())
397        }));
398    }
399
400    #[test]
401    fn test_one_pending_then_ready_future() {
402        block_on(poll_fn(|cx| {
403            let mut kfu = KeyedFuturesUnordered::new();
404            let (send, recv) = oneshot::channel::<Value>();
405            kfu.try_insert(Key(0), recv).unwrap();
406
407            // Nothing ready yet.
408            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
409
410            // Should be able to get it.
411            assert!(kfu.get(&Key(0)).is_some());
412            assert!(kfu.get_mut(&Key(0)).is_some());
413
414            send.send(Value(1)).unwrap();
415
416            // oneshot future should be ready.
417            assert_eq!(
418                kfu.poll_next_unpin(cx),
419                Poll::Ready(Some((Key(0), Ok(Value(1)))))
420            );
421
422            // Empty again.
423            assert!(kfu.get(&Key(0)).is_none());
424            assert!(kfu.get_mut(&Key(0)).is_none());
425            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
426
427            Poll::Ready(())
428        }));
429    }
430
431    #[test]
432    fn test_remove_pending() {
433        block_on(poll_fn(|cx| {
434            let mut kfu = KeyedFuturesUnordered::new();
435            kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
436            assert_eq!(
437                kfu.remove(&Key(0)),
438                Some((Key(0), ValueFut::pending(Value(0))))
439            );
440            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
441            Poll::Ready(())
442        }));
443    }
444
445    #[test]
446    fn test_remove_ready() {
447        block_on(poll_fn(|cx| {
448            let mut kfu = KeyedFuturesUnordered::new();
449            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
450            assert_eq!(
451                kfu.remove(&Key(0)),
452                Some((Key(0), ValueFut::ready(Value(1))))
453            );
454            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
455            Poll::Ready(())
456        }));
457    }
458
459    #[test]
460    fn test_remove_and_reuse_ready() {
461        block_on(poll_fn(|cx| {
462            let mut kfu = KeyedFuturesUnordered::new();
463            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
464            assert_eq!(
465                kfu.remove(&Key(0)),
466                Some((Key(0), ValueFut::ready(Value(1))))
467            );
468            kfu.try_insert(Key(0), ValueFut::ready(Value(2))).unwrap();
469
470            // We should get back *only* the second value.
471            assert_eq!(
472                kfu.poll_next_unpin(cx),
473                Poll::Ready(Some((Key(0), Value(2))))
474            );
475            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
476
477            Poll::Ready(())
478        }));
479    }
480
481    #[test]
482    fn test_remove_and_reuse_pending_then_ready() {
483        block_on(poll_fn(|cx| {
484            let mut kfu = KeyedFuturesUnordered::new();
485            kfu.try_insert(Key(0), ValueFut::pending(Value(1))).unwrap();
486            let (_key, mut removed_value) = kfu.remove(&Key(0)).unwrap();
487            kfu.try_insert(Key(0), ValueFut::pending(Value(2))).unwrap();
488
489            // Make the *removed* future ready before polling again. This should
490            // cause an internal spurious wakeup, but not be visible from the
491            // user's perspective.
492            removed_value.make_ready();
493            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
494
495            // Make the future that we replaced it with become ready.
496            kfu.get_mut(&Key(0)).unwrap().make_ready();
497
498            // We should now get back *only* the second value.
499            assert_eq!(
500                kfu.poll_next_unpin(cx),
501                Poll::Ready(Some((Key(0), Value(2))))
502            );
503            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
504
505            Poll::Ready(())
506        }));
507    }
508
509    #[test]
510    fn test_async() {
511        MockRuntime::test_with_various(|rt| async move {
512            let mut kfu = KeyedFuturesUnordered::new();
513
514            for i in 0..10 {
515                let (send, recv) = oneshot::channel();
516                kfu.try_insert(Key(i), recv).unwrap();
517                rt.spawn_identified(format!("sender-{i}"), async move {
518                    send.send(Value(i)).unwrap();
519                });
520            }
521
522            let values = kfu.collect::<Vec<_>>().await;
523            let mut values = values
524                .into_iter()
525                .map(|(k, v)| (k, v.unwrap()))
526                .collect::<Vec<_>>();
527            values.sort();
528
529            let expected_values = (0..10).map(|i| (Key(i), Value(i))).collect::<Vec<_>>();
530            assert_eq!(values, expected_values);
531        });
532    }
533}