tokio_rustls/
server.rs

1use std::io;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use rustls::ServerConnection;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::common::{IoSession, Stream, TlsState};
13
14/// A wrapper around an underlying raw stream which implements the TLS or SSL
15/// protocol.
16#[derive(Debug)]
17pub struct TlsStream<IO> {
18    pub(crate) io: IO,
19    pub(crate) session: ServerConnection,
20    pub(crate) state: TlsState,
21}
22
23impl<IO> TlsStream<IO> {
24    #[inline]
25    pub fn get_ref(&self) -> (&IO, &ServerConnection) {
26        (&self.io, &self.session)
27    }
28
29    #[inline]
30    pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
31        (&mut self.io, &mut self.session)
32    }
33
34    #[inline]
35    pub fn into_inner(self) -> (IO, ServerConnection) {
36        (self.io, self.session)
37    }
38}
39
40impl<IO> IoSession for TlsStream<IO> {
41    type Io = IO;
42    type Session = ServerConnection;
43
44    #[inline]
45    fn skip_handshake(&self) -> bool {
46        false
47    }
48
49    #[inline]
50    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
51        (&mut self.state, &mut self.io, &mut self.session)
52    }
53
54    #[inline]
55    fn into_io(self) -> Self::Io {
56        self.io
57    }
58}
59
60impl<IO> AsyncRead for TlsStream<IO>
61where
62    IO: AsyncRead + AsyncWrite + Unpin,
63{
64    fn poll_read(
65        self: Pin<&mut Self>,
66        cx: &mut Context<'_>,
67        buf: &mut ReadBuf<'_>,
68    ) -> Poll<io::Result<()>> {
69        let this = self.get_mut();
70        let mut stream =
71            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
72
73        match &this.state {
74            TlsState::Stream | TlsState::WriteShutdown => {
75                let prev = buf.remaining();
76
77                match stream.as_mut_pin().poll_read(cx, buf) {
78                    Poll::Ready(Ok(())) => {
79                        if prev == buf.remaining() || stream.eof {
80                            this.state.shutdown_read();
81                        }
82
83                        Poll::Ready(Ok(()))
84                    }
85                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
86                        this.state.shutdown_read();
87                        Poll::Ready(Err(err))
88                    }
89                    output => output,
90                }
91            }
92            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
93            #[cfg(feature = "early-data")]
94            s => unreachable!("server TLS can not hit this state: {:?}", s),
95        }
96    }
97}
98
99impl<IO> AsyncWrite for TlsStream<IO>
100where
101    IO: AsyncRead + AsyncWrite + Unpin,
102{
103    /// Note: that it does not guarantee the final data to be sent.
104    /// To be cautious, you must manually call `flush`.
105    fn poll_write(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        buf: &[u8],
109    ) -> Poll<io::Result<usize>> {
110        let this = self.get_mut();
111        let mut stream =
112            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
113        stream.as_mut_pin().poll_write(cx, buf)
114    }
115
116    /// Note: that it does not guarantee the final data to be sent.
117    /// To be cautious, you must manually call `flush`.
118    fn poll_write_vectored(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        bufs: &[io::IoSlice<'_>],
122    ) -> Poll<io::Result<usize>> {
123        let this = self.get_mut();
124        let mut stream =
125            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
126        stream.as_mut_pin().poll_write_vectored(cx, bufs)
127    }
128
129    #[inline]
130    fn is_write_vectored(&self) -> bool {
131        true
132    }
133
134    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
135        let this = self.get_mut();
136        let mut stream =
137            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
138        stream.as_mut_pin().poll_flush(cx)
139    }
140
141    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142        if self.state.writeable() {
143            self.session.send_close_notify();
144            self.state.shutdown_write();
145        }
146
147        let this = self.get_mut();
148        let mut stream =
149            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
150        stream.as_mut_pin().poll_shutdown(cx)
151    }
152}
153
154#[cfg(unix)]
155impl<IO> AsRawFd for TlsStream<IO>
156where
157    IO: AsRawFd,
158{
159    fn as_raw_fd(&self) -> RawFd {
160        self.get_ref().0.as_raw_fd()
161    }
162}
163
164#[cfg(windows)]
165impl<IO> AsRawSocket for TlsStream<IO>
166where
167    IO: AsRawSocket,
168{
169    fn as_raw_socket(&self) -> RawSocket {
170        self.get_ref().0.as_raw_socket()
171    }
172}