hyper/proto/h1/
encode.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::io::IoSlice;
4
5use bytes::buf::{Chain, Take};
6use bytes::{Buf, Bytes};
7use http::{
8    header::{
9        AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10        CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
11    },
12    HeaderMap, HeaderName, HeaderValue,
13};
14
15use super::io::WriteBuf;
16use super::role::{write_headers, write_headers_title_case};
17
18type StaticBuf = &'static [u8];
19
20/// Encoders to handle different Transfer-Encodings.
21#[derive(Debug, Clone, PartialEq)]
22pub(crate) struct Encoder {
23    kind: Kind,
24    is_last: bool,
25}
26
27#[derive(Debug)]
28pub(crate) struct EncodedBuf<B> {
29    kind: BufKind<B>,
30}
31
32#[derive(Debug)]
33pub(crate) struct NotEof(u64);
34
35#[derive(Debug, PartialEq, Clone)]
36enum Kind {
37    /// An Encoder for when Transfer-Encoding includes `chunked`.
38    Chunked(Option<Vec<HeaderValue>>),
39    /// An Encoder for when Content-Length is set.
40    ///
41    /// Enforces that the body is not longer than the Content-Length header.
42    Length(u64),
43    /// An Encoder for when neither Content-Length nor Chunked encoding is set.
44    ///
45    /// This is mostly only used with HTTP/1.0 with a length. This kind requires
46    /// the connection to be closed when the body is finished.
47    #[cfg(feature = "server")]
48    CloseDelimited,
49}
50
51#[derive(Debug)]
52enum BufKind<B> {
53    Exact(B),
54    Limited(Take<B>),
55    Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
56    ChunkedEnd(StaticBuf),
57    Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
58}
59
60impl Encoder {
61    fn new(kind: Kind) -> Encoder {
62        Encoder {
63            kind,
64            is_last: false,
65        }
66    }
67    pub(crate) fn chunked() -> Encoder {
68        Encoder::new(Kind::Chunked(None))
69    }
70
71    pub(crate) fn length(len: u64) -> Encoder {
72        Encoder::new(Kind::Length(len))
73    }
74
75    #[cfg(feature = "server")]
76    pub(crate) fn close_delimited() -> Encoder {
77        Encoder::new(Kind::CloseDelimited)
78    }
79
80    pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
81        match self.kind {
82            Kind::Chunked(_) => Encoder {
83                kind: Kind::Chunked(Some(trailers)),
84                is_last: self.is_last,
85            },
86            _ => self,
87        }
88    }
89
90    pub(crate) fn is_eof(&self) -> bool {
91        matches!(self.kind, Kind::Length(0))
92    }
93
94    #[cfg(feature = "server")]
95    pub(crate) fn set_last(mut self, is_last: bool) -> Self {
96        self.is_last = is_last;
97        self
98    }
99
100    pub(crate) fn is_last(&self) -> bool {
101        self.is_last
102    }
103
104    pub(crate) fn is_close_delimited(&self) -> bool {
105        match self.kind {
106            #[cfg(feature = "server")]
107            Kind::CloseDelimited => true,
108            _ => false,
109        }
110    }
111
112    pub(crate) fn is_chunked(&self) -> bool {
113        matches!(self.kind, Kind::Chunked(_))
114    }
115
116    pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
117        match self.kind {
118            Kind::Length(0) => Ok(None),
119            Kind::Chunked(_) => Ok(Some(EncodedBuf {
120                kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
121            })),
122            #[cfg(feature = "server")]
123            Kind::CloseDelimited => Ok(None),
124            Kind::Length(n) => Err(NotEof(n)),
125        }
126    }
127
128    pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
129    where
130        B: Buf,
131    {
132        let len = msg.remaining();
133        debug_assert!(len > 0, "encode() called with empty buf");
134
135        let kind = match self.kind {
136            Kind::Chunked(_) => {
137                trace!("encoding chunked {}B", len);
138                let buf = ChunkSize::new(len)
139                    .chain(msg)
140                    .chain(b"\r\n" as &'static [u8]);
141                BufKind::Chunked(buf)
142            }
143            Kind::Length(ref mut remaining) => {
144                trace!("sized write, len = {}", len);
145                if len as u64 > *remaining {
146                    let limit = *remaining as usize;
147                    *remaining = 0;
148                    BufKind::Limited(msg.take(limit))
149                } else {
150                    *remaining -= len as u64;
151                    BufKind::Exact(msg)
152                }
153            }
154            #[cfg(feature = "server")]
155            Kind::CloseDelimited => {
156                trace!("close delimited write {}B", len);
157                BufKind::Exact(msg)
158            }
159        };
160        EncodedBuf { kind }
161    }
162
163    pub(crate) fn encode_trailers<B>(
164        &self,
165        trailers: HeaderMap,
166        title_case_headers: bool,
167    ) -> Option<EncodedBuf<B>> {
168        trace!("encoding trailers");
169        match &self.kind {
170            Kind::Chunked(Some(allowed_trailer_fields)) => {
171                let allowed_trailer_field_map = allowed_trailer_field_map(allowed_trailer_fields);
172
173                let mut cur_name = None;
174                let mut allowed_trailers = HeaderMap::new();
175
176                for (opt_name, value) in trailers {
177                    if let Some(n) = opt_name {
178                        cur_name = Some(n);
179                    }
180                    let name = cur_name.as_ref().expect("current header name");
181
182                    if allowed_trailer_field_map.contains_key(name.as_str()) {
183                        if is_valid_trailer_field(name) {
184                            allowed_trailers.insert(name, value);
185                        } else {
186                            debug!("trailer field is not valid: {}", &name);
187                        }
188                    } else {
189                        debug!("trailer header name not found in trailer header: {}", &name);
190                    }
191                }
192
193                let mut buf = Vec::new();
194                if title_case_headers {
195                    write_headers_title_case(&allowed_trailers, &mut buf);
196                } else {
197                    write_headers(&allowed_trailers, &mut buf);
198                }
199
200                if buf.is_empty() {
201                    return None;
202                }
203
204                Some(EncodedBuf {
205                    kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
206                })
207            }
208            Kind::Chunked(None) => {
209                debug!("attempted to encode trailers, but the trailer header is not set");
210                None
211            }
212            _ => {
213                debug!("attempted to encode trailers for non-chunked response");
214                None
215            }
216        }
217    }
218
219    pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
220    where
221        B: Buf,
222    {
223        let len = msg.remaining();
224        debug_assert!(len > 0, "encode() called with empty buf");
225
226        match self.kind {
227            Kind::Chunked(_) => {
228                trace!("encoding chunked {}B", len);
229                let buf = ChunkSize::new(len)
230                    .chain(msg)
231                    .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
232                dst.buffer(buf);
233                !self.is_last
234            }
235            Kind::Length(remaining) => {
236                use std::cmp::Ordering;
237
238                trace!("sized write, len = {}", len);
239                match (len as u64).cmp(&remaining) {
240                    Ordering::Equal => {
241                        dst.buffer(msg);
242                        !self.is_last
243                    }
244                    Ordering::Greater => {
245                        dst.buffer(msg.take(remaining as usize));
246                        !self.is_last
247                    }
248                    Ordering::Less => {
249                        dst.buffer(msg);
250                        false
251                    }
252                }
253            }
254            #[cfg(feature = "server")]
255            Kind::CloseDelimited => {
256                trace!("close delimited write {}B", len);
257                dst.buffer(msg);
258                false
259            }
260        }
261    }
262}
263
264fn is_valid_trailer_field(name: &HeaderName) -> bool {
265    !matches!(
266        *name,
267        AUTHORIZATION
268            | CACHE_CONTROL
269            | CONTENT_ENCODING
270            | CONTENT_LENGTH
271            | CONTENT_RANGE
272            | CONTENT_TYPE
273            | HOST
274            | MAX_FORWARDS
275            | SET_COOKIE
276            | TRAILER
277            | TRANSFER_ENCODING
278            | TE
279    )
280}
281
282fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
283    let mut trailer_map = HashMap::new();
284
285    for header_value in allowed_trailer_fields {
286        if let Ok(header_str) = header_value.to_str() {
287            let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();
288
289            for item in items {
290                trailer_map.entry(item.to_string()).or_insert(());
291            }
292        }
293    }
294
295    trailer_map
296}
297
298impl<B> Buf for EncodedBuf<B>
299where
300    B: Buf,
301{
302    #[inline]
303    fn remaining(&self) -> usize {
304        match self.kind {
305            BufKind::Exact(ref b) => b.remaining(),
306            BufKind::Limited(ref b) => b.remaining(),
307            BufKind::Chunked(ref b) => b.remaining(),
308            BufKind::ChunkedEnd(ref b) => b.remaining(),
309            BufKind::Trailers(ref b) => b.remaining(),
310        }
311    }
312
313    #[inline]
314    fn chunk(&self) -> &[u8] {
315        match self.kind {
316            BufKind::Exact(ref b) => b.chunk(),
317            BufKind::Limited(ref b) => b.chunk(),
318            BufKind::Chunked(ref b) => b.chunk(),
319            BufKind::ChunkedEnd(ref b) => b.chunk(),
320            BufKind::Trailers(ref b) => b.chunk(),
321        }
322    }
323
324    #[inline]
325    fn advance(&mut self, cnt: usize) {
326        match self.kind {
327            BufKind::Exact(ref mut b) => b.advance(cnt),
328            BufKind::Limited(ref mut b) => b.advance(cnt),
329            BufKind::Chunked(ref mut b) => b.advance(cnt),
330            BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
331            BufKind::Trailers(ref mut b) => b.advance(cnt),
332        }
333    }
334
335    #[inline]
336    fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
337        match self.kind {
338            BufKind::Exact(ref b) => b.chunks_vectored(dst),
339            BufKind::Limited(ref b) => b.chunks_vectored(dst),
340            BufKind::Chunked(ref b) => b.chunks_vectored(dst),
341            BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
342            BufKind::Trailers(ref b) => b.chunks_vectored(dst),
343        }
344    }
345}
346
347#[cfg(target_pointer_width = "32")]
348const USIZE_BYTES: usize = 4;
349
350#[cfg(target_pointer_width = "64")]
351const USIZE_BYTES: usize = 8;
352
353// each byte will become 2 hex
354const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
355
356#[derive(Clone, Copy)]
357struct ChunkSize {
358    bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
359    pos: u8,
360    len: u8,
361}
362
363impl ChunkSize {
364    fn new(len: usize) -> ChunkSize {
365        use std::fmt::Write;
366        let mut size = ChunkSize {
367            bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
368            pos: 0,
369            len: 0,
370        };
371        write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
372        size
373    }
374}
375
376impl Buf for ChunkSize {
377    #[inline]
378    fn remaining(&self) -> usize {
379        (self.len - self.pos).into()
380    }
381
382    #[inline]
383    fn chunk(&self) -> &[u8] {
384        &self.bytes[self.pos.into()..self.len.into()]
385    }
386
387    #[inline]
388    fn advance(&mut self, cnt: usize) {
389        assert!(cnt <= self.remaining());
390        self.pos += cnt as u8; // just asserted cnt fits in u8
391    }
392}
393
394impl fmt::Debug for ChunkSize {
395    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
396        f.debug_struct("ChunkSize")
397            .field("bytes", &&self.bytes[..self.len.into()])
398            .field("pos", &self.pos)
399            .finish()
400    }
401}
402
403impl fmt::Write for ChunkSize {
404    fn write_str(&mut self, num: &str) -> fmt::Result {
405        use std::io::Write;
406        (&mut self.bytes[self.len.into()..])
407            .write_all(num.as_bytes())
408            .expect("&mut [u8].write() cannot error");
409        self.len += num.len() as u8; // safe because bytes is never bigger than 256
410        Ok(())
411    }
412}
413
414impl<B: Buf> From<B> for EncodedBuf<B> {
415    fn from(buf: B) -> Self {
416        EncodedBuf {
417            kind: BufKind::Exact(buf),
418        }
419    }
420}
421
422impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
423    fn from(buf: Take<B>) -> Self {
424        EncodedBuf {
425            kind: BufKind::Limited(buf),
426        }
427    }
428}
429
430impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
431    fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
432        EncodedBuf {
433            kind: BufKind::Chunked(buf),
434        }
435    }
436}
437
438impl fmt::Display for NotEof {
439    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440        write!(f, "early end, expected {} more bytes", self.0)
441    }
442}
443
444impl std::error::Error for NotEof {}
445
446#[cfg(test)]
447mod tests {
448    use bytes::BufMut;
449    use http::{
450        header::{
451            AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
452            CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
453        },
454        HeaderMap, HeaderName, HeaderValue,
455    };
456
457    use super::super::io::Cursor;
458    use super::Encoder;
459
460    #[test]
461    fn chunked() {
462        let mut encoder = Encoder::chunked();
463        let mut dst = Vec::new();
464
465        let msg1 = b"foo bar".as_ref();
466        let buf1 = encoder.encode(msg1);
467        dst.put(buf1);
468        assert_eq!(dst, b"7\r\nfoo bar\r\n");
469
470        let msg2 = b"baz quux herp".as_ref();
471        let buf2 = encoder.encode(msg2);
472        dst.put(buf2);
473
474        assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
475
476        let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
477        dst.put(end);
478
479        assert_eq!(
480            dst,
481            b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
482        );
483    }
484
485    #[test]
486    fn length() {
487        let max_len = 8;
488        let mut encoder = Encoder::length(max_len as u64);
489        let mut dst = Vec::new();
490
491        let msg1 = b"foo bar".as_ref();
492        let buf1 = encoder.encode(msg1);
493        dst.put(buf1);
494
495        assert_eq!(dst, b"foo bar");
496        assert!(!encoder.is_eof());
497        encoder.end::<()>().unwrap_err();
498
499        let msg2 = b"baz".as_ref();
500        let buf2 = encoder.encode(msg2);
501        dst.put(buf2);
502
503        assert_eq!(dst.len(), max_len);
504        assert_eq!(dst, b"foo barb");
505        assert!(encoder.is_eof());
506        assert!(encoder.end::<()>().unwrap().is_none());
507    }
508
509    #[cfg(feature = "server")]
510    #[test]
511    fn eof() {
512        let mut encoder = Encoder::close_delimited();
513        let mut dst = Vec::new();
514
515        let msg1 = b"foo bar".as_ref();
516        let buf1 = encoder.encode(msg1);
517        dst.put(buf1);
518
519        assert_eq!(dst, b"foo bar");
520        assert!(!encoder.is_eof());
521        encoder.end::<()>().unwrap();
522
523        let msg2 = b"baz".as_ref();
524        let buf2 = encoder.encode(msg2);
525        dst.put(buf2);
526
527        assert_eq!(dst, b"foo barbaz");
528        assert!(!encoder.is_eof());
529        encoder.end::<()>().unwrap();
530    }
531
532    #[test]
533    fn chunked_with_valid_trailers() {
534        let encoder = Encoder::chunked();
535        let trailers = vec![HeaderValue::from_static("chunky-trailer")];
536        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
537
538        let headers = HeaderMap::from_iter(vec![
539            (
540                HeaderName::from_static("chunky-trailer"),
541                HeaderValue::from_static("header data"),
542            ),
543            (
544                HeaderName::from_static("should-not-be-included"),
545                HeaderValue::from_static("oops"),
546            ),
547        ]);
548
549        let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
550
551        let mut dst = Vec::new();
552        dst.put(buf1);
553        assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
554    }
555
556    #[test]
557    fn chunked_with_multiple_trailer_headers() {
558        let encoder = Encoder::chunked();
559        let trailers = vec![
560            HeaderValue::from_static("chunky-trailer"),
561            HeaderValue::from_static("chunky-trailer-2"),
562        ];
563        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
564
565        let headers = HeaderMap::from_iter(vec![
566            (
567                HeaderName::from_static("chunky-trailer"),
568                HeaderValue::from_static("header data"),
569            ),
570            (
571                HeaderName::from_static("chunky-trailer-2"),
572                HeaderValue::from_static("more header data"),
573            ),
574        ]);
575
576        let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
577
578        let mut dst = Vec::new();
579        dst.put(buf1);
580        assert_eq!(
581            dst,
582            b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
583        );
584    }
585
586    #[test]
587    fn chunked_with_no_trailer_header() {
588        let encoder = Encoder::chunked();
589
590        let headers = HeaderMap::from_iter(vec![(
591            HeaderName::from_static("chunky-trailer"),
592            HeaderValue::from_static("header data"),
593        )]);
594
595        assert!(encoder
596            .encode_trailers::<&[u8]>(headers.clone(), false)
597            .is_none());
598
599        let trailers = vec![];
600        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
601
602        assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
603    }
604
605    #[test]
606    fn chunked_with_invalid_trailers() {
607        let encoder = Encoder::chunked();
608
609        let trailers = format!(
610            "{},{},{},{},{},{},{},{},{},{},{},{}",
611            AUTHORIZATION,
612            CACHE_CONTROL,
613            CONTENT_ENCODING,
614            CONTENT_LENGTH,
615            CONTENT_RANGE,
616            CONTENT_TYPE,
617            HOST,
618            MAX_FORWARDS,
619            SET_COOKIE,
620            TRAILER,
621            TRANSFER_ENCODING,
622            TE,
623        );
624        let trailers = vec![HeaderValue::from_str(&trailers).unwrap()];
625        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
626
627        let mut headers = HeaderMap::new();
628        headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
629        headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
630        headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
631        headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
632        headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
633        headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
634        headers.insert(HOST, HeaderValue::from_static("header data"));
635        headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
636        headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
637        headers.insert(TRAILER, HeaderValue::from_static("header data"));
638        headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
639        headers.insert(TE, HeaderValue::from_static("header data"));
640
641        assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
642    }
643
644    #[test]
645    fn chunked_with_title_case_headers() {
646        let encoder = Encoder::chunked();
647        let trailers = vec![HeaderValue::from_static("chunky-trailer")];
648        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
649
650        let headers = HeaderMap::from_iter(vec![(
651            HeaderName::from_static("chunky-trailer"),
652            HeaderValue::from_static("header data"),
653        )]);
654        let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
655
656        let mut dst = Vec::new();
657        dst.put(buf1);
658        assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
659    }
660}