http_body_util/combinators/
with_trailers.rs1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures_util::ready;
8use http::HeaderMap;
9use http_body::{Body, Frame};
10use pin_project_lite::pin_project;
11
12pin_project! {
13 pub struct WithTrailers<T, F> {
17 #[pin]
18 state: State<T, F>,
19 }
20}
21
22impl<T, F> WithTrailers<T, F> {
23 pub(crate) fn new(body: T, trailers: F) -> Self {
24 Self {
25 state: State::PollBody {
26 body,
27 trailers: Some(trailers),
28 },
29 }
30 }
31}
32
33pin_project! {
34 #[project = StateProj]
35 enum State<T, F> {
36 PollBody {
37 #[pin]
38 body: T,
39 trailers: Option<F>,
40 },
41 PollTrailers {
42 #[pin]
43 trailers: F,
44 prev_trailers: Option<HeaderMap>,
45 },
46 Done,
47 }
48}
49
50impl<T, F> Body for WithTrailers<T, F>
51where
52 T: Body,
53 F: Future<Output = Option<Result<HeaderMap, T::Error>>>,
54{
55 type Data = T::Data;
56 type Error = T::Error;
57
58 fn poll_frame(
59 mut self: Pin<&mut Self>,
60 cx: &mut Context<'_>,
61 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
62 loop {
63 let mut this = self.as_mut().project();
64
65 match this.state.as_mut().project() {
66 StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) {
67 Some(frame) => match frame.into_trailers() {
68 Ok(prev_trailers) => {
69 let trailers = trailers.take().unwrap();
70 this.state.set(State::PollTrailers {
71 trailers,
72 prev_trailers: Some(prev_trailers),
73 });
74 }
75 Err(frame) => {
76 return Poll::Ready(Some(Ok(frame)));
77 }
78 },
79 None => {
80 let trailers = trailers.take().unwrap();
81 this.state.set(State::PollTrailers {
82 trailers,
83 prev_trailers: None,
84 });
85 }
86 },
87 StateProj::PollTrailers {
88 trailers,
89 prev_trailers,
90 } => {
91 let trailers = ready!(trailers.poll(cx)?);
92 match (trailers, prev_trailers.take()) {
93 (None, None) => return Poll::Ready(None),
94 (None, Some(trailers)) | (Some(trailers), None) => {
95 this.state.set(State::Done);
96 return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
97 }
98 (Some(new_trailers), Some(mut prev_trailers)) => {
99 prev_trailers.extend(new_trailers);
100 this.state.set(State::Done);
101 return Poll::Ready(Some(Ok(Frame::trailers(prev_trailers))));
102 }
103 }
104 }
105 StateProj::Done => {
106 return Poll::Ready(None);
107 }
108 }
109 }
110 }
111
112 #[inline]
113 fn size_hint(&self) -> http_body::SizeHint {
114 match &self.state {
115 State::PollBody { body, .. } => body.size_hint(),
116 State::PollTrailers { .. } | State::Done => Default::default(),
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use std::convert::Infallible;
124
125 use bytes::Bytes;
126 use http::{HeaderName, HeaderValue};
127
128 use crate::{BodyExt, Empty, Full};
129
130 #[allow(unused_imports)]
131 use super::*;
132
133 #[tokio::test]
134 async fn works() {
135 let mut trailers = HeaderMap::new();
136 trailers.insert(
137 HeaderName::from_static("foo"),
138 HeaderValue::from_static("bar"),
139 );
140
141 let body =
142 Full::<Bytes>::from("hello").with_trailers(std::future::ready(Some(
143 Ok::<_, Infallible>(trailers.clone()),
144 )));
145
146 futures_util::pin_mut!(body);
147 let waker = futures_util::task::noop_waker();
148 let mut cx = Context::from_waker(&waker);
149
150 let data = unwrap_ready(body.as_mut().poll_frame(&mut cx))
151 .unwrap()
152 .unwrap()
153 .into_data()
154 .unwrap();
155 assert_eq!(data, "hello");
156
157 let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
158 .unwrap()
159 .unwrap()
160 .into_trailers()
161 .unwrap();
162 assert_eq!(body_trailers, trailers);
163
164 assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
165 }
166
167 #[tokio::test]
168 async fn merges_trailers() {
169 let mut trailers_1 = HeaderMap::new();
170 trailers_1.insert(
171 HeaderName::from_static("foo"),
172 HeaderValue::from_static("bar"),
173 );
174
175 let mut trailers_2 = HeaderMap::new();
176 trailers_2.insert(
177 HeaderName::from_static("baz"),
178 HeaderValue::from_static("qux"),
179 );
180
181 let body = Empty::<Bytes>::new()
182 .with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
183 trailers_1.clone(),
184 ))))
185 .with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
186 trailers_2.clone(),
187 ))));
188
189 futures_util::pin_mut!(body);
190 let waker = futures_util::task::noop_waker();
191 let mut cx = Context::from_waker(&waker);
192
193 let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
194 .unwrap()
195 .unwrap()
196 .into_trailers()
197 .unwrap();
198
199 let mut all_trailers = HeaderMap::new();
200 all_trailers.extend(trailers_1);
201 all_trailers.extend(trailers_2);
202 assert_eq!(body_trailers, all_trailers);
203
204 assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
205 }
206
207 fn unwrap_ready<T>(poll: Poll<T>) -> T {
208 match poll {
209 Poll::Ready(t) => t,
210 Poll::Pending => panic!("pending"),
211 }
212 }
213}