http_body_util/combinators/
with_trailers.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures_util::ready;
8use http::HeaderMap;
9use http_body::{Body, Frame};
10use pin_project_lite::pin_project;
11
12pin_project! {
13    /// Adds trailers to a body.
14    ///
15    /// See [`BodyExt::with_trailers`] for more details.
16    pub struct WithTrailers<T, F> {
17        #[pin]
18        state: State<T, F>,
19    }
20}
21
22impl<T, F> WithTrailers<T, F> {
23    pub(crate) fn new(body: T, trailers: F) -> Self {
24        Self {
25            state: State::PollBody {
26                body,
27                trailers: Some(trailers),
28            },
29        }
30    }
31}
32
33pin_project! {
34    #[project = StateProj]
35    enum State<T, F> {
36        PollBody {
37            #[pin]
38            body: T,
39            trailers: Option<F>,
40        },
41        PollTrailers {
42            #[pin]
43            trailers: F,
44            prev_trailers: Option<HeaderMap>,
45        },
46        Done,
47    }
48}
49
50impl<T, F> Body for WithTrailers<T, F>
51where
52    T: Body,
53    F: Future<Output = Option<Result<HeaderMap, T::Error>>>,
54{
55    type Data = T::Data;
56    type Error = T::Error;
57
58    fn poll_frame(
59        mut self: Pin<&mut Self>,
60        cx: &mut Context<'_>,
61    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
62        loop {
63            let mut this = self.as_mut().project();
64
65            match this.state.as_mut().project() {
66                StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) {
67                    Some(frame) => match frame.into_trailers() {
68                        Ok(prev_trailers) => {
69                            let trailers = trailers.take().unwrap();
70                            this.state.set(State::PollTrailers {
71                                trailers,
72                                prev_trailers: Some(prev_trailers),
73                            });
74                        }
75                        Err(frame) => {
76                            return Poll::Ready(Some(Ok(frame)));
77                        }
78                    },
79                    None => {
80                        let trailers = trailers.take().unwrap();
81                        this.state.set(State::PollTrailers {
82                            trailers,
83                            prev_trailers: None,
84                        });
85                    }
86                },
87                StateProj::PollTrailers {
88                    trailers,
89                    prev_trailers,
90                } => {
91                    let trailers = ready!(trailers.poll(cx)?);
92                    match (trailers, prev_trailers.take()) {
93                        (None, None) => return Poll::Ready(None),
94                        (None, Some(trailers)) | (Some(trailers), None) => {
95                            this.state.set(State::Done);
96                            return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
97                        }
98                        (Some(new_trailers), Some(mut prev_trailers)) => {
99                            prev_trailers.extend(new_trailers);
100                            this.state.set(State::Done);
101                            return Poll::Ready(Some(Ok(Frame::trailers(prev_trailers))));
102                        }
103                    }
104                }
105                StateProj::Done => {
106                    return Poll::Ready(None);
107                }
108            }
109        }
110    }
111
112    #[inline]
113    fn size_hint(&self) -> http_body::SizeHint {
114        match &self.state {
115            State::PollBody { body, .. } => body.size_hint(),
116            State::PollTrailers { .. } | State::Done => Default::default(),
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use std::convert::Infallible;
124
125    use bytes::Bytes;
126    use http::{HeaderName, HeaderValue};
127
128    use crate::{BodyExt, Empty, Full};
129
130    #[allow(unused_imports)]
131    use super::*;
132
133    #[tokio::test]
134    async fn works() {
135        let mut trailers = HeaderMap::new();
136        trailers.insert(
137            HeaderName::from_static("foo"),
138            HeaderValue::from_static("bar"),
139        );
140
141        let body =
142            Full::<Bytes>::from("hello").with_trailers(std::future::ready(Some(
143                Ok::<_, Infallible>(trailers.clone()),
144            )));
145
146        futures_util::pin_mut!(body);
147        let waker = futures_util::task::noop_waker();
148        let mut cx = Context::from_waker(&waker);
149
150        let data = unwrap_ready(body.as_mut().poll_frame(&mut cx))
151            .unwrap()
152            .unwrap()
153            .into_data()
154            .unwrap();
155        assert_eq!(data, "hello");
156
157        let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
158            .unwrap()
159            .unwrap()
160            .into_trailers()
161            .unwrap();
162        assert_eq!(body_trailers, trailers);
163
164        assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
165    }
166
167    #[tokio::test]
168    async fn merges_trailers() {
169        let mut trailers_1 = HeaderMap::new();
170        trailers_1.insert(
171            HeaderName::from_static("foo"),
172            HeaderValue::from_static("bar"),
173        );
174
175        let mut trailers_2 = HeaderMap::new();
176        trailers_2.insert(
177            HeaderName::from_static("baz"),
178            HeaderValue::from_static("qux"),
179        );
180
181        let body = Empty::<Bytes>::new()
182            .with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
183                trailers_1.clone(),
184            ))))
185            .with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
186                trailers_2.clone(),
187            ))));
188
189        futures_util::pin_mut!(body);
190        let waker = futures_util::task::noop_waker();
191        let mut cx = Context::from_waker(&waker);
192
193        let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
194            .unwrap()
195            .unwrap()
196            .into_trailers()
197            .unwrap();
198
199        let mut all_trailers = HeaderMap::new();
200        all_trailers.extend(trailers_1);
201        all_trailers.extend(trailers_2);
202        assert_eq!(body_trailers, all_trailers);
203
204        assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
205    }
206
207    fn unwrap_ready<T>(poll: Poll<T>) -> T {
208        match poll {
209            Poll::Ready(t) => t,
210            Poll::Pending => panic!("pending"),
211        }
212    }
213}