tokio/sync/task/
atomic_waker.rs

1#![cfg_attr(any(loom, not(feature = "sync")), allow(dead_code, unreachable_pub))]
2
3use crate::loom::cell::UnsafeCell;
4use crate::loom::hint;
5use crate::loom::sync::atomic::AtomicUsize;
6
7use std::fmt;
8use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe};
9use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
10use std::task::Waker;
11
12/// A synchronization primitive for task waking.
13///
14/// `AtomicWaker` will coordinate concurrent wakes with the consumer
15/// potentially "waking" the underlying task. This is useful in scenarios
16/// where a computation completes in another thread and wants to wake the
17/// consumer, but the consumer is in the process of being migrated to a new
18/// logical task.
19///
20/// Consumers should call `register` before checking the result of a computation
21/// and producers should call `wake` after producing the computation (this
22/// differs from the usual `thread::park` pattern). It is also permitted for
23/// `wake` to be called **before** `register`. This results in a no-op.
24///
25/// A single `AtomicWaker` may be reused for any number of calls to `register` or
26/// `wake`.
27pub(crate) struct AtomicWaker {
28    state: AtomicUsize,
29    waker: UnsafeCell<Option<Waker>>,
30}
31
32impl RefUnwindSafe for AtomicWaker {}
33impl UnwindSafe for AtomicWaker {}
34
35// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell
36// stores a `Waker` value produced by calls to `register` and many threads can
37// race to take the waker by calling `wake`.
38//
39// If a new `Waker` instance is produced by calling `register` before an existing
40// one is consumed, then the existing one is overwritten.
41//
42// While `AtomicWaker` is single-producer, the implementation ensures memory
43// safety. In the event of concurrent calls to `register`, there will be a
44// single winner whose waker will get stored in the cell. The losers will not
45// have their tasks woken. As such, callers should ensure to add synchronization
46// to calls to `register`.
47//
48// The implementation uses a single `AtomicUsize` value to coordinate access to
49// the `Waker` cell. There are two bits that are operated on independently. These
50// are represented by `REGISTERING` and `WAKING`.
51//
52// The `REGISTERING` bit is set when a producer enters the critical section. The
53// `WAKING` bit is set when a consumer enters the critical section. Neither
54// bit being set is represented by `WAITING`.
55//
56// A thread obtains an exclusive lock on the waker cell by transitioning the
57// state from `WAITING` to `REGISTERING` or `WAKING`, depending on the
58// operation the thread wishes to perform. When this transition is made, it is
59// guaranteed that no other thread will access the waker cell.
60//
61// # Registering
62//
63// On a call to `register`, an attempt to transition the state from WAITING to
64// REGISTERING is made. On success, the caller obtains a lock on the waker cell.
65//
66// If the lock is obtained, then the thread sets the waker cell to the waker
67// provided as an argument. Then it attempts to transition the state back from
68// `REGISTERING` -> `WAITING`.
69//
70// If this transition is successful, then the registering process is complete
71// and the next call to `wake` will observe the waker.
72//
73// If the transition fails, then there was a concurrent call to `wake` that
74// was unable to access the waker cell (due to the registering thread holding the
75// lock). To handle this, the registering thread removes the waker it just set
76// from the cell and calls `wake` on it. This call to wake represents the
77// attempt to wake by the other thread (that set the `WAKING` bit). The
78// state is then transitioned from `REGISTERING | WAKING` back to `WAITING`.
79// This transition must succeed because, at this point, the state cannot be
80// transitioned by another thread.
81//
82// # Waking
83//
84// On a call to `wake`, an attempt to transition the state from `WAITING` to
85// `WAKING` is made. On success, the caller obtains a lock on the waker cell.
86//
87// If the lock is obtained, then the thread takes ownership of the current value
88// in the waker cell, and calls `wake` on it. The state is then transitioned
89// back to `WAITING`. This transition must succeed as, at this point, the state
90// cannot be transitioned by another thread.
91//
92// If the thread is unable to obtain the lock, the `WAKING` bit is still set.
93// This is because it has either been set by the current thread but the previous
94// value included the `REGISTERING` bit **or** a concurrent thread is in the
95// `WAKING` critical section. Either way, no action must be taken.
96//
97// If the current thread is the only concurrent call to `wake` and another
98// thread is in the `register` critical section, when the other thread **exits**
99// the `register` critical section, it will observe the `WAKING` bit and
100// handle the waker itself.
101//
102// If another thread is in the `waker` critical section, then it will handle
103// waking the caller task.
104//
105// # A potential race (is safely handled).
106//
107// Imagine the following situation:
108//
109// * Thread A obtains the `wake` lock and wakes a task.
110//
111// * Before thread A releases the `wake` lock, the woken task is scheduled.
112//
113// * Thread B attempts to wake the task. In theory this should result in the
114//   task being woken, but it cannot because thread A still holds the wake
115//   lock.
116//
117// This case is handled by requiring users of `AtomicWaker` to call `register`
118// **before** attempting to observe the application state change that resulted
119// in the task being woken. The wakers also change the application state
120// before calling wake.
121//
122// Because of this, the task will do one of two things.
123//
124// 1) Observe the application state change that Thread B is waking on. In
125//    this case, it is OK for Thread B's wake to be lost.
126//
127// 2) Call register before attempting to observe the application state. Since
128//    Thread A still holds the `wake` lock, the call to `register` will result
129//    in the task waking itself and get scheduled again.
130
131/// Idle state.
132const WAITING: usize = 0;
133
134/// A new waker value is being registered with the `AtomicWaker` cell.
135const REGISTERING: usize = 0b01;
136
137/// The task currently registered with the `AtomicWaker` cell is being woken.
138const WAKING: usize = 0b10;
139
140impl AtomicWaker {
141    /// Create an `AtomicWaker`
142    pub(crate) fn new() -> AtomicWaker {
143        AtomicWaker {
144            state: AtomicUsize::new(WAITING),
145            waker: UnsafeCell::new(None),
146        }
147    }
148
149    /*
150    /// Registers the current waker to be notified on calls to `wake`.
151    pub(crate) fn register(&self, waker: Waker) {
152        self.do_register(waker);
153    }
154    */
155
156    /// Registers the provided waker to be notified on calls to `wake`.
157    ///
158    /// The new waker will take place of any previous wakers that were registered
159    /// by previous calls to `register`. Any calls to `wake` that happen after
160    /// a call to `register` (as defined by the memory ordering rules), will
161    /// wake the `register` caller's task.
162    ///
163    /// It is safe to call `register` with multiple other threads concurrently
164    /// calling `wake`. This will result in the `register` caller's current
165    /// task being woken once.
166    ///
167    /// This function is safe to call concurrently, but this is generally a bad
168    /// idea. Concurrent calls to `register` will attempt to register different
169    /// tasks to be woken. One of the callers will win and have its task set,
170    /// but there is no guarantee as to which caller will succeed.
171    pub(crate) fn register_by_ref(&self, waker: &Waker) {
172        self.do_register(waker);
173    }
174
175    fn do_register<W>(&self, waker: W)
176    where
177        W: WakerRef,
178    {
179        fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> {
180            std::panic::catch_unwind(AssertUnwindSafe(f))
181        }
182
183        match self
184            .state
185            .compare_exchange(WAITING, REGISTERING, Acquire, Acquire)
186            .unwrap_or_else(|x| x)
187        {
188            WAITING => {
189                unsafe {
190                    // If `into_waker` panics (because it's code outside of
191                    // AtomicWaker) we need to prime a guard that is called on
192                    // unwind to restore the waker to a WAITING state. Otherwise
193                    // any future calls to register will incorrectly be stuck
194                    // believing it's being updated by someone else.
195                    let new_waker_or_panic = catch_unwind(move || waker.into_waker());
196
197                    // Set the field to contain the new waker, or if
198                    // `into_waker` panicked, leave the old value.
199                    let mut maybe_panic = None;
200                    let mut old_waker = None;
201                    match new_waker_or_panic {
202                        Ok(new_waker) => {
203                            old_waker = self.waker.with_mut(|t| (*t).take());
204                            self.waker.with_mut(|t| *t = Some(new_waker));
205                        }
206                        Err(panic) => maybe_panic = Some(panic),
207                    }
208
209                    // Release the lock. If the state transitioned to include
210                    // the `WAKING` bit, this means that a wake has been
211                    // called concurrently, so we have to remove the waker and
212                    // wake it.`
213                    //
214                    // Start by assuming that the state is `REGISTERING` as this
215                    // is what we jut set it to.
216                    let res = self
217                        .state
218                        .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire);
219
220                    match res {
221                        Ok(_) => {
222                            // We don't want to give the caller the panic if it
223                            // was someone else who put in that waker.
224                            let _ = catch_unwind(move || {
225                                drop(old_waker);
226                            });
227                        }
228                        Err(actual) => {
229                            // This branch can only be reached if a
230                            // concurrent thread called `wake`. In this
231                            // case, `actual` **must** be `REGISTERING |
232                            // WAKING`.
233                            debug_assert_eq!(actual, REGISTERING | WAKING);
234
235                            // Take the waker to wake once the atomic operation has
236                            // completed.
237                            let mut waker = self.waker.with_mut(|t| (*t).take());
238
239                            // Just swap, because no one could change state
240                            // while state == `Registering | `Waking`
241                            self.state.swap(WAITING, AcqRel);
242
243                            // If `into_waker` panicked, then the waker in the
244                            // waker slot is actually the old waker.
245                            if maybe_panic.is_some() {
246                                old_waker = waker.take();
247                            }
248
249                            // We don't want to give the caller the panic if it
250                            // was someone else who put in that waker.
251                            if let Some(old_waker) = old_waker {
252                                let _ = catch_unwind(move || {
253                                    old_waker.wake();
254                                });
255                            }
256
257                            // The atomic swap was complete, now wake the waker
258                            // and return.
259                            //
260                            // If this panics, we end up in a consumed state and
261                            // return the panic to the caller.
262                            if let Some(waker) = waker {
263                                debug_assert!(maybe_panic.is_none());
264                                waker.wake();
265                            }
266                        }
267                    }
268
269                    if let Some(panic) = maybe_panic {
270                        // If `into_waker` panicked, return the panic to the caller.
271                        resume_unwind(panic);
272                    }
273                }
274            }
275            WAKING => {
276                // Currently in the process of waking the task, i.e.,
277                // `wake` is currently being called on the old waker.
278                // So, we call wake on the new waker.
279                //
280                // If this panics, someone else is responsible for restoring the
281                // state of the waker.
282                waker.wake();
283
284                // This is equivalent to a spin lock, so use a spin hint.
285                hint::spin_loop();
286            }
287            state => {
288                // In this case, a concurrent thread is holding the
289                // "registering" lock. This probably indicates a bug in the
290                // caller's code as racing to call `register` doesn't make much
291                // sense.
292                //
293                // We just want to maintain memory safety. It is ok to drop the
294                // call to `register`.
295                debug_assert!(state == REGISTERING || state == REGISTERING | WAKING);
296            }
297        }
298    }
299
300    /// Wakes the task that last called `register`.
301    ///
302    /// If `register` has not been called yet, then this does nothing.
303    pub(crate) fn wake(&self) {
304        if let Some(waker) = self.take_waker() {
305            // If wake panics, we've consumed the waker which is a legitimate
306            // outcome.
307            waker.wake();
308        }
309    }
310
311    /// Attempts to take the `Waker` value out of the `AtomicWaker` with the
312    /// intention that the caller will wake the task later.
313    pub(crate) fn take_waker(&self) -> Option<Waker> {
314        // AcqRel ordering is used in order to acquire the value of the `waker`
315        // cell as well as to establish a `release` ordering with whatever
316        // memory the `AtomicWaker` is associated with.
317        match self.state.fetch_or(WAKING, AcqRel) {
318            WAITING => {
319                // The waking lock has been acquired.
320                let waker = unsafe { self.waker.with_mut(|t| (*t).take()) };
321
322                // Release the lock
323                self.state.fetch_and(!WAKING, Release);
324
325                waker
326            }
327            state => {
328                // There is a concurrent thread currently updating the
329                // associated waker.
330                //
331                // Nothing more to do as the `WAKING` bit has been set. It
332                // doesn't matter if there are concurrent registering threads or
333                // not.
334                //
335                debug_assert!(
336                    state == REGISTERING || state == REGISTERING | WAKING || state == WAKING
337                );
338                None
339            }
340        }
341    }
342}
343
344impl Default for AtomicWaker {
345    fn default() -> Self {
346        AtomicWaker::new()
347    }
348}
349
350impl fmt::Debug for AtomicWaker {
351    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
352        write!(fmt, "AtomicWaker")
353    }
354}
355
356unsafe impl Send for AtomicWaker {}
357unsafe impl Sync for AtomicWaker {}
358
359trait WakerRef {
360    fn wake(self);
361    fn into_waker(self) -> Waker;
362}
363
364impl WakerRef for Waker {
365    fn wake(self) {
366        self.wake();
367    }
368
369    fn into_waker(self) -> Waker {
370        self
371    }
372}
373
374impl WakerRef for &Waker {
375    fn wake(self) {
376        self.wake_by_ref();
377    }
378
379    fn into_waker(self) -> Waker {
380        self.clone()
381    }
382}