cuprate_levin/
codec.rs

1// Rust Levin Library
2// Written in 2023 by
3//   Cuprate Contributors
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15
16//! A tokio-codec for levin buckets
17
18use std::{fmt::Debug, marker::PhantomData};
19
20use bytes::{Buf, BufMut, BytesMut};
21use tokio_util::codec::{Decoder, Encoder};
22
23use cuprate_helper::cast::u64_to_usize;
24
25use crate::{
26    header::{Flags, HEADER_SIZE},
27    message::{make_dummy_message, LevinMessage},
28    Bucket, BucketBuilder, BucketError, BucketHead, LevinBody, LevinCommand, MessageType, Protocol,
29};
30
31#[derive(Debug, Clone)]
32pub enum LevinBucketState<C> {
33    /// Waiting for the peer to send a header.
34    WaitingForHeader,
35    /// Waiting for a peer to send a body.
36    WaitingForBody(BucketHead<C>),
37}
38
39/// The levin tokio-codec for decoding and encoding raw levin buckets
40///
41#[derive(Debug, Clone)]
42pub struct LevinBucketCodec<C> {
43    state: LevinBucketState<C>,
44    protocol: Protocol,
45    handshake_message_seen: bool,
46}
47
48impl<C> Default for LevinBucketCodec<C> {
49    fn default() -> Self {
50        Self {
51            state: LevinBucketState::WaitingForHeader,
52            protocol: Protocol::default(),
53            handshake_message_seen: false,
54        }
55    }
56}
57
58impl<C> LevinBucketCodec<C> {
59    pub const fn new(protocol: Protocol) -> Self {
60        Self {
61            state: LevinBucketState::WaitingForHeader,
62            protocol,
63            handshake_message_seen: false,
64        }
65    }
66}
67
68impl<C: LevinCommand + Debug> Decoder for LevinBucketCodec<C> {
69    type Item = Bucket<C>;
70    type Error = BucketError;
71    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
72        loop {
73            match &self.state {
74                LevinBucketState::WaitingForHeader => {
75                    if src.len() < HEADER_SIZE {
76                        return Ok(None);
77                    };
78
79                    let head = BucketHead::<C>::from_bytes(src);
80
81                    #[cfg(feature = "tracing")]
82                    tracing::trace!(
83                        "Received new bucket header, command: {:?}, waiting for body, body len: {}",
84                        head.command,
85                        head.size
86                    );
87
88                    if head.size > self.protocol.max_packet_size
89                        || head.size > head.command.bucket_size_limit()
90                    {
91                        #[cfg(feature = "tracing")]
92                        tracing::debug!("Peer sent message which is too large.");
93
94                        return Err(BucketError::BucketExceededMaxSize);
95                    }
96
97                    if !self.handshake_message_seen {
98                        if head.size > self.protocol.max_packet_size_before_handshake {
99                            #[cfg(feature = "tracing")]
100                            tracing::debug!("Peer sent message which is too large.");
101
102                            return Err(BucketError::BucketExceededMaxSize);
103                        }
104
105                        if head.command.is_handshake() {
106                            #[cfg(feature = "tracing")]
107                            tracing::debug!(
108                                "Peer handshake message seen, increasing bucket size limit."
109                            );
110
111                            self.handshake_message_seen = true;
112                        }
113                    }
114
115                    drop(std::mem::replace(
116                        &mut self.state,
117                        LevinBucketState::WaitingForBody(head),
118                    ));
119                }
120                LevinBucketState::WaitingForBody(head) => {
121                    let body_len = u64_to_usize(head.size);
122                    if src.len() < body_len {
123                        src.reserve(body_len - src.len());
124                        return Ok(None);
125                    }
126
127                    let LevinBucketState::WaitingForBody(header) =
128                        std::mem::replace(&mut self.state, LevinBucketState::WaitingForHeader)
129                    else {
130                        unreachable!()
131                    };
132
133                    #[cfg(feature = "tracing")]
134                    tracing::trace!("Received full bucket for command: {:?}", header.command);
135
136                    return Ok(Some(Bucket {
137                        header,
138                        body: src.copy_to_bytes(body_len),
139                    }));
140                }
141            }
142        }
143    }
144}
145
146impl<C: LevinCommand> Encoder<Bucket<C>> for LevinBucketCodec<C> {
147    type Error = BucketError;
148    fn encode(&mut self, item: Bucket<C>, dst: &mut BytesMut) -> Result<(), Self::Error> {
149        if let Some(additional) = (HEADER_SIZE + item.body.len()).checked_sub(dst.capacity()) {
150            dst.reserve(additional);
151        }
152
153        item.header.write_bytes_into(dst);
154        dst.put_slice(&item.body);
155        Ok(())
156    }
157}
158
159#[derive(Default, Debug, Clone)]
160enum MessageState {
161    #[default]
162    WaitingForBucket,
163    /// Waiting for the rest of a fragmented message.
164    ///
165    /// We keep the fragmented message as a Vec<u8> instead of [`Bytes`](bytes::Bytes) as [`Bytes`](bytes::Bytes) could point to a
166    /// large allocation even if the [`Bytes`](bytes::Bytes) itself is small, so is not safe to keep around for long.
167    /// To prevent this attack vector completely we just use Vec<u8> for fragmented messages.
168    WaitingForRestOfFragment(Vec<u8>),
169}
170
171/// A tokio-codec for levin messages or in other words the decoded body
172/// of a levin bucket.
173#[derive(Debug, Clone)]
174pub struct LevinMessageCodec<T: LevinBody> {
175    message_ty: PhantomData<T>,
176    bucket_codec: LevinBucketCodec<T::Command>,
177    state: MessageState,
178}
179
180impl<T: LevinBody> Default for LevinMessageCodec<T> {
181    fn default() -> Self {
182        Self {
183            message_ty: Default::default(),
184            bucket_codec: Default::default(),
185            state: Default::default(),
186        }
187    }
188}
189
190impl<T: LevinBody> Decoder for LevinMessageCodec<T> {
191    type Item = T;
192    type Error = BucketError;
193    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
194        loop {
195            match &mut self.state {
196                MessageState::WaitingForBucket => {
197                    let Some(mut bucket) = self.bucket_codec.decode(src)? else {
198                        return Ok(None);
199                    };
200
201                    let flags = &bucket.header.flags;
202
203                    if flags.contains(Flags::DUMMY) {
204                        // Dummy message
205
206                        #[cfg(feature = "tracing")]
207                        tracing::trace!("Received DUMMY bucket from peer, ignoring.");
208                        // We may have another bucket in `src`.
209                        continue;
210                    };
211
212                    if flags.contains(Flags::END_FRAGMENT) {
213                        return Err(BucketError::InvalidHeaderFlags(
214                            "Flag end fragment received before a start fragment",
215                        ));
216                    };
217
218                    if flags.contains(Flags::START_FRAGMENT) {
219                        // monerod does not require a start flag before starting a fragmented message,
220                        // but will always produce one, so it is ok for us to require one.
221
222                        #[cfg(feature = "tracing")]
223                        tracing::debug!("Bucket is a fragment, waiting for rest of message.");
224
225                        self.state = MessageState::WaitingForRestOfFragment(bucket.body.to_vec());
226
227                        continue;
228                    }
229
230                    // Normal, non fragmented bucket
231
232                    let message_type = MessageType::from_flags_and_have_to_return(
233                        bucket.header.flags,
234                        bucket.header.have_to_return_data,
235                    )?;
236
237                    return Ok(Some(T::decode_message(
238                        &mut bucket.body,
239                        message_type,
240                        bucket.header.command,
241                    )?));
242                }
243                MessageState::WaitingForRestOfFragment(bytes) => {
244                    let Some(bucket) = self.bucket_codec.decode(src)? else {
245                        return Ok(None);
246                    };
247
248                    let flags = &bucket.header.flags;
249
250                    if flags.contains(Flags::DUMMY) {
251                        // Dummy message
252
253                        #[cfg(feature = "tracing")]
254                        tracing::trace!("Received DUMMY bucket from peer, ignoring.");
255                        // We may have another bucket in `src`.
256                        continue;
257                    };
258
259                    let max_size = u64_to_usize(if self.bucket_codec.handshake_message_seen {
260                        self.bucket_codec.protocol.max_packet_size
261                    } else {
262                        self.bucket_codec.protocol.max_packet_size_before_handshake
263                    });
264
265                    if bytes.len().saturating_add(bucket.body.len()) > max_size {
266                        return Err(BucketError::InvalidFragmentedMessage(
267                            "Fragmented message exceeded maximum size",
268                        ));
269                    }
270
271                    #[cfg(feature = "tracing")]
272                    tracing::trace!("Received another bucket fragment.");
273
274                    bytes.extend_from_slice(bucket.body.as_ref());
275
276                    if flags.contains(Flags::END_FRAGMENT) {
277                        // make sure we only look at the internal bucket and don't use this.
278                        drop(bucket);
279
280                        let MessageState::WaitingForRestOfFragment(bytes) =
281                            std::mem::replace(&mut self.state, MessageState::WaitingForBucket)
282                        else {
283                            unreachable!();
284                        };
285
286                        // Check there are enough bytes in the fragment to build a header.
287                        if bytes.len() < HEADER_SIZE {
288                            return Err(BucketError::InvalidFragmentedMessage(
289                                "Fragmented message is not large enough to build a bucket.",
290                            ));
291                        }
292
293                        let mut header_bytes = BytesMut::from(&bytes[0..HEADER_SIZE]);
294
295                        let header = BucketHead::<T::Command>::from_bytes(&mut header_bytes);
296
297                        if header.size > header.command.bucket_size_limit() {
298                            return Err(BucketError::BucketExceededMaxSize);
299                        }
300
301                        // Check the fragmented message contains enough bytes to build the message.
302                        if bytes.len().saturating_sub(HEADER_SIZE) < u64_to_usize(header.size) {
303                            return Err(BucketError::InvalidFragmentedMessage(
304                                "Fragmented message does not have enough bytes to fill bucket body",
305                            ));
306                        }
307
308                        #[cfg(feature = "tracing")]
309                        tracing::debug!(
310                            "Received final fragment, combined message command: {:?}.",
311                            header.command
312                        );
313
314                        let message_type = MessageType::from_flags_and_have_to_return(
315                            header.flags,
316                            header.have_to_return_data,
317                        )?;
318
319                        if header.command.is_handshake() {
320                            #[cfg(feature = "tracing")]
321                            tracing::debug!(
322                                "Peer handshake message seen, increasing bucket size limit."
323                            );
324
325                            self.bucket_codec.handshake_message_seen = true;
326                        }
327
328                        return Ok(Some(T::decode_message(
329                            &mut &bytes[HEADER_SIZE..],
330                            message_type,
331                            header.command,
332                        )?));
333                    }
334                }
335            }
336        }
337    }
338}
339
340impl<T: LevinBody> Encoder<LevinMessage<T>> for LevinMessageCodec<T> {
341    type Error = BucketError;
342    fn encode(&mut self, item: LevinMessage<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
343        match item {
344            LevinMessage::Body(body) => {
345                let mut bucket_builder = BucketBuilder::new(&self.bucket_codec.protocol);
346                body.encode(&mut bucket_builder)?;
347                let bucket = bucket_builder.finish();
348                self.bucket_codec.encode(bucket, dst)
349            }
350            LevinMessage::Bucket(bucket) => self.bucket_codec.encode(bucket, dst),
351            LevinMessage::Dummy(size) => {
352                let bucket = make_dummy_message(&self.bucket_codec.protocol, size);
353                self.bucket_codec.encode(bucket, dst)
354            }
355        }
356    }
357}