tor_proto/stream/
queue.rs

1//! Queues for stream messages.
2//!
3//! While these are technically "channels", we call them "queues" to indicate that they're mostly
4//! just dumb pipes. They do some tracking (memquota and size), but nothing else. The higher-level
5//! object is [`StreamReceiver`](crate::stream::raw::StreamReceiver) which tracks SENDME and END
6//! messages. So the idea is that the "queue" (ex: [`StreamQueueReceiver`]) just holds data and the
7//! "channel" (ex: `StreamReceiver`) adds the Tor logic.
8//!
9//! The main purpose of these types are so that we can count how many bytes of stream data are
10//! stored for the stream. Ideally we'd use a channel type that tracks and reports this as part of
11//! its implementation, but popular channel implementations don't seem to do that.
12
13use std::fmt::Debug;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex};
16use std::task::{Context, Poll};
17
18use futures::{Sink, SinkExt, Stream};
19use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
20use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
21use tor_async_utils::SinkTrySend;
22use tor_cell::relaycell::UnparsedRelayMsg;
23use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec};
24use tor_rtcompat::DynTimeProvider;
25
26use crate::memquota::{SpecificAccount, StreamAccount};
27
28/// Create a new stream queue for incoming messages.
29pub(crate) fn stream_queue(
30    size: usize,
31    memquota: &StreamAccount,
32    time_prov: &DynTimeProvider,
33) -> Result<(StreamQueueSender, StreamQueueReceiver), tor_memquota::Error> {
34    let (sender, receiver) =
35        MpscSpec::new(size).new_mq(time_prov.clone(), memquota.as_raw_account())?;
36    let receiver = StreamUnobtrusivePeeker::new(receiver);
37    let counter = Arc::new(Mutex::new(0));
38    Ok((
39        StreamQueueSender {
40            sender,
41            counter: Arc::clone(&counter),
42        },
43        StreamQueueReceiver { receiver, counter },
44    ))
45}
46
47/// For testing purposes, create a stream queue wth a no-op memquota account and a fake time
48/// provider.
49#[cfg(test)]
50pub(crate) fn fake_stream_queue(size: usize) -> (StreamQueueSender, StreamQueueReceiver) {
51    // The fake Account doesn't care about the data ages, so this will do.
52    //
53    // This would be wrong to use generally in tests, where we might want to mock time,
54    // since we end up, here with totally *different* mocked time.
55    // But it's OK here, and saves passing a runtime parameter into this function.
56    stream_queue(
57        size,
58        &StreamAccount::new_noop(),
59        &DynTimeProvider::new(tor_rtmock::MockRuntime::default()),
60    )
61    .expect("create fake stream queue")
62}
63
64/// The sending end of a channel of incoming stream messages.
65#[derive(Debug)]
66#[pin_project::pin_project]
67pub(crate) struct StreamQueueSender {
68    /// The inner sender.
69    #[pin]
70    sender: mq_queue::Sender<UnparsedRelayMsg, MpscSpec>,
71    /// Number of bytes within the queue.
72    counter: Arc<Mutex<usize>>,
73}
74
75/// The receiving end of a channel of incoming stream messages.
76#[derive(Debug)]
77#[pin_project::pin_project]
78pub(crate) struct StreamQueueReceiver {
79    /// The inner receiver.
80    ///
81    /// We add the [`StreamUnobtrusivePeeker`] here so that peeked messages are included in
82    /// `counter`.
83    // TODO(arti#534): the possible extra msg held by the `StreamUnobtrusivePeeker` isn't tracked by
84    // memquota
85    #[pin]
86    receiver: StreamUnobtrusivePeeker<mq_queue::Receiver<UnparsedRelayMsg, MpscSpec>>,
87    /// Number of bytes within the queue.
88    counter: Arc<Mutex<usize>>,
89}
90
91impl StreamQueueSender {
92    /// Get the approximate number of data bytes queued for this stream.
93    ///
94    /// As messages can be dequeued at any time, the return value may be larger than the actual
95    /// number of bytes queued for this stream.
96    pub(crate) fn approx_stream_bytes(&self) -> usize {
97        *self.counter.lock().expect("poisoned")
98    }
99}
100
101impl StreamQueueReceiver {
102    /// Get the approximate number of data bytes queued for this stream.
103    ///
104    /// As messages can be enqueued at any time, the return value may be smaller than the actual
105    /// number of bytes queued for this stream.
106    pub(crate) fn approx_stream_bytes(&self) -> usize {
107        *self.counter.lock().expect("poisoned")
108    }
109}
110
111impl Sink<UnparsedRelayMsg> for StreamQueueSender {
112    type Error = <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as Sink<UnparsedRelayMsg>>::Error;
113
114    fn poll_ready(
115        mut self: Pin<&mut Self>,
116        cx: &mut Context<'_>,
117    ) -> Poll<std::result::Result<(), Self::Error>> {
118        self.sender.poll_ready_unpin(cx)
119    }
120
121    fn start_send(
122        mut self: Pin<&mut Self>,
123        item: UnparsedRelayMsg,
124    ) -> std::result::Result<(), Self::Error> {
125        let mut self_ = self.as_mut().project();
126
127        let stream_data_len = data_len(&item);
128
129        // This lock ensures that us sending the item and the counter increase are done
130        // "atomically", so that the receiver doesn't see the item and try to decrement the
131        // counter before we've incremented the counter, which could cause an underflow.
132        let mut counter = self_.counter.lock().expect("poisoned");
133
134        self_.sender.start_send_unpin(item)?;
135
136        *counter = counter
137            .checked_add(stream_data_len.into())
138            .expect("queue has more than `usize::MAX` bytes?!");
139
140        Ok(())
141    }
142
143    fn poll_flush(
144        mut self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146    ) -> Poll<std::result::Result<(), Self::Error>> {
147        self.sender.poll_flush_unpin(cx)
148    }
149
150    fn poll_close(
151        mut self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153    ) -> Poll<std::result::Result<(), Self::Error>> {
154        self.sender.poll_close_unpin(cx)
155    }
156}
157
158impl SinkTrySend<UnparsedRelayMsg> for StreamQueueSender {
159    type Error =
160        <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as SinkTrySend<UnparsedRelayMsg>>::Error;
161
162    fn try_send_or_return(
163        mut self: Pin<&mut Self>,
164        item: UnparsedRelayMsg,
165    ) -> Result<
166        (),
167        (
168            <Self as SinkTrySend<UnparsedRelayMsg>>::Error,
169            UnparsedRelayMsg,
170        ),
171    > {
172        let self_ = self.as_mut().project();
173
174        let stream_data_len = data_len(&item);
175
176        // See comments in `StreamQueueSender::start_send`.
177        let mut counter = self_.counter.lock().expect("poisoned");
178
179        self_.sender.try_send_or_return(item)?;
180
181        *counter = counter
182            .checked_add(stream_data_len.into())
183            .expect("queue has more than `usize::MAX` bytes?!");
184
185        Ok(())
186    }
187}
188
189impl Stream for StreamQueueReceiver {
190    type Item = UnparsedRelayMsg;
191
192    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
193        let self_ = self.as_mut().project();
194
195        // This lock ensures that us receiving the item and the counter decrease are done
196        // "atomically", so that the sender doesn't send a new item and try to increase the
197        // counter before we've decreased the counter, which could cause an overflow.
198        let mut counter = self_.counter.lock().expect("poisoned");
199
200        let item = match self_.receiver.poll_next(cx) {
201            Poll::Ready(Some(x)) => x,
202            Poll::Ready(None) => return Poll::Ready(None),
203            Poll::Pending => return Poll::Pending,
204        };
205
206        let stream_data_len = data_len(&item);
207
208        if stream_data_len != 0 {
209            *counter = counter
210                .checked_sub(stream_data_len.into())
211                .expect("we've removed more bytes than we've added?!");
212        }
213
214        Poll::Ready(Some(item))
215    }
216}
217
218impl UnobtrusivePeekableStream for StreamQueueReceiver {
219    fn unobtrusive_peek_mut<'s>(
220        self: Pin<&'s mut Self>,
221    ) -> Option<&'s mut <Self as futures::Stream>::Item> {
222        self.project().receiver.unobtrusive_peek_mut()
223    }
224}
225
226/// The `length` field of the message, or 0 if not a data message.
227///
228/// If the RELAY_DATA message had an invalid length field, we just ignore the message.
229/// The receiver will find out eventually when it tries to parse the message.
230/// We could return an error here, but for now I think it's best not to behave as if this
231/// queue is performing any validation.
232///
233/// This is its own function so that all parts of the code use the same logic.
234fn data_len(item: &UnparsedRelayMsg) -> u16 {
235    item.data_len().unwrap_or(0)
236}