1use crate::traits::SleepProvider;
4use futures::{Future, FutureExt};
5use pin_project::pin_project;
6use std::{
7 pin::Pin,
8 task::{Context, Poll},
9 time::{Duration, SystemTime},
10};
11
12#[derive(Copy, Clone, Debug, Eq, PartialEq)]
18#[allow(clippy::exhaustive_structs)]
19pub struct TimeoutError;
20impl std::error::Error for TimeoutError {}
21impl std::fmt::Display for TimeoutError {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "Timeout expired")
24 }
25}
26
27impl From<TimeoutError> for std::io::Error {
28 fn from(err: TimeoutError) -> std::io::Error {
29 std::io::Error::new(std::io::ErrorKind::TimedOut, err)
30 }
31}
32
33pub trait SleepProviderExt: SleepProvider {
35 #[must_use = "timeout() returns a future, which does nothing unless used"]
46 fn timeout<F: Future>(&self, duration: Duration, future: F) -> Timeout<F, Self::SleepFuture> {
47 let sleep_future = self.sleep(duration);
48
49 Timeout {
50 future,
51 sleep_future,
52 }
53 }
54
55 #[must_use = "sleep_until_wallclock() returns a future, which does nothing unless used"]
74 fn sleep_until_wallclock(&self, when: SystemTime) -> SleepUntilWallclock<'_, Self> {
75 SleepUntilWallclock {
76 provider: self,
77 target: when,
78 sleep_future: None,
79 }
80 }
81}
82
83impl<T: SleepProvider> SleepProviderExt for T {}
84
85#[pin_project]
87pub struct Timeout<T, S> {
88 #[pin]
90 future: T,
91 #[pin]
93 sleep_future: S,
94}
95
96impl<T, S> Future for Timeout<T, S>
97where
98 T: Future,
99 S: Future<Output = ()>,
100{
101 type Output = Result<T::Output, TimeoutError>;
102
103 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104 let this = self.project();
105 if let Poll::Ready(x) = this.future.poll(cx) {
106 return Poll::Ready(Ok(x));
107 }
108
109 match this.sleep_future.poll(cx) {
110 Poll::Pending => Poll::Pending,
111 Poll::Ready(()) => Poll::Ready(Err(TimeoutError)),
112 }
113 }
114}
115
116pub struct SleepUntilWallclock<'a, SP: SleepProvider> {
118 provider: &'a SP,
120 target: SystemTime,
122 sleep_future: Option<Pin<Box<SP::SleepFuture>>>,
124}
125
126impl<'a, SP> Future for SleepUntilWallclock<'a, SP>
127where
128 SP: SleepProvider,
129{
130 type Output = ();
131
132 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
133 let target = self.target;
139 loop {
140 let now = self.provider.wallclock();
141 if now >= target {
142 return Poll::Ready(());
143 }
144
145 let (last_delay, delay) = calc_next_delay(now, target);
146
147 self.sleep_future.take();
154
155 let mut sleep_future = Box::pin(self.provider.sleep(delay));
156 match sleep_future.poll_unpin(cx) {
157 Poll::Pending => {
158 self.sleep_future = Some(sleep_future);
159 return Poll::Pending;
160 }
161 Poll::Ready(()) => {
162 if last_delay {
163 return Poll::Ready(());
164 }
165 }
166 }
167 }
168 }
169}
170
171const MAX_SLEEP: Duration = Duration::from_secs(600);
177
178pub(crate) fn calc_next_delay(now: SystemTime, when: SystemTime) -> (bool, Duration) {
184 let remainder = when
185 .duration_since(now)
186 .unwrap_or_else(|_| Duration::from_secs(0));
187 if remainder > MAX_SLEEP {
188 (false, MAX_SLEEP)
189 } else {
190 (true, remainder)
191 }
192}
193
194#[cfg(test)]
195mod test {
196 #![allow(clippy::bool_assert_comparison)]
198 #![allow(clippy::clone_on_copy)]
199 #![allow(clippy::dbg_macro)]
200 #![allow(clippy::mixed_attributes_style)]
201 #![allow(clippy::print_stderr)]
202 #![allow(clippy::print_stdout)]
203 #![allow(clippy::single_char_pattern)]
204 #![allow(clippy::unwrap_used)]
205 #![allow(clippy::unchecked_duration_subtraction)]
206 #![allow(clippy::useless_vec)]
207 #![allow(clippy::needless_pass_by_value)]
208 #![allow(clippy::erasing_op)]
210
211 #[cfg(not(miri))]
212 use super::*;
213
214 #[cfg(not(miri))] #[test]
216 fn sleep_delay() {
217 fn calc(now: SystemTime, when: SystemTime) -> Duration {
218 calc_next_delay(now, when).1
219 }
220 let minute = Duration::from_secs(60);
221 let second = Duration::from_secs(1);
222 let start = SystemTime::now();
223
224 let target = start + 30 * minute;
225
226 assert_eq!(calc(start, target), minute * 10);
227 assert_eq!(calc(target + minute, target), minute * 0);
228 assert_eq!(calc(target, target), minute * 0);
229 assert_eq!(calc(target - second, target), second);
230 assert_eq!(calc(target - minute * 9, target), minute * 9);
231 assert_eq!(calc(target - minute * 11, target), minute * 10);
232 }
233}