tokio_rustls/
client.rs

1use std::io;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7#[cfg(feature = "early-data")]
8use std::task::Waker;
9use std::task::{Context, Poll};
10
11use rustls::ClientConnection;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14use crate::common::{IoSession, Stream, TlsState};
15
16/// A wrapper around an underlying raw stream which implements the TLS or SSL
17/// protocol.
18#[derive(Debug)]
19pub struct TlsStream<IO> {
20    pub(crate) io: IO,
21    pub(crate) session: ClientConnection,
22    pub(crate) state: TlsState,
23
24    #[cfg(feature = "early-data")]
25    pub(crate) early_waker: Option<Waker>,
26}
27
28impl<IO> TlsStream<IO> {
29    #[inline]
30    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
31        (&self.io, &self.session)
32    }
33
34    #[inline]
35    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
36        (&mut self.io, &mut self.session)
37    }
38
39    #[inline]
40    pub fn into_inner(self) -> (IO, ClientConnection) {
41        (self.io, self.session)
42    }
43}
44
45#[cfg(unix)]
46impl<S> AsRawFd for TlsStream<S>
47where
48    S: AsRawFd,
49{
50    fn as_raw_fd(&self) -> RawFd {
51        self.get_ref().0.as_raw_fd()
52    }
53}
54
55#[cfg(windows)]
56impl<S> AsRawSocket for TlsStream<S>
57where
58    S: AsRawSocket,
59{
60    fn as_raw_socket(&self) -> RawSocket {
61        self.get_ref().0.as_raw_socket()
62    }
63}
64
65impl<IO> IoSession for TlsStream<IO> {
66    type Io = IO;
67    type Session = ClientConnection;
68
69    #[inline]
70    fn skip_handshake(&self) -> bool {
71        self.state.is_early_data()
72    }
73
74    #[inline]
75    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
76        (&mut self.state, &mut self.io, &mut self.session)
77    }
78
79    #[inline]
80    fn into_io(self) -> Self::Io {
81        self.io
82    }
83}
84
85impl<IO> AsyncRead for TlsStream<IO>
86where
87    IO: AsyncRead + AsyncWrite + Unpin,
88{
89    fn poll_read(
90        self: Pin<&mut Self>,
91        cx: &mut Context<'_>,
92        buf: &mut ReadBuf<'_>,
93    ) -> Poll<io::Result<()>> {
94        match self.state {
95            #[cfg(feature = "early-data")]
96            TlsState::EarlyData(..) => {
97                let this = self.get_mut();
98
99                // In the EarlyData state, we have not really established a Tls connection.
100                // Before writing data through `AsyncWrite` and completing the tls handshake,
101                // we ignore read readiness and return to pending.
102                //
103                // In order to avoid event loss,
104                // we need to register a waker and wake it up after tls is connected.
105                if this
106                    .early_waker
107                    .as_ref()
108                    .filter(|waker| cx.waker().will_wake(waker))
109                    .is_none()
110                {
111                    this.early_waker = Some(cx.waker().clone());
112                }
113
114                Poll::Pending
115            }
116            TlsState::Stream | TlsState::WriteShutdown => {
117                let this = self.get_mut();
118                let mut stream =
119                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120                let prev = buf.remaining();
121
122                match stream.as_mut_pin().poll_read(cx, buf) {
123                    Poll::Ready(Ok(())) => {
124                        if prev == buf.remaining() || stream.eof {
125                            this.state.shutdown_read();
126                        }
127
128                        Poll::Ready(Ok(()))
129                    }
130                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
131                        this.state.shutdown_read();
132                        Poll::Ready(Err(err))
133                    }
134                    output => output,
135                }
136            }
137            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
138        }
139    }
140}
141
142impl<IO> AsyncWrite for TlsStream<IO>
143where
144    IO: AsyncRead + AsyncWrite + Unpin,
145{
146    /// Note: that it does not guarantee the final data to be sent.
147    /// To be cautious, you must manually call `flush`.
148    fn poll_write(
149        self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151        buf: &[u8],
152    ) -> Poll<io::Result<usize>> {
153        let this = self.get_mut();
154        let mut stream =
155            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
156
157        #[cfg(feature = "early-data")]
158        {
159            let bufs = [io::IoSlice::new(buf)];
160            let written = ready!(poll_handle_early_data(
161                &mut this.state,
162                &mut stream,
163                &mut this.early_waker,
164                cx,
165                &bufs
166            ))?;
167            if written != 0 {
168                return Poll::Ready(Ok(written));
169            }
170        }
171
172        stream.as_mut_pin().poll_write(cx, buf)
173    }
174
175    /// Note: that it does not guarantee the final data to be sent.
176    /// To be cautious, you must manually call `flush`.
177    fn poll_write_vectored(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        bufs: &[io::IoSlice<'_>],
181    ) -> Poll<io::Result<usize>> {
182        let this = self.get_mut();
183        let mut stream =
184            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
185
186        #[cfg(feature = "early-data")]
187        {
188            let written = ready!(poll_handle_early_data(
189                &mut this.state,
190                &mut stream,
191                &mut this.early_waker,
192                cx,
193                bufs
194            ))?;
195            if written != 0 {
196                return Poll::Ready(Ok(written));
197            }
198        }
199
200        stream.as_mut_pin().poll_write_vectored(cx, bufs)
201    }
202
203    #[inline]
204    fn is_write_vectored(&self) -> bool {
205        true
206    }
207
208    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209        let this = self.get_mut();
210        let mut stream =
211            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
212
213        #[cfg(feature = "early-data")]
214        ready!(poll_handle_early_data(
215            &mut this.state,
216            &mut stream,
217            &mut this.early_waker,
218            cx,
219            &[]
220        ))?;
221
222        stream.as_mut_pin().poll_flush(cx)
223    }
224
225    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
226        #[cfg(feature = "early-data")]
227        {
228            // complete handshake
229            if matches!(self.state, TlsState::EarlyData(..)) {
230                ready!(self.as_mut().poll_flush(cx))?;
231            }
232        }
233
234        if self.state.writeable() {
235            self.session.send_close_notify();
236            self.state.shutdown_write();
237        }
238
239        let this = self.get_mut();
240        let mut stream =
241            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
242        stream.as_mut_pin().poll_shutdown(cx)
243    }
244}
245
246#[cfg(feature = "early-data")]
247fn poll_handle_early_data<IO>(
248    state: &mut TlsState,
249    stream: &mut Stream<IO, ClientConnection>,
250    early_waker: &mut Option<Waker>,
251    cx: &mut Context<'_>,
252    bufs: &[io::IoSlice<'_>],
253) -> Poll<io::Result<usize>>
254where
255    IO: AsyncRead + AsyncWrite + Unpin,
256{
257    if let TlsState::EarlyData(pos, data) = state {
258        use std::io::Write;
259
260        // write early data
261        if let Some(mut early_data) = stream.session.early_data() {
262            let mut written = 0;
263
264            for buf in bufs {
265                if buf.is_empty() {
266                    continue;
267                }
268
269                let len = match early_data.write(buf) {
270                    Ok(0) => break,
271                    Ok(n) => n,
272                    Err(err) => return Poll::Ready(Err(err)),
273                };
274
275                written += len;
276                data.extend_from_slice(&buf[..len]);
277
278                if len < buf.len() {
279                    break;
280                }
281            }
282
283            if written != 0 {
284                return Poll::Ready(Ok(written));
285            }
286        }
287
288        // complete handshake
289        while stream.session.is_handshaking() {
290            ready!(stream.handshake(cx))?;
291        }
292
293        // write early data (fallback)
294        if !stream.session.is_early_data_accepted() {
295            while *pos < data.len() {
296                let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
297                *pos += len;
298            }
299        }
300
301        // end
302        *state = TlsState::Stream;
303
304        if let Some(waker) = early_waker.take() {
305            waker.wake();
306        }
307    }
308
309    Poll::Ready(Ok(0))
310}