thread_local/
lib.rs

1// Copyright 2017 Amanieu d'Antras
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Per-object thread-local storage
9//!
10//! This library provides the `ThreadLocal` type which allows a separate copy of
11//! an object to be used for each thread. This allows for per-object
12//! thread-local storage, unlike the standard library's `thread_local!` macro
13//! which only allows static thread-local storage.
14//!
15//! Per-thread objects are not destroyed when a thread exits. Instead, objects
16//! are only destroyed when the `ThreadLocal` containing them is destroyed.
17//!
18//! You can also iterate over the thread-local values of all thread in a
19//! `ThreadLocal` object using the `iter_mut` and `into_iter` methods. This can
20//! only be done if you have mutable access to the `ThreadLocal` object, which
21//! guarantees that you are the only thread currently accessing it.
22//!
23//! Note that since thread IDs are recycled when a thread exits, it is possible
24//! for one thread to retrieve the object of another thread. Since this can only
25//! occur after a thread has exited this does not lead to any race conditions.
26//!
27//! # Examples
28//!
29//! Basic usage of `ThreadLocal`:
30//!
31//! ```rust
32//! use thread_local::ThreadLocal;
33//! let tls: ThreadLocal<u32> = ThreadLocal::new();
34//! assert_eq!(tls.get(), None);
35//! assert_eq!(tls.get_or(|| 5), &5);
36//! assert_eq!(tls.get(), Some(&5));
37//! ```
38//!
39//! Combining thread-local values into a single result:
40//!
41//! ```rust
42//! use thread_local::ThreadLocal;
43//! use std::sync::Arc;
44//! use std::cell::Cell;
45//! use std::thread;
46//!
47//! let tls = Arc::new(ThreadLocal::new());
48//!
49//! // Create a bunch of threads to do stuff
50//! for _ in 0..5 {
51//!     let tls2 = tls.clone();
52//!     thread::spawn(move || {
53//!         // Increment a counter to count some event...
54//!         let cell = tls2.get_or(|| Cell::new(0));
55//!         cell.set(cell.get() + 1);
56//!     }).join().unwrap();
57//! }
58//!
59//! // Once all threads are done, collect the counter values and return the
60//! // sum of all thread-local counter values.
61//! let tls = Arc::try_unwrap(tls).unwrap();
62//! let total = tls.into_iter().fold(0, |x, y| x + y.get());
63//! assert_eq!(total, 5);
64//! ```
65
66#![warn(missing_docs)]
67#![allow(clippy::mutex_atomic)]
68#![cfg_attr(feature = "nightly", feature(thread_local))]
69
70mod cached;
71mod thread_id;
72mod unreachable;
73
74#[allow(deprecated)]
75pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
76
77use std::cell::UnsafeCell;
78use std::fmt;
79use std::iter::FusedIterator;
80use std::mem;
81use std::mem::MaybeUninit;
82use std::panic::UnwindSafe;
83use std::ptr;
84use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
85use thread_id::Thread;
86use unreachable::UncheckedResultExt;
87
88// Use usize::BITS once it has stabilized and the MSRV has been bumped.
89#[cfg(target_pointer_width = "16")]
90const POINTER_WIDTH: u8 = 16;
91#[cfg(target_pointer_width = "32")]
92const POINTER_WIDTH: u8 = 32;
93#[cfg(target_pointer_width = "64")]
94const POINTER_WIDTH: u8 = 64;
95
96/// The total number of buckets stored in each thread local.
97/// All buckets combined can hold up to `usize::MAX - 1` entries.
98const BUCKETS: usize = (POINTER_WIDTH - 1) as usize;
99
100/// Thread-local variable wrapper
101///
102/// See the [module-level documentation](index.html) for more.
103pub struct ThreadLocal<T: Send> {
104    /// The buckets in the thread local. The nth bucket contains `2^n`
105    /// elements. Each bucket is lazily allocated.
106    buckets: [AtomicPtr<Entry<T>>; BUCKETS],
107
108    /// The number of values in the thread local. This can be less than the real number of values,
109    /// but is never more.
110    values: AtomicUsize,
111}
112
113struct Entry<T> {
114    present: AtomicBool,
115    value: UnsafeCell<MaybeUninit<T>>,
116}
117
118impl<T> Drop for Entry<T> {
119    fn drop(&mut self) {
120        unsafe {
121            if *self.present.get_mut() {
122                ptr::drop_in_place((*self.value.get()).as_mut_ptr());
123            }
124        }
125    }
126}
127
128// ThreadLocal is always Sync, even if T isn't
129unsafe impl<T: Send> Sync for ThreadLocal<T> {}
130
131impl<T: Send> Default for ThreadLocal<T> {
132    fn default() -> ThreadLocal<T> {
133        ThreadLocal::new()
134    }
135}
136
137impl<T: Send> Drop for ThreadLocal<T> {
138    fn drop(&mut self) {
139        // Free each non-null bucket
140        for (i, bucket) in self.buckets.iter_mut().enumerate() {
141            let bucket_ptr = *bucket.get_mut();
142
143            let this_bucket_size = 1 << i;
144
145            if bucket_ptr.is_null() {
146                continue;
147            }
148
149            unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) };
150        }
151    }
152}
153
154impl<T: Send> ThreadLocal<T> {
155    /// Creates a new empty `ThreadLocal`.
156    pub const fn new() -> ThreadLocal<T> {
157        let buckets = [ptr::null_mut::<Entry<T>>(); BUCKETS];
158        Self {
159            buckets: unsafe { mem::transmute(buckets) },
160            values: AtomicUsize::new(0),
161        }
162    }
163
164    /// Creates a new `ThreadLocal` with an initial capacity. If less than the capacity threads
165    /// access the thread local it will never reallocate. The capacity may be rounded up to the
166    /// nearest power of two.
167    pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
168        let allocated_buckets = usize::from(POINTER_WIDTH) - (capacity.leading_zeros() as usize);
169
170        let mut buckets = [ptr::null_mut(); BUCKETS];
171        for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() {
172            *bucket = allocate_bucket::<T>(1 << i);
173        }
174
175        Self {
176            // Safety: AtomicPtr has the same representation as a pointer and arrays have the same
177            // representation as a sequence of their inner type.
178            buckets: unsafe { mem::transmute(buckets) },
179            values: AtomicUsize::new(0),
180        }
181    }
182
183    /// Returns the element for the current thread, if it exists.
184    pub fn get(&self) -> Option<&T> {
185        self.get_inner(thread_id::get())
186    }
187
188    /// Returns the element for the current thread, or creates it if it doesn't
189    /// exist.
190    pub fn get_or<F>(&self, create: F) -> &T
191    where
192        F: FnOnce() -> T,
193    {
194        unsafe {
195            self.get_or_try(|| Ok::<T, ()>(create()))
196                .unchecked_unwrap_ok()
197        }
198    }
199
200    /// Returns the element for the current thread, or creates it if it doesn't
201    /// exist. If `create` fails, that error is returned and no element is
202    /// added.
203    pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
204    where
205        F: FnOnce() -> Result<T, E>,
206    {
207        let thread = thread_id::get();
208        if let Some(val) = self.get_inner(thread) {
209            return Ok(val);
210        }
211
212        Ok(self.insert(thread, create()?))
213    }
214
215    fn get_inner(&self, thread: Thread) -> Option<&T> {
216        let bucket_ptr =
217            unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
218        if bucket_ptr.is_null() {
219            return None;
220        }
221        unsafe {
222            let entry = &*bucket_ptr.add(thread.index);
223            if entry.present.load(Ordering::Relaxed) {
224                Some(&*(&*entry.value.get()).as_ptr())
225            } else {
226                None
227            }
228        }
229    }
230
231    #[cold]
232    fn insert(&self, thread: Thread, data: T) -> &T {
233        let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
234        let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
235
236        // If the bucket doesn't already exist, we need to allocate it
237        let bucket_ptr = if bucket_ptr.is_null() {
238            let new_bucket = allocate_bucket(thread.bucket_size);
239
240            match bucket_atomic_ptr.compare_exchange(
241                ptr::null_mut(),
242                new_bucket,
243                Ordering::AcqRel,
244                Ordering::Acquire,
245            ) {
246                Ok(_) => new_bucket,
247                // If the bucket value changed (from null), that means
248                // another thread stored a new bucket before we could,
249                // and we can free our bucket and use that one instead
250                Err(bucket_ptr) => {
251                    unsafe { deallocate_bucket(new_bucket, thread.bucket_size) }
252                    bucket_ptr
253                }
254            }
255        } else {
256            bucket_ptr
257        };
258
259        // Insert the new element into the bucket
260        let entry = unsafe { &*bucket_ptr.add(thread.index) };
261        let value_ptr = entry.value.get();
262        unsafe { value_ptr.write(MaybeUninit::new(data)) };
263        entry.present.store(true, Ordering::Release);
264
265        self.values.fetch_add(1, Ordering::Release);
266
267        unsafe { &*(&*value_ptr).as_ptr() }
268    }
269
270    /// Returns an iterator over the local values of all threads in unspecified
271    /// order.
272    ///
273    /// This call can be done safely, as `T` is required to implement [`Sync`].
274    pub fn iter(&self) -> Iter<'_, T>
275    where
276        T: Sync,
277    {
278        Iter {
279            thread_local: self,
280            raw: RawIter::new(),
281        }
282    }
283
284    /// Returns a mutable iterator over the local values of all threads in
285    /// unspecified order.
286    ///
287    /// Since this call borrows the `ThreadLocal` mutably, this operation can
288    /// be done safely---the mutable borrow statically guarantees no other
289    /// threads are currently accessing their associated values.
290    pub fn iter_mut(&mut self) -> IterMut<T> {
291        IterMut {
292            thread_local: self,
293            raw: RawIter::new(),
294        }
295    }
296
297    /// Removes all thread-specific values from the `ThreadLocal`, effectively
298    /// reseting it to its original state.
299    ///
300    /// Since this call borrows the `ThreadLocal` mutably, this operation can
301    /// be done safely---the mutable borrow statically guarantees no other
302    /// threads are currently accessing their associated values.
303    pub fn clear(&mut self) {
304        *self = ThreadLocal::new();
305    }
306}
307
308impl<T: Send> IntoIterator for ThreadLocal<T> {
309    type Item = T;
310    type IntoIter = IntoIter<T>;
311
312    fn into_iter(self) -> IntoIter<T> {
313        IntoIter {
314            thread_local: self,
315            raw: RawIter::new(),
316        }
317    }
318}
319
320impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal<T> {
321    type Item = &'a T;
322    type IntoIter = Iter<'a, T>;
323
324    fn into_iter(self) -> Self::IntoIter {
325        self.iter()
326    }
327}
328
329impl<'a, T: Send> IntoIterator for &'a mut ThreadLocal<T> {
330    type Item = &'a mut T;
331    type IntoIter = IterMut<'a, T>;
332
333    fn into_iter(self) -> IterMut<'a, T> {
334        self.iter_mut()
335    }
336}
337
338impl<T: Send + Default> ThreadLocal<T> {
339    /// Returns the element for the current thread, or creates a default one if
340    /// it doesn't exist.
341    pub fn get_or_default(&self) -> &T {
342        self.get_or(Default::default)
343    }
344}
345
346impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
347    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
348        write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
349    }
350}
351
352impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
353
354#[derive(Debug)]
355struct RawIter {
356    yielded: usize,
357    bucket: usize,
358    bucket_size: usize,
359    index: usize,
360}
361impl RawIter {
362    #[inline]
363    fn new() -> Self {
364        Self {
365            yielded: 0,
366            bucket: 0,
367            bucket_size: 1,
368            index: 0,
369        }
370    }
371
372    fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
373        while self.bucket < BUCKETS {
374            let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
375            let bucket = bucket.load(Ordering::Acquire);
376
377            if !bucket.is_null() {
378                while self.index < self.bucket_size {
379                    let entry = unsafe { &*bucket.add(self.index) };
380                    self.index += 1;
381                    if entry.present.load(Ordering::Acquire) {
382                        self.yielded += 1;
383                        return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
384                    }
385                }
386            }
387
388            self.next_bucket();
389        }
390        None
391    }
392    fn next_mut<'a, T: Send>(
393        &mut self,
394        thread_local: &'a mut ThreadLocal<T>,
395    ) -> Option<&'a mut Entry<T>> {
396        if *thread_local.values.get_mut() == self.yielded {
397            return None;
398        }
399
400        loop {
401            let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
402            let bucket = *bucket.get_mut();
403
404            if !bucket.is_null() {
405                while self.index < self.bucket_size {
406                    let entry = unsafe { &mut *bucket.add(self.index) };
407                    self.index += 1;
408                    if *entry.present.get_mut() {
409                        self.yielded += 1;
410                        return Some(entry);
411                    }
412                }
413            }
414
415            self.next_bucket();
416        }
417    }
418
419    #[inline]
420    fn next_bucket(&mut self) {
421        self.bucket_size <<= 1;
422        self.bucket += 1;
423        self.index = 0;
424    }
425
426    fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
427        let total = thread_local.values.load(Ordering::Acquire);
428        (total - self.yielded, None)
429    }
430    fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
431        let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
432        let remaining = total - self.yielded;
433        (remaining, Some(remaining))
434    }
435}
436
437/// Iterator over the contents of a `ThreadLocal`.
438#[derive(Debug)]
439pub struct Iter<'a, T: Send + Sync> {
440    thread_local: &'a ThreadLocal<T>,
441    raw: RawIter,
442}
443
444impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
445    type Item = &'a T;
446    fn next(&mut self) -> Option<Self::Item> {
447        self.raw.next(self.thread_local)
448    }
449    fn size_hint(&self) -> (usize, Option<usize>) {
450        self.raw.size_hint(self.thread_local)
451    }
452}
453impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
454
455/// Mutable iterator over the contents of a `ThreadLocal`.
456pub struct IterMut<'a, T: Send> {
457    thread_local: &'a mut ThreadLocal<T>,
458    raw: RawIter,
459}
460
461impl<'a, T: Send> Iterator for IterMut<'a, T> {
462    type Item = &'a mut T;
463    fn next(&mut self) -> Option<&'a mut T> {
464        self.raw
465            .next_mut(self.thread_local)
466            .map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
467    }
468    fn size_hint(&self) -> (usize, Option<usize>) {
469        self.raw.size_hint_frozen(self.thread_local)
470    }
471}
472
473impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
474impl<T: Send> FusedIterator for IterMut<'_, T> {}
475
476// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to
477// this thread's value that potentially aliases with a mutable reference we have given out.
478impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
479    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480        f.debug_struct("IterMut").field("raw", &self.raw).finish()
481    }
482}
483
484/// An iterator that moves out of a `ThreadLocal`.
485#[derive(Debug)]
486pub struct IntoIter<T: Send> {
487    thread_local: ThreadLocal<T>,
488    raw: RawIter,
489}
490
491impl<T: Send> Iterator for IntoIter<T> {
492    type Item = T;
493    fn next(&mut self) -> Option<T> {
494        self.raw.next_mut(&mut self.thread_local).map(|entry| {
495            *entry.present.get_mut() = false;
496            unsafe {
497                std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
498            }
499        })
500    }
501    fn size_hint(&self) -> (usize, Option<usize>) {
502        self.raw.size_hint_frozen(&self.thread_local)
503    }
504}
505
506impl<T: Send> ExactSizeIterator for IntoIter<T> {}
507impl<T: Send> FusedIterator for IntoIter<T> {}
508
509fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
510    Box::into_raw(
511        (0..size)
512            .map(|_| Entry::<T> {
513                present: AtomicBool::new(false),
514                value: UnsafeCell::new(MaybeUninit::uninit()),
515            })
516            .collect(),
517    ) as *mut _
518}
519
520unsafe fn deallocate_bucket<T>(bucket: *mut Entry<T>, size: usize) {
521    let _ = Box::from_raw(std::slice::from_raw_parts_mut(bucket, size));
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    use std::cell::RefCell;
529    use std::sync::atomic::AtomicUsize;
530    use std::sync::atomic::Ordering::Relaxed;
531    use std::sync::Arc;
532    use std::thread;
533
534    fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
535        let count = AtomicUsize::new(0);
536        Arc::new(move || count.fetch_add(1, Relaxed))
537    }
538
539    #[test]
540    fn same_thread() {
541        let create = make_create();
542        let mut tls = ThreadLocal::new();
543        assert_eq!(None, tls.get());
544        assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
545        assert_eq!(0, *tls.get_or(|| create()));
546        assert_eq!(Some(&0), tls.get());
547        assert_eq!(0, *tls.get_or(|| create()));
548        assert_eq!(Some(&0), tls.get());
549        assert_eq!(0, *tls.get_or(|| create()));
550        assert_eq!(Some(&0), tls.get());
551        assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
552        tls.clear();
553        assert_eq!(None, tls.get());
554    }
555
556    #[test]
557    fn different_thread() {
558        let create = make_create();
559        let tls = Arc::new(ThreadLocal::new());
560        assert_eq!(None, tls.get());
561        assert_eq!(0, *tls.get_or(|| create()));
562        assert_eq!(Some(&0), tls.get());
563
564        let tls2 = tls.clone();
565        let create2 = create.clone();
566        thread::spawn(move || {
567            assert_eq!(None, tls2.get());
568            assert_eq!(1, *tls2.get_or(|| create2()));
569            assert_eq!(Some(&1), tls2.get());
570        })
571        .join()
572        .unwrap();
573
574        assert_eq!(Some(&0), tls.get());
575        assert_eq!(0, *tls.get_or(|| create()));
576    }
577
578    #[test]
579    fn iter() {
580        let tls = Arc::new(ThreadLocal::new());
581        tls.get_or(|| Box::new(1));
582
583        let tls2 = tls.clone();
584        thread::spawn(move || {
585            tls2.get_or(|| Box::new(2));
586            let tls3 = tls2.clone();
587            thread::spawn(move || {
588                tls3.get_or(|| Box::new(3));
589            })
590            .join()
591            .unwrap();
592            drop(tls2);
593        })
594        .join()
595        .unwrap();
596
597        let mut tls = Arc::try_unwrap(tls).unwrap();
598
599        let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
600        v.sort_unstable();
601        assert_eq!(vec![1, 2, 3], v);
602
603        let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
604        v.sort_unstable();
605        assert_eq!(vec![1, 2, 3], v);
606
607        let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
608        v.sort_unstable();
609        assert_eq!(vec![1, 2, 3], v);
610    }
611
612    #[test]
613    fn miri_iter_soundness_check() {
614        let tls = Arc::new(ThreadLocal::new());
615        let _local = tls.get_or(|| Box::new(1));
616
617        let tls2 = tls.clone();
618        let join_1 = thread::spawn(move || {
619            let _tls = tls2.get_or(|| Box::new(2));
620            let iter = tls2.iter();
621            for item in iter {
622                println!("{:?}", item);
623            }
624        });
625
626        let iter = tls.iter();
627        for item in iter {
628            println!("{:?}", item);
629        }
630
631        join_1.join().ok();
632    }
633
634    #[test]
635    fn test_drop() {
636        let local = ThreadLocal::new();
637        struct Dropped(Arc<AtomicUsize>);
638        impl Drop for Dropped {
639            fn drop(&mut self) {
640                self.0.fetch_add(1, Relaxed);
641            }
642        }
643
644        let dropped = Arc::new(AtomicUsize::new(0));
645        local.get_or(|| Dropped(dropped.clone()));
646        assert_eq!(dropped.load(Relaxed), 0);
647        drop(local);
648        assert_eq!(dropped.load(Relaxed), 1);
649    }
650
651    #[test]
652    fn test_earlyreturn_buckets() {
653        struct Dropped(Arc<AtomicUsize>);
654        impl Drop for Dropped {
655            fn drop(&mut self) {
656                self.0.fetch_add(1, Relaxed);
657            }
658        }
659        let dropped = Arc::new(AtomicUsize::new(0));
660
661        // We use a high `id` here to guarantee that a lazily allocated bucket somewhere in the middle is used.
662        // Neither iteration nor `Drop` must early-return on `null` buckets that are used for lower `buckets`.
663        let thread = Thread::new(1234);
664        assert!(thread.bucket > 1);
665
666        let mut local = ThreadLocal::new();
667        local.insert(thread, Dropped(dropped.clone()));
668
669        let item = local.iter().next().unwrap();
670        assert_eq!(item.0.load(Relaxed), 0);
671        let item = local.iter_mut().next().unwrap();
672        assert_eq!(item.0.load(Relaxed), 0);
673        drop(local);
674        assert_eq!(dropped.load(Relaxed), 1);
675    }
676
677    #[test]
678    fn is_sync() {
679        fn foo<T: Sync>() {}
680        foo::<ThreadLocal<String>>();
681        foo::<ThreadLocal<RefCell<String>>>();
682    }
683}