cuprate_helper/num/rolling_median.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
use std::{
collections::VecDeque,
ops::{Add, Div, Mul, Sub},
};
use crate::num::median;
/// A rolling median type.
///
/// This keeps track of a window of items and allows calculating the [`RollingMedian::median`] of them.
///
/// Example:
/// ```rust
/// # use cuprate_helper::num::RollingMedian;
/// let mut rolling_median = RollingMedian::new(2);
///
/// rolling_median.push(1);
/// assert_eq!(rolling_median.median(), 1);
/// assert_eq!(rolling_median.window_len(), 1);
///
/// rolling_median.push(3);
/// assert_eq!(rolling_median.median(), 2);
/// assert_eq!(rolling_median.window_len(), 2);
///
/// rolling_median.push(5);
/// assert_eq!(rolling_median.median(), 4);
/// assert_eq!(rolling_median.window_len(), 2);
/// ```
///
// TODO: a more efficient structure is probably possible.
#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Clone)]
pub struct RollingMedian<T> {
/// The window of items, in order of insertion.
window: VecDeque<T>,
/// The window of items, sorted.
sorted_window: Vec<T>,
/// The target window length.
target_window: usize,
}
impl<T> RollingMedian<T>
where
T: Ord
+ PartialOrd
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>
+ Copy
+ From<u8>,
{
/// Creates a new [`RollingMedian`] with a certain target window length.
///
/// `target_window` is the maximum amount of items to keep in the rolling window.
pub fn new(target_window: usize) -> Self {
Self {
window: VecDeque::with_capacity(target_window),
sorted_window: Vec::with_capacity(target_window),
target_window,
}
}
/// Creates a new [`RollingMedian`] from a [`Vec`] with a certain target window length.
///
/// `target_window` is the maximum amount of items to keep in the rolling window.
///
/// # Panics
/// This function panics if `vec.len() > target_window`.
pub fn from_vec(vec: Vec<T>, target_window: usize) -> Self {
assert!(vec.len() <= target_window);
let mut sorted_window = vec.clone();
sorted_window.sort_unstable();
Self {
window: vec.into(),
sorted_window,
target_window,
}
}
/// Pops the front of the window, i.e. the oldest item.
///
/// This is often not needed as [`RollingMedian::push`] will handle popping old values when they fall
/// out of the window.
pub fn pop_front(&mut self) {
if let Some(item) = self.window.pop_front() {
match self.sorted_window.binary_search(&item) {
Ok(idx) => {
self.sorted_window.remove(idx);
}
Err(_) => panic!("Value expected to be in sorted_window was not there"),
}
}
}
/// Pops the back of the window, i.e. the youngest item.
pub fn pop_back(&mut self) {
if let Some(item) = self.window.pop_back() {
match self.sorted_window.binary_search(&item) {
Ok(idx) => {
self.sorted_window.remove(idx);
}
Err(_) => panic!("Value expected to be in sorted_window was not there"),
}
}
}
/// Push an item to the _back_ of the window.
///
/// This will pop the oldest item in the window if the target length has been exceeded.
pub fn push(&mut self, item: T) {
if self.window.len() >= self.target_window {
self.pop_front();
}
self.window.push_back(item);
match self.sorted_window.binary_search(&item) {
Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
}
}
/// Append some values to the _front_ of the window.
///
/// These new values will be the oldest items in the window. The order of the inputted items will be
/// kept, i.e. the first item in the [`Vec`] will be the oldest item in the queue.
pub fn append_front(&mut self, items: Vec<T>) {
for item in items.into_iter().rev() {
self.window.push_front(item);
match self.sorted_window.binary_search(&item) {
Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
}
if self.window.len() > self.target_window {
self.pop_back();
}
}
}
/// Returns the number of items currently in the [`RollingMedian`].
pub fn window_len(&self) -> usize {
self.window.len()
}
/// Calculates the median of the values currently in the [`RollingMedian`].
pub fn median(&self) -> T {
median(&self.sorted_window)
}
}