cuprate_fixed_bytes/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use core::{
4    fmt::{Debug, Formatter},
5    ops::{Deref, Index},
6};
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Deserializer, Serialize};
12
13#[cfg_attr(feature = "std", derive(thiserror::Error))]
14#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
15#[derive(Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
16pub enum FixedByteError {
17    #[cfg_attr(
18        feature = "std",
19        error("Cannot create fix byte array, input has invalid length.")
20    )]
21    InvalidLength,
22}
23
24impl FixedByteError {
25    const fn field_name(&self) -> &'static str {
26        match self {
27            Self::InvalidLength => "input",
28        }
29    }
30
31    const fn field_data(&self) -> &'static str {
32        match self {
33            Self::InvalidLength => "Cannot create fix byte array, input has invalid length.",
34        }
35    }
36}
37
38impl Debug for FixedByteError {
39    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
40        f.debug_struct("FixedByteError")
41            .field(self.field_name(), &self.field_data())
42            .finish()
43    }
44}
45
46/// A fixed size byte slice.
47///
48/// Internally this is just a wrapper around [`Bytes`], with the constructors checking that the length is equal to `N`.
49/// This implements [`Deref`] with the target being `[u8; N]`.
50#[derive(Debug, Default, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
51#[cfg_attr(feature = "serde", derive(Serialize))]
52#[cfg_attr(feature = "serde", serde(transparent))]
53#[repr(transparent)]
54pub struct ByteArray<const N: usize>(Bytes);
55
56#[cfg(feature = "serde")]
57impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
58    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59    where
60        D: Deserializer<'de>,
61    {
62        let bytes = Bytes::deserialize(deserializer)?;
63        let len = bytes.len();
64        if len == N {
65            Ok(Self(bytes))
66        } else {
67            Err(serde::de::Error::invalid_length(
68                len,
69                &N.to_string().as_str(),
70            ))
71        }
72    }
73}
74
75impl<const N: usize> ByteArray<N> {
76    pub fn take_bytes(self) -> Bytes {
77        self.0
78    }
79}
80
81impl<const N: usize> From<[u8; N]> for ByteArray<N> {
82    fn from(value: [u8; N]) -> Self {
83        Self(Bytes::copy_from_slice(&value))
84    }
85}
86
87impl<const N: usize> Deref for ByteArray<N> {
88    type Target = [u8; N];
89
90    fn deref(&self) -> &Self::Target {
91        self.0.deref().try_into().unwrap()
92    }
93}
94
95impl<const N: usize> TryFrom<Bytes> for ByteArray<N> {
96    type Error = FixedByteError;
97
98    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
99        if value.len() != N {
100            return Err(FixedByteError::InvalidLength);
101        }
102        Ok(Self(value))
103    }
104}
105
106impl<const N: usize> TryFrom<Vec<u8>> for ByteArray<N> {
107    type Error = FixedByteError;
108
109    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
110        if value.len() != N {
111            return Err(FixedByteError::InvalidLength);
112        }
113        Ok(Self(Bytes::from(value)))
114    }
115}
116
117#[derive(Debug, Default, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
118#[cfg_attr(feature = "serde", derive(Serialize))]
119#[cfg_attr(feature = "serde", serde(transparent))]
120#[repr(transparent)]
121pub struct ByteArrayVec<const N: usize>(Bytes);
122
123#[cfg(feature = "serde")]
124impl<'de, const N: usize> Deserialize<'de> for ByteArrayVec<N> {
125    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126    where
127        D: Deserializer<'de>,
128    {
129        let bytes = Bytes::deserialize(deserializer)?;
130        let len = bytes.len();
131        if len % N == 0 {
132            Ok(Self(bytes))
133        } else {
134            Err(serde::de::Error::invalid_length(
135                len,
136                &N.to_string().as_str(),
137            ))
138        }
139    }
140}
141
142impl<const N: usize> ByteArrayVec<N> {
143    pub const fn len(&self) -> usize {
144        self.0.len() / N
145    }
146
147    pub const fn is_empty(&self) -> bool {
148        self.len() == 0
149    }
150
151    pub fn take_bytes(self) -> Bytes {
152        self.0
153    }
154
155    /// Splits the byte array vec into two at the given index.
156    ///
157    /// Afterwards self contains elements [0, at), and the returned [`ByteArrayVec`] contains elements [at, len).
158    ///
159    /// This is an O(1) operation that just increases the reference count and sets a few indices.
160    ///
161    /// # Panics
162    /// Panics if at > len.
163    #[must_use]
164    pub fn split_off(&mut self, at: usize) -> Self {
165        Self(self.0.split_off(at * N))
166    }
167}
168
169impl<const N: usize> From<&ByteArrayVec<N>> for Vec<[u8; N]> {
170    fn from(value: &ByteArrayVec<N>) -> Self {
171        let mut out = Self::with_capacity(value.len());
172        for i in 0..value.len() {
173            out.push(value[i]);
174        }
175
176        out
177    }
178}
179
180impl<const N: usize> From<Vec<[u8; N]>> for ByteArrayVec<N> {
181    fn from(value: Vec<[u8; N]>) -> Self {
182        let mut bytes = BytesMut::with_capacity(N * value.len());
183        for i in value {
184            bytes.extend_from_slice(&i);
185        }
186
187        Self(bytes.freeze())
188    }
189}
190
191impl<const N: usize> TryFrom<Bytes> for ByteArrayVec<N> {
192    type Error = FixedByteError;
193
194    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
195        if value.len() % N != 0 {
196            return Err(FixedByteError::InvalidLength);
197        }
198
199        Ok(Self(value))
200    }
201}
202
203impl<const N: usize> From<[u8; N]> for ByteArrayVec<N> {
204    fn from(value: [u8; N]) -> Self {
205        Self(Bytes::copy_from_slice(value.as_slice()))
206    }
207}
208
209impl<const N: usize, const LEN: usize> From<[[u8; N]; LEN]> for ByteArrayVec<N> {
210    fn from(value: [[u8; N]; LEN]) -> Self {
211        let mut bytes = BytesMut::with_capacity(N * LEN);
212
213        for val in value {
214            bytes.put_slice(val.as_slice());
215        }
216
217        Self(bytes.freeze())
218    }
219}
220
221impl<const N: usize> TryFrom<Vec<u8>> for ByteArrayVec<N> {
222    type Error = FixedByteError;
223
224    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
225        if value.len() % N != 0 {
226            return Err(FixedByteError::InvalidLength);
227        }
228
229        Ok(Self(Bytes::from(value)))
230    }
231}
232
233impl<const N: usize> Index<usize> for ByteArrayVec<N> {
234    type Output = [u8; N];
235
236    fn index(&self, index: usize) -> &Self::Output {
237        assert!(
238            (index + 1) * N <= self.0.len(),
239            "Index out of range, idx: {}, length: {}",
240            index,
241            self.len()
242        );
243
244        self.0[index * N..(index + 1) * N]
245            .as_ref()
246            .try_into()
247            .unwrap()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use serde_json::{from_str, to_string};
254
255    use super::*;
256
257    #[test]
258    fn byte_array_vec_len() {
259        let bytes = vec![0; 32 * 100];
260        let bytes = ByteArrayVec::<32>::try_from(Bytes::from(bytes)).unwrap();
261
262        assert_eq!(bytes.len(), 100);
263        let _ = bytes[99];
264    }
265
266    /// Tests that `serde` works on [`ByteArray`].
267    #[test]
268    #[cfg(feature = "serde")]
269    fn byte_array_serde() {
270        let b = ByteArray::from([1, 0, 0, 0, 1]);
271        let string = to_string(&b).unwrap();
272        assert_eq!(string, "[1,0,0,0,1]");
273        let b2 = from_str::<ByteArray<5>>(&string).unwrap();
274        assert_eq!(b, b2);
275    }
276
277    /// Tests that `serde` works on [`ByteArrayVec`].
278    #[test]
279    #[cfg(feature = "serde")]
280    fn byte_array_vec_serde() {
281        let b = ByteArrayVec::from([1, 0, 0, 0, 1]);
282        let string = to_string(&b).unwrap();
283        assert_eq!(string, "[1,0,0,0,1]");
284        let b2 = from_str::<ByteArrayVec<5>>(&string).unwrap();
285        assert_eq!(b, b2);
286    }
287
288    /// Tests that bad input `serde` fails on [`ByteArray`].
289    #[test]
290    #[cfg(feature = "serde")]
291    #[should_panic(
292        expected = r#"called `Result::unwrap()` on an `Err` value: Error("invalid length 4, expected 5", line: 0, column: 0)"#
293    )]
294    fn byte_array_bad_deserialize() {
295        from_str::<ByteArray<5>>("[1,0,0,0]").unwrap();
296    }
297
298    /// Tests that bad input `serde` fails on [`ByteArrayVec`].
299    #[test]
300    #[cfg(feature = "serde")]
301    #[should_panic(
302        expected = r#"called `Result::unwrap()` on an `Err` value: Error("invalid length 4, expected 5", line: 0, column: 0)"#
303    )]
304    fn byte_array_vec_bad_deserialize() {
305        from_str::<ByteArrayVec<5>>("[1,0,0,0]").unwrap();
306    }
307}