1use std::io;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use rustls::ServerConnection;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::common::{IoSession, Stream, TlsState};
13
14#[derive(Debug)]
17pub struct TlsStream<IO> {
18 pub(crate) io: IO,
19 pub(crate) session: ServerConnection,
20 pub(crate) state: TlsState,
21}
22
23impl<IO> TlsStream<IO> {
24 #[inline]
25 pub fn get_ref(&self) -> (&IO, &ServerConnection) {
26 (&self.io, &self.session)
27 }
28
29 #[inline]
30 pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
31 (&mut self.io, &mut self.session)
32 }
33
34 #[inline]
35 pub fn into_inner(self) -> (IO, ServerConnection) {
36 (self.io, self.session)
37 }
38}
39
40impl<IO> IoSession for TlsStream<IO> {
41 type Io = IO;
42 type Session = ServerConnection;
43
44 #[inline]
45 fn skip_handshake(&self) -> bool {
46 false
47 }
48
49 #[inline]
50 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
51 (&mut self.state, &mut self.io, &mut self.session)
52 }
53
54 #[inline]
55 fn into_io(self) -> Self::Io {
56 self.io
57 }
58}
59
60impl<IO> AsyncRead for TlsStream<IO>
61where
62 IO: AsyncRead + AsyncWrite + Unpin,
63{
64 fn poll_read(
65 self: Pin<&mut Self>,
66 cx: &mut Context<'_>,
67 buf: &mut ReadBuf<'_>,
68 ) -> Poll<io::Result<()>> {
69 let this = self.get_mut();
70 let mut stream =
71 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
72
73 match &this.state {
74 TlsState::Stream | TlsState::WriteShutdown => {
75 let prev = buf.remaining();
76
77 match stream.as_mut_pin().poll_read(cx, buf) {
78 Poll::Ready(Ok(())) => {
79 if prev == buf.remaining() || stream.eof {
80 this.state.shutdown_read();
81 }
82
83 Poll::Ready(Ok(()))
84 }
85 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
86 this.state.shutdown_read();
87 Poll::Ready(Err(err))
88 }
89 output => output,
90 }
91 }
92 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
93 #[cfg(feature = "early-data")]
94 s => unreachable!("server TLS can not hit this state: {:?}", s),
95 }
96 }
97}
98
99impl<IO> AsyncWrite for TlsStream<IO>
100where
101 IO: AsyncRead + AsyncWrite + Unpin,
102{
103 fn poll_write(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 buf: &[u8],
109 ) -> Poll<io::Result<usize>> {
110 let this = self.get_mut();
111 let mut stream =
112 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
113 stream.as_mut_pin().poll_write(cx, buf)
114 }
115
116 fn poll_write_vectored(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 bufs: &[io::IoSlice<'_>],
122 ) -> Poll<io::Result<usize>> {
123 let this = self.get_mut();
124 let mut stream =
125 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
126 stream.as_mut_pin().poll_write_vectored(cx, bufs)
127 }
128
129 #[inline]
130 fn is_write_vectored(&self) -> bool {
131 true
132 }
133
134 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
135 let this = self.get_mut();
136 let mut stream =
137 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
138 stream.as_mut_pin().poll_flush(cx)
139 }
140
141 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142 if self.state.writeable() {
143 self.session.send_close_notify();
144 self.state.shutdown_write();
145 }
146
147 let this = self.get_mut();
148 let mut stream =
149 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
150 stream.as_mut_pin().poll_shutdown(cx)
151 }
152}
153
154#[cfg(unix)]
155impl<IO> AsRawFd for TlsStream<IO>
156where
157 IO: AsRawFd,
158{
159 fn as_raw_fd(&self) -> RawFd {
160 self.get_ref().0.as_raw_fd()
161 }
162}
163
164#[cfg(windows)]
165impl<IO> AsRawSocket for TlsStream<IO>
166where
167 IO: AsRawSocket,
168{
169 fn as_raw_socket(&self) -> RawSocket {
170 self.get_ref().0.as_raw_socket()
171 }
172}