http_body_util/
collected.rs

1use std::{
2    convert::Infallible,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, Bytes};
8use http::HeaderMap;
9use http_body::{Body, Frame};
10
11use crate::util::BufList;
12
13/// A collected body produced by [`BodyExt::collect`] which collects all the DATA frames
14/// and trailers.
15///
16/// [`BodyExt::collect`]: crate::BodyExt::collect
17#[derive(Debug)]
18pub struct Collected<B> {
19    bufs: BufList<B>,
20    trailers: Option<HeaderMap>,
21}
22
23impl<B: Buf> Collected<B> {
24    /// If there is a trailers frame buffered, returns a reference to it.
25    ///
26    /// Returns `None` if the body contained no trailers.
27    pub fn trailers(&self) -> Option<&HeaderMap> {
28        self.trailers.as_ref()
29    }
30
31    /// Aggregate this buffered into a [`Buf`].
32    pub fn aggregate(self) -> impl Buf {
33        self.bufs
34    }
35
36    /// Convert this body into a [`Bytes`].
37    pub fn to_bytes(mut self) -> Bytes {
38        self.bufs.copy_to_bytes(self.bufs.remaining())
39    }
40
41    pub(crate) fn push_frame(&mut self, frame: Frame<B>) {
42        let frame = match frame.into_data() {
43            Ok(data) => {
44                // Only push this frame if it has some data in it, to avoid crashing on
45                // `BufList::push`.
46                if data.has_remaining() {
47                    self.bufs.push(data);
48                }
49                return;
50            }
51            Err(frame) => frame,
52        };
53
54        if let Ok(trailers) = frame.into_trailers() {
55            if let Some(current) = &mut self.trailers {
56                current.extend(trailers);
57            } else {
58                self.trailers = Some(trailers);
59            }
60        };
61    }
62}
63
64impl<B: Buf> Body for Collected<B> {
65    type Data = B;
66    type Error = Infallible;
67
68    fn poll_frame(
69        mut self: Pin<&mut Self>,
70        _: &mut Context<'_>,
71    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
72        let frame = if let Some(data) = self.bufs.pop() {
73            Frame::data(data)
74        } else if let Some(trailers) = self.trailers.take() {
75            Frame::trailers(trailers)
76        } else {
77            return Poll::Ready(None);
78        };
79
80        Poll::Ready(Some(Ok(frame)))
81    }
82}
83
84impl<B> Default for Collected<B> {
85    fn default() -> Self {
86        Self {
87            bufs: BufList::default(),
88            trailers: None,
89        }
90    }
91}
92
93impl<B> Unpin for Collected<B> {}
94
95#[cfg(test)]
96mod tests {
97    use std::convert::TryInto;
98
99    use futures_util::stream;
100
101    use crate::{BodyExt, Full, StreamBody};
102
103    use super::*;
104
105    #[tokio::test]
106    async fn full_body() {
107        let body = Full::new(&b"hello"[..]);
108
109        let buffered = body.collect().await.unwrap();
110
111        let mut buf = buffered.to_bytes();
112
113        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], &b"hello"[..]);
114    }
115
116    #[tokio::test]
117    async fn segmented_body() {
118        let bufs = [&b"hello"[..], &b"world"[..], &b"!"[..]];
119        let body = StreamBody::new(stream::iter(bufs.map(Frame::data).map(Ok::<_, Infallible>)));
120
121        let buffered = body.collect().await.unwrap();
122
123        let mut buf = buffered.to_bytes();
124
125        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"helloworld!");
126    }
127
128    #[tokio::test]
129    async fn delayed_segments() {
130        let one = stream::once(async { Ok::<_, Infallible>(Frame::data(&b"hello "[..])) });
131        let two = stream::once(async {
132            // a yield just so its not ready immediately
133            tokio::task::yield_now().await;
134            Ok::<_, Infallible>(Frame::data(&b"world!"[..]))
135        });
136        let stream = futures_util::StreamExt::chain(one, two);
137
138        let body = StreamBody::new(stream);
139
140        let buffered = body.collect().await.unwrap();
141
142        let mut buf = buffered.to_bytes();
143
144        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"hello world!");
145    }
146
147    #[tokio::test]
148    async fn trailers() {
149        let mut trailers = HeaderMap::new();
150        trailers.insert("this", "a trailer".try_into().unwrap());
151        let bufs = [
152            Frame::data(&b"hello"[..]),
153            Frame::data(&b"world!"[..]),
154            Frame::trailers(trailers.clone()),
155        ];
156
157        let body = StreamBody::new(stream::iter(bufs.map(Ok::<_, Infallible>)));
158
159        let buffered = body.collect().await.unwrap();
160
161        assert_eq!(&trailers, buffered.trailers().unwrap());
162
163        let mut buf = buffered.to_bytes();
164
165        assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"helloworld!");
166    }
167
168    /// Test for issue [#88](https://github.com/hyperium/http-body/issues/88).
169    #[tokio::test]
170    async fn empty_frame() {
171        let bufs: [&[u8]; 1] = [&[]];
172
173        let body = StreamBody::new(stream::iter(bufs.map(Frame::data).map(Ok::<_, Infallible>)));
174        let buffered = body.collect().await.unwrap();
175
176        assert_eq!(buffered.to_bytes().len(), 0);
177    }
178}