tor_rtcompat/task.rs
1//! Functions for task management that don't belong inside the Runtime
2//! trait.
3
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8/// Yield execution back to the runtime temporarily, so that other
9/// tasks can run.
10#[must_use = "yield_now returns a future that must be .awaited on."]
11pub fn yield_now() -> YieldFuture {
12 // TODO: There are functions similar to this in tokio and
13 // async_std and futures_lite. It would be lovely if futures had
14 // one too. If it does, we should probably use it.
15 YieldFuture { first_time: true }
16}
17
18/// A future returned by [`yield_now()`].
19///
20/// It returns `Poll::Pending` once, and `Poll::Ready` thereafter.
21#[derive(Debug)]
22#[must_use = "Futures do nothing unless .awaited on."]
23pub struct YieldFuture {
24 /// True if this future has not yet been polled.
25 first_time: bool,
26}
27
28impl Future for YieldFuture {
29 type Output = ();
30 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
31 if self.first_time {
32 self.first_time = false;
33 cx.waker().wake_by_ref();
34 Poll::Pending
35 } else {
36 Poll::Ready(())
37 }
38 }
39}
40
41#[cfg(all(
42 test,
43 any(feature = "native-tls", feature = "rustls"),
44 any(feature = "tokio", feature = "async-std"),
45 not(miri), // this typically results in use of a yield syscall
46))]
47mod test {
48 use super::yield_now;
49 use crate::test_with_all_runtimes;
50
51 use std::sync::atomic::{AtomicBool, Ordering};
52
53 #[test]
54 fn test_yield() {
55 test_with_all_runtimes!(|_| async {
56 let b = AtomicBool::new(false);
57 use Ordering::SeqCst;
58
59 // Both tasks here run in a loop, trying to set 'b' to their
60 // favorite value, and returning once they've done it 10 times.
61 //
62 // Without 'yield_now', one task is likely to monopolize
63 // the scheduler.
64 futures::join!(
65 async {
66 let mut n = 0_usize;
67 while n < 10 {
68 if b.compare_exchange(false, true, SeqCst, SeqCst).is_ok() {
69 n += 1;
70 }
71 yield_now().await;
72 }
73 },
74 async {
75 let mut n = 0_usize;
76 while n < 10 {
77 if b.compare_exchange(true, false, SeqCst, SeqCst).is_ok() {
78 n += 1;
79 }
80 yield_now().await;
81 }
82 }
83 );
84 std::io::Result::Ok(())
85 });
86 }
87}