1use std::future::Future;
40use std::io;
41#[cfg(unix)]
42use std::os::unix::io::{AsRawFd, RawFd};
43#[cfg(windows)]
44use std::os::windows::io::{AsRawSocket, RawSocket};
45use std::pin::Pin;
46use std::sync::Arc;
47use std::task::{Context, Poll};
48
49pub use rustls;
50
51use rustls::pki_types::ServerName;
52use rustls::server::AcceptedAlert;
53use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
54use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
55
56macro_rules! ready {
57 ( $e:expr ) => {
58 match $e {
59 std::task::Poll::Ready(t) => t,
60 std::task::Poll::Pending => return std::task::Poll::Pending,
61 }
62 };
63}
64
65pub mod client;
66mod common;
67use common::{MidHandshake, TlsState};
68pub mod server;
69
70#[derive(Clone)]
72pub struct TlsConnector {
73 inner: Arc<ClientConfig>,
74 #[cfg(feature = "early-data")]
75 early_data: bool,
76}
77
78#[derive(Clone)]
80pub struct TlsAcceptor {
81 inner: Arc<ServerConfig>,
82}
83
84impl From<Arc<ClientConfig>> for TlsConnector {
85 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
86 TlsConnector {
87 inner,
88 #[cfg(feature = "early-data")]
89 early_data: false,
90 }
91 }
92}
93
94impl From<Arc<ServerConfig>> for TlsAcceptor {
95 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
96 TlsAcceptor { inner }
97 }
98}
99
100impl TlsConnector {
101 #[cfg(feature = "early-data")]
106 pub fn early_data(mut self, flag: bool) -> TlsConnector {
107 self.early_data = flag;
108 self
109 }
110
111 #[inline]
112 pub fn connect<IO>(&self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
113 where
114 IO: AsyncRead + AsyncWrite + Unpin,
115 {
116 self.connect_with(domain, stream, |_| ())
117 }
118
119 pub fn connect_with<IO, F>(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
120 where
121 IO: AsyncRead + AsyncWrite + Unpin,
122 F: FnOnce(&mut ClientConnection),
123 {
124 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
125 Ok(session) => session,
126 Err(error) => {
127 return Connect(MidHandshake::Error {
128 io: stream,
129 error: io::Error::new(io::ErrorKind::Other, error),
132 });
133 }
134 };
135 f(&mut session);
136
137 Connect(MidHandshake::Handshaking(client::TlsStream {
138 io: stream,
139
140 #[cfg(not(feature = "early-data"))]
141 state: TlsState::Stream,
142
143 #[cfg(feature = "early-data")]
144 state: if self.early_data && session.early_data().is_some() {
145 TlsState::EarlyData(0, Vec::new())
146 } else {
147 TlsState::Stream
148 },
149
150 #[cfg(feature = "early-data")]
151 early_waker: None,
152
153 session,
154 }))
155 }
156}
157
158impl TlsAcceptor {
159 #[inline]
160 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
161 where
162 IO: AsyncRead + AsyncWrite + Unpin,
163 {
164 self.accept_with(stream, |_| ())
165 }
166
167 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
168 where
169 IO: AsyncRead + AsyncWrite + Unpin,
170 F: FnOnce(&mut ServerConnection),
171 {
172 let mut session = match ServerConnection::new(self.inner.clone()) {
173 Ok(session) => session,
174 Err(error) => {
175 return Accept(MidHandshake::Error {
176 io: stream,
177 error: io::Error::new(io::ErrorKind::Other, error),
180 });
181 }
182 };
183 f(&mut session);
184
185 Accept(MidHandshake::Handshaking(server::TlsStream {
186 session,
187 io: stream,
188 state: TlsState::Stream,
189 }))
190 }
191}
192
193pub struct LazyConfigAcceptor<IO> {
194 acceptor: rustls::server::Acceptor,
195 io: Option<IO>,
196 alert: Option<(rustls::Error, AcceptedAlert)>,
197}
198
199impl<IO> LazyConfigAcceptor<IO>
200where
201 IO: AsyncRead + AsyncWrite + Unpin,
202{
203 #[inline]
204 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
205 Self {
206 acceptor,
207 io: Some(io),
208 alert: None,
209 }
210 }
211
212 pub fn take_io(&mut self) -> Option<IO> {
254 self.io.take()
255 }
256}
257
258impl<IO> Future for LazyConfigAcceptor<IO>
259where
260 IO: AsyncRead + AsyncWrite + Unpin,
261{
262 type Output = Result<StartHandshake<IO>, io::Error>;
263
264 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265 let this = self.get_mut();
266 loop {
267 let io = match this.io.as_mut() {
268 Some(io) => io,
269 None => {
270 return Poll::Ready(Err(io::Error::new(
271 io::ErrorKind::Other,
272 "acceptor cannot be polled after acceptance",
273 )))
274 }
275 };
276
277 if let Some((err, mut alert)) = this.alert.take() {
278 match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
279 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
280 this.alert = Some((err, alert));
281 return Poll::Pending;
282 }
283 Ok(0) | Err(_) => {
284 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
285 }
286 Ok(_) => {
287 this.alert = Some((err, alert));
288 continue;
289 }
290 };
291 }
292
293 let mut reader = common::SyncReadAdapter { io, cx };
294 match this.acceptor.read_tls(&mut reader) {
295 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
296 Ok(_) => {}
297 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
298 Err(e) => return Err(e).into(),
299 }
300
301 match this.acceptor.accept() {
302 Ok(Some(accepted)) => {
303 let io = this.io.take().unwrap();
304 return Poll::Ready(Ok(StartHandshake { accepted, io }));
305 }
306 Ok(None) => {}
307 Err((err, alert)) => {
308 this.alert = Some((err, alert));
309 }
310 }
311 }
312 }
313}
314
315pub struct StartHandshake<IO> {
316 accepted: rustls::server::Accepted,
317 io: IO,
318}
319
320impl<IO> StartHandshake<IO>
321where
322 IO: AsyncRead + AsyncWrite + Unpin,
323{
324 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
325 self.accepted.client_hello()
326 }
327
328 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
329 self.into_stream_with(config, |_| ())
330 }
331
332 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
333 where
334 F: FnOnce(&mut ServerConnection),
335 {
336 let mut conn = match self.accepted.into_connection(config) {
337 Ok(conn) => conn,
338 Err((error, alert)) => {
339 return Accept(MidHandshake::SendAlert {
340 io: self.io,
341 alert,
342 error: io::Error::new(io::ErrorKind::InvalidData, error),
345 });
346 }
347 };
348 f(&mut conn);
349
350 Accept(MidHandshake::Handshaking(server::TlsStream {
351 session: conn,
352 io: self.io,
353 state: TlsState::Stream,
354 }))
355 }
356}
357
358pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
361
362pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
365
366pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
368
369pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
371
372impl<IO> Connect<IO> {
373 #[inline]
374 pub fn into_fallible(self) -> FallibleConnect<IO> {
375 FallibleConnect(self.0)
376 }
377
378 pub fn get_ref(&self) -> Option<&IO> {
379 match &self.0 {
380 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
381 MidHandshake::SendAlert { io, .. } => Some(io),
382 MidHandshake::Error { io, .. } => Some(io),
383 MidHandshake::End => None,
384 }
385 }
386
387 pub fn get_mut(&mut self) -> Option<&mut IO> {
388 match &mut self.0 {
389 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
390 MidHandshake::SendAlert { io, .. } => Some(io),
391 MidHandshake::Error { io, .. } => Some(io),
392 MidHandshake::End => None,
393 }
394 }
395}
396
397impl<IO> Accept<IO> {
398 #[inline]
399 pub fn into_fallible(self) -> FallibleAccept<IO> {
400 FallibleAccept(self.0)
401 }
402
403 pub fn get_ref(&self) -> Option<&IO> {
404 match &self.0 {
405 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
406 MidHandshake::SendAlert { io, .. } => Some(io),
407 MidHandshake::Error { io, .. } => Some(io),
408 MidHandshake::End => None,
409 }
410 }
411
412 pub fn get_mut(&mut self) -> Option<&mut IO> {
413 match &mut self.0 {
414 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
415 MidHandshake::SendAlert { io, .. } => Some(io),
416 MidHandshake::Error { io, .. } => Some(io),
417 MidHandshake::End => None,
418 }
419 }
420}
421
422impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
423 type Output = io::Result<client::TlsStream<IO>>;
424
425 #[inline]
426 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
427 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
428 }
429}
430
431impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
432 type Output = io::Result<server::TlsStream<IO>>;
433
434 #[inline]
435 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
436 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
437 }
438}
439
440impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
441 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
442
443 #[inline]
444 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
445 Pin::new(&mut self.0).poll(cx)
446 }
447}
448
449impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
450 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
451
452 #[inline]
453 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
454 Pin::new(&mut self.0).poll(cx)
455 }
456}
457
458#[allow(clippy::large_enum_variant)] #[derive(Debug)]
464pub enum TlsStream<T> {
465 Client(client::TlsStream<T>),
466 Server(server::TlsStream<T>),
467}
468
469impl<T> TlsStream<T> {
470 pub fn get_ref(&self) -> (&T, &CommonState) {
471 use TlsStream::*;
472 match self {
473 Client(io) => {
474 let (io, session) = io.get_ref();
475 (io, session)
476 }
477 Server(io) => {
478 let (io, session) = io.get_ref();
479 (io, session)
480 }
481 }
482 }
483
484 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
485 use TlsStream::*;
486 match self {
487 Client(io) => {
488 let (io, session) = io.get_mut();
489 (io, &mut *session)
490 }
491 Server(io) => {
492 let (io, session) = io.get_mut();
493 (io, &mut *session)
494 }
495 }
496 }
497}
498
499impl<T> From<client::TlsStream<T>> for TlsStream<T> {
500 fn from(s: client::TlsStream<T>) -> Self {
501 Self::Client(s)
502 }
503}
504
505impl<T> From<server::TlsStream<T>> for TlsStream<T> {
506 fn from(s: server::TlsStream<T>) -> Self {
507 Self::Server(s)
508 }
509}
510
511#[cfg(unix)]
512impl<S> AsRawFd for TlsStream<S>
513where
514 S: AsRawFd,
515{
516 fn as_raw_fd(&self) -> RawFd {
517 self.get_ref().0.as_raw_fd()
518 }
519}
520
521#[cfg(windows)]
522impl<S> AsRawSocket for TlsStream<S>
523where
524 S: AsRawSocket,
525{
526 fn as_raw_socket(&self) -> RawSocket {
527 self.get_ref().0.as_raw_socket()
528 }
529}
530
531impl<T> AsyncRead for TlsStream<T>
532where
533 T: AsyncRead + AsyncWrite + Unpin,
534{
535 #[inline]
536 fn poll_read(
537 self: Pin<&mut Self>,
538 cx: &mut Context<'_>,
539 buf: &mut ReadBuf<'_>,
540 ) -> Poll<io::Result<()>> {
541 match self.get_mut() {
542 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
543 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
544 }
545 }
546}
547
548impl<T> AsyncWrite for TlsStream<T>
549where
550 T: AsyncRead + AsyncWrite + Unpin,
551{
552 #[inline]
553 fn poll_write(
554 self: Pin<&mut Self>,
555 cx: &mut Context<'_>,
556 buf: &[u8],
557 ) -> Poll<io::Result<usize>> {
558 match self.get_mut() {
559 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
560 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
561 }
562 }
563
564 #[inline]
565 fn poll_write_vectored(
566 self: Pin<&mut Self>,
567 cx: &mut Context<'_>,
568 bufs: &[io::IoSlice<'_>],
569 ) -> Poll<io::Result<usize>> {
570 match self.get_mut() {
571 TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
572 TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
573 }
574 }
575
576 #[inline]
577 fn is_write_vectored(&self) -> bool {
578 match self {
579 TlsStream::Client(x) => x.is_write_vectored(),
580 TlsStream::Server(x) => x.is_write_vectored(),
581 }
582 }
583
584 #[inline]
585 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
586 match self.get_mut() {
587 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
588 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
589 }
590 }
591
592 #[inline]
593 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
594 match self.get_mut() {
595 TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
596 TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
597 }
598 }
599}