use crate::transaction_tracker::{SavepointId, TransactionId, TransactionTracker};
use crate::tree_store::{
AllPageNumbersBtreeIter, BtreeHeader, BtreeRangeIter, FreedPageList, FreedTableKey,
InternalTableDefinition, PageHint, PageNumber, RawBtree, SerializedSavepoint, TableTreeMut,
TableType, TransactionalMemory, PAGE_SIZE,
};
use crate::types::{Key, Value};
use crate::{CompactionError, DatabaseError, Error, ReadOnlyTable, SavepointError, StorageError};
use crate::{ReadTransaction, Result, WriteTransaction};
use std::fmt::{Debug, Display, Formatter};
use std::fs::{File, OpenOptions};
use std::io::ErrorKind;
use std::marker::PhantomData;
use std::ops::RangeFull;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::{io, thread};
use crate::error::TransactionError;
use crate::sealed::Sealed;
use crate::transactions::{
AllocatorStateKey, AllocatorStateTree, ALLOCATOR_STATE_TABLE_NAME, SAVEPOINT_TABLE,
};
use crate::tree_store::file_backend::FileBackend;
#[cfg(feature = "logging")]
use log::{debug, info, warn};
#[allow(clippy::len_without_is_empty)]
pub trait StorageBackend: 'static + Debug + Send + Sync {
fn len(&self) -> std::result::Result<u64, io::Error>;
fn read(&self, offset: u64, len: usize) -> std::result::Result<Vec<u8>, io::Error>;
fn set_len(&self, len: u64) -> std::result::Result<(), io::Error>;
fn sync_data(&self, eventual: bool) -> std::result::Result<(), io::Error>;
fn write(&self, offset: u64, data: &[u8]) -> std::result::Result<(), io::Error>;
}
pub trait TableHandle: Sealed {
fn name(&self) -> &str;
}
#[derive(Clone)]
pub struct UntypedTableHandle {
name: String,
}
impl UntypedTableHandle {
pub(crate) fn new(name: String) -> Self {
Self { name }
}
}
impl TableHandle for UntypedTableHandle {
fn name(&self) -> &str {
&self.name
}
}
impl Sealed for UntypedTableHandle {}
pub trait MultimapTableHandle: Sealed {
fn name(&self) -> &str;
}
#[derive(Clone)]
pub struct UntypedMultimapTableHandle {
name: String,
}
impl UntypedMultimapTableHandle {
pub(crate) fn new(name: String) -> Self {
Self { name }
}
}
impl MultimapTableHandle for UntypedMultimapTableHandle {
fn name(&self) -> &str {
&self.name
}
}
impl Sealed for UntypedMultimapTableHandle {}
pub struct TableDefinition<'a, K: Key + 'static, V: Value + 'static> {
name: &'a str,
_key_type: PhantomData<K>,
_value_type: PhantomData<V>,
}
impl<'a, K: Key + 'static, V: Value + 'static> TableDefinition<'a, K, V> {
pub const fn new(name: &'a str) -> Self {
assert!(!name.is_empty());
Self {
name,
_key_type: PhantomData,
_value_type: PhantomData,
}
}
}
impl<'a, K: Key + 'static, V: Value + 'static> TableHandle for TableDefinition<'a, K, V> {
fn name(&self) -> &str {
self.name
}
}
impl<K: Key, V: Value> Sealed for TableDefinition<'_, K, V> {}
impl<'a, K: Key + 'static, V: Value + 'static> Clone for TableDefinition<'a, K, V> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, K: Key + 'static, V: Value + 'static> Copy for TableDefinition<'a, K, V> {}
impl<'a, K: Key + 'static, V: Value + 'static> Display for TableDefinition<'a, K, V> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}<{}, {}>",
self.name,
K::type_name().name(),
V::type_name().name()
)
}
}
pub struct MultimapTableDefinition<'a, K: Key + 'static, V: Key + 'static> {
name: &'a str,
_key_type: PhantomData<K>,
_value_type: PhantomData<V>,
}
impl<'a, K: Key + 'static, V: Key + 'static> MultimapTableDefinition<'a, K, V> {
pub const fn new(name: &'a str) -> Self {
assert!(!name.is_empty());
Self {
name,
_key_type: PhantomData,
_value_type: PhantomData,
}
}
}
impl<'a, K: Key + 'static, V: Key + 'static> MultimapTableHandle
for MultimapTableDefinition<'a, K, V>
{
fn name(&self) -> &str {
self.name
}
}
impl<K: Key, V: Key> Sealed for MultimapTableDefinition<'_, K, V> {}
impl<'a, K: Key + 'static, V: Key + 'static> Clone for MultimapTableDefinition<'a, K, V> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, K: Key + 'static, V: Key + 'static> Copy for MultimapTableDefinition<'a, K, V> {}
impl<'a, K: Key + 'static, V: Key + 'static> Display for MultimapTableDefinition<'a, K, V> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}<{}, {}>",
self.name,
K::type_name().name(),
V::type_name().name()
)
}
}
pub(crate) struct TransactionGuard {
transaction_tracker: Option<Arc<TransactionTracker>>,
transaction_id: Option<TransactionId>,
write_transaction: bool,
}
impl TransactionGuard {
pub(crate) fn new_read(
transaction_id: TransactionId,
tracker: Arc<TransactionTracker>,
) -> Self {
Self {
transaction_tracker: Some(tracker),
transaction_id: Some(transaction_id),
write_transaction: false,
}
}
pub(crate) fn new_write(
transaction_id: TransactionId,
tracker: Arc<TransactionTracker>,
) -> Self {
Self {
transaction_tracker: Some(tracker),
transaction_id: Some(transaction_id),
write_transaction: true,
}
}
pub(crate) fn fake() -> Self {
Self {
transaction_tracker: None,
transaction_id: None,
write_transaction: false,
}
}
pub(crate) fn id(&self) -> TransactionId {
self.transaction_id.unwrap()
}
pub(crate) fn leak(mut self) -> TransactionId {
self.transaction_id.take().unwrap()
}
}
impl Drop for TransactionGuard {
fn drop(&mut self) {
if self.transaction_tracker.is_none() {
return;
}
if let Some(transaction_id) = self.transaction_id {
if self.write_transaction {
self.transaction_tracker
.as_ref()
.unwrap()
.end_write_transaction(transaction_id);
} else {
self.transaction_tracker
.as_ref()
.unwrap()
.deallocate_read_transaction(transaction_id);
}
}
}
}
pub struct Database {
mem: Arc<TransactionalMemory>,
transaction_tracker: Arc<TransactionTracker>,
}
impl Database {
pub fn create(path: impl AsRef<Path>) -> Result<Database, DatabaseError> {
Self::builder().create(path)
}
pub fn open(path: impl AsRef<Path>) -> Result<Database, DatabaseError> {
Self::builder().open(path)
}
pub(crate) fn get_memory(&self) -> Arc<TransactionalMemory> {
self.mem.clone()
}
pub(crate) fn verify_primary_checksums(mem: Arc<TransactionalMemory>) -> Result<bool> {
let fake_freed_pages = Arc::new(Mutex::new(vec![]));
let table_tree = TableTreeMut::new(
mem.get_data_root(),
Arc::new(TransactionGuard::fake()),
mem.clone(),
fake_freed_pages.clone(),
);
if !table_tree.verify_checksums()? {
return Ok(false);
}
let system_table_tree = TableTreeMut::new(
mem.get_system_root(),
Arc::new(TransactionGuard::fake()),
mem.clone(),
fake_freed_pages.clone(),
);
if !system_table_tree.verify_checksums()? {
return Ok(false);
}
assert!(fake_freed_pages.lock().unwrap().is_empty());
if let Some(header) = mem.get_freed_root() {
if !RawBtree::new(
Some(header),
FreedTableKey::fixed_width(),
FreedPageList::fixed_width(),
mem.clone(),
)
.verify_checksum()?
{
return Ok(false);
}
}
Ok(true)
}
pub fn check_integrity(&mut self) -> Result<bool, DatabaseError> {
let allocator_hash = self.mem.allocator_hash();
let mut was_clean = Arc::get_mut(&mut self.mem)
.unwrap()
.clear_cache_and_reload()?;
let old_roots = [
self.mem.get_data_root(),
self.mem.get_system_root(),
self.mem.get_freed_root(),
];
let new_roots = Self::do_repair(&mut self.mem, &|_| {}).map_err(|err| match err {
DatabaseError::Storage(storage_err) => storage_err,
_ => unreachable!(),
})?;
if old_roots != new_roots || allocator_hash != self.mem.allocator_hash() {
was_clean = false;
}
if !was_clean {
let next_transaction_id = self.mem.get_last_committed_transaction_id()?.next();
let [data_root, system_root, freed_root] = new_roots;
self.mem.commit(
data_root,
system_root,
freed_root,
next_transaction_id,
false,
true,
)?;
}
self.mem.begin_writable()?;
Ok(was_clean)
}
pub fn compact(&mut self) -> Result<bool, CompactionError> {
if self
.transaction_tracker
.oldest_live_read_transaction()
.is_some()
{
return Err(CompactionError::TransactionInProgress);
}
let mut txn = self.begin_write().map_err(|e| e.into_storage_error())?;
if txn.list_persistent_savepoints()?.next().is_some() {
return Err(CompactionError::PersistentSavepointExists);
}
if self.transaction_tracker.any_savepoint_exists() {
return Err(CompactionError::EphemeralSavepointExists);
}
txn.set_two_phase_commit(true);
txn.commit().map_err(|e| e.into_storage_error())?;
let mut txn = self.begin_write().map_err(|e| e.into_storage_error())?;
txn.set_two_phase_commit(true);
txn.commit().map_err(|e| e.into_storage_error())?;
assert!(self.mem.get_freed_root().is_none());
let mut compacted = false;
loop {
let mut progress = false;
let mut txn = self.begin_write().map_err(|e| e.into_storage_error())?;
if txn.compact_pages()? {
progress = true;
txn.commit().map_err(|e| e.into_storage_error())?;
} else {
txn.abort()?;
}
let mut txn = self.begin_write().map_err(|e| e.into_storage_error())?;
txn.set_two_phase_commit(true);
txn.commit().map_err(|e| e.into_storage_error())?;
assert!(self.mem.get_freed_root().is_none());
if !progress {
break;
}
compacted = true;
}
Ok(compacted)
}
fn check_repaired_persistent_savepoints(
system_root: Option<BtreeHeader>,
mem: Arc<TransactionalMemory>,
) -> Result {
let freed_list = Arc::new(Mutex::new(vec![]));
let table_tree = TableTreeMut::new(
system_root,
Arc::new(TransactionGuard::fake()),
mem.clone(),
freed_list,
);
let fake_transaction_tracker = Arc::new(TransactionTracker::new(TransactionId::new(0)));
if let Some(savepoint_table_def) = table_tree
.get_table::<SavepointId, SerializedSavepoint>(
SAVEPOINT_TABLE.name(),
TableType::Normal,
)
.map_err(|e| {
e.into_storage_error_or_corrupted("Persistent savepoint table corrupted")
})?
{
let savepoint_table_root =
if let InternalTableDefinition::Normal { table_root, .. } = savepoint_table_def {
table_root
} else {
unreachable!()
};
let savepoint_table: ReadOnlyTable<SavepointId, SerializedSavepoint> =
ReadOnlyTable::new(
"internal savepoint table".to_string(),
savepoint_table_root,
PageHint::None,
Arc::new(TransactionGuard::fake()),
mem.clone(),
)?;
for result in savepoint_table.range::<SavepointId>(..)? {
let (_, savepoint_data) = result?;
let savepoint = savepoint_data
.value()
.to_savepoint(fake_transaction_tracker.clone());
if let Some(header) = savepoint.get_user_root() {
Self::check_pages_allocated_recursive(header.root, mem.clone())?;
}
}
}
Ok(())
}
fn mark_freed_tree(freed_root: Option<BtreeHeader>, mem: Arc<TransactionalMemory>) -> Result {
if let Some(header) = freed_root {
let freed_pages_iter = AllPageNumbersBtreeIter::new(
header.root,
FreedTableKey::fixed_width(),
FreedPageList::fixed_width(),
mem.clone(),
)?;
for page in freed_pages_iter {
mem.mark_page_allocated(page?);
}
}
let freed_table: ReadOnlyTable<FreedTableKey, FreedPageList<'static>> = ReadOnlyTable::new(
"internal freed table".to_string(),
freed_root,
PageHint::None,
Arc::new(TransactionGuard::fake()),
mem.clone(),
)?;
for result in freed_table.range::<FreedTableKey>(..)? {
let (_, freed_page_list) = result?;
for i in 0..freed_page_list.value().len() {
mem.mark_page_allocated(freed_page_list.value().get(i));
}
}
Ok(())
}
fn check_pages_allocated_recursive(root: PageNumber, mem: Arc<TransactionalMemory>) -> Result {
let master_pages_iter = AllPageNumbersBtreeIter::new(root, None, None, mem.clone())?;
for result in master_pages_iter {
let page = result?;
assert!(mem.is_allocated(page));
}
let iter: BtreeRangeIter<&str, InternalTableDefinition> =
BtreeRangeIter::new::<RangeFull, &str>(&(..), Some(root), mem.clone())?;
for entry in iter {
let definition = entry?.value();
definition.visit_all_pages(mem.clone(), |path| {
assert!(mem.is_allocated(path.page_number()));
Ok(())
})?;
}
Ok(())
}
fn mark_tables_recursive(root: PageNumber, mem: Arc<TransactionalMemory>) -> Result {
let master_pages_iter = AllPageNumbersBtreeIter::new(root, None, None, mem.clone())?;
for page in master_pages_iter {
mem.mark_page_allocated(page?);
}
let iter: BtreeRangeIter<&str, InternalTableDefinition> =
BtreeRangeIter::new::<RangeFull, &str>(&(..), Some(root), mem.clone())?;
for entry in iter {
let definition = entry?.value();
definition.visit_all_pages(mem.clone(), |path| {
mem.mark_page_allocated(path.page_number());
Ok(())
})?;
}
Ok(())
}
fn do_repair(
mem: &mut Arc<TransactionalMemory>, repair_callback: &(dyn Fn(&mut RepairSession) + 'static),
) -> Result<[Option<BtreeHeader>; 3], DatabaseError> {
if !Self::verify_primary_checksums(mem.clone())? {
if mem.used_two_phase_commit() {
return Err(DatabaseError::Storage(StorageError::Corrupted(
"Primary is corrupted despite 2-phase commit".to_string(),
)));
}
let mut handle = RepairSession::new(0.3);
repair_callback(&mut handle);
if handle.aborted() {
return Err(DatabaseError::RepairAborted);
}
mem.repair_primary_corrupted();
mem.clear_read_cache();
if !Self::verify_primary_checksums(mem.clone())? {
return Err(DatabaseError::Storage(StorageError::Corrupted(
"Failed to repair database. All roots are corrupted".to_string(),
)));
}
}
let mut handle = RepairSession::new(0.6);
repair_callback(&mut handle);
if handle.aborted() {
return Err(DatabaseError::RepairAborted);
}
mem.begin_repair()?;
let data_root = mem.get_data_root();
if let Some(header) = data_root {
Self::mark_tables_recursive(header.root, mem.clone())?;
}
let freed_root = mem.get_freed_root();
Self::mark_freed_tree(freed_root, mem.clone())?;
let freed_table: ReadOnlyTable<FreedTableKey, FreedPageList<'static>> = ReadOnlyTable::new(
"internal freed table".to_string(),
freed_root,
PageHint::None,
Arc::new(TransactionGuard::fake()),
mem.clone(),
)?;
drop(freed_table);
let mut handle = RepairSession::new(0.9);
repair_callback(&mut handle);
if handle.aborted() {
return Err(DatabaseError::RepairAborted);
}
let system_root = mem.get_system_root();
if let Some(header) = system_root {
Self::mark_tables_recursive(header.root, mem.clone())?;
}
#[cfg(debug_assertions)]
{
Self::check_repaired_persistent_savepoints(system_root, mem.clone())?;
}
mem.end_repair()?;
mem.clear_read_cache();
Ok([data_root, system_root, freed_root])
}
fn new(
file: Box<dyn StorageBackend>,
page_size: usize,
region_size: Option<u64>,
read_cache_size_bytes: usize,
write_cache_size_bytes: usize,
repair_callback: &(dyn Fn(&mut RepairSession) + 'static),
) -> Result<Self, DatabaseError> {
#[cfg(feature = "logging")]
let file_path = format!("{:?}", &file);
#[cfg(feature = "logging")]
info!("Opening database {:?}", &file_path);
let mem = TransactionalMemory::new(
file,
page_size,
region_size,
read_cache_size_bytes,
write_cache_size_bytes,
)?;
let mut mem = Arc::new(mem);
if mem.needs_repair()? {
if let Some(tree) = Self::get_allocator_state_table(&mem)? {
#[cfg(feature = "logging")]
info!("Found valid allocator state, full repair not needed");
mem.load_allocator_state(&tree)?;
} else {
#[cfg(feature = "logging")]
warn!("Database {:?} not shutdown cleanly. Repairing", &file_path);
let mut handle = RepairSession::new(0.0);
repair_callback(&mut handle);
if handle.aborted() {
return Err(DatabaseError::RepairAborted);
}
let [data_root, system_root, freed_root] =
Self::do_repair(&mut mem, repair_callback)?;
let next_transaction_id = mem.get_last_committed_transaction_id()?.next();
mem.commit(
data_root,
system_root,
freed_root,
next_transaction_id,
false,
true,
)?;
}
}
mem.begin_writable()?;
let next_transaction_id = mem.get_last_committed_transaction_id()?.next();
let db = Database {
mem,
transaction_tracker: Arc::new(TransactionTracker::new(next_transaction_id)),
};
let txn = db.begin_write().map_err(|e| e.into_storage_error())?;
if let Some(next_id) = txn.next_persistent_savepoint_id()? {
db.transaction_tracker
.restore_savepoint_counter_state(next_id);
}
for id in txn.list_persistent_savepoints()? {
let savepoint = match txn.get_persistent_savepoint(id) {
Ok(savepoint) => savepoint,
Err(err) => match err {
SavepointError::InvalidSavepoint => unreachable!(),
SavepointError::Storage(storage) => {
return Err(storage.into());
}
},
};
db.transaction_tracker
.register_persistent_savepoint(&savepoint);
}
txn.abort()?;
Ok(db)
}
fn get_allocator_state_table(
mem: &Arc<TransactionalMemory>,
) -> Result<Option<AllocatorStateTree>> {
if !mem.used_two_phase_commit() {
return Ok(None);
}
let fake_freed_pages = Arc::new(Mutex::new(vec![]));
let system_table_tree = TableTreeMut::new(
mem.get_system_root(),
Arc::new(TransactionGuard::fake()),
mem.clone(),
fake_freed_pages.clone(),
);
let Some(allocator_state_table) = system_table_tree
.get_table::<AllocatorStateKey, &[u8]>(ALLOCATOR_STATE_TABLE_NAME, TableType::Normal)
.map_err(|e| e.into_storage_error_or_corrupted("Unexpected TableError"))?
else {
return Ok(None);
};
let InternalTableDefinition::Normal { table_root, .. } = allocator_state_table else {
unreachable!();
};
let tree = AllocatorStateTree::new(
table_root,
Arc::new(TransactionGuard::fake()),
mem.clone(),
fake_freed_pages,
);
if !mem.is_valid_allocator_state(&tree)? {
return Ok(None);
}
Ok(Some(tree))
}
fn allocate_read_transaction(&self) -> Result<TransactionGuard> {
let id = self
.transaction_tracker
.register_read_transaction(&self.mem)?;
Ok(TransactionGuard::new_read(
id,
self.transaction_tracker.clone(),
))
}
pub fn builder() -> Builder {
Builder::new()
}
pub fn begin_write(&self) -> Result<WriteTransaction, TransactionError> {
self.mem.check_io_errors()?;
let guard = TransactionGuard::new_write(
self.transaction_tracker.start_write_transaction(),
self.transaction_tracker.clone(),
);
WriteTransaction::new(guard, self.transaction_tracker.clone(), self.mem.clone())
.map_err(|e| e.into())
}
pub fn begin_read(&self) -> Result<ReadTransaction, TransactionError> {
let guard = self.allocate_read_transaction()?;
#[cfg(feature = "logging")]
debug!("Beginning read transaction id={:?}", guard.id());
ReadTransaction::new(self.get_memory(), guard)
}
fn ensure_allocator_state_table(&self) -> Result<(), Error> {
if Self::get_allocator_state_table(&self.mem)?.is_some() {
return Ok(());
}
#[cfg(feature = "logging")]
debug!("Writing allocator state table");
let mut tx = self.begin_write()?;
tx.set_quick_repair(true);
tx.commit()?;
Ok(())
}
}
impl Drop for Database {
fn drop(&mut self) {
if thread::panicking() {
return;
}
if self.ensure_allocator_state_table().is_err() {
#[cfg(feature = "logging")]
warn!("Failed to write allocator state table. Repair may be required at restart.")
}
}
}
pub struct RepairSession {
progress: f64,
aborted: bool,
}
impl RepairSession {
pub(crate) fn new(progress: f64) -> Self {
Self {
progress,
aborted: false,
}
}
pub(crate) fn aborted(&self) -> bool {
self.aborted
}
pub fn abort(&mut self) {
self.aborted = true;
}
pub fn progress(&self) -> f64 {
self.progress
}
}
pub struct Builder {
page_size: usize,
region_size: Option<u64>,
read_cache_size_bytes: usize,
write_cache_size_bytes: usize,
repair_callback: Box<dyn Fn(&mut RepairSession)>,
}
impl Builder {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut result = Self {
page_size: PAGE_SIZE,
region_size: None,
read_cache_size_bytes: 0,
write_cache_size_bytes: 0,
repair_callback: Box::new(|_| {}),
};
result.set_cache_size(1024 * 1024 * 1024);
result
}
pub fn set_repair_callback(
&mut self,
callback: impl Fn(&mut RepairSession) + 'static,
) -> &mut Self {
self.repair_callback = Box::new(callback);
self
}
#[cfg(any(fuzzing, test))]
pub fn set_page_size(&mut self, size: usize) -> &mut Self {
assert!(size.is_power_of_two());
self.page_size = std::cmp::max(size, 512);
self
}
pub fn set_cache_size(&mut self, bytes: usize) -> &mut Self {
self.read_cache_size_bytes = bytes / 10 * 9;
self.write_cache_size_bytes = bytes / 10;
self
}
#[cfg(any(test, fuzzing))]
pub fn set_region_size(&mut self, size: u64) -> &mut Self {
assert!(size.is_power_of_two());
self.region_size = Some(size);
self
}
pub fn create(&self, path: impl AsRef<Path>) -> Result<Database, DatabaseError> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)?;
Database::new(
Box::new(FileBackend::new(file)?),
self.page_size,
self.region_size,
self.read_cache_size_bytes,
self.write_cache_size_bytes,
&self.repair_callback,
)
}
pub fn open(&self, path: impl AsRef<Path>) -> Result<Database, DatabaseError> {
let file = OpenOptions::new().read(true).write(true).open(path)?;
if file.metadata()?.len() == 0 {
return Err(StorageError::Io(ErrorKind::InvalidData.into()).into());
}
Database::new(
Box::new(FileBackend::new(file)?),
self.page_size,
None,
self.read_cache_size_bytes,
self.write_cache_size_bytes,
&self.repair_callback,
)
}
pub fn create_file(&self, file: File) -> Result<Database, DatabaseError> {
Database::new(
Box::new(FileBackend::new(file)?),
self.page_size,
self.region_size,
self.read_cache_size_bytes,
self.write_cache_size_bytes,
&self.repair_callback,
)
}
pub fn create_with_backend(
&self,
backend: impl StorageBackend,
) -> Result<Database, DatabaseError> {
Database::new(
Box::new(backend),
self.page_size,
self.region_size,
self.read_cache_size_bytes,
self.write_cache_size_bytes,
&self.repair_callback,
)
}
}
impl std::fmt::Debug for Database {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Database").finish()
}
}
#[cfg(test)]
mod test {
use crate::backends::FileBackend;
use crate::{
CommitError, Database, DatabaseError, Durability, ReadableTable, StorageBackend,
StorageError, TableDefinition, TransactionError,
};
use std::fs::File;
use std::io::{ErrorKind, Read, Seek, SeekFrom};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug)]
struct FailingBackend {
inner: FileBackend,
countdown: Arc<AtomicU64>,
}
impl FailingBackend {
fn new(backend: FileBackend, countdown: u64) -> Self {
Self {
inner: backend,
countdown: Arc::new(AtomicU64::new(countdown)),
}
}
fn check_countdown(&self) -> Result<(), std::io::Error> {
if self.countdown.load(Ordering::SeqCst) == 0 {
return Err(std::io::Error::from(ErrorKind::Other));
}
Ok(())
}
fn decrement_countdown(&self) -> Result<(), std::io::Error> {
if self
.countdown
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| {
if x > 0 {
Some(x - 1)
} else {
None
}
})
.is_err()
{
return Err(std::io::Error::from(ErrorKind::Other));
}
Ok(())
}
}
impl StorageBackend for FailingBackend {
fn len(&self) -> Result<u64, std::io::Error> {
self.inner.len()
}
fn read(&self, offset: u64, len: usize) -> Result<Vec<u8>, std::io::Error> {
self.check_countdown()?;
self.inner.read(offset, len)
}
fn set_len(&self, len: u64) -> Result<(), std::io::Error> {
self.inner.set_len(len)
}
fn sync_data(&self, eventual: bool) -> Result<(), std::io::Error> {
self.check_countdown()?;
self.inner.sync_data(eventual)
}
fn write(&self, offset: u64, data: &[u8]) -> Result<(), std::io::Error> {
self.decrement_countdown()?;
self.inner.write(offset, data)
}
}
#[test]
fn crash_regression4() {
let tmpfile = crate::create_tempfile();
let (file, path) = tmpfile.into_parts();
let backend = FailingBackend::new(FileBackend::new(file).unwrap(), 23);
let db = Database::builder()
.set_cache_size(12686)
.set_page_size(8 * 1024)
.set_region_size(32 * 4096)
.create_with_backend(backend)
.unwrap();
let table_def: TableDefinition<u64, &[u8]> = TableDefinition::new("x");
let tx = db.begin_write().unwrap();
let _savepoint = tx.ephemeral_savepoint().unwrap();
let _persistent_savepoint = tx.persistent_savepoint().unwrap();
tx.commit().unwrap();
let tx = db.begin_write().unwrap();
{
let mut table = tx.open_table(table_def).unwrap();
let _ = table.insert_reserve(118821, 360).unwrap();
}
let result = tx.commit();
assert!(result.is_err());
drop(db);
Database::builder()
.set_cache_size(1024 * 1024)
.set_page_size(8 * 1024)
.set_region_size(32 * 4096)
.create(&path)
.unwrap();
}
#[test]
fn transient_io_error() {
let tmpfile = crate::create_tempfile();
let (file, path) = tmpfile.into_parts();
let backend = FailingBackend::new(FileBackend::new(file).unwrap(), u64::MAX);
let countdown = backend.countdown.clone();
let db = Database::builder()
.set_cache_size(0)
.create_with_backend(backend)
.unwrap();
let table_def: TableDefinition<u64, u64> = TableDefinition::new("x");
let tx = db.begin_write().unwrap();
{
let mut table = tx.open_table(table_def).unwrap();
table.insert(0, 0).unwrap();
}
tx.commit().unwrap();
let tx = db.begin_write().unwrap();
{
let mut table = tx.open_table(table_def).unwrap();
table.insert(0, 1).unwrap();
}
tx.commit().unwrap();
let tx = db.begin_write().unwrap();
countdown.store(0, Ordering::SeqCst);
let result = tx.commit().err().unwrap();
assert!(matches!(result, CommitError::Storage(StorageError::Io(_))));
let result = db.begin_write().err().unwrap();
assert!(matches!(
result,
TransactionError::Storage(StorageError::PreviousIo)
));
countdown.store(u64::MAX, Ordering::SeqCst);
drop(db);
let mut file = File::open(&path).unwrap();
file.seek(SeekFrom::Start(9)).unwrap();
let mut god_byte = vec![0u8];
assert_eq!(file.read(&mut god_byte).unwrap(), 1);
assert_ne!(god_byte[0] & 2, 0);
}
#[test]
fn small_pages() {
let tmpfile = crate::create_tempfile();
let db = Database::builder()
.set_page_size(512)
.create(tmpfile.path())
.unwrap();
let table_definition: TableDefinition<u64, &[u8]> = TableDefinition::new("x");
let txn = db.begin_write().unwrap();
{
txn.open_table(table_definition).unwrap();
}
txn.commit().unwrap();
}
#[test]
fn small_pages2() {
let tmpfile = crate::create_tempfile();
let db = Database::builder()
.set_page_size(512)
.create(tmpfile.path())
.unwrap();
let table_def: TableDefinition<u64, &[u8]> = TableDefinition::new("x");
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
let savepoint0 = tx.ephemeral_savepoint().unwrap();
{
tx.open_table(table_def).unwrap();
}
tx.commit().unwrap();
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
let savepoint1 = tx.ephemeral_savepoint().unwrap();
tx.restore_savepoint(&savepoint0).unwrap();
tx.set_durability(Durability::None);
{
let mut t = tx.open_table(table_def).unwrap();
t.insert_reserve(&660503, 489).unwrap().as_mut().fill(0xFF);
assert!(t.remove(&291295).unwrap().is_none());
}
tx.commit().unwrap();
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
tx.restore_savepoint(&savepoint0).unwrap();
{
tx.open_table(table_def).unwrap();
}
tx.commit().unwrap();
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
let savepoint2 = tx.ephemeral_savepoint().unwrap();
drop(savepoint0);
tx.restore_savepoint(&savepoint2).unwrap();
{
let mut t = tx.open_table(table_def).unwrap();
assert!(t.get(&2059).unwrap().is_none());
assert!(t.remove(&145227).unwrap().is_none());
assert!(t.remove(&145227).unwrap().is_none());
}
tx.commit().unwrap();
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
let savepoint3 = tx.ephemeral_savepoint().unwrap();
drop(savepoint1);
tx.restore_savepoint(&savepoint3).unwrap();
{
tx.open_table(table_def).unwrap();
}
tx.commit().unwrap();
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
let savepoint4 = tx.ephemeral_savepoint().unwrap();
drop(savepoint2);
tx.restore_savepoint(&savepoint3).unwrap();
tx.set_durability(Durability::None);
{
let mut t = tx.open_table(table_def).unwrap();
assert!(t.remove(&207936).unwrap().is_none());
}
tx.abort().unwrap();
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
let savepoint5 = tx.ephemeral_savepoint().unwrap();
drop(savepoint3);
assert!(tx.restore_savepoint(&savepoint4).is_err());
{
tx.open_table(table_def).unwrap();
}
tx.commit().unwrap();
let mut tx = db.begin_write().unwrap();
tx.set_two_phase_commit(true);
tx.restore_savepoint(&savepoint5).unwrap();
tx.set_durability(Durability::None);
{
tx.open_table(table_def).unwrap();
}
tx.commit().unwrap();
}
#[test]
fn small_pages3() {
let tmpfile = crate::create_tempfile();
let db = Database::builder()
.set_page_size(1024)
.create(tmpfile.path())
.unwrap();
let table_def: TableDefinition<u64, &[u8]> = TableDefinition::new("x");
let mut tx = db.begin_write().unwrap();
let _savepoint0 = tx.ephemeral_savepoint().unwrap();
tx.set_durability(Durability::None);
{
let mut t = tx.open_table(table_def).unwrap();
let value = vec![0; 306];
t.insert(&539717, value.as_slice()).unwrap();
}
tx.abort().unwrap();
let mut tx = db.begin_write().unwrap();
let savepoint1 = tx.ephemeral_savepoint().unwrap();
tx.restore_savepoint(&savepoint1).unwrap();
tx.set_durability(Durability::None);
{
let mut t = tx.open_table(table_def).unwrap();
let value = vec![0; 2008];
t.insert(&784384, value.as_slice()).unwrap();
}
tx.abort().unwrap();
}
#[test]
fn small_pages4() {
let tmpfile = crate::create_tempfile();
let db = Database::builder()
.set_cache_size(1024 * 1024)
.set_page_size(1024)
.create(tmpfile.path())
.unwrap();
let table_def: TableDefinition<u64, &[u8]> = TableDefinition::new("x");
let tx = db.begin_write().unwrap();
{
tx.open_table(table_def).unwrap();
}
tx.commit().unwrap();
let tx = db.begin_write().unwrap();
{
let mut t = tx.open_table(table_def).unwrap();
assert!(t.get(&131072).unwrap().is_none());
let value = vec![0xFF; 1130];
t.insert(&42394, value.as_slice()).unwrap();
t.insert_reserve(&744037, 3645).unwrap().as_mut().fill(0xFF);
assert!(t.get(&0).unwrap().is_none());
}
tx.abort().unwrap();
let tx = db.begin_write().unwrap();
{
let mut t = tx.open_table(table_def).unwrap();
t.insert_reserve(&118749, 734).unwrap().as_mut().fill(0xFF);
}
tx.abort().unwrap();
}
#[test]
fn dynamic_shrink() {
let tmpfile = crate::create_tempfile();
let table_definition: TableDefinition<u64, &[u8]> = TableDefinition::new("x");
let big_value = vec![0u8; 1024];
let db = Database::builder()
.set_region_size(1024 * 1024)
.create(tmpfile.path())
.unwrap();
let txn = db.begin_write().unwrap();
{
let mut table = txn.open_table(table_definition).unwrap();
for i in 0..2048 {
table.insert(&i, big_value.as_slice()).unwrap();
}
}
txn.commit().unwrap();
let file_size = tmpfile.as_file().metadata().unwrap().len();
let txn = db.begin_write().unwrap();
{
let mut table = txn.open_table(table_definition).unwrap();
for i in 0..2048 {
table.remove(&i).unwrap();
}
}
txn.commit().unwrap();
let txn = db.begin_write().unwrap();
{
let mut table = txn.open_table(table_definition).unwrap();
table.insert(0, [].as_slice()).unwrap();
}
txn.commit().unwrap();
let txn = db.begin_write().unwrap();
{
let mut table = txn.open_table(table_definition).unwrap();
table.remove(0).unwrap();
}
txn.commit().unwrap();
let txn = db.begin_write().unwrap();
txn.commit().unwrap();
let final_file_size = tmpfile.as_file().metadata().unwrap().len();
assert!(final_file_size < file_size);
}
#[test]
fn create_new_db_in_empty_file() {
let tmpfile = crate::create_tempfile();
let _db = Database::builder()
.create_file(tmpfile.into_file())
.unwrap();
}
#[test]
fn open_missing_file() {
let tmpfile = crate::create_tempfile();
let err = Database::builder()
.open(tmpfile.path().with_extension("missing"))
.unwrap_err();
match err {
DatabaseError::Storage(StorageError::Io(err)) if err.kind() == ErrorKind::NotFound => {}
err => panic!("Unexpected error for empty file: {err}"),
}
}
#[test]
fn open_empty_file() {
let tmpfile = crate::create_tempfile();
let err = Database::builder().open(tmpfile.path()).unwrap_err();
match err {
DatabaseError::Storage(StorageError::Io(err))
if err.kind() == ErrorKind::InvalidData => {}
err => panic!("Unexpected error for empty file: {err}"),
}
}
}