hyper_util/rt/
tokio.rs

1#![allow(dead_code)]
2//! Tokio IO integration for hyper
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9
10use hyper::rt::{Executor, Sleep, Timer};
11use pin_project_lite::pin_project;
12
13/// Future executor that utilises `tokio` threads.
14#[non_exhaustive]
15#[derive(Default, Debug, Clone)]
16pub struct TokioExecutor {}
17
18pin_project! {
19    /// A wrapper that implements Tokio's IO traits for an inner type that
20    /// implements hyper's IO traits, or vice versa (implements hyper's IO
21    /// traits for a type that implements Tokio's IO traits).
22    #[derive(Debug)]
23    pub struct TokioIo<T> {
24        #[pin]
25        inner: T,
26    }
27}
28
29/// A Timer that uses the tokio runtime.
30#[non_exhaustive]
31#[derive(Default, Clone, Debug)]
32pub struct TokioTimer;
33
34// Use TokioSleep to get tokio::time::Sleep to implement Unpin.
35// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html
36pin_project! {
37    #[derive(Debug)]
38    struct TokioSleep {
39        #[pin]
40        inner: tokio::time::Sleep,
41    }
42}
43
44// ===== impl TokioExecutor =====
45
46impl<Fut> Executor<Fut> for TokioExecutor
47where
48    Fut: Future + Send + 'static,
49    Fut::Output: Send + 'static,
50{
51    fn execute(&self, fut: Fut) {
52        tokio::spawn(fut);
53    }
54}
55
56impl TokioExecutor {
57    /// Create new executor that relies on [`tokio::spawn`] to execute futures.
58    pub fn new() -> Self {
59        Self {}
60    }
61}
62
63// ==== impl TokioIo =====
64
65impl<T> TokioIo<T> {
66    /// Wrap a type implementing Tokio's or hyper's IO traits.
67    pub fn new(inner: T) -> Self {
68        Self { inner }
69    }
70
71    /// Borrow the inner type.
72    pub fn inner(&self) -> &T {
73        &self.inner
74    }
75
76    /// Mut borrow the inner type.
77    pub fn inner_mut(&mut self) -> &mut T {
78        &mut self.inner
79    }
80
81    /// Consume this wrapper and get the inner type.
82    pub fn into_inner(self) -> T {
83        self.inner
84    }
85}
86
87impl<T> hyper::rt::Read for TokioIo<T>
88where
89    T: tokio::io::AsyncRead,
90{
91    fn poll_read(
92        self: Pin<&mut Self>,
93        cx: &mut Context<'_>,
94        mut buf: hyper::rt::ReadBufCursor<'_>,
95    ) -> Poll<Result<(), std::io::Error>> {
96        let n = unsafe {
97            let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
98            match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
99                Poll::Ready(Ok(())) => tbuf.filled().len(),
100                other => return other,
101            }
102        };
103
104        unsafe {
105            buf.advance(n);
106        }
107        Poll::Ready(Ok(()))
108    }
109}
110
111impl<T> hyper::rt::Write for TokioIo<T>
112where
113    T: tokio::io::AsyncWrite,
114{
115    fn poll_write(
116        self: Pin<&mut Self>,
117        cx: &mut Context<'_>,
118        buf: &[u8],
119    ) -> Poll<Result<usize, std::io::Error>> {
120        tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
121    }
122
123    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
124        tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
125    }
126
127    fn poll_shutdown(
128        self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130    ) -> Poll<Result<(), std::io::Error>> {
131        tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
132    }
133
134    fn is_write_vectored(&self) -> bool {
135        tokio::io::AsyncWrite::is_write_vectored(&self.inner)
136    }
137
138    fn poll_write_vectored(
139        self: Pin<&mut Self>,
140        cx: &mut Context<'_>,
141        bufs: &[std::io::IoSlice<'_>],
142    ) -> Poll<Result<usize, std::io::Error>> {
143        tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
144    }
145}
146
147impl<T> tokio::io::AsyncRead for TokioIo<T>
148where
149    T: hyper::rt::Read,
150{
151    fn poll_read(
152        self: Pin<&mut Self>,
153        cx: &mut Context<'_>,
154        tbuf: &mut tokio::io::ReadBuf<'_>,
155    ) -> Poll<Result<(), std::io::Error>> {
156        //let init = tbuf.initialized().len();
157        let filled = tbuf.filled().len();
158        let sub_filled = unsafe {
159            let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
160
161            match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
162                Poll::Ready(Ok(())) => buf.filled().len(),
163                other => return other,
164            }
165        };
166
167        let n_filled = filled + sub_filled;
168        // At least sub_filled bytes had to have been initialized.
169        let n_init = sub_filled;
170        unsafe {
171            tbuf.assume_init(n_init);
172            tbuf.set_filled(n_filled);
173        }
174
175        Poll::Ready(Ok(()))
176    }
177}
178
179impl<T> tokio::io::AsyncWrite for TokioIo<T>
180where
181    T: hyper::rt::Write,
182{
183    fn poll_write(
184        self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186        buf: &[u8],
187    ) -> Poll<Result<usize, std::io::Error>> {
188        hyper::rt::Write::poll_write(self.project().inner, cx, buf)
189    }
190
191    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
192        hyper::rt::Write::poll_flush(self.project().inner, cx)
193    }
194
195    fn poll_shutdown(
196        self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198    ) -> Poll<Result<(), std::io::Error>> {
199        hyper::rt::Write::poll_shutdown(self.project().inner, cx)
200    }
201
202    fn is_write_vectored(&self) -> bool {
203        hyper::rt::Write::is_write_vectored(&self.inner)
204    }
205
206    fn poll_write_vectored(
207        self: Pin<&mut Self>,
208        cx: &mut Context<'_>,
209        bufs: &[std::io::IoSlice<'_>],
210    ) -> Poll<Result<usize, std::io::Error>> {
211        hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
212    }
213}
214
215// ==== impl TokioTimer =====
216
217impl Timer for TokioTimer {
218    fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
219        Box::pin(TokioSleep {
220            inner: tokio::time::sleep(duration),
221        })
222    }
223
224    fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
225        Box::pin(TokioSleep {
226            inner: tokio::time::sleep_until(deadline.into()),
227        })
228    }
229
230    fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
231        if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
232            sleep.reset(new_deadline)
233        }
234    }
235}
236
237impl TokioTimer {
238    /// Create a new TokioTimer
239    pub fn new() -> Self {
240        Self {}
241    }
242}
243
244impl Future for TokioSleep {
245    type Output = ();
246
247    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
248        self.project().inner.poll(cx)
249    }
250}
251
252impl Sleep for TokioSleep {}
253
254impl TokioSleep {
255    fn reset(self: Pin<&mut Self>, deadline: Instant) {
256        self.project().inner.as_mut().reset(deadline.into());
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use crate::rt::TokioExecutor;
263    use hyper::rt::Executor;
264    use tokio::sync::oneshot;
265
266    #[cfg(not(miri))]
267    #[tokio::test]
268    async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> {
269        let (tx, rx) = oneshot::channel();
270        let executor = TokioExecutor::new();
271        executor.execute(async move {
272            tx.send(()).unwrap();
273        });
274        rx.await.map_err(Into::into)
275    }
276}