tokio/runtime/scheduler/multi_thread/
queue.rs

1//! Run-queue structures to support a work-stealing scheduler
2
3use crate::loom::cell::UnsafeCell;
4use crate::loom::sync::Arc;
5use crate::runtime::scheduler::multi_thread::{Overflow, Stats};
6use crate::runtime::task;
7
8use std::mem::{self, MaybeUninit};
9use std::ptr;
10use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11
12// Use wider integers when possible to increase ABA resilience.
13//
14// See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15cfg_has_atomic_u64! {
16    type UnsignedShort = u32;
17    type UnsignedLong = u64;
18    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20}
21cfg_not_has_atomic_u64! {
22    type UnsignedShort = u16;
23    type UnsignedLong = u32;
24    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26}
27
28/// Producer handle. May only be used from a single thread.
29pub(crate) struct Local<T: 'static> {
30    inner: Arc<Inner<T>>,
31}
32
33/// Consumer handle. May be used from many threads.
34pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35
36pub(crate) struct Inner<T: 'static> {
37    /// Concurrently updated by many threads.
38    ///
39    /// Contains two `UnsignedShort` values. The `LSB` byte is the "real" head of
40    /// the queue. The `UnsignedShort` in the `MSB` is set by a stealer in process
41    /// of stealing values. It represents the first value being stolen in the
42    /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43    /// required for buffer indexing in order to provide ABA mitigation and make
44    /// it possible to distinguish between full and empty buffers.
45    ///
46    /// When both `UnsignedShort` values are the same, there is no active
47    /// stealer.
48    ///
49    /// Tracking an in-progress stealer prevents a wrapping scenario.
50    head: AtomicUnsignedLong,
51
52    /// Only updated by producer thread but read by many threads.
53    tail: AtomicUnsignedShort,
54
55    /// Elements
56    buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57}
58
59unsafe impl<T> Send for Inner<T> {}
60unsafe impl<T> Sync for Inner<T> {}
61
62#[cfg(not(loom))]
63const LOCAL_QUEUE_CAPACITY: usize = 256;
64
65// Shrink the size of the local queue when using loom. This shouldn't impact
66// logic, but allows loom to test more edge cases in a reasonable a mount of
67// time.
68#[cfg(loom)]
69const LOCAL_QUEUE_CAPACITY: usize = 4;
70
71const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72
73// Constructing the fixed size array directly is very awkward. The only way to
74// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75// the contents are not Copy. The trick with defining a const doesn't work for
76// generic types.
77fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78    assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79
80    // safety: We check that the length is correct.
81    unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82}
83
84/// Create a new local run-queue
85pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86    let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87
88    for _ in 0..LOCAL_QUEUE_CAPACITY {
89        buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90    }
91
92    let inner = Arc::new(Inner {
93        head: AtomicUnsignedLong::new(0),
94        tail: AtomicUnsignedShort::new(0),
95        buffer: make_fixed_size(buffer.into_boxed_slice()),
96    });
97
98    let local = Local {
99        inner: inner.clone(),
100    };
101
102    let remote = Steal(inner);
103
104    (remote, local)
105}
106
107impl<T> Local<T> {
108    /// Returns the number of entries in the queue
109    pub(crate) fn len(&self) -> usize {
110        let (_, head) = unpack(self.inner.head.load(Acquire));
111        // safety: this is the **only** thread that updates this cell.
112        let tail = unsafe { self.inner.tail.unsync_load() };
113        len(head, tail)
114    }
115
116    /// How many tasks can be pushed into the queue
117    pub(crate) fn remaining_slots(&self) -> usize {
118        let (steal, _) = unpack(self.inner.head.load(Acquire));
119        // safety: this is the **only** thread that updates this cell.
120        let tail = unsafe { self.inner.tail.unsync_load() };
121
122        LOCAL_QUEUE_CAPACITY - len(steal, tail)
123    }
124
125    pub(crate) fn max_capacity(&self) -> usize {
126        LOCAL_QUEUE_CAPACITY
127    }
128
129    /// Returns false if there are any entries in the queue
130    ///
131    /// Separate to `is_stealable` so that refactors of `is_stealable` to "protect"
132    /// some tasks from stealing won't affect this
133    pub(crate) fn has_tasks(&self) -> bool {
134        self.len() != 0
135    }
136
137    /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
138    /// the local queue.
139    ///
140    /// # Panics
141    ///
142    /// The method panics if there is not enough capacity to fit in the queue.
143    pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
144        let len = tasks.len();
145        assert!(len <= LOCAL_QUEUE_CAPACITY);
146
147        if len == 0 {
148            // Nothing to do
149            return;
150        }
151
152        let head = self.inner.head.load(Acquire);
153        let (steal, _) = unpack(head);
154
155        // safety: this is the **only** thread that updates this cell.
156        let mut tail = unsafe { self.inner.tail.unsync_load() };
157
158        if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort {
159            // Yes, this if condition is structured a bit weird (first block
160            // does nothing, second returns an error). It is this way to match
161            // `push_back_or_overflow`.
162        } else {
163            panic!()
164        }
165
166        for task in tasks {
167            let idx = tail as usize & MASK;
168
169            self.inner.buffer[idx].with_mut(|ptr| {
170                // Write the task to the slot
171                //
172                // Safety: There is only one producer and the above `if`
173                // condition ensures we don't touch a cell if there is a
174                // value, thus no consumer.
175                unsafe {
176                    ptr::write((*ptr).as_mut_ptr(), task);
177                }
178            });
179
180            tail = tail.wrapping_add(1);
181        }
182
183        self.inner.tail.store(tail, Release);
184    }
185
186    /// Pushes a task to the back of the local queue, if there is not enough
187    /// capacity in the queue, this triggers the overflow operation.
188    ///
189    /// When the queue overflows, half of the current contents of the queue is
190    /// moved to the given Injection queue. This frees up capacity for more
191    /// tasks to be pushed into the local queue.
192    pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
193        &mut self,
194        mut task: task::Notified<T>,
195        overflow: &O,
196        stats: &mut Stats,
197    ) {
198        let tail = loop {
199            let head = self.inner.head.load(Acquire);
200            let (steal, real) = unpack(head);
201
202            // safety: this is the **only** thread that updates this cell.
203            let tail = unsafe { self.inner.tail.unsync_load() };
204
205            if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
206                // There is capacity for the task
207                break tail;
208            } else if steal != real {
209                // Concurrently stealing, this will free up capacity, so only
210                // push the task onto the inject queue
211                overflow.push(task);
212                return;
213            } else {
214                // Push the current task and half of the queue into the
215                // inject queue.
216                match self.push_overflow(task, real, tail, overflow, stats) {
217                    Ok(_) => return,
218                    // Lost the race, try again
219                    Err(v) => {
220                        task = v;
221                    }
222                }
223            }
224        };
225
226        self.push_back_finish(task, tail);
227    }
228
229    // Second half of `push_back`
230    fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
231        // Map the position to a slot index.
232        let idx = tail as usize & MASK;
233
234        self.inner.buffer[idx].with_mut(|ptr| {
235            // Write the task to the slot
236            //
237            // Safety: There is only one producer and the above `if`
238            // condition ensures we don't touch a cell if there is a
239            // value, thus no consumer.
240            unsafe {
241                ptr::write((*ptr).as_mut_ptr(), task);
242            }
243        });
244
245        // Make the task available. Synchronizes with a load in
246        // `steal_into2`.
247        self.inner.tail.store(tail.wrapping_add(1), Release);
248    }
249
250    /// Moves a batch of tasks into the inject queue.
251    ///
252    /// This will temporarily make some of the tasks unavailable to stealers.
253    /// Once `push_overflow` is done, a notification is sent out, so if other
254    /// workers "missed" some of the tasks during a steal, they will get
255    /// another opportunity.
256    #[inline(never)]
257    fn push_overflow<O: Overflow<T>>(
258        &mut self,
259        task: task::Notified<T>,
260        head: UnsignedShort,
261        tail: UnsignedShort,
262        overflow: &O,
263        stats: &mut Stats,
264    ) -> Result<(), task::Notified<T>> {
265        /// How many elements are we taking from the local queue.
266        ///
267        /// This is one less than the number of tasks pushed to the inject
268        /// queue as we are also inserting the `task` argument.
269        const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
270
271        assert_eq!(
272            tail.wrapping_sub(head) as usize,
273            LOCAL_QUEUE_CAPACITY,
274            "queue is not full; tail = {tail}; head = {head}"
275        );
276
277        let prev = pack(head, head);
278
279        // Claim a bunch of tasks
280        //
281        // We are claiming the tasks **before** reading them out of the buffer.
282        // This is safe because only the **current** thread is able to push new
283        // tasks.
284        //
285        // There isn't really any need for memory ordering... Relaxed would
286        // work. This is because all tasks are pushed into the queue from the
287        // current thread (or memory has been acquired if the local queue handle
288        // moved).
289        if self
290            .inner
291            .head
292            .compare_exchange(
293                prev,
294                pack(
295                    head.wrapping_add(NUM_TASKS_TAKEN),
296                    head.wrapping_add(NUM_TASKS_TAKEN),
297                ),
298                Release,
299                Relaxed,
300            )
301            .is_err()
302        {
303            // We failed to claim the tasks, losing the race. Return out of
304            // this function and try the full `push` routine again. The queue
305            // may not be full anymore.
306            return Err(task);
307        }
308
309        /// An iterator that takes elements out of the run queue.
310        struct BatchTaskIter<'a, T: 'static> {
311            buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
312            head: UnsignedLong,
313            i: UnsignedLong,
314        }
315        impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
316            type Item = task::Notified<T>;
317
318            #[inline]
319            fn next(&mut self) -> Option<task::Notified<T>> {
320                if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
321                    None
322                } else {
323                    let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
324                    let slot = &self.buffer[i_idx];
325
326                    // safety: Our CAS from before has assumed exclusive ownership
327                    // of the task pointers in this range.
328                    let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
329
330                    self.i += 1;
331                    Some(task)
332                }
333            }
334        }
335
336        // safety: The CAS above ensures that no consumer will look at these
337        // values again, and we are the only producer.
338        let batch_iter = BatchTaskIter {
339            buffer: &self.inner.buffer,
340            head: head as UnsignedLong,
341            i: 0,
342        };
343        overflow.push_batch(batch_iter.chain(std::iter::once(task)));
344
345        // Add 1 to factor in the task currently being scheduled.
346        stats.incr_overflow_count();
347
348        Ok(())
349    }
350
351    /// Pops a task from the local queue.
352    pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
353        let mut head = self.inner.head.load(Acquire);
354
355        let idx = loop {
356            let (steal, real) = unpack(head);
357
358            // safety: this is the **only** thread that updates this cell.
359            let tail = unsafe { self.inner.tail.unsync_load() };
360
361            if real == tail {
362                // queue is empty
363                return None;
364            }
365
366            let next_real = real.wrapping_add(1);
367
368            // If `steal == real` there are no concurrent stealers. Both `steal`
369            // and `real` are updated.
370            let next = if steal == real {
371                pack(next_real, next_real)
372            } else {
373                assert_ne!(steal, next_real);
374                pack(steal, next_real)
375            };
376
377            // Attempt to claim a task.
378            let res = self
379                .inner
380                .head
381                .compare_exchange(head, next, AcqRel, Acquire);
382
383            match res {
384                Ok(_) => break real as usize & MASK,
385                Err(actual) => head = actual,
386            }
387        };
388
389        Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
390    }
391}
392
393impl<T> Steal<T> {
394    /// Returns the number of entries in the queue
395    pub(crate) fn len(&self) -> usize {
396        let (_, head) = unpack(self.0.head.load(Acquire));
397        let tail = self.0.tail.load(Acquire);
398        len(head, tail)
399    }
400
401    /// Return true if the queue is empty,
402    /// false if there are any entries in the queue
403    pub(crate) fn is_empty(&self) -> bool {
404        self.len() == 0
405    }
406
407    /// Steals half the tasks from self and place them into `dst`.
408    pub(crate) fn steal_into(
409        &self,
410        dst: &mut Local<T>,
411        dst_stats: &mut Stats,
412    ) -> Option<task::Notified<T>> {
413        // Safety: the caller is the only thread that mutates `dst.tail` and
414        // holds a mutable reference.
415        let dst_tail = unsafe { dst.inner.tail.unsync_load() };
416
417        // To the caller, `dst` may **look** empty but still have values
418        // contained in the buffer. If another thread is concurrently stealing
419        // from `dst` there may not be enough capacity to steal.
420        let (steal, _) = unpack(dst.inner.head.load(Acquire));
421
422        if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
423            // we *could* try to steal less here, but for simplicity, we're just
424            // going to abort.
425            return None;
426        }
427
428        // Steal the tasks into `dst`'s buffer. This does not yet expose the
429        // tasks in `dst`.
430        let mut n = self.steal_into2(dst, dst_tail);
431
432        if n == 0 {
433            // No tasks were stolen
434            return None;
435        }
436
437        dst_stats.incr_steal_count(n as u16);
438        dst_stats.incr_steal_operations();
439
440        // We are returning a task here
441        n -= 1;
442
443        let ret_pos = dst_tail.wrapping_add(n);
444        let ret_idx = ret_pos as usize & MASK;
445
446        // safety: the value was written as part of `steal_into2` and not
447        // exposed to stealers, so no other thread can access it.
448        let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
449
450        if n == 0 {
451            // The `dst` queue is empty, but a single task was stolen
452            return Some(ret);
453        }
454
455        // Make the stolen items available to consumers
456        dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
457
458        Some(ret)
459    }
460
461    // Steal tasks from `self`, placing them into `dst`. Returns the number of
462    // tasks that were stolen.
463    fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
464        let mut prev_packed = self.0.head.load(Acquire);
465        let mut next_packed;
466
467        let n = loop {
468            let (src_head_steal, src_head_real) = unpack(prev_packed);
469            let src_tail = self.0.tail.load(Acquire);
470
471            // If these two do not match, another thread is concurrently
472            // stealing from the queue.
473            if src_head_steal != src_head_real {
474                return 0;
475            }
476
477            // Number of available tasks to steal
478            let n = src_tail.wrapping_sub(src_head_real);
479            let n = n - n / 2;
480
481            if n == 0 {
482                // No tasks available to steal
483                return 0;
484            }
485
486            // Update the real head index to acquire the tasks.
487            let steal_to = src_head_real.wrapping_add(n);
488            assert_ne!(src_head_steal, steal_to);
489            next_packed = pack(src_head_steal, steal_to);
490
491            // Claim all those tasks. This is done by incrementing the "real"
492            // head but not the steal. By doing this, no other thread is able to
493            // steal from this queue until the current thread completes.
494            let res = self
495                .0
496                .head
497                .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
498
499            match res {
500                Ok(_) => break n,
501                Err(actual) => prev_packed = actual,
502            }
503        };
504
505        assert!(
506            n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
507            "actual = {n}"
508        );
509
510        let (first, _) = unpack(next_packed);
511
512        // Take all the tasks
513        for i in 0..n {
514            // Compute the positions
515            let src_pos = first.wrapping_add(i);
516            let dst_pos = dst_tail.wrapping_add(i);
517
518            // Map to slots
519            let src_idx = src_pos as usize & MASK;
520            let dst_idx = dst_pos as usize & MASK;
521
522            // Read the task
523            //
524            // safety: We acquired the task with the atomic exchange above.
525            let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
526
527            // Write the task to the new slot
528            //
529            // safety: `dst` queue is empty and we are the only producer to
530            // this queue.
531            dst.inner.buffer[dst_idx]
532                .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
533        }
534
535        let mut prev_packed = next_packed;
536
537        // Update `src_head_steal` to match `src_head_real` signalling that the
538        // stealing routine is complete.
539        loop {
540            let head = unpack(prev_packed).1;
541            next_packed = pack(head, head);
542
543            let res = self
544                .0
545                .head
546                .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
547
548            match res {
549                Ok(_) => return n,
550                Err(actual) => {
551                    let (actual_steal, actual_real) = unpack(actual);
552
553                    assert_ne!(actual_steal, actual_real);
554
555                    prev_packed = actual;
556                }
557            }
558        }
559    }
560}
561
562impl<T> Clone for Steal<T> {
563    fn clone(&self) -> Steal<T> {
564        Steal(self.0.clone())
565    }
566}
567
568impl<T> Drop for Local<T> {
569    fn drop(&mut self) {
570        if !std::thread::panicking() {
571            assert!(self.pop().is_none(), "queue not empty");
572        }
573    }
574}
575
576/// Calculate the length of the queue using the head and tail.
577/// The `head` can be the `steal` or `real` head.
578fn len(head: UnsignedShort, tail: UnsignedShort) -> usize {
579    tail.wrapping_sub(head) as usize
580}
581
582/// Split the head value into the real head and the index a stealer is working
583/// on.
584fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
585    let real = n & UnsignedShort::MAX as UnsignedLong;
586    let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
587
588    (steal as UnsignedShort, real as UnsignedShort)
589}
590
591/// Join the two head values
592fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
593    (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
594}
595
596#[test]
597fn test_local_queue_capacity() {
598    assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
599}