h2/codec/
framed_write.rs

1use crate::codec::UserError;
2use crate::codec::UserError::*;
3use crate::frame::{self, Frame, FrameSize};
4use crate::hpack;
5
6use bytes::{Buf, BufMut, BytesMut};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio_util::io::poll_write_buf;
11
12use std::io::{self, Cursor};
13
14// A macro to get around a method needing to borrow &mut self
15macro_rules! limited_write_buf {
16    ($self:expr) => {{
17        let limit = $self.max_frame_size() + frame::HEADER_LEN;
18        $self.buf.get_mut().limit(limit)
19    }};
20}
21
22#[derive(Debug)]
23pub struct FramedWrite<T, B> {
24    /// Upstream `AsyncWrite`
25    inner: T,
26    final_flush_done: bool,
27
28    encoder: Encoder<B>,
29}
30
31#[derive(Debug)]
32struct Encoder<B> {
33    /// HPACK encoder
34    hpack: hpack::Encoder,
35
36    /// Write buffer
37    ///
38    /// TODO: Should this be a ring buffer?
39    buf: Cursor<BytesMut>,
40
41    /// Next frame to encode
42    next: Option<Next<B>>,
43
44    /// Last data frame
45    last_data_frame: Option<frame::Data<B>>,
46
47    /// Max frame size, this is specified by the peer
48    max_frame_size: FrameSize,
49
50    /// Chain payloads bigger than this.
51    chain_threshold: usize,
52
53    /// Min buffer required to attempt to write a frame
54    min_buffer_capacity: usize,
55}
56
57#[derive(Debug)]
58enum Next<B> {
59    Data(frame::Data<B>),
60    Continuation(frame::Continuation),
61}
62
63/// Initialize the connection with this amount of write buffer.
64///
65/// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
66/// frame that big.
67const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
68
69/// Chain payloads bigger than this when vectored I/O is enabled. The remote
70/// will never advertise a max frame size less than this (well, the spec says
71/// the max frame size can't be less than 16kb, so not even close).
72const CHAIN_THRESHOLD: usize = 256;
73
74/// Chain payloads bigger than this when vectored I/O is **not** enabled.
75/// A larger value in this scenario will reduce the number of small and
76/// fragmented data being sent, and hereby improve the throughput.
77const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024;
78
79// TODO: Make generic
80impl<T, B> FramedWrite<T, B>
81where
82    T: AsyncWrite + Unpin,
83    B: Buf,
84{
85    pub fn new(inner: T) -> FramedWrite<T, B> {
86        let chain_threshold = if inner.is_write_vectored() {
87            CHAIN_THRESHOLD
88        } else {
89            CHAIN_THRESHOLD_WITHOUT_VECTORED_IO
90        };
91        FramedWrite {
92            inner,
93            final_flush_done: false,
94            encoder: Encoder {
95                hpack: hpack::Encoder::default(),
96                buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
97                next: None,
98                last_data_frame: None,
99                max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
100                chain_threshold,
101                min_buffer_capacity: chain_threshold + frame::HEADER_LEN,
102            },
103        }
104    }
105
106    /// Returns `Ready` when `send` is able to accept a frame
107    ///
108    /// Calling this function may result in the current contents of the buffer
109    /// to be flushed to `T`.
110    pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
111        if !self.encoder.has_capacity() {
112            // Try flushing
113            ready!(self.flush(cx))?;
114
115            if !self.encoder.has_capacity() {
116                return Poll::Pending;
117            }
118        }
119
120        Poll::Ready(Ok(()))
121    }
122
123    /// Buffer a frame.
124    ///
125    /// `poll_ready` must be called first to ensure that a frame may be
126    /// accepted.
127    pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
128        self.encoder.buffer(item)
129    }
130
131    /// Flush buffered data to the wire
132    pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
133        let span = tracing::trace_span!("FramedWrite::flush");
134        let _e = span.enter();
135
136        loop {
137            while !self.encoder.is_empty() {
138                match self.encoder.next {
139                    Some(Next::Data(ref mut frame)) => {
140                        tracing::trace!(queued_data_frame = true);
141                        let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
142                        ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))?
143                    }
144                    _ => {
145                        tracing::trace!(queued_data_frame = false);
146                        ready!(poll_write_buf(
147                            Pin::new(&mut self.inner),
148                            cx,
149                            &mut self.encoder.buf
150                        ))?
151                    }
152                };
153            }
154
155            match self.encoder.unset_frame() {
156                ControlFlow::Continue => (),
157                ControlFlow::Break => break,
158            }
159        }
160
161        tracing::trace!("flushing buffer");
162        // Flush the upstream
163        ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
164
165        Poll::Ready(Ok(()))
166    }
167
168    /// Close the codec
169    pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
170        if !self.final_flush_done {
171            ready!(self.flush(cx))?;
172            self.final_flush_done = true;
173        }
174        Pin::new(&mut self.inner).poll_shutdown(cx)
175    }
176}
177
178#[must_use]
179enum ControlFlow {
180    Continue,
181    Break,
182}
183
184impl<B> Encoder<B>
185where
186    B: Buf,
187{
188    fn unset_frame(&mut self) -> ControlFlow {
189        // Clear internal buffer
190        self.buf.set_position(0);
191        self.buf.get_mut().clear();
192
193        // The data frame has been written, so unset it
194        match self.next.take() {
195            Some(Next::Data(frame)) => {
196                self.last_data_frame = Some(frame);
197                debug_assert!(self.is_empty());
198                ControlFlow::Break
199            }
200            Some(Next::Continuation(frame)) => {
201                // Buffer the continuation frame, then try to write again
202                let mut buf = limited_write_buf!(self);
203                if let Some(continuation) = frame.encode(&mut buf) {
204                    self.next = Some(Next::Continuation(continuation));
205                }
206                ControlFlow::Continue
207            }
208            None => ControlFlow::Break,
209        }
210    }
211
212    fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
213        // Ensure that we have enough capacity to accept the write.
214        assert!(self.has_capacity());
215        let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
216        let _e = span.enter();
217
218        tracing::debug!(frame = ?item, "send");
219
220        match item {
221            Frame::Data(mut v) => {
222                // Ensure that the payload is not greater than the max frame.
223                let len = v.payload().remaining();
224
225                if len > self.max_frame_size() {
226                    return Err(PayloadTooBig);
227                }
228
229                if len >= self.chain_threshold {
230                    let head = v.head();
231
232                    // Encode the frame head to the buffer
233                    head.encode(len, self.buf.get_mut());
234
235                    if self.buf.get_ref().remaining() < self.chain_threshold {
236                        let extra_bytes = self.chain_threshold - self.buf.remaining();
237                        self.buf.get_mut().put(v.payload_mut().take(extra_bytes));
238                    }
239
240                    // Save the data frame
241                    self.next = Some(Next::Data(v));
242                } else {
243                    v.encode_chunk(self.buf.get_mut());
244
245                    // The chunk has been fully encoded, so there is no need to
246                    // keep it around
247                    assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
248
249                    // Save off the last frame...
250                    self.last_data_frame = Some(v);
251                }
252            }
253            Frame::Headers(v) => {
254                let mut buf = limited_write_buf!(self);
255                if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
256                    self.next = Some(Next::Continuation(continuation));
257                }
258            }
259            Frame::PushPromise(v) => {
260                let mut buf = limited_write_buf!(self);
261                if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
262                    self.next = Some(Next::Continuation(continuation));
263                }
264            }
265            Frame::Settings(v) => {
266                v.encode(self.buf.get_mut());
267                tracing::trace!(rem = self.buf.remaining(), "encoded settings");
268            }
269            Frame::GoAway(v) => {
270                v.encode(self.buf.get_mut());
271                tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
272            }
273            Frame::Ping(v) => {
274                v.encode(self.buf.get_mut());
275                tracing::trace!(rem = self.buf.remaining(), "encoded ping");
276            }
277            Frame::WindowUpdate(v) => {
278                v.encode(self.buf.get_mut());
279                tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
280            }
281
282            Frame::Priority(_) => {
283                /*
284                v.encode(self.buf.get_mut());
285                tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
286                */
287                unimplemented!();
288            }
289            Frame::Reset(v) => {
290                v.encode(self.buf.get_mut());
291                tracing::trace!(rem = self.buf.remaining(), "encoded reset");
292            }
293        }
294
295        Ok(())
296    }
297
298    fn has_capacity(&self) -> bool {
299        self.next.is_none()
300            && (self.buf.get_ref().capacity() - self.buf.get_ref().len()
301                >= self.min_buffer_capacity)
302    }
303
304    fn is_empty(&self) -> bool {
305        match self.next {
306            Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
307            _ => !self.buf.has_remaining(),
308        }
309    }
310}
311
312impl<B> Encoder<B> {
313    fn max_frame_size(&self) -> usize {
314        self.max_frame_size as usize
315    }
316}
317
318impl<T, B> FramedWrite<T, B> {
319    /// Returns the max frame size that can be sent
320    pub fn max_frame_size(&self) -> usize {
321        self.encoder.max_frame_size()
322    }
323
324    /// Set the peer's max frame size.
325    pub fn set_max_frame_size(&mut self, val: usize) {
326        assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
327        self.encoder.max_frame_size = val as FrameSize;
328    }
329
330    /// Set the peer's header table size.
331    pub fn set_header_table_size(&mut self, val: usize) {
332        self.encoder.hpack.update_max_size(val);
333    }
334
335    /// Retrieve the last data frame that has been sent
336    pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
337        self.encoder.last_data_frame.take()
338    }
339
340    pub fn get_mut(&mut self) -> &mut T {
341        &mut self.inner
342    }
343}
344
345impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
346    fn poll_read(
347        mut self: Pin<&mut Self>,
348        cx: &mut Context<'_>,
349        buf: &mut ReadBuf,
350    ) -> Poll<io::Result<()>> {
351        Pin::new(&mut self.inner).poll_read(cx, buf)
352    }
353}
354
355// We never project the Pin to `B`.
356impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
357
358#[cfg(feature = "unstable")]
359mod unstable {
360    use super::*;
361
362    impl<T, B> FramedWrite<T, B> {
363        pub fn get_ref(&self) -> &T {
364            &self.inner
365        }
366    }
367}