redb/tree_store/page_store/
cached_file.rs

1use crate::tree_store::page_store::base::PageHint;
2use crate::tree_store::page_store::lru_cache::LRUCache;
3use crate::{CacheStats, DatabaseError, Result, StorageBackend, StorageError};
4use std::ops::{Index, IndexMut};
5use std::slice::SliceIndex;
6#[cfg(feature = "cache_metrics")]
7use std::sync::atomic::AtomicU64;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex, RwLock};
10
11pub(super) struct WritablePage {
12    buffer: Arc<Mutex<LRUWriteCache>>,
13    offset: u64,
14    data: Arc<[u8]>,
15}
16
17impl WritablePage {
18    pub(super) fn mem(&self) -> &[u8] {
19        &self.data
20    }
21
22    pub(super) fn mem_mut(&mut self) -> &mut [u8] {
23        Arc::get_mut(&mut self.data).unwrap()
24    }
25}
26
27impl Drop for WritablePage {
28    fn drop(&mut self) {
29        self.buffer
30            .lock()
31            .unwrap()
32            .return_value(self.offset, self.data.clone());
33    }
34}
35
36impl<I: SliceIndex<[u8]>> Index<I> for WritablePage {
37    type Output = I::Output;
38
39    fn index(&self, index: I) -> &Self::Output {
40        self.mem().index(index)
41    }
42}
43
44impl<I: SliceIndex<[u8]>> IndexMut<I> for WritablePage {
45    fn index_mut(&mut self, index: I) -> &mut Self::Output {
46        self.mem_mut().index_mut(index)
47    }
48}
49
50#[derive(Default)]
51struct LRUWriteCache {
52    cache: LRUCache<Option<Arc<[u8]>>>,
53}
54
55impl LRUWriteCache {
56    fn new() -> Self {
57        Self {
58            cache: Default::default(),
59        }
60    }
61
62    fn insert(&mut self, key: u64, value: Arc<[u8]>) {
63        assert!(self.cache.insert(key, Some(value)).is_none());
64    }
65
66    fn get(&self, key: u64) -> Option<&Arc<[u8]>> {
67        self.cache.get(key).map(|x| x.as_ref().unwrap())
68    }
69
70    fn remove(&mut self, key: u64) -> Option<Arc<[u8]>> {
71        if let Some(value) = self.cache.remove(key) {
72            assert!(value.is_some());
73            return value;
74        }
75        None
76    }
77
78    fn return_value(&mut self, key: u64, value: Arc<[u8]>) {
79        assert!(self.cache.get_mut(key).unwrap().replace(value).is_none());
80    }
81
82    fn take_value(&mut self, key: u64) -> Option<Arc<[u8]>> {
83        if let Some(value) = self.cache.get_mut(key) {
84            let result = value.take().unwrap();
85            return Some(result);
86        }
87        None
88    }
89
90    fn pop_lowest_priority(&mut self) -> Option<(u64, Arc<[u8]>)> {
91        for _ in 0..self.cache.len() {
92            if let Some((k, v)) = self.cache.pop_lowest_priority() {
93                if let Some(v_inner) = v {
94                    return Some((k, v_inner));
95                }
96
97                // Value is borrowed by take_value(). We can't evict it, so put it back.
98                self.cache.insert(k, v);
99            } else {
100                break;
101            }
102        }
103        None
104    }
105
106    fn clear(&mut self) {
107        self.cache.clear();
108    }
109}
110
111#[derive(Debug)]
112struct CheckedBackend {
113    file: Box<dyn StorageBackend>,
114    io_failed: AtomicBool,
115}
116
117impl CheckedBackend {
118    fn new(file: Box<dyn StorageBackend>) -> Self {
119        Self {
120            file,
121            io_failed: AtomicBool::new(false),
122        }
123    }
124
125    fn check_failure(&self) -> Result<()> {
126        if self.io_failed.load(Ordering::Acquire) {
127            Err(StorageError::PreviousIo)
128        } else {
129            Ok(())
130        }
131    }
132
133    fn len(&self) -> Result<u64> {
134        self.check_failure()?;
135        let result = self.file.len();
136        if result.is_err() {
137            self.io_failed.store(true, Ordering::Release);
138        }
139        result.map_err(StorageError::from)
140    }
141
142    fn read(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
143        self.check_failure()?;
144        let result = self.file.read(offset, len);
145        if result.is_err() {
146            self.io_failed.store(true, Ordering::Release);
147        }
148        result.map_err(StorageError::from)
149    }
150
151    fn set_len(&self, len: u64) -> Result<()> {
152        self.check_failure()?;
153        let result = self.file.set_len(len);
154        if result.is_err() {
155            self.io_failed.store(true, Ordering::Release);
156        }
157        result.map_err(StorageError::from)
158    }
159
160    fn sync_data(&self, eventual: bool) -> Result<()> {
161        self.check_failure()?;
162        let result = self.file.sync_data(eventual);
163        if result.is_err() {
164            self.io_failed.store(true, Ordering::Release);
165        }
166        result.map_err(StorageError::from)
167    }
168
169    fn write(&self, offset: u64, data: &[u8]) -> Result<()> {
170        self.check_failure()?;
171        let result = self.file.write(offset, data);
172        if result.is_err() {
173            self.io_failed.store(true, Ordering::Release);
174        }
175        result.map_err(StorageError::from)
176    }
177}
178
179pub(super) struct PagedCachedFile {
180    file: CheckedBackend,
181    page_size: u64,
182    max_read_cache_bytes: usize,
183    read_cache_bytes: AtomicUsize,
184    max_write_buffer_bytes: usize,
185    write_buffer_bytes: AtomicUsize,
186    #[cfg(feature = "cache_metrics")]
187    reads_total: AtomicU64,
188    #[cfg(feature = "cache_metrics")]
189    reads_hits: AtomicU64,
190    #[cfg(feature = "cache_metrics")]
191    evictions: AtomicU64,
192    read_cache: Vec<RwLock<LRUCache<Arc<[u8]>>>>,
193    // TODO: maybe move this cache to WriteTransaction?
194    write_buffer: Arc<Mutex<LRUWriteCache>>,
195}
196
197impl PagedCachedFile {
198    pub(super) fn new(
199        file: Box<dyn StorageBackend>,
200        page_size: u64,
201        max_read_cache_bytes: usize,
202        max_write_buffer_bytes: usize,
203    ) -> Result<Self, DatabaseError> {
204        let read_cache = (0..Self::lock_stripes())
205            .map(|_| RwLock::new(LRUCache::new()))
206            .collect();
207
208        Ok(Self {
209            file: CheckedBackend::new(file),
210            page_size,
211            max_read_cache_bytes,
212            read_cache_bytes: AtomicUsize::new(0),
213            max_write_buffer_bytes,
214            write_buffer_bytes: AtomicUsize::new(0),
215            #[cfg(feature = "cache_metrics")]
216            reads_total: Default::default(),
217            #[cfg(feature = "cache_metrics")]
218            reads_hits: Default::default(),
219            #[cfg(feature = "cache_metrics")]
220            evictions: Default::default(),
221            read_cache,
222            write_buffer: Arc::new(Mutex::new(LRUWriteCache::new())),
223        })
224    }
225
226    #[allow(clippy::unused_self)]
227    pub(crate) fn cache_stats(&self) -> CacheStats {
228        CacheStats {
229            #[cfg(not(feature = "cache_metrics"))]
230            evictions: 0,
231            #[cfg(feature = "cache_metrics")]
232            evictions: self.evictions.load(Ordering::Acquire),
233        }
234    }
235
236    pub(crate) fn check_io_errors(&self) -> Result {
237        self.file.check_failure()
238    }
239
240    pub(crate) fn raw_file_len(&self) -> Result<u64> {
241        self.file.len()
242    }
243
244    const fn lock_stripes() -> u64 {
245        131
246    }
247
248    fn flush_write_buffer(&self) -> Result {
249        let mut write_buffer = self.write_buffer.lock().unwrap();
250
251        for (offset, buffer) in write_buffer.cache.iter() {
252            self.file.write(*offset, buffer.as_ref().unwrap())?;
253        }
254        for (offset, buffer) in write_buffer.cache.iter_mut() {
255            let buffer = buffer.take().unwrap();
256            let cache_size = self
257                .read_cache_bytes
258                .fetch_add(buffer.len(), Ordering::AcqRel);
259
260            if cache_size + buffer.len() <= self.max_read_cache_bytes {
261                let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
262                let mut lock = self.read_cache[cache_slot].write().unwrap();
263                if let Some(replaced) = lock.insert(*offset, buffer) {
264                    // A race could cause us to replace an existing buffer
265                    self.read_cache_bytes
266                        .fetch_sub(replaced.len(), Ordering::AcqRel);
267                }
268            } else {
269                self.read_cache_bytes
270                    .fetch_sub(buffer.len(), Ordering::AcqRel);
271                break;
272            }
273        }
274        self.write_buffer_bytes.store(0, Ordering::Release);
275        write_buffer.clear();
276
277        Ok(())
278    }
279
280    // Caller should invalidate all cached pages that are no longer valid
281    pub(super) fn resize(&self, len: u64) -> Result {
282        // TODO: be more fine-grained about this invalidation
283        self.invalidate_cache_all();
284
285        self.file.set_len(len)
286    }
287
288    pub(super) fn flush(&self, #[allow(unused_variables)] eventual: bool) -> Result {
289        self.flush_write_buffer()?;
290
291        self.file.sync_data(eventual)
292    }
293
294    // Make writes visible to readers, but does not guarantee any durability
295    pub(super) fn write_barrier(&self) -> Result {
296        self.flush_write_buffer()
297    }
298
299    // Read directly from the file, ignoring any cached data
300    pub(super) fn read_direct(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
301        self.file.read(offset, len)
302    }
303
304    // Read with caching. Caller must not read overlapping ranges without first calling invalidate_cache().
305    // Doing so will not cause UB, but is a logic error.
306    pub(super) fn read(&self, offset: u64, len: usize, hint: PageHint) -> Result<Arc<[u8]>> {
307        debug_assert_eq!(0, offset % self.page_size);
308        #[cfg(feature = "cache_metrics")]
309        self.reads_total.fetch_add(1, Ordering::AcqRel);
310
311        if !matches!(hint, PageHint::Clean) {
312            let lock = self.write_buffer.lock().unwrap();
313            if let Some(cached) = lock.get(offset) {
314                #[cfg(feature = "cache_metrics")]
315                self.reads_hits.fetch_add(1, Ordering::Release);
316                debug_assert_eq!(cached.len(), len);
317                return Ok(cached.clone());
318            }
319        }
320
321        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
322        {
323            let read_lock = self.read_cache[cache_slot].read().unwrap();
324            if let Some(cached) = read_lock.get(offset) {
325                #[cfg(feature = "cache_metrics")]
326                self.reads_hits.fetch_add(1, Ordering::Release);
327                debug_assert_eq!(cached.len(), len);
328                return Ok(cached.clone());
329            }
330        }
331
332        let buffer: Arc<[u8]> = self.read_direct(offset, len)?.into();
333        let cache_size = self.read_cache_bytes.fetch_add(len, Ordering::AcqRel);
334        let mut write_lock = self.read_cache[cache_slot].write().unwrap();
335        let cache_size = if let Some(replaced) = write_lock.insert(offset, buffer.clone()) {
336            // A race could cause us to replace an existing buffer
337            self.read_cache_bytes
338                .fetch_sub(replaced.len(), Ordering::AcqRel)
339        } else {
340            cache_size
341        };
342        let mut removed = 0;
343        if cache_size + len > self.max_read_cache_bytes {
344            while removed < len {
345                if let Some((_, v)) = write_lock.pop_lowest_priority() {
346                    #[cfg(feature = "cache_metrics")]
347                    {
348                        self.evictions.fetch_add(1, Ordering::Relaxed);
349                    }
350                    removed += v.len();
351                } else {
352                    break;
353                }
354            }
355        }
356        if removed > 0 {
357            self.read_cache_bytes.fetch_sub(removed, Ordering::AcqRel);
358        }
359
360        Ok(buffer)
361    }
362
363    // Discard pending writes to the given range
364    pub(super) fn cancel_pending_write(&self, offset: u64, _len: usize) {
365        assert_eq!(0, offset % self.page_size);
366        if let Some(removed) = self.write_buffer.lock().unwrap().remove(offset) {
367            self.write_buffer_bytes
368                .fetch_sub(removed.len(), Ordering::Release);
369        }
370    }
371
372    // Invalidate any caching of the given range. After this call overlapping reads of the range are allowed
373    //
374    // NOTE: Invalidating a cached region in subsections is permitted, as long as all subsections are invalidated
375    pub(super) fn invalidate_cache(&self, offset: u64, len: usize) {
376        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
377        let mut lock = self.read_cache[cache_slot].write().unwrap();
378        if let Some(removed) = lock.remove(offset) {
379            assert_eq!(len, removed.len());
380            self.read_cache_bytes
381                .fetch_sub(removed.len(), Ordering::AcqRel);
382        }
383    }
384
385    pub(super) fn invalidate_cache_all(&self) {
386        for cache_slot in 0..self.read_cache.len() {
387            let mut lock = self.read_cache[cache_slot].write().unwrap();
388            while let Some((_, removed)) = lock.pop_lowest_priority() {
389                self.read_cache_bytes
390                    .fetch_sub(removed.len(), Ordering::AcqRel);
391            }
392        }
393    }
394
395    // If overwrite is true, the page is initialized to zero
396    // cache_policy takes the existing data as an argument and returns the priority. The priority should be stable and not change after WritablePage is dropped
397    pub(super) fn write(&self, offset: u64, len: usize, overwrite: bool) -> Result<WritablePage> {
398        assert_eq!(0, offset % self.page_size);
399        let mut lock = self.write_buffer.lock().unwrap();
400
401        // TODO: allow hint that page is known to be dirty and will not be in the read cache
402        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
403        let existing = {
404            let mut lock = self.read_cache[cache_slot].write().unwrap();
405            if let Some(removed) = lock.remove(offset) {
406                assert_eq!(
407                    len,
408                    removed.len(),
409                    "cache inconsistency {len} != {} for offset {offset}",
410                    removed.len()
411                );
412                self.read_cache_bytes
413                    .fetch_sub(removed.len(), Ordering::AcqRel);
414                Some(removed)
415            } else {
416                None
417            }
418        };
419
420        let data = if let Some(removed) = lock.take_value(offset) {
421            removed
422        } else {
423            let previous = self.write_buffer_bytes.fetch_add(len, Ordering::AcqRel);
424            if previous + len > self.max_write_buffer_bytes {
425                let mut removed_bytes = 0;
426                while removed_bytes < len {
427                    if let Some((offset, buffer)) = lock.pop_lowest_priority() {
428                        let removed_len = buffer.len();
429                        let result = self.file.write(offset, &buffer);
430                        if result.is_err() {
431                            lock.insert(offset, buffer);
432                        }
433                        result?;
434                        self.write_buffer_bytes
435                            .fetch_sub(removed_len, Ordering::Release);
436                        #[cfg(feature = "cache_metrics")]
437                        {
438                            self.evictions.fetch_add(1, Ordering::Relaxed);
439                        }
440                        removed_bytes += removed_len;
441                    } else {
442                        break;
443                    }
444                }
445            }
446            let result = if let Some(data) = existing {
447                data
448            } else if overwrite {
449                vec![0; len].into()
450            } else {
451                self.read_direct(offset, len)?.into()
452            };
453            lock.insert(offset, result);
454            lock.take_value(offset).unwrap()
455        };
456        Ok(WritablePage {
457            buffer: self.write_buffer.clone(),
458            offset,
459            data,
460        })
461    }
462}
463
464#[cfg(test)]
465mod test {
466    use crate::StorageBackend;
467    use crate::backends::InMemoryBackend;
468    use crate::tree_store::PageHint;
469    use crate::tree_store::page_store::cached_file::PagedCachedFile;
470    use std::sync::Arc;
471    use std::sync::atomic::Ordering;
472
473    #[test]
474    fn cache_leak() {
475        let backend = InMemoryBackend::new();
476        backend.set_len(1024).unwrap();
477        let cached_file = PagedCachedFile::new(Box::new(backend), 128, 1024, 128).unwrap();
478        let cached_file = Arc::new(cached_file);
479
480        let t1 = {
481            let cached_file = cached_file.clone();
482            std::thread::spawn(move || {
483                for _ in 0..1000 {
484                    cached_file.read(0, 128, PageHint::None).unwrap();
485                    cached_file.invalidate_cache(0, 128);
486                }
487            })
488        };
489        let t2 = {
490            let cached_file = cached_file.clone();
491            std::thread::spawn(move || {
492                for _ in 0..1000 {
493                    cached_file.read(0, 128, PageHint::None).unwrap();
494                    cached_file.invalidate_cache(0, 128);
495                }
496            })
497        };
498
499        t1.join().unwrap();
500        t2.join().unwrap();
501        cached_file.invalidate_cache(0, 128);
502        assert_eq!(cached_file.read_cache_bytes.load(Ordering::Acquire), 0);
503    }
504}