1use crate::loom::sync::{Arc, Condvar, Mutex};
4use crate::loom::thread;
5use crate::runtime::blocking::schedule::BlockingSchedule;
6use crate::runtime::blocking::{shutdown, BlockingTask};
7use crate::runtime::builder::ThreadNameFn;
8use crate::runtime::task::{self, JoinHandle};
9use crate::runtime::{Builder, Callback, Handle, BOX_FUTURE_THRESHOLD};
10use crate::util::metric_atomics::MetricAtomicUsize;
11use crate::util::trace::{blocking_task, SpawnMeta};
12
13use std::collections::{HashMap, VecDeque};
14use std::fmt;
15use std::io;
16use std::sync::atomic::Ordering;
17use std::time::Duration;
18
19pub(crate) struct BlockingPool {
20 spawner: Spawner,
21 shutdown_rx: shutdown::Receiver,
22}
23
24#[derive(Clone)]
25pub(crate) struct Spawner {
26 inner: Arc<Inner>,
27}
28
29#[derive(Default)]
30pub(crate) struct SpawnerMetrics {
31 num_threads: MetricAtomicUsize,
32 num_idle_threads: MetricAtomicUsize,
33 queue_depth: MetricAtomicUsize,
34}
35
36impl SpawnerMetrics {
37 fn num_threads(&self) -> usize {
38 self.num_threads.load(Ordering::Relaxed)
39 }
40
41 fn num_idle_threads(&self) -> usize {
42 self.num_idle_threads.load(Ordering::Relaxed)
43 }
44
45 cfg_unstable_metrics! {
46 fn queue_depth(&self) -> usize {
47 self.queue_depth.load(Ordering::Relaxed)
48 }
49 }
50
51 fn inc_num_threads(&self) {
52 self.num_threads.increment();
53 }
54
55 fn dec_num_threads(&self) {
56 self.num_threads.decrement();
57 }
58
59 fn inc_num_idle_threads(&self) {
60 self.num_idle_threads.increment();
61 }
62
63 fn dec_num_idle_threads(&self) -> usize {
64 self.num_idle_threads.decrement()
65 }
66
67 fn inc_queue_depth(&self) {
68 self.queue_depth.increment();
69 }
70
71 fn dec_queue_depth(&self) {
72 self.queue_depth.decrement();
73 }
74}
75
76struct Inner {
77 shared: Mutex<Shared>,
79
80 condvar: Condvar,
82
83 thread_name: ThreadNameFn,
85
86 stack_size: Option<usize>,
88
89 after_start: Option<Callback>,
91
92 before_stop: Option<Callback>,
94
95 thread_cap: usize,
97
98 keep_alive: Duration,
100
101 metrics: SpawnerMetrics,
103}
104
105struct Shared {
106 queue: VecDeque<Task>,
107 num_notify: u32,
108 shutdown: bool,
109 shutdown_tx: Option<shutdown::Sender>,
110 last_exiting_thread: Option<thread::JoinHandle<()>>,
116 worker_threads: HashMap<usize, thread::JoinHandle<()>>,
119 worker_thread_index: usize,
122}
123
124pub(crate) struct Task {
125 task: task::UnownedTask<BlockingSchedule>,
126 mandatory: Mandatory,
127}
128
129#[derive(PartialEq, Eq)]
130pub(crate) enum Mandatory {
131 #[cfg_attr(not(fs), allow(dead_code))]
132 Mandatory,
133 NonMandatory,
134}
135
136pub(crate) enum SpawnError {
137 ShuttingDown,
139 NoThreads(io::Error),
142}
143
144impl From<SpawnError> for io::Error {
145 fn from(e: SpawnError) -> Self {
146 match e {
147 SpawnError::ShuttingDown => {
148 io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
149 }
150 SpawnError::NoThreads(e) => e,
151 }
152 }
153}
154
155impl Task {
156 pub(crate) fn new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task {
157 Task { task, mandatory }
158 }
159
160 fn run(self) {
161 self.task.run();
162 }
163
164 fn shutdown_or_run_if_mandatory(self) {
165 match self.mandatory {
166 Mandatory::NonMandatory => self.task.shutdown(),
167 Mandatory::Mandatory => self.task.run(),
168 }
169 }
170}
171
172const KEEP_ALIVE: Duration = Duration::from_secs(10);
173
174#[track_caller]
178#[cfg_attr(target_os = "wasi", allow(dead_code))]
179pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
180where
181 F: FnOnce() -> R + Send + 'static,
182 R: Send + 'static,
183{
184 let rt = Handle::current();
185 rt.spawn_blocking(func)
186}
187
188cfg_fs! {
189 #[cfg_attr(any(
190 all(loom, not(test)), test
192 ), allow(dead_code))]
193 pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
198 where
199 F: FnOnce() -> R + Send + 'static,
200 R: Send + 'static,
201 {
202 let rt = Handle::current();
203 rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func)
204 }
205}
206
207impl BlockingPool {
210 pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
211 let (shutdown_tx, shutdown_rx) = shutdown::channel();
212 let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
213
214 BlockingPool {
215 spawner: Spawner {
216 inner: Arc::new(Inner {
217 shared: Mutex::new(Shared {
218 queue: VecDeque::new(),
219 num_notify: 0,
220 shutdown: false,
221 shutdown_tx: Some(shutdown_tx),
222 last_exiting_thread: None,
223 worker_threads: HashMap::new(),
224 worker_thread_index: 0,
225 }),
226 condvar: Condvar::new(),
227 thread_name: builder.thread_name.clone(),
228 stack_size: builder.thread_stack_size,
229 after_start: builder.after_start.clone(),
230 before_stop: builder.before_stop.clone(),
231 thread_cap,
232 keep_alive,
233 metrics: SpawnerMetrics::default(),
234 }),
235 },
236 shutdown_rx,
237 }
238 }
239
240 pub(crate) fn spawner(&self) -> &Spawner {
241 &self.spawner
242 }
243
244 pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
245 let mut shared = self.spawner.inner.shared.lock();
246
247 if shared.shutdown {
251 return;
252 }
253
254 shared.shutdown = true;
255 shared.shutdown_tx = None;
256 self.spawner.inner.condvar.notify_all();
257
258 let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
259 let workers = std::mem::take(&mut shared.worker_threads);
260
261 drop(shared);
262
263 if self.shutdown_rx.wait(timeout) {
264 let _ = last_exited_thread.map(thread::JoinHandle::join);
265
266 #[cfg(loom)]
269 let workers: Vec<(usize, thread::JoinHandle<()>)> = {
270 let mut workers: Vec<_> = workers.into_iter().collect();
271 workers.sort_by_key(|(id, _)| *id);
272 workers
273 };
274
275 for (_id, handle) in workers {
276 let _ = handle.join();
277 }
278 }
279 }
280}
281
282impl Drop for BlockingPool {
283 fn drop(&mut self) {
284 self.shutdown(None);
285 }
286}
287
288impl fmt::Debug for BlockingPool {
289 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
290 fmt.debug_struct("BlockingPool").finish()
291 }
292}
293
294impl Spawner {
297 #[track_caller]
298 pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
299 where
300 F: FnOnce() -> R + Send + 'static,
301 R: Send + 'static,
302 {
303 let fn_size = std::mem::size_of::<F>();
304 let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
305 self.spawn_blocking_inner(
306 Box::new(func),
307 Mandatory::NonMandatory,
308 SpawnMeta::new_unnamed(fn_size),
309 rt,
310 )
311 } else {
312 self.spawn_blocking_inner(
313 func,
314 Mandatory::NonMandatory,
315 SpawnMeta::new_unnamed(fn_size),
316 rt,
317 )
318 };
319
320 match spawn_result {
321 Ok(()) => join_handle,
322 Err(SpawnError::ShuttingDown) => join_handle,
324 Err(SpawnError::NoThreads(e)) => {
325 panic!("OS can't spawn worker thread: {e}")
326 }
327 }
328 }
329
330 cfg_fs! {
331 #[track_caller]
332 #[cfg_attr(any(
333 all(loom, not(test)), test
335 ), allow(dead_code))]
336 pub(crate) fn spawn_mandatory_blocking<F, R>(&self, rt: &Handle, func: F) -> Option<JoinHandle<R>>
337 where
338 F: FnOnce() -> R + Send + 'static,
339 R: Send + 'static,
340 {
341 let fn_size = std::mem::size_of::<F>();
342 let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
343 self.spawn_blocking_inner(
344 Box::new(func),
345 Mandatory::Mandatory,
346 SpawnMeta::new_unnamed(fn_size),
347 rt,
348 )
349 } else {
350 self.spawn_blocking_inner(
351 func,
352 Mandatory::Mandatory,
353 SpawnMeta::new_unnamed(fn_size),
354 rt,
355 )
356 };
357
358 if spawn_result.is_ok() {
359 Some(join_handle)
360 } else {
361 None
362 }
363 }
364 }
365
366 #[track_caller]
367 pub(crate) fn spawn_blocking_inner<F, R>(
368 &self,
369 func: F,
370 is_mandatory: Mandatory,
371 spawn_meta: SpawnMeta<'_>,
372 rt: &Handle,
373 ) -> (JoinHandle<R>, Result<(), SpawnError>)
374 where
375 F: FnOnce() -> R + Send + 'static,
376 R: Send + 'static,
377 {
378 let id = task::Id::next();
379 let fut =
380 blocking_task::<F, BlockingTask<F>>(BlockingTask::new(func), spawn_meta, id.as_u64());
381
382 let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id);
383
384 let spawned = self.spawn_task(Task::new(task, is_mandatory), rt);
385 (handle, spawned)
386 }
387
388 fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> {
389 let mut shared = self.inner.shared.lock();
390
391 if shared.shutdown {
392 task.task.shutdown();
396
397 return Err(SpawnError::ShuttingDown);
399 }
400
401 shared.queue.push_back(task);
402 self.inner.metrics.inc_queue_depth();
403
404 if self.inner.metrics.num_idle_threads() == 0 {
405 if self.inner.metrics.num_threads() == self.inner.thread_cap {
408 } else {
410 assert!(shared.shutdown_tx.is_some());
411 let shutdown_tx = shared.shutdown_tx.clone();
412
413 if let Some(shutdown_tx) = shutdown_tx {
414 let id = shared.worker_thread_index;
415
416 match self.spawn_thread(shutdown_tx, rt, id) {
417 Ok(handle) => {
418 self.inner.metrics.inc_num_threads();
419 shared.worker_thread_index += 1;
420 shared.worker_threads.insert(id, handle);
421 }
422 Err(ref e)
423 if is_temporary_os_thread_error(e)
424 && self.inner.metrics.num_threads() > 0 =>
425 {
426 }
430 Err(e) => {
431 return Err(SpawnError::NoThreads(e));
434 }
435 }
436 }
437 }
438 } else {
439 self.inner.metrics.dec_num_idle_threads();
445 shared.num_notify += 1;
446 self.inner.condvar.notify_one();
447 }
448
449 Ok(())
450 }
451
452 fn spawn_thread(
453 &self,
454 shutdown_tx: shutdown::Sender,
455 rt: &Handle,
456 id: usize,
457 ) -> io::Result<thread::JoinHandle<()>> {
458 let mut builder = thread::Builder::new().name((self.inner.thread_name)());
459
460 if let Some(stack_size) = self.inner.stack_size {
461 builder = builder.stack_size(stack_size);
462 }
463
464 let rt = rt.clone();
465
466 builder.spawn(move || {
467 let _enter = rt.enter();
469 rt.inner.blocking_spawner().inner.run(id);
470 drop(shutdown_tx);
471 })
472 }
473}
474
475cfg_unstable_metrics! {
476 impl Spawner {
477 pub(crate) fn num_threads(&self) -> usize {
478 self.inner.metrics.num_threads()
479 }
480
481 pub(crate) fn num_idle_threads(&self) -> usize {
482 self.inner.metrics.num_idle_threads()
483 }
484
485 pub(crate) fn queue_depth(&self) -> usize {
486 self.inner.metrics.queue_depth()
487 }
488 }
489}
490
491#[inline]
493fn is_temporary_os_thread_error(error: &io::Error) -> bool {
494 matches!(error.kind(), io::ErrorKind::WouldBlock)
495}
496
497impl Inner {
498 fn run(&self, worker_thread_id: usize) {
499 if let Some(f) = &self.after_start {
500 f();
501 }
502
503 let mut shared = self.shared.lock();
504 let mut join_on_thread = None;
505
506 'main: loop {
507 while let Some(task) = shared.queue.pop_front() {
509 self.metrics.dec_queue_depth();
510 drop(shared);
511 task.run();
512
513 shared = self.shared.lock();
514 }
515
516 self.metrics.inc_num_idle_threads();
518
519 while !shared.shutdown {
520 let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
521
522 shared = lock_result.0;
523 let timeout_result = lock_result.1;
524
525 if shared.num_notify != 0 {
526 shared.num_notify -= 1;
530 break;
531 }
532
533 if !shared.shutdown && timeout_result.timed_out() {
536 let my_handle = shared.worker_threads.remove(&worker_thread_id);
540 join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
541
542 break 'main;
543 }
544
545 }
547
548 if shared.shutdown {
549 while let Some(task) = shared.queue.pop_front() {
551 self.metrics.dec_queue_depth();
552 drop(shared);
553
554 task.shutdown_or_run_if_mandatory();
555
556 shared = self.shared.lock();
557 }
558
559 self.metrics.inc_num_idle_threads();
563 break;
566 }
567 }
568
569 self.metrics.dec_num_threads();
571
572 let prev_idle = self.metrics.dec_num_idle_threads();
576 assert!(
577 prev_idle >= self.metrics.num_idle_threads(),
578 "num_idle_threads underflowed on thread exit"
579 );
580
581 if shared.shutdown && self.metrics.num_threads() == 0 {
582 self.condvar.notify_one();
583 }
584
585 drop(shared);
586
587 if let Some(f) = &self.before_stop {
588 f();
589 }
590
591 if let Some(handle) = join_on_thread {
592 let _ = handle.join();
593 }
594 }
595}
596
597impl fmt::Debug for Spawner {
598 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
599 fmt.debug_struct("blocking::Spawner").finish()
600 }
601}