tor_proto/util/token_bucket/
dynamic_writer.rs1use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::io::Error;
7use futures::stream::FusedStream;
8use futures::AsyncWrite;
9use tor_rtcompat::SleepProvider;
10
11use super::writer::{RateLimitedWriter, RateLimitedWriterConfig};
12
13#[derive(educe::Educe)]
18#[educe(Debug)]
19#[pin_project::pin_project]
20pub(crate) struct DynamicRateLimitedWriter<W: AsyncWrite, S, P: SleepProvider> {
21 #[pin]
23 writer: RateLimitedWriter<W, P>,
24 #[educe(Debug(ignore))]
26 #[pin]
27 updates: S,
28}
29
30impl<W, S, P> DynamicRateLimitedWriter<W, S, P>
31where
32 W: AsyncWrite,
33 P: SleepProvider,
34{
35 pub(crate) fn new(writer: RateLimitedWriter<W, P>, updates: S) -> Self {
39 Self { writer, updates }
40 }
41
42 pub(crate) fn inner(&self) -> &W {
44 self.writer.inner()
45 }
46}
47
48impl<W, S, P> AsyncWrite for DynamicRateLimitedWriter<W, S, P>
49where
50 W: AsyncWrite,
51 S: FusedStream<Item = RateLimitedWriterConfig>,
52 P: SleepProvider,
53{
54 fn poll_write(
55 mut self: Pin<&mut Self>,
56 cx: &mut Context<'_>,
57 buf: &[u8],
58 ) -> Poll<Result<usize, Error>> {
59 let mut self_ = self.as_mut().project();
60
61 let mut iters = 0;
71 while let Poll::Ready(Some(config)) = self_.updates.as_mut().poll_next(cx) {
72 let now = self_.writer.sleep_provider().now();
74 self_.writer.adjust(now, &config);
75
76 iters += 1;
81 if iters > 100_000 {
82 const MSG: &str =
83 "possible infinite loop in `DynamicRateLimitedWriter::poll_write`";
84 tracing::debug!(MSG);
85 return Poll::Ready(Err(Error::other(MSG)));
86 }
87 }
88
89 self_.writer.poll_write(cx, buf)
91 }
92
93 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
94 self.project().writer.poll_flush(cx)
95 }
96
97 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
98 self.project().writer.poll_close(cx)
99 }
100}
101
102#[cfg(feature = "tokio")]
105mod tokio_impl {
106 use super::*;
107
108 use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
109 use tokio_util::compat::FuturesAsyncWriteCompatExt;
110
111 use std::io::Result as IoResult;
112
113 impl<W, S, P> TokioAsyncWrite for DynamicRateLimitedWriter<W, S, P>
114 where
115 W: AsyncWrite,
116 S: FusedStream<Item = RateLimitedWriterConfig>,
117 P: SleepProvider,
118 {
119 fn poll_write(
120 self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 buf: &[u8],
123 ) -> Poll<IoResult<usize>> {
124 TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
125 }
126
127 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
128 TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
129 }
130
131 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
132 TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
133 }
134 }
135}
136
137#[cfg(test)]
138mod test {
139 #![allow(clippy::unwrap_used)]
140
141 use super::*;
142
143 use std::num::NonZero;
144 use std::time::Duration;
145
146 use futures::task::SpawnExt;
147 use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
148
149 #[cfg(feature = "tokio")]
150 use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
151
152 #[cfg(feature = "tokio")]
155 #[test]
156 fn alternating_on_off() {
157 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
158 let rt_clone = rt.clone();
160 rt.spawn(async move {
161 for _ in 0..8_000 {
162 rt_clone.progress_until_stalled().await;
163 rt_clone.advance_by(Duration::from_millis(1)).await;
164 }
165 })
166 .unwrap();
167
168 let config = RateLimitedWriterConfig {
170 rate: 0,
171 burst: 0,
172 wake_when_bytes_available: NonZero::new(10).unwrap(),
174 };
175
176 let (writer, reader) = tokio_crate::io::duplex(1000);
179 let writer = writer.compat_write();
180 let mut reader = reader.compat();
181
182 let writer = RateLimitedWriter::new(writer, &config, rt.clone());
183
184 let (mut rate_tx, rate_rx) = futures::channel::mpsc::unbounded();
186
187 let mut writer = DynamicRateLimitedWriter::new(writer, rate_rx);
189
190 const UPDATE_INTERVAL: Duration = Duration::from_millis(841);
194
195 let rt_clone = rt.clone();
197 rt.spawn(async move {
198 for rate in [100, 0, 200, 0, 400, 0] {
199 rt_clone.sleep(UPDATE_INTERVAL).await;
200
201 let mut config = config.clone();
203 config.rate = rate;
204 config.burst = rate;
205
206 rate_tx.send(config).now_or_never().unwrap().unwrap();
208 }
209 })
210 .unwrap();
211
212 rt.spawn(async move {
214 while writer.write(&[0; 100]).await.is_ok() {}
216 })
217 .unwrap();
218
219 let res_unwrap = Result::unwrap;
221
222 let mut buf = vec![0; 1000];
223 let buf = &mut buf;
224
225 rt.sleep(Duration::from_millis(1)).await;
227
228 rt.sleep(UPDATE_INTERVAL).await;
230 assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
231
232 rt.sleep(UPDATE_INTERVAL).await;
235 assert_eq!(Some(80), reader.read(buf).now_or_never().map(res_unwrap));
236
237 rt.sleep(UPDATE_INTERVAL).await;
239 assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
240
241 rt.sleep(UPDATE_INTERVAL).await;
244 assert_eq!(Some(160), reader.read(buf).now_or_never().map(res_unwrap));
245
246 rt.sleep(UPDATE_INTERVAL).await;
248 assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
249
250 rt.sleep(UPDATE_INTERVAL).await;
253 assert_eq!(Some(330), reader.read(buf).now_or_never().map(res_unwrap));
254
255 rt.sleep(UPDATE_INTERVAL).await;
257 assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
258 });
259 }
260}