tokio/runtime/blocking/
pool.rs

1//! Thread pool for blocking operations
2
3use crate::loom::sync::{Arc, Condvar, Mutex};
4use crate::loom::thread;
5use crate::runtime::blocking::schedule::BlockingSchedule;
6use crate::runtime::blocking::{shutdown, BlockingTask};
7use crate::runtime::builder::ThreadNameFn;
8use crate::runtime::task::{self, JoinHandle};
9use crate::runtime::{Builder, Callback, Handle, BOX_FUTURE_THRESHOLD};
10use crate::util::metric_atomics::MetricAtomicUsize;
11use crate::util::trace::{blocking_task, SpawnMeta};
12
13use std::collections::{HashMap, VecDeque};
14use std::fmt;
15use std::io;
16use std::sync::atomic::Ordering;
17use std::time::Duration;
18
19pub(crate) struct BlockingPool {
20    spawner: Spawner,
21    shutdown_rx: shutdown::Receiver,
22}
23
24#[derive(Clone)]
25pub(crate) struct Spawner {
26    inner: Arc<Inner>,
27}
28
29#[derive(Default)]
30pub(crate) struct SpawnerMetrics {
31    num_threads: MetricAtomicUsize,
32    num_idle_threads: MetricAtomicUsize,
33    queue_depth: MetricAtomicUsize,
34}
35
36impl SpawnerMetrics {
37    fn num_threads(&self) -> usize {
38        self.num_threads.load(Ordering::Relaxed)
39    }
40
41    fn num_idle_threads(&self) -> usize {
42        self.num_idle_threads.load(Ordering::Relaxed)
43    }
44
45    cfg_unstable_metrics! {
46        fn queue_depth(&self) -> usize {
47            self.queue_depth.load(Ordering::Relaxed)
48        }
49    }
50
51    fn inc_num_threads(&self) {
52        self.num_threads.increment();
53    }
54
55    fn dec_num_threads(&self) {
56        self.num_threads.decrement();
57    }
58
59    fn inc_num_idle_threads(&self) {
60        self.num_idle_threads.increment();
61    }
62
63    fn dec_num_idle_threads(&self) -> usize {
64        self.num_idle_threads.decrement()
65    }
66
67    fn inc_queue_depth(&self) {
68        self.queue_depth.increment();
69    }
70
71    fn dec_queue_depth(&self) {
72        self.queue_depth.decrement();
73    }
74}
75
76struct Inner {
77    /// State shared between worker threads.
78    shared: Mutex<Shared>,
79
80    /// Pool threads wait on this.
81    condvar: Condvar,
82
83    /// Spawned threads use this name.
84    thread_name: ThreadNameFn,
85
86    /// Spawned thread stack size.
87    stack_size: Option<usize>,
88
89    /// Call after a thread starts.
90    after_start: Option<Callback>,
91
92    /// Call before a thread stops.
93    before_stop: Option<Callback>,
94
95    // Maximum number of threads.
96    thread_cap: usize,
97
98    // Customizable wait timeout.
99    keep_alive: Duration,
100
101    // Metrics about the pool.
102    metrics: SpawnerMetrics,
103}
104
105struct Shared {
106    queue: VecDeque<Task>,
107    num_notify: u32,
108    shutdown: bool,
109    shutdown_tx: Option<shutdown::Sender>,
110    /// Prior to shutdown, we clean up `JoinHandles` by having each timed-out
111    /// thread join on the previous timed-out thread. This is not strictly
112    /// necessary but helps avoid Valgrind false positives, see
113    /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666>
114    /// for more information.
115    last_exiting_thread: Option<thread::JoinHandle<()>>,
116    /// This holds the `JoinHandles` for all running threads; on shutdown, the thread
117    /// calling shutdown handles joining on these.
118    worker_threads: HashMap<usize, thread::JoinHandle<()>>,
119    /// This is a counter used to iterate `worker_threads` in a consistent order (for loom's
120    /// benefit).
121    worker_thread_index: usize,
122}
123
124pub(crate) struct Task {
125    task: task::UnownedTask<BlockingSchedule>,
126    mandatory: Mandatory,
127}
128
129#[derive(PartialEq, Eq)]
130pub(crate) enum Mandatory {
131    #[cfg_attr(not(fs), allow(dead_code))]
132    Mandatory,
133    NonMandatory,
134}
135
136pub(crate) enum SpawnError {
137    /// Pool is shutting down and the task was not scheduled
138    ShuttingDown,
139    /// There are no worker threads available to take the task
140    /// and the OS failed to spawn a new one
141    NoThreads(io::Error),
142}
143
144impl From<SpawnError> for io::Error {
145    fn from(e: SpawnError) -> Self {
146        match e {
147            SpawnError::ShuttingDown => {
148                io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
149            }
150            SpawnError::NoThreads(e) => e,
151        }
152    }
153}
154
155impl Task {
156    pub(crate) fn new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task {
157        Task { task, mandatory }
158    }
159
160    fn run(self) {
161        self.task.run();
162    }
163
164    fn shutdown_or_run_if_mandatory(self) {
165        match self.mandatory {
166            Mandatory::NonMandatory => self.task.shutdown(),
167            Mandatory::Mandatory => self.task.run(),
168        }
169    }
170}
171
172const KEEP_ALIVE: Duration = Duration::from_secs(10);
173
174/// Runs the provided function on an executor dedicated to blocking operations.
175/// Tasks will be scheduled as non-mandatory, meaning they may not get executed
176/// in case of runtime shutdown.
177#[track_caller]
178#[cfg_attr(target_os = "wasi", allow(dead_code))]
179pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
180where
181    F: FnOnce() -> R + Send + 'static,
182    R: Send + 'static,
183{
184    let rt = Handle::current();
185    rt.spawn_blocking(func)
186}
187
188cfg_fs! {
189    #[cfg_attr(any(
190        all(loom, not(test)), // the function is covered by loom tests
191        test
192    ), allow(dead_code))]
193    /// Runs the provided function on an executor dedicated to blocking
194    /// operations. Tasks will be scheduled as mandatory, meaning they are
195    /// guaranteed to run unless a shutdown is already taking place. In case a
196    /// shutdown is already taking place, `None` will be returned.
197    pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
198    where
199        F: FnOnce() -> R + Send + 'static,
200        R: Send + 'static,
201    {
202        let rt = Handle::current();
203        rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func)
204    }
205}
206
207// ===== impl BlockingPool =====
208
209impl BlockingPool {
210    pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
211        let (shutdown_tx, shutdown_rx) = shutdown::channel();
212        let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
213
214        BlockingPool {
215            spawner: Spawner {
216                inner: Arc::new(Inner {
217                    shared: Mutex::new(Shared {
218                        queue: VecDeque::new(),
219                        num_notify: 0,
220                        shutdown: false,
221                        shutdown_tx: Some(shutdown_tx),
222                        last_exiting_thread: None,
223                        worker_threads: HashMap::new(),
224                        worker_thread_index: 0,
225                    }),
226                    condvar: Condvar::new(),
227                    thread_name: builder.thread_name.clone(),
228                    stack_size: builder.thread_stack_size,
229                    after_start: builder.after_start.clone(),
230                    before_stop: builder.before_stop.clone(),
231                    thread_cap,
232                    keep_alive,
233                    metrics: SpawnerMetrics::default(),
234                }),
235            },
236            shutdown_rx,
237        }
238    }
239
240    pub(crate) fn spawner(&self) -> &Spawner {
241        &self.spawner
242    }
243
244    pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
245        let mut shared = self.spawner.inner.shared.lock();
246
247        // The function can be called multiple times. First, by explicitly
248        // calling `shutdown` then by the drop handler calling `shutdown`. This
249        // prevents shutting down twice.
250        if shared.shutdown {
251            return;
252        }
253
254        shared.shutdown = true;
255        shared.shutdown_tx = None;
256        self.spawner.inner.condvar.notify_all();
257
258        let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
259        let workers = std::mem::take(&mut shared.worker_threads);
260
261        drop(shared);
262
263        if self.shutdown_rx.wait(timeout) {
264            let _ = last_exited_thread.map(thread::JoinHandle::join);
265
266            // Loom requires that execution be deterministic, so sort by thread ID before joining.
267            // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic)
268            #[cfg(loom)]
269            let workers: Vec<(usize, thread::JoinHandle<()>)> = {
270                let mut workers: Vec<_> = workers.into_iter().collect();
271                workers.sort_by_key(|(id, _)| *id);
272                workers
273            };
274
275            for (_id, handle) in workers {
276                let _ = handle.join();
277            }
278        }
279    }
280}
281
282impl Drop for BlockingPool {
283    fn drop(&mut self) {
284        self.shutdown(None);
285    }
286}
287
288impl fmt::Debug for BlockingPool {
289    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
290        fmt.debug_struct("BlockingPool").finish()
291    }
292}
293
294// ===== impl Spawner =====
295
296impl Spawner {
297    #[track_caller]
298    pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
299    where
300        F: FnOnce() -> R + Send + 'static,
301        R: Send + 'static,
302    {
303        let fn_size = std::mem::size_of::<F>();
304        let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
305            self.spawn_blocking_inner(
306                Box::new(func),
307                Mandatory::NonMandatory,
308                SpawnMeta::new_unnamed(fn_size),
309                rt,
310            )
311        } else {
312            self.spawn_blocking_inner(
313                func,
314                Mandatory::NonMandatory,
315                SpawnMeta::new_unnamed(fn_size),
316                rt,
317            )
318        };
319
320        match spawn_result {
321            Ok(()) => join_handle,
322            // Compat: do not panic here, return the join_handle even though it will never resolve
323            Err(SpawnError::ShuttingDown) => join_handle,
324            Err(SpawnError::NoThreads(e)) => {
325                panic!("OS can't spawn worker thread: {e}")
326            }
327        }
328    }
329
330    cfg_fs! {
331        #[track_caller]
332        #[cfg_attr(any(
333            all(loom, not(test)), // the function is covered by loom tests
334            test
335        ), allow(dead_code))]
336        pub(crate) fn spawn_mandatory_blocking<F, R>(&self, rt: &Handle, func: F) -> Option<JoinHandle<R>>
337        where
338            F: FnOnce() -> R + Send + 'static,
339            R: Send + 'static,
340        {
341            let fn_size = std::mem::size_of::<F>();
342            let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
343                self.spawn_blocking_inner(
344                    Box::new(func),
345                    Mandatory::Mandatory,
346                    SpawnMeta::new_unnamed(fn_size),
347                    rt,
348                )
349            } else {
350                self.spawn_blocking_inner(
351                    func,
352                    Mandatory::Mandatory,
353                    SpawnMeta::new_unnamed(fn_size),
354                    rt,
355                )
356            };
357
358            if spawn_result.is_ok() {
359                Some(join_handle)
360            } else {
361                None
362            }
363        }
364    }
365
366    #[track_caller]
367    pub(crate) fn spawn_blocking_inner<F, R>(
368        &self,
369        func: F,
370        is_mandatory: Mandatory,
371        spawn_meta: SpawnMeta<'_>,
372        rt: &Handle,
373    ) -> (JoinHandle<R>, Result<(), SpawnError>)
374    where
375        F: FnOnce() -> R + Send + 'static,
376        R: Send + 'static,
377    {
378        let id = task::Id::next();
379        let fut =
380            blocking_task::<F, BlockingTask<F>>(BlockingTask::new(func), spawn_meta, id.as_u64());
381
382        let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id);
383
384        let spawned = self.spawn_task(Task::new(task, is_mandatory), rt);
385        (handle, spawned)
386    }
387
388    fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> {
389        let mut shared = self.inner.shared.lock();
390
391        if shared.shutdown {
392            // Shutdown the task: it's fine to shutdown this task (even if
393            // mandatory) because it was scheduled after the shutdown of the
394            // runtime began.
395            task.task.shutdown();
396
397            // no need to even push this task; it would never get picked up
398            return Err(SpawnError::ShuttingDown);
399        }
400
401        shared.queue.push_back(task);
402        self.inner.metrics.inc_queue_depth();
403
404        if self.inner.metrics.num_idle_threads() == 0 {
405            // No threads are able to process the task.
406
407            if self.inner.metrics.num_threads() == self.inner.thread_cap {
408                // At max number of threads
409            } else {
410                assert!(shared.shutdown_tx.is_some());
411                let shutdown_tx = shared.shutdown_tx.clone();
412
413                if let Some(shutdown_tx) = shutdown_tx {
414                    let id = shared.worker_thread_index;
415
416                    match self.spawn_thread(shutdown_tx, rt, id) {
417                        Ok(handle) => {
418                            self.inner.metrics.inc_num_threads();
419                            shared.worker_thread_index += 1;
420                            shared.worker_threads.insert(id, handle);
421                        }
422                        Err(ref e)
423                            if is_temporary_os_thread_error(e)
424                                && self.inner.metrics.num_threads() > 0 =>
425                        {
426                            // OS temporarily failed to spawn a new thread.
427                            // The task will be picked up eventually by a currently
428                            // busy thread.
429                        }
430                        Err(e) => {
431                            // The OS refused to spawn the thread and there is no thread
432                            // to pick up the task that has just been pushed to the queue.
433                            return Err(SpawnError::NoThreads(e));
434                        }
435                    }
436                }
437            }
438        } else {
439            // Notify an idle worker thread. The notification counter
440            // is used to count the needed amount of notifications
441            // exactly. Thread libraries may generate spurious
442            // wakeups, this counter is used to keep us in a
443            // consistent state.
444            self.inner.metrics.dec_num_idle_threads();
445            shared.num_notify += 1;
446            self.inner.condvar.notify_one();
447        }
448
449        Ok(())
450    }
451
452    fn spawn_thread(
453        &self,
454        shutdown_tx: shutdown::Sender,
455        rt: &Handle,
456        id: usize,
457    ) -> io::Result<thread::JoinHandle<()>> {
458        let mut builder = thread::Builder::new().name((self.inner.thread_name)());
459
460        if let Some(stack_size) = self.inner.stack_size {
461            builder = builder.stack_size(stack_size);
462        }
463
464        let rt = rt.clone();
465
466        builder.spawn(move || {
467            // Only the reference should be moved into the closure
468            let _enter = rt.enter();
469            rt.inner.blocking_spawner().inner.run(id);
470            drop(shutdown_tx);
471        })
472    }
473}
474
475cfg_unstable_metrics! {
476    impl Spawner {
477        pub(crate) fn num_threads(&self) -> usize {
478            self.inner.metrics.num_threads()
479        }
480
481        pub(crate) fn num_idle_threads(&self) -> usize {
482            self.inner.metrics.num_idle_threads()
483        }
484
485        pub(crate) fn queue_depth(&self) -> usize {
486            self.inner.metrics.queue_depth()
487        }
488    }
489}
490
491// Tells whether the error when spawning a thread is temporary.
492#[inline]
493fn is_temporary_os_thread_error(error: &io::Error) -> bool {
494    matches!(error.kind(), io::ErrorKind::WouldBlock)
495}
496
497impl Inner {
498    fn run(&self, worker_thread_id: usize) {
499        if let Some(f) = &self.after_start {
500            f();
501        }
502
503        let mut shared = self.shared.lock();
504        let mut join_on_thread = None;
505
506        'main: loop {
507            // BUSY
508            while let Some(task) = shared.queue.pop_front() {
509                self.metrics.dec_queue_depth();
510                drop(shared);
511                task.run();
512
513                shared = self.shared.lock();
514            }
515
516            // IDLE
517            self.metrics.inc_num_idle_threads();
518
519            while !shared.shutdown {
520                let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
521
522                shared = lock_result.0;
523                let timeout_result = lock_result.1;
524
525                if shared.num_notify != 0 {
526                    // We have received a legitimate wakeup,
527                    // acknowledge it by decrementing the counter
528                    // and transition to the BUSY state.
529                    shared.num_notify -= 1;
530                    break;
531                }
532
533                // Even if the condvar "timed out", if the pool is entering the
534                // shutdown phase, we want to perform the cleanup logic.
535                if !shared.shutdown && timeout_result.timed_out() {
536                    // We'll join the prior timed-out thread's JoinHandle after dropping the lock.
537                    // This isn't done when shutting down, because the thread calling shutdown will
538                    // handle joining everything.
539                    let my_handle = shared.worker_threads.remove(&worker_thread_id);
540                    join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
541
542                    break 'main;
543                }
544
545                // Spurious wakeup detected, go back to sleep.
546            }
547
548            if shared.shutdown {
549                // Drain the queue
550                while let Some(task) = shared.queue.pop_front() {
551                    self.metrics.dec_queue_depth();
552                    drop(shared);
553
554                    task.shutdown_or_run_if_mandatory();
555
556                    shared = self.shared.lock();
557                }
558
559                // Work was produced, and we "took" it (by decrementing num_notify).
560                // This means that num_idle was decremented once for our wakeup.
561                // But, since we are exiting, we need to "undo" that, as we'll stay idle.
562                self.metrics.inc_num_idle_threads();
563                // NOTE: Technically we should also do num_notify++ and notify again,
564                // but since we're shutting down anyway, that won't be necessary.
565                break;
566            }
567        }
568
569        // Thread exit
570        self.metrics.dec_num_threads();
571
572        // num_idle should now be tracked exactly, panic
573        // with a descriptive message if it is not the
574        // case.
575        let prev_idle = self.metrics.dec_num_idle_threads();
576        assert!(
577            prev_idle >= self.metrics.num_idle_threads(),
578            "num_idle_threads underflowed on thread exit"
579        );
580
581        if shared.shutdown && self.metrics.num_threads() == 0 {
582            self.condvar.notify_one();
583        }
584
585        drop(shared);
586
587        if let Some(f) = &self.before_stop {
588            f();
589        }
590
591        if let Some(handle) = join_on_thread {
592            let _ = handle.join();
593        }
594    }
595}
596
597impl fmt::Debug for Spawner {
598    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
599        fmt.debug_struct("blocking::Spawner").finish()
600    }
601}