tor_circmgr/timeouts/
estimator.rs1use crate::timeouts::{
4 pareto::{ParetoTimeoutEstimator, ParetoTimeoutState},
5 readonly::ReadonlyTimeoutEstimator,
6 Action, TimeoutEstimator,
7};
8use crate::TimeoutStateHandle;
9use std::sync::Mutex;
10use std::time::Duration;
11use tor_error::warn_report;
12use tor_netdir::params::NetParameters;
13use tracing::{debug, warn};
14
15pub(crate) struct Estimator {
18 inner: Mutex<Box<dyn TimeoutEstimator + Send + 'static>>,
20}
21
22impl Estimator {
23 #[cfg(test)]
25 pub(crate) fn new(est: impl TimeoutEstimator + Send + 'static) -> Self {
26 Self {
27 inner: Mutex::new(Box::new(est)),
28 }
29 }
30
31 pub(crate) fn from_storage(storage: &TimeoutStateHandle) -> Self {
34 let (_, est) = estimator_from_storage(storage);
35 Self {
36 inner: Mutex::new(est),
37 }
38 }
39
40 pub(crate) fn upgrade_to_owning_storage(&self, storage: &TimeoutStateHandle) {
43 let (readonly, est) = estimator_from_storage(storage);
44 if readonly {
45 warn!("Unable to upgrade to owned persistent storage.");
46 return;
47 }
48 *self.inner.lock().expect("Timeout estimator lock poisoned") = est;
49 }
50
51 pub(crate) fn reload_readonly_from_storage(&self, storage: &TimeoutStateHandle) {
54 if let Ok(Some(v)) = storage.load() {
55 let est = ReadonlyTimeoutEstimator::from_state(&v);
56 *self.inner.lock().expect("Timeout estimator lock poisoned") = Box::new(est);
57 } else {
58 debug!("Unable to reload timeout state.");
59 }
60 }
61
62 pub(crate) fn note_hop_completed(&self, hop: u8, delay: Duration, is_last: bool) {
71 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
72
73 inner.note_hop_completed(hop, delay, is_last);
74 }
75
76 pub(crate) fn note_circ_timeout(&self, hop: u8, delay: Duration) {
84 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
85 inner.note_circ_timeout(hop, delay);
86 }
87
88 pub(crate) fn timeouts(&self, action: &Action) -> (Duration, Duration) {
98 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
99
100 inner.timeouts(action)
101 }
102
103 pub(crate) fn learning_timeouts(&self) -> bool {
106 let inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
107 inner.learning_timeouts()
108 }
109
110 pub(crate) fn update_params(&self, params: &NetParameters) {
113 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
114 inner.update_params(params);
115 }
116
117 pub(crate) fn save_state(&self, storage: &TimeoutStateHandle) -> crate::Result<()> {
119 let state = {
120 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
121 inner.build_state()
122 };
123 if let Some(state) = state {
124 storage.store(&state)?;
125 }
126 Ok(())
127 }
128}
129
130fn estimator_from_storage(
135 storage: &TimeoutStateHandle,
136) -> (bool, Box<dyn TimeoutEstimator + Send + 'static>) {
137 let state = match storage.load() {
138 Ok(Some(v)) => v,
139 Ok(None) => ParetoTimeoutState::default(),
140 Err(e) => {
141 warn_report!(e, "Unable to load timeout state");
142 return (true, Box::new(ReadonlyTimeoutEstimator::new()));
143 }
144 };
145
146 if storage.can_store() {
147 (false, Box::new(ParetoTimeoutEstimator::from_state(state)))
149 } else {
150 (true, Box::new(ReadonlyTimeoutEstimator::from_state(&state)))
151 }
152}
153
154#[cfg(test)]
155mod test {
156 #![allow(clippy::bool_assert_comparison)]
158 #![allow(clippy::clone_on_copy)]
159 #![allow(clippy::dbg_macro)]
160 #![allow(clippy::mixed_attributes_style)]
161 #![allow(clippy::print_stderr)]
162 #![allow(clippy::print_stdout)]
163 #![allow(clippy::single_char_pattern)]
164 #![allow(clippy::unwrap_used)]
165 #![allow(clippy::unchecked_duration_subtraction)]
166 #![allow(clippy::useless_vec)]
167 #![allow(clippy::needless_pass_by_value)]
168 use super::*;
170 use tor_persist::StateMgr;
171
172 #[test]
173 fn load_estimator() {
174 let params = NetParameters::default();
175
176 let storage = tor_persist::TestingStateMgr::new();
178 assert!(storage.try_lock().unwrap().held());
179 let handle = storage.clone().create_handle("paretorama");
180
181 let est = Estimator::from_storage(&handle);
182 assert!(est.learning_timeouts());
183 est.save_state(&handle).unwrap();
184
185 let storage2 = storage.new_manager();
188 assert!(!storage2.try_lock().unwrap().held());
189 let handle2 = storage2.clone().create_handle("paretorama");
190
191 let est2 = Estimator::from_storage(&handle2);
192 assert!(!est2.learning_timeouts());
193
194 est.update_params(¶ms);
195 est2.update_params(¶ms);
196
197 let act = Action::BuildCircuit { length: 3 };
199 assert_eq!(
200 est.timeouts(&act),
201 (Duration::from_secs(60), Duration::from_secs(60))
202 );
203 assert_eq!(
204 est2.timeouts(&act),
205 (Duration::from_secs(60), Duration::from_secs(60))
206 );
207
208 for _ in 0..500 {
210 est.note_hop_completed(2, Duration::from_secs(7), true);
211 est.note_hop_completed(2, Duration::from_secs(2), true);
212 est2.note_hop_completed(2, Duration::from_secs(4), true);
213 }
214 assert!(!est.learning_timeouts());
215
216 est.save_state(&handle).unwrap();
218 let to_1 = est.timeouts(&act);
219 assert_ne!(
220 est.timeouts(&act),
221 (Duration::from_secs(60), Duration::from_secs(60))
222 );
223 assert_eq!(
224 est2.timeouts(&act),
225 (Duration::from_secs(60), Duration::from_secs(60))
226 );
227 est2.reload_readonly_from_storage(&handle2);
228 let to_1_secs = to_1.0.as_secs_f64();
229 let timeouts = est2.timeouts(&act);
230 assert!((timeouts.0.as_secs_f64() - to_1_secs).abs() < 0.001);
231 assert!((timeouts.1.as_secs_f64() - to_1_secs).abs() < 0.001);
232
233 drop(est);
234 drop(handle);
235 drop(storage);
236
237 assert!(storage2.try_lock().unwrap().held());
239 est2.upgrade_to_owning_storage(&handle2);
240 let to_2 = est2.timeouts(&act);
241 assert!(to_2.0 > to_1.0 - Duration::from_secs(1));
243 assert!(to_2.0 < to_1.0 + Duration::from_secs(1));
244 for _ in 0..200 {
246 est2.note_hop_completed(2, Duration::from_secs(1), true);
247 }
248 let to_3 = est2.timeouts(&act);
249 assert!(to_3.0 < to_2.0);
250 }
251}