axum/
util.rs

1use pin_project_lite::pin_project;
2use std::{ops::Deref, sync::Arc};
3
4pub(crate) use self::mutex::*;
5
6#[derive(Clone, Debug, PartialEq, Eq, Hash)]
7pub(crate) struct PercentDecodedStr(Arc<str>);
8
9impl PercentDecodedStr {
10    pub(crate) fn new<S>(s: S) -> Option<Self>
11    where
12        S: AsRef<str>,
13    {
14        percent_encoding::percent_decode(s.as_ref().as_bytes())
15            .decode_utf8()
16            .ok()
17            .map(|decoded| Self(decoded.as_ref().into()))
18    }
19
20    pub(crate) fn as_str(&self) -> &str {
21        &self.0
22    }
23}
24
25impl Deref for PercentDecodedStr {
26    type Target = str;
27
28    #[inline]
29    fn deref(&self) -> &Self::Target {
30        self.as_str()
31    }
32}
33
34pin_project! {
35    #[project = EitherProj]
36    pub(crate) enum Either<A, B> {
37        A { #[pin] inner: A },
38        B { #[pin] inner: B },
39    }
40}
41
42pub(crate) fn try_downcast<T, K>(k: K) -> Result<T, K>
43where
44    T: 'static,
45    K: Send + 'static,
46{
47    let mut k = Some(k);
48    if let Some(k) = <dyn std::any::Any>::downcast_mut::<Option<T>>(&mut k) {
49        Ok(k.take().unwrap())
50    } else {
51        Err(k.unwrap())
52    }
53}
54
55#[test]
56fn test_try_downcast() {
57    assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
58    assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
59}
60
61// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
62// times it's been locked on the current task. That way we can write a test to ensure we don't
63// accidentally introduce more locking.
64//
65// When not in test mode, it is just a type alias for `std::sync::Mutex`.
66#[cfg(not(test))]
67mod mutex {
68    #[allow(clippy::disallowed_types)]
69    pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
70}
71
72#[cfg(test)]
73#[allow(clippy::disallowed_types)]
74mod mutex {
75    use std::sync::{
76        atomic::{AtomicUsize, Ordering},
77        LockResult, Mutex, MutexGuard,
78    };
79
80    tokio::task_local! {
81        pub(crate) static NUM_LOCKED: AtomicUsize;
82    }
83
84    pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
85    where
86        F: FnOnce() -> Fut,
87        Fut: std::future::IntoFuture,
88    {
89        NUM_LOCKED
90            .scope(AtomicUsize::new(0), async move {
91                let output = f().await;
92                let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
93                (num, output)
94            })
95            .await
96    }
97
98    pub(crate) struct AxumMutex<T>(Mutex<T>);
99
100    impl<T> AxumMutex<T> {
101        pub(crate) fn new(value: T) -> Self {
102            Self(Mutex::new(value))
103        }
104
105        pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
106            self.0.get_mut()
107        }
108
109        pub(crate) fn into_inner(self) -> LockResult<T> {
110            self.0.into_inner()
111        }
112
113        pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
114            _ = NUM_LOCKED.try_with(|num| {
115                num.fetch_add(1, Ordering::SeqCst);
116            });
117            self.0.lock()
118        }
119    }
120}