tor_proto/stream/
queue.rs1use std::fmt::Debug;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex};
16use std::task::{Context, Poll};
17
18use futures::{Sink, SinkExt, Stream};
19use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
20use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
21use tor_async_utils::SinkTrySend;
22use tor_cell::relaycell::UnparsedRelayMsg;
23use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec};
24use tor_rtcompat::DynTimeProvider;
25
26use crate::memquota::{SpecificAccount, StreamAccount};
27
28pub(crate) fn stream_queue(
30 size: usize,
31 memquota: &StreamAccount,
32 time_prov: &DynTimeProvider,
33) -> Result<(StreamQueueSender, StreamQueueReceiver), tor_memquota::Error> {
34 let (sender, receiver) =
35 MpscSpec::new(size).new_mq(time_prov.clone(), memquota.as_raw_account())?;
36 let receiver = StreamUnobtrusivePeeker::new(receiver);
37 let counter = Arc::new(Mutex::new(0));
38 Ok((
39 StreamQueueSender {
40 sender,
41 counter: Arc::clone(&counter),
42 },
43 StreamQueueReceiver { receiver, counter },
44 ))
45}
46
47#[cfg(test)]
50pub(crate) fn fake_stream_queue(size: usize) -> (StreamQueueSender, StreamQueueReceiver) {
51 stream_queue(
57 size,
58 &StreamAccount::new_noop(),
59 &DynTimeProvider::new(tor_rtmock::MockRuntime::default()),
60 )
61 .expect("create fake stream queue")
62}
63
64#[derive(Debug)]
66#[pin_project::pin_project]
67pub(crate) struct StreamQueueSender {
68 #[pin]
70 sender: mq_queue::Sender<UnparsedRelayMsg, MpscSpec>,
71 counter: Arc<Mutex<usize>>,
73}
74
75#[derive(Debug)]
77#[pin_project::pin_project]
78pub(crate) struct StreamQueueReceiver {
79 #[pin]
86 receiver: StreamUnobtrusivePeeker<mq_queue::Receiver<UnparsedRelayMsg, MpscSpec>>,
87 counter: Arc<Mutex<usize>>,
89}
90
91impl StreamQueueSender {
92 pub(crate) fn approx_stream_bytes(&self) -> usize {
97 *self.counter.lock().expect("poisoned")
98 }
99}
100
101impl StreamQueueReceiver {
102 pub(crate) fn approx_stream_bytes(&self) -> usize {
107 *self.counter.lock().expect("poisoned")
108 }
109}
110
111impl Sink<UnparsedRelayMsg> for StreamQueueSender {
112 type Error = <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as Sink<UnparsedRelayMsg>>::Error;
113
114 fn poll_ready(
115 mut self: Pin<&mut Self>,
116 cx: &mut Context<'_>,
117 ) -> Poll<std::result::Result<(), Self::Error>> {
118 self.sender.poll_ready_unpin(cx)
119 }
120
121 fn start_send(
122 mut self: Pin<&mut Self>,
123 item: UnparsedRelayMsg,
124 ) -> std::result::Result<(), Self::Error> {
125 let mut self_ = self.as_mut().project();
126
127 let stream_data_len = data_len(&item);
128
129 let mut counter = self_.counter.lock().expect("poisoned");
133
134 self_.sender.start_send_unpin(item)?;
135
136 *counter = counter
137 .checked_add(stream_data_len.into())
138 .expect("queue has more than `usize::MAX` bytes?!");
139
140 Ok(())
141 }
142
143 fn poll_flush(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 ) -> Poll<std::result::Result<(), Self::Error>> {
147 self.sender.poll_flush_unpin(cx)
148 }
149
150 fn poll_close(
151 mut self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 ) -> Poll<std::result::Result<(), Self::Error>> {
154 self.sender.poll_close_unpin(cx)
155 }
156}
157
158impl SinkTrySend<UnparsedRelayMsg> for StreamQueueSender {
159 type Error =
160 <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as SinkTrySend<UnparsedRelayMsg>>::Error;
161
162 fn try_send_or_return(
163 mut self: Pin<&mut Self>,
164 item: UnparsedRelayMsg,
165 ) -> Result<
166 (),
167 (
168 <Self as SinkTrySend<UnparsedRelayMsg>>::Error,
169 UnparsedRelayMsg,
170 ),
171 > {
172 let self_ = self.as_mut().project();
173
174 let stream_data_len = data_len(&item);
175
176 let mut counter = self_.counter.lock().expect("poisoned");
178
179 self_.sender.try_send_or_return(item)?;
180
181 *counter = counter
182 .checked_add(stream_data_len.into())
183 .expect("queue has more than `usize::MAX` bytes?!");
184
185 Ok(())
186 }
187}
188
189impl Stream for StreamQueueReceiver {
190 type Item = UnparsedRelayMsg;
191
192 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
193 let self_ = self.as_mut().project();
194
195 let mut counter = self_.counter.lock().expect("poisoned");
199
200 let item = match self_.receiver.poll_next(cx) {
201 Poll::Ready(Some(x)) => x,
202 Poll::Ready(None) => return Poll::Ready(None),
203 Poll::Pending => return Poll::Pending,
204 };
205
206 let stream_data_len = data_len(&item);
207
208 if stream_data_len != 0 {
209 *counter = counter
210 .checked_sub(stream_data_len.into())
211 .expect("we've removed more bytes than we've added?!");
212 }
213
214 Poll::Ready(Some(item))
215 }
216}
217
218impl UnobtrusivePeekableStream for StreamQueueReceiver {
219 fn unobtrusive_peek_mut<'s>(
220 self: Pin<&'s mut Self>,
221 ) -> Option<&'s mut <Self as futures::Stream>::Item> {
222 self.project().receiver.unobtrusive_peek_mut()
223 }
224}
225
226fn data_len(item: &UnparsedRelayMsg) -> u16 {
235 item.data_len().unwrap_or(0)
236}