hyper/proto/h1/
decode.rs

1use std::error::Error as StdError;
2use std::fmt;
3use std::io;
4use std::task::{Context, Poll};
5
6use bytes::{BufMut, Bytes, BytesMut};
7use futures_util::ready;
8use http::{HeaderMap, HeaderName, HeaderValue};
9use http_body::Frame;
10
11use super::io::MemRead;
12use super::role::DEFAULT_MAX_HEADERS;
13use super::DecodedLength;
14
15use self::Kind::{Chunked, Eof, Length};
16
17/// Maximum amount of bytes allowed in chunked extensions.
18///
19/// This limit is currentlty applied for the entire body, not per chunk.
20const CHUNKED_EXTENSIONS_LIMIT: u64 = 1024 * 16;
21
22/// Maximum number of bytes allowed for all trailer fields.
23///
24/// TODO: remove this when we land h1_max_header_size support
25const TRAILER_LIMIT: usize = 1024 * 16;
26
27/// Decoders to handle different Transfer-Encodings.
28///
29/// If a message body does not include a Transfer-Encoding, it *should*
30/// include a Content-Length header.
31#[derive(Clone, PartialEq)]
32pub(crate) struct Decoder {
33    kind: Kind,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37enum Kind {
38    /// A Reader used when a Content-Length header is passed with a positive integer.
39    Length(u64),
40    /// A Reader used when Transfer-Encoding is `chunked`.
41    Chunked {
42        state: ChunkedState,
43        chunk_len: u64,
44        extensions_cnt: u64,
45        trailers_buf: Option<BytesMut>,
46        trailers_cnt: usize,
47        h1_max_headers: Option<usize>,
48        h1_max_header_size: Option<usize>,
49    },
50    /// A Reader used for responses that don't indicate a length or chunked.
51    ///
52    /// The bool tracks when EOF is seen on the transport.
53    ///
54    /// Note: This should only used for `Response`s. It is illegal for a
55    /// `Request` to be made with both `Content-Length` and
56    /// `Transfer-Encoding: chunked` missing, as explained from the spec:
57    ///
58    /// > If a Transfer-Encoding header field is present in a response and
59    /// > the chunked transfer coding is not the final encoding, the
60    /// > message body length is determined by reading the connection until
61    /// > it is closed by the server.  If a Transfer-Encoding header field
62    /// > is present in a request and the chunked transfer coding is not
63    /// > the final encoding, the message body length cannot be determined
64    /// > reliably; the server MUST respond with the 400 (Bad Request)
65    /// > status code and then close the connection.
66    Eof(bool),
67}
68
69#[derive(Debug, PartialEq, Clone, Copy)]
70enum ChunkedState {
71    Start,
72    Size,
73    SizeLws,
74    Extension,
75    SizeLf,
76    Body,
77    BodyCr,
78    BodyLf,
79    Trailer,
80    TrailerLf,
81    EndCr,
82    EndLf,
83    End,
84}
85
86impl Decoder {
87    // constructors
88
89    pub(crate) fn length(x: u64) -> Decoder {
90        Decoder {
91            kind: Kind::Length(x),
92        }
93    }
94
95    pub(crate) fn chunked(
96        h1_max_headers: Option<usize>,
97        h1_max_header_size: Option<usize>,
98    ) -> Decoder {
99        Decoder {
100            kind: Kind::Chunked {
101                state: ChunkedState::new(),
102                chunk_len: 0,
103                extensions_cnt: 0,
104                trailers_buf: None,
105                trailers_cnt: 0,
106                h1_max_headers,
107                h1_max_header_size,
108            },
109        }
110    }
111
112    pub(crate) fn eof() -> Decoder {
113        Decoder {
114            kind: Kind::Eof(false),
115        }
116    }
117
118    pub(super) fn new(
119        len: DecodedLength,
120        h1_max_headers: Option<usize>,
121        h1_max_header_size: Option<usize>,
122    ) -> Self {
123        match len {
124            DecodedLength::CHUNKED => Decoder::chunked(h1_max_headers, h1_max_header_size),
125            DecodedLength::CLOSE_DELIMITED => Decoder::eof(),
126            length => Decoder::length(length.danger_len()),
127        }
128    }
129
130    // methods
131
132    pub(crate) fn is_eof(&self) -> bool {
133        matches!(
134            self.kind,
135            Length(0)
136                | Chunked {
137                    state: ChunkedState::End,
138                    ..
139                }
140                | Eof(true)
141        )
142    }
143
144    pub(crate) fn decode<R: MemRead>(
145        &mut self,
146        cx: &mut Context<'_>,
147        body: &mut R,
148    ) -> Poll<Result<Frame<Bytes>, io::Error>> {
149        trace!("decode; state={:?}", self.kind);
150        match self.kind {
151            Length(ref mut remaining) => {
152                if *remaining == 0 {
153                    Poll::Ready(Ok(Frame::data(Bytes::new())))
154                } else {
155                    let to_read = *remaining as usize;
156                    let buf = ready!(body.read_mem(cx, to_read))?;
157                    let num = buf.as_ref().len() as u64;
158                    if num > *remaining {
159                        *remaining = 0;
160                    } else if num == 0 {
161                        return Poll::Ready(Err(io::Error::new(
162                            io::ErrorKind::UnexpectedEof,
163                            IncompleteBody,
164                        )));
165                    } else {
166                        *remaining -= num;
167                    }
168                    Poll::Ready(Ok(Frame::data(buf)))
169                }
170            }
171            Chunked {
172                ref mut state,
173                ref mut chunk_len,
174                ref mut extensions_cnt,
175                ref mut trailers_buf,
176                ref mut trailers_cnt,
177                ref h1_max_headers,
178                ref h1_max_header_size,
179            } => {
180                let h1_max_headers = h1_max_headers.unwrap_or(DEFAULT_MAX_HEADERS);
181                let h1_max_header_size = h1_max_header_size.unwrap_or(TRAILER_LIMIT);
182                loop {
183                    let mut buf = None;
184                    // advances the chunked state
185                    *state = ready!(state.step(
186                        cx,
187                        body,
188                        chunk_len,
189                        extensions_cnt,
190                        &mut buf,
191                        trailers_buf,
192                        trailers_cnt,
193                        h1_max_headers,
194                        h1_max_header_size
195                    ))?;
196                    if *state == ChunkedState::End {
197                        trace!("end of chunked");
198
199                        if trailers_buf.is_some() {
200                            trace!("found possible trailers");
201
202                            // decoder enforces that trailers count will not exceed h1_max_headers
203                            if *trailers_cnt >= h1_max_headers {
204                                return Poll::Ready(Err(io::Error::new(
205                                    io::ErrorKind::InvalidData,
206                                    "chunk trailers count overflow",
207                                )));
208                            }
209                            match decode_trailers(
210                                &mut trailers_buf.take().expect("Trailer is None"),
211                                *trailers_cnt,
212                            ) {
213                                Ok(headers) => {
214                                    return Poll::Ready(Ok(Frame::trailers(headers)));
215                                }
216                                Err(e) => {
217                                    return Poll::Ready(Err(e));
218                                }
219                            }
220                        }
221
222                        return Poll::Ready(Ok(Frame::data(Bytes::new())));
223                    }
224                    if let Some(buf) = buf {
225                        return Poll::Ready(Ok(Frame::data(buf)));
226                    }
227                }
228            }
229            Eof(ref mut is_eof) => {
230                if *is_eof {
231                    Poll::Ready(Ok(Frame::data(Bytes::new())))
232                } else {
233                    // 8192 chosen because its about 2 packets, there probably
234                    // won't be that much available, so don't have MemReaders
235                    // allocate buffers to big
236                    body.read_mem(cx, 8192).map_ok(|slice| {
237                        *is_eof = slice.is_empty();
238                        Frame::data(slice)
239                    })
240                }
241            }
242        }
243    }
244
245    #[cfg(test)]
246    async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Frame<Bytes>, io::Error> {
247        futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await
248    }
249}
250
251impl fmt::Debug for Decoder {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        fmt::Debug::fmt(&self.kind, f)
254    }
255}
256
257macro_rules! byte (
258    ($rdr:ident, $cx:expr) => ({
259        let buf = ready!($rdr.read_mem($cx, 1))?;
260        if !buf.is_empty() {
261            buf[0]
262        } else {
263            return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof,
264                                      "unexpected EOF during chunk size line")));
265        }
266    })
267);
268
269macro_rules! or_overflow {
270    ($e:expr) => (
271        match $e {
272            Some(val) => val,
273            None => return Poll::Ready(Err(io::Error::new(
274                io::ErrorKind::InvalidData,
275                "invalid chunk size: overflow",
276            ))),
277        }
278    )
279}
280
281macro_rules! put_u8 {
282    ($trailers_buf:expr, $byte:expr, $limit:expr) => {
283        $trailers_buf.put_u8($byte);
284
285        if $trailers_buf.len() >= $limit {
286            return Poll::Ready(Err(io::Error::new(
287                io::ErrorKind::InvalidData,
288                "chunk trailers bytes over limit",
289            )));
290        }
291    };
292}
293
294impl ChunkedState {
295    fn new() -> ChunkedState {
296        ChunkedState::Start
297    }
298    fn step<R: MemRead>(
299        &self,
300        cx: &mut Context<'_>,
301        body: &mut R,
302        size: &mut u64,
303        extensions_cnt: &mut u64,
304        buf: &mut Option<Bytes>,
305        trailers_buf: &mut Option<BytesMut>,
306        trailers_cnt: &mut usize,
307        h1_max_headers: usize,
308        h1_max_header_size: usize,
309    ) -> Poll<Result<ChunkedState, io::Error>> {
310        use self::ChunkedState::*;
311        match *self {
312            Start => ChunkedState::read_start(cx, body, size),
313            Size => ChunkedState::read_size(cx, body, size),
314            SizeLws => ChunkedState::read_size_lws(cx, body),
315            Extension => ChunkedState::read_extension(cx, body, extensions_cnt),
316            SizeLf => ChunkedState::read_size_lf(cx, body, *size),
317            Body => ChunkedState::read_body(cx, body, size, buf),
318            BodyCr => ChunkedState::read_body_cr(cx, body),
319            BodyLf => ChunkedState::read_body_lf(cx, body),
320            Trailer => ChunkedState::read_trailer(cx, body, trailers_buf, h1_max_header_size),
321            TrailerLf => ChunkedState::read_trailer_lf(
322                cx,
323                body,
324                trailers_buf,
325                trailers_cnt,
326                h1_max_headers,
327                h1_max_header_size,
328            ),
329            EndCr => ChunkedState::read_end_cr(cx, body, trailers_buf, h1_max_header_size),
330            EndLf => ChunkedState::read_end_lf(cx, body, trailers_buf, h1_max_header_size),
331            End => Poll::Ready(Ok(ChunkedState::End)),
332        }
333    }
334
335    fn read_start<R: MemRead>(
336        cx: &mut Context<'_>,
337        rdr: &mut R,
338        size: &mut u64,
339    ) -> Poll<Result<ChunkedState, io::Error>> {
340        trace!("Read chunk start");
341
342        let radix = 16;
343        match byte!(rdr, cx) {
344            b @ b'0'..=b'9' => {
345                *size = or_overflow!(size.checked_mul(radix));
346                *size = or_overflow!(size.checked_add((b - b'0') as u64));
347            }
348            b @ b'a'..=b'f' => {
349                *size = or_overflow!(size.checked_mul(radix));
350                *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
351            }
352            b @ b'A'..=b'F' => {
353                *size = or_overflow!(size.checked_mul(radix));
354                *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
355            }
356            _ => {
357                return Poll::Ready(Err(io::Error::new(
358                    io::ErrorKind::InvalidInput,
359                    "Invalid chunk size line: missing size digit",
360                )));
361            }
362        }
363
364        Poll::Ready(Ok(ChunkedState::Size))
365    }
366
367    fn read_size<R: MemRead>(
368        cx: &mut Context<'_>,
369        rdr: &mut R,
370        size: &mut u64,
371    ) -> Poll<Result<ChunkedState, io::Error>> {
372        trace!("Read chunk hex size");
373
374        let radix = 16;
375        match byte!(rdr, cx) {
376            b @ b'0'..=b'9' => {
377                *size = or_overflow!(size.checked_mul(radix));
378                *size = or_overflow!(size.checked_add((b - b'0') as u64));
379            }
380            b @ b'a'..=b'f' => {
381                *size = or_overflow!(size.checked_mul(radix));
382                *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
383            }
384            b @ b'A'..=b'F' => {
385                *size = or_overflow!(size.checked_mul(radix));
386                *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
387            }
388            b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
389            b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
390            b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
391            _ => {
392                return Poll::Ready(Err(io::Error::new(
393                    io::ErrorKind::InvalidInput,
394                    "Invalid chunk size line: Invalid Size",
395                )));
396            }
397        }
398        Poll::Ready(Ok(ChunkedState::Size))
399    }
400    fn read_size_lws<R: MemRead>(
401        cx: &mut Context<'_>,
402        rdr: &mut R,
403    ) -> Poll<Result<ChunkedState, io::Error>> {
404        trace!("read_size_lws");
405        match byte!(rdr, cx) {
406            // LWS can follow the chunk size, but no more digits can come
407            b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
408            b';' => Poll::Ready(Ok(ChunkedState::Extension)),
409            b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
410            _ => Poll::Ready(Err(io::Error::new(
411                io::ErrorKind::InvalidInput,
412                "Invalid chunk size linear white space",
413            ))),
414        }
415    }
416    fn read_extension<R: MemRead>(
417        cx: &mut Context<'_>,
418        rdr: &mut R,
419        extensions_cnt: &mut u64,
420    ) -> Poll<Result<ChunkedState, io::Error>> {
421        trace!("read_extension");
422        // We don't care about extensions really at all. Just ignore them.
423        // They "end" at the next CRLF.
424        //
425        // However, some implementations may not check for the CR, so to save
426        // them from themselves, we reject extensions containing plain LF as
427        // well.
428        match byte!(rdr, cx) {
429            b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
430            b'\n' => Poll::Ready(Err(io::Error::new(
431                io::ErrorKind::InvalidData,
432                "invalid chunk extension contains newline",
433            ))),
434            _ => {
435                *extensions_cnt += 1;
436                if *extensions_cnt >= CHUNKED_EXTENSIONS_LIMIT {
437                    Poll::Ready(Err(io::Error::new(
438                        io::ErrorKind::InvalidData,
439                        "chunk extensions over limit",
440                    )))
441                } else {
442                    Poll::Ready(Ok(ChunkedState::Extension))
443                }
444            } // no supported extensions
445        }
446    }
447    fn read_size_lf<R: MemRead>(
448        cx: &mut Context<'_>,
449        rdr: &mut R,
450        size: u64,
451    ) -> Poll<Result<ChunkedState, io::Error>> {
452        trace!("Chunk size is {:?}", size);
453        match byte!(rdr, cx) {
454            b'\n' => {
455                if size == 0 {
456                    Poll::Ready(Ok(ChunkedState::EndCr))
457                } else {
458                    debug!("incoming chunked header: {0:#X} ({0} bytes)", size);
459                    Poll::Ready(Ok(ChunkedState::Body))
460                }
461            }
462            _ => Poll::Ready(Err(io::Error::new(
463                io::ErrorKind::InvalidInput,
464                "Invalid chunk size LF",
465            ))),
466        }
467    }
468
469    fn read_body<R: MemRead>(
470        cx: &mut Context<'_>,
471        rdr: &mut R,
472        rem: &mut u64,
473        buf: &mut Option<Bytes>,
474    ) -> Poll<Result<ChunkedState, io::Error>> {
475        trace!("Chunked read, remaining={:?}", rem);
476
477        // cap remaining bytes at the max capacity of usize
478        let rem_cap = match *rem {
479            r if r > usize::MAX as u64 => usize::MAX,
480            r => r as usize,
481        };
482
483        let to_read = rem_cap;
484        let slice = ready!(rdr.read_mem(cx, to_read))?;
485        let count = slice.len();
486
487        if count == 0 {
488            *rem = 0;
489            return Poll::Ready(Err(io::Error::new(
490                io::ErrorKind::UnexpectedEof,
491                IncompleteBody,
492            )));
493        }
494        *buf = Some(slice);
495        *rem -= count as u64;
496
497        if *rem > 0 {
498            Poll::Ready(Ok(ChunkedState::Body))
499        } else {
500            Poll::Ready(Ok(ChunkedState::BodyCr))
501        }
502    }
503    fn read_body_cr<R: MemRead>(
504        cx: &mut Context<'_>,
505        rdr: &mut R,
506    ) -> Poll<Result<ChunkedState, io::Error>> {
507        match byte!(rdr, cx) {
508            b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
509            _ => Poll::Ready(Err(io::Error::new(
510                io::ErrorKind::InvalidInput,
511                "Invalid chunk body CR",
512            ))),
513        }
514    }
515    fn read_body_lf<R: MemRead>(
516        cx: &mut Context<'_>,
517        rdr: &mut R,
518    ) -> Poll<Result<ChunkedState, io::Error>> {
519        match byte!(rdr, cx) {
520            b'\n' => Poll::Ready(Ok(ChunkedState::Start)),
521            _ => Poll::Ready(Err(io::Error::new(
522                io::ErrorKind::InvalidInput,
523                "Invalid chunk body LF",
524            ))),
525        }
526    }
527
528    fn read_trailer<R: MemRead>(
529        cx: &mut Context<'_>,
530        rdr: &mut R,
531        trailers_buf: &mut Option<BytesMut>,
532        h1_max_header_size: usize,
533    ) -> Poll<Result<ChunkedState, io::Error>> {
534        trace!("read_trailer");
535        let byte = byte!(rdr, cx);
536
537        put_u8!(
538            trailers_buf.as_mut().expect("trailers_buf is None"),
539            byte,
540            h1_max_header_size
541        );
542
543        match byte {
544            b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)),
545            _ => Poll::Ready(Ok(ChunkedState::Trailer)),
546        }
547    }
548
549    fn read_trailer_lf<R: MemRead>(
550        cx: &mut Context<'_>,
551        rdr: &mut R,
552        trailers_buf: &mut Option<BytesMut>,
553        trailers_cnt: &mut usize,
554        h1_max_headers: usize,
555        h1_max_header_size: usize,
556    ) -> Poll<Result<ChunkedState, io::Error>> {
557        let byte = byte!(rdr, cx);
558        match byte {
559            b'\n' => {
560                if *trailers_cnt >= h1_max_headers {
561                    return Poll::Ready(Err(io::Error::new(
562                        io::ErrorKind::InvalidData,
563                        "chunk trailers count overflow",
564                    )));
565                }
566                *trailers_cnt += 1;
567
568                put_u8!(
569                    trailers_buf.as_mut().expect("trailers_buf is None"),
570                    byte,
571                    h1_max_header_size
572                );
573
574                Poll::Ready(Ok(ChunkedState::EndCr))
575            }
576            _ => Poll::Ready(Err(io::Error::new(
577                io::ErrorKind::InvalidInput,
578                "Invalid trailer end LF",
579            ))),
580        }
581    }
582
583    fn read_end_cr<R: MemRead>(
584        cx: &mut Context<'_>,
585        rdr: &mut R,
586        trailers_buf: &mut Option<BytesMut>,
587        h1_max_header_size: usize,
588    ) -> Poll<Result<ChunkedState, io::Error>> {
589        let byte = byte!(rdr, cx);
590        match byte {
591            b'\r' => {
592                if let Some(trailers_buf) = trailers_buf {
593                    put_u8!(trailers_buf, byte, h1_max_header_size);
594                }
595                Poll::Ready(Ok(ChunkedState::EndLf))
596            }
597            byte => {
598                match trailers_buf {
599                    None => {
600                        // 64 will fit a single Expires header without reallocating
601                        let mut buf = BytesMut::with_capacity(64);
602                        buf.put_u8(byte);
603                        *trailers_buf = Some(buf);
604                    }
605                    Some(ref mut trailers_buf) => {
606                        put_u8!(trailers_buf, byte, h1_max_header_size);
607                    }
608                }
609
610                Poll::Ready(Ok(ChunkedState::Trailer))
611            }
612        }
613    }
614    fn read_end_lf<R: MemRead>(
615        cx: &mut Context<'_>,
616        rdr: &mut R,
617        trailers_buf: &mut Option<BytesMut>,
618        h1_max_header_size: usize,
619    ) -> Poll<Result<ChunkedState, io::Error>> {
620        let byte = byte!(rdr, cx);
621        match byte {
622            b'\n' => {
623                if let Some(trailers_buf) = trailers_buf {
624                    put_u8!(trailers_buf, byte, h1_max_header_size);
625                }
626                Poll::Ready(Ok(ChunkedState::End))
627            }
628            _ => Poll::Ready(Err(io::Error::new(
629                io::ErrorKind::InvalidInput,
630                "Invalid chunk end LF",
631            ))),
632        }
633    }
634}
635
636// TODO: disallow Transfer-Encoding, Content-Length, Trailer, etc in trailers ??
637fn decode_trailers(buf: &mut BytesMut, count: usize) -> Result<HeaderMap, io::Error> {
638    let mut trailers = HeaderMap::new();
639    let mut headers = vec![httparse::EMPTY_HEADER; count];
640    let res = httparse::parse_headers(buf, &mut headers);
641    match res {
642        Ok(httparse::Status::Complete((_, headers))) => {
643            for header in headers.iter() {
644                use std::convert::TryFrom;
645                let name = match HeaderName::try_from(header.name) {
646                    Ok(name) => name,
647                    Err(_) => {
648                        return Err(io::Error::new(
649                            io::ErrorKind::InvalidInput,
650                            format!("Invalid header name: {:?}", &header),
651                        ));
652                    }
653                };
654
655                let value = match HeaderValue::from_bytes(header.value) {
656                    Ok(value) => value,
657                    Err(_) => {
658                        return Err(io::Error::new(
659                            io::ErrorKind::InvalidInput,
660                            format!("Invalid header value: {:?}", &header),
661                        ));
662                    }
663                };
664
665                trailers.insert(name, value);
666            }
667
668            Ok(trailers)
669        }
670        Ok(httparse::Status::Partial) => Err(io::Error::new(
671            io::ErrorKind::InvalidInput,
672            "Partial header",
673        )),
674        Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
675    }
676}
677
678#[derive(Debug)]
679struct IncompleteBody;
680
681impl fmt::Display for IncompleteBody {
682    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683        write!(f, "end of file before message length reached")
684    }
685}
686
687impl StdError for IncompleteBody {}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692    use crate::rt::{Read, ReadBuf};
693    use std::pin::Pin;
694    use std::time::Duration;
695
696    impl MemRead for &[u8] {
697        fn read_mem(&mut self, _: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
698            let n = std::cmp::min(len, self.len());
699            if n > 0 {
700                let (a, b) = self.split_at(n);
701                let buf = Bytes::copy_from_slice(a);
702                *self = b;
703                Poll::Ready(Ok(buf))
704            } else {
705                Poll::Ready(Ok(Bytes::new()))
706            }
707        }
708    }
709
710    impl MemRead for &mut (dyn Read + Unpin) {
711        fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
712            let mut v = vec![0; len];
713            let mut buf = ReadBuf::new(&mut v);
714            ready!(Pin::new(self).poll_read(cx, buf.unfilled())?);
715            Poll::Ready(Ok(Bytes::copy_from_slice(buf.filled())))
716        }
717    }
718
719    impl MemRead for Bytes {
720        fn read_mem(&mut self, _: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
721            let n = std::cmp::min(len, self.len());
722            let ret = self.split_to(n);
723            Poll::Ready(Ok(ret))
724        }
725    }
726
727    /*
728    use std::io;
729    use std::io::Write;
730    use super::Decoder;
731    use super::ChunkedState;
732    use futures::{Async, Poll};
733    use bytes::{BytesMut, Bytes};
734    use crate::mock::AsyncIo;
735    */
736
737    #[cfg(not(miri))]
738    #[tokio::test]
739    async fn test_read_chunk_size() {
740        use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
741
742        async fn read(s: &str) -> u64 {
743            let mut state = ChunkedState::new();
744            let rdr = &mut s.as_bytes();
745            let mut size = 0;
746            let mut ext_cnt = 0;
747            let mut trailers_cnt = 0;
748            loop {
749                let result = futures_util::future::poll_fn(|cx| {
750                    state.step(
751                        cx,
752                        rdr,
753                        &mut size,
754                        &mut ext_cnt,
755                        &mut None,
756                        &mut None,
757                        &mut trailers_cnt,
758                        DEFAULT_MAX_HEADERS,
759                        TRAILER_LIMIT,
760                    )
761                })
762                .await;
763                let desc = format!("read_size failed for {:?}", s);
764                state = result.expect(&desc);
765                if state == ChunkedState::Body || state == ChunkedState::EndCr {
766                    break;
767                }
768            }
769            size
770        }
771
772        async fn read_err(s: &str, expected_err: io::ErrorKind) {
773            let mut state = ChunkedState::new();
774            let rdr = &mut s.as_bytes();
775            let mut size = 0;
776            let mut ext_cnt = 0;
777            let mut trailers_cnt = 0;
778            loop {
779                let result = futures_util::future::poll_fn(|cx| {
780                    state.step(
781                        cx,
782                        rdr,
783                        &mut size,
784                        &mut ext_cnt,
785                        &mut None,
786                        &mut None,
787                        &mut trailers_cnt,
788                        DEFAULT_MAX_HEADERS,
789                        TRAILER_LIMIT,
790                    )
791                })
792                .await;
793                state = match result {
794                    Ok(s) => s,
795                    Err(e) => {
796                        assert!(
797                            expected_err == e.kind(),
798                            "Reading {:?}, expected {:?}, but got {:?}",
799                            s,
800                            expected_err,
801                            e.kind()
802                        );
803                        return;
804                    }
805                };
806                if state == ChunkedState::Body || state == ChunkedState::End {
807                    panic!("Was Ok. Expected Err for {:?}", s);
808                }
809            }
810        }
811
812        assert_eq!(1, read("1\r\n").await);
813        assert_eq!(1, read("01\r\n").await);
814        assert_eq!(0, read("0\r\n").await);
815        assert_eq!(0, read("00\r\n").await);
816        assert_eq!(10, read("A\r\n").await);
817        assert_eq!(10, read("a\r\n").await);
818        assert_eq!(255, read("Ff\r\n").await);
819        assert_eq!(255, read("Ff   \r\n").await);
820        // Missing LF or CRLF
821        read_err("F\rF", InvalidInput).await;
822        read_err("F", UnexpectedEof).await;
823        // Missing digit
824        read_err("\r\n\r\n", InvalidInput).await;
825        read_err("\r\n", InvalidInput).await;
826        // Invalid hex digit
827        read_err("X\r\n", InvalidInput).await;
828        read_err("1X\r\n", InvalidInput).await;
829        read_err("-\r\n", InvalidInput).await;
830        read_err("-1\r\n", InvalidInput).await;
831        // Acceptable (if not fully valid) extensions do not influence the size
832        assert_eq!(1, read("1;extension\r\n").await);
833        assert_eq!(10, read("a;ext name=value\r\n").await);
834        assert_eq!(1, read("1;extension;extension2\r\n").await);
835        assert_eq!(1, read("1;;;  ;\r\n").await);
836        assert_eq!(2, read("2; extension...\r\n").await);
837        assert_eq!(3, read("3   ; extension=123\r\n").await);
838        assert_eq!(3, read("3   ;\r\n").await);
839        assert_eq!(3, read("3   ;   \r\n").await);
840        // Invalid extensions cause an error
841        read_err("1 invalid extension\r\n", InvalidInput).await;
842        read_err("1 A\r\n", InvalidInput).await;
843        read_err("1;no CRLF", UnexpectedEof).await;
844        read_err("1;reject\nnewlines\r\n", InvalidData).await;
845        // Overflow
846        read_err("f0000000000000003\r\n", InvalidData).await;
847    }
848
849    #[cfg(not(miri))]
850    #[tokio::test]
851    async fn test_read_sized_early_eof() {
852        let mut bytes = &b"foo bar"[..];
853        let mut decoder = Decoder::length(10);
854        assert_eq!(
855            decoder
856                .decode_fut(&mut bytes)
857                .await
858                .unwrap()
859                .data_ref()
860                .unwrap()
861                .len(),
862            7
863        );
864        let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
865        assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
866    }
867
868    #[cfg(not(miri))]
869    #[tokio::test]
870    async fn test_read_chunked_early_eof() {
871        let mut bytes = &b"\
872            9\r\n\
873            foo bar\
874        "[..];
875        let mut decoder = Decoder::chunked(None, None);
876        assert_eq!(
877            decoder
878                .decode_fut(&mut bytes)
879                .await
880                .unwrap()
881                .data_ref()
882                .unwrap()
883                .len(),
884            7
885        );
886        let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
887        assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
888    }
889
890    #[cfg(not(miri))]
891    #[tokio::test]
892    async fn test_read_chunked_single_read() {
893        let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..];
894        let buf = Decoder::chunked(None, None)
895            .decode_fut(&mut mock_buf)
896            .await
897            .expect("decode")
898            .into_data()
899            .expect("unknown frame type");
900        assert_eq!(16, buf.len());
901        let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
902        assert_eq!("1234567890abcdef", &result);
903    }
904
905    #[tokio::test]
906    async fn test_read_chunked_with_missing_zero_digit() {
907        // After reading a valid chunk, the ending is missing a zero.
908        let mut mock_buf = &b"1\r\nZ\r\n\r\n\r\n"[..];
909        let mut decoder = Decoder::chunked(None, None);
910        let buf = decoder
911            .decode_fut(&mut mock_buf)
912            .await
913            .expect("decode")
914            .into_data()
915            .expect("unknown frame type");
916        assert_eq!("Z", buf);
917
918        let err = decoder
919            .decode_fut(&mut mock_buf)
920            .await
921            .expect_err("decode 2");
922        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
923    }
924
925    #[tokio::test]
926    async fn test_read_chunked_extensions_over_limit() {
927        // construct a chunked body where each individual chunked extension
928        // is totally fine, but combined is over the limit.
929        let per_chunk = super::CHUNKED_EXTENSIONS_LIMIT * 2 / 3;
930        let mut scratch = vec![];
931        for _ in 0..2 {
932            scratch.extend(b"1;");
933            scratch.extend(b"x".repeat(per_chunk as usize));
934            scratch.extend(b"\r\nA\r\n");
935        }
936        scratch.extend(b"0\r\n\r\n");
937        let mut mock_buf = Bytes::from(scratch);
938
939        let mut decoder = Decoder::chunked(None, None);
940        let buf1 = decoder
941            .decode_fut(&mut mock_buf)
942            .await
943            .expect("decode1")
944            .into_data()
945            .expect("unknown frame type");
946        assert_eq!(&buf1[..], b"A");
947
948        let err = decoder
949            .decode_fut(&mut mock_buf)
950            .await
951            .expect_err("decode2");
952        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
953        assert_eq!(err.to_string(), "chunk extensions over limit");
954    }
955
956    #[cfg(not(miri))]
957    #[tokio::test]
958    async fn test_read_chunked_trailer_with_missing_lf() {
959        let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..];
960        let mut decoder = Decoder::chunked(None, None);
961        decoder.decode_fut(&mut mock_buf).await.expect("decode");
962        let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err();
963        assert_eq!(e.kind(), io::ErrorKind::InvalidInput);
964    }
965
966    #[cfg(not(miri))]
967    #[tokio::test]
968    async fn test_read_chunked_after_eof() {
969        let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..];
970        let mut decoder = Decoder::chunked(None, None);
971
972        // normal read
973        let buf = decoder
974            .decode_fut(&mut mock_buf)
975            .await
976            .unwrap()
977            .into_data()
978            .expect("unknown frame type");
979        assert_eq!(16, buf.len());
980        let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
981        assert_eq!("1234567890abcdef", &result);
982
983        // eof read
984        let buf = decoder
985            .decode_fut(&mut mock_buf)
986            .await
987            .expect("decode")
988            .into_data()
989            .expect("unknown frame type");
990        assert_eq!(0, buf.len());
991
992        // ensure read after eof also returns eof
993        let buf = decoder
994            .decode_fut(&mut mock_buf)
995            .await
996            .expect("decode")
997            .into_data()
998            .expect("unknown frame type");
999        assert_eq!(0, buf.len());
1000    }
1001
1002    // perform an async read using a custom buffer size and causing a blocking
1003    // read at the specified byte
1004    async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String {
1005        let mut outs = Vec::new();
1006
1007        let mut ins = crate::common::io::Compat::new(if block_at == 0 {
1008            tokio_test::io::Builder::new()
1009                .wait(Duration::from_millis(10))
1010                .read(content)
1011                .build()
1012        } else {
1013            tokio_test::io::Builder::new()
1014                .read(&content[..block_at])
1015                .wait(Duration::from_millis(10))
1016                .read(&content[block_at..])
1017                .build()
1018        });
1019
1020        let mut ins = &mut ins as &mut (dyn Read + Unpin);
1021
1022        loop {
1023            let buf = decoder
1024                .decode_fut(&mut ins)
1025                .await
1026                .expect("unexpected decode error")
1027                .into_data()
1028                .expect("unexpected frame type");
1029            if buf.is_empty() {
1030                break; // eof
1031            }
1032            outs.extend(buf.as_ref());
1033        }
1034
1035        String::from_utf8(outs).expect("decode String")
1036    }
1037
1038    // iterate over the different ways that this async read could go.
1039    // tests blocking a read at each byte along the content - The shotgun approach
1040    async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) {
1041        let content_len = content.len();
1042        for block_at in 0..content_len {
1043            let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await;
1044            assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at);
1045        }
1046    }
1047
1048    #[cfg(not(miri))]
1049    #[tokio::test]
1050    async fn test_read_length_async() {
1051        let content = "foobar";
1052        all_async_cases(content, content, Decoder::length(content.len() as u64)).await;
1053    }
1054
1055    #[cfg(not(miri))]
1056    #[tokio::test]
1057    async fn test_read_chunked_async() {
1058        let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n";
1059        let expected = "foobar";
1060        all_async_cases(content, expected, Decoder::chunked(None, None)).await;
1061    }
1062
1063    #[cfg(not(miri))]
1064    #[tokio::test]
1065    async fn test_read_eof_async() {
1066        let content = "foobar";
1067        all_async_cases(content, content, Decoder::eof()).await;
1068    }
1069
1070    #[cfg(all(feature = "nightly", not(miri)))]
1071    #[bench]
1072    fn bench_decode_chunked_1kb(b: &mut test::Bencher) {
1073        let rt = new_runtime();
1074
1075        const LEN: usize = 1024;
1076        let mut vec = Vec::new();
1077        vec.extend(format!("{:x}\r\n", LEN).as_bytes());
1078        vec.extend(&[0; LEN][..]);
1079        vec.extend(b"\r\n");
1080        let content = Bytes::from(vec);
1081
1082        b.bytes = LEN as u64;
1083
1084        b.iter(|| {
1085            let mut decoder = Decoder::chunked(None, None);
1086            rt.block_on(async {
1087                let mut raw = content.clone();
1088                let chunk = decoder
1089                    .decode_fut(&mut raw)
1090                    .await
1091                    .unwrap()
1092                    .into_data()
1093                    .unwrap();
1094                assert_eq!(chunk.len(), LEN);
1095            });
1096        });
1097    }
1098
1099    #[cfg(all(feature = "nightly", not(miri)))]
1100    #[bench]
1101    fn bench_decode_length_1kb(b: &mut test::Bencher) {
1102        let rt = new_runtime();
1103
1104        const LEN: usize = 1024;
1105        let content = Bytes::from(&[0; LEN][..]);
1106        b.bytes = LEN as u64;
1107
1108        b.iter(|| {
1109            let mut decoder = Decoder::length(LEN as u64);
1110            rt.block_on(async {
1111                let mut raw = content.clone();
1112                let chunk = decoder
1113                    .decode_fut(&mut raw)
1114                    .await
1115                    .unwrap()
1116                    .into_data()
1117                    .unwrap();
1118                assert_eq!(chunk.len(), LEN);
1119            });
1120        });
1121    }
1122
1123    #[cfg(feature = "nightly")]
1124    fn new_runtime() -> tokio::runtime::Runtime {
1125        tokio::runtime::Builder::new_current_thread()
1126            .enable_all()
1127            .build()
1128            .expect("rt build")
1129    }
1130
1131    #[test]
1132    fn test_decode_trailers() {
1133        let mut buf = BytesMut::new();
1134        buf.extend_from_slice(
1135            b"Expires: Wed, 21 Oct 2015 07:28:00 GMT\r\nX-Stream-Error: failed to decode\r\n\r\n",
1136        );
1137        let headers = decode_trailers(&mut buf, 2).expect("decode_trailers");
1138        assert_eq!(headers.len(), 2);
1139        assert_eq!(
1140            headers.get("Expires").unwrap(),
1141            "Wed, 21 Oct 2015 07:28:00 GMT"
1142        );
1143        assert_eq!(headers.get("X-Stream-Error").unwrap(), "failed to decode");
1144    }
1145
1146    #[tokio::test]
1147    async fn test_trailer_max_headers_enforced() {
1148        let h1_max_headers = 10;
1149        let mut scratch = vec![];
1150        scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1151        for i in 0..h1_max_headers {
1152            scratch.extend(format!("trailer{}: {}\r\n", i, i).as_bytes());
1153        }
1154        scratch.extend(b"\r\n");
1155        let mut mock_buf = Bytes::from(scratch);
1156
1157        let mut decoder = Decoder::chunked(Some(h1_max_headers), None);
1158
1159        // ready chunked body
1160        let buf = decoder
1161            .decode_fut(&mut mock_buf)
1162            .await
1163            .unwrap()
1164            .into_data()
1165            .expect("unknown frame type");
1166        assert_eq!(16, buf.len());
1167
1168        // eof read
1169        let err = decoder
1170            .decode_fut(&mut mock_buf)
1171            .await
1172            .expect_err("trailer fields over limit");
1173        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1174    }
1175
1176    #[tokio::test]
1177    async fn test_trailer_max_header_size_huge_trailer() {
1178        let max_header_size = 1024;
1179        let mut scratch = vec![];
1180        scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1181        scratch.extend(format!("huge_trailer: {}\r\n", "x".repeat(max_header_size)).as_bytes());
1182        scratch.extend(b"\r\n");
1183        let mut mock_buf = Bytes::from(scratch);
1184
1185        let mut decoder = Decoder::chunked(None, Some(max_header_size));
1186
1187        // ready chunked body
1188        let buf = decoder
1189            .decode_fut(&mut mock_buf)
1190            .await
1191            .unwrap()
1192            .into_data()
1193            .expect("unknown frame type");
1194        assert_eq!(16, buf.len());
1195
1196        // eof read
1197        let err = decoder
1198            .decode_fut(&mut mock_buf)
1199            .await
1200            .expect_err("trailers over limit");
1201        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1202    }
1203
1204    #[tokio::test]
1205    async fn test_trailer_max_header_size_many_small_trailers() {
1206        let max_headers = 10;
1207        let header_size = 64;
1208        let mut scratch = vec![];
1209        scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1210
1211        for i in 0..max_headers {
1212            scratch.extend(format!("trailer{}: {}\r\n", i, "x".repeat(header_size)).as_bytes());
1213        }
1214
1215        scratch.extend(b"\r\n");
1216        let mut mock_buf = Bytes::from(scratch);
1217
1218        let mut decoder = Decoder::chunked(None, Some(max_headers * header_size));
1219
1220        // ready chunked body
1221        let buf = decoder
1222            .decode_fut(&mut mock_buf)
1223            .await
1224            .unwrap()
1225            .into_data()
1226            .expect("unknown frame type");
1227        assert_eq!(16, buf.len());
1228
1229        // eof read
1230        let err = decoder
1231            .decode_fut(&mut mock_buf)
1232            .await
1233            .expect_err("trailers over limit");
1234        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1235    }
1236}