tower/util/rng.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
//! [PRNG] utilities for tower middleware.
//!
//! This module provides a generic [`Rng`] trait and a [`HasherRng`] that
//! implements the trait based on [`RandomState`] or any other [`Hasher`].
//!
//! These utilities replace tower's internal usage of `rand` with these smaller,
//! more lightweight methods. Most of the implementations are extracted from
//! their corresponding `rand` implementations.
//!
//! [PRNG]: https://en.wikipedia.org/wiki/Pseudorandom_number_generator
use std::{
collections::hash_map::RandomState,
hash::{BuildHasher, Hasher},
ops::Range,
};
/// A simple [PRNG] trait for use within tower middleware.
///
/// [PRNG]: https://en.wikipedia.org/wiki/Pseudorandom_number_generator
pub trait Rng {
/// Generate a random [`u64`].
fn next_u64(&mut self) -> u64;
/// Generate a random [`f64`] between `[0, 1)`.
fn next_f64(&mut self) -> f64 {
// Borrowed from:
// https://github.com/rust-random/rand/blob/master/src/distributions/float.rs#L106
let float_size = std::mem::size_of::<f64>() as u32 * 8;
let precision = 52 + 1;
let scale = 1.0 / ((1u64 << precision) as f64);
let value = self.next_u64();
let value = value >> (float_size - precision);
scale * value as f64
}
/// Randomly pick a value within the range.
///
/// # Panic
///
/// - If start < end this will panic in debug mode.
fn next_range(&mut self, range: Range<u64>) -> u64 {
debug_assert!(
range.start < range.end,
"The range start must be smaller than the end"
);
let start = range.start;
let end = range.end;
let range = end - start;
let n = self.next_u64();
(n % range) + start
}
}
impl<R: Rng + ?Sized> Rng for Box<R> {
fn next_u64(&mut self) -> u64 {
(**self).next_u64()
}
}
/// A [`Rng`] implementation that uses a [`Hasher`] to generate the random
/// values. The implementation uses an internal counter to pass to the hasher
/// for each iteration of [`Rng::next_u64`].
///
/// # Default
///
/// This hasher has a default type of [`RandomState`] which just uses the
/// libstd method of getting a random u64.
#[derive(Clone, Debug)]
pub struct HasherRng<H = RandomState> {
hasher: H,
counter: u64,
}
impl HasherRng {
/// Create a new default [`HasherRng`].
pub fn new() -> Self {
HasherRng::default()
}
}
impl Default for HasherRng {
fn default() -> Self {
HasherRng::with_hasher(RandomState::default())
}
}
impl<H> HasherRng<H> {
/// Create a new [`HasherRng`] with the provided hasher.
pub fn with_hasher(hasher: H) -> Self {
HasherRng { hasher, counter: 0 }
}
}
impl<H> Rng for HasherRng<H>
where
H: BuildHasher,
{
fn next_u64(&mut self) -> u64 {
let mut hasher = self.hasher.build_hasher();
hasher.write_u64(self.counter);
self.counter = self.counter.wrapping_add(1);
hasher.finish()
}
}
/// A sampler modified from the Rand implementation for use internally for the balance middleware.
///
/// It's an implementation of Floyd's combination algorithm with amount fixed at 2. This uses no allocated
/// memory and finishes in constant time (only 2 random calls).
///
/// ref: This was borrowed and modified from the following Rand implementation
/// https://github.com/rust-random/rand/blob/b73640705d6714509f8ceccc49e8df996fa19f51/src/seq/index.rs#L375-L411
pub(crate) fn sample_floyd2<R: Rng>(rng: &mut R, length: u64) -> [u64; 2] {
debug_assert!(2 <= length);
let aidx = rng.next_range(0..length - 1);
let bidx = rng.next_range(0..length);
let aidx = if aidx == bidx { length - 1 } else { aidx };
[aidx, bidx]
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
quickcheck! {
fn next_f64(counter: u64) -> TestResult {
let mut rng = HasherRng::default();
rng.counter = counter;
let n = rng.next_f64();
TestResult::from_bool(n < 1.0 && n >= 0.0)
}
fn next_range(counter: u64, range: Range<u64>) -> TestResult {
if range.start >= range.end{
return TestResult::discard();
}
let mut rng = HasherRng::default();
rng.counter = counter;
let n = rng.next_range(range.clone());
TestResult::from_bool(n >= range.start && (n < range.end || range.start == range.end))
}
fn sample_floyd2(counter: u64, length: u64) -> TestResult {
if length < 2 || length > 256 {
return TestResult::discard();
}
let mut rng = HasherRng::default();
rng.counter = counter;
let [a, b] = super::sample_floyd2(&mut rng, length);
if a >= length || b >= length || a == b {
return TestResult::failed();
}
TestResult::passed()
}
}
#[test]
fn sample_inplace_boundaries() {
let mut r = HasherRng::default();
match super::sample_floyd2(&mut r, 2) {
[0, 1] | [1, 0] => (),
array => panic!("unexpected inplace boundaries: {:?}", array),
}
}
}