rand_distr/
geometric.rs

1//! The geometric distribution.
2
3use crate::Distribution;
4use rand::Rng;
5use core::fmt;
6#[allow(unused_imports)]
7use num_traits::Float;
8
9/// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`.
10/// 
11/// This is the probability distribution of the number of failures before the
12/// first success in a series of Bernoulli trials. It has the density function
13/// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success
14/// on each trial.
15/// 
16/// This is the discrete analogue of the [exponential distribution](crate::Exp).
17/// 
18/// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised
19/// implementation for `p = 0.5`.
20///
21/// # Example
22///
23/// ```
24/// use rand_distr::{Geometric, Distribution};
25///
26/// let geo = Geometric::new(0.25).unwrap();
27/// let v = geo.sample(&mut rand::thread_rng());
28/// println!("{} is from a Geometric(0.25) distribution", v);
29/// ```
30#[derive(Copy, Clone, Debug)]
31#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
32pub struct Geometric
33{
34    p: f64,
35    pi: f64,
36    k: u64
37}
38
39/// Error type returned from `Geometric::new`.
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
41pub enum Error {
42    /// `p < 0 || p > 1` or `nan`
43    InvalidProbability,
44}
45
46impl fmt::Display for Error {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.write_str(match self {
49            Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution",
50        })
51    }
52}
53
54#[cfg(feature = "std")]
55#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
56impl std::error::Error for Error {}
57
58impl Geometric {
59    /// Construct a new `Geometric` with the given shape parameter `p`
60    /// (probability of success on each trial).
61    pub fn new(p: f64) -> Result<Self, Error> {
62        if !p.is_finite() || p < 0.0 || p > 1.0 {
63            Err(Error::InvalidProbability)
64        } else if p == 0.0 || p >= 2.0 / 3.0 {
65            Ok(Geometric { p, pi: p, k: 0 })
66        } else {
67            let (pi, k) = {
68                // choose smallest k such that pi = (1 - p)^(2^k) <= 0.5
69                let mut k = 1;
70                let mut pi = (1.0 - p).powi(2);
71                while pi > 0.5 {
72                    k += 1;
73                    pi = pi * pi;
74                }
75                (pi, k)
76            };
77
78            Ok(Geometric { p, pi, k })
79        }
80    }
81}
82
83impl Distribution<u64> for Geometric
84{
85    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
86        if self.p >= 2.0 / 3.0 {
87            // use the trivial algorithm:
88            let mut failures = 0;
89            loop {
90                let u = rng.gen::<f64>();
91                if u <= self.p { break; }
92                failures += 1;
93            }
94            return failures;
95        }
96        
97        if self.p == 0.0 { return core::u64::MAX; }
98
99        let Geometric { p, pi, k } = *self;
100
101        // Based on the algorithm presented in section 3 of
102        // Karl Bringmann and Tobias Friedrich (July 2013) - Exact and Efficient
103        // Generation of Geometric Random Variates and Random Graphs, published
104        // in International Colloquium on Automata, Languages and Programming
105        // (pp.267-278)
106        // https://people.mpi-inf.mpg.de/~kbringma/paper/2013ICALP-1.pdf
107
108        // Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k:
109        let d = {
110            let mut failures = 0;
111            while rng.gen::<f64>() < pi {
112                failures += 1;
113            }
114            failures
115        };
116
117        // Use rejection sampling for the remainder M from Geo(p) % 2^k:
118        // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M
119        // NOTE: The paper suggests using bitwise sampling here, which is 
120        // currently unsupported, but should improve performance by requiring
121        // fewer iterations on average.                 ~ October 28, 2020
122        let m = loop {
123            let m = rng.gen::<u64>() & ((1 << k) - 1);
124            let p_reject = if m <= core::i32::MAX as u64 {
125                (1.0 - p).powi(m as i32)
126            } else {
127                (1.0 - p).powf(m as f64)
128            };
129            
130            let u = rng.gen::<f64>();
131            if u < p_reject {
132                break m;
133            }
134        };
135
136        (d << k) + m
137    }
138}
139
140/// Samples integers according to the geometric distribution with success
141/// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`,
142/// but faster.
143/// 
144/// See [`Geometric`](crate::Geometric) for the general geometric distribution.
145/// 
146/// Implemented via iterated [Rng::gen::<u64>().leading_zeros()].
147/// 
148/// # Example
149/// ```
150/// use rand::prelude::*;
151/// use rand_distr::StandardGeometric;
152/// 
153/// let v = StandardGeometric.sample(&mut thread_rng());
154/// println!("{} is from a Geometric(0.5) distribution", v);
155/// ```
156#[derive(Copy, Clone, Debug)]
157#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
158pub struct StandardGeometric;
159
160impl Distribution<u64> for StandardGeometric {
161    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
162        let mut result = 0;
163        loop {
164            let x = rng.gen::<u64>().leading_zeros() as u64;
165            result += x;
166            if x < 64 { break; }
167        }
168        result
169    }
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175
176    #[test]
177    fn test_geo_invalid_p() {
178        assert!(Geometric::new(core::f64::NAN).is_err());
179        assert!(Geometric::new(core::f64::INFINITY).is_err());
180        assert!(Geometric::new(core::f64::NEG_INFINITY).is_err());
181
182        assert!(Geometric::new(-0.5).is_err());
183        assert!(Geometric::new(0.0).is_ok());
184        assert!(Geometric::new(1.0).is_ok());
185        assert!(Geometric::new(2.0).is_err());
186    }
187
188    fn test_geo_mean_and_variance<R: Rng>(p: f64, rng: &mut R) {
189        let distr = Geometric::new(p).unwrap();
190
191        let expected_mean = (1.0 - p) / p;
192        let expected_variance = (1.0 - p) / (p * p);
193
194        let mut results = [0.0; 10000];
195        for i in results.iter_mut() {
196            *i = distr.sample(rng) as f64;
197        }
198
199        let mean = results.iter().sum::<f64>() / results.len() as f64;
200        assert!((mean as f64 - expected_mean).abs() < expected_mean / 40.0);
201
202        let variance =
203            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
204        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
205    }
206
207    #[test]
208    fn test_geometric() {
209        let mut rng = crate::test::rng(12345);
210
211        test_geo_mean_and_variance(0.10, &mut rng);
212        test_geo_mean_and_variance(0.25, &mut rng);
213        test_geo_mean_and_variance(0.50, &mut rng);
214        test_geo_mean_and_variance(0.75, &mut rng);
215        test_geo_mean_and_variance(0.90, &mut rng);
216    }
217
218    #[test]
219    fn test_standard_geometric() {
220        let mut rng = crate::test::rng(654321);
221
222        let distr = StandardGeometric;
223        let expected_mean = 1.0;
224        let expected_variance = 2.0;
225
226        let mut results = [0.0; 1000];
227        for i in results.iter_mut() {
228            *i = distr.sample(&mut rng) as f64;
229        }
230
231        let mean = results.iter().sum::<f64>() / results.len() as f64;
232        assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
233
234        let variance =
235            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
236        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
237    }
238}