rand_distr/
hypergeometric.rs

1//! The hypergeometric distribution.
2
3use crate::Distribution;
4use rand::Rng;
5use rand::distributions::uniform::Uniform;
6use core::fmt;
7#[allow(unused_imports)]
8use num_traits::Float;
9
10#[derive(Clone, Copy, Debug)]
11#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
12enum SamplingMethod {
13    InverseTransform{ initial_p: f64, initial_x: i64 },
14    RejectionAcceptance{
15        m: f64,
16        a: f64,
17        lambda_l: f64,
18        lambda_r: f64,
19        x_l: f64,
20        x_r: f64,
21        p1: f64,
22        p2: f64,
23        p3: f64
24    },
25}
26
27/// The hypergeometric distribution `Hypergeometric(N, K, n)`.
28/// 
29/// This is the distribution of successes in samples of size `n` drawn without
30/// replacement from a population of size `N` containing `K` success states.
31/// It has the density function:
32/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`,
33/// where `binomial(a, b) = a! / (b! * (a - b)!)`.
34/// 
35/// The [binomial distribution](crate::Binomial) is the analogous distribution
36/// for sampling with replacement. It is a good approximation when the population
37/// size is much larger than the sample size.
38/// 
39/// # Example
40/// 
41/// ```
42/// use rand_distr::{Distribution, Hypergeometric};
43///
44/// let hypergeo = Hypergeometric::new(60, 24, 7).unwrap();
45/// let v = hypergeo.sample(&mut rand::thread_rng());
46/// println!("{} is from a hypergeometric distribution", v);
47/// ```
48#[derive(Copy, Clone, Debug)]
49#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
50pub struct Hypergeometric {
51    n1: u64,
52    n2: u64,
53    k: u64,
54    offset_x: i64,
55    sign_x: i64,
56    sampling_method: SamplingMethod,
57}
58
59/// Error type returned from `Hypergeometric::new`.
60#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub enum Error {
62    /// `total_population_size` is too large, causing floating point underflow.
63    PopulationTooLarge,
64    /// `population_with_feature > total_population_size`.
65    ProbabilityTooLarge,
66    /// `sample_size > total_population_size`.
67    SampleSizeTooLarge,
68}
69
70impl fmt::Display for Error {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        f.write_str(match self {
73            Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution",
74            Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution",
75            Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution",
76        })
77    }
78}
79
80#[cfg(feature = "std")]
81#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
82impl std::error::Error for Error {}
83
84// evaluate fact(numerator.0)*fact(numerator.1) / fact(denominator.0)*fact(denominator.1)
85fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 {
86    let min_top = u64::min(numerator.0, numerator.1);
87    let min_bottom = u64::min(denominator.0, denominator.1);
88    // the factorial of this will cancel out:
89    let min_all = u64::min(min_top, min_bottom);
90
91    let max_top = u64::max(numerator.0, numerator.1);
92    let max_bottom = u64::max(denominator.0, denominator.1);
93    let max_all = u64::max(max_top, max_bottom);
94
95    let mut result = 1.0;
96    for i in (min_all + 1)..=max_all {
97        if i <= min_top {
98            result *= i as f64;
99        }
100        
101        if i <= min_bottom {
102            result /= i as f64;
103        }
104        
105        if i <= max_top {
106            result *= i as f64;
107        }
108        
109        if i <= max_bottom {
110            result /= i as f64;
111        }
112    }
113    
114    result
115}
116
117fn ln_of_factorial(v: f64) -> f64 {
118    // the paper calls for ln(v!), but also wants to pass in fractions,
119    // so we need to use Stirling's approximation to fill in the gaps:
120    v * v.ln() - v
121}
122
123impl Hypergeometric {
124    /// Constructs a new `Hypergeometric` with the shape parameters
125    /// `N = total_population_size`,
126    /// `K = population_with_feature`,
127    /// `n = sample_size`.
128    #[allow(clippy::many_single_char_names)] // Same names as in the reference.
129    pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result<Self, Error> {
130        if population_with_feature > total_population_size {
131            return Err(Error::ProbabilityTooLarge);
132        }
133
134        if sample_size > total_population_size {
135            return Err(Error::SampleSizeTooLarge);
136        }
137
138        // set-up constants as function of original parameters
139        let n = total_population_size;
140        let (mut sign_x, mut offset_x) = (1, 0);
141        let (n1, n2) = {
142            // switch around success and failure states if necessary to ensure n1 <= n2
143            let population_without_feature = n - population_with_feature;
144            if population_with_feature > population_without_feature {
145                sign_x = -1;
146                offset_x = sample_size as i64;
147                (population_without_feature, population_with_feature)
148            } else {
149                (population_with_feature, population_without_feature)
150            }
151        };
152        // when sampling more than half the total population, take the smaller
153        // group as sampled instead (we can then return n1-x instead).
154        // 
155        // Note: the boundary condition given in the paper is `sample_size < n / 2`;
156        // we're deviating here, because when n is even, it doesn't matter whether
157        // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching
158        // when `k == n/2`, we'd actually be taking the _larger_ group as sampled.
159        let k = if sample_size <= n / 2 {
160            sample_size
161        } else {
162            offset_x += n1 as i64 * sign_x;
163            sign_x *= -1;
164            n - sample_size
165        };
166
167        // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`,
168        // where `M` is the mode of the distribution.
169        // Use algorithm HIN for the remaining parameter space.
170        // 
171        // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer
172        // generation of hypergeometric random variates.
173        // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145
174        // https://www.researchgate.net/publication/233212638
175        const HIN_THRESHOLD: f64 = 10.0;
176        let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
177        let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
178            let (initial_p, initial_x) = if k < n2 {
179                (fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0)
180            } else {
181                (fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64)
182            };
183
184            if initial_p <= 0.0 || !initial_p.is_finite() {
185                return Err(Error::PopulationTooLarge);
186            }
187
188            SamplingMethod::InverseTransform { initial_p, initial_x }
189        } else {
190            let a = ln_of_factorial(m) +
191                ln_of_factorial(n1 as f64 - m) +
192                ln_of_factorial(k as f64 - m) +
193                ln_of_factorial((n2 - k) as f64 + m);
194
195            let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
196            let denominator = (n - 1) as f64 * n as f64 * n as f64;
197            let d = 1.5 * (numerator / denominator).sqrt() + 0.5;
198
199            let x_l = m - d + 0.5;
200            let x_r = m + d + 0.5;
201
202            let k_l = f64::exp(a -
203                ln_of_factorial(x_l) -
204                ln_of_factorial(n1 as f64 - x_l) -
205                ln_of_factorial(k as f64 - x_l) -
206                ln_of_factorial((n2 - k) as f64 + x_l));
207            let k_r = f64::exp(a -
208                ln_of_factorial(x_r - 1.0) -
209                ln_of_factorial(n1 as f64 - x_r + 1.0) -
210                ln_of_factorial(k as f64 - x_r + 1.0) -
211                ln_of_factorial((n2 - k) as f64 + x_r - 1.0));
212            
213            let numerator = x_l * ((n2 - k) as f64 + x_l);
214            let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
215            let lambda_l = -((numerator / denominator).ln());
216
217            let numerator = (n1 as f64 - x_r + 1.0) * (k as f64 - x_r + 1.0);
218            let denominator = x_r * ((n2 - k) as f64 + x_r);
219            let lambda_r = -((numerator / denominator).ln());
220
221            // the paper literally gives `p2 + kL/lambdaL` where it (probably)
222            // should have been `p2 <- p1 + kL/lambdaL`; another print error?!
223            let p1 = 2.0 * d;
224            let p2 = p1 + k_l / lambda_l;
225            let p3 = p2 + k_r / lambda_r;
226
227            SamplingMethod::RejectionAcceptance {
228                m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3
229            }
230        };
231
232        Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method })
233    }
234}
235
236impl Distribution<u64> for Hypergeometric {
237    #[allow(clippy::many_single_char_names)] // Same names as in the reference.
238    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
239        use SamplingMethod::*;
240
241        let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self;
242        let x = match sampling_method {
243            InverseTransform { initial_p: mut p, initial_x: mut x } => {
244                let mut u = rng.gen::<f64>();
245                while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense
246                    u -= p;
247                    p *= ((n1 as i64 - x as i64) * (k as i64 - x as i64)) as f64;
248                    p /= ((x as i64 + 1) * (n2 as i64 - k as i64 + 1 + x as i64)) as f64;
249                    x += 1;
250                }
251                x
252            },
253            RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => {
254                let distr_region_select = Uniform::new(0.0, p3);
255                loop {
256                    let (y, v) = loop {
257                        let u = distr_region_select.sample(rng);
258                        let v = rng.gen::<f64>(); // for the accept/reject decision
259            
260                        if u <= p1 {
261                            // Region 1, central bell
262                            let y = (x_l + u).floor();
263                            break (y, v);
264                        } else if u <= p2 {
265                            // Region 2, left exponential tail
266                            let y = (x_l + v.ln() / lambda_l).floor();
267                            if y as i64 >= i64::max(0, k as i64 - n2 as i64) {
268                                let v = v * (u - p1) * lambda_l;
269                                break (y, v);
270                            }
271                        } else {
272                            // Region 3, right exponential tail
273                            let y = (x_r - v.ln() / lambda_r).floor();
274                            if y as u64 <= u64::min(n1, k) {
275                                let v = v * (u - p2) * lambda_r;
276                                break (y, v);
277                            }
278                        }
279                    };
280        
281                    // Step 4: Acceptance/Rejection Comparison
282                    if m < 100.0 || y <= 50.0 {
283                        // Step 4.1: evaluate f(y) via recursive relationship
284                        let mut f = 1.0;
285                        if m < y {
286                            for i in (m as u64 + 1)..=(y as u64) {
287                                f *= (n1 - i + 1) as f64 * (k - i + 1) as f64;
288                                f /= i as f64 * (n2 - k + i) as f64;
289                            }
290                        } else {
291                            for i in (y as u64 + 1)..=(m as u64) {
292                                f *= i as f64 * (n2 - k + i) as f64;
293                                f /= (n1 - i) as f64 * (k - i) as f64;
294                            }
295                        }
296        
297                        if v <= f { break y as i64; }
298                    } else {
299                        // Step 4.2: Squeezing
300                        let y1 = y + 1.0;
301                        let ym = y - m;
302                        let yn = n1 as f64 - y + 1.0;
303                        let yk = k as f64 - y + 1.0;
304                        let nk = n2 as f64 - k as f64 + y1;
305                        let r = -ym / y1;
306                        let s = ym / yn;
307                        let t = ym / yk;
308                        let e = -ym / nk;
309                        let g = yn * yk / (y1 * nk) - 1.0;
310                        let dg = if g < 0.0 {
311                            1.0 + g
312                        } else {
313                            1.0
314                        };
315                        let gu = g * (1.0 + g * (-0.5 + g / 3.0));
316                        let gl = gu - g.powi(4) / (4.0 * dg);
317                        let xm = m + 0.5;
318                        let xn = n1 as f64 - m + 0.5;
319                        let xk = k as f64 - m + 0.5;
320                        let nm = n2 as f64 - k as f64 + xm;
321                        let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) +
322                            xn * s * (1.0 + s * (-0.5 + s / 3.0)) +
323                            xk * t * (1.0 + t * (-0.5 + t / 3.0)) +
324                            nm * e * (1.0 + e * (-0.5 + e / 3.0)) +
325                            y * gu - m * gl + 0.0034;
326                        let av = v.ln();
327                        if av > ub { continue; }
328                        let dr = if r < 0.0 {
329                            xm * r.powi(4) / (1.0 + r)
330                        } else {
331                            xm * r.powi(4)
332                        };
333                        let ds = if s < 0.0 {
334                            xn * s.powi(4) / (1.0 + s)
335                        } else {
336                            xn * s.powi(4)
337                        };
338                        let dt = if t < 0.0 {
339                            xk * t.powi(4) / (1.0 + t)
340                        } else {
341                            xk * t.powi(4)
342                        };
343                        let de = if e < 0.0 {
344                            nm * e.powi(4) / (1.0 + e)
345                        } else {
346                            nm * e.powi(4)
347                        };
348        
349                        if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 {
350                            break y as i64;
351                        }
352        
353                        // Step 4.3: Final Acceptance/Rejection Test
354                        let av_critical = a -
355                            ln_of_factorial(y) -
356                            ln_of_factorial(n1 as f64 - y) - 
357                            ln_of_factorial(k as f64 - y) - 
358                            ln_of_factorial((n2 - k) as f64 + y);
359                        if v.ln() <= av_critical {
360                            break y as i64;
361                        }
362                    }
363                }
364            }
365        };
366
367        (offset_x + sign_x * x) as u64
368    }
369}
370
371#[cfg(test)]
372mod test {
373    use super::*;
374
375    #[test]
376    fn test_hypergeometric_invalid_params() {
377        assert!(Hypergeometric::new(100, 101, 5).is_err());
378        assert!(Hypergeometric::new(100, 10, 101).is_err());
379        assert!(Hypergeometric::new(100, 101, 101).is_err());
380        assert!(Hypergeometric::new(100, 10, 5).is_ok());
381    }
382
383    fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R)
384    {
385        let distr = Hypergeometric::new(n, k, s).unwrap();
386
387        let expected_mean = s as f64 * k as f64 / n as f64;
388        let expected_variance = {
389            let numerator = (s * k * (n - k) * (n - s)) as f64;
390            let denominator = (n * n * (n - 1)) as f64;
391            numerator / denominator
392        };
393
394        let mut results = [0.0; 1000];
395        for i in results.iter_mut() {
396            *i = distr.sample(rng) as f64;
397        }
398
399        let mean = results.iter().sum::<f64>() / results.len() as f64;
400        assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
401
402        let variance =
403            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
404        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
405    }
406
407    #[test]
408    fn test_hypergeometric() {
409        let mut rng = crate::test::rng(737);
410
411        // exercise algorithm HIN:
412        test_hypergeometric_mean_and_variance(500, 400, 30, &mut rng);
413        test_hypergeometric_mean_and_variance(250, 200, 230, &mut rng);
414        test_hypergeometric_mean_and_variance(100, 20, 6, &mut rng);
415        test_hypergeometric_mean_and_variance(50, 10, 47, &mut rng);
416
417        // exercise algorithm H2PE
418        test_hypergeometric_mean_and_variance(5000, 2500, 500, &mut rng);
419        test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
420        test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
421    }
422}