1use std::{
2 error::Error as StdError,
3 future::Future,
4 marker::Unpin,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use crate::rt::{Read, Write};
10use bytes::{Buf, Bytes};
11use futures_util::ready;
12use http::Request;
13
14use super::{Http1Transaction, Wants};
15use crate::body::{Body, DecodedLength, Incoming as IncomingBody};
16#[cfg(feature = "client")]
17use crate::client::dispatch::TrySendError;
18use crate::common::task;
19use crate::proto::{BodyLength, Conn, Dispatched, MessageHead, RequestHead};
20use crate::upgrade::OnUpgrade;
21
22pub(crate) struct Dispatcher<D, Bs: Body, I, T> {
23 conn: Conn<I, Bs::Data, T>,
24 dispatch: D,
25 body_tx: Option<crate::body::Sender>,
26 body_rx: Pin<Box<Option<Bs>>>,
27 is_closing: bool,
28}
29
30pub(crate) trait Dispatch {
31 type PollItem;
32 type PollBody;
33 type PollError;
34 type RecvItem;
35 fn poll_msg(
36 self: Pin<&mut Self>,
37 cx: &mut Context<'_>,
38 ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>;
39 fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>)
40 -> crate::Result<()>;
41 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ()>>;
42 fn should_poll(&self) -> bool;
43}
44
45cfg_server! {
46 use crate::service::HttpService;
47
48 pub(crate) struct Server<S: HttpService<B>, B> {
49 in_flight: Pin<Box<Option<S::Future>>>,
50 pub(crate) service: S,
51 }
52}
53
54cfg_client! {
55 pin_project_lite::pin_project! {
56 pub(crate) struct Client<B> {
57 callback: Option<crate::client::dispatch::Callback<Request<B>, http::Response<IncomingBody>>>,
58 #[pin]
59 rx: ClientRx<B>,
60 rx_closed: bool,
61 }
62 }
63
64 type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, http::Response<IncomingBody>>;
65}
66
67impl<D, Bs, I, T> Dispatcher<D, Bs, I, T>
68where
69 D: Dispatch<
70 PollItem = MessageHead<T::Outgoing>,
71 PollBody = Bs,
72 RecvItem = MessageHead<T::Incoming>,
73 > + Unpin,
74 D::PollError: Into<Box<dyn StdError + Send + Sync>>,
75 I: Read + Write + Unpin,
76 T: Http1Transaction + Unpin,
77 Bs: Body + 'static,
78 Bs::Error: Into<Box<dyn StdError + Send + Sync>>,
79{
80 pub(crate) fn new(dispatch: D, conn: Conn<I, Bs::Data, T>) -> Self {
81 Dispatcher {
82 conn,
83 dispatch,
84 body_tx: None,
85 body_rx: Box::pin(None),
86 is_closing: false,
87 }
88 }
89
90 #[cfg(feature = "server")]
91 pub(crate) fn disable_keep_alive(&mut self) {
92 self.conn.disable_keep_alive();
93
94 if self.conn.is_write_closed() || self.conn.has_initial_read_write_state() {
98 self.close();
99 }
100 }
101
102 pub(crate) fn into_inner(self) -> (I, Bytes, D) {
103 let (io, buf) = self.conn.into_inner();
104 (io, buf, self.dispatch)
105 }
106
107 pub(crate) fn poll_without_shutdown(
113 &mut self,
114 cx: &mut Context<'_>,
115 ) -> Poll<crate::Result<()>> {
116 Pin::new(self).poll_catch(cx, false).map_ok(|ds| {
117 if let Dispatched::Upgrade(pending) = ds {
118 pending.manual();
119 }
120 })
121 }
122
123 fn poll_catch(
124 &mut self,
125 cx: &mut Context<'_>,
126 should_shutdown: bool,
127 ) -> Poll<crate::Result<Dispatched>> {
128 Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| {
129 if let Some(mut body) = self.body_tx.take() {
131 body.send_error(crate::Error::new_body("connection error"));
132 }
133 self.dispatch.recv_msg(Err(e))?;
138 Ok(Dispatched::Shutdown)
139 }))
140 }
141
142 fn poll_inner(
143 &mut self,
144 cx: &mut Context<'_>,
145 should_shutdown: bool,
146 ) -> Poll<crate::Result<Dispatched>> {
147 T::update_date();
148
149 ready!(self.poll_loop(cx))?;
150
151 if self.is_done() {
152 if let Some(pending) = self.conn.pending_upgrade() {
153 self.conn.take_error()?;
154 return Poll::Ready(Ok(Dispatched::Upgrade(pending)));
155 } else if should_shutdown {
156 ready!(self.conn.poll_shutdown(cx)).map_err(crate::Error::new_shutdown)?;
157 }
158 self.conn.take_error()?;
159 Poll::Ready(Ok(Dispatched::Shutdown))
160 } else {
161 Poll::Pending
162 }
163 }
164
165 fn poll_loop(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
166 for _ in 0..16 {
172 let _ = self.poll_read(cx)?;
173 let _ = self.poll_write(cx)?;
174 let _ = self.poll_flush(cx)?;
175
176 if !self.conn.wants_read_again() {
185 return Poll::Ready(Ok(()));
187 }
188 }
189
190 trace!("poll_loop yielding (self = {:p})", self);
191
192 task::yield_now(cx).map(|never| match never {})
193 }
194
195 fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
196 loop {
197 if self.is_closing {
198 return Poll::Ready(Ok(()));
199 } else if self.conn.can_read_head() {
200 ready!(self.poll_read_head(cx))?;
201 } else if let Some(mut body) = self.body_tx.take() {
202 if self.conn.can_read_body() {
203 match body.poll_ready(cx) {
204 Poll::Ready(Ok(())) => (),
205 Poll::Pending => {
206 self.body_tx = Some(body);
207 return Poll::Pending;
208 }
209 Poll::Ready(Err(_canceled)) => {
210 trace!("body receiver dropped before eof, draining or closing");
213 self.conn.poll_drain_or_close_read(cx);
214 continue;
215 }
216 }
217 match self.conn.poll_read_body(cx) {
218 Poll::Ready(Some(Ok(frame))) => {
219 if frame.is_data() {
220 let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
221 match body.try_send_data(chunk) {
222 Ok(()) => {
223 self.body_tx = Some(body);
224 }
225 Err(_canceled) => {
226 if self.conn.can_read_body() {
227 trace!("body receiver dropped before eof, closing");
228 self.conn.close_read();
229 }
230 }
231 }
232 } else if frame.is_trailers() {
233 let trailers =
234 frame.into_trailers().unwrap_or_else(|_| unreachable!());
235 match body.try_send_trailers(trailers) {
236 Ok(()) => {
237 self.body_tx = Some(body);
238 }
239 Err(_canceled) => {
240 if self.conn.can_read_body() {
241 trace!("body receiver dropped before eof, closing");
242 self.conn.close_read();
243 }
244 }
245 }
246 } else {
247 error!("unexpected frame");
249 }
250 }
251 Poll::Ready(None) => {
252 }
254 Poll::Pending => {
255 self.body_tx = Some(body);
256 return Poll::Pending;
257 }
258 Poll::Ready(Some(Err(e))) => {
259 body.send_error(crate::Error::new_body(e));
260 }
261 }
262 } else {
263 }
265 } else {
266 return self.conn.poll_read_keep_alive(cx);
267 }
268 }
269 }
270
271 fn poll_read_head(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
272 match ready!(self.dispatch.poll_ready(cx)) {
274 Ok(()) => (),
275 Err(()) => {
276 trace!("dispatch no longer receiving messages");
277 self.close();
278 return Poll::Ready(Ok(()));
279 }
280 }
281
282 match ready!(self.conn.poll_read_head(cx)) {
284 Some(Ok((mut head, body_len, wants))) => {
285 let body = match body_len {
286 DecodedLength::ZERO => IncomingBody::empty(),
287 other => {
288 let (tx, rx) =
289 IncomingBody::new_channel(other, wants.contains(Wants::EXPECT));
290 self.body_tx = Some(tx);
291 rx
292 }
293 };
294 if wants.contains(Wants::UPGRADE) {
295 let upgrade = self.conn.on_upgrade();
296 debug_assert!(!upgrade.is_none(), "empty upgrade");
297 debug_assert!(
298 head.extensions.get::<OnUpgrade>().is_none(),
299 "OnUpgrade already set"
300 );
301 head.extensions.insert(upgrade);
302 }
303 self.dispatch.recv_msg(Ok((head, body)))?;
304 Poll::Ready(Ok(()))
305 }
306 Some(Err(err)) => {
307 debug!("read_head error: {}", err);
308 self.dispatch.recv_msg(Err(err))?;
309 self.close();
313 Poll::Ready(Ok(()))
314 }
315 None => {
316 debug_assert!(self.conn.is_read_closed());
320 if self.conn.is_write_closed() {
321 self.close();
322 }
323 Poll::Ready(Ok(()))
324 }
325 }
326 }
327
328 fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
329 loop {
330 if self.is_closing {
331 return Poll::Ready(Ok(()));
332 } else if self.body_rx.is_none()
333 && self.conn.can_write_head()
334 && self.dispatch.should_poll()
335 {
336 if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) {
337 let (head, body) = msg.map_err(crate::Error::new_user_service)?;
338
339 let body_type = if body.is_end_stream() {
340 self.body_rx.set(None);
341 None
342 } else {
343 let btype = body
344 .size_hint()
345 .exact()
346 .map(BodyLength::Known)
347 .or(Some(BodyLength::Unknown));
348 self.body_rx.set(Some(body));
349 btype
350 };
351 self.conn.write_head(head, body_type);
352 } else {
353 self.close();
354 return Poll::Ready(Ok(()));
355 }
356 } else if !self.conn.can_buffer_body() {
357 ready!(self.poll_flush(cx))?;
358 } else {
359 if let (Some(mut body), clear_body) =
361 OptGuard::new(self.body_rx.as_mut()).guard_mut()
362 {
363 debug_assert!(!*clear_body, "opt guard defaults to keeping body");
364 if !self.conn.can_write_body() {
365 trace!(
366 "no more write body allowed, user body is_end_stream = {}",
367 body.is_end_stream(),
368 );
369 *clear_body = true;
370 continue;
371 }
372
373 let item = ready!(body.as_mut().poll_frame(cx));
374 if let Some(item) = item {
375 let frame = item.map_err(|e| {
376 *clear_body = true;
377 crate::Error::new_user_body(e)
378 })?;
379
380 if frame.is_data() {
381 let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
382 let eos = body.is_end_stream();
383 if eos {
384 *clear_body = true;
385 if chunk.remaining() == 0 {
386 trace!("discarding empty chunk");
387 self.conn.end_body()?;
388 } else {
389 self.conn.write_body_and_end(chunk);
390 }
391 } else {
392 if chunk.remaining() == 0 {
393 trace!("discarding empty chunk");
394 continue;
395 }
396 self.conn.write_body(chunk);
397 }
398 } else if frame.is_trailers() {
399 *clear_body = true;
400 self.conn.write_trailers(
401 frame.into_trailers().unwrap_or_else(|_| unreachable!()),
402 );
403 } else {
404 trace!("discarding unknown frame");
405 continue;
406 }
407 } else {
408 *clear_body = true;
409 self.conn.end_body()?;
410 }
411 } else {
412 if self.conn.can_write_body() {
414 self.conn.end_body()?;
415 } else {
416 return Poll::Pending;
417 }
418 }
419 }
420 }
421 }
422
423 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
424 self.conn.poll_flush(cx).map_err(|err| {
425 debug!("error writing: {}", err);
426 crate::Error::new_body_write(err)
427 })
428 }
429
430 fn close(&mut self) {
431 self.is_closing = true;
432 self.conn.close_read();
433 self.conn.close_write();
434 }
435
436 fn is_done(&self) -> bool {
437 if self.is_closing {
438 return true;
439 }
440
441 let read_done = self.conn.is_read_closed();
442
443 if !T::should_read_first() && read_done {
444 true
446 } else {
447 let write_done = self.conn.is_write_closed()
448 || (!self.dispatch.should_poll() && self.body_rx.is_none());
449 read_done && write_done
450 }
451 }
452}
453
454impl<D, Bs, I, T> Future for Dispatcher<D, Bs, I, T>
455where
456 D: Dispatch<
457 PollItem = MessageHead<T::Outgoing>,
458 PollBody = Bs,
459 RecvItem = MessageHead<T::Incoming>,
460 > + Unpin,
461 D::PollError: Into<Box<dyn StdError + Send + Sync>>,
462 I: Read + Write + Unpin,
463 T: Http1Transaction + Unpin,
464 Bs: Body + 'static,
465 Bs::Error: Into<Box<dyn StdError + Send + Sync>>,
466{
467 type Output = crate::Result<Dispatched>;
468
469 #[inline]
470 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
471 self.poll_catch(cx, true)
472 }
473}
474
475struct OptGuard<'a, T>(Pin<&'a mut Option<T>>, bool);
480
481impl<'a, T> OptGuard<'a, T> {
482 fn new(pin: Pin<&'a mut Option<T>>) -> Self {
483 OptGuard(pin, false)
484 }
485
486 fn guard_mut(&mut self) -> (Option<Pin<&mut T>>, &mut bool) {
487 (self.0.as_mut().as_pin_mut(), &mut self.1)
488 }
489}
490
491impl<T> Drop for OptGuard<'_, T> {
492 fn drop(&mut self) {
493 if self.1 {
494 self.0.set(None);
495 }
496 }
497}
498
499cfg_server! {
502 impl<S, B> Server<S, B>
503 where
504 S: HttpService<B>,
505 {
506 pub(crate) fn new(service: S) -> Server<S, B> {
507 Server {
508 in_flight: Box::pin(None),
509 service,
510 }
511 }
512
513 pub(crate) fn into_service(self) -> S {
514 self.service
515 }
516 }
517
518 impl<S: HttpService<B>, B> Unpin for Server<S, B> {}
520
521 impl<S, Bs> Dispatch for Server<S, IncomingBody>
522 where
523 S: HttpService<IncomingBody, ResBody = Bs>,
524 S::Error: Into<Box<dyn StdError + Send + Sync>>,
525 Bs: Body,
526 {
527 type PollItem = MessageHead<http::StatusCode>;
528 type PollBody = Bs;
529 type PollError = S::Error;
530 type RecvItem = RequestHead;
531
532 fn poll_msg(
533 mut self: Pin<&mut Self>,
534 cx: &mut Context<'_>,
535 ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> {
536 let mut this = self.as_mut();
537 let ret = if let Some(ref mut fut) = this.in_flight.as_mut().as_pin_mut() {
538 let resp = ready!(fut.as_mut().poll(cx)?);
539 let (parts, body) = resp.into_parts();
540 let head = MessageHead {
541 version: parts.version,
542 subject: parts.status,
543 headers: parts.headers,
544 extensions: parts.extensions,
545 };
546 Poll::Ready(Some(Ok((head, body))))
547 } else {
548 unreachable!("poll_msg shouldn't be called if no inflight");
549 };
550
551 this.in_flight.set(None);
553 ret
554 }
555
556 fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>) -> crate::Result<()> {
557 let (msg, body) = msg?;
558 let mut req = Request::new(body);
559 *req.method_mut() = msg.subject.0;
560 *req.uri_mut() = msg.subject.1;
561 *req.headers_mut() = msg.headers;
562 *req.version_mut() = msg.version;
563 *req.extensions_mut() = msg.extensions;
564 let fut = self.service.call(req);
565 self.in_flight.set(Some(fut));
566 Ok(())
567 }
568
569 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
570 if self.in_flight.is_some() {
571 Poll::Pending
572 } else {
573 Poll::Ready(Ok(()))
574 }
575 }
576
577 fn should_poll(&self) -> bool {
578 self.in_flight.is_some()
579 }
580 }
581}
582
583cfg_client! {
586 use std::convert::Infallible;
587
588 impl<B> Client<B> {
589 pub(crate) fn new(rx: ClientRx<B>) -> Client<B> {
590 Client {
591 callback: None,
592 rx,
593 rx_closed: false,
594 }
595 }
596 }
597
598 impl<B> Dispatch for Client<B>
599 where
600 B: Body,
601 {
602 type PollItem = RequestHead;
603 type PollBody = B;
604 type PollError = Infallible;
605 type RecvItem = crate::proto::ResponseHead;
606
607 fn poll_msg(
608 mut self: Pin<&mut Self>,
609 cx: &mut Context<'_>,
610 ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Infallible>>> {
611 let mut this = self.as_mut();
612 debug_assert!(!this.rx_closed);
613 match this.rx.poll_recv(cx) {
614 Poll::Ready(Some((req, mut cb))) => {
615 match cb.poll_canceled(cx) {
617 Poll::Ready(()) => {
618 trace!("request canceled");
619 Poll::Ready(None)
620 }
621 Poll::Pending => {
622 let (parts, body) = req.into_parts();
623 let head = RequestHead {
624 version: parts.version,
625 subject: crate::proto::RequestLine(parts.method, parts.uri),
626 headers: parts.headers,
627 extensions: parts.extensions,
628 };
629 this.callback = Some(cb);
630 Poll::Ready(Some(Ok((head, body))))
631 }
632 }
633 }
634 Poll::Ready(None) => {
635 trace!("client tx closed");
637 this.rx_closed = true;
638 Poll::Ready(None)
639 }
640 Poll::Pending => Poll::Pending,
641 }
642 }
643
644 fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>) -> crate::Result<()> {
645 match msg {
646 Ok((msg, body)) => {
647 if let Some(cb) = self.callback.take() {
648 let res = msg.into_response(body);
649 cb.send(Ok(res));
650 Ok(())
651 } else {
652 Err(crate::Error::new_unexpected_message())
656 }
657 }
658 Err(err) => {
659 if let Some(cb) = self.callback.take() {
660 cb.send(Err(TrySendError {
661 error: err,
662 message: None,
663 }));
664 Ok(())
665 } else if !self.rx_closed {
666 self.rx.close();
667 if let Some((req, cb)) = self.rx.try_recv() {
668 trace!("canceling queued request with connection error: {}", err);
669 cb.send(Err(TrySendError {
672 error: crate::Error::new_canceled().with(err),
673 message: Some(req),
674 }));
675 Ok(())
676 } else {
677 Err(err)
678 }
679 } else {
680 Err(err)
681 }
682 }
683 }
684 }
685
686 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
687 match self.callback {
688 Some(ref mut cb) => match cb.poll_canceled(cx) {
689 Poll::Ready(()) => {
690 trace!("callback receiver has dropped");
691 Poll::Ready(Err(()))
692 }
693 Poll::Pending => Poll::Ready(Ok(())),
694 },
695 None => Poll::Ready(Err(())),
696 }
697 }
698
699 fn should_poll(&self) -> bool {
700 self.callback.is_none()
701 }
702 }
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708 use crate::common::io::Compat;
709 use crate::proto::h1::ClientTransaction;
710 use std::time::Duration;
711
712 #[test]
713 fn client_read_bytes_before_writing_request() {
714 let _ = pretty_env_logger::try_init();
715
716 tokio_test::task::spawn(()).enter(|cx, _| {
717 let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
718
719 let (mut tx, rx) = crate::client::dispatch::channel();
722 let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(Compat::new(io));
723 let mut dispatcher = Dispatcher::new(Client::new(rx), conn);
724
725 assert!(Pin::new(&mut dispatcher).poll(cx).is_pending());
727
728 handle.read(b"HTTP/1.1 200 OK\r\n\r\n");
731
732 let mut res_rx = tx
733 .try_send(crate::Request::new(IncomingBody::empty()))
734 .unwrap();
735
736 tokio_test::assert_ready_ok!(Pin::new(&mut dispatcher).poll(cx));
737 let err = tokio_test::assert_ready_ok!(Pin::new(&mut res_rx).poll(cx))
738 .expect_err("callback should send error");
739
740 match (err.error.is_canceled(), err.message.as_ref()) {
741 (true, Some(_)) => (),
742 _ => panic!("expected Canceled, got {:?}", err),
743 }
744 });
745 }
746
747 #[cfg(not(miri))]
748 #[tokio::test]
749 async fn client_flushing_is_not_ready_for_next_request() {
750 let _ = pretty_env_logger::try_init();
751
752 let (io, _handle) = tokio_test::io::Builder::new()
753 .write(b"POST / HTTP/1.1\r\ncontent-length: 4\r\n\r\n")
754 .read(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n")
755 .wait(std::time::Duration::from_secs(2))
756 .build_with_handle();
757
758 let (mut tx, rx) = crate::client::dispatch::channel();
759 let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(Compat::new(io));
760 conn.set_write_strategy_queue();
761
762 let dispatcher = Dispatcher::new(Client::new(rx), conn);
763 let _dispatcher = tokio::spawn(async move { dispatcher.await });
764
765 let body = {
766 let (mut tx, body) = IncomingBody::new_channel(DecodedLength::new(4), false);
767 tx.try_send_data("reee".into()).unwrap();
768 body
769 };
770
771 let req = crate::Request::builder().method("POST").body(body).unwrap();
772
773 let res = tx.try_send(req).unwrap().await.expect("response");
774 drop(res);
775
776 assert!(!tx.is_ready());
777 }
778
779 #[cfg(not(miri))]
780 #[tokio::test]
781 async fn body_empty_chunks_ignored() {
782 let _ = pretty_env_logger::try_init();
783
784 let io = tokio_test::io::Builder::new()
785 .wait(Duration::from_secs(5))
787 .build();
788
789 let (mut tx, rx) = crate::client::dispatch::channel();
790 let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(Compat::new(io));
791 let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn));
792
793 assert!(dispatcher.poll().is_pending());
795
796 let body = {
797 let (mut tx, body) = IncomingBody::channel();
798 tx.try_send_data("".into()).unwrap();
799 body
800 };
801
802 let _res_rx = tx.try_send(crate::Request::new(body)).unwrap();
803
804 assert!(dispatcher.poll().is_pending());
807 }
808}