1use std::collections::HashMap;
2use std::fmt;
3use std::io::IoSlice;
4
5use bytes::buf::{Chain, Take};
6use bytes::{Buf, Bytes};
7use http::{
8 header::{
9 AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10 CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
11 },
12 HeaderMap, HeaderName, HeaderValue,
13};
14
15use super::io::WriteBuf;
16use super::role::{write_headers, write_headers_title_case};
17
18type StaticBuf = &'static [u8];
19
20#[derive(Debug, Clone, PartialEq)]
22pub(crate) struct Encoder {
23 kind: Kind,
24 is_last: bool,
25}
26
27#[derive(Debug)]
28pub(crate) struct EncodedBuf<B> {
29 kind: BufKind<B>,
30}
31
32#[derive(Debug)]
33pub(crate) struct NotEof(u64);
34
35#[derive(Debug, PartialEq, Clone)]
36enum Kind {
37 Chunked(Option<Vec<HeaderValue>>),
39 Length(u64),
43 #[cfg(feature = "server")]
48 CloseDelimited,
49}
50
51#[derive(Debug)]
52enum BufKind<B> {
53 Exact(B),
54 Limited(Take<B>),
55 Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
56 ChunkedEnd(StaticBuf),
57 Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
58}
59
60impl Encoder {
61 fn new(kind: Kind) -> Encoder {
62 Encoder {
63 kind,
64 is_last: false,
65 }
66 }
67 pub(crate) fn chunked() -> Encoder {
68 Encoder::new(Kind::Chunked(None))
69 }
70
71 pub(crate) fn length(len: u64) -> Encoder {
72 Encoder::new(Kind::Length(len))
73 }
74
75 #[cfg(feature = "server")]
76 pub(crate) fn close_delimited() -> Encoder {
77 Encoder::new(Kind::CloseDelimited)
78 }
79
80 pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
81 match self.kind {
82 Kind::Chunked(_) => Encoder {
83 kind: Kind::Chunked(Some(trailers)),
84 is_last: self.is_last,
85 },
86 _ => self,
87 }
88 }
89
90 pub(crate) fn is_eof(&self) -> bool {
91 matches!(self.kind, Kind::Length(0))
92 }
93
94 #[cfg(feature = "server")]
95 pub(crate) fn set_last(mut self, is_last: bool) -> Self {
96 self.is_last = is_last;
97 self
98 }
99
100 pub(crate) fn is_last(&self) -> bool {
101 self.is_last
102 }
103
104 pub(crate) fn is_close_delimited(&self) -> bool {
105 match self.kind {
106 #[cfg(feature = "server")]
107 Kind::CloseDelimited => true,
108 _ => false,
109 }
110 }
111
112 pub(crate) fn is_chunked(&self) -> bool {
113 matches!(self.kind, Kind::Chunked(_))
114 }
115
116 pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
117 match self.kind {
118 Kind::Length(0) => Ok(None),
119 Kind::Chunked(_) => Ok(Some(EncodedBuf {
120 kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
121 })),
122 #[cfg(feature = "server")]
123 Kind::CloseDelimited => Ok(None),
124 Kind::Length(n) => Err(NotEof(n)),
125 }
126 }
127
128 pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
129 where
130 B: Buf,
131 {
132 let len = msg.remaining();
133 debug_assert!(len > 0, "encode() called with empty buf");
134
135 let kind = match self.kind {
136 Kind::Chunked(_) => {
137 trace!("encoding chunked {}B", len);
138 let buf = ChunkSize::new(len)
139 .chain(msg)
140 .chain(b"\r\n" as &'static [u8]);
141 BufKind::Chunked(buf)
142 }
143 Kind::Length(ref mut remaining) => {
144 trace!("sized write, len = {}", len);
145 if len as u64 > *remaining {
146 let limit = *remaining as usize;
147 *remaining = 0;
148 BufKind::Limited(msg.take(limit))
149 } else {
150 *remaining -= len as u64;
151 BufKind::Exact(msg)
152 }
153 }
154 #[cfg(feature = "server")]
155 Kind::CloseDelimited => {
156 trace!("close delimited write {}B", len);
157 BufKind::Exact(msg)
158 }
159 };
160 EncodedBuf { kind }
161 }
162
163 pub(crate) fn encode_trailers<B>(
164 &self,
165 trailers: HeaderMap,
166 title_case_headers: bool,
167 ) -> Option<EncodedBuf<B>> {
168 trace!("encoding trailers");
169 match &self.kind {
170 Kind::Chunked(Some(allowed_trailer_fields)) => {
171 let allowed_trailer_field_map = allowed_trailer_field_map(allowed_trailer_fields);
172
173 let mut cur_name = None;
174 let mut allowed_trailers = HeaderMap::new();
175
176 for (opt_name, value) in trailers {
177 if let Some(n) = opt_name {
178 cur_name = Some(n);
179 }
180 let name = cur_name.as_ref().expect("current header name");
181
182 if allowed_trailer_field_map.contains_key(name.as_str()) {
183 if is_valid_trailer_field(name) {
184 allowed_trailers.insert(name, value);
185 } else {
186 debug!("trailer field is not valid: {}", &name);
187 }
188 } else {
189 debug!("trailer header name not found in trailer header: {}", &name);
190 }
191 }
192
193 let mut buf = Vec::new();
194 if title_case_headers {
195 write_headers_title_case(&allowed_trailers, &mut buf);
196 } else {
197 write_headers(&allowed_trailers, &mut buf);
198 }
199
200 if buf.is_empty() {
201 return None;
202 }
203
204 Some(EncodedBuf {
205 kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
206 })
207 }
208 Kind::Chunked(None) => {
209 debug!("attempted to encode trailers, but the trailer header is not set");
210 None
211 }
212 _ => {
213 debug!("attempted to encode trailers for non-chunked response");
214 None
215 }
216 }
217 }
218
219 pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
220 where
221 B: Buf,
222 {
223 let len = msg.remaining();
224 debug_assert!(len > 0, "encode() called with empty buf");
225
226 match self.kind {
227 Kind::Chunked(_) => {
228 trace!("encoding chunked {}B", len);
229 let buf = ChunkSize::new(len)
230 .chain(msg)
231 .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
232 dst.buffer(buf);
233 !self.is_last
234 }
235 Kind::Length(remaining) => {
236 use std::cmp::Ordering;
237
238 trace!("sized write, len = {}", len);
239 match (len as u64).cmp(&remaining) {
240 Ordering::Equal => {
241 dst.buffer(msg);
242 !self.is_last
243 }
244 Ordering::Greater => {
245 dst.buffer(msg.take(remaining as usize));
246 !self.is_last
247 }
248 Ordering::Less => {
249 dst.buffer(msg);
250 false
251 }
252 }
253 }
254 #[cfg(feature = "server")]
255 Kind::CloseDelimited => {
256 trace!("close delimited write {}B", len);
257 dst.buffer(msg);
258 false
259 }
260 }
261 }
262}
263
264fn is_valid_trailer_field(name: &HeaderName) -> bool {
265 !matches!(
266 *name,
267 AUTHORIZATION
268 | CACHE_CONTROL
269 | CONTENT_ENCODING
270 | CONTENT_LENGTH
271 | CONTENT_RANGE
272 | CONTENT_TYPE
273 | HOST
274 | MAX_FORWARDS
275 | SET_COOKIE
276 | TRAILER
277 | TRANSFER_ENCODING
278 | TE
279 )
280}
281
282fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
283 let mut trailer_map = HashMap::new();
284
285 for header_value in allowed_trailer_fields {
286 if let Ok(header_str) = header_value.to_str() {
287 let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();
288
289 for item in items {
290 trailer_map.entry(item.to_string()).or_insert(());
291 }
292 }
293 }
294
295 trailer_map
296}
297
298impl<B> Buf for EncodedBuf<B>
299where
300 B: Buf,
301{
302 #[inline]
303 fn remaining(&self) -> usize {
304 match self.kind {
305 BufKind::Exact(ref b) => b.remaining(),
306 BufKind::Limited(ref b) => b.remaining(),
307 BufKind::Chunked(ref b) => b.remaining(),
308 BufKind::ChunkedEnd(ref b) => b.remaining(),
309 BufKind::Trailers(ref b) => b.remaining(),
310 }
311 }
312
313 #[inline]
314 fn chunk(&self) -> &[u8] {
315 match self.kind {
316 BufKind::Exact(ref b) => b.chunk(),
317 BufKind::Limited(ref b) => b.chunk(),
318 BufKind::Chunked(ref b) => b.chunk(),
319 BufKind::ChunkedEnd(ref b) => b.chunk(),
320 BufKind::Trailers(ref b) => b.chunk(),
321 }
322 }
323
324 #[inline]
325 fn advance(&mut self, cnt: usize) {
326 match self.kind {
327 BufKind::Exact(ref mut b) => b.advance(cnt),
328 BufKind::Limited(ref mut b) => b.advance(cnt),
329 BufKind::Chunked(ref mut b) => b.advance(cnt),
330 BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
331 BufKind::Trailers(ref mut b) => b.advance(cnt),
332 }
333 }
334
335 #[inline]
336 fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
337 match self.kind {
338 BufKind::Exact(ref b) => b.chunks_vectored(dst),
339 BufKind::Limited(ref b) => b.chunks_vectored(dst),
340 BufKind::Chunked(ref b) => b.chunks_vectored(dst),
341 BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
342 BufKind::Trailers(ref b) => b.chunks_vectored(dst),
343 }
344 }
345}
346
347#[cfg(target_pointer_width = "32")]
348const USIZE_BYTES: usize = 4;
349
350#[cfg(target_pointer_width = "64")]
351const USIZE_BYTES: usize = 8;
352
353const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
355
356#[derive(Clone, Copy)]
357struct ChunkSize {
358 bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
359 pos: u8,
360 len: u8,
361}
362
363impl ChunkSize {
364 fn new(len: usize) -> ChunkSize {
365 use std::fmt::Write;
366 let mut size = ChunkSize {
367 bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
368 pos: 0,
369 len: 0,
370 };
371 write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
372 size
373 }
374}
375
376impl Buf for ChunkSize {
377 #[inline]
378 fn remaining(&self) -> usize {
379 (self.len - self.pos).into()
380 }
381
382 #[inline]
383 fn chunk(&self) -> &[u8] {
384 &self.bytes[self.pos.into()..self.len.into()]
385 }
386
387 #[inline]
388 fn advance(&mut self, cnt: usize) {
389 assert!(cnt <= self.remaining());
390 self.pos += cnt as u8; }
392}
393
394impl fmt::Debug for ChunkSize {
395 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
396 f.debug_struct("ChunkSize")
397 .field("bytes", &&self.bytes[..self.len.into()])
398 .field("pos", &self.pos)
399 .finish()
400 }
401}
402
403impl fmt::Write for ChunkSize {
404 fn write_str(&mut self, num: &str) -> fmt::Result {
405 use std::io::Write;
406 (&mut self.bytes[self.len.into()..])
407 .write_all(num.as_bytes())
408 .expect("&mut [u8].write() cannot error");
409 self.len += num.len() as u8; Ok(())
411 }
412}
413
414impl<B: Buf> From<B> for EncodedBuf<B> {
415 fn from(buf: B) -> Self {
416 EncodedBuf {
417 kind: BufKind::Exact(buf),
418 }
419 }
420}
421
422impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
423 fn from(buf: Take<B>) -> Self {
424 EncodedBuf {
425 kind: BufKind::Limited(buf),
426 }
427 }
428}
429
430impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
431 fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
432 EncodedBuf {
433 kind: BufKind::Chunked(buf),
434 }
435 }
436}
437
438impl fmt::Display for NotEof {
439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440 write!(f, "early end, expected {} more bytes", self.0)
441 }
442}
443
444impl std::error::Error for NotEof {}
445
446#[cfg(test)]
447mod tests {
448 use bytes::BufMut;
449 use http::{
450 header::{
451 AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
452 CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
453 },
454 HeaderMap, HeaderName, HeaderValue,
455 };
456
457 use super::super::io::Cursor;
458 use super::Encoder;
459
460 #[test]
461 fn chunked() {
462 let mut encoder = Encoder::chunked();
463 let mut dst = Vec::new();
464
465 let msg1 = b"foo bar".as_ref();
466 let buf1 = encoder.encode(msg1);
467 dst.put(buf1);
468 assert_eq!(dst, b"7\r\nfoo bar\r\n");
469
470 let msg2 = b"baz quux herp".as_ref();
471 let buf2 = encoder.encode(msg2);
472 dst.put(buf2);
473
474 assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
475
476 let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
477 dst.put(end);
478
479 assert_eq!(
480 dst,
481 b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
482 );
483 }
484
485 #[test]
486 fn length() {
487 let max_len = 8;
488 let mut encoder = Encoder::length(max_len as u64);
489 let mut dst = Vec::new();
490
491 let msg1 = b"foo bar".as_ref();
492 let buf1 = encoder.encode(msg1);
493 dst.put(buf1);
494
495 assert_eq!(dst, b"foo bar");
496 assert!(!encoder.is_eof());
497 encoder.end::<()>().unwrap_err();
498
499 let msg2 = b"baz".as_ref();
500 let buf2 = encoder.encode(msg2);
501 dst.put(buf2);
502
503 assert_eq!(dst.len(), max_len);
504 assert_eq!(dst, b"foo barb");
505 assert!(encoder.is_eof());
506 assert!(encoder.end::<()>().unwrap().is_none());
507 }
508
509 #[cfg(feature = "server")]
510 #[test]
511 fn eof() {
512 let mut encoder = Encoder::close_delimited();
513 let mut dst = Vec::new();
514
515 let msg1 = b"foo bar".as_ref();
516 let buf1 = encoder.encode(msg1);
517 dst.put(buf1);
518
519 assert_eq!(dst, b"foo bar");
520 assert!(!encoder.is_eof());
521 encoder.end::<()>().unwrap();
522
523 let msg2 = b"baz".as_ref();
524 let buf2 = encoder.encode(msg2);
525 dst.put(buf2);
526
527 assert_eq!(dst, b"foo barbaz");
528 assert!(!encoder.is_eof());
529 encoder.end::<()>().unwrap();
530 }
531
532 #[test]
533 fn chunked_with_valid_trailers() {
534 let encoder = Encoder::chunked();
535 let trailers = vec![HeaderValue::from_static("chunky-trailer")];
536 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
537
538 let headers = HeaderMap::from_iter(vec![
539 (
540 HeaderName::from_static("chunky-trailer"),
541 HeaderValue::from_static("header data"),
542 ),
543 (
544 HeaderName::from_static("should-not-be-included"),
545 HeaderValue::from_static("oops"),
546 ),
547 ]);
548
549 let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
550
551 let mut dst = Vec::new();
552 dst.put(buf1);
553 assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
554 }
555
556 #[test]
557 fn chunked_with_multiple_trailer_headers() {
558 let encoder = Encoder::chunked();
559 let trailers = vec![
560 HeaderValue::from_static("chunky-trailer"),
561 HeaderValue::from_static("chunky-trailer-2"),
562 ];
563 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
564
565 let headers = HeaderMap::from_iter(vec![
566 (
567 HeaderName::from_static("chunky-trailer"),
568 HeaderValue::from_static("header data"),
569 ),
570 (
571 HeaderName::from_static("chunky-trailer-2"),
572 HeaderValue::from_static("more header data"),
573 ),
574 ]);
575
576 let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
577
578 let mut dst = Vec::new();
579 dst.put(buf1);
580 assert_eq!(
581 dst,
582 b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
583 );
584 }
585
586 #[test]
587 fn chunked_with_no_trailer_header() {
588 let encoder = Encoder::chunked();
589
590 let headers = HeaderMap::from_iter(vec![(
591 HeaderName::from_static("chunky-trailer"),
592 HeaderValue::from_static("header data"),
593 )]);
594
595 assert!(encoder
596 .encode_trailers::<&[u8]>(headers.clone(), false)
597 .is_none());
598
599 let trailers = vec![];
600 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
601
602 assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
603 }
604
605 #[test]
606 fn chunked_with_invalid_trailers() {
607 let encoder = Encoder::chunked();
608
609 let trailers = format!(
610 "{},{},{},{},{},{},{},{},{},{},{},{}",
611 AUTHORIZATION,
612 CACHE_CONTROL,
613 CONTENT_ENCODING,
614 CONTENT_LENGTH,
615 CONTENT_RANGE,
616 CONTENT_TYPE,
617 HOST,
618 MAX_FORWARDS,
619 SET_COOKIE,
620 TRAILER,
621 TRANSFER_ENCODING,
622 TE,
623 );
624 let trailers = vec![HeaderValue::from_str(&trailers).unwrap()];
625 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
626
627 let mut headers = HeaderMap::new();
628 headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
629 headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
630 headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
631 headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
632 headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
633 headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
634 headers.insert(HOST, HeaderValue::from_static("header data"));
635 headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
636 headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
637 headers.insert(TRAILER, HeaderValue::from_static("header data"));
638 headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
639 headers.insert(TE, HeaderValue::from_static("header data"));
640
641 assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
642 }
643
644 #[test]
645 fn chunked_with_title_case_headers() {
646 let encoder = Encoder::chunked();
647 let trailers = vec![HeaderValue::from_static("chunky-trailer")];
648 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
649
650 let headers = HeaderMap::from_iter(vec![(
651 HeaderName::from_static("chunky-trailer"),
652 HeaderValue::from_static("header data"),
653 )]);
654 let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
655
656 let mut dst = Vec::new();
657 dst.put(buf1);
658 assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
659 }
660}