tokio_util/codec/
framed_impl.rs

1use crate::codec::decoder::Decoder;
2use crate::codec::encoder::Encoder;
3
4use futures_core::Stream;
5use tokio::io::{AsyncRead, AsyncWrite};
6
7use bytes::BytesMut;
8use futures_sink::Sink;
9use pin_project_lite::pin_project;
10use std::borrow::{Borrow, BorrowMut};
11use std::io;
12use std::pin::Pin;
13use std::task::{ready, Context, Poll};
14
15pin_project! {
16    #[derive(Debug)]
17    pub(crate) struct FramedImpl<T, U, State> {
18        #[pin]
19        pub(crate) inner: T,
20        pub(crate) state: State,
21        pub(crate) codec: U,
22    }
23}
24
25const INITIAL_CAPACITY: usize = 8 * 1024;
26
27#[derive(Debug)]
28pub(crate) struct ReadFrame {
29    pub(crate) eof: bool,
30    pub(crate) is_readable: bool,
31    pub(crate) buffer: BytesMut,
32    pub(crate) has_errored: bool,
33}
34
35pub(crate) struct WriteFrame {
36    pub(crate) buffer: BytesMut,
37    pub(crate) backpressure_boundary: usize,
38}
39
40#[derive(Default)]
41pub(crate) struct RWFrames {
42    pub(crate) read: ReadFrame,
43    pub(crate) write: WriteFrame,
44}
45
46impl Default for ReadFrame {
47    fn default() -> Self {
48        Self {
49            eof: false,
50            is_readable: false,
51            buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
52            has_errored: false,
53        }
54    }
55}
56
57impl Default for WriteFrame {
58    fn default() -> Self {
59        Self {
60            buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
61            backpressure_boundary: INITIAL_CAPACITY,
62        }
63    }
64}
65
66impl From<BytesMut> for ReadFrame {
67    fn from(mut buffer: BytesMut) -> Self {
68        let size = buffer.capacity();
69        if size < INITIAL_CAPACITY {
70            buffer.reserve(INITIAL_CAPACITY - size);
71        }
72
73        Self {
74            buffer,
75            is_readable: size > 0,
76            eof: false,
77            has_errored: false,
78        }
79    }
80}
81
82impl From<BytesMut> for WriteFrame {
83    fn from(mut buffer: BytesMut) -> Self {
84        let size = buffer.capacity();
85        if size < INITIAL_CAPACITY {
86            buffer.reserve(INITIAL_CAPACITY - size);
87        }
88
89        Self {
90            buffer,
91            backpressure_boundary: INITIAL_CAPACITY,
92        }
93    }
94}
95
96impl Borrow<ReadFrame> for RWFrames {
97    fn borrow(&self) -> &ReadFrame {
98        &self.read
99    }
100}
101impl BorrowMut<ReadFrame> for RWFrames {
102    fn borrow_mut(&mut self) -> &mut ReadFrame {
103        &mut self.read
104    }
105}
106impl Borrow<WriteFrame> for RWFrames {
107    fn borrow(&self) -> &WriteFrame {
108        &self.write
109    }
110}
111impl BorrowMut<WriteFrame> for RWFrames {
112    fn borrow_mut(&mut self) -> &mut WriteFrame {
113        &mut self.write
114    }
115}
116impl<T, U, R> Stream for FramedImpl<T, U, R>
117where
118    T: AsyncRead,
119    U: Decoder,
120    R: BorrowMut<ReadFrame>,
121{
122    type Item = Result<U::Item, U::Error>;
123
124    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125        use crate::util::poll_read_buf;
126
127        let mut pinned = self.project();
128        let state: &mut ReadFrame = pinned.state.borrow_mut();
129        // The following loops implements a state machine with each state corresponding
130        // to a combination of the `is_readable` and `eof` flags. States persist across
131        // loop entries and most state transitions occur with a return.
132        //
133        // The initial state is `reading`.
134        //
135        // | state   | eof   | is_readable | has_errored |
136        // |---------|-------|-------------|-------------|
137        // | reading | false | false       | false       |
138        // | framing | false | true        | false       |
139        // | pausing | true  | true        | false       |
140        // | paused  | true  | false       | false       |
141        // | errored | <any> | <any>       | true        |
142        //                                                       `decode_eof` returns Err
143        //                                          ┌────────────────────────────────────────────────────────┐
144        //                   `decode_eof` returns   │                                                        │
145        //                             `Ok(Some)`   │                                                        │
146        //                                 ┌─────┐  │     `decode_eof` returns               After returning │
147        //                Read 0 bytes     ├─────▼──┴┐    `Ok(None)`          ┌────────┐ ◄───┐ `None`    ┌───▼─────┐
148        //               ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐   └───────────┤ Errored │
149        //               │                 └─────────┘                        └─┬──▲───┘ │               └───▲───▲─┘
150        // Pending read  │                                                      │  │     │                   │   │
151        //     ┌──────┐  │            `decode` returns `Some`                   │  └─────┘                   │   │
152        //     │      │  │                   ┌──────┐                           │  Pending                   │   │
153        //     │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐     read n>0 bytes      │  read                      │   │
154        //     └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘                            │   │
155        //       └──┬─▲────┘                └─────┬──┬┘                                                      │   │
156        //          │ │                           │  │                 `decode` returns Err                  │   │
157        //          │ └───decode` returns `None`──┘  └───────────────────────────────────────────────────────┘   │
158        //          │                             read returns Err                                               │
159        //          └────────────────────────────────────────────────────────────────────────────────────────────┘
160        loop {
161            // Return `None` if we have encountered an error from the underlying decoder
162            // See: https://github.com/tokio-rs/tokio/issues/3976
163            if state.has_errored {
164                // preparing has_errored -> paused
165                trace!("Returning None and setting paused");
166                state.is_readable = false;
167                state.has_errored = false;
168                return Poll::Ready(None);
169            }
170
171            // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
172            // i.e. it _might_ contain data consumable as a frame or closing frame.
173            // Both signal that there is no such data by returning `None`.
174            //
175            // If `decode` couldn't read a frame and the upstream source has returned eof,
176            // `decode_eof` will attempt to decode the remaining bytes as closing frames.
177            //
178            // If the underlying AsyncRead is resumable, we may continue after an EOF,
179            // but must finish emitting all of it's associated `decode_eof` frames.
180            // Furthermore, we don't want to emit any `decode_eof` frames on retried
181            // reads after an EOF unless we've actually read more data.
182            if state.is_readable {
183                // pausing or framing
184                if state.eof {
185                    // pausing
186                    let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
187                        trace!("Got an error, going to errored state");
188                        state.has_errored = true;
189                        err
190                    })?;
191                    if frame.is_none() {
192                        state.is_readable = false; // prepare pausing -> paused
193                    }
194                    // implicit pausing -> pausing or pausing -> paused
195                    return Poll::Ready(frame.map(Ok));
196                }
197
198                // framing
199                trace!("attempting to decode a frame");
200
201                if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
202                    trace!("Got an error, going to errored state");
203                    state.has_errored = true;
204                    op
205                })? {
206                    trace!("frame decoded from buffer");
207                    // implicit framing -> framing
208                    return Poll::Ready(Some(Ok(frame)));
209                }
210
211                // framing -> reading
212                state.is_readable = false;
213            }
214            // reading or paused
215            // If we can't build a frame yet, try to read more data and try again.
216            // Make sure we've got room for at least one byte to read to ensure
217            // that we don't get a spurious 0 that looks like EOF.
218            state.buffer.reserve(1);
219            #[allow(clippy::blocks_in_conditions)]
220            let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
221                |err| {
222                    trace!("Got an error, going to errored state");
223                    state.has_errored = true;
224                    err
225                },
226            )? {
227                Poll::Ready(ct) => ct,
228                // implicit reading -> reading or implicit paused -> paused
229                Poll::Pending => return Poll::Pending,
230            };
231            if bytect == 0 {
232                if state.eof {
233                    // We're already at an EOF, and since we've reached this path
234                    // we're also not readable. This implies that we've already finished
235                    // our `decode_eof` handling, so we can simply return `None`.
236                    // implicit paused -> paused
237                    return Poll::Ready(None);
238                }
239                // prepare reading -> paused
240                state.eof = true;
241            } else {
242                // prepare paused -> framing or noop reading -> framing
243                state.eof = false;
244            }
245
246            // paused -> framing or reading -> framing or reading -> pausing
247            state.is_readable = true;
248        }
249    }
250}
251
252impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
253where
254    T: AsyncWrite,
255    U: Encoder<I>,
256    U::Error: From<io::Error>,
257    W: BorrowMut<WriteFrame>,
258{
259    type Error = U::Error;
260
261    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262        if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
263            self.as_mut().poll_flush(cx)
264        } else {
265            Poll::Ready(Ok(()))
266        }
267    }
268
269    fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
270        let pinned = self.project();
271        pinned
272            .codec
273            .encode(item, &mut pinned.state.borrow_mut().buffer)?;
274        Ok(())
275    }
276
277    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
278        use crate::util::poll_write_buf;
279        trace!("flushing framed transport");
280        let mut pinned = self.project();
281
282        while !pinned.state.borrow_mut().buffer.is_empty() {
283            let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
284            trace!(remaining = buffer.len(), "writing;");
285
286            let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
287
288            if n == 0 {
289                return Poll::Ready(Err(io::Error::new(
290                    io::ErrorKind::WriteZero,
291                    "failed to \
292                     write frame to transport",
293                )
294                .into()));
295            }
296        }
297
298        // Try flushing the underlying IO
299        ready!(pinned.inner.poll_flush(cx))?;
300
301        trace!("framed transport flushed");
302        Poll::Ready(Ok(()))
303    }
304
305    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
306        ready!(self.as_mut().poll_flush(cx))?;
307        ready!(self.project().inner.poll_shutdown(cx))?;
308
309        Poll::Ready(Ok(()))
310    }
311}