1use crate::raw_mutex::RawMutex;
9use core::num::NonZeroUsize;
10use lock_api::{self, GetThreadId};
11
12pub struct RawThreadId;
14
15unsafe impl GetThreadId for RawThreadId {
16 const INIT: RawThreadId = RawThreadId;
17
18 fn nonzero_thread_id(&self) -> NonZeroUsize {
19 thread_local!(static KEY: u8 = 0);
23 KEY.with(|x| {
24 NonZeroUsize::new(x as *const _ as usize)
25 .expect("thread-local variable address is null")
26 })
27 }
28}
29
30pub type ReentrantMutex<T> = lock_api::ReentrantMutex<RawMutex, RawThreadId, T>;
42
43pub const fn const_reentrant_mutex<T>(val: T) -> ReentrantMutex<T> {
47 ReentrantMutex::const_new(
48 <RawMutex as lock_api::RawMutex>::INIT,
49 <RawThreadId as lock_api::GetThreadId>::INIT,
50 val,
51 )
52}
53
54pub type ReentrantMutexGuard<'a, T> = lock_api::ReentrantMutexGuard<'a, RawMutex, RawThreadId, T>;
60
61pub type MappedReentrantMutexGuard<'a, T> =
69 lock_api::MappedReentrantMutexGuard<'a, RawMutex, RawThreadId, T>;
70
71#[cfg(test)]
72mod tests {
73 use crate::ReentrantMutex;
74 use crate::ReentrantMutexGuard;
75 use std::cell::RefCell;
76 use std::sync::mpsc::channel;
77 use std::sync::Arc;
78 use std::thread;
79
80 #[cfg(feature = "serde")]
81 use bincode::{deserialize, serialize};
82
83 #[test]
84 fn smoke() {
85 let m = ReentrantMutex::new(2);
86 {
87 let a = m.lock();
88 {
89 let b = m.lock();
90 {
91 let c = m.lock();
92 assert_eq!(*c, 2);
93 }
94 assert_eq!(*b, 2);
95 }
96 assert_eq!(*a, 2);
97 }
98 }
99
100 #[test]
101 fn is_mutex() {
102 let m = Arc::new(ReentrantMutex::new(RefCell::new(0)));
103 let m2 = m.clone();
104 let lock = m.lock();
105 let child = thread::spawn(move || {
106 let lock = m2.lock();
107 assert_eq!(*lock.borrow(), 4950);
108 });
109 for i in 0..100 {
110 let lock = m.lock();
111 *lock.borrow_mut() += i;
112 }
113 drop(lock);
114 child.join().unwrap();
115 }
116
117 #[test]
118 fn trylock_works() {
119 let m = Arc::new(ReentrantMutex::new(()));
120 let m2 = m.clone();
121 let _lock = m.try_lock();
122 let _lock2 = m.try_lock();
123 thread::spawn(move || {
124 let lock = m2.try_lock();
125 assert!(lock.is_none());
126 })
127 .join()
128 .unwrap();
129 let _lock3 = m.try_lock();
130 }
131
132 #[test]
133 fn test_reentrant_mutex_debug() {
134 let mutex = ReentrantMutex::new(vec![0u8, 10]);
135
136 assert_eq!(format!("{:?}", mutex), "ReentrantMutex { data: [0, 10] }");
137 }
138
139 #[test]
140 fn test_reentrant_mutex_bump() {
141 let mutex = Arc::new(ReentrantMutex::new(()));
142 let mutex2 = mutex.clone();
143
144 let mut guard = mutex.lock();
145
146 let (tx, rx) = channel();
147
148 thread::spawn(move || {
149 let _guard = mutex2.lock();
150 tx.send(()).unwrap();
151 });
152
153 while rx.try_recv().is_err() {
155 ReentrantMutexGuard::bump(&mut guard);
156 }
157 }
158
159 #[cfg(feature = "serde")]
160 #[test]
161 fn test_serde() {
162 let contents: Vec<u8> = vec![0, 1, 2];
163 let mutex = ReentrantMutex::new(contents.clone());
164
165 let serialized = serialize(&mutex).unwrap();
166 let deserialized: ReentrantMutex<Vec<u8>> = deserialize(&serialized).unwrap();
167
168 assert_eq!(*(mutex.lock()), *(deserialized.lock()));
169 assert_eq!(contents, *(deserialized.lock()));
170 }
171}