redb/tree_store/page_store/
cached_file.rs

1use crate::tree_store::page_store::base::PageHint;
2use crate::tree_store::page_store::lru_cache::LRUCache;
3use crate::{DatabaseError, Result, StorageBackend, StorageError};
4use std::ops::{Index, IndexMut};
5use std::slice::SliceIndex;
6#[cfg(feature = "cache_metrics")]
7use std::sync::atomic::AtomicU64;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex, RwLock};
10
11pub(super) struct WritablePage {
12    buffer: Arc<Mutex<LRUWriteCache>>,
13    offset: u64,
14    data: Arc<[u8]>,
15}
16
17impl WritablePage {
18    pub(super) fn mem(&self) -> &[u8] {
19        &self.data
20    }
21
22    pub(super) fn mem_mut(&mut self) -> &mut [u8] {
23        Arc::get_mut(&mut self.data).unwrap()
24    }
25}
26
27impl Drop for WritablePage {
28    fn drop(&mut self) {
29        self.buffer
30            .lock()
31            .unwrap()
32            .return_value(self.offset, self.data.clone());
33    }
34}
35
36impl<I: SliceIndex<[u8]>> Index<I> for WritablePage {
37    type Output = I::Output;
38
39    fn index(&self, index: I) -> &Self::Output {
40        self.mem().index(index)
41    }
42}
43
44impl<I: SliceIndex<[u8]>> IndexMut<I> for WritablePage {
45    fn index_mut(&mut self, index: I) -> &mut Self::Output {
46        self.mem_mut().index_mut(index)
47    }
48}
49
50#[derive(Default)]
51struct LRUWriteCache {
52    cache: LRUCache<Option<Arc<[u8]>>>,
53}
54
55impl LRUWriteCache {
56    fn new() -> Self {
57        Self {
58            cache: Default::default(),
59        }
60    }
61
62    fn insert(&mut self, key: u64, value: Arc<[u8]>) {
63        assert!(self.cache.insert(key, Some(value)).is_none());
64    }
65
66    fn get(&self, key: u64) -> Option<&Arc<[u8]>> {
67        self.cache.get(key).map(|x| x.as_ref().unwrap())
68    }
69
70    fn remove(&mut self, key: u64) -> Option<Arc<[u8]>> {
71        if let Some(value) = self.cache.remove(key) {
72            assert!(value.is_some());
73            return value;
74        }
75        None
76    }
77
78    fn return_value(&mut self, key: u64, value: Arc<[u8]>) {
79        assert!(self.cache.get_mut(key).unwrap().replace(value).is_none());
80    }
81
82    fn take_value(&mut self, key: u64) -> Option<Arc<[u8]>> {
83        if let Some(value) = self.cache.get_mut(key) {
84            let result = value.take().unwrap();
85            return Some(result);
86        }
87        None
88    }
89
90    fn pop_lowest_priority(&mut self) -> Option<(u64, Arc<[u8]>)> {
91        for _ in 0..self.cache.len() {
92            if let Some((k, v)) = self.cache.pop_lowest_priority() {
93                if let Some(v_inner) = v {
94                    return Some((k, v_inner));
95                }
96
97                // Value is borrowed by take_value(). We can't evict it, so put it back.
98                self.cache.insert(k, v);
99            } else {
100                break;
101            }
102        }
103        None
104    }
105
106    fn clear(&mut self) {
107        self.cache.clear();
108    }
109}
110
111#[derive(Debug)]
112struct CheckedBackend {
113    file: Box<dyn StorageBackend>,
114    io_failed: AtomicBool,
115}
116
117impl CheckedBackend {
118    fn new(file: Box<dyn StorageBackend>) -> Self {
119        Self {
120            file,
121            io_failed: AtomicBool::new(false),
122        }
123    }
124
125    fn check_failure(&self) -> Result<()> {
126        if self.io_failed.load(Ordering::Acquire) {
127            Err(StorageError::PreviousIo)
128        } else {
129            Ok(())
130        }
131    }
132
133    fn len(&self) -> Result<u64> {
134        self.check_failure()?;
135        let result = self.file.len();
136        if result.is_err() {
137            self.io_failed.store(true, Ordering::Release);
138        }
139        result.map_err(StorageError::from)
140    }
141
142    fn read(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
143        self.check_failure()?;
144        let result = self.file.read(offset, len);
145        if result.is_err() {
146            self.io_failed.store(true, Ordering::Release);
147        }
148        result.map_err(StorageError::from)
149    }
150
151    fn set_len(&self, len: u64) -> Result<()> {
152        self.check_failure()?;
153        let result = self.file.set_len(len);
154        if result.is_err() {
155            self.io_failed.store(true, Ordering::Release);
156        }
157        result.map_err(StorageError::from)
158    }
159
160    fn sync_data(&self, eventual: bool) -> Result<()> {
161        self.check_failure()?;
162        let result = self.file.sync_data(eventual);
163        if result.is_err() {
164            self.io_failed.store(true, Ordering::Release);
165        }
166        result.map_err(StorageError::from)
167    }
168
169    fn write(&self, offset: u64, data: &[u8]) -> Result<()> {
170        self.check_failure()?;
171        let result = self.file.write(offset, data);
172        if result.is_err() {
173            self.io_failed.store(true, Ordering::Release);
174        }
175        result.map_err(StorageError::from)
176    }
177}
178
179pub(super) struct PagedCachedFile {
180    file: CheckedBackend,
181    page_size: u64,
182    max_read_cache_bytes: usize,
183    read_cache_bytes: AtomicUsize,
184    max_write_buffer_bytes: usize,
185    write_buffer_bytes: AtomicUsize,
186    #[cfg(feature = "cache_metrics")]
187    reads_total: AtomicU64,
188    #[cfg(feature = "cache_metrics")]
189    reads_hits: AtomicU64,
190    read_cache: Vec<RwLock<LRUCache<Arc<[u8]>>>>,
191    // TODO: maybe move this cache to WriteTransaction?
192    write_buffer: Arc<Mutex<LRUWriteCache>>,
193}
194
195impl PagedCachedFile {
196    pub(super) fn new(
197        file: Box<dyn StorageBackend>,
198        page_size: u64,
199        max_read_cache_bytes: usize,
200        max_write_buffer_bytes: usize,
201    ) -> Result<Self, DatabaseError> {
202        let read_cache = (0..Self::lock_stripes())
203            .map(|_| RwLock::new(LRUCache::new()))
204            .collect();
205
206        Ok(Self {
207            file: CheckedBackend::new(file),
208            page_size,
209            max_read_cache_bytes,
210            read_cache_bytes: AtomicUsize::new(0),
211            max_write_buffer_bytes,
212            write_buffer_bytes: AtomicUsize::new(0),
213            #[cfg(feature = "cache_metrics")]
214            reads_total: Default::default(),
215            #[cfg(feature = "cache_metrics")]
216            reads_hits: Default::default(),
217            read_cache,
218            write_buffer: Arc::new(Mutex::new(LRUWriteCache::new())),
219        })
220    }
221
222    pub(crate) fn check_io_errors(&self) -> Result {
223        self.file.check_failure()
224    }
225
226    pub(crate) fn raw_file_len(&self) -> Result<u64> {
227        self.file.len()
228    }
229
230    const fn lock_stripes() -> u64 {
231        131
232    }
233
234    fn flush_write_buffer(&self) -> Result {
235        let mut write_buffer = self.write_buffer.lock().unwrap();
236
237        for (offset, buffer) in write_buffer.cache.iter() {
238            self.file.write(*offset, buffer.as_ref().unwrap())?;
239        }
240        for (offset, buffer) in write_buffer.cache.iter_mut() {
241            let buffer = buffer.take().unwrap();
242            let cache_size = self
243                .read_cache_bytes
244                .fetch_add(buffer.len(), Ordering::AcqRel);
245
246            if cache_size + buffer.len() <= self.max_read_cache_bytes {
247                let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
248                let mut lock = self.read_cache[cache_slot].write().unwrap();
249                if let Some(replaced) = lock.insert(*offset, buffer) {
250                    // A race could cause us to replace an existing buffer
251                    self.read_cache_bytes
252                        .fetch_sub(replaced.len(), Ordering::AcqRel);
253                }
254            } else {
255                self.read_cache_bytes
256                    .fetch_sub(buffer.len(), Ordering::AcqRel);
257                break;
258            }
259        }
260        self.write_buffer_bytes.store(0, Ordering::Release);
261        write_buffer.clear();
262
263        Ok(())
264    }
265
266    // Caller should invalidate all cached pages that are no longer valid
267    pub(super) fn resize(&self, len: u64) -> Result {
268        // TODO: be more fine-grained about this invalidation
269        self.invalidate_cache_all();
270
271        self.file.set_len(len)
272    }
273
274    pub(super) fn flush(&self, #[allow(unused_variables)] eventual: bool) -> Result {
275        self.flush_write_buffer()?;
276
277        self.file.sync_data(eventual)
278    }
279
280    // Make writes visible to readers, but does not guarantee any durability
281    pub(super) fn write_barrier(&self) -> Result {
282        self.flush_write_buffer()
283    }
284
285    // Read directly from the file, ignoring any cached data
286    pub(super) fn read_direct(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
287        self.file.read(offset, len)
288    }
289
290    // Read with caching. Caller must not read overlapping ranges without first calling invalidate_cache().
291    // Doing so will not cause UB, but is a logic error.
292    pub(super) fn read(&self, offset: u64, len: usize, hint: PageHint) -> Result<Arc<[u8]>> {
293        debug_assert_eq!(0, offset % self.page_size);
294        #[cfg(feature = "cache_metrics")]
295        self.reads_total.fetch_add(1, Ordering::AcqRel);
296
297        if !matches!(hint, PageHint::Clean) {
298            let lock = self.write_buffer.lock().unwrap();
299            if let Some(cached) = lock.get(offset) {
300                #[cfg(feature = "cache_metrics")]
301                self.reads_hits.fetch_add(1, Ordering::Release);
302                debug_assert_eq!(cached.len(), len);
303                return Ok(cached.clone());
304            }
305        }
306
307        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
308        {
309            let read_lock = self.read_cache[cache_slot].read().unwrap();
310            if let Some(cached) = read_lock.get(offset) {
311                #[cfg(feature = "cache_metrics")]
312                self.reads_hits.fetch_add(1, Ordering::Release);
313                debug_assert_eq!(cached.len(), len);
314                return Ok(cached.clone());
315            }
316        }
317
318        let buffer: Arc<[u8]> = self.read_direct(offset, len)?.into();
319        let cache_size = self.read_cache_bytes.fetch_add(len, Ordering::AcqRel);
320        let mut write_lock = self.read_cache[cache_slot].write().unwrap();
321        let cache_size = if let Some(replaced) = write_lock.insert(offset, buffer.clone()) {
322            // A race could cause us to replace an existing buffer
323            self.read_cache_bytes
324                .fetch_sub(replaced.len(), Ordering::AcqRel)
325        } else {
326            cache_size
327        };
328        let mut removed = 0;
329        if cache_size + len > self.max_read_cache_bytes {
330            while removed < len {
331                if let Some((_, v)) = write_lock.pop_lowest_priority() {
332                    removed += v.len();
333                } else {
334                    break;
335                }
336            }
337        }
338        if removed > 0 {
339            self.read_cache_bytes.fetch_sub(removed, Ordering::AcqRel);
340        }
341
342        Ok(buffer)
343    }
344
345    // Discard pending writes to the given range
346    pub(super) fn cancel_pending_write(&self, offset: u64, _len: usize) {
347        assert_eq!(0, offset % self.page_size);
348        if let Some(removed) = self.write_buffer.lock().unwrap().remove(offset) {
349            self.write_buffer_bytes
350                .fetch_sub(removed.len(), Ordering::Release);
351        }
352    }
353
354    // Invalidate any caching of the given range. After this call overlapping reads of the range are allowed
355    //
356    // NOTE: Invalidating a cached region in subsections is permitted, as long as all subsections are invalidated
357    pub(super) fn invalidate_cache(&self, offset: u64, len: usize) {
358        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
359        let mut lock = self.read_cache[cache_slot].write().unwrap();
360        if let Some(removed) = lock.remove(offset) {
361            assert_eq!(len, removed.len());
362            self.read_cache_bytes
363                .fetch_sub(removed.len(), Ordering::AcqRel);
364        }
365    }
366
367    pub(super) fn invalidate_cache_all(&self) {
368        for cache_slot in 0..self.read_cache.len() {
369            let mut lock = self.read_cache[cache_slot].write().unwrap();
370            while let Some((_, removed)) = lock.pop_lowest_priority() {
371                self.read_cache_bytes
372                    .fetch_sub(removed.len(), Ordering::AcqRel);
373            }
374        }
375    }
376
377    // If overwrite is true, the page is initialized to zero
378    // cache_policy takes the existing data as an argument and returns the priority. The priority should be stable and not change after WritablePage is dropped
379    pub(super) fn write(&self, offset: u64, len: usize, overwrite: bool) -> Result<WritablePage> {
380        assert_eq!(0, offset % self.page_size);
381        let mut lock = self.write_buffer.lock().unwrap();
382
383        // TODO: allow hint that page is known to be dirty and will not be in the read cache
384        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
385        let existing = {
386            let mut lock = self.read_cache[cache_slot].write().unwrap();
387            if let Some(removed) = lock.remove(offset) {
388                assert_eq!(
389                    len,
390                    removed.len(),
391                    "cache inconsistency {len} != {} for offset {offset}",
392                    removed.len()
393                );
394                self.read_cache_bytes
395                    .fetch_sub(removed.len(), Ordering::AcqRel);
396                Some(removed)
397            } else {
398                None
399            }
400        };
401
402        let data = if let Some(removed) = lock.take_value(offset) {
403            removed
404        } else {
405            let previous = self.write_buffer_bytes.fetch_add(len, Ordering::AcqRel);
406            if previous + len > self.max_write_buffer_bytes {
407                let mut removed_bytes = 0;
408                while removed_bytes < len {
409                    if let Some((offset, buffer)) = lock.pop_lowest_priority() {
410                        let removed_len = buffer.len();
411                        let result = self.file.write(offset, &buffer);
412                        if result.is_err() {
413                            lock.insert(offset, buffer);
414                        }
415                        result?;
416                        self.write_buffer_bytes
417                            .fetch_sub(removed_len, Ordering::Release);
418                        removed_bytes += removed_len;
419                    } else {
420                        break;
421                    }
422                }
423            }
424            let result = if let Some(data) = existing {
425                data
426            } else if overwrite {
427                vec![0; len].into()
428            } else {
429                self.read_direct(offset, len)?.into()
430            };
431            lock.insert(offset, result);
432            lock.take_value(offset).unwrap()
433        };
434        Ok(WritablePage {
435            buffer: self.write_buffer.clone(),
436            offset,
437            data,
438        })
439    }
440}
441
442#[cfg(test)]
443mod test {
444    use crate::backends::InMemoryBackend;
445    use crate::tree_store::page_store::cached_file::PagedCachedFile;
446    use crate::tree_store::PageHint;
447    use crate::StorageBackend;
448    use std::sync::atomic::Ordering;
449    use std::sync::Arc;
450
451    // TODO: Switch to threaded wasi build for the tests
452    #[cfg(not(target_os = "wasi"))]
453    #[test]
454    fn cache_leak() {
455        let backend = InMemoryBackend::new();
456        backend.set_len(1024).unwrap();
457        let cached_file = PagedCachedFile::new(Box::new(backend), 128, 1024, 128).unwrap();
458        let cached_file = Arc::new(cached_file);
459
460        let t1 = {
461            let cached_file = cached_file.clone();
462            std::thread::spawn(move || {
463                for _ in 0..1000 {
464                    cached_file.read(0, 128, PageHint::None).unwrap();
465                    cached_file.invalidate_cache(0, 128);
466                }
467            })
468        };
469        let t2 = {
470            let cached_file = cached_file.clone();
471            std::thread::spawn(move || {
472                for _ in 0..1000 {
473                    cached_file.read(0, 128, PageHint::None).unwrap();
474                    cached_file.invalidate_cache(0, 128);
475                }
476            })
477        };
478
479        t1.join().unwrap();
480        t2.join().unwrap();
481        cached_file.invalidate_cache(0, 128);
482        assert_eq!(cached_file.read_cache_bytes.load(Ordering::Acquire), 0);
483    }
484}