1#![doc = include_str!("../README.md")]
2
3use core::{
4 fmt::{Debug, Formatter},
5 ops::{Deref, Index},
6};
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Deserializer, Serialize};
12
13#[cfg_attr(feature = "std", derive(thiserror::Error))]
14#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
15#[derive(Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
16pub enum FixedByteError {
17 #[cfg_attr(
18 feature = "std",
19 error("Cannot create fix byte array, input has invalid length.")
20 )]
21 InvalidLength,
22}
23
24impl FixedByteError {
25 const fn field_name(&self) -> &'static str {
26 match self {
27 Self::InvalidLength => "input",
28 }
29 }
30
31 const fn field_data(&self) -> &'static str {
32 match self {
33 Self::InvalidLength => "Cannot create fix byte array, input has invalid length.",
34 }
35 }
36}
37
38impl Debug for FixedByteError {
39 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
40 f.debug_struct("FixedByteError")
41 .field(self.field_name(), &self.field_data())
42 .finish()
43 }
44}
45
46#[derive(Debug, Default, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
51#[cfg_attr(feature = "serde", derive(Serialize))]
52#[cfg_attr(feature = "serde", serde(transparent))]
53#[repr(transparent)]
54pub struct ByteArray<const N: usize>(Bytes);
55
56#[cfg(feature = "serde")]
57impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
58 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59 where
60 D: Deserializer<'de>,
61 {
62 let bytes = Bytes::deserialize(deserializer)?;
63 let len = bytes.len();
64 if len == N {
65 Ok(Self(bytes))
66 } else {
67 Err(serde::de::Error::invalid_length(
68 len,
69 &N.to_string().as_str(),
70 ))
71 }
72 }
73}
74
75impl<const N: usize> ByteArray<N> {
76 pub fn take_bytes(self) -> Bytes {
77 self.0
78 }
79}
80
81impl<const N: usize> From<[u8; N]> for ByteArray<N> {
82 fn from(value: [u8; N]) -> Self {
83 Self(Bytes::copy_from_slice(&value))
84 }
85}
86
87impl<const N: usize> Deref for ByteArray<N> {
88 type Target = [u8; N];
89
90 fn deref(&self) -> &Self::Target {
91 self.0.deref().try_into().unwrap()
92 }
93}
94
95impl<const N: usize> TryFrom<Bytes> for ByteArray<N> {
96 type Error = FixedByteError;
97
98 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
99 if value.len() != N {
100 return Err(FixedByteError::InvalidLength);
101 }
102 Ok(Self(value))
103 }
104}
105
106impl<const N: usize> TryFrom<Vec<u8>> for ByteArray<N> {
107 type Error = FixedByteError;
108
109 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
110 if value.len() != N {
111 return Err(FixedByteError::InvalidLength);
112 }
113 Ok(Self(Bytes::from(value)))
114 }
115}
116
117#[derive(Debug, Default, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
118#[cfg_attr(feature = "serde", derive(Serialize))]
119#[cfg_attr(feature = "serde", serde(transparent))]
120#[repr(transparent)]
121pub struct ByteArrayVec<const N: usize>(Bytes);
122
123#[cfg(feature = "serde")]
124impl<'de, const N: usize> Deserialize<'de> for ByteArrayVec<N> {
125 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126 where
127 D: Deserializer<'de>,
128 {
129 let bytes = Bytes::deserialize(deserializer)?;
130 let len = bytes.len();
131 if len % N == 0 {
132 Ok(Self(bytes))
133 } else {
134 Err(serde::de::Error::invalid_length(
135 len,
136 &N.to_string().as_str(),
137 ))
138 }
139 }
140}
141
142impl<const N: usize> ByteArrayVec<N> {
143 pub const fn len(&self) -> usize {
144 self.0.len() / N
145 }
146
147 pub const fn is_empty(&self) -> bool {
148 self.len() == 0
149 }
150
151 pub fn take_bytes(self) -> Bytes {
152 self.0
153 }
154
155 #[must_use]
164 pub fn split_off(&mut self, at: usize) -> Self {
165 Self(self.0.split_off(at * N))
166 }
167}
168
169impl<const N: usize> From<&ByteArrayVec<N>> for Vec<[u8; N]> {
170 fn from(value: &ByteArrayVec<N>) -> Self {
171 let mut out = Self::with_capacity(value.len());
172 for i in 0..value.len() {
173 out.push(value[i]);
174 }
175
176 out
177 }
178}
179
180impl<const N: usize> From<Vec<[u8; N]>> for ByteArrayVec<N> {
181 fn from(value: Vec<[u8; N]>) -> Self {
182 let mut bytes = BytesMut::with_capacity(N * value.len());
183 for i in value {
184 bytes.extend_from_slice(&i);
185 }
186
187 Self(bytes.freeze())
188 }
189}
190
191impl<const N: usize> TryFrom<Bytes> for ByteArrayVec<N> {
192 type Error = FixedByteError;
193
194 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
195 if value.len() % N != 0 {
196 return Err(FixedByteError::InvalidLength);
197 }
198
199 Ok(Self(value))
200 }
201}
202
203impl<const N: usize> From<[u8; N]> for ByteArrayVec<N> {
204 fn from(value: [u8; N]) -> Self {
205 Self(Bytes::copy_from_slice(value.as_slice()))
206 }
207}
208
209impl<const N: usize, const LEN: usize> From<[[u8; N]; LEN]> for ByteArrayVec<N> {
210 fn from(value: [[u8; N]; LEN]) -> Self {
211 let mut bytes = BytesMut::with_capacity(N * LEN);
212
213 for val in value {
214 bytes.put_slice(val.as_slice());
215 }
216
217 Self(bytes.freeze())
218 }
219}
220
221impl<const N: usize> TryFrom<Vec<u8>> for ByteArrayVec<N> {
222 type Error = FixedByteError;
223
224 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
225 if value.len() % N != 0 {
226 return Err(FixedByteError::InvalidLength);
227 }
228
229 Ok(Self(Bytes::from(value)))
230 }
231}
232
233impl<const N: usize> Index<usize> for ByteArrayVec<N> {
234 type Output = [u8; N];
235
236 fn index(&self, index: usize) -> &Self::Output {
237 assert!(
238 (index + 1) * N <= self.0.len(),
239 "Index out of range, idx: {}, length: {}",
240 index,
241 self.len()
242 );
243
244 self.0[index * N..(index + 1) * N]
245 .as_ref()
246 .try_into()
247 .unwrap()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use serde_json::{from_str, to_string};
254
255 use super::*;
256
257 #[test]
258 fn byte_array_vec_len() {
259 let bytes = vec![0; 32 * 100];
260 let bytes = ByteArrayVec::<32>::try_from(Bytes::from(bytes)).unwrap();
261
262 assert_eq!(bytes.len(), 100);
263 let _ = bytes[99];
264 }
265
266 #[test]
268 #[cfg(feature = "serde")]
269 fn byte_array_serde() {
270 let b = ByteArray::from([1, 0, 0, 0, 1]);
271 let string = to_string(&b).unwrap();
272 assert_eq!(string, "[1,0,0,0,1]");
273 let b2 = from_str::<ByteArray<5>>(&string).unwrap();
274 assert_eq!(b, b2);
275 }
276
277 #[test]
279 #[cfg(feature = "serde")]
280 fn byte_array_vec_serde() {
281 let b = ByteArrayVec::from([1, 0, 0, 0, 1]);
282 let string = to_string(&b).unwrap();
283 assert_eq!(string, "[1,0,0,0,1]");
284 let b2 = from_str::<ByteArrayVec<5>>(&string).unwrap();
285 assert_eq!(b, b2);
286 }
287
288 #[test]
290 #[cfg(feature = "serde")]
291 #[should_panic(
292 expected = r#"called `Result::unwrap()` on an `Err` value: Error("invalid length 4, expected 5", line: 0, column: 0)"#
293 )]
294 fn byte_array_bad_deserialize() {
295 from_str::<ByteArray<5>>("[1,0,0,0]").unwrap();
296 }
297
298 #[test]
300 #[cfg(feature = "serde")]
301 #[should_panic(
302 expected = r#"called `Result::unwrap()` on an `Err` value: Error("invalid length 4, expected 5", line: 0, column: 0)"#
303 )]
304 fn byte_array_vec_bad_deserialize() {
305 from_str::<ByteArrayVec<5>>("[1,0,0,0]").unwrap();
306 }
307}