async_native_tls/
handshake.rs1use std::future::Future;
2use std::io::{Read, Write};
3use std::marker::Unpin;
4use std::pin::Pin;
5use std::ptr::null_mut;
6use std::task::{Context, Poll};
7
8use native_tls::{Error, HandshakeError, MidHandshakeTlsStream};
9
10use crate::runtime::{AsyncRead, AsyncWrite};
11use crate::std_adapter::StdAdapter;
12use crate::TlsStream;
13
14pub(crate) async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error>
15where
16 F: FnOnce(
17 StdAdapter<S>,
18 )
19 -> Result<native_tls::TlsStream<StdAdapter<S>>, HandshakeError<StdAdapter<S>>>
20 + Unpin,
21 S: AsyncRead + AsyncWrite + Unpin,
22{
23 let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
24
25 match start.await {
26 Err(e) => Err(e),
27 Ok(StartedHandshake::Done(s)) => Ok(s),
28 Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await,
29 }
30}
31
32struct MidHandshake<S>(Option<MidHandshakeTlsStream<StdAdapter<S>>>);
33
34enum StartedHandshake<S> {
35 Done(TlsStream<S>),
36 Mid(MidHandshakeTlsStream<StdAdapter<S>>),
37}
38
39struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
40struct StartedHandshakeFutureInner<F, S> {
41 f: F,
42 stream: S,
43}
44
45impl<F, S> Future for StartedHandshakeFuture<F, S>
46where
47 F: FnOnce(
48 StdAdapter<S>,
49 )
50 -> Result<native_tls::TlsStream<StdAdapter<S>>, HandshakeError<StdAdapter<S>>>
51 + Unpin,
52 S: Unpin,
53 StdAdapter<S>: Read + Write,
54{
55 type Output = Result<StartedHandshake<S>, Error>;
56
57 fn poll(
58 mut self: Pin<&mut Self>,
59 ctx: &mut Context<'_>,
60 ) -> Poll<Result<StartedHandshake<S>, Error>> {
61 let inner = self.0.take().expect("future polled after completion");
62 let stream = StdAdapter {
63 inner: inner.stream,
64 context: ctx as *mut _ as *mut (),
65 };
66
67 match (inner.f)(stream) {
68 Ok(mut s) => {
69 s.get_mut().context = null_mut();
70 Poll::Ready(Ok(StartedHandshake::Done(TlsStream::new(s))))
71 }
72 Err(HandshakeError::WouldBlock(mut s)) => {
73 s.get_mut().context = null_mut();
74 Poll::Ready(Ok(StartedHandshake::Mid(s)))
75 }
76 Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
77 }
78 }
79}
80
81impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> {
82 type Output = Result<TlsStream<S>, Error>;
83
84 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85 let mut_self = self.get_mut();
86 let mut s = mut_self.0.take().expect("future polled after completion");
87
88 s.get_mut().context = cx as *mut _ as *mut ();
89 match s.handshake() {
90 Ok(stream) => Poll::Ready(Ok(TlsStream::new(stream))),
91 Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
92 Err(HandshakeError::WouldBlock(mut s)) => {
93 s.get_mut().context = null_mut();
94 mut_self.0 = Some(s);
95 Poll::Pending
96 }
97 }
98 }
99}