cuprate_helper/num/
rolling_median.rs

1use std::{
2    collections::VecDeque,
3    ops::{Add, Div, Mul, Sub},
4};
5
6use crate::num::median;
7
8/// A rolling median type.
9///
10/// This keeps track of a window of items and allows calculating the [`RollingMedian::median`] of them.
11///
12/// Example:
13/// ```rust
14/// # use cuprate_helper::num::RollingMedian;
15/// let mut rolling_median = RollingMedian::new(2);
16///
17/// rolling_median.push(1);
18/// assert_eq!(rolling_median.median(), 1);
19/// assert_eq!(rolling_median.window_len(), 1);
20///
21/// rolling_median.push(3);
22/// assert_eq!(rolling_median.median(), 2);
23/// assert_eq!(rolling_median.window_len(), 2);
24///
25/// rolling_median.push(5);
26/// assert_eq!(rolling_median.median(), 4);
27/// assert_eq!(rolling_median.window_len(), 2);
28/// ```
29///
30// TODO: a more efficient structure is probably possible.
31#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Clone)]
32pub struct RollingMedian<T> {
33    /// The window of items, in order of insertion.
34    window: VecDeque<T>,
35    /// The window of items, sorted.
36    sorted_window: Vec<T>,
37
38    /// The target window length.
39    target_window: usize,
40}
41
42impl<T> RollingMedian<T>
43where
44    T: Ord
45        + PartialOrd
46        + Add<Output = T>
47        + Sub<Output = T>
48        + Div<Output = T>
49        + Mul<Output = T>
50        + Copy
51        + From<u8>,
52{
53    /// Creates a new [`RollingMedian`] with a certain target window length.
54    ///
55    /// `target_window` is the maximum amount of items to keep in the rolling window.
56    pub fn new(target_window: usize) -> Self {
57        Self {
58            window: VecDeque::with_capacity(target_window),
59            sorted_window: Vec::with_capacity(target_window),
60            target_window,
61        }
62    }
63
64    /// Creates a new [`RollingMedian`] from a [`Vec`] with a certain target window length.
65    ///
66    /// `target_window` is the maximum amount of items to keep in the rolling window.
67    ///
68    /// # Panics
69    /// This function panics if `vec.len() > target_window`.
70    pub fn from_vec(vec: Vec<T>, target_window: usize) -> Self {
71        assert!(vec.len() <= target_window);
72
73        let mut sorted_window = vec.clone();
74        sorted_window.sort_unstable();
75
76        Self {
77            window: vec.into(),
78            sorted_window,
79            target_window,
80        }
81    }
82
83    /// Pops the front of the window, i.e. the oldest item.
84    ///
85    /// This is often not needed as [`RollingMedian::push`] will handle popping old values when they fall
86    /// out of the window.
87    pub fn pop_front(&mut self) {
88        if let Some(item) = self.window.pop_front() {
89            match self.sorted_window.binary_search(&item) {
90                Ok(idx) => {
91                    self.sorted_window.remove(idx);
92                }
93                Err(_) => panic!("Value expected to be in sorted_window was not there"),
94            }
95        }
96    }
97
98    /// Pops the back of the window, i.e. the youngest item.
99    pub fn pop_back(&mut self) {
100        if let Some(item) = self.window.pop_back() {
101            match self.sorted_window.binary_search(&item) {
102                Ok(idx) => {
103                    self.sorted_window.remove(idx);
104                }
105                Err(_) => panic!("Value expected to be in sorted_window was not there"),
106            }
107        }
108    }
109
110    /// Push an item to the _back_ of the window.
111    ///
112    /// This will pop the oldest item in the window if the target length has been exceeded.
113    pub fn push(&mut self, item: T) {
114        if self.window.len() >= self.target_window {
115            self.pop_front();
116        }
117
118        self.window.push_back(item);
119        match self.sorted_window.binary_search(&item) {
120            Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
121        }
122    }
123
124    /// Append some values to the _front_ of the window.
125    ///
126    /// These new values will be the oldest items in the window. The order of the inputted items will be
127    /// kept, i.e. the first item in the [`Vec`] will be the oldest item in the queue.
128    pub fn append_front(&mut self, items: Vec<T>) {
129        for item in items.into_iter().rev() {
130            self.window.push_front(item);
131            match self.sorted_window.binary_search(&item) {
132                Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
133            }
134
135            if self.window.len() > self.target_window {
136                self.pop_back();
137            }
138        }
139    }
140
141    /// Returns the number of items currently in the [`RollingMedian`].
142    pub fn window_len(&self) -> usize {
143        self.window.len()
144    }
145
146    /// Calculates the median of the values currently in the [`RollingMedian`].
147    pub fn median(&self) -> T {
148        median(&self.sorted_window)
149    }
150}