tor_proto/util/token_bucket/
dynamic_writer.rs

1//! An [`AsyncWrite`] rate limiter which receives rate limit changes from a [`FusedStream`].
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::io::Error;
7use futures::stream::FusedStream;
8use futures::AsyncWrite;
9use tor_rtcompat::SleepProvider;
10
11use super::writer::{RateLimitedWriter, RateLimitedWriterConfig};
12
13/// A rate-limited async [writer](AsyncWrite).
14///
15/// This wraps a [`RateLimitedWriter`] and watches a stream for configuration changes (such as rate
16/// limit changes).
17#[derive(educe::Educe)]
18#[educe(Debug)]
19#[pin_project::pin_project]
20pub(crate) struct DynamicRateLimitedWriter<W: AsyncWrite, S, P: SleepProvider> {
21    /// The rate-limited writer.
22    #[pin]
23    writer: RateLimitedWriter<W, P>,
24    /// A stream that provides configuration updates, including rate limit updates.
25    #[educe(Debug(ignore))]
26    #[pin]
27    updates: S,
28}
29
30impl<W, S, P> DynamicRateLimitedWriter<W, S, P>
31where
32    W: AsyncWrite,
33    P: SleepProvider,
34{
35    /// Create a new [`DynamicRateLimitedWriter`].
36    ///
37    /// This wraps the `writer` and watches for configuration changes from the `updates` stream.
38    pub(crate) fn new(writer: RateLimitedWriter<W, P>, updates: S) -> Self {
39        Self { writer, updates }
40    }
41
42    /// Access the inner [`AsyncWrite`] writer of the [`RateLimitedWriter`].
43    pub(crate) fn inner(&self) -> &W {
44        self.writer.inner()
45    }
46}
47
48impl<W, S, P> AsyncWrite for DynamicRateLimitedWriter<W, S, P>
49where
50    W: AsyncWrite,
51    S: FusedStream<Item = RateLimitedWriterConfig>,
52    P: SleepProvider,
53{
54    fn poll_write(
55        mut self: Pin<&mut Self>,
56        cx: &mut Context<'_>,
57        buf: &[u8],
58    ) -> Poll<Result<usize, Error>> {
59        let mut self_ = self.as_mut().project();
60
61        // Try getting any update to the rate limit and burst.
62        //
63        // We loop until we receive `Ready(None)` or `Pending`. The former indicates that we
64        // shouldn't receive any more updates. The latter indicates that there aren't currently more
65        // to read, and that we've registered the waker with the stream so that we'll wake when the
66        // rate limit is later updated.
67        //
68        // Since `S` is a `FusedStream`, it's fine to call `poll_next()` even if `Ready(None)` was
69        // returned in the past.
70        let mut iters = 0;
71        while let Poll::Ready(Some(config)) = self_.updates.as_mut().poll_next(cx) {
72            // update the writer's configuration
73            let now = self_.writer.sleep_provider().now();
74            self_.writer.adjust(now, &config);
75
76            // It's possible that `DynamicRateLimitedWriter` was constructed with a stream where an
77            // infinite number of items will be immediately ready, for example with
78            // `futures::stream::repeat()`. We escape the possible infinite loop by returning an
79            // error.
80            iters += 1;
81            if iters > 100_000 {
82                const MSG: &str =
83                    "possible infinite loop in `DynamicRateLimitedWriter::poll_write`";
84                tracing::debug!(MSG);
85                return Poll::Ready(Err(Error::other(MSG)));
86            }
87        }
88
89        // Try writing the bytes. This also registers the waker with the `RateLimitedWriter`.
90        self_.writer.poll_write(cx, buf)
91    }
92
93    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
94        self.project().writer.poll_flush(cx)
95    }
96
97    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
98        self.project().writer.poll_close(cx)
99    }
100}
101
102/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
103/// everywhere.
104#[cfg(feature = "tokio")]
105mod tokio_impl {
106    use super::*;
107
108    use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
109    use tokio_util::compat::FuturesAsyncWriteCompatExt;
110
111    use std::io::Result as IoResult;
112
113    impl<W, S, P> TokioAsyncWrite for DynamicRateLimitedWriter<W, S, P>
114    where
115        W: AsyncWrite,
116        S: FusedStream<Item = RateLimitedWriterConfig>,
117        P: SleepProvider,
118    {
119        fn poll_write(
120            self: Pin<&mut Self>,
121            cx: &mut Context<'_>,
122            buf: &[u8],
123        ) -> Poll<IoResult<usize>> {
124            TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
125        }
126
127        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
128            TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
129        }
130
131        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
132            TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
133        }
134    }
135}
136
137#[cfg(test)]
138mod test {
139    #![allow(clippy::unwrap_used)]
140
141    use super::*;
142
143    use std::num::NonZero;
144    use std::time::Duration;
145
146    use futures::task::SpawnExt;
147    use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
148
149    #[cfg(feature = "tokio")]
150    use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
151
152    /// This test ensures that a [`DynamicRateLimitedWriter`] writes the expected number of bytes,
153    /// as a background task alternates the rate limit between on/off once every second.
154    #[cfg(feature = "tokio")]
155    #[test]
156    fn alternating_on_off() {
157        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
158            // drive time forward from 0 to 8_000 ms in 1 ms intervals
159            let rt_clone = rt.clone();
160            rt.spawn(async move {
161                for _ in 0..8_000 {
162                    rt_clone.progress_until_stalled().await;
163                    rt_clone.advance_by(Duration::from_millis(1)).await;
164                }
165            })
166            .unwrap();
167
168            // start with a rate limiter that doesn't allow any bytes
169            let config = RateLimitedWriterConfig {
170                rate: 0,
171                burst: 0,
172                // wake up the writer each time the rate limiter allows 10 bytes to be sent
173                wake_when_bytes_available: NonZero::new(10).unwrap(),
174            };
175
176            // there are some other crates which allow you to make a data "pipe" without tokio, but
177            // I don't think it's worth bringing in a new dev-dependency for this
178            let (writer, reader) = tokio_crate::io::duplex(/* max_buf_size= */ 1000);
179            let writer = writer.compat_write();
180            let mut reader = reader.compat();
181
182            let writer = RateLimitedWriter::new(writer, &config, rt.clone());
183
184            // how we send rate updates to the rate-limited writer
185            let (mut rate_tx, rate_rx) = futures::channel::mpsc::unbounded();
186
187            // our rate-limited writer which can receive rate limit changes
188            let mut writer = DynamicRateLimitedWriter::new(writer, rate_rx);
189
190            /// Duration between updates. A prime number is used so that smaller intervals don't
191            /// fall on this interval, which can causes issues with `MockRuntime::test_with_various`
192            /// since the test becomes dependent on the order that tasks are woken.
193            const UPDATE_INTERVAL: Duration = Duration::from_millis(841);
194
195            // a background task which sends alternating on/off rate limits every 841 ms
196            let rt_clone = rt.clone();
197            rt.spawn(async move {
198                for rate in [100, 0, 200, 0, 400, 0] {
199                    rt_clone.sleep(UPDATE_INTERVAL).await;
200
201                    // update the rate/burst
202                    let mut config = config.clone();
203                    config.rate = rate;
204                    config.burst = rate;
205
206                    // we expect the send() to succeed immediately
207                    rate_tx.send(config).now_or_never().unwrap().unwrap();
208                }
209            })
210            .unwrap();
211
212            // a background task which writes as much as possible
213            rt.spawn(async move {
214                // write until the receiving end goes away
215                while writer.write(&[0; 100]).await.is_ok() {}
216            })
217            .unwrap();
218
219            // helper to make the `assert_eq` a single line
220            let res_unwrap = Result::unwrap;
221
222            let mut buf = vec![0; 1000];
223            let buf = &mut buf;
224
225            // sleep for 1 ms so that our upcoming sleeps end 1 ms after the rate limit changes
226            rt.sleep(Duration::from_millis(1)).await;
227
228            // Rate is 0, so no bytes expected.
229            rt.sleep(UPDATE_INTERVAL).await;
230            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
231
232            // Rate is 100 bytes/s, so 841/(1000/100) = 84 bytes expected.
233            // Woken every `wake_when_bytes_available` = 10 bytes, so 80 bytes expected.
234            rt.sleep(UPDATE_INTERVAL).await;
235            assert_eq!(Some(80), reader.read(buf).now_or_never().map(res_unwrap));
236
237            // Rate is 0, so no bytes expected.
238            rt.sleep(UPDATE_INTERVAL).await;
239            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
240
241            // Rate is 200 bytes/s, so 841/(1000/200) = 168 bytes expected.
242            // Woken every `wake_when_bytes_available` = 10 bytes, so 160 bytes expected.
243            rt.sleep(UPDATE_INTERVAL).await;
244            assert_eq!(Some(160), reader.read(buf).now_or_never().map(res_unwrap));
245
246            // Rate is 0, so no bytes expected.
247            rt.sleep(UPDATE_INTERVAL).await;
248            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
249
250            // Rate is 400 bytes/s, so 841/(1000/400) = 336 bytes expected.
251            // Woken every `wake_when_bytes_available` = 10 bytes, so 330 bytes expected.
252            rt.sleep(UPDATE_INTERVAL).await;
253            assert_eq!(Some(330), reader.read(buf).now_or_never().map(res_unwrap));
254
255            // Rate is 0, so no bytes expected.
256            rt.sleep(UPDATE_INTERVAL).await;
257            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
258        });
259    }
260}