tokio_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls).
2//!
3//! # Why do I need to call `poll_flush`?
4//!
5//! Most TLS implementations will have an internal buffer to improve throughput,
6//! and rustls is no exception.
7//!
8//! When we write data to `TlsStream`, we always write rustls buffer first,
9//! then take out rustls encrypted data packet, and write it to data channel (like TcpStream).
10//! When data channel is pending, some data may remain in rustls buffer.
11//!
12//! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`.
13//! For `TlsStream<TcpStream>`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`.
14//! You must call `poll_flush` to ensure that it is written to `TcpStream`.
15//!
16//! You should call `poll_flush` at the appropriate time,
17//! such as when a period of `poll_write` write is complete and there is no more data to write.
18//!
19//! ## Why don't we write during `poll_read`?
20//!
21//! We did this in the early days of `tokio-rustls`, but it caused some bugs.
22//! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup).
23//!
24//! And reverse write will also prevent us implement full duplex in the future.
25//!
26//! see <https://github.com/tokio-rs/tls/issues/40>
27//!
28//! ## Why can't we handle it like `native-tls`?
29//!
30//! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes.
31//! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`.
32//! Thus avoiding the call of `poll_flush`.
33//!
34//! but which does not conform to convention of `AsyncWrite` trait.
35//! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior.
36//!
37//! see <https://github.com/tokio-rs/tls/issues/41>
38
39use std::future::Future;
40use std::io;
41#[cfg(unix)]
42use std::os::unix::io::{AsRawFd, RawFd};
43#[cfg(windows)]
44use std::os::windows::io::{AsRawSocket, RawSocket};
45use std::pin::Pin;
46use std::sync::Arc;
47use std::task::{Context, Poll};
48
49pub use rustls;
50
51use rustls::pki_types::ServerName;
52use rustls::server::AcceptedAlert;
53use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
54use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
55
56macro_rules! ready {
57    ( $e:expr ) => {
58        match $e {
59            std::task::Poll::Ready(t) => t,
60            std::task::Poll::Pending => return std::task::Poll::Pending,
61        }
62    };
63}
64
65pub mod client;
66mod common;
67use common::{MidHandshake, TlsState};
68pub mod server;
69
70/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
71#[derive(Clone)]
72pub struct TlsConnector {
73    inner: Arc<ClientConfig>,
74    #[cfg(feature = "early-data")]
75    early_data: bool,
76}
77
78/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
79#[derive(Clone)]
80pub struct TlsAcceptor {
81    inner: Arc<ServerConfig>,
82}
83
84impl From<Arc<ClientConfig>> for TlsConnector {
85    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
86        TlsConnector {
87            inner,
88            #[cfg(feature = "early-data")]
89            early_data: false,
90        }
91    }
92}
93
94impl From<Arc<ServerConfig>> for TlsAcceptor {
95    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
96        TlsAcceptor { inner }
97    }
98}
99
100impl TlsConnector {
101    /// Enable 0-RTT.
102    ///
103    /// If you want to use 0-RTT,
104    /// You must also set `ClientConfig.enable_early_data` to `true`.
105    #[cfg(feature = "early-data")]
106    pub fn early_data(mut self, flag: bool) -> TlsConnector {
107        self.early_data = flag;
108        self
109    }
110
111    #[inline]
112    pub fn connect<IO>(&self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
113    where
114        IO: AsyncRead + AsyncWrite + Unpin,
115    {
116        self.connect_with(domain, stream, |_| ())
117    }
118
119    pub fn connect_with<IO, F>(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
120    where
121        IO: AsyncRead + AsyncWrite + Unpin,
122        F: FnOnce(&mut ClientConnection),
123    {
124        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
125            Ok(session) => session,
126            Err(error) => {
127                return Connect(MidHandshake::Error {
128                    io: stream,
129                    // TODO(eliza): should this really return an `io::Error`?
130                    // Probably not...
131                    error: io::Error::new(io::ErrorKind::Other, error),
132                });
133            }
134        };
135        f(&mut session);
136
137        Connect(MidHandshake::Handshaking(client::TlsStream {
138            io: stream,
139
140            #[cfg(not(feature = "early-data"))]
141            state: TlsState::Stream,
142
143            #[cfg(feature = "early-data")]
144            state: if self.early_data && session.early_data().is_some() {
145                TlsState::EarlyData(0, Vec::new())
146            } else {
147                TlsState::Stream
148            },
149
150            #[cfg(feature = "early-data")]
151            early_waker: None,
152
153            session,
154        }))
155    }
156}
157
158impl TlsAcceptor {
159    #[inline]
160    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
161    where
162        IO: AsyncRead + AsyncWrite + Unpin,
163    {
164        self.accept_with(stream, |_| ())
165    }
166
167    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
168    where
169        IO: AsyncRead + AsyncWrite + Unpin,
170        F: FnOnce(&mut ServerConnection),
171    {
172        let mut session = match ServerConnection::new(self.inner.clone()) {
173            Ok(session) => session,
174            Err(error) => {
175                return Accept(MidHandshake::Error {
176                    io: stream,
177                    // TODO(eliza): should this really return an `io::Error`?
178                    // Probably not...
179                    error: io::Error::new(io::ErrorKind::Other, error),
180                });
181            }
182        };
183        f(&mut session);
184
185        Accept(MidHandshake::Handshaking(server::TlsStream {
186            session,
187            io: stream,
188            state: TlsState::Stream,
189        }))
190    }
191}
192
193pub struct LazyConfigAcceptor<IO> {
194    acceptor: rustls::server::Acceptor,
195    io: Option<IO>,
196    alert: Option<(rustls::Error, AcceptedAlert)>,
197}
198
199impl<IO> LazyConfigAcceptor<IO>
200where
201    IO: AsyncRead + AsyncWrite + Unpin,
202{
203    #[inline]
204    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
205        Self {
206            acceptor,
207            io: Some(io),
208            alert: None,
209        }
210    }
211
212    /// Takes back the client connection. Will return `None` if called more than once or if the
213    /// connection has been accepted.
214    ///
215    /// # Example
216    ///
217    /// ```no_run
218    /// # fn choose_server_config(
219    /// #     _: rustls::server::ClientHello,
220    /// # ) -> std::sync::Arc<rustls::ServerConfig> {
221    /// #     unimplemented!();
222    /// # }
223    /// # #[allow(unused_variables)]
224    /// # async fn listen() {
225    /// use tokio::io::AsyncWriteExt;
226    /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
227    /// let (stream, _) = listener.accept().await.unwrap();
228    ///
229    /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
230    /// tokio::pin!(acceptor);
231    ///
232    /// match acceptor.as_mut().await {
233    ///     Ok(start) => {
234    ///         let clientHello = start.client_hello();
235    ///         let config = choose_server_config(clientHello);
236    ///         let stream = start.into_stream(config).await.unwrap();
237    ///         // Proceed with handling the ServerConnection...
238    ///     }
239    ///     Err(err) => {
240    ///         if let Some(mut stream) = acceptor.take_io() {
241    ///             stream
242    ///                 .write_all(
243    ///                     format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
244    ///                         .as_bytes()
245    ///                 )
246    ///                 .await
247    ///                 .unwrap();
248    ///         }
249    ///     }
250    /// }
251    /// # }
252    /// ```
253    pub fn take_io(&mut self) -> Option<IO> {
254        self.io.take()
255    }
256}
257
258impl<IO> Future for LazyConfigAcceptor<IO>
259where
260    IO: AsyncRead + AsyncWrite + Unpin,
261{
262    type Output = Result<StartHandshake<IO>, io::Error>;
263
264    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265        let this = self.get_mut();
266        loop {
267            let io = match this.io.as_mut() {
268                Some(io) => io,
269                None => {
270                    return Poll::Ready(Err(io::Error::new(
271                        io::ErrorKind::Other,
272                        "acceptor cannot be polled after acceptance",
273                    )))
274                }
275            };
276
277            if let Some((err, mut alert)) = this.alert.take() {
278                match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
279                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
280                        this.alert = Some((err, alert));
281                        return Poll::Pending;
282                    }
283                    Ok(0) | Err(_) => {
284                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
285                    }
286                    Ok(_) => {
287                        this.alert = Some((err, alert));
288                        continue;
289                    }
290                };
291            }
292
293            let mut reader = common::SyncReadAdapter { io, cx };
294            match this.acceptor.read_tls(&mut reader) {
295                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
296                Ok(_) => {}
297                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
298                Err(e) => return Err(e).into(),
299            }
300
301            match this.acceptor.accept() {
302                Ok(Some(accepted)) => {
303                    let io = this.io.take().unwrap();
304                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
305                }
306                Ok(None) => {}
307                Err((err, alert)) => {
308                    this.alert = Some((err, alert));
309                }
310            }
311        }
312    }
313}
314
315pub struct StartHandshake<IO> {
316    accepted: rustls::server::Accepted,
317    io: IO,
318}
319
320impl<IO> StartHandshake<IO>
321where
322    IO: AsyncRead + AsyncWrite + Unpin,
323{
324    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
325        self.accepted.client_hello()
326    }
327
328    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
329        self.into_stream_with(config, |_| ())
330    }
331
332    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
333    where
334        F: FnOnce(&mut ServerConnection),
335    {
336        let mut conn = match self.accepted.into_connection(config) {
337            Ok(conn) => conn,
338            Err((error, alert)) => {
339                return Accept(MidHandshake::SendAlert {
340                    io: self.io,
341                    alert,
342                    // TODO(eliza): should this really return an `io::Error`?
343                    // Probably not...
344                    error: io::Error::new(io::ErrorKind::InvalidData, error),
345                });
346            }
347        };
348        f(&mut conn);
349
350        Accept(MidHandshake::Handshaking(server::TlsStream {
351            session: conn,
352            io: self.io,
353            state: TlsState::Stream,
354        }))
355    }
356}
357
358/// Future returned from `TlsConnector::connect` which will resolve
359/// once the connection handshake has finished.
360pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
361
362/// Future returned from `TlsAcceptor::accept` which will resolve
363/// once the accept handshake has finished.
364pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
365
366/// Like [Connect], but returns `IO` on failure.
367pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
368
369/// Like [Accept], but returns `IO` on failure.
370pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
371
372impl<IO> Connect<IO> {
373    #[inline]
374    pub fn into_fallible(self) -> FallibleConnect<IO> {
375        FallibleConnect(self.0)
376    }
377
378    pub fn get_ref(&self) -> Option<&IO> {
379        match &self.0 {
380            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
381            MidHandshake::SendAlert { io, .. } => Some(io),
382            MidHandshake::Error { io, .. } => Some(io),
383            MidHandshake::End => None,
384        }
385    }
386
387    pub fn get_mut(&mut self) -> Option<&mut IO> {
388        match &mut self.0 {
389            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
390            MidHandshake::SendAlert { io, .. } => Some(io),
391            MidHandshake::Error { io, .. } => Some(io),
392            MidHandshake::End => None,
393        }
394    }
395}
396
397impl<IO> Accept<IO> {
398    #[inline]
399    pub fn into_fallible(self) -> FallibleAccept<IO> {
400        FallibleAccept(self.0)
401    }
402
403    pub fn get_ref(&self) -> Option<&IO> {
404        match &self.0 {
405            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
406            MidHandshake::SendAlert { io, .. } => Some(io),
407            MidHandshake::Error { io, .. } => Some(io),
408            MidHandshake::End => None,
409        }
410    }
411
412    pub fn get_mut(&mut self) -> Option<&mut IO> {
413        match &mut self.0 {
414            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
415            MidHandshake::SendAlert { io, .. } => Some(io),
416            MidHandshake::Error { io, .. } => Some(io),
417            MidHandshake::End => None,
418        }
419    }
420}
421
422impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
423    type Output = io::Result<client::TlsStream<IO>>;
424
425    #[inline]
426    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
427        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
428    }
429}
430
431impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
432    type Output = io::Result<server::TlsStream<IO>>;
433
434    #[inline]
435    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
436        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
437    }
438}
439
440impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
441    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
442
443    #[inline]
444    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
445        Pin::new(&mut self.0).poll(cx)
446    }
447}
448
449impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
450    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
451
452    #[inline]
453    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
454        Pin::new(&mut self.0).poll(cx)
455    }
456}
457
458/// Unified TLS stream type
459///
460/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
461/// a single type to keep both client- and server-initiated TLS-encrypted connections.
462#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
463#[derive(Debug)]
464pub enum TlsStream<T> {
465    Client(client::TlsStream<T>),
466    Server(server::TlsStream<T>),
467}
468
469impl<T> TlsStream<T> {
470    pub fn get_ref(&self) -> (&T, &CommonState) {
471        use TlsStream::*;
472        match self {
473            Client(io) => {
474                let (io, session) = io.get_ref();
475                (io, session)
476            }
477            Server(io) => {
478                let (io, session) = io.get_ref();
479                (io, session)
480            }
481        }
482    }
483
484    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
485        use TlsStream::*;
486        match self {
487            Client(io) => {
488                let (io, session) = io.get_mut();
489                (io, &mut *session)
490            }
491            Server(io) => {
492                let (io, session) = io.get_mut();
493                (io, &mut *session)
494            }
495        }
496    }
497}
498
499impl<T> From<client::TlsStream<T>> for TlsStream<T> {
500    fn from(s: client::TlsStream<T>) -> Self {
501        Self::Client(s)
502    }
503}
504
505impl<T> From<server::TlsStream<T>> for TlsStream<T> {
506    fn from(s: server::TlsStream<T>) -> Self {
507        Self::Server(s)
508    }
509}
510
511#[cfg(unix)]
512impl<S> AsRawFd for TlsStream<S>
513where
514    S: AsRawFd,
515{
516    fn as_raw_fd(&self) -> RawFd {
517        self.get_ref().0.as_raw_fd()
518    }
519}
520
521#[cfg(windows)]
522impl<S> AsRawSocket for TlsStream<S>
523where
524    S: AsRawSocket,
525{
526    fn as_raw_socket(&self) -> RawSocket {
527        self.get_ref().0.as_raw_socket()
528    }
529}
530
531impl<T> AsyncRead for TlsStream<T>
532where
533    T: AsyncRead + AsyncWrite + Unpin,
534{
535    #[inline]
536    fn poll_read(
537        self: Pin<&mut Self>,
538        cx: &mut Context<'_>,
539        buf: &mut ReadBuf<'_>,
540    ) -> Poll<io::Result<()>> {
541        match self.get_mut() {
542            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
543            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
544        }
545    }
546}
547
548impl<T> AsyncWrite for TlsStream<T>
549where
550    T: AsyncRead + AsyncWrite + Unpin,
551{
552    #[inline]
553    fn poll_write(
554        self: Pin<&mut Self>,
555        cx: &mut Context<'_>,
556        buf: &[u8],
557    ) -> Poll<io::Result<usize>> {
558        match self.get_mut() {
559            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
560            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
561        }
562    }
563
564    #[inline]
565    fn poll_write_vectored(
566        self: Pin<&mut Self>,
567        cx: &mut Context<'_>,
568        bufs: &[io::IoSlice<'_>],
569    ) -> Poll<io::Result<usize>> {
570        match self.get_mut() {
571            TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
572            TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
573        }
574    }
575
576    #[inline]
577    fn is_write_vectored(&self) -> bool {
578        match self {
579            TlsStream::Client(x) => x.is_write_vectored(),
580            TlsStream::Server(x) => x.is_write_vectored(),
581        }
582    }
583
584    #[inline]
585    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
586        match self.get_mut() {
587            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
588            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
589        }
590    }
591
592    #[inline]
593    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
594        match self.get_mut() {
595            TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
596            TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
597        }
598    }
599}