1use crate::Distribution;
4use rand::Rng;
5use core::fmt;
6#[allow(unused_imports)]
7use num_traits::Float;
8
9#[derive(Copy, Clone, Debug)]
31#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
32pub struct Geometric
33{
34 p: f64,
35 pi: f64,
36 k: u64
37}
38
39#[derive(Clone, Copy, Debug, PartialEq, Eq)]
41pub enum Error {
42 InvalidProbability,
44}
45
46impl fmt::Display for Error {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.write_str(match self {
49 Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution",
50 })
51 }
52}
53
54#[cfg(feature = "std")]
55#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
56impl std::error::Error for Error {}
57
58impl Geometric {
59 pub fn new(p: f64) -> Result<Self, Error> {
62 if !p.is_finite() || p < 0.0 || p > 1.0 {
63 Err(Error::InvalidProbability)
64 } else if p == 0.0 || p >= 2.0 / 3.0 {
65 Ok(Geometric { p, pi: p, k: 0 })
66 } else {
67 let (pi, k) = {
68 let mut k = 1;
70 let mut pi = (1.0 - p).powi(2);
71 while pi > 0.5 {
72 k += 1;
73 pi = pi * pi;
74 }
75 (pi, k)
76 };
77
78 Ok(Geometric { p, pi, k })
79 }
80 }
81}
82
83impl Distribution<u64> for Geometric
84{
85 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
86 if self.p >= 2.0 / 3.0 {
87 let mut failures = 0;
89 loop {
90 let u = rng.gen::<f64>();
91 if u <= self.p { break; }
92 failures += 1;
93 }
94 return failures;
95 }
96
97 if self.p == 0.0 { return core::u64::MAX; }
98
99 let Geometric { p, pi, k } = *self;
100
101 let d = {
110 let mut failures = 0;
111 while rng.gen::<f64>() < pi {
112 failures += 1;
113 }
114 failures
115 };
116
117 let m = loop {
123 let m = rng.gen::<u64>() & ((1 << k) - 1);
124 let p_reject = if m <= core::i32::MAX as u64 {
125 (1.0 - p).powi(m as i32)
126 } else {
127 (1.0 - p).powf(m as f64)
128 };
129
130 let u = rng.gen::<f64>();
131 if u < p_reject {
132 break m;
133 }
134 };
135
136 (d << k) + m
137 }
138}
139
140#[derive(Copy, Clone, Debug)]
157#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
158pub struct StandardGeometric;
159
160impl Distribution<u64> for StandardGeometric {
161 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
162 let mut result = 0;
163 loop {
164 let x = rng.gen::<u64>().leading_zeros() as u64;
165 result += x;
166 if x < 64 { break; }
167 }
168 result
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use super::*;
175
176 #[test]
177 fn test_geo_invalid_p() {
178 assert!(Geometric::new(core::f64::NAN).is_err());
179 assert!(Geometric::new(core::f64::INFINITY).is_err());
180 assert!(Geometric::new(core::f64::NEG_INFINITY).is_err());
181
182 assert!(Geometric::new(-0.5).is_err());
183 assert!(Geometric::new(0.0).is_ok());
184 assert!(Geometric::new(1.0).is_ok());
185 assert!(Geometric::new(2.0).is_err());
186 }
187
188 fn test_geo_mean_and_variance<R: Rng>(p: f64, rng: &mut R) {
189 let distr = Geometric::new(p).unwrap();
190
191 let expected_mean = (1.0 - p) / p;
192 let expected_variance = (1.0 - p) / (p * p);
193
194 let mut results = [0.0; 10000];
195 for i in results.iter_mut() {
196 *i = distr.sample(rng) as f64;
197 }
198
199 let mean = results.iter().sum::<f64>() / results.len() as f64;
200 assert!((mean as f64 - expected_mean).abs() < expected_mean / 40.0);
201
202 let variance =
203 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
204 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
205 }
206
207 #[test]
208 fn test_geometric() {
209 let mut rng = crate::test::rng(12345);
210
211 test_geo_mean_and_variance(0.10, &mut rng);
212 test_geo_mean_and_variance(0.25, &mut rng);
213 test_geo_mean_and_variance(0.50, &mut rng);
214 test_geo_mean_and_variance(0.75, &mut rng);
215 test_geo_mean_and_variance(0.90, &mut rng);
216 }
217
218 #[test]
219 fn test_standard_geometric() {
220 let mut rng = crate::test::rng(654321);
221
222 let distr = StandardGeometric;
223 let expected_mean = 1.0;
224 let expected_variance = 2.0;
225
226 let mut results = [0.0; 1000];
227 for i in results.iter_mut() {
228 *i = distr.sample(&mut rng) as f64;
229 }
230
231 let mean = results.iter().sum::<f64>() / results.len() as f64;
232 assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
233
234 let variance =
235 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
236 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
237 }
238}