1#![warn(missing_docs)]
67#![allow(clippy::mutex_atomic)]
68#![cfg_attr(feature = "nightly", feature(thread_local))]
69
70mod cached;
71mod thread_id;
72mod unreachable;
73
74#[allow(deprecated)]
75pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
76
77use std::cell::UnsafeCell;
78use std::fmt;
79use std::iter::FusedIterator;
80use std::mem;
81use std::mem::MaybeUninit;
82use std::panic::UnwindSafe;
83use std::ptr;
84use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
85use thread_id::Thread;
86use unreachable::UncheckedResultExt;
87
88#[cfg(target_pointer_width = "16")]
90const POINTER_WIDTH: u8 = 16;
91#[cfg(target_pointer_width = "32")]
92const POINTER_WIDTH: u8 = 32;
93#[cfg(target_pointer_width = "64")]
94const POINTER_WIDTH: u8 = 64;
95
96const BUCKETS: usize = (POINTER_WIDTH - 1) as usize;
99
100pub struct ThreadLocal<T: Send> {
104 buckets: [AtomicPtr<Entry<T>>; BUCKETS],
107
108 values: AtomicUsize,
111}
112
113struct Entry<T> {
114 present: AtomicBool,
115 value: UnsafeCell<MaybeUninit<T>>,
116}
117
118impl<T> Drop for Entry<T> {
119 fn drop(&mut self) {
120 unsafe {
121 if *self.present.get_mut() {
122 ptr::drop_in_place((*self.value.get()).as_mut_ptr());
123 }
124 }
125 }
126}
127
128unsafe impl<T: Send> Sync for ThreadLocal<T> {}
130
131impl<T: Send> Default for ThreadLocal<T> {
132 fn default() -> ThreadLocal<T> {
133 ThreadLocal::new()
134 }
135}
136
137impl<T: Send> Drop for ThreadLocal<T> {
138 fn drop(&mut self) {
139 for (i, bucket) in self.buckets.iter_mut().enumerate() {
141 let bucket_ptr = *bucket.get_mut();
142
143 let this_bucket_size = 1 << i;
144
145 if bucket_ptr.is_null() {
146 continue;
147 }
148
149 unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) };
150 }
151 }
152}
153
154impl<T: Send> ThreadLocal<T> {
155 pub const fn new() -> ThreadLocal<T> {
157 let buckets = [ptr::null_mut::<Entry<T>>(); BUCKETS];
158 Self {
159 buckets: unsafe { mem::transmute(buckets) },
160 values: AtomicUsize::new(0),
161 }
162 }
163
164 pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
168 let allocated_buckets = usize::from(POINTER_WIDTH) - (capacity.leading_zeros() as usize);
169
170 let mut buckets = [ptr::null_mut(); BUCKETS];
171 for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() {
172 *bucket = allocate_bucket::<T>(1 << i);
173 }
174
175 Self {
176 buckets: unsafe { mem::transmute(buckets) },
179 values: AtomicUsize::new(0),
180 }
181 }
182
183 pub fn get(&self) -> Option<&T> {
185 self.get_inner(thread_id::get())
186 }
187
188 pub fn get_or<F>(&self, create: F) -> &T
191 where
192 F: FnOnce() -> T,
193 {
194 unsafe {
195 self.get_or_try(|| Ok::<T, ()>(create()))
196 .unchecked_unwrap_ok()
197 }
198 }
199
200 pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
204 where
205 F: FnOnce() -> Result<T, E>,
206 {
207 let thread = thread_id::get();
208 if let Some(val) = self.get_inner(thread) {
209 return Ok(val);
210 }
211
212 Ok(self.insert(thread, create()?))
213 }
214
215 fn get_inner(&self, thread: Thread) -> Option<&T> {
216 let bucket_ptr =
217 unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
218 if bucket_ptr.is_null() {
219 return None;
220 }
221 unsafe {
222 let entry = &*bucket_ptr.add(thread.index);
223 if entry.present.load(Ordering::Relaxed) {
224 Some(&*(&*entry.value.get()).as_ptr())
225 } else {
226 None
227 }
228 }
229 }
230
231 #[cold]
232 fn insert(&self, thread: Thread, data: T) -> &T {
233 let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
234 let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
235
236 let bucket_ptr = if bucket_ptr.is_null() {
238 let new_bucket = allocate_bucket(thread.bucket_size);
239
240 match bucket_atomic_ptr.compare_exchange(
241 ptr::null_mut(),
242 new_bucket,
243 Ordering::AcqRel,
244 Ordering::Acquire,
245 ) {
246 Ok(_) => new_bucket,
247 Err(bucket_ptr) => {
251 unsafe { deallocate_bucket(new_bucket, thread.bucket_size) }
252 bucket_ptr
253 }
254 }
255 } else {
256 bucket_ptr
257 };
258
259 let entry = unsafe { &*bucket_ptr.add(thread.index) };
261 let value_ptr = entry.value.get();
262 unsafe { value_ptr.write(MaybeUninit::new(data)) };
263 entry.present.store(true, Ordering::Release);
264
265 self.values.fetch_add(1, Ordering::Release);
266
267 unsafe { &*(&*value_ptr).as_ptr() }
268 }
269
270 pub fn iter(&self) -> Iter<'_, T>
275 where
276 T: Sync,
277 {
278 Iter {
279 thread_local: self,
280 raw: RawIter::new(),
281 }
282 }
283
284 pub fn iter_mut(&mut self) -> IterMut<T> {
291 IterMut {
292 thread_local: self,
293 raw: RawIter::new(),
294 }
295 }
296
297 pub fn clear(&mut self) {
304 *self = ThreadLocal::new();
305 }
306}
307
308impl<T: Send> IntoIterator for ThreadLocal<T> {
309 type Item = T;
310 type IntoIter = IntoIter<T>;
311
312 fn into_iter(self) -> IntoIter<T> {
313 IntoIter {
314 thread_local: self,
315 raw: RawIter::new(),
316 }
317 }
318}
319
320impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal<T> {
321 type Item = &'a T;
322 type IntoIter = Iter<'a, T>;
323
324 fn into_iter(self) -> Self::IntoIter {
325 self.iter()
326 }
327}
328
329impl<'a, T: Send> IntoIterator for &'a mut ThreadLocal<T> {
330 type Item = &'a mut T;
331 type IntoIter = IterMut<'a, T>;
332
333 fn into_iter(self) -> IterMut<'a, T> {
334 self.iter_mut()
335 }
336}
337
338impl<T: Send + Default> ThreadLocal<T> {
339 pub fn get_or_default(&self) -> &T {
342 self.get_or(Default::default)
343 }
344}
345
346impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
347 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
348 write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
349 }
350}
351
352impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
353
354#[derive(Debug)]
355struct RawIter {
356 yielded: usize,
357 bucket: usize,
358 bucket_size: usize,
359 index: usize,
360}
361impl RawIter {
362 #[inline]
363 fn new() -> Self {
364 Self {
365 yielded: 0,
366 bucket: 0,
367 bucket_size: 1,
368 index: 0,
369 }
370 }
371
372 fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
373 while self.bucket < BUCKETS {
374 let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
375 let bucket = bucket.load(Ordering::Acquire);
376
377 if !bucket.is_null() {
378 while self.index < self.bucket_size {
379 let entry = unsafe { &*bucket.add(self.index) };
380 self.index += 1;
381 if entry.present.load(Ordering::Acquire) {
382 self.yielded += 1;
383 return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
384 }
385 }
386 }
387
388 self.next_bucket();
389 }
390 None
391 }
392 fn next_mut<'a, T: Send>(
393 &mut self,
394 thread_local: &'a mut ThreadLocal<T>,
395 ) -> Option<&'a mut Entry<T>> {
396 if *thread_local.values.get_mut() == self.yielded {
397 return None;
398 }
399
400 loop {
401 let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
402 let bucket = *bucket.get_mut();
403
404 if !bucket.is_null() {
405 while self.index < self.bucket_size {
406 let entry = unsafe { &mut *bucket.add(self.index) };
407 self.index += 1;
408 if *entry.present.get_mut() {
409 self.yielded += 1;
410 return Some(entry);
411 }
412 }
413 }
414
415 self.next_bucket();
416 }
417 }
418
419 #[inline]
420 fn next_bucket(&mut self) {
421 self.bucket_size <<= 1;
422 self.bucket += 1;
423 self.index = 0;
424 }
425
426 fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
427 let total = thread_local.values.load(Ordering::Acquire);
428 (total - self.yielded, None)
429 }
430 fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
431 let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
432 let remaining = total - self.yielded;
433 (remaining, Some(remaining))
434 }
435}
436
437#[derive(Debug)]
439pub struct Iter<'a, T: Send + Sync> {
440 thread_local: &'a ThreadLocal<T>,
441 raw: RawIter,
442}
443
444impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
445 type Item = &'a T;
446 fn next(&mut self) -> Option<Self::Item> {
447 self.raw.next(self.thread_local)
448 }
449 fn size_hint(&self) -> (usize, Option<usize>) {
450 self.raw.size_hint(self.thread_local)
451 }
452}
453impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
454
455pub struct IterMut<'a, T: Send> {
457 thread_local: &'a mut ThreadLocal<T>,
458 raw: RawIter,
459}
460
461impl<'a, T: Send> Iterator for IterMut<'a, T> {
462 type Item = &'a mut T;
463 fn next(&mut self) -> Option<&'a mut T> {
464 self.raw
465 .next_mut(self.thread_local)
466 .map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
467 }
468 fn size_hint(&self) -> (usize, Option<usize>) {
469 self.raw.size_hint_frozen(self.thread_local)
470 }
471}
472
473impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
474impl<T: Send> FusedIterator for IterMut<'_, T> {}
475
476impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
479 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480 f.debug_struct("IterMut").field("raw", &self.raw).finish()
481 }
482}
483
484#[derive(Debug)]
486pub struct IntoIter<T: Send> {
487 thread_local: ThreadLocal<T>,
488 raw: RawIter,
489}
490
491impl<T: Send> Iterator for IntoIter<T> {
492 type Item = T;
493 fn next(&mut self) -> Option<T> {
494 self.raw.next_mut(&mut self.thread_local).map(|entry| {
495 *entry.present.get_mut() = false;
496 unsafe {
497 std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
498 }
499 })
500 }
501 fn size_hint(&self) -> (usize, Option<usize>) {
502 self.raw.size_hint_frozen(&self.thread_local)
503 }
504}
505
506impl<T: Send> ExactSizeIterator for IntoIter<T> {}
507impl<T: Send> FusedIterator for IntoIter<T> {}
508
509fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
510 Box::into_raw(
511 (0..size)
512 .map(|_| Entry::<T> {
513 present: AtomicBool::new(false),
514 value: UnsafeCell::new(MaybeUninit::uninit()),
515 })
516 .collect(),
517 ) as *mut _
518}
519
520unsafe fn deallocate_bucket<T>(bucket: *mut Entry<T>, size: usize) {
521 let _ = Box::from_raw(std::slice::from_raw_parts_mut(bucket, size));
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 use std::cell::RefCell;
529 use std::sync::atomic::AtomicUsize;
530 use std::sync::atomic::Ordering::Relaxed;
531 use std::sync::Arc;
532 use std::thread;
533
534 fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
535 let count = AtomicUsize::new(0);
536 Arc::new(move || count.fetch_add(1, Relaxed))
537 }
538
539 #[test]
540 fn same_thread() {
541 let create = make_create();
542 let mut tls = ThreadLocal::new();
543 assert_eq!(None, tls.get());
544 assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
545 assert_eq!(0, *tls.get_or(|| create()));
546 assert_eq!(Some(&0), tls.get());
547 assert_eq!(0, *tls.get_or(|| create()));
548 assert_eq!(Some(&0), tls.get());
549 assert_eq!(0, *tls.get_or(|| create()));
550 assert_eq!(Some(&0), tls.get());
551 assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
552 tls.clear();
553 assert_eq!(None, tls.get());
554 }
555
556 #[test]
557 fn different_thread() {
558 let create = make_create();
559 let tls = Arc::new(ThreadLocal::new());
560 assert_eq!(None, tls.get());
561 assert_eq!(0, *tls.get_or(|| create()));
562 assert_eq!(Some(&0), tls.get());
563
564 let tls2 = tls.clone();
565 let create2 = create.clone();
566 thread::spawn(move || {
567 assert_eq!(None, tls2.get());
568 assert_eq!(1, *tls2.get_or(|| create2()));
569 assert_eq!(Some(&1), tls2.get());
570 })
571 .join()
572 .unwrap();
573
574 assert_eq!(Some(&0), tls.get());
575 assert_eq!(0, *tls.get_or(|| create()));
576 }
577
578 #[test]
579 fn iter() {
580 let tls = Arc::new(ThreadLocal::new());
581 tls.get_or(|| Box::new(1));
582
583 let tls2 = tls.clone();
584 thread::spawn(move || {
585 tls2.get_or(|| Box::new(2));
586 let tls3 = tls2.clone();
587 thread::spawn(move || {
588 tls3.get_or(|| Box::new(3));
589 })
590 .join()
591 .unwrap();
592 drop(tls2);
593 })
594 .join()
595 .unwrap();
596
597 let mut tls = Arc::try_unwrap(tls).unwrap();
598
599 let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
600 v.sort_unstable();
601 assert_eq!(vec![1, 2, 3], v);
602
603 let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
604 v.sort_unstable();
605 assert_eq!(vec![1, 2, 3], v);
606
607 let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
608 v.sort_unstable();
609 assert_eq!(vec![1, 2, 3], v);
610 }
611
612 #[test]
613 fn miri_iter_soundness_check() {
614 let tls = Arc::new(ThreadLocal::new());
615 let _local = tls.get_or(|| Box::new(1));
616
617 let tls2 = tls.clone();
618 let join_1 = thread::spawn(move || {
619 let _tls = tls2.get_or(|| Box::new(2));
620 let iter = tls2.iter();
621 for item in iter {
622 println!("{:?}", item);
623 }
624 });
625
626 let iter = tls.iter();
627 for item in iter {
628 println!("{:?}", item);
629 }
630
631 join_1.join().ok();
632 }
633
634 #[test]
635 fn test_drop() {
636 let local = ThreadLocal::new();
637 struct Dropped(Arc<AtomicUsize>);
638 impl Drop for Dropped {
639 fn drop(&mut self) {
640 self.0.fetch_add(1, Relaxed);
641 }
642 }
643
644 let dropped = Arc::new(AtomicUsize::new(0));
645 local.get_or(|| Dropped(dropped.clone()));
646 assert_eq!(dropped.load(Relaxed), 0);
647 drop(local);
648 assert_eq!(dropped.load(Relaxed), 1);
649 }
650
651 #[test]
652 fn test_earlyreturn_buckets() {
653 struct Dropped(Arc<AtomicUsize>);
654 impl Drop for Dropped {
655 fn drop(&mut self) {
656 self.0.fetch_add(1, Relaxed);
657 }
658 }
659 let dropped = Arc::new(AtomicUsize::new(0));
660
661 let thread = Thread::new(1234);
664 assert!(thread.bucket > 1);
665
666 let mut local = ThreadLocal::new();
667 local.insert(thread, Dropped(dropped.clone()));
668
669 let item = local.iter().next().unwrap();
670 assert_eq!(item.0.load(Relaxed), 0);
671 let item = local.iter_mut().next().unwrap();
672 assert_eq!(item.0.load(Relaxed), 0);
673 drop(local);
674 assert_eq!(dropped.load(Relaxed), 1);
675 }
676
677 #[test]
678 fn is_sync() {
679 fn foo<T: Sync>() {}
680 foo::<ThreadLocal<String>>();
681 foo::<ThreadLocal<RefCell<String>>>();
682 }
683}