tor_proto/util/keyed_futures_unordered.rs
1//! Provides [`KeyedFuturesUnordered`]
2
3// So that we can declare these things as if they were in their own crate.
4#![allow(unreachable_pub)]
5
6use std::{
7 collections::{hash_map, HashMap},
8 hash::Hash,
9 pin::Pin,
10 sync::Arc,
11 task::Poll,
12};
13
14use futures::future::FutureExt;
15use futures::{
16 channel::mpsc::{UnboundedReceiver, UnboundedSender},
17 Future,
18};
19use pin_project::pin_project;
20
21/// Waker for internal use in [`KeyedFuturesUnordered`]
22///
23/// When woken, it notifies the parent [`KeyedFuturesUnordered`] that the future
24/// for a corresponding key is ready to be polled.
25struct KeyedWaker<K> {
26 /// The key associated with this waker.
27 key: K,
28 /// Sender cloned from the parent [`KeyedFuturesUnordered`].
29 sender: UnboundedSender<K>,
30}
31
32impl<K> std::task::Wake for KeyedWaker<K>
33where
34 K: Clone,
35{
36 fn wake(self: Arc<Self>) {
37 self.sender
38 .unbounded_send(self.key.clone())
39 .unwrap_or_else(|e| {
40 if e.is_disconnected() {
41 // Other side has disappeared. Can safely ignore.
42 return;
43 }
44 // Shouldn't happen, but probably no need to `panic`.
45 tracing::error!("Bug: Unexpected send error: {e:?}");
46 });
47 }
48}
49
50/// Efficiently manages a dynamic set of futures as per
51/// [`futures::stream::FuturesUnordered`]. Unlike `FuturesUnordered`, each future
52/// has an associated key. This key is returned along with the future's output,
53/// and can be used to cancel and *remove* a future from the set.
54///
55/// Implements [`futures::Stream`], producing a stream of completed futures and
56/// their associated keys.
57///
58/// # Stream behavior
59///
60/// `Stream::poll_next` returns:
61/// * `Poll::Ready(None)` if there are no futures managed by this object.
62/// * `Poll::Ready(Some((key, output)))` with the key and output of a ready
63/// future when there is one.
64/// * `Poll::Pending` when there are futures managed by this object, but none
65/// are currently ready.
66///
67/// Unlike for a generic `Stream`, it *is* permitted to call `poll_next` again
68/// after having received `Poll::Ready(None)`. It will still behave as above
69/// (i.e. returning `Pending` or `Ready` if futures have since been inserted).
70#[derive(Debug)]
71#[pin_project]
72pub struct KeyedFuturesUnordered<K, F>
73where
74 F: Future,
75{
76 /// Receiver on which we're notified of keys that are ready to be polled.
77 #[pin]
78 notification_receiver: UnboundedReceiver<K>,
79 /// Sender on which to notify `notifications_receiver` that keys are ready
80 /// to be polled.
81 // In particular, keys are sent here:
82 // * When a future is inserted.
83 // * In `KeyedWaker`, which is the `Waker` we register with futures when we
84 // poll them internally.
85 notification_sender: UnboundedSender<K>,
86 /// Map of pending futures.
87 futures: HashMap<K, F>,
88}
89
90impl<K, F> KeyedFuturesUnordered<K, F>
91where
92 F: Future,
93 K: Eq + Hash + Clone,
94{
95 /// Create an empty [`KeyedFuturesUnordered`].
96 pub fn new() -> Self {
97 let (send, recv) = futures::channel::mpsc::unbounded();
98 Self {
99 notification_sender: send,
100 notification_receiver: recv,
101 futures: Default::default(),
102 }
103 }
104
105 /// Insert a future and associate it with `key`. Return an error if there is already an entry for `key`.
106 pub fn try_insert(&mut self, key: K, fut: F) -> Result<(), KeyAlreadyInsertedError<K, F>> {
107 let hash_map::Entry::Vacant(v) = self.futures.entry(key.clone()) else {
108 // Key is already present.
109 return Err(KeyAlreadyInsertedError { key, fut });
110 };
111 v.insert(fut);
112 // Immediately "notify" ourselves, to enqueue this key to be polled.
113 self.notification_sender
114 .unbounded_send(key)
115 // * Since the sender is unbounded, can't fail due to fullness.
116 // * Since we have our own copy of the receiver, can't be disconnected.
117 .expect("Unbounded send unexpectedly failed");
118 Ok(())
119 }
120
121 /// Remove the entry for `key`, if any, and return the corresponding future.
122 pub fn remove(&mut self, key: &K) -> Option<(K, F)> {
123 self.futures.remove_entry(key)
124 }
125
126 /// Get the future corresponding to `key`, if any.
127 ///
128 /// As for [`Self::get_mut`], removing or replacing its [`std::task::Waker`]
129 /// without waking it (e.g. using internal mutability) results in
130 /// unspecified (but sound) behavior.
131 #[allow(dead_code)]
132 pub fn get<'a>(&'a self, key: &K) -> Option<&'a F> {
133 self.futures.get(key)
134 }
135
136 /// Get the future corresponding to `key`, if any.
137 ///
138 /// The future should not be `poll`d, nor its registered
139 /// [`std::task::Waker`] otherwise removed or replaced (unless it is also
140 /// woken; see below). The result of doing either is unspecified (but
141 /// sound).
142 ///
143 /// This method is useful primarily when the future has other functionality
144 /// or data bundled with it besides its implementation of the `Future`
145 /// trait, though it *is* permitted to mutate the object in a way that
146 /// causes it to become ready (i.e. wakes and discards its registered
147 /// [`std::task::Waker`]`), or become unready (cause its next poll result to
148 /// be `Poll::Pending` when it otherwise would have been `Poll::Ready` and
149 /// may have already woken its registered `Waker`).
150 //
151 // More specifically:
152 // * If the waker is lost without being woken, we'll never
153 // poll this future again.
154 // * If our waker is woken *and* the caller polls the future to completion,
155 // we could end up polling it again after completion,
156 // breaking the `Future` contract.
157 #[allow(dead_code)]
158 pub fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut F> {
159 self.futures.get_mut(key)
160 }
161}
162
163impl<K, F> futures::Stream for KeyedFuturesUnordered<K, F>
164where
165 F: Future + Unpin,
166 K: Clone + Hash + Eq + Send + Sync + 'static,
167{
168 type Item = (K, F::Output);
169
170 fn poll_next(
171 self: Pin<&mut Self>,
172 cx: &mut std::task::Context<'_>,
173 ) -> Poll<Option<Self::Item>> {
174 if self.futures.is_empty() {
175 // Follow precedent of `FuturesUnordered` of returning None in this case.
176 // TODO: Consider breaking this precedent? This behavior is a bit
177 // odd, since the documentation of the Stream trait indicates that a
178 // stream shouldn't be polled again after returning None.
179 return Poll::Ready(None);
180 }
181 let mut self_ = self.project();
182 loop {
183 // Get the next pollable future, registering the caller's waker.
184 let key = match self_.notification_receiver.as_mut().poll_next(cx) {
185 Poll::Ready(key) => key.expect("Unexpected end of stream"),
186 Poll::Pending => {
187 // No more keys to try.
188 return Poll::Pending;
189 }
190 };
191 let Some(fut) = self_.futures.get_mut(&key) else {
192 // No future for this key. Presumably because it was removed
193 // from the map. Try the next key.
194 continue;
195 };
196 // Poll the future itself, using our own waker that will notify us
197 // that this key is ready.
198 let waker = std::task::Waker::from(Arc::new(KeyedWaker {
199 key: key.clone(),
200 sender: self_.notification_sender.clone(),
201 }));
202 match fut.poll_unpin(&mut std::task::Context::from_waker(&waker)) {
203 Poll::Ready(o) => {
204 // Remove and drop the future itself.
205 // We *could* return it along with the item, but this would
206 // be a departure from the interface of `FuturesUnordered`,
207 // and most futures are designed to be discarded after
208 // completion.
209 self_.futures.remove(&key);
210
211 return Poll::Ready(Some((key, o)));
212 }
213 Poll::Pending => {
214 // This future wasn't actually ready.
215 //
216 // This can happen, e.g. because:
217 // * This is our first time actually polling this future.
218 // * The futures waker was called spuriously.
219 // * This was actually a reused key, and we received the notification from
220 // a waker for a previous future registered with this key.
221 //
222 // Move on to the next key.
223 }
224 }
225 }
226 }
227}
228
229/// Error returned by [`KeyedFuturesUnordered::try_insert`].
230#[derive(Debug, thiserror::Error)]
231#[allow(clippy::exhaustive_structs)]
232pub struct KeyAlreadyInsertedError<K, F> {
233 /// Key that caller tried to insert.
234 #[allow(dead_code)]
235 pub key: K,
236 /// Future that caller tried to insert.
237 #[allow(dead_code)]
238 pub fut: F,
239}
240
241#[cfg(test)]
242mod tests {
243 // @@ begin test lint list maintained by maint/add_warning @@
244 #![allow(clippy::bool_assert_comparison)]
245 #![allow(clippy::clone_on_copy)]
246 #![allow(clippy::dbg_macro)]
247 #![allow(clippy::mixed_attributes_style)]
248 #![allow(clippy::print_stderr)]
249 #![allow(clippy::print_stdout)]
250 #![allow(clippy::single_char_pattern)]
251 #![allow(clippy::unwrap_used)]
252 #![allow(clippy::unchecked_duration_subtraction)]
253 #![allow(clippy::useless_vec)]
254 #![allow(clippy::needless_pass_by_value)]
255 //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
256
257 use std::task::Waker;
258
259 use futures::{executor::block_on, future::poll_fn, StreamExt as _};
260 use oneshot_fused_workaround as oneshot;
261 use tor_rtmock::MockRuntime;
262
263 use super::*;
264
265 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
266 struct Key(u64);
267
268 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
269 struct Value(u64);
270
271 /// Simple future for testing. Supports comparison, and can be mutated directly to become ready.
272 #[derive(Debug, Clone)]
273 struct ValueFut<V> {
274 /// Value that will be produced when ready.
275 value: Option<V>,
276 /// Whether this is ready.
277 // We use a distinct flag here instead of a None value so that pending
278 // instances are still unequal if they have different values.
279 ready: bool,
280 // Waker
281 waker: Option<Waker>,
282 }
283
284 impl<V> std::cmp::PartialEq for ValueFut<V>
285 where
286 V: std::cmp::PartialEq,
287 {
288 fn eq(&self, other: &Self) -> bool {
289 // Ignores the waker, which isn't comparable
290 self.value == other.value && self.ready == other.ready
291 }
292 }
293
294 impl<V> std::cmp::Eq for ValueFut<V> where V: std::cmp::Eq {}
295
296 impl<V> ValueFut<V> {
297 fn ready(value: V) -> Self {
298 Self {
299 value: Some(value),
300 ready: true,
301 waker: None,
302 }
303 }
304 fn pending(value: V) -> Self {
305 Self {
306 value: Some(value),
307 ready: false,
308 waker: None,
309 }
310 }
311 fn make_ready(&mut self) {
312 self.ready = true;
313 if let Some(waker) = self.waker.take() {
314 waker.wake();
315 }
316 }
317 }
318
319 impl<V> Future for ValueFut<V>
320 where
321 V: Unpin,
322 {
323 type Output = V;
324
325 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
326 if !self.ready {
327 self.waker.replace(cx.waker().clone());
328 Poll::Pending
329 } else {
330 Poll::Ready(self.value.take().expect("Polled future after it was ready"))
331 }
332 }
333 }
334
335 #[test]
336 fn test_empty() {
337 block_on(poll_fn(|cx| {
338 let mut kfu = KeyedFuturesUnordered::<Key, ValueFut<Value>>::new();
339
340 // When there are no futures in the set (ready or pending), returns
341 // `Poll::Ready(None)` as for `FuturesUnordered`.
342 assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
343
344 // Nothing to get.
345 assert_eq!(kfu.get(&Key(0)), None);
346 assert_eq!(kfu.get_mut(&Key(0)), None);
347
348 Poll::Ready(())
349 }));
350 }
351
352 #[test]
353 fn test_one_pending_future() {
354 block_on(poll_fn(|cx| {
355 let mut kfu = KeyedFuturesUnordered::new();
356
357 kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
358
359 // When there are futures in the set, but none are ready, returns
360 // `Poll::Pending`, as for `FuturesUnordered`
361 assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
362
363 // State should be unchanged; same result if we poll again.
364 assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
365
366 // We should be able to get the future.
367 assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::pending(Value(0))));
368 assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::pending(Value(0))));
369
370 Poll::Ready(())
371 }));
372 }
373
374 #[test]
375 fn test_one_ready_future() {
376 block_on(poll_fn(|cx| {
377 let mut kfu = KeyedFuturesUnordered::new();
378
379 kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
380
381 // Should be able to get the future before it's polled.
382 assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::ready(Value(1))));
383 assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::ready(Value(1))));
384
385 // When there is a ready future, returns it.
386 assert_eq!(
387 kfu.poll_next_unpin(cx),
388 Poll::Ready(Some((Key(0), Value(1))))
389 );
390
391 // After having returned the ready future, should be empty again.
392 assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
393 assert_eq!(kfu.get(&Key(0)), None);
394 assert_eq!(kfu.get_mut(&Key(0)), None);
395
396 Poll::Ready(())
397 }));
398 }
399
400 #[test]
401 fn test_one_pending_then_ready_future() {
402 block_on(poll_fn(|cx| {
403 let mut kfu = KeyedFuturesUnordered::new();
404 let (send, recv) = oneshot::channel::<Value>();
405 kfu.try_insert(Key(0), recv).unwrap();
406
407 // Nothing ready yet.
408 assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
409
410 // Should be able to get it.
411 assert!(kfu.get(&Key(0)).is_some());
412 assert!(kfu.get_mut(&Key(0)).is_some());
413
414 send.send(Value(1)).unwrap();
415
416 // oneshot future should be ready.
417 assert_eq!(
418 kfu.poll_next_unpin(cx),
419 Poll::Ready(Some((Key(0), Ok(Value(1)))))
420 );
421
422 // Empty again.
423 assert!(kfu.get(&Key(0)).is_none());
424 assert!(kfu.get_mut(&Key(0)).is_none());
425 assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
426
427 Poll::Ready(())
428 }));
429 }
430
431 #[test]
432 fn test_remove_pending() {
433 block_on(poll_fn(|cx| {
434 let mut kfu = KeyedFuturesUnordered::new();
435 kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
436 assert_eq!(
437 kfu.remove(&Key(0)),
438 Some((Key(0), ValueFut::pending(Value(0))))
439 );
440 assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
441 Poll::Ready(())
442 }));
443 }
444
445 #[test]
446 fn test_remove_ready() {
447 block_on(poll_fn(|cx| {
448 let mut kfu = KeyedFuturesUnordered::new();
449 kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
450 assert_eq!(
451 kfu.remove(&Key(0)),
452 Some((Key(0), ValueFut::ready(Value(1))))
453 );
454 assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
455 Poll::Ready(())
456 }));
457 }
458
459 #[test]
460 fn test_remove_and_reuse_ready() {
461 block_on(poll_fn(|cx| {
462 let mut kfu = KeyedFuturesUnordered::new();
463 kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
464 assert_eq!(
465 kfu.remove(&Key(0)),
466 Some((Key(0), ValueFut::ready(Value(1))))
467 );
468 kfu.try_insert(Key(0), ValueFut::ready(Value(2))).unwrap();
469
470 // We should get back *only* the second value.
471 assert_eq!(
472 kfu.poll_next_unpin(cx),
473 Poll::Ready(Some((Key(0), Value(2))))
474 );
475 assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
476
477 Poll::Ready(())
478 }));
479 }
480
481 #[test]
482 fn test_remove_and_reuse_pending_then_ready() {
483 block_on(poll_fn(|cx| {
484 let mut kfu = KeyedFuturesUnordered::new();
485 kfu.try_insert(Key(0), ValueFut::pending(Value(1))).unwrap();
486 let (_key, mut removed_value) = kfu.remove(&Key(0)).unwrap();
487 kfu.try_insert(Key(0), ValueFut::pending(Value(2))).unwrap();
488
489 // Make the *removed* future ready before polling again. This should
490 // cause an internal spurious wakeup, but not be visible from the
491 // user's perspective.
492 removed_value.make_ready();
493 assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
494
495 // Make the future that we replaced it with become ready.
496 kfu.get_mut(&Key(0)).unwrap().make_ready();
497
498 // We should now get back *only* the second value.
499 assert_eq!(
500 kfu.poll_next_unpin(cx),
501 Poll::Ready(Some((Key(0), Value(2))))
502 );
503 assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
504
505 Poll::Ready(())
506 }));
507 }
508
509 #[test]
510 fn test_async() {
511 MockRuntime::test_with_various(|rt| async move {
512 let mut kfu = KeyedFuturesUnordered::new();
513
514 for i in 0..10 {
515 let (send, recv) = oneshot::channel();
516 kfu.try_insert(Key(i), recv).unwrap();
517 rt.spawn_identified(format!("sender-{i}"), async move {
518 send.send(Value(i)).unwrap();
519 });
520 }
521
522 let values = kfu.collect::<Vec<_>>().await;
523 let mut values = values
524 .into_iter()
525 .map(|(k, v)| (k, v.unwrap()))
526 .collect::<Vec<_>>();
527 values.sort();
528
529 let expected_values = (0..10).map(|i| (Key(i), Value(i))).collect::<Vec<_>>();
530 assert_eq!(values, expected_values);
531 });
532 }
533}