redb/tree_store/page_store/
cached_file.rs

1use crate::tree_store::page_store::base::PageHint;
2use crate::tree_store::page_store::lru_cache::LRUCache;
3use crate::{CacheStats, DatabaseError, Result, StorageBackend, StorageError};
4use std::ops::{Index, IndexMut};
5use std::slice::SliceIndex;
6#[cfg(feature = "cache_metrics")]
7use std::sync::atomic::AtomicU64;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex, RwLock};
10
11pub(super) struct WritablePage {
12    buffer: Arc<Mutex<LRUWriteCache>>,
13    offset: u64,
14    data: Arc<[u8]>,
15}
16
17impl WritablePage {
18    pub(super) fn mem(&self) -> &[u8] {
19        &self.data
20    }
21
22    pub(super) fn mem_mut(&mut self) -> &mut [u8] {
23        Arc::get_mut(&mut self.data).unwrap()
24    }
25}
26
27impl Drop for WritablePage {
28    fn drop(&mut self) {
29        self.buffer
30            .lock()
31            .unwrap()
32            .return_value(self.offset, self.data.clone());
33    }
34}
35
36impl<I: SliceIndex<[u8]>> Index<I> for WritablePage {
37    type Output = I::Output;
38
39    fn index(&self, index: I) -> &Self::Output {
40        self.mem().index(index)
41    }
42}
43
44impl<I: SliceIndex<[u8]>> IndexMut<I> for WritablePage {
45    fn index_mut(&mut self, index: I) -> &mut Self::Output {
46        self.mem_mut().index_mut(index)
47    }
48}
49
50#[derive(Default)]
51struct LRUWriteCache {
52    cache: LRUCache<Option<Arc<[u8]>>>,
53}
54
55impl LRUWriteCache {
56    fn new() -> Self {
57        Self {
58            cache: Default::default(),
59        }
60    }
61
62    fn insert(&mut self, key: u64, value: Arc<[u8]>) {
63        assert!(self.cache.insert(key, Some(value)).is_none());
64    }
65
66    fn get(&self, key: u64) -> Option<&Arc<[u8]>> {
67        self.cache.get(key).map(|x| x.as_ref().unwrap())
68    }
69
70    fn remove(&mut self, key: u64) -> Option<Arc<[u8]>> {
71        if let Some(value) = self.cache.remove(key) {
72            assert!(value.is_some());
73            return value;
74        }
75        None
76    }
77
78    fn return_value(&mut self, key: u64, value: Arc<[u8]>) {
79        assert!(self.cache.get_mut(key).unwrap().replace(value).is_none());
80    }
81
82    fn take_value(&mut self, key: u64) -> Option<Arc<[u8]>> {
83        if let Some(value) = self.cache.get_mut(key) {
84            let result = value.take().unwrap();
85            return Some(result);
86        }
87        None
88    }
89
90    fn pop_lowest_priority(&mut self) -> Option<(u64, Arc<[u8]>)> {
91        for _ in 0..self.cache.len() {
92            if let Some((k, v)) = self.cache.pop_lowest_priority() {
93                if let Some(v_inner) = v {
94                    return Some((k, v_inner));
95                }
96
97                // Value is borrowed by take_value(). We can't evict it, so put it back.
98                self.cache.insert(k, v);
99            } else {
100                break;
101            }
102        }
103        None
104    }
105
106    fn clear(&mut self) {
107        self.cache.clear();
108    }
109}
110
111#[derive(Debug)]
112struct CheckedBackend {
113    file: Box<dyn StorageBackend>,
114    io_failed: AtomicBool,
115}
116
117impl CheckedBackend {
118    fn new(file: Box<dyn StorageBackend>) -> Self {
119        Self {
120            file,
121            io_failed: AtomicBool::new(false),
122        }
123    }
124
125    fn set_failure(&self) {
126        self.io_failed.store(true, Ordering::Release);
127    }
128
129    fn check_failure(&self) -> Result<()> {
130        if self.io_failed.load(Ordering::Acquire) {
131            Err(StorageError::PreviousIo)
132        } else {
133            Ok(())
134        }
135    }
136
137    fn len(&self) -> Result<u64> {
138        self.check_failure()?;
139        let result = self.file.len();
140        if result.is_err() {
141            self.io_failed.store(true, Ordering::Release);
142        }
143        result.map_err(StorageError::from)
144    }
145
146    fn read(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
147        self.check_failure()?;
148        let result = self.file.read(offset, len);
149        if result.is_err() {
150            self.io_failed.store(true, Ordering::Release);
151        }
152        result.map_err(StorageError::from)
153    }
154
155    fn set_len(&self, len: u64) -> Result<()> {
156        self.check_failure()?;
157        let result = self.file.set_len(len);
158        if result.is_err() {
159            self.io_failed.store(true, Ordering::Release);
160        }
161        result.map_err(StorageError::from)
162    }
163
164    fn sync_data(&self, eventual: bool) -> Result<()> {
165        self.check_failure()?;
166        let result = self.file.sync_data(eventual);
167        if result.is_err() {
168            self.io_failed.store(true, Ordering::Release);
169        }
170        result.map_err(StorageError::from)
171    }
172
173    fn write(&self, offset: u64, data: &[u8]) -> Result<()> {
174        self.check_failure()?;
175        let result = self.file.write(offset, data);
176        if result.is_err() {
177            self.io_failed.store(true, Ordering::Release);
178        }
179        result.map_err(StorageError::from)
180    }
181}
182
183pub(super) struct PagedCachedFile {
184    file: CheckedBackend,
185    page_size: u64,
186    max_read_cache_bytes: usize,
187    read_cache_bytes: AtomicUsize,
188    max_write_buffer_bytes: usize,
189    write_buffer_bytes: AtomicUsize,
190    #[cfg(feature = "cache_metrics")]
191    reads_total: AtomicU64,
192    #[cfg(feature = "cache_metrics")]
193    reads_hits: AtomicU64,
194    #[cfg(feature = "cache_metrics")]
195    evictions: AtomicU64,
196    read_cache: Vec<RwLock<LRUCache<Arc<[u8]>>>>,
197    // TODO: maybe move this cache to WriteTransaction?
198    write_buffer: Arc<Mutex<LRUWriteCache>>,
199}
200
201impl PagedCachedFile {
202    pub(super) fn new(
203        file: Box<dyn StorageBackend>,
204        page_size: u64,
205        max_read_cache_bytes: usize,
206        max_write_buffer_bytes: usize,
207    ) -> Result<Self, DatabaseError> {
208        let read_cache = (0..Self::lock_stripes())
209            .map(|_| RwLock::new(LRUCache::new()))
210            .collect();
211
212        Ok(Self {
213            file: CheckedBackend::new(file),
214            page_size,
215            max_read_cache_bytes,
216            read_cache_bytes: AtomicUsize::new(0),
217            max_write_buffer_bytes,
218            write_buffer_bytes: AtomicUsize::new(0),
219            #[cfg(feature = "cache_metrics")]
220            reads_total: Default::default(),
221            #[cfg(feature = "cache_metrics")]
222            reads_hits: Default::default(),
223            #[cfg(feature = "cache_metrics")]
224            evictions: Default::default(),
225            read_cache,
226            write_buffer: Arc::new(Mutex::new(LRUWriteCache::new())),
227        })
228    }
229
230    #[allow(clippy::unused_self)]
231    pub(crate) fn cache_stats(&self) -> CacheStats {
232        CacheStats {
233            #[cfg(not(feature = "cache_metrics"))]
234            evictions: 0,
235            #[cfg(feature = "cache_metrics")]
236            evictions: self.evictions.load(Ordering::Acquire),
237        }
238    }
239
240    pub(crate) fn check_io_errors(&self) -> Result {
241        self.file.check_failure()
242    }
243
244    pub(crate) fn set_irrecoverable_io_error(&self) {
245        self.file.set_failure();
246    }
247
248    pub(crate) fn raw_file_len(&self) -> Result<u64> {
249        self.file.len()
250    }
251
252    const fn lock_stripes() -> u64 {
253        131
254    }
255
256    fn flush_write_buffer(&self) -> Result {
257        let mut write_buffer = self.write_buffer.lock().unwrap();
258
259        for (offset, buffer) in write_buffer.cache.iter() {
260            self.file.write(*offset, buffer.as_ref().unwrap())?;
261        }
262        for (offset, buffer) in write_buffer.cache.iter_mut() {
263            let buffer = buffer.take().unwrap();
264            let cache_size = self
265                .read_cache_bytes
266                .fetch_add(buffer.len(), Ordering::AcqRel);
267
268            if cache_size + buffer.len() <= self.max_read_cache_bytes {
269                let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
270                let mut lock = self.read_cache[cache_slot].write().unwrap();
271                if let Some(replaced) = lock.insert(*offset, buffer) {
272                    // A race could cause us to replace an existing buffer
273                    self.read_cache_bytes
274                        .fetch_sub(replaced.len(), Ordering::AcqRel);
275                }
276            } else {
277                self.read_cache_bytes
278                    .fetch_sub(buffer.len(), Ordering::AcqRel);
279                break;
280            }
281        }
282        self.write_buffer_bytes.store(0, Ordering::Release);
283        write_buffer.clear();
284
285        Ok(())
286    }
287
288    // Caller should invalidate all cached pages that are no longer valid
289    pub(super) fn resize(&self, len: u64) -> Result {
290        // TODO: be more fine-grained about this invalidation
291        self.invalidate_cache_all();
292
293        self.file.set_len(len)
294    }
295
296    pub(super) fn flush(&self, #[allow(unused_variables)] eventual: bool) -> Result {
297        self.flush_write_buffer()?;
298
299        self.file.sync_data(eventual)
300    }
301
302    // Make writes visible to readers, but does not guarantee any durability
303    pub(super) fn write_barrier(&self) -> Result {
304        self.flush_write_buffer()
305    }
306
307    // Read directly from the file, ignoring any cached data
308    pub(super) fn read_direct(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
309        self.file.read(offset, len)
310    }
311
312    // Read with caching. Caller must not read overlapping ranges without first calling invalidate_cache().
313    // Doing so will not cause UB, but is a logic error.
314    pub(super) fn read(&self, offset: u64, len: usize, hint: PageHint) -> Result<Arc<[u8]>> {
315        debug_assert_eq!(0, offset % self.page_size);
316        #[cfg(feature = "cache_metrics")]
317        self.reads_total.fetch_add(1, Ordering::AcqRel);
318
319        if !matches!(hint, PageHint::Clean) {
320            let lock = self.write_buffer.lock().unwrap();
321            if let Some(cached) = lock.get(offset) {
322                #[cfg(feature = "cache_metrics")]
323                self.reads_hits.fetch_add(1, Ordering::Release);
324                debug_assert_eq!(cached.len(), len);
325                return Ok(cached.clone());
326            }
327        }
328
329        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
330        {
331            let read_lock = self.read_cache[cache_slot].read().unwrap();
332            if let Some(cached) = read_lock.get(offset) {
333                #[cfg(feature = "cache_metrics")]
334                self.reads_hits.fetch_add(1, Ordering::Release);
335                debug_assert_eq!(cached.len(), len);
336                return Ok(cached.clone());
337            }
338        }
339
340        let buffer: Arc<[u8]> = self.read_direct(offset, len)?.into();
341        let cache_size = self.read_cache_bytes.fetch_add(len, Ordering::AcqRel);
342        let mut write_lock = self.read_cache[cache_slot].write().unwrap();
343        let cache_size = if let Some(replaced) = write_lock.insert(offset, buffer.clone()) {
344            // A race could cause us to replace an existing buffer
345            self.read_cache_bytes
346                .fetch_sub(replaced.len(), Ordering::AcqRel)
347        } else {
348            cache_size
349        };
350        let mut removed = 0;
351        if cache_size + len > self.max_read_cache_bytes {
352            while removed < len {
353                if let Some((_, v)) = write_lock.pop_lowest_priority() {
354                    #[cfg(feature = "cache_metrics")]
355                    {
356                        self.evictions.fetch_add(1, Ordering::Relaxed);
357                    }
358                    removed += v.len();
359                } else {
360                    break;
361                }
362            }
363        }
364        if removed > 0 {
365            self.read_cache_bytes.fetch_sub(removed, Ordering::AcqRel);
366        }
367
368        Ok(buffer)
369    }
370
371    // Discard pending writes to the given range
372    pub(super) fn cancel_pending_write(&self, offset: u64, _len: usize) {
373        assert_eq!(0, offset % self.page_size);
374        if let Some(removed) = self.write_buffer.lock().unwrap().remove(offset) {
375            self.write_buffer_bytes
376                .fetch_sub(removed.len(), Ordering::Release);
377        }
378    }
379
380    // Invalidate any caching of the given range. After this call overlapping reads of the range are allowed
381    //
382    // NOTE: Invalidating a cached region in subsections is permitted, as long as all subsections are invalidated
383    pub(super) fn invalidate_cache(&self, offset: u64, len: usize) {
384        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
385        let mut lock = self.read_cache[cache_slot].write().unwrap();
386        if let Some(removed) = lock.remove(offset) {
387            assert_eq!(len, removed.len());
388            self.read_cache_bytes
389                .fetch_sub(removed.len(), Ordering::AcqRel);
390        }
391    }
392
393    pub(super) fn invalidate_cache_all(&self) {
394        for cache_slot in 0..self.read_cache.len() {
395            let mut lock = self.read_cache[cache_slot].write().unwrap();
396            while let Some((_, removed)) = lock.pop_lowest_priority() {
397                self.read_cache_bytes
398                    .fetch_sub(removed.len(), Ordering::AcqRel);
399            }
400        }
401    }
402
403    // If overwrite is true, the page is initialized to zero
404    // cache_policy takes the existing data as an argument and returns the priority. The priority should be stable and not change after WritablePage is dropped
405    pub(super) fn write(&self, offset: u64, len: usize, overwrite: bool) -> Result<WritablePage> {
406        assert_eq!(0, offset % self.page_size);
407        let mut lock = self.write_buffer.lock().unwrap();
408
409        // TODO: allow hint that page is known to be dirty and will not be in the read cache
410        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
411        let existing = {
412            let mut lock = self.read_cache[cache_slot].write().unwrap();
413            if let Some(removed) = lock.remove(offset) {
414                assert_eq!(
415                    len,
416                    removed.len(),
417                    "cache inconsistency {len} != {} for offset {offset}",
418                    removed.len()
419                );
420                self.read_cache_bytes
421                    .fetch_sub(removed.len(), Ordering::AcqRel);
422                Some(removed)
423            } else {
424                None
425            }
426        };
427
428        let data = if let Some(removed) = lock.take_value(offset) {
429            removed
430        } else {
431            let previous = self.write_buffer_bytes.fetch_add(len, Ordering::AcqRel);
432            if previous + len > self.max_write_buffer_bytes {
433                let mut removed_bytes = 0;
434                while removed_bytes < len {
435                    if let Some((offset, buffer)) = lock.pop_lowest_priority() {
436                        let removed_len = buffer.len();
437                        let result = self.file.write(offset, &buffer);
438                        if result.is_err() {
439                            lock.insert(offset, buffer);
440                        }
441                        result?;
442                        self.write_buffer_bytes
443                            .fetch_sub(removed_len, Ordering::Release);
444                        #[cfg(feature = "cache_metrics")]
445                        {
446                            self.evictions.fetch_add(1, Ordering::Relaxed);
447                        }
448                        removed_bytes += removed_len;
449                    } else {
450                        break;
451                    }
452                }
453            }
454            let result = if let Some(data) = existing {
455                data
456            } else if overwrite {
457                vec![0; len].into()
458            } else {
459                self.read_direct(offset, len)?.into()
460            };
461            lock.insert(offset, result);
462            lock.take_value(offset).unwrap()
463        };
464        Ok(WritablePage {
465            buffer: self.write_buffer.clone(),
466            offset,
467            data,
468        })
469    }
470}
471
472#[cfg(test)]
473mod test {
474    use crate::StorageBackend;
475    use crate::backends::InMemoryBackend;
476    use crate::tree_store::PageHint;
477    use crate::tree_store::page_store::cached_file::PagedCachedFile;
478    use std::sync::Arc;
479    use std::sync::atomic::Ordering;
480
481    #[test]
482    fn cache_leak() {
483        let backend = InMemoryBackend::new();
484        backend.set_len(1024).unwrap();
485        let cached_file = PagedCachedFile::new(Box::new(backend), 128, 1024, 128).unwrap();
486        let cached_file = Arc::new(cached_file);
487
488        let t1 = {
489            let cached_file = cached_file.clone();
490            std::thread::spawn(move || {
491                for _ in 0..1000 {
492                    cached_file.read(0, 128, PageHint::None).unwrap();
493                    cached_file.invalidate_cache(0, 128);
494                }
495            })
496        };
497        let t2 = {
498            let cached_file = cached_file.clone();
499            std::thread::spawn(move || {
500                for _ in 0..1000 {
501                    cached_file.read(0, 128, PageHint::None).unwrap();
502                    cached_file.invalidate_cache(0, 128);
503                }
504            })
505        };
506
507        t1.join().unwrap();
508        t2.join().unwrap();
509        cached_file.invalidate_cache(0, 128);
510        assert_eq!(cached_file.read_cache_bytes.load(Ordering::Acquire), 0);
511    }
512}