cuprate_database/backend/heed/
storable.rs

1//! `cuprate_database::Storable` <-> `heed` serde trait compatibility layer.
2
3//---------------------------------------------------------------------------------------------------- Use
4use std::{borrow::Cow, cmp::Ordering, marker::PhantomData};
5
6use heed::{BoxedError, BytesDecode, BytesEncode};
7
8use crate::{storable::Storable, Key};
9
10//---------------------------------------------------------------------------------------------------- StorableHeed
11/// The glue struct that implements `heed`'s (de)serialization
12/// traits on any type that implements `cuprate_database::Storable`.
13///
14/// Never actually gets constructed, just used for trait bound translations.
15pub(super) struct StorableHeed<T>(PhantomData<T>)
16where
17    T: Storable + ?Sized;
18
19//---------------------------------------------------------------------------------------------------- Key
20// If `Key` is also implemented, this can act as the comparison function.
21impl<T> heed::Comparator for StorableHeed<T>
22where
23    T: Key,
24{
25    #[inline]
26    fn compare(a: &[u8], b: &[u8]) -> Ordering {
27        <T as Key>::KEY_COMPARE.as_compare_fn::<T>()(a, b)
28    }
29}
30
31//---------------------------------------------------------------------------------------------------- BytesDecode/Encode
32impl<'a, T> BytesDecode<'a> for StorableHeed<T>
33where
34    T: Storable + 'static,
35{
36    type DItem = T;
37
38    #[inline]
39    /// This function is infallible (will always return `Ok`).
40    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
41        Ok(T::from_bytes(bytes))
42    }
43}
44
45impl<'a, T> BytesEncode<'a> for StorableHeed<T>
46where
47    T: Storable + ?Sized + 'a,
48{
49    type EItem = T;
50
51    #[inline]
52    /// This function is infallible (will always return `Ok`).
53    fn bytes_encode(item: &'a Self::EItem) -> Result<Cow<'a, [u8]>, BoxedError> {
54        Ok(Cow::Borrowed(item.as_bytes()))
55    }
56}
57
58//---------------------------------------------------------------------------------------------------- Tests
59#[cfg(test)]
60mod test {
61    use std::fmt::Debug;
62
63    use super::*;
64    use crate::{StorableBytes, StorableVec};
65
66    // Each `#[test]` function has a `test()` to:
67    // - log
68    // - simplify trait bounds
69    // - make sure the right function is being called
70
71    #[test]
72    /// Assert key comparison behavior is correct.
73    fn compare() {
74        fn test<T>(left: T, right: T, expected: Ordering)
75        where
76            T: Key + Ord + 'static,
77        {
78            println!("left: {left:?}, right: {right:?}, expected: {expected:?}");
79            assert_eq!(
80                <StorableHeed::<T> as heed::Comparator>::compare(
81                    &<StorableHeed::<T> as BytesEncode>::bytes_encode(&left).unwrap(),
82                    &<StorableHeed::<T> as BytesEncode>::bytes_encode(&right).unwrap()
83                ),
84                expected
85            );
86        }
87
88        // Value comparison
89        test::<u8>(0, 255, Ordering::Less);
90        test::<u16>(0, 256, Ordering::Less);
91        test::<u32>(0, 256, Ordering::Less);
92        test::<u64>(0, 256, Ordering::Less);
93        test::<u128>(0, 256, Ordering::Less);
94        test::<usize>(0, 256, Ordering::Less);
95        test::<i8>(-1, 2, Ordering::Less);
96        test::<i16>(-1, 2, Ordering::Less);
97        test::<i32>(-1, 2, Ordering::Less);
98        test::<i64>(-1, 2, Ordering::Less);
99        test::<i128>(-1, 2, Ordering::Less);
100        test::<isize>(-1, 2, Ordering::Less);
101
102        // Byte comparison
103        test::<[u8; 2]>([1, 1], [1, 0], Ordering::Greater);
104        test::<[u8; 3]>([1, 2, 3], [1, 2, 3], Ordering::Equal);
105    }
106
107    #[test]
108    /// Assert `BytesEncode::bytes_encode` is accurate.
109    fn bytes_encode() {
110        fn test<T>(t: &T, expected: &[u8])
111        where
112            T: Storable + ?Sized,
113        {
114            println!("t: {t:?}, expected: {expected:?}");
115            assert_eq!(
116                <StorableHeed::<T> as BytesEncode>::bytes_encode(t).unwrap(),
117                expected
118            );
119        }
120
121        test::<()>(&(), &[]);
122        test::<u8>(&0, &[0]);
123        test::<u16>(&1, &[1, 0]);
124        test::<u32>(&2, &[2, 0, 0, 0]);
125        test::<u64>(&3, &[3, 0, 0, 0, 0, 0, 0, 0]);
126        test::<i8>(&-1, &[255]);
127        test::<i16>(&-2, &[254, 255]);
128        test::<i32>(&-3, &[253, 255, 255, 255]);
129        test::<i64>(&-4, &[252, 255, 255, 255, 255, 255, 255, 255]);
130        test::<StorableVec<u8>>(&StorableVec(vec![1, 2]), &[1, 2]);
131        test::<StorableBytes>(&StorableBytes(bytes::Bytes::from_static(&[1, 2])), &[1, 2]);
132        test::<[u8; 0]>(&[], &[]);
133        test::<[u8; 1]>(&[255], &[255]);
134        test::<[u8; 2]>(&[111, 0], &[111, 0]);
135        test::<[u8; 3]>(&[1, 0, 1], &[1, 0, 1]);
136    }
137
138    #[test]
139    /// Assert `BytesDecode::bytes_decode` is accurate.
140    fn bytes_decode() {
141        fn test<T>(bytes: &[u8], expected: &T)
142        where
143            T: Storable + PartialEq + ToOwned + Debug + 'static,
144            T::Owned: Debug,
145        {
146            println!("bytes: {bytes:?}, expected: {expected:?}");
147            assert_eq!(
148                &<StorableHeed::<T> as BytesDecode>::bytes_decode(bytes).unwrap(),
149                expected
150            );
151        }
152
153        test::<()>([].as_slice(), &());
154        test::<u8>([0].as_slice(), &0);
155        test::<u16>([1, 0].as_slice(), &1);
156        test::<u32>([2, 0, 0, 0].as_slice(), &2);
157        test::<u64>([3, 0, 0, 0, 0, 0, 0, 0].as_slice(), &3);
158        test::<i8>([255].as_slice(), &-1);
159        test::<i16>([254, 255].as_slice(), &-2);
160        test::<i32>([253, 255, 255, 255].as_slice(), &-3);
161        test::<i64>([252, 255, 255, 255, 255, 255, 255, 255].as_slice(), &-4);
162        test::<StorableVec<u8>>(&[1, 2], &StorableVec(vec![1, 2]));
163        test::<StorableBytes>(&[1, 2], &StorableBytes(bytes::Bytes::from_static(&[1, 2])));
164        test::<[u8; 0]>([].as_slice(), &[]);
165        test::<[u8; 1]>([255].as_slice(), &[255]);
166        test::<[u8; 2]>([111, 0].as_slice(), &[111, 0]);
167        test::<[u8; 3]>([1, 0, 1].as_slice(), &[1, 0, 1]);
168    }
169}