rand_distr/binomial.rs
1// Copyright 2018 Developers of the Rand project.
2// Copyright 2016-2017 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The binomial distribution.
11
12use crate::{Distribution, Uniform};
13use rand::Rng;
14use core::fmt;
15use core::cmp::Ordering;
16#[allow(unused_imports)]
17use num_traits::Float;
18
19/// The binomial distribution `Binomial(n, p)`.
20///
21/// This distribution has density function:
22/// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
23///
24/// # Example
25///
26/// ```
27/// use rand_distr::{Binomial, Distribution};
28///
29/// let bin = Binomial::new(20, 0.3).unwrap();
30/// let v = bin.sample(&mut rand::thread_rng());
31/// println!("{} is from a binomial distribution", v);
32/// ```
33#[derive(Clone, Copy, Debug)]
34#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
35pub struct Binomial {
36 /// Number of trials.
37 n: u64,
38 /// Probability of success.
39 p: f64,
40}
41
42/// Error type returned from `Binomial::new`.
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum Error {
45 /// `p < 0` or `nan`.
46 ProbabilityTooSmall,
47 /// `p > 1`.
48 ProbabilityTooLarge,
49}
50
51impl fmt::Display for Error {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.write_str(match self {
54 Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution",
55 Error::ProbabilityTooLarge => "p > 1 in binomial distribution",
56 })
57 }
58}
59
60#[cfg(feature = "std")]
61#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
62impl std::error::Error for Error {}
63
64impl Binomial {
65 /// Construct a new `Binomial` with the given shape parameters `n` (number
66 /// of trials) and `p` (probability of success).
67 pub fn new(n: u64, p: f64) -> Result<Binomial, Error> {
68 if !(p >= 0.0) {
69 return Err(Error::ProbabilityTooSmall);
70 }
71 if !(p <= 1.0) {
72 return Err(Error::ProbabilityTooLarge);
73 }
74 Ok(Binomial { n, p })
75 }
76}
77
78/// Convert a `f64` to an `i64`, panicking on overflow.
79fn f64_to_i64(x: f64) -> i64 {
80 assert!(x < (core::i64::MAX as f64));
81 x as i64
82}
83
84impl Distribution<u64> for Binomial {
85 #[allow(clippy::many_single_char_names)] // Same names as in the reference.
86 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
87 // Handle these values directly.
88 if self.p == 0.0 {
89 return 0;
90 } else if self.p == 1.0 {
91 return self.n;
92 }
93
94 // The binomial distribution is symmetrical with respect to p -> 1-p,
95 // k -> n-k switch p so that it is less than 0.5 - this allows for lower
96 // expected values we will just invert the result at the end
97 let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p };
98
99 let result;
100 let q = 1. - p;
101
102 // For small n * min(p, 1 - p), the BINV algorithm based on the inverse
103 // transformation of the binomial distribution is efficient. Otherwise,
104 // the BTPE algorithm is used.
105 //
106 // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
107 // random variate generation. Commun. ACM 31, 2 (February 1988),
108 // 216-222. http://dx.doi.org/10.1145/42372.42381
109
110 // Threshold for preferring the BINV algorithm. The paper suggests 10,
111 // Ranlib uses 30, and GSL uses 14.
112 const BINV_THRESHOLD: f64 = 10.;
113
114 if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) {
115 // Use the BINV algorithm.
116 let s = p / q;
117 let a = ((self.n + 1) as f64) * s;
118 let mut r = q.powi(self.n as i32);
119 let mut u: f64 = rng.gen();
120 let mut x = 0;
121 while u > r as f64 {
122 u -= r;
123 x += 1;
124 r *= a / (x as f64) - s;
125 }
126 result = x;
127 } else {
128 // Use the BTPE algorithm.
129
130 // Threshold for using the squeeze algorithm. This can be freely
131 // chosen based on performance. Ranlib and GSL use 20.
132 const SQUEEZE_THRESHOLD: i64 = 20;
133
134 // Step 0: Calculate constants as functions of `n` and `p`.
135 let n = self.n as f64;
136 let np = n * p;
137 let npq = np * q;
138 let f_m = np + p;
139 let m = f64_to_i64(f_m);
140 // radius of triangle region, since height=1 also area of region
141 let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
142 // tip of triangle
143 let x_m = (m as f64) + 0.5;
144 // left edge of triangle
145 let x_l = x_m - p1;
146 // right edge of triangle
147 let x_r = x_m + p1;
148 let c = 0.134 + 20.5 / (15.3 + (m as f64));
149 // p1 + area of parallelogram region
150 let p2 = p1 * (1. + 2. * c);
151
152 fn lambda(a: f64) -> f64 {
153 a * (1. + 0.5 * a)
154 }
155
156 let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
157 let lambda_r = lambda((x_r - f_m) / (x_r * q));
158 // p1 + area of left tail
159 let p3 = p2 + c / lambda_l;
160 // p1 + area of right tail
161 let p4 = p3 + c / lambda_r;
162
163 // return value
164 let mut y: i64;
165
166 let gen_u = Uniform::new(0., p4);
167 let gen_v = Uniform::new(0., 1.);
168
169 loop {
170 // Step 1: Generate `u` for selecting the region. If region 1 is
171 // selected, generate a triangularly distributed variate.
172 let u = gen_u.sample(rng);
173 let mut v = gen_v.sample(rng);
174 if !(u > p1) {
175 y = f64_to_i64(x_m - p1 * v + u);
176 break;
177 }
178
179 if !(u > p2) {
180 // Step 2: Region 2, parallelograms. Check if region 2 is
181 // used. If so, generate `y`.
182 let x = x_l + (u - p1) / c;
183 v = v * c + 1.0 - (x - x_m).abs() / p1;
184 if v > 1. {
185 continue;
186 } else {
187 y = f64_to_i64(x);
188 }
189 } else if !(u > p3) {
190 // Step 3: Region 3, left exponential tail.
191 y = f64_to_i64(x_l + v.ln() / lambda_l);
192 if y < 0 {
193 continue;
194 } else {
195 v *= (u - p2) * lambda_l;
196 }
197 } else {
198 // Step 4: Region 4, right exponential tail.
199 y = f64_to_i64(x_r - v.ln() / lambda_r);
200 if y > 0 && (y as u64) > self.n {
201 continue;
202 } else {
203 v *= (u - p3) * lambda_r;
204 }
205 }
206
207 // Step 5: Acceptance/rejection comparison.
208
209 // Step 5.0: Test for appropriate method of evaluating f(y).
210 let k = (y - m).abs();
211 if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
212 // Step 5.1: Evaluate f(y) via the recursive relationship. Start the
213 // search from the mode.
214 let s = p / q;
215 let a = s * (n + 1.);
216 let mut f = 1.0;
217 match m.cmp(&y) {
218 Ordering::Less => {
219 let mut i = m;
220 loop {
221 i += 1;
222 f *= a / (i as f64) - s;
223 if i == y {
224 break;
225 }
226 }
227 },
228 Ordering::Greater => {
229 let mut i = y;
230 loop {
231 i += 1;
232 f /= a / (i as f64) - s;
233 if i == m {
234 break;
235 }
236 }
237 },
238 Ordering::Equal => {},
239 }
240 if v > f {
241 continue;
242 } else {
243 break;
244 }
245 }
246
247 // Step 5.2: Squeezing. Check the value of ln(v) against upper and
248 // lower bound of ln(f(y)).
249 let k = k as f64;
250 let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
251 let t = -0.5 * k * k / npq;
252 let alpha = v.ln();
253 if alpha < t - rho {
254 break;
255 }
256 if alpha > t + rho {
257 continue;
258 }
259
260 // Step 5.3: Final acceptance/rejection test.
261 let x1 = (y + 1) as f64;
262 let f1 = (m + 1) as f64;
263 let z = (f64_to_i64(n) + 1 - m) as f64;
264 let w = (f64_to_i64(n) - y + 1) as f64;
265
266 fn stirling(a: f64) -> f64 {
267 let a2 = a * a;
268 (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
269 }
270
271 if alpha
272 > x_m * (f1 / x1).ln()
273 + (n - (m as f64) + 0.5) * (z / w).ln()
274 + ((y - m) as f64) * (w * p / (x1 * q)).ln()
275 // We use the signs from the GSL implementation, which are
276 // different than the ones in the reference. According to
277 // the GSL authors, the new signs were verified to be
278 // correct by one of the original designers of the
279 // algorithm.
280 + stirling(f1)
281 + stirling(z)
282 - stirling(x1)
283 - stirling(w)
284 {
285 continue;
286 }
287
288 break;
289 }
290 assert!(y >= 0);
291 result = y as u64;
292 }
293
294 // Invert the result for p < 0.5.
295 if p != self.p {
296 self.n - result
297 } else {
298 result
299 }
300 }
301}
302
303#[cfg(test)]
304mod test {
305 use super::Binomial;
306 use crate::Distribution;
307 use rand::Rng;
308
309 fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
310 let binomial = Binomial::new(n, p).unwrap();
311
312 let expected_mean = n as f64 * p;
313 let expected_variance = n as f64 * p * (1.0 - p);
314
315 let mut results = [0.0; 1000];
316 for i in results.iter_mut() {
317 *i = binomial.sample(rng) as f64;
318 }
319
320 let mean = results.iter().sum::<f64>() / results.len() as f64;
321 assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
322
323 let variance =
324 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
325 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
326 }
327
328 #[test]
329 fn test_binomial() {
330 let mut rng = crate::test::rng(351);
331 test_binomial_mean_and_variance(150, 0.1, &mut rng);
332 test_binomial_mean_and_variance(70, 0.6, &mut rng);
333 test_binomial_mean_and_variance(40, 0.5, &mut rng);
334 test_binomial_mean_and_variance(20, 0.7, &mut rng);
335 test_binomial_mean_and_variance(20, 0.5, &mut rng);
336 }
337
338 #[test]
339 fn test_binomial_end_points() {
340 let mut rng = crate::test::rng(352);
341 assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0);
342 assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20);
343 }
344
345 #[test]
346 #[should_panic]
347 fn test_binomial_invalid_lambda_neg() {
348 Binomial::new(20, -10.0).unwrap();
349 }
350}