tor_proto/congestion/
sendme.rs1use std::collections::VecDeque;
14
15use tor_cell::relaycell::RelayCmd;
16use tor_cell::relaycell::UnparsedRelayMsg;
17use tor_error::internal;
18
19use crate::{Error, Result};
20
21pub(crate) type CircSendWindow = SendWindow<CircParams>;
23pub(crate) type StreamSendWindow = SendWindow<StreamParams>;
25
26pub(crate) type CircRecvWindow = RecvWindow<CircParams>;
28pub(crate) type StreamRecvWindow = RecvWindow<StreamParams>;
30
31#[derive(Clone, Debug)]
37pub(crate) struct SendWindow<P>
38where
39 P: WindowParams,
40{
41 window: u16,
43 _dummy: std::marker::PhantomData<P>,
45}
46
47pub(crate) trait WindowParams {
49 #[allow(dead_code)] fn maximum() -> u16;
52 fn increment() -> u16;
54 fn start() -> u16;
56}
57
58#[derive(Clone, Debug)]
61pub(crate) struct CircParams;
62impl WindowParams for CircParams {
63 fn maximum() -> u16 {
64 1000
65 }
66 fn increment() -> u16 {
67 100
68 }
69 fn start() -> u16 {
70 1000
71 }
72}
73
74#[derive(Clone, Debug)]
77pub(crate) struct StreamParams;
78impl WindowParams for StreamParams {
79 fn maximum() -> u16 {
80 500
81 }
82 fn increment() -> u16 {
83 50
84 }
85 fn start() -> u16 {
86 500
87 }
88}
89
90#[derive(Clone, Debug)]
92pub(crate) struct SendmeValidator<T>
93where
94 T: PartialEq + Eq + Clone,
95{
96 tags: VecDeque<T>,
99}
100
101impl<T> SendmeValidator<T>
102where
103 T: PartialEq + Eq + Clone,
104{
105 pub(crate) fn new() -> Self {
107 Self {
108 tags: VecDeque::new(),
109 }
110 }
111
112 pub(crate) fn record<U>(&mut self, tag: &U)
114 where
115 U: Clone + Into<T>,
116 {
117 self.tags.push_back(tag.clone().into());
118 }
119
120 pub(crate) fn validate<U>(&mut self, tag: Option<U>) -> Result<()>
123 where
124 T: PartialEq<U>,
125 {
126 match (self.tags.front(), tag) {
127 (Some(t), Some(tag)) if t == &tag => {} (Some(_), None) => {} (Some(_), Some(_)) => {
130 return Err(Error::CircProto("Mismatched tag on circuit SENDME".into()));
131 }
132 (None, _) => {
133 return Err(Error::CircProto(
134 "Received a SENDME when none was expected".into(),
135 ));
136 }
137 }
138 self.tags.pop_front();
139 Ok(())
140 }
141
142 #[cfg(test)]
143 pub(crate) fn expected_tags(&self) -> Vec<T> {
144 self.tags.iter().map(Clone::clone).collect()
145 }
146}
147
148impl<P> SendWindow<P>
149where
150 P: WindowParams,
151{
152 pub(crate) fn new(window: u16) -> SendWindow<P> {
154 SendWindow {
155 window,
156 _dummy: std::marker::PhantomData,
157 }
158 }
159
160 pub(crate) fn should_record_tag(&self) -> bool {
162 self.window % P::increment() == 0
163 }
164
165 pub(crate) fn take(&mut self) -> Result<()> {
168 self.window = self.window.checked_sub(1).ok_or(Error::CircProto(
169 "Called SendWindow::take() on empty SendWindow".into(),
170 ))?;
171 Ok(())
172 }
173
174 #[must_use = "didn't check whether SENDME was expected."]
178 pub(crate) fn put(&mut self) -> Result<()> {
179 let new_window = self
181 .window
182 .checked_add(P::increment())
183 .ok_or(Error::from(internal!("Overflow on SENDME window")))?;
184 if new_window > P::maximum() {
186 return Err(Error::CircProto("Unexpected stream SENDME".into()));
187 }
188 self.window = new_window;
189 Ok(())
190 }
191
192 pub(crate) fn window(&self) -> u16 {
194 self.window
195 }
196}
197
198#[derive(Clone, Debug)]
200pub(crate) struct RecvWindow<P: WindowParams> {
201 window: u16,
204 _dummy: std::marker::PhantomData<P>,
206}
207
208impl<P: WindowParams> RecvWindow<P> {
209 pub(crate) fn new(window: u16) -> RecvWindow<P> {
211 RecvWindow {
212 window,
213 _dummy: std::marker::PhantomData,
214 }
215 }
216
217 pub(crate) fn take(&mut self) -> Result<bool> {
223 let v = self.window.checked_sub(1);
224 if let Some(x) = v {
225 self.window = x;
226 Ok(x % P::increment() == 0)
229 } else {
230 Err(Error::CircProto(
231 "Received a data cell in violation of a window".into(),
232 ))
233 }
234 }
235
236 pub(crate) fn decrement_n(&mut self, n: u16) -> crate::Result<()> {
238 self.window = self.window.checked_sub(n).ok_or(Error::CircProto(
239 "Received too many cells on a stream".into(),
240 ))?;
241 Ok(())
242 }
243
244 pub(crate) fn put(&mut self) {
246 self.window = self
247 .window
248 .checked_add(P::increment())
249 .expect("Overflow detected while attempting to increment window");
250 }
251}
252
253pub(crate) fn cmd_counts_towards_windows(cmd: RelayCmd) -> bool {
255 cmd == RelayCmd::DATA
256}
257
258#[cfg(test)]
260pub(crate) fn msg_counts_towards_windows(msg: &tor_cell::relaycell::msg::AnyRelayMsg) -> bool {
261 use tor_cell::relaycell::RelayMsg;
262 cmd_counts_towards_windows(msg.cmd())
263}
264
265pub(crate) fn cell_counts_towards_windows(cell: &UnparsedRelayMsg) -> bool {
267 cmd_counts_towards_windows(cell.cmd())
268}
269
270#[cfg(test)]
271mod test {
272 #![allow(clippy::bool_assert_comparison)]
274 #![allow(clippy::clone_on_copy)]
275 #![allow(clippy::dbg_macro)]
276 #![allow(clippy::mixed_attributes_style)]
277 #![allow(clippy::print_stderr)]
278 #![allow(clippy::print_stdout)]
279 #![allow(clippy::single_char_pattern)]
280 #![allow(clippy::unwrap_used)]
281 #![allow(clippy::unchecked_duration_subtraction)]
282 #![allow(clippy::useless_vec)]
283 #![allow(clippy::needless_pass_by_value)]
284 use super::*;
286 use tor_basic_utils::test_rng::testing_rng;
287 use tor_cell::relaycell::{msg, AnyRelayMsgOuter, RelayCellFormat, StreamId};
288
289 #[test]
290 fn what_counts() {
291 let mut rng = testing_rng();
292 let fmt = RelayCellFormat::V0;
293 let m = msg::Begin::new("www.torproject.org", 443, 0)
294 .unwrap()
295 .into();
296 assert!(!msg_counts_towards_windows(&m));
297 assert!(!cell_counts_towards_windows(
298 &UnparsedRelayMsg::from_singleton_body(
299 RelayCellFormat::V0,
300 AnyRelayMsgOuter::new(StreamId::new(77), m)
301 .encode(fmt, &mut rng)
302 .unwrap()
303 )
304 .unwrap()
305 ));
306
307 let m = msg::Data::new(&b"Education is not a prerequisite to political control-political control is the cause of popular education."[..]).unwrap().into(); assert!(msg_counts_towards_windows(&m));
309 assert!(cell_counts_towards_windows(
310 &UnparsedRelayMsg::from_singleton_body(
311 RelayCellFormat::V0,
312 AnyRelayMsgOuter::new(StreamId::new(128), m)
313 .encode(fmt, &mut rng)
314 .unwrap()
315 )
316 .unwrap()
317 ));
318 }
319
320 #[test]
321 fn recvwindow() {
322 let mut w: RecvWindow<StreamParams> = RecvWindow::new(500);
323
324 for _ in 0..49 {
325 assert!(!w.take().unwrap());
326 }
327 assert!(w.take().unwrap());
328 assert_eq!(w.window, 450);
329
330 assert!(w.decrement_n(123).is_ok());
331 assert_eq!(w.window, 327);
332
333 w.put();
334 assert_eq!(w.window, 377);
335
336 assert!(w.decrement_n(400).is_err());
338 assert!(w.decrement_n(377).is_ok());
340 assert!(w.take().is_err());
341 }
342
343 fn new_sendwindow() -> SendWindow<CircParams> {
344 SendWindow::new(1000)
345 }
346
347 #[test]
348 fn sendwindow_basic() -> Result<()> {
349 let mut w = new_sendwindow();
350
351 w.take()?;
352 assert_eq!(w.window(), 999);
353 for _ in 0_usize..98 {
354 w.take()?;
355 }
356 assert_eq!(w.window(), 901);
357
358 w.take()?;
359 assert_eq!(w.window(), 900);
360
361 w.take()?;
362 assert_eq!(w.window(), 899);
363
364 w.put()?;
366 assert_eq!(w.window(), 999);
367
368 for _ in 0_usize..300 {
369 w.take()?;
370 }
371
372 w.put()?;
374 assert_eq!(w.window(), 799);
375
376 Ok(())
377 }
378
379 #[test]
380 fn sendwindow_erroring() -> Result<()> {
381 let mut w = new_sendwindow();
382 for _ in 0_usize..1000 {
383 w.take()?;
384 }
385 assert_eq!(w.window(), 0);
386
387 let ready = w.take();
388 assert!(ready.is_err());
389 Ok(())
390 }
391}