tokio/runtime/scheduler/multi_thread/queue.rs
1//! Run-queue structures to support a work-stealing scheduler
2
3use crate::loom::cell::UnsafeCell;
4use crate::loom::sync::Arc;
5use crate::runtime::scheduler::multi_thread::{Overflow, Stats};
6use crate::runtime::task;
7
8use std::mem::{self, MaybeUninit};
9use std::ptr;
10use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11
12// Use wider integers when possible to increase ABA resilience.
13//
14// See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15cfg_has_atomic_u64! {
16 type UnsignedShort = u32;
17 type UnsignedLong = u64;
18 type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19 type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20}
21cfg_not_has_atomic_u64! {
22 type UnsignedShort = u16;
23 type UnsignedLong = u32;
24 type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25 type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26}
27
28/// Producer handle. May only be used from a single thread.
29pub(crate) struct Local<T: 'static> {
30 inner: Arc<Inner<T>>,
31}
32
33/// Consumer handle. May be used from many threads.
34pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35
36pub(crate) struct Inner<T: 'static> {
37 /// Concurrently updated by many threads.
38 ///
39 /// Contains two `UnsignedShort` values. The `LSB` byte is the "real" head of
40 /// the queue. The `UnsignedShort` in the `MSB` is set by a stealer in process
41 /// of stealing values. It represents the first value being stolen in the
42 /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43 /// required for buffer indexing in order to provide ABA mitigation and make
44 /// it possible to distinguish between full and empty buffers.
45 ///
46 /// When both `UnsignedShort` values are the same, there is no active
47 /// stealer.
48 ///
49 /// Tracking an in-progress stealer prevents a wrapping scenario.
50 head: AtomicUnsignedLong,
51
52 /// Only updated by producer thread but read by many threads.
53 tail: AtomicUnsignedShort,
54
55 /// Elements
56 buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57}
58
59unsafe impl<T> Send for Inner<T> {}
60unsafe impl<T> Sync for Inner<T> {}
61
62#[cfg(not(loom))]
63const LOCAL_QUEUE_CAPACITY: usize = 256;
64
65// Shrink the size of the local queue when using loom. This shouldn't impact
66// logic, but allows loom to test more edge cases in a reasonable a mount of
67// time.
68#[cfg(loom)]
69const LOCAL_QUEUE_CAPACITY: usize = 4;
70
71const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72
73// Constructing the fixed size array directly is very awkward. The only way to
74// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75// the contents are not Copy. The trick with defining a const doesn't work for
76// generic types.
77fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78 assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79
80 // safety: We check that the length is correct.
81 unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82}
83
84/// Create a new local run-queue
85pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86 let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87
88 for _ in 0..LOCAL_QUEUE_CAPACITY {
89 buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90 }
91
92 let inner = Arc::new(Inner {
93 head: AtomicUnsignedLong::new(0),
94 tail: AtomicUnsignedShort::new(0),
95 buffer: make_fixed_size(buffer.into_boxed_slice()),
96 });
97
98 let local = Local {
99 inner: inner.clone(),
100 };
101
102 let remote = Steal(inner);
103
104 (remote, local)
105}
106
107impl<T> Local<T> {
108 /// Returns the number of entries in the queue
109 pub(crate) fn len(&self) -> usize {
110 let (_, head) = unpack(self.inner.head.load(Acquire));
111 // safety: this is the **only** thread that updates this cell.
112 let tail = unsafe { self.inner.tail.unsync_load() };
113 len(head, tail)
114 }
115
116 /// How many tasks can be pushed into the queue
117 pub(crate) fn remaining_slots(&self) -> usize {
118 let (steal, _) = unpack(self.inner.head.load(Acquire));
119 // safety: this is the **only** thread that updates this cell.
120 let tail = unsafe { self.inner.tail.unsync_load() };
121
122 LOCAL_QUEUE_CAPACITY - len(steal, tail)
123 }
124
125 pub(crate) fn max_capacity(&self) -> usize {
126 LOCAL_QUEUE_CAPACITY
127 }
128
129 /// Returns false if there are any entries in the queue
130 ///
131 /// Separate to `is_stealable` so that refactors of `is_stealable` to "protect"
132 /// some tasks from stealing won't affect this
133 pub(crate) fn has_tasks(&self) -> bool {
134 self.len() != 0
135 }
136
137 /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
138 /// the local queue.
139 ///
140 /// # Panics
141 ///
142 /// The method panics if there is not enough capacity to fit in the queue.
143 pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
144 let len = tasks.len();
145 assert!(len <= LOCAL_QUEUE_CAPACITY);
146
147 if len == 0 {
148 // Nothing to do
149 return;
150 }
151
152 let head = self.inner.head.load(Acquire);
153 let (steal, _) = unpack(head);
154
155 // safety: this is the **only** thread that updates this cell.
156 let mut tail = unsafe { self.inner.tail.unsync_load() };
157
158 if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort {
159 // Yes, this if condition is structured a bit weird (first block
160 // does nothing, second returns an error). It is this way to match
161 // `push_back_or_overflow`.
162 } else {
163 panic!()
164 }
165
166 for task in tasks {
167 let idx = tail as usize & MASK;
168
169 self.inner.buffer[idx].with_mut(|ptr| {
170 // Write the task to the slot
171 //
172 // Safety: There is only one producer and the above `if`
173 // condition ensures we don't touch a cell if there is a
174 // value, thus no consumer.
175 unsafe {
176 ptr::write((*ptr).as_mut_ptr(), task);
177 }
178 });
179
180 tail = tail.wrapping_add(1);
181 }
182
183 self.inner.tail.store(tail, Release);
184 }
185
186 /// Pushes a task to the back of the local queue, if there is not enough
187 /// capacity in the queue, this triggers the overflow operation.
188 ///
189 /// When the queue overflows, half of the current contents of the queue is
190 /// moved to the given Injection queue. This frees up capacity for more
191 /// tasks to be pushed into the local queue.
192 pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
193 &mut self,
194 mut task: task::Notified<T>,
195 overflow: &O,
196 stats: &mut Stats,
197 ) {
198 let tail = loop {
199 let head = self.inner.head.load(Acquire);
200 let (steal, real) = unpack(head);
201
202 // safety: this is the **only** thread that updates this cell.
203 let tail = unsafe { self.inner.tail.unsync_load() };
204
205 if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
206 // There is capacity for the task
207 break tail;
208 } else if steal != real {
209 // Concurrently stealing, this will free up capacity, so only
210 // push the task onto the inject queue
211 overflow.push(task);
212 return;
213 } else {
214 // Push the current task and half of the queue into the
215 // inject queue.
216 match self.push_overflow(task, real, tail, overflow, stats) {
217 Ok(_) => return,
218 // Lost the race, try again
219 Err(v) => {
220 task = v;
221 }
222 }
223 }
224 };
225
226 self.push_back_finish(task, tail);
227 }
228
229 // Second half of `push_back`
230 fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
231 // Map the position to a slot index.
232 let idx = tail as usize & MASK;
233
234 self.inner.buffer[idx].with_mut(|ptr| {
235 // Write the task to the slot
236 //
237 // Safety: There is only one producer and the above `if`
238 // condition ensures we don't touch a cell if there is a
239 // value, thus no consumer.
240 unsafe {
241 ptr::write((*ptr).as_mut_ptr(), task);
242 }
243 });
244
245 // Make the task available. Synchronizes with a load in
246 // `steal_into2`.
247 self.inner.tail.store(tail.wrapping_add(1), Release);
248 }
249
250 /// Moves a batch of tasks into the inject queue.
251 ///
252 /// This will temporarily make some of the tasks unavailable to stealers.
253 /// Once `push_overflow` is done, a notification is sent out, so if other
254 /// workers "missed" some of the tasks during a steal, they will get
255 /// another opportunity.
256 #[inline(never)]
257 fn push_overflow<O: Overflow<T>>(
258 &mut self,
259 task: task::Notified<T>,
260 head: UnsignedShort,
261 tail: UnsignedShort,
262 overflow: &O,
263 stats: &mut Stats,
264 ) -> Result<(), task::Notified<T>> {
265 /// How many elements are we taking from the local queue.
266 ///
267 /// This is one less than the number of tasks pushed to the inject
268 /// queue as we are also inserting the `task` argument.
269 const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
270
271 assert_eq!(
272 tail.wrapping_sub(head) as usize,
273 LOCAL_QUEUE_CAPACITY,
274 "queue is not full; tail = {tail}; head = {head}"
275 );
276
277 let prev = pack(head, head);
278
279 // Claim a bunch of tasks
280 //
281 // We are claiming the tasks **before** reading them out of the buffer.
282 // This is safe because only the **current** thread is able to push new
283 // tasks.
284 //
285 // There isn't really any need for memory ordering... Relaxed would
286 // work. This is because all tasks are pushed into the queue from the
287 // current thread (or memory has been acquired if the local queue handle
288 // moved).
289 if self
290 .inner
291 .head
292 .compare_exchange(
293 prev,
294 pack(
295 head.wrapping_add(NUM_TASKS_TAKEN),
296 head.wrapping_add(NUM_TASKS_TAKEN),
297 ),
298 Release,
299 Relaxed,
300 )
301 .is_err()
302 {
303 // We failed to claim the tasks, losing the race. Return out of
304 // this function and try the full `push` routine again. The queue
305 // may not be full anymore.
306 return Err(task);
307 }
308
309 /// An iterator that takes elements out of the run queue.
310 struct BatchTaskIter<'a, T: 'static> {
311 buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
312 head: UnsignedLong,
313 i: UnsignedLong,
314 }
315 impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
316 type Item = task::Notified<T>;
317
318 #[inline]
319 fn next(&mut self) -> Option<task::Notified<T>> {
320 if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
321 None
322 } else {
323 let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
324 let slot = &self.buffer[i_idx];
325
326 // safety: Our CAS from before has assumed exclusive ownership
327 // of the task pointers in this range.
328 let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
329
330 self.i += 1;
331 Some(task)
332 }
333 }
334 }
335
336 // safety: The CAS above ensures that no consumer will look at these
337 // values again, and we are the only producer.
338 let batch_iter = BatchTaskIter {
339 buffer: &self.inner.buffer,
340 head: head as UnsignedLong,
341 i: 0,
342 };
343 overflow.push_batch(batch_iter.chain(std::iter::once(task)));
344
345 // Add 1 to factor in the task currently being scheduled.
346 stats.incr_overflow_count();
347
348 Ok(())
349 }
350
351 /// Pops a task from the local queue.
352 pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
353 let mut head = self.inner.head.load(Acquire);
354
355 let idx = loop {
356 let (steal, real) = unpack(head);
357
358 // safety: this is the **only** thread that updates this cell.
359 let tail = unsafe { self.inner.tail.unsync_load() };
360
361 if real == tail {
362 // queue is empty
363 return None;
364 }
365
366 let next_real = real.wrapping_add(1);
367
368 // If `steal == real` there are no concurrent stealers. Both `steal`
369 // and `real` are updated.
370 let next = if steal == real {
371 pack(next_real, next_real)
372 } else {
373 assert_ne!(steal, next_real);
374 pack(steal, next_real)
375 };
376
377 // Attempt to claim a task.
378 let res = self
379 .inner
380 .head
381 .compare_exchange(head, next, AcqRel, Acquire);
382
383 match res {
384 Ok(_) => break real as usize & MASK,
385 Err(actual) => head = actual,
386 }
387 };
388
389 Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
390 }
391}
392
393impl<T> Steal<T> {
394 /// Returns the number of entries in the queue
395 pub(crate) fn len(&self) -> usize {
396 let (_, head) = unpack(self.0.head.load(Acquire));
397 let tail = self.0.tail.load(Acquire);
398 len(head, tail)
399 }
400
401 /// Return true if the queue is empty,
402 /// false if there are any entries in the queue
403 pub(crate) fn is_empty(&self) -> bool {
404 self.len() == 0
405 }
406
407 /// Steals half the tasks from self and place them into `dst`.
408 pub(crate) fn steal_into(
409 &self,
410 dst: &mut Local<T>,
411 dst_stats: &mut Stats,
412 ) -> Option<task::Notified<T>> {
413 // Safety: the caller is the only thread that mutates `dst.tail` and
414 // holds a mutable reference.
415 let dst_tail = unsafe { dst.inner.tail.unsync_load() };
416
417 // To the caller, `dst` may **look** empty but still have values
418 // contained in the buffer. If another thread is concurrently stealing
419 // from `dst` there may not be enough capacity to steal.
420 let (steal, _) = unpack(dst.inner.head.load(Acquire));
421
422 if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
423 // we *could* try to steal less here, but for simplicity, we're just
424 // going to abort.
425 return None;
426 }
427
428 // Steal the tasks into `dst`'s buffer. This does not yet expose the
429 // tasks in `dst`.
430 let mut n = self.steal_into2(dst, dst_tail);
431
432 if n == 0 {
433 // No tasks were stolen
434 return None;
435 }
436
437 dst_stats.incr_steal_count(n as u16);
438 dst_stats.incr_steal_operations();
439
440 // We are returning a task here
441 n -= 1;
442
443 let ret_pos = dst_tail.wrapping_add(n);
444 let ret_idx = ret_pos as usize & MASK;
445
446 // safety: the value was written as part of `steal_into2` and not
447 // exposed to stealers, so no other thread can access it.
448 let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
449
450 if n == 0 {
451 // The `dst` queue is empty, but a single task was stolen
452 return Some(ret);
453 }
454
455 // Make the stolen items available to consumers
456 dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
457
458 Some(ret)
459 }
460
461 // Steal tasks from `self`, placing them into `dst`. Returns the number of
462 // tasks that were stolen.
463 fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
464 let mut prev_packed = self.0.head.load(Acquire);
465 let mut next_packed;
466
467 let n = loop {
468 let (src_head_steal, src_head_real) = unpack(prev_packed);
469 let src_tail = self.0.tail.load(Acquire);
470
471 // If these two do not match, another thread is concurrently
472 // stealing from the queue.
473 if src_head_steal != src_head_real {
474 return 0;
475 }
476
477 // Number of available tasks to steal
478 let n = src_tail.wrapping_sub(src_head_real);
479 let n = n - n / 2;
480
481 if n == 0 {
482 // No tasks available to steal
483 return 0;
484 }
485
486 // Update the real head index to acquire the tasks.
487 let steal_to = src_head_real.wrapping_add(n);
488 assert_ne!(src_head_steal, steal_to);
489 next_packed = pack(src_head_steal, steal_to);
490
491 // Claim all those tasks. This is done by incrementing the "real"
492 // head but not the steal. By doing this, no other thread is able to
493 // steal from this queue until the current thread completes.
494 let res = self
495 .0
496 .head
497 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
498
499 match res {
500 Ok(_) => break n,
501 Err(actual) => prev_packed = actual,
502 }
503 };
504
505 assert!(
506 n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
507 "actual = {n}"
508 );
509
510 let (first, _) = unpack(next_packed);
511
512 // Take all the tasks
513 for i in 0..n {
514 // Compute the positions
515 let src_pos = first.wrapping_add(i);
516 let dst_pos = dst_tail.wrapping_add(i);
517
518 // Map to slots
519 let src_idx = src_pos as usize & MASK;
520 let dst_idx = dst_pos as usize & MASK;
521
522 // Read the task
523 //
524 // safety: We acquired the task with the atomic exchange above.
525 let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
526
527 // Write the task to the new slot
528 //
529 // safety: `dst` queue is empty and we are the only producer to
530 // this queue.
531 dst.inner.buffer[dst_idx]
532 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
533 }
534
535 let mut prev_packed = next_packed;
536
537 // Update `src_head_steal` to match `src_head_real` signalling that the
538 // stealing routine is complete.
539 loop {
540 let head = unpack(prev_packed).1;
541 next_packed = pack(head, head);
542
543 let res = self
544 .0
545 .head
546 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
547
548 match res {
549 Ok(_) => return n,
550 Err(actual) => {
551 let (actual_steal, actual_real) = unpack(actual);
552
553 assert_ne!(actual_steal, actual_real);
554
555 prev_packed = actual;
556 }
557 }
558 }
559 }
560}
561
562impl<T> Clone for Steal<T> {
563 fn clone(&self) -> Steal<T> {
564 Steal(self.0.clone())
565 }
566}
567
568impl<T> Drop for Local<T> {
569 fn drop(&mut self) {
570 if !std::thread::panicking() {
571 assert!(self.pop().is_none(), "queue not empty");
572 }
573 }
574}
575
576/// Calculate the length of the queue using the head and tail.
577/// The `head` can be the `steal` or `real` head.
578fn len(head: UnsignedShort, tail: UnsignedShort) -> usize {
579 tail.wrapping_sub(head) as usize
580}
581
582/// Split the head value into the real head and the index a stealer is working
583/// on.
584fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
585 let real = n & UnsignedShort::MAX as UnsignedLong;
586 let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
587
588 (steal as UnsignedShort, real as UnsignedShort)
589}
590
591/// Join the two head values
592fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
593 (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
594}
595
596#[test]
597fn test_local_queue_capacity() {
598 assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
599}