cuprate_helper/num/
rolling_median.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use std::{
    collections::VecDeque,
    ops::{Add, Div, Mul, Sub},
};

use crate::num::median;

/// A rolling median type.
///
/// This keeps track of a window of items and allows calculating the [`RollingMedian::median`] of them.
///
/// Example:
/// ```rust
/// # use cuprate_helper::num::RollingMedian;
/// let mut rolling_median = RollingMedian::new(2);
///
/// rolling_median.push(1);
/// assert_eq!(rolling_median.median(), 1);
/// assert_eq!(rolling_median.window_len(), 1);
///
/// rolling_median.push(3);
/// assert_eq!(rolling_median.median(), 2);
/// assert_eq!(rolling_median.window_len(), 2);
///
/// rolling_median.push(5);
/// assert_eq!(rolling_median.median(), 4);
/// assert_eq!(rolling_median.window_len(), 2);
/// ```
///
// TODO: a more efficient structure is probably possible.
#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Clone)]
pub struct RollingMedian<T> {
    /// The window of items, in order of insertion.
    window: VecDeque<T>,
    /// The window of items, sorted.
    sorted_window: Vec<T>,

    /// The target window length.
    target_window: usize,
}

impl<T> RollingMedian<T>
where
    T: Ord
        + PartialOrd
        + Add<Output = T>
        + Sub<Output = T>
        + Div<Output = T>
        + Mul<Output = T>
        + Copy
        + From<u8>,
{
    /// Creates a new [`RollingMedian`] with a certain target window length.
    ///
    /// `target_window` is the maximum amount of items to keep in the rolling window.
    pub fn new(target_window: usize) -> Self {
        Self {
            window: VecDeque::with_capacity(target_window),
            sorted_window: Vec::with_capacity(target_window),
            target_window,
        }
    }

    /// Creates a new [`RollingMedian`] from a [`Vec`] with a certain target window length.
    ///
    /// `target_window` is the maximum amount of items to keep in the rolling window.
    ///
    /// # Panics
    /// This function panics if `vec.len() > target_window`.
    pub fn from_vec(vec: Vec<T>, target_window: usize) -> Self {
        assert!(vec.len() <= target_window);

        let mut sorted_window = vec.clone();
        sorted_window.sort_unstable();

        Self {
            window: vec.into(),
            sorted_window,
            target_window,
        }
    }

    /// Pops the front of the window, i.e. the oldest item.
    ///
    /// This is often not needed as [`RollingMedian::push`] will handle popping old values when they fall
    /// out of the window.
    pub fn pop_front(&mut self) {
        if let Some(item) = self.window.pop_front() {
            match self.sorted_window.binary_search(&item) {
                Ok(idx) => {
                    self.sorted_window.remove(idx);
                }
                Err(_) => panic!("Value expected to be in sorted_window was not there"),
            }
        }
    }

    /// Pops the back of the window, i.e. the youngest item.
    pub fn pop_back(&mut self) {
        if let Some(item) = self.window.pop_back() {
            match self.sorted_window.binary_search(&item) {
                Ok(idx) => {
                    self.sorted_window.remove(idx);
                }
                Err(_) => panic!("Value expected to be in sorted_window was not there"),
            }
        }
    }

    /// Push an item to the _back_ of the window.
    ///
    /// This will pop the oldest item in the window if the target length has been exceeded.
    pub fn push(&mut self, item: T) {
        if self.window.len() >= self.target_window {
            self.pop_front();
        }

        self.window.push_back(item);
        match self.sorted_window.binary_search(&item) {
            Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
        }
    }

    /// Append some values to the _front_ of the window.
    ///
    /// These new values will be the oldest items in the window. The order of the inputted items will be
    /// kept, i.e. the first item in the [`Vec`] will be the oldest item in the queue.
    pub fn append_front(&mut self, items: Vec<T>) {
        for item in items.into_iter().rev() {
            self.window.push_front(item);
            match self.sorted_window.binary_search(&item) {
                Ok(idx) | Err(idx) => self.sorted_window.insert(idx, item),
            }

            if self.window.len() > self.target_window {
                self.pop_back();
            }
        }
    }

    /// Returns the number of items currently in the [`RollingMedian`].
    pub fn window_len(&self) -> usize {
        self.window.len()
    }

    /// Calculates the median of the values currently in the [`RollingMedian`].
    pub fn median(&self) -> T {
        median(&self.sorted_window)
    }
}