1use std::error::Error as StdError;
2use std::fmt;
3use std::io;
4use std::task::{Context, Poll};
5
6use bytes::{BufMut, Bytes, BytesMut};
7use futures_util::ready;
8use http::{HeaderMap, HeaderName, HeaderValue};
9use http_body::Frame;
10
11use super::io::MemRead;
12use super::role::DEFAULT_MAX_HEADERS;
13use super::DecodedLength;
14
15use self::Kind::{Chunked, Eof, Length};
16
17const CHUNKED_EXTENSIONS_LIMIT: u64 = 1024 * 16;
21
22const TRAILER_LIMIT: usize = 1024 * 16;
26
27#[derive(Clone, PartialEq)]
32pub(crate) struct Decoder {
33 kind: Kind,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37enum Kind {
38 Length(u64),
40 Chunked {
42 state: ChunkedState,
43 chunk_len: u64,
44 extensions_cnt: u64,
45 trailers_buf: Option<BytesMut>,
46 trailers_cnt: usize,
47 h1_max_headers: Option<usize>,
48 h1_max_header_size: Option<usize>,
49 },
50 Eof(bool),
67}
68
69#[derive(Debug, PartialEq, Clone, Copy)]
70enum ChunkedState {
71 Start,
72 Size,
73 SizeLws,
74 Extension,
75 SizeLf,
76 Body,
77 BodyCr,
78 BodyLf,
79 Trailer,
80 TrailerLf,
81 EndCr,
82 EndLf,
83 End,
84}
85
86impl Decoder {
87 pub(crate) fn length(x: u64) -> Decoder {
90 Decoder {
91 kind: Kind::Length(x),
92 }
93 }
94
95 pub(crate) fn chunked(
96 h1_max_headers: Option<usize>,
97 h1_max_header_size: Option<usize>,
98 ) -> Decoder {
99 Decoder {
100 kind: Kind::Chunked {
101 state: ChunkedState::new(),
102 chunk_len: 0,
103 extensions_cnt: 0,
104 trailers_buf: None,
105 trailers_cnt: 0,
106 h1_max_headers,
107 h1_max_header_size,
108 },
109 }
110 }
111
112 pub(crate) fn eof() -> Decoder {
113 Decoder {
114 kind: Kind::Eof(false),
115 }
116 }
117
118 pub(super) fn new(
119 len: DecodedLength,
120 h1_max_headers: Option<usize>,
121 h1_max_header_size: Option<usize>,
122 ) -> Self {
123 match len {
124 DecodedLength::CHUNKED => Decoder::chunked(h1_max_headers, h1_max_header_size),
125 DecodedLength::CLOSE_DELIMITED => Decoder::eof(),
126 length => Decoder::length(length.danger_len()),
127 }
128 }
129
130 pub(crate) fn is_eof(&self) -> bool {
133 matches!(
134 self.kind,
135 Length(0)
136 | Chunked {
137 state: ChunkedState::End,
138 ..
139 }
140 | Eof(true)
141 )
142 }
143
144 pub(crate) fn decode<R: MemRead>(
145 &mut self,
146 cx: &mut Context<'_>,
147 body: &mut R,
148 ) -> Poll<Result<Frame<Bytes>, io::Error>> {
149 trace!("decode; state={:?}", self.kind);
150 match self.kind {
151 Length(ref mut remaining) => {
152 if *remaining == 0 {
153 Poll::Ready(Ok(Frame::data(Bytes::new())))
154 } else {
155 let to_read = *remaining as usize;
156 let buf = ready!(body.read_mem(cx, to_read))?;
157 let num = buf.as_ref().len() as u64;
158 if num > *remaining {
159 *remaining = 0;
160 } else if num == 0 {
161 return Poll::Ready(Err(io::Error::new(
162 io::ErrorKind::UnexpectedEof,
163 IncompleteBody,
164 )));
165 } else {
166 *remaining -= num;
167 }
168 Poll::Ready(Ok(Frame::data(buf)))
169 }
170 }
171 Chunked {
172 ref mut state,
173 ref mut chunk_len,
174 ref mut extensions_cnt,
175 ref mut trailers_buf,
176 ref mut trailers_cnt,
177 ref h1_max_headers,
178 ref h1_max_header_size,
179 } => {
180 let h1_max_headers = h1_max_headers.unwrap_or(DEFAULT_MAX_HEADERS);
181 let h1_max_header_size = h1_max_header_size.unwrap_or(TRAILER_LIMIT);
182 loop {
183 let mut buf = None;
184 *state = ready!(state.step(
186 cx,
187 body,
188 chunk_len,
189 extensions_cnt,
190 &mut buf,
191 trailers_buf,
192 trailers_cnt,
193 h1_max_headers,
194 h1_max_header_size
195 ))?;
196 if *state == ChunkedState::End {
197 trace!("end of chunked");
198
199 if trailers_buf.is_some() {
200 trace!("found possible trailers");
201
202 if *trailers_cnt >= h1_max_headers {
204 return Poll::Ready(Err(io::Error::new(
205 io::ErrorKind::InvalidData,
206 "chunk trailers count overflow",
207 )));
208 }
209 match decode_trailers(
210 &mut trailers_buf.take().expect("Trailer is None"),
211 *trailers_cnt,
212 ) {
213 Ok(headers) => {
214 return Poll::Ready(Ok(Frame::trailers(headers)));
215 }
216 Err(e) => {
217 return Poll::Ready(Err(e));
218 }
219 }
220 }
221
222 return Poll::Ready(Ok(Frame::data(Bytes::new())));
223 }
224 if let Some(buf) = buf {
225 return Poll::Ready(Ok(Frame::data(buf)));
226 }
227 }
228 }
229 Eof(ref mut is_eof) => {
230 if *is_eof {
231 Poll::Ready(Ok(Frame::data(Bytes::new())))
232 } else {
233 body.read_mem(cx, 8192).map_ok(|slice| {
237 *is_eof = slice.is_empty();
238 Frame::data(slice)
239 })
240 }
241 }
242 }
243 }
244
245 #[cfg(test)]
246 async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Frame<Bytes>, io::Error> {
247 futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await
248 }
249}
250
251impl fmt::Debug for Decoder {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 fmt::Debug::fmt(&self.kind, f)
254 }
255}
256
257macro_rules! byte (
258 ($rdr:ident, $cx:expr) => ({
259 let buf = ready!($rdr.read_mem($cx, 1))?;
260 if !buf.is_empty() {
261 buf[0]
262 } else {
263 return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof,
264 "unexpected EOF during chunk size line")));
265 }
266 })
267);
268
269macro_rules! or_overflow {
270 ($e:expr) => (
271 match $e {
272 Some(val) => val,
273 None => return Poll::Ready(Err(io::Error::new(
274 io::ErrorKind::InvalidData,
275 "invalid chunk size: overflow",
276 ))),
277 }
278 )
279}
280
281macro_rules! put_u8 {
282 ($trailers_buf:expr, $byte:expr, $limit:expr) => {
283 $trailers_buf.put_u8($byte);
284
285 if $trailers_buf.len() >= $limit {
286 return Poll::Ready(Err(io::Error::new(
287 io::ErrorKind::InvalidData,
288 "chunk trailers bytes over limit",
289 )));
290 }
291 };
292}
293
294impl ChunkedState {
295 fn new() -> ChunkedState {
296 ChunkedState::Start
297 }
298 fn step<R: MemRead>(
299 &self,
300 cx: &mut Context<'_>,
301 body: &mut R,
302 size: &mut u64,
303 extensions_cnt: &mut u64,
304 buf: &mut Option<Bytes>,
305 trailers_buf: &mut Option<BytesMut>,
306 trailers_cnt: &mut usize,
307 h1_max_headers: usize,
308 h1_max_header_size: usize,
309 ) -> Poll<Result<ChunkedState, io::Error>> {
310 use self::ChunkedState::*;
311 match *self {
312 Start => ChunkedState::read_start(cx, body, size),
313 Size => ChunkedState::read_size(cx, body, size),
314 SizeLws => ChunkedState::read_size_lws(cx, body),
315 Extension => ChunkedState::read_extension(cx, body, extensions_cnt),
316 SizeLf => ChunkedState::read_size_lf(cx, body, *size),
317 Body => ChunkedState::read_body(cx, body, size, buf),
318 BodyCr => ChunkedState::read_body_cr(cx, body),
319 BodyLf => ChunkedState::read_body_lf(cx, body),
320 Trailer => ChunkedState::read_trailer(cx, body, trailers_buf, h1_max_header_size),
321 TrailerLf => ChunkedState::read_trailer_lf(
322 cx,
323 body,
324 trailers_buf,
325 trailers_cnt,
326 h1_max_headers,
327 h1_max_header_size,
328 ),
329 EndCr => ChunkedState::read_end_cr(cx, body, trailers_buf, h1_max_header_size),
330 EndLf => ChunkedState::read_end_lf(cx, body, trailers_buf, h1_max_header_size),
331 End => Poll::Ready(Ok(ChunkedState::End)),
332 }
333 }
334
335 fn read_start<R: MemRead>(
336 cx: &mut Context<'_>,
337 rdr: &mut R,
338 size: &mut u64,
339 ) -> Poll<Result<ChunkedState, io::Error>> {
340 trace!("Read chunk start");
341
342 let radix = 16;
343 match byte!(rdr, cx) {
344 b @ b'0'..=b'9' => {
345 *size = or_overflow!(size.checked_mul(radix));
346 *size = or_overflow!(size.checked_add((b - b'0') as u64));
347 }
348 b @ b'a'..=b'f' => {
349 *size = or_overflow!(size.checked_mul(radix));
350 *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
351 }
352 b @ b'A'..=b'F' => {
353 *size = or_overflow!(size.checked_mul(radix));
354 *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
355 }
356 _ => {
357 return Poll::Ready(Err(io::Error::new(
358 io::ErrorKind::InvalidInput,
359 "Invalid chunk size line: missing size digit",
360 )));
361 }
362 }
363
364 Poll::Ready(Ok(ChunkedState::Size))
365 }
366
367 fn read_size<R: MemRead>(
368 cx: &mut Context<'_>,
369 rdr: &mut R,
370 size: &mut u64,
371 ) -> Poll<Result<ChunkedState, io::Error>> {
372 trace!("Read chunk hex size");
373
374 let radix = 16;
375 match byte!(rdr, cx) {
376 b @ b'0'..=b'9' => {
377 *size = or_overflow!(size.checked_mul(radix));
378 *size = or_overflow!(size.checked_add((b - b'0') as u64));
379 }
380 b @ b'a'..=b'f' => {
381 *size = or_overflow!(size.checked_mul(radix));
382 *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
383 }
384 b @ b'A'..=b'F' => {
385 *size = or_overflow!(size.checked_mul(radix));
386 *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
387 }
388 b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
389 b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
390 b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
391 _ => {
392 return Poll::Ready(Err(io::Error::new(
393 io::ErrorKind::InvalidInput,
394 "Invalid chunk size line: Invalid Size",
395 )));
396 }
397 }
398 Poll::Ready(Ok(ChunkedState::Size))
399 }
400 fn read_size_lws<R: MemRead>(
401 cx: &mut Context<'_>,
402 rdr: &mut R,
403 ) -> Poll<Result<ChunkedState, io::Error>> {
404 trace!("read_size_lws");
405 match byte!(rdr, cx) {
406 b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
408 b';' => Poll::Ready(Ok(ChunkedState::Extension)),
409 b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
410 _ => Poll::Ready(Err(io::Error::new(
411 io::ErrorKind::InvalidInput,
412 "Invalid chunk size linear white space",
413 ))),
414 }
415 }
416 fn read_extension<R: MemRead>(
417 cx: &mut Context<'_>,
418 rdr: &mut R,
419 extensions_cnt: &mut u64,
420 ) -> Poll<Result<ChunkedState, io::Error>> {
421 trace!("read_extension");
422 match byte!(rdr, cx) {
429 b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
430 b'\n' => Poll::Ready(Err(io::Error::new(
431 io::ErrorKind::InvalidData,
432 "invalid chunk extension contains newline",
433 ))),
434 _ => {
435 *extensions_cnt += 1;
436 if *extensions_cnt >= CHUNKED_EXTENSIONS_LIMIT {
437 Poll::Ready(Err(io::Error::new(
438 io::ErrorKind::InvalidData,
439 "chunk extensions over limit",
440 )))
441 } else {
442 Poll::Ready(Ok(ChunkedState::Extension))
443 }
444 } }
446 }
447 fn read_size_lf<R: MemRead>(
448 cx: &mut Context<'_>,
449 rdr: &mut R,
450 size: u64,
451 ) -> Poll<Result<ChunkedState, io::Error>> {
452 trace!("Chunk size is {:?}", size);
453 match byte!(rdr, cx) {
454 b'\n' => {
455 if size == 0 {
456 Poll::Ready(Ok(ChunkedState::EndCr))
457 } else {
458 debug!("incoming chunked header: {0:#X} ({0} bytes)", size);
459 Poll::Ready(Ok(ChunkedState::Body))
460 }
461 }
462 _ => Poll::Ready(Err(io::Error::new(
463 io::ErrorKind::InvalidInput,
464 "Invalid chunk size LF",
465 ))),
466 }
467 }
468
469 fn read_body<R: MemRead>(
470 cx: &mut Context<'_>,
471 rdr: &mut R,
472 rem: &mut u64,
473 buf: &mut Option<Bytes>,
474 ) -> Poll<Result<ChunkedState, io::Error>> {
475 trace!("Chunked read, remaining={:?}", rem);
476
477 let rem_cap = match *rem {
479 r if r > usize::MAX as u64 => usize::MAX,
480 r => r as usize,
481 };
482
483 let to_read = rem_cap;
484 let slice = ready!(rdr.read_mem(cx, to_read))?;
485 let count = slice.len();
486
487 if count == 0 {
488 *rem = 0;
489 return Poll::Ready(Err(io::Error::new(
490 io::ErrorKind::UnexpectedEof,
491 IncompleteBody,
492 )));
493 }
494 *buf = Some(slice);
495 *rem -= count as u64;
496
497 if *rem > 0 {
498 Poll::Ready(Ok(ChunkedState::Body))
499 } else {
500 Poll::Ready(Ok(ChunkedState::BodyCr))
501 }
502 }
503 fn read_body_cr<R: MemRead>(
504 cx: &mut Context<'_>,
505 rdr: &mut R,
506 ) -> Poll<Result<ChunkedState, io::Error>> {
507 match byte!(rdr, cx) {
508 b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
509 _ => Poll::Ready(Err(io::Error::new(
510 io::ErrorKind::InvalidInput,
511 "Invalid chunk body CR",
512 ))),
513 }
514 }
515 fn read_body_lf<R: MemRead>(
516 cx: &mut Context<'_>,
517 rdr: &mut R,
518 ) -> Poll<Result<ChunkedState, io::Error>> {
519 match byte!(rdr, cx) {
520 b'\n' => Poll::Ready(Ok(ChunkedState::Start)),
521 _ => Poll::Ready(Err(io::Error::new(
522 io::ErrorKind::InvalidInput,
523 "Invalid chunk body LF",
524 ))),
525 }
526 }
527
528 fn read_trailer<R: MemRead>(
529 cx: &mut Context<'_>,
530 rdr: &mut R,
531 trailers_buf: &mut Option<BytesMut>,
532 h1_max_header_size: usize,
533 ) -> Poll<Result<ChunkedState, io::Error>> {
534 trace!("read_trailer");
535 let byte = byte!(rdr, cx);
536
537 put_u8!(
538 trailers_buf.as_mut().expect("trailers_buf is None"),
539 byte,
540 h1_max_header_size
541 );
542
543 match byte {
544 b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)),
545 _ => Poll::Ready(Ok(ChunkedState::Trailer)),
546 }
547 }
548
549 fn read_trailer_lf<R: MemRead>(
550 cx: &mut Context<'_>,
551 rdr: &mut R,
552 trailers_buf: &mut Option<BytesMut>,
553 trailers_cnt: &mut usize,
554 h1_max_headers: usize,
555 h1_max_header_size: usize,
556 ) -> Poll<Result<ChunkedState, io::Error>> {
557 let byte = byte!(rdr, cx);
558 match byte {
559 b'\n' => {
560 if *trailers_cnt >= h1_max_headers {
561 return Poll::Ready(Err(io::Error::new(
562 io::ErrorKind::InvalidData,
563 "chunk trailers count overflow",
564 )));
565 }
566 *trailers_cnt += 1;
567
568 put_u8!(
569 trailers_buf.as_mut().expect("trailers_buf is None"),
570 byte,
571 h1_max_header_size
572 );
573
574 Poll::Ready(Ok(ChunkedState::EndCr))
575 }
576 _ => Poll::Ready(Err(io::Error::new(
577 io::ErrorKind::InvalidInput,
578 "Invalid trailer end LF",
579 ))),
580 }
581 }
582
583 fn read_end_cr<R: MemRead>(
584 cx: &mut Context<'_>,
585 rdr: &mut R,
586 trailers_buf: &mut Option<BytesMut>,
587 h1_max_header_size: usize,
588 ) -> Poll<Result<ChunkedState, io::Error>> {
589 let byte = byte!(rdr, cx);
590 match byte {
591 b'\r' => {
592 if let Some(trailers_buf) = trailers_buf {
593 put_u8!(trailers_buf, byte, h1_max_header_size);
594 }
595 Poll::Ready(Ok(ChunkedState::EndLf))
596 }
597 byte => {
598 match trailers_buf {
599 None => {
600 let mut buf = BytesMut::with_capacity(64);
602 buf.put_u8(byte);
603 *trailers_buf = Some(buf);
604 }
605 Some(ref mut trailers_buf) => {
606 put_u8!(trailers_buf, byte, h1_max_header_size);
607 }
608 }
609
610 Poll::Ready(Ok(ChunkedState::Trailer))
611 }
612 }
613 }
614 fn read_end_lf<R: MemRead>(
615 cx: &mut Context<'_>,
616 rdr: &mut R,
617 trailers_buf: &mut Option<BytesMut>,
618 h1_max_header_size: usize,
619 ) -> Poll<Result<ChunkedState, io::Error>> {
620 let byte = byte!(rdr, cx);
621 match byte {
622 b'\n' => {
623 if let Some(trailers_buf) = trailers_buf {
624 put_u8!(trailers_buf, byte, h1_max_header_size);
625 }
626 Poll::Ready(Ok(ChunkedState::End))
627 }
628 _ => Poll::Ready(Err(io::Error::new(
629 io::ErrorKind::InvalidInput,
630 "Invalid chunk end LF",
631 ))),
632 }
633 }
634}
635
636fn decode_trailers(buf: &mut BytesMut, count: usize) -> Result<HeaderMap, io::Error> {
638 let mut trailers = HeaderMap::new();
639 let mut headers = vec![httparse::EMPTY_HEADER; count];
640 let res = httparse::parse_headers(buf, &mut headers);
641 match res {
642 Ok(httparse::Status::Complete((_, headers))) => {
643 for header in headers.iter() {
644 use std::convert::TryFrom;
645 let name = match HeaderName::try_from(header.name) {
646 Ok(name) => name,
647 Err(_) => {
648 return Err(io::Error::new(
649 io::ErrorKind::InvalidInput,
650 format!("Invalid header name: {:?}", &header),
651 ));
652 }
653 };
654
655 let value = match HeaderValue::from_bytes(header.value) {
656 Ok(value) => value,
657 Err(_) => {
658 return Err(io::Error::new(
659 io::ErrorKind::InvalidInput,
660 format!("Invalid header value: {:?}", &header),
661 ));
662 }
663 };
664
665 trailers.insert(name, value);
666 }
667
668 Ok(trailers)
669 }
670 Ok(httparse::Status::Partial) => Err(io::Error::new(
671 io::ErrorKind::InvalidInput,
672 "Partial header",
673 )),
674 Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
675 }
676}
677
678#[derive(Debug)]
679struct IncompleteBody;
680
681impl fmt::Display for IncompleteBody {
682 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683 write!(f, "end of file before message length reached")
684 }
685}
686
687impl StdError for IncompleteBody {}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692 use crate::rt::{Read, ReadBuf};
693 use std::pin::Pin;
694 use std::time::Duration;
695
696 impl MemRead for &[u8] {
697 fn read_mem(&mut self, _: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
698 let n = std::cmp::min(len, self.len());
699 if n > 0 {
700 let (a, b) = self.split_at(n);
701 let buf = Bytes::copy_from_slice(a);
702 *self = b;
703 Poll::Ready(Ok(buf))
704 } else {
705 Poll::Ready(Ok(Bytes::new()))
706 }
707 }
708 }
709
710 impl MemRead for &mut (dyn Read + Unpin) {
711 fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
712 let mut v = vec![0; len];
713 let mut buf = ReadBuf::new(&mut v);
714 ready!(Pin::new(self).poll_read(cx, buf.unfilled())?);
715 Poll::Ready(Ok(Bytes::copy_from_slice(buf.filled())))
716 }
717 }
718
719 impl MemRead for Bytes {
720 fn read_mem(&mut self, _: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
721 let n = std::cmp::min(len, self.len());
722 let ret = self.split_to(n);
723 Poll::Ready(Ok(ret))
724 }
725 }
726
727 #[cfg(not(miri))]
738 #[tokio::test]
739 async fn test_read_chunk_size() {
740 use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
741
742 async fn read(s: &str) -> u64 {
743 let mut state = ChunkedState::new();
744 let rdr = &mut s.as_bytes();
745 let mut size = 0;
746 let mut ext_cnt = 0;
747 let mut trailers_cnt = 0;
748 loop {
749 let result = futures_util::future::poll_fn(|cx| {
750 state.step(
751 cx,
752 rdr,
753 &mut size,
754 &mut ext_cnt,
755 &mut None,
756 &mut None,
757 &mut trailers_cnt,
758 DEFAULT_MAX_HEADERS,
759 TRAILER_LIMIT,
760 )
761 })
762 .await;
763 let desc = format!("read_size failed for {:?}", s);
764 state = result.expect(&desc);
765 if state == ChunkedState::Body || state == ChunkedState::EndCr {
766 break;
767 }
768 }
769 size
770 }
771
772 async fn read_err(s: &str, expected_err: io::ErrorKind) {
773 let mut state = ChunkedState::new();
774 let rdr = &mut s.as_bytes();
775 let mut size = 0;
776 let mut ext_cnt = 0;
777 let mut trailers_cnt = 0;
778 loop {
779 let result = futures_util::future::poll_fn(|cx| {
780 state.step(
781 cx,
782 rdr,
783 &mut size,
784 &mut ext_cnt,
785 &mut None,
786 &mut None,
787 &mut trailers_cnt,
788 DEFAULT_MAX_HEADERS,
789 TRAILER_LIMIT,
790 )
791 })
792 .await;
793 state = match result {
794 Ok(s) => s,
795 Err(e) => {
796 assert!(
797 expected_err == e.kind(),
798 "Reading {:?}, expected {:?}, but got {:?}",
799 s,
800 expected_err,
801 e.kind()
802 );
803 return;
804 }
805 };
806 if state == ChunkedState::Body || state == ChunkedState::End {
807 panic!("Was Ok. Expected Err for {:?}", s);
808 }
809 }
810 }
811
812 assert_eq!(1, read("1\r\n").await);
813 assert_eq!(1, read("01\r\n").await);
814 assert_eq!(0, read("0\r\n").await);
815 assert_eq!(0, read("00\r\n").await);
816 assert_eq!(10, read("A\r\n").await);
817 assert_eq!(10, read("a\r\n").await);
818 assert_eq!(255, read("Ff\r\n").await);
819 assert_eq!(255, read("Ff \r\n").await);
820 read_err("F\rF", InvalidInput).await;
822 read_err("F", UnexpectedEof).await;
823 read_err("\r\n\r\n", InvalidInput).await;
825 read_err("\r\n", InvalidInput).await;
826 read_err("X\r\n", InvalidInput).await;
828 read_err("1X\r\n", InvalidInput).await;
829 read_err("-\r\n", InvalidInput).await;
830 read_err("-1\r\n", InvalidInput).await;
831 assert_eq!(1, read("1;extension\r\n").await);
833 assert_eq!(10, read("a;ext name=value\r\n").await);
834 assert_eq!(1, read("1;extension;extension2\r\n").await);
835 assert_eq!(1, read("1;;; ;\r\n").await);
836 assert_eq!(2, read("2; extension...\r\n").await);
837 assert_eq!(3, read("3 ; extension=123\r\n").await);
838 assert_eq!(3, read("3 ;\r\n").await);
839 assert_eq!(3, read("3 ; \r\n").await);
840 read_err("1 invalid extension\r\n", InvalidInput).await;
842 read_err("1 A\r\n", InvalidInput).await;
843 read_err("1;no CRLF", UnexpectedEof).await;
844 read_err("1;reject\nnewlines\r\n", InvalidData).await;
845 read_err("f0000000000000003\r\n", InvalidData).await;
847 }
848
849 #[cfg(not(miri))]
850 #[tokio::test]
851 async fn test_read_sized_early_eof() {
852 let mut bytes = &b"foo bar"[..];
853 let mut decoder = Decoder::length(10);
854 assert_eq!(
855 decoder
856 .decode_fut(&mut bytes)
857 .await
858 .unwrap()
859 .data_ref()
860 .unwrap()
861 .len(),
862 7
863 );
864 let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
865 assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
866 }
867
868 #[cfg(not(miri))]
869 #[tokio::test]
870 async fn test_read_chunked_early_eof() {
871 let mut bytes = &b"\
872 9\r\n\
873 foo bar\
874 "[..];
875 let mut decoder = Decoder::chunked(None, None);
876 assert_eq!(
877 decoder
878 .decode_fut(&mut bytes)
879 .await
880 .unwrap()
881 .data_ref()
882 .unwrap()
883 .len(),
884 7
885 );
886 let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
887 assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
888 }
889
890 #[cfg(not(miri))]
891 #[tokio::test]
892 async fn test_read_chunked_single_read() {
893 let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..];
894 let buf = Decoder::chunked(None, None)
895 .decode_fut(&mut mock_buf)
896 .await
897 .expect("decode")
898 .into_data()
899 .expect("unknown frame type");
900 assert_eq!(16, buf.len());
901 let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
902 assert_eq!("1234567890abcdef", &result);
903 }
904
905 #[tokio::test]
906 async fn test_read_chunked_with_missing_zero_digit() {
907 let mut mock_buf = &b"1\r\nZ\r\n\r\n\r\n"[..];
909 let mut decoder = Decoder::chunked(None, None);
910 let buf = decoder
911 .decode_fut(&mut mock_buf)
912 .await
913 .expect("decode")
914 .into_data()
915 .expect("unknown frame type");
916 assert_eq!("Z", buf);
917
918 let err = decoder
919 .decode_fut(&mut mock_buf)
920 .await
921 .expect_err("decode 2");
922 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
923 }
924
925 #[tokio::test]
926 async fn test_read_chunked_extensions_over_limit() {
927 let per_chunk = super::CHUNKED_EXTENSIONS_LIMIT * 2 / 3;
930 let mut scratch = vec![];
931 for _ in 0..2 {
932 scratch.extend(b"1;");
933 scratch.extend(b"x".repeat(per_chunk as usize));
934 scratch.extend(b"\r\nA\r\n");
935 }
936 scratch.extend(b"0\r\n\r\n");
937 let mut mock_buf = Bytes::from(scratch);
938
939 let mut decoder = Decoder::chunked(None, None);
940 let buf1 = decoder
941 .decode_fut(&mut mock_buf)
942 .await
943 .expect("decode1")
944 .into_data()
945 .expect("unknown frame type");
946 assert_eq!(&buf1[..], b"A");
947
948 let err = decoder
949 .decode_fut(&mut mock_buf)
950 .await
951 .expect_err("decode2");
952 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
953 assert_eq!(err.to_string(), "chunk extensions over limit");
954 }
955
956 #[cfg(not(miri))]
957 #[tokio::test]
958 async fn test_read_chunked_trailer_with_missing_lf() {
959 let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..];
960 let mut decoder = Decoder::chunked(None, None);
961 decoder.decode_fut(&mut mock_buf).await.expect("decode");
962 let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err();
963 assert_eq!(e.kind(), io::ErrorKind::InvalidInput);
964 }
965
966 #[cfg(not(miri))]
967 #[tokio::test]
968 async fn test_read_chunked_after_eof() {
969 let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..];
970 let mut decoder = Decoder::chunked(None, None);
971
972 let buf = decoder
974 .decode_fut(&mut mock_buf)
975 .await
976 .unwrap()
977 .into_data()
978 .expect("unknown frame type");
979 assert_eq!(16, buf.len());
980 let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
981 assert_eq!("1234567890abcdef", &result);
982
983 let buf = decoder
985 .decode_fut(&mut mock_buf)
986 .await
987 .expect("decode")
988 .into_data()
989 .expect("unknown frame type");
990 assert_eq!(0, buf.len());
991
992 let buf = decoder
994 .decode_fut(&mut mock_buf)
995 .await
996 .expect("decode")
997 .into_data()
998 .expect("unknown frame type");
999 assert_eq!(0, buf.len());
1000 }
1001
1002 async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String {
1005 let mut outs = Vec::new();
1006
1007 let mut ins = crate::common::io::Compat::new(if block_at == 0 {
1008 tokio_test::io::Builder::new()
1009 .wait(Duration::from_millis(10))
1010 .read(content)
1011 .build()
1012 } else {
1013 tokio_test::io::Builder::new()
1014 .read(&content[..block_at])
1015 .wait(Duration::from_millis(10))
1016 .read(&content[block_at..])
1017 .build()
1018 });
1019
1020 let mut ins = &mut ins as &mut (dyn Read + Unpin);
1021
1022 loop {
1023 let buf = decoder
1024 .decode_fut(&mut ins)
1025 .await
1026 .expect("unexpected decode error")
1027 .into_data()
1028 .expect("unexpected frame type");
1029 if buf.is_empty() {
1030 break; }
1032 outs.extend(buf.as_ref());
1033 }
1034
1035 String::from_utf8(outs).expect("decode String")
1036 }
1037
1038 async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) {
1041 let content_len = content.len();
1042 for block_at in 0..content_len {
1043 let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await;
1044 assert_eq!(expected, &actual) }
1046 }
1047
1048 #[cfg(not(miri))]
1049 #[tokio::test]
1050 async fn test_read_length_async() {
1051 let content = "foobar";
1052 all_async_cases(content, content, Decoder::length(content.len() as u64)).await;
1053 }
1054
1055 #[cfg(not(miri))]
1056 #[tokio::test]
1057 async fn test_read_chunked_async() {
1058 let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n";
1059 let expected = "foobar";
1060 all_async_cases(content, expected, Decoder::chunked(None, None)).await;
1061 }
1062
1063 #[cfg(not(miri))]
1064 #[tokio::test]
1065 async fn test_read_eof_async() {
1066 let content = "foobar";
1067 all_async_cases(content, content, Decoder::eof()).await;
1068 }
1069
1070 #[cfg(all(feature = "nightly", not(miri)))]
1071 #[bench]
1072 fn bench_decode_chunked_1kb(b: &mut test::Bencher) {
1073 let rt = new_runtime();
1074
1075 const LEN: usize = 1024;
1076 let mut vec = Vec::new();
1077 vec.extend(format!("{:x}\r\n", LEN).as_bytes());
1078 vec.extend(&[0; LEN][..]);
1079 vec.extend(b"\r\n");
1080 let content = Bytes::from(vec);
1081
1082 b.bytes = LEN as u64;
1083
1084 b.iter(|| {
1085 let mut decoder = Decoder::chunked(None, None);
1086 rt.block_on(async {
1087 let mut raw = content.clone();
1088 let chunk = decoder
1089 .decode_fut(&mut raw)
1090 .await
1091 .unwrap()
1092 .into_data()
1093 .unwrap();
1094 assert_eq!(chunk.len(), LEN);
1095 });
1096 });
1097 }
1098
1099 #[cfg(all(feature = "nightly", not(miri)))]
1100 #[bench]
1101 fn bench_decode_length_1kb(b: &mut test::Bencher) {
1102 let rt = new_runtime();
1103
1104 const LEN: usize = 1024;
1105 let content = Bytes::from(&[0; LEN][..]);
1106 b.bytes = LEN as u64;
1107
1108 b.iter(|| {
1109 let mut decoder = Decoder::length(LEN as u64);
1110 rt.block_on(async {
1111 let mut raw = content.clone();
1112 let chunk = decoder
1113 .decode_fut(&mut raw)
1114 .await
1115 .unwrap()
1116 .into_data()
1117 .unwrap();
1118 assert_eq!(chunk.len(), LEN);
1119 });
1120 });
1121 }
1122
1123 #[cfg(feature = "nightly")]
1124 fn new_runtime() -> tokio::runtime::Runtime {
1125 tokio::runtime::Builder::new_current_thread()
1126 .enable_all()
1127 .build()
1128 .expect("rt build")
1129 }
1130
1131 #[test]
1132 fn test_decode_trailers() {
1133 let mut buf = BytesMut::new();
1134 buf.extend_from_slice(
1135 b"Expires: Wed, 21 Oct 2015 07:28:00 GMT\r\nX-Stream-Error: failed to decode\r\n\r\n",
1136 );
1137 let headers = decode_trailers(&mut buf, 2).expect("decode_trailers");
1138 assert_eq!(headers.len(), 2);
1139 assert_eq!(
1140 headers.get("Expires").unwrap(),
1141 "Wed, 21 Oct 2015 07:28:00 GMT"
1142 );
1143 assert_eq!(headers.get("X-Stream-Error").unwrap(), "failed to decode");
1144 }
1145
1146 #[tokio::test]
1147 async fn test_trailer_max_headers_enforced() {
1148 let h1_max_headers = 10;
1149 let mut scratch = vec![];
1150 scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1151 for i in 0..h1_max_headers {
1152 scratch.extend(format!("trailer{}: {}\r\n", i, i).as_bytes());
1153 }
1154 scratch.extend(b"\r\n");
1155 let mut mock_buf = Bytes::from(scratch);
1156
1157 let mut decoder = Decoder::chunked(Some(h1_max_headers), None);
1158
1159 let buf = decoder
1161 .decode_fut(&mut mock_buf)
1162 .await
1163 .unwrap()
1164 .into_data()
1165 .expect("unknown frame type");
1166 assert_eq!(16, buf.len());
1167
1168 let err = decoder
1170 .decode_fut(&mut mock_buf)
1171 .await
1172 .expect_err("trailer fields over limit");
1173 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1174 }
1175
1176 #[tokio::test]
1177 async fn test_trailer_max_header_size_huge_trailer() {
1178 let max_header_size = 1024;
1179 let mut scratch = vec![];
1180 scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1181 scratch.extend(format!("huge_trailer: {}\r\n", "x".repeat(max_header_size)).as_bytes());
1182 scratch.extend(b"\r\n");
1183 let mut mock_buf = Bytes::from(scratch);
1184
1185 let mut decoder = Decoder::chunked(None, Some(max_header_size));
1186
1187 let buf = decoder
1189 .decode_fut(&mut mock_buf)
1190 .await
1191 .unwrap()
1192 .into_data()
1193 .expect("unknown frame type");
1194 assert_eq!(16, buf.len());
1195
1196 let err = decoder
1198 .decode_fut(&mut mock_buf)
1199 .await
1200 .expect_err("trailers over limit");
1201 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1202 }
1203
1204 #[tokio::test]
1205 async fn test_trailer_max_header_size_many_small_trailers() {
1206 let max_headers = 10;
1207 let header_size = 64;
1208 let mut scratch = vec![];
1209 scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1210
1211 for i in 0..max_headers {
1212 scratch.extend(format!("trailer{}: {}\r\n", i, "x".repeat(header_size)).as_bytes());
1213 }
1214
1215 scratch.extend(b"\r\n");
1216 let mut mock_buf = Bytes::from(scratch);
1217
1218 let mut decoder = Decoder::chunked(None, Some(max_headers * header_size));
1219
1220 let buf = decoder
1222 .decode_fut(&mut mock_buf)
1223 .await
1224 .unwrap()
1225 .into_data()
1226 .expect("unknown frame type");
1227 assert_eq!(16, buf.len());
1228
1229 let err = decoder
1231 .decode_fut(&mut mock_buf)
1232 .await
1233 .expect_err("trailers over limit");
1234 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1235 }
1236}