1use std::{fmt::Debug, marker::PhantomData};
19
20use bytes::{Buf, BufMut, BytesMut};
21use tokio_util::codec::{Decoder, Encoder};
22
23use cuprate_helper::cast::u64_to_usize;
24
25use crate::{
26 header::{Flags, HEADER_SIZE},
27 message::{make_dummy_message, LevinMessage},
28 Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, MessageType, Protocol,
29};
30
31#[derive(Debug, Clone)]
32pub enum LevinBucketState<C> {
33 WaitingForHeader,
35 WaitingForBody(BucketHead<C>),
37}
38
39#[derive(Debug, Clone)]
42pub struct LevinBucketCodec<C> {
43 state: LevinBucketState<C>,
44 protocol: Protocol,
45 handshake_message_seen: bool,
46}
47
48impl<C> Default for LevinBucketCodec<C> {
49 fn default() -> Self {
50 Self {
51 state: LevinBucketState::WaitingForHeader,
52 protocol: Protocol::default(),
53 handshake_message_seen: false,
54 }
55 }
56}
57
58impl<C> LevinBucketCodec<C> {
59 pub const fn new(protocol: Protocol) -> Self {
60 Self {
61 state: LevinBucketState::WaitingForHeader,
62 protocol,
63 handshake_message_seen: false,
64 }
65 }
66}
67
68impl<C: LevinCommand + Debug> Decoder for LevinBucketCodec<C> {
69 type Item = Bucket<C>;
70 type Error = BucketError;
71 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
72 loop {
73 match &self.state {
74 LevinBucketState::WaitingForHeader => {
75 if src.len() < HEADER_SIZE {
76 return Ok(None);
77 };
78
79 let head = BucketHead::<C>::from_bytes(src);
80
81 #[cfg(feature = "tracing")]
82 tracing::trace!(
83 "Received new bucket header, command: {:?}, waiting for body, body len: {}",
84 head.command,
85 head.size
86 );
87
88 if head.size > self.protocol.max_packet_size
89 || head.size > head.command.bucket_size_limit()
90 {
91 #[cfg(feature = "tracing")]
92 tracing::debug!("Peer sent message which is too large.");
93
94 return Err(BucketError::BucketExceededMaxSize);
95 }
96
97 if !self.handshake_message_seen {
98 if head.size > self.protocol.max_packet_size_before_handshake {
99 #[cfg(feature = "tracing")]
100 tracing::debug!("Peer sent message which is too large.");
101
102 return Err(BucketError::BucketExceededMaxSize);
103 }
104
105 if head.command.is_handshake() {
106 #[cfg(feature = "tracing")]
107 tracing::debug!(
108 "Peer handshake message seen, increasing bucket size limit."
109 );
110
111 self.handshake_message_seen = true;
112 }
113 }
114
115 drop(std::mem::replace(
116 &mut self.state,
117 LevinBucketState::WaitingForBody(head),
118 ));
119 }
120 LevinBucketState::WaitingForBody(head) => {
121 let body_len = u64_to_usize(head.size);
122 if src.len() < body_len {
123 src.reserve(body_len - src.len());
124 return Ok(None);
125 }
126
127 let LevinBucketState::WaitingForBody(header) =
128 std::mem::replace(&mut self.state, LevinBucketState::WaitingForHeader)
129 else {
130 unreachable!()
131 };
132
133 #[cfg(feature = "tracing")]
134 tracing::trace!("Received full bucket for command: {:?}", header.command);
135
136 return Ok(Some(Bucket {
137 header,
138 body: src.copy_to_bytes(body_len),
139 }));
140 }
141 }
142 }
143 }
144}
145
146impl<C: LevinCommand> Encoder<Bucket<C>> for LevinBucketCodec<C> {
147 type Error = BucketError;
148 fn encode(&mut self, item: Bucket<C>, dst: &mut BytesMut) -> Result<(), Self::Error> {
149 if let Some(additional) = (HEADER_SIZE + item.body.len()).checked_sub(dst.capacity()) {
150 dst.reserve(additional);
151 }
152
153 item.header.write_bytes_into(dst);
154 dst.put_slice(&item.body);
155 Ok(())
156 }
157}
158
159#[derive(Default, Debug, Clone)]
160enum MessageState {
161 #[default]
162 WaitingForBucket,
163 WaitingForRestOfFragment(Vec<u8>),
169}
170
171#[derive(Debug, Clone)]
174pub struct LevinMessageCodec<T: LevinBody> {
175 message_ty: PhantomData<T>,
176 bucket_codec: LevinBucketCodec<T::Command>,
177 state: MessageState,
178}
179
180impl<T: LevinBody> Default for LevinMessageCodec<T> {
181 fn default() -> Self {
182 Self {
183 message_ty: Default::default(),
184 bucket_codec: Default::default(),
185 state: Default::default(),
186 }
187 }
188}
189
190impl<T: LevinBody> Decoder for LevinMessageCodec<T> {
191 type Item = T;
192 type Error = BucketError;
193 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
194 loop {
195 match &mut self.state {
196 MessageState::WaitingForBucket => {
197 let Some(mut bucket) = self.bucket_codec.decode(src)? else {
198 return Ok(None);
199 };
200
201 let flags = &bucket.header.flags;
202
203 if flags.contains(Flags::DUMMY) {
204 #[cfg(feature = "tracing")]
207 tracing::trace!("Received DUMMY bucket from peer, ignoring.");
208 continue;
210 };
211
212 if flags.contains(Flags::END_FRAGMENT) {
213 return Err(BucketError::InvalidHeaderFlags(
214 "Flag end fragment received before a start fragment",
215 ));
216 };
217
218 if flags.contains(Flags::START_FRAGMENT) {
219 #[cfg(feature = "tracing")]
223 tracing::debug!("Bucket is a fragment, waiting for rest of message.");
224
225 self.state = MessageState::WaitingForRestOfFragment(bucket.body.to_vec());
226
227 continue;
228 }
229
230 let message_type = MessageType::from_flags_and_have_to_return(
233 bucket.header.flags,
234 bucket.header.have_to_return_data,
235 )?;
236
237 return Ok(Some(T::decode_message(
238 &mut bucket.body,
239 message_type,
240 bucket.header.command,
241 )?));
242 }
243 MessageState::WaitingForRestOfFragment(bytes) => {
244 let Some(bucket) = self.bucket_codec.decode(src)? else {
245 return Ok(None);
246 };
247
248 let flags = &bucket.header.flags;
249
250 if flags.contains(Flags::DUMMY) {
251 #[cfg(feature = "tracing")]
254 tracing::trace!("Received DUMMY bucket from peer, ignoring.");
255 continue;
257 };
258
259 let max_size = u64_to_usize(if self.bucket_codec.handshake_message_seen {
260 self.bucket_codec.protocol.max_packet_size
261 } else {
262 self.bucket_codec.protocol.max_packet_size_before_handshake
263 });
264
265 if bytes.len().saturating_add(bucket.body.len()) > max_size {
266 return Err(BucketError::InvalidFragmentedMessage(
267 "Fragmented message exceeded maximum size",
268 ));
269 }
270
271 #[cfg(feature = "tracing")]
272 tracing::trace!("Received another bucket fragment.");
273
274 bytes.extend_from_slice(bucket.body.as_ref());
275
276 if flags.contains(Flags::END_FRAGMENT) {
277 drop(bucket);
279
280 let MessageState::WaitingForRestOfFragment(bytes) =
281 std::mem::replace(&mut self.state, MessageState::WaitingForBucket)
282 else {
283 unreachable!();
284 };
285
286 if bytes.len() < HEADER_SIZE {
288 return Err(BucketError::InvalidFragmentedMessage(
289 "Fragmented message is not large enough to build a bucket.",
290 ));
291 }
292
293 let mut header_bytes = BytesMut::from(&bytes[0..HEADER_SIZE]);
294
295 let header = BucketHead::<T::Command>::from_bytes(&mut header_bytes);
296
297 if header.size > header.command.bucket_size_limit() {
298 return Err(BucketError::BucketExceededMaxSize);
299 }
300
301 if bytes.len().saturating_sub(HEADER_SIZE) < u64_to_usize(header.size) {
303 return Err(BucketError::InvalidFragmentedMessage(
304 "Fragmented message does not have enough bytes to fill bucket body",
305 ));
306 }
307
308 #[cfg(feature = "tracing")]
309 tracing::debug!(
310 "Received final fragment, combined message command: {:?}.",
311 header.command
312 );
313
314 let message_type = MessageType::from_flags_and_have_to_return(
315 header.flags,
316 header.have_to_return_data,
317 )?;
318
319 if header.command.is_handshake() {
320 #[cfg(feature = "tracing")]
321 tracing::debug!(
322 "Peer handshake message seen, increasing bucket size limit."
323 );
324
325 self.bucket_codec.handshake_message_seen = true;
326 }
327
328 return Ok(Some(T::decode_message(
329 &mut &bytes[HEADER_SIZE..],
330 message_type,
331 header.command,
332 )?));
333 }
334 }
335 }
336 }
337 }
338}
339
340impl<T: LevinBody> Encoder<LevinMessage<T>> for LevinMessageCodec<T> {
341 type Error = BucketError;
342 fn encode(&mut self, item: LevinMessage<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
343 match item {
344 LevinMessage::Body(body) => {
345 let mut bucket_builder = BucketBuilder::new(&self.bucket_codec.protocol);
346 body.encode(&mut bucket_builder)?;
347 let bucket = bucket_builder.finish();
348 self.bucket_codec.encode(bucket, dst)
349 }
350 LevinMessage::Bucket(bucket) => self.bucket_codec.encode(bucket, dst),
351 LevinMessage::Dummy(size) => {
352 let bucket = make_dummy_message(&self.bucket_codec.protocol, size);
353 self.bucket_codec.encode(bucket, dst)
354 }
355 }
356 }
357}