h2/codec/
framed_write.rs
1use crate::codec::UserError;
2use crate::codec::UserError::*;
3use crate::frame::{self, Frame, FrameSize};
4use crate::hpack;
5
6use bytes::{Buf, BufMut, BytesMut};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio_util::io::poll_write_buf;
11
12use std::io::{self, Cursor};
13
14macro_rules! limited_write_buf {
16 ($self:expr) => {{
17 let limit = $self.max_frame_size() + frame::HEADER_LEN;
18 $self.buf.get_mut().limit(limit)
19 }};
20}
21
22#[derive(Debug)]
23pub struct FramedWrite<T, B> {
24 inner: T,
26 final_flush_done: bool,
27
28 encoder: Encoder<B>,
29}
30
31#[derive(Debug)]
32struct Encoder<B> {
33 hpack: hpack::Encoder,
35
36 buf: Cursor<BytesMut>,
40
41 next: Option<Next<B>>,
43
44 last_data_frame: Option<frame::Data<B>>,
46
47 max_frame_size: FrameSize,
49
50 chain_threshold: usize,
52
53 min_buffer_capacity: usize,
55}
56
57#[derive(Debug)]
58enum Next<B> {
59 Data(frame::Data<B>),
60 Continuation(frame::Continuation),
61}
62
63const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
68
69const CHAIN_THRESHOLD: usize = 256;
73
74const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024;
78
79impl<T, B> FramedWrite<T, B>
81where
82 T: AsyncWrite + Unpin,
83 B: Buf,
84{
85 pub fn new(inner: T) -> FramedWrite<T, B> {
86 let chain_threshold = if inner.is_write_vectored() {
87 CHAIN_THRESHOLD
88 } else {
89 CHAIN_THRESHOLD_WITHOUT_VECTORED_IO
90 };
91 FramedWrite {
92 inner,
93 final_flush_done: false,
94 encoder: Encoder {
95 hpack: hpack::Encoder::default(),
96 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
97 next: None,
98 last_data_frame: None,
99 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
100 chain_threshold,
101 min_buffer_capacity: chain_threshold + frame::HEADER_LEN,
102 },
103 }
104 }
105
106 pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
111 if !self.encoder.has_capacity() {
112 ready!(self.flush(cx))?;
114
115 if !self.encoder.has_capacity() {
116 return Poll::Pending;
117 }
118 }
119
120 Poll::Ready(Ok(()))
121 }
122
123 pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
128 self.encoder.buffer(item)
129 }
130
131 pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
133 let span = tracing::trace_span!("FramedWrite::flush");
134 let _e = span.enter();
135
136 loop {
137 while !self.encoder.is_empty() {
138 match self.encoder.next {
139 Some(Next::Data(ref mut frame)) => {
140 tracing::trace!(queued_data_frame = true);
141 let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
142 ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))?
143 }
144 _ => {
145 tracing::trace!(queued_data_frame = false);
146 ready!(poll_write_buf(
147 Pin::new(&mut self.inner),
148 cx,
149 &mut self.encoder.buf
150 ))?
151 }
152 };
153 }
154
155 match self.encoder.unset_frame() {
156 ControlFlow::Continue => (),
157 ControlFlow::Break => break,
158 }
159 }
160
161 tracing::trace!("flushing buffer");
162 ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
164
165 Poll::Ready(Ok(()))
166 }
167
168 pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
170 if !self.final_flush_done {
171 ready!(self.flush(cx))?;
172 self.final_flush_done = true;
173 }
174 Pin::new(&mut self.inner).poll_shutdown(cx)
175 }
176}
177
178#[must_use]
179enum ControlFlow {
180 Continue,
181 Break,
182}
183
184impl<B> Encoder<B>
185where
186 B: Buf,
187{
188 fn unset_frame(&mut self) -> ControlFlow {
189 self.buf.set_position(0);
191 self.buf.get_mut().clear();
192
193 match self.next.take() {
195 Some(Next::Data(frame)) => {
196 self.last_data_frame = Some(frame);
197 debug_assert!(self.is_empty());
198 ControlFlow::Break
199 }
200 Some(Next::Continuation(frame)) => {
201 let mut buf = limited_write_buf!(self);
203 if let Some(continuation) = frame.encode(&mut buf) {
204 self.next = Some(Next::Continuation(continuation));
205 }
206 ControlFlow::Continue
207 }
208 None => ControlFlow::Break,
209 }
210 }
211
212 fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
213 assert!(self.has_capacity());
215 let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
216 let _e = span.enter();
217
218 tracing::debug!(frame = ?item, "send");
219
220 match item {
221 Frame::Data(mut v) => {
222 let len = v.payload().remaining();
224
225 if len > self.max_frame_size() {
226 return Err(PayloadTooBig);
227 }
228
229 if len >= self.chain_threshold {
230 let head = v.head();
231
232 head.encode(len, self.buf.get_mut());
234
235 if self.buf.get_ref().remaining() < self.chain_threshold {
236 let extra_bytes = self.chain_threshold - self.buf.remaining();
237 self.buf.get_mut().put(v.payload_mut().take(extra_bytes));
238 }
239
240 self.next = Some(Next::Data(v));
242 } else {
243 v.encode_chunk(self.buf.get_mut());
244
245 assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
248
249 self.last_data_frame = Some(v);
251 }
252 }
253 Frame::Headers(v) => {
254 let mut buf = limited_write_buf!(self);
255 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
256 self.next = Some(Next::Continuation(continuation));
257 }
258 }
259 Frame::PushPromise(v) => {
260 let mut buf = limited_write_buf!(self);
261 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
262 self.next = Some(Next::Continuation(continuation));
263 }
264 }
265 Frame::Settings(v) => {
266 v.encode(self.buf.get_mut());
267 tracing::trace!(rem = self.buf.remaining(), "encoded settings");
268 }
269 Frame::GoAway(v) => {
270 v.encode(self.buf.get_mut());
271 tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
272 }
273 Frame::Ping(v) => {
274 v.encode(self.buf.get_mut());
275 tracing::trace!(rem = self.buf.remaining(), "encoded ping");
276 }
277 Frame::WindowUpdate(v) => {
278 v.encode(self.buf.get_mut());
279 tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
280 }
281
282 Frame::Priority(_) => {
283 unimplemented!();
288 }
289 Frame::Reset(v) => {
290 v.encode(self.buf.get_mut());
291 tracing::trace!(rem = self.buf.remaining(), "encoded reset");
292 }
293 }
294
295 Ok(())
296 }
297
298 fn has_capacity(&self) -> bool {
299 self.next.is_none()
300 && (self.buf.get_ref().capacity() - self.buf.get_ref().len()
301 >= self.min_buffer_capacity)
302 }
303
304 fn is_empty(&self) -> bool {
305 match self.next {
306 Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
307 _ => !self.buf.has_remaining(),
308 }
309 }
310}
311
312impl<B> Encoder<B> {
313 fn max_frame_size(&self) -> usize {
314 self.max_frame_size as usize
315 }
316}
317
318impl<T, B> FramedWrite<T, B> {
319 pub fn max_frame_size(&self) -> usize {
321 self.encoder.max_frame_size()
322 }
323
324 pub fn set_max_frame_size(&mut self, val: usize) {
326 assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
327 self.encoder.max_frame_size = val as FrameSize;
328 }
329
330 pub fn set_header_table_size(&mut self, val: usize) {
332 self.encoder.hpack.update_max_size(val);
333 }
334
335 pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
337 self.encoder.last_data_frame.take()
338 }
339
340 pub fn get_mut(&mut self) -> &mut T {
341 &mut self.inner
342 }
343}
344
345impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
346 fn poll_read(
347 mut self: Pin<&mut Self>,
348 cx: &mut Context<'_>,
349 buf: &mut ReadBuf,
350 ) -> Poll<io::Result<()>> {
351 Pin::new(&mut self.inner).poll_read(cx, buf)
352 }
353}
354
355impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
357
358#[cfg(feature = "unstable")]
359mod unstable {
360 use super::*;
361
362 impl<T, B> FramedWrite<T, B> {
363 pub fn get_ref(&self) -> &T {
364 &self.inner
365 }
366 }
367}