hyper_rustls/
stream.rs

1// Copied from hyperium/hyper-tls#62e3376/src/stream.rs
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use hyper::rt;
8use hyper_util::client::legacy::connect::{Connected, Connection};
9
10use hyper_util::rt::TokioIo;
11use tokio_rustls::client::TlsStream;
12
13/// A stream that might be protected with TLS.
14#[allow(clippy::large_enum_variant)]
15pub enum MaybeHttpsStream<T> {
16    /// A stream over plain text.
17    Http(T),
18    /// A stream protected with TLS.
19    Https(TokioIo<TlsStream<TokioIo<T>>>),
20}
21
22impl<T: rt::Read + rt::Write + Connection + Unpin> Connection for MaybeHttpsStream<T> {
23    fn connected(&self) -> Connected {
24        match self {
25            Self::Http(s) => s.connected(),
26            Self::Https(s) => {
27                let (tcp, tls) = s.inner().get_ref();
28                if tls.alpn_protocol() == Some(b"h2") {
29                    tcp.inner().connected().negotiated_h2()
30                } else {
31                    tcp.inner().connected()
32                }
33            }
34        }
35    }
36}
37
38impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
39    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40        match *self {
41            Self::Http(..) => f.pad("Http(..)"),
42            Self::Https(..) => f.pad("Https(..)"),
43        }
44    }
45}
46
47impl<T> From<T> for MaybeHttpsStream<T> {
48    fn from(inner: T) -> Self {
49        Self::Http(inner)
50    }
51}
52
53impl<T> From<TlsStream<TokioIo<T>>> for MaybeHttpsStream<T> {
54    fn from(inner: TlsStream<TokioIo<T>>) -> Self {
55        Self::Https(TokioIo::new(inner))
56    }
57}
58
59impl<T: rt::Read + rt::Write + Unpin> rt::Read for MaybeHttpsStream<T> {
60    #[inline]
61    fn poll_read(
62        self: Pin<&mut Self>,
63        cx: &mut Context,
64        buf: rt::ReadBufCursor<'_>,
65    ) -> Poll<Result<(), io::Error>> {
66        match Pin::get_mut(self) {
67            Self::Http(s) => Pin::new(s).poll_read(cx, buf),
68            Self::Https(s) => Pin::new(s).poll_read(cx, buf),
69        }
70    }
71}
72
73impl<T: rt::Write + rt::Read + Unpin> rt::Write for MaybeHttpsStream<T> {
74    #[inline]
75    fn poll_write(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78        buf: &[u8],
79    ) -> Poll<Result<usize, io::Error>> {
80        match Pin::get_mut(self) {
81            Self::Http(s) => Pin::new(s).poll_write(cx, buf),
82            Self::Https(s) => Pin::new(s).poll_write(cx, buf),
83        }
84    }
85
86    #[inline]
87    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
88        match Pin::get_mut(self) {
89            Self::Http(s) => Pin::new(s).poll_flush(cx),
90            Self::Https(s) => Pin::new(s).poll_flush(cx),
91        }
92    }
93
94    #[inline]
95    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
96        match Pin::get_mut(self) {
97            Self::Http(s) => Pin::new(s).poll_shutdown(cx),
98            Self::Https(s) => Pin::new(s).poll_shutdown(cx),
99        }
100    }
101
102    #[inline]
103    fn is_write_vectored(&self) -> bool {
104        match self {
105            Self::Http(s) => s.is_write_vectored(),
106            Self::Https(s) => s.is_write_vectored(),
107        }
108    }
109
110    #[inline]
111    fn poll_write_vectored(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        bufs: &[io::IoSlice<'_>],
115    ) -> Poll<Result<usize, io::Error>> {
116        match Pin::get_mut(self) {
117            Self::Http(s) => Pin::new(s).poll_write_vectored(cx, bufs),
118            Self::Https(s) => Pin::new(s).poll_write_vectored(cx, bufs),
119        }
120    }
121}