tor_proto/channel/
circmap.rs1use crate::{Error, Result};
7use tor_basic_utils::RngExt;
8use tor_cell::chancell::CircId;
9
10use crate::tunnel::circuit::halfcirc::HalfCirc;
11use crate::tunnel::circuit::{celltypes::CreateResponse, CircuitRxSender};
12
13use oneshot_fused_workaround as oneshot;
14
15use rand::distr::Distribution;
16use rand::Rng;
17use std::collections::{hash_map::Entry, HashMap};
18use std::ops::{Deref, DerefMut};
19
20#[derive(Copy, Clone)]
25pub(super) enum CircIdRange {
26 #[allow(dead_code)] Low,
29 High,
31 }
35
36impl rand::distr::Distribution<CircId> for CircIdRange {
37 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> CircId {
39 let midpoint = 0x8000_0000_u32;
40 let v = match self {
41 CircIdRange::Low => rng.gen_range_checked(1..midpoint),
43 CircIdRange::High => rng.gen_range_checked(midpoint..=u32::MAX),
44 };
45 let v = v.expect("Unexpected empty range passed to gen_range_checked");
46 CircId::new(v).expect("Unexpected zero value")
47 }
48}
49
50#[derive(Debug)]
54pub(super) enum CircEnt {
55 Opening(oneshot::Sender<CreateResponse>, CircuitRxSender),
64
65 Open(CircuitRxSender),
67
68 DestroySent(HalfCirc),
71}
72
73pub(super) struct MutCircEnt<'a> {
79 value: &'a mut CircEnt,
81 open_count: &'a mut usize,
84 was_open: bool,
86}
87
88impl<'a> Drop for MutCircEnt<'a> {
89 fn drop(&mut self) {
90 let is_open = !matches!(self.value, CircEnt::DestroySent(_));
91 match (self.was_open, is_open) {
92 (false, true) => *self.open_count = self.open_count.saturating_add(1),
93 (true, false) => *self.open_count = self.open_count.saturating_sub(1),
94 (_, _) => (),
95 };
96 }
97}
98
99impl<'a> Deref for MutCircEnt<'a> {
100 type Target = CircEnt;
101 fn deref(&self) -> &Self::Target {
102 self.value
103 }
104}
105
106impl<'a> DerefMut for MutCircEnt<'a> {
107 fn deref_mut(&mut self) -> &mut Self::Target {
108 self.value
109 }
110}
111
112pub(super) struct CircMap {
114 m: HashMap<CircId, CircEnt>,
116 range: CircIdRange,
118 open_count: usize,
120}
121
122impl CircMap {
123 pub(super) fn new(idrange: CircIdRange) -> Self {
125 CircMap {
126 m: HashMap::new(),
127 range: idrange,
128 open_count: 0,
129 }
130 }
131
132 pub(super) fn add_ent<R: Rng>(
137 &mut self,
138 rng: &mut R,
139 createdsink: oneshot::Sender<CreateResponse>,
140 sink: CircuitRxSender,
141 ) -> Result<CircId> {
142 const N_ATTEMPTS: usize = 16;
147 let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
148 let circ_ent = CircEnt::Opening(createdsink, sink);
149 for id in iter {
150 let ent = self.m.entry(id);
151 if let Entry::Vacant(_) = &ent {
152 ent.or_insert(circ_ent);
153 self.open_count += 1;
154 return Ok(id);
155 }
156 }
157 Err(Error::IdRangeFull)
158 }
159
160 #[cfg(test)]
163 pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
164 self.m.insert(id, ent);
165 }
166
167 pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
169 let open_count = &mut self.open_count;
170 self.m.get_mut(&id).map(move |ent| MutCircEnt {
171 open_count,
172 was_open: !matches!(ent, CircEnt::DestroySent(_)),
173 value: ent,
174 })
175 }
176
177 pub(super) fn advance_from_opening(
180 &mut self,
181 id: CircId,
182 ) -> Result<oneshot::Sender<CreateResponse>> {
183 let ok = matches!(self.m.get(&id), Some(CircEnt::Opening(_, _)));
188 if ok {
189 if let Some(CircEnt::Opening(oneshot, sink)) = self.m.remove(&id) {
190 self.m.insert(id, CircEnt::Open(sink));
191 Ok(oneshot)
192 } else {
193 panic!("internal error: inconsistent circuit state");
194 }
195 } else {
196 Err(Error::ChanProto(
197 "Unexpected CREATED* cell not on opening circuit".into(),
198 ))
199 }
200 }
201
202 pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
206 if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
207 if !matches!(replaced, CircEnt::DestroySent(_)) {
208 self.open_count = self.open_count.saturating_sub(1);
210 }
211 }
212 }
213
214 pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
216 self.m.remove(&id).map(|removed| {
217 if !matches!(removed, CircEnt::DestroySent(_)) {
218 self.open_count = self.open_count.saturating_sub(1);
219 }
220 removed
221 })
222 }
223
224 pub(super) fn open_ent_count(&self) -> usize {
226 self.open_count
227 }
228
229 }
232
233#[cfg(test)]
234mod test {
235 #![allow(clippy::bool_assert_comparison)]
237 #![allow(clippy::clone_on_copy)]
238 #![allow(clippy::dbg_macro)]
239 #![allow(clippy::mixed_attributes_style)]
240 #![allow(clippy::print_stderr)]
241 #![allow(clippy::print_stdout)]
242 #![allow(clippy::single_char_pattern)]
243 #![allow(clippy::unwrap_used)]
244 #![allow(clippy::unchecked_duration_subtraction)]
245 #![allow(clippy::useless_vec)]
246 #![allow(clippy::needless_pass_by_value)]
247 use super::*;
249 use crate::fake_mpsc;
250 use tor_basic_utils::test_rng::testing_rng;
251
252 #[test]
253 fn circmap_basics() {
254 let mut map_low = CircMap::new(CircIdRange::Low);
255 let mut map_high = CircMap::new(CircIdRange::High);
256 let mut ids_low: Vec<CircId> = Vec::new();
257 let mut ids_high: Vec<CircId> = Vec::new();
258 let mut rng = testing_rng();
259
260 assert!(map_low.get_mut(CircId::new(77).unwrap()).is_none());
261
262 for _ in 0..128 {
263 let (csnd, _) = oneshot::channel();
264 let (snd, _) = fake_mpsc(8);
265 let id_low = map_low.add_ent(&mut rng, csnd, snd).unwrap();
266 assert!(u32::from(id_low) > 0);
267 assert!(u32::from(id_low) < 0x80000000);
268 assert!(!ids_low.contains(&id_low));
269 ids_low.push(id_low);
270
271 assert!(matches!(
272 *map_low.get_mut(id_low).unwrap(),
273 CircEnt::Opening(_, _)
274 ));
275
276 let (csnd, _) = oneshot::channel();
277 let (snd, _) = fake_mpsc(8);
278 let id_high = map_high.add_ent(&mut rng, csnd, snd).unwrap();
279 assert!(u32::from(id_high) >= 0x80000000);
280 assert!(!ids_high.contains(&id_high));
281 ids_high.push(id_high);
282 }
283
284 assert_eq!(128, map_low.open_ent_count());
286 assert_eq!(128, map_high.open_ent_count());
287
288 assert!(map_low.get_mut(ids_low[0]).is_some());
290 map_low.remove(ids_low[0]);
291 assert!(map_low.get_mut(ids_low[0]).is_none());
292 assert_eq!(127, map_low.open_ent_count());
293
294 map_low.destroy_sent(CircId::new(256).unwrap(), HalfCirc::new(1));
296 assert_eq!(127, map_low.open_ent_count());
297
298 assert!(map_high.get_mut(ids_high[0]).is_some());
302 assert!(matches!(
303 *map_high.get_mut(ids_high[0]).unwrap(),
304 CircEnt::Opening(_, _)
305 ));
306 let adv = map_high.advance_from_opening(ids_high[0]);
307 assert!(adv.is_ok());
308 assert!(matches!(
309 *map_high.get_mut(ids_high[0]).unwrap(),
310 CircEnt::Open(_)
311 ));
312
313 let adv = map_high.advance_from_opening(ids_high[0]);
315 assert!(adv.is_err());
316
317 let adv = map_high.advance_from_opening(CircId::new(77).unwrap());
321 assert!(adv.is_err());
322 }
323}