cuprate_helper/num/rolling_median.rs
1use std::{
2 collections::VecDeque,
3 ops::{Add, Div, Mul, Sub},
4};
5
6use crate::num::median;
7
8/// A rolling median type.
9///
10/// This keeps track of a window of items and allows calculating the [`RollingMedian::median`] of them.
11///
12/// Example:
13/// ```rust
14/// # use cuprate_helper::num::RollingMedian;
15/// let mut rolling_median = RollingMedian::new(2);
16///
17/// rolling_median.push(1);
18/// assert_eq!(rolling_median.median(), 1);
19/// assert_eq!(rolling_median.window_len(), 1);
20///
21/// rolling_median.push(3);
22/// assert_eq!(rolling_median.median(), 2);
23/// assert_eq!(rolling_median.window_len(), 2);
24///
25/// rolling_median.push(5);
26/// assert_eq!(rolling_median.median(), 4);
27/// assert_eq!(rolling_median.window_len(), 2);
28/// ```
29///
30// TODO: a more efficient structure is probably possible.
31#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Clone)]
32pub struct RollingMedian<T> {
33 /// The window of items, in order of insertion.
34 window: VecDeque<T>,
35 /// The window of items, sorted.
36 sorted_window: Vec<T>,
37
38 /// The target window length.
39 target_window: usize,
40}
41
42impl<T> RollingMedian<T>
43where
44 T: Ord
45 + PartialOrd
46 + Add<Output = T>
47 + Sub<Output = T>
48 + Div<Output = T>
49 + Mul<Output = T>
50 + Copy
51 + From<u8>,
52{
53 /// Creates a new [`RollingMedian`] with a certain target window length.
54 ///
55 /// `target_window` is the maximum amount of items to keep in the rolling window.
56 pub fn new(target_window: usize) -> Self {
57 Self {
58 window: VecDeque::with_capacity(target_window),
59 sorted_window: Vec::with_capacity(target_window),
60 target_window,
61 }
62 }
63
64 /// Creates a new [`RollingMedian`] from a [`Vec`] with a certain target window length.
65 ///
66 /// `target_window` is the maximum amount of items to keep in the rolling window.
67 ///
68 /// # Panics
69 /// This function panics if `vec.len() > target_window`.
70 pub fn from_vec(vec: Vec<T>, target_window: usize) -> Self {
71 assert!(vec.len() <= target_window);
72
73 let mut sorted_window = vec.clone();
74 sorted_window.sort_unstable();
75
76 Self {
77 window: vec.into(),
78 sorted_window,
79 target_window,
80 }
81 }
82
83 /// Pops the front of the window, i.e. the oldest item.
84 ///
85 /// This is often not needed as [`RollingMedian::push`] will handle popping old values when they fall
86 /// out of the window.
87 pub fn pop_front(&mut self) {
88 if let Some(item) = self.window.pop_front() {
89 match self.sorted_window.binary_search(&item) {
90 Ok(idx) => {
91 self.sorted_window.remove(idx);
92 }
93 Err(_) => panic!("Value expected to be in sorted_window was not there"),
94 }
95 }
96 }
97
98 /// Pops the back of the window, i.e. the youngest item.
99 pub fn pop_back(&mut self) {
100 if let Some(item) = self.window.pop_back() {
101 match self.sorted_window.binary_search(&item) {
102 Ok(idx) => {
103 self.sorted_window.remove(idx);
104 }
105 Err(_) => panic!("Value expected to be in sorted_window was not there"),
106 }
107 }
108 }
109
110 /// Push an item to the _back_ of the window.
111 ///
112 /// This will pop the oldest item in the window if the target length has been exceeded.
113 pub fn push(&mut self, item: T) {
114 if self.window.len() >= self.target_window {
115 self.pop_front();
116 }
117
118 self.window.push_back(item);
119 match self.sorted_window.binary_search(&item) {
120 Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
121 }
122 }
123
124 /// Append some values to the _front_ of the window.
125 ///
126 /// These new values will be the oldest items in the window. The order of the inputted items will be
127 /// kept, i.e. the first item in the [`Vec`] will be the oldest item in the queue.
128 pub fn append_front(&mut self, items: Vec<T>) {
129 for item in items.into_iter().rev() {
130 self.window.push_front(item);
131 match self.sorted_window.binary_search(&item) {
132 Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
133 }
134
135 if self.window.len() > self.target_window {
136 self.pop_back();
137 }
138 }
139 }
140
141 /// Returns the number of items currently in the [`RollingMedian`].
142 pub fn window_len(&self) -> usize {
143 self.window.len()
144 }
145
146 /// Calculates the median of the values currently in the [`RollingMedian`].
147 pub fn median(&self) -> T {
148 median(&self.sorted_window)
149 }
150}