1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
//! `async` related
//!
//! `#[no_std]` compatible.

//---------------------------------------------------------------------------------------------------- Use
use core::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use futures::{channel::oneshot, FutureExt};

//---------------------------------------------------------------------------------------------------- InfallibleOneshotReceiver
/// A oneshot receiver channel that doesn't return an Error.
///
/// This requires the sender to always return a response.
pub struct InfallibleOneshotReceiver<T>(oneshot::Receiver<T>);

impl<T> From<oneshot::Receiver<T>> for InfallibleOneshotReceiver<T> {
    fn from(value: oneshot::Receiver<T>) -> Self {
        Self(value)
    }
}

impl<T> Future for InfallibleOneshotReceiver<T> {
    type Output = T;

    #[inline]
    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        self.0
            .poll_unpin(ctx)
            .map(|res| res.expect("Oneshot must not be cancelled before response!"))
    }
}

//---------------------------------------------------------------------------------------------------- rayon_spawn_async
/// Spawns a task for the rayon thread pool and awaits the result without blocking the async runtime.
pub async fn rayon_spawn_async<F, R>(f: F) -> R
where
    F: FnOnce() -> R + Send + 'static,
    R: Send + 'static,
{
    let (tx, rx) = oneshot::channel();
    rayon::spawn(move || {
        drop(tx.send(f()));
    });
    rx.await.expect("The sender must not be dropped")
}

//---------------------------------------------------------------------------------------------------- Tests
#[cfg(test)]
mod test {
    use std::{
        sync::{Arc, Barrier},
        thread,
        time::Duration,
    };

    use super::*;

    #[tokio::test]
    // Assert that basic channel operations work.
    async fn infallible_oneshot_receiver() {
        let (tx, rx) = oneshot::channel::<String>();
        let msg = "hello world!".to_string();

        tx.send(msg.clone()).unwrap();

        let oneshot = InfallibleOneshotReceiver::from(rx);
        assert_eq!(oneshot.await, msg);
    }

    #[test]
    fn rayon_spawn_async_does_not_block() {
        // There must be more than 1 rayon thread for this to work.
        rayon::ThreadPoolBuilder::new()
            .num_threads(2)
            .build_global()
            .unwrap();

        // We use a barrier to make sure both tasks are executed together, we block the rayon thread
        // until both rayon threads are blocked.
        let barrier = Arc::new(Barrier::new(2));
        let task = |barrier: &Barrier| barrier.wait();

        let b_2 = Arc::clone(&barrier);

        let (tx, rx) = std::sync::mpsc::channel();

        thread::spawn(move || {
            let runtime = tokio::runtime::Builder::new_current_thread()
                .enable_all()
                .build()
                .unwrap();

            runtime.block_on(async {
                tokio::join!(
                    // This polls them concurrently in the same task, so if the first one blocks the task then
                    // the second wont run and if the second does not run the first does not unblock.
                    rayon_spawn_async(move || task(&barrier)),
                    rayon_spawn_async(move || task(&b_2)),
                )
            });

            // if we managed to get here then rayon_spawn_async didn't block.
            tx.send(()).unwrap();
        });

        rx.recv_timeout(Duration::from_secs(2))
            .expect("rayon_spawn_async blocked");
    }
}