1use crate::Distribution;
4use rand::Rng;
5use rand::distributions::uniform::Uniform;
6use core::fmt;
7#[allow(unused_imports)]
8use num_traits::Float;
9
10#[derive(Clone, Copy, Debug)]
11#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
12enum SamplingMethod {
13 InverseTransform{ initial_p: f64, initial_x: i64 },
14 RejectionAcceptance{
15 m: f64,
16 a: f64,
17 lambda_l: f64,
18 lambda_r: f64,
19 x_l: f64,
20 x_r: f64,
21 p1: f64,
22 p2: f64,
23 p3: f64
24 },
25}
26
27#[derive(Copy, Clone, Debug)]
49#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
50pub struct Hypergeometric {
51 n1: u64,
52 n2: u64,
53 k: u64,
54 offset_x: i64,
55 sign_x: i64,
56 sampling_method: SamplingMethod,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub enum Error {
62 PopulationTooLarge,
64 ProbabilityTooLarge,
66 SampleSizeTooLarge,
68}
69
70impl fmt::Display for Error {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 f.write_str(match self {
73 Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution",
74 Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution",
75 Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution",
76 })
77 }
78}
79
80#[cfg(feature = "std")]
81#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
82impl std::error::Error for Error {}
83
84fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 {
86 let min_top = u64::min(numerator.0, numerator.1);
87 let min_bottom = u64::min(denominator.0, denominator.1);
88 let min_all = u64::min(min_top, min_bottom);
90
91 let max_top = u64::max(numerator.0, numerator.1);
92 let max_bottom = u64::max(denominator.0, denominator.1);
93 let max_all = u64::max(max_top, max_bottom);
94
95 let mut result = 1.0;
96 for i in (min_all + 1)..=max_all {
97 if i <= min_top {
98 result *= i as f64;
99 }
100
101 if i <= min_bottom {
102 result /= i as f64;
103 }
104
105 if i <= max_top {
106 result *= i as f64;
107 }
108
109 if i <= max_bottom {
110 result /= i as f64;
111 }
112 }
113
114 result
115}
116
117fn ln_of_factorial(v: f64) -> f64 {
118 v * v.ln() - v
121}
122
123impl Hypergeometric {
124 #[allow(clippy::many_single_char_names)] pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result<Self, Error> {
130 if population_with_feature > total_population_size {
131 return Err(Error::ProbabilityTooLarge);
132 }
133
134 if sample_size > total_population_size {
135 return Err(Error::SampleSizeTooLarge);
136 }
137
138 let n = total_population_size;
140 let (mut sign_x, mut offset_x) = (1, 0);
141 let (n1, n2) = {
142 let population_without_feature = n - population_with_feature;
144 if population_with_feature > population_without_feature {
145 sign_x = -1;
146 offset_x = sample_size as i64;
147 (population_without_feature, population_with_feature)
148 } else {
149 (population_with_feature, population_without_feature)
150 }
151 };
152 let k = if sample_size <= n / 2 {
160 sample_size
161 } else {
162 offset_x += n1 as i64 * sign_x;
163 sign_x *= -1;
164 n - sample_size
165 };
166
167 const HIN_THRESHOLD: f64 = 10.0;
176 let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
177 let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
178 let (initial_p, initial_x) = if k < n2 {
179 (fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0)
180 } else {
181 (fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64)
182 };
183
184 if initial_p <= 0.0 || !initial_p.is_finite() {
185 return Err(Error::PopulationTooLarge);
186 }
187
188 SamplingMethod::InverseTransform { initial_p, initial_x }
189 } else {
190 let a = ln_of_factorial(m) +
191 ln_of_factorial(n1 as f64 - m) +
192 ln_of_factorial(k as f64 - m) +
193 ln_of_factorial((n2 - k) as f64 + m);
194
195 let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
196 let denominator = (n - 1) as f64 * n as f64 * n as f64;
197 let d = 1.5 * (numerator / denominator).sqrt() + 0.5;
198
199 let x_l = m - d + 0.5;
200 let x_r = m + d + 0.5;
201
202 let k_l = f64::exp(a -
203 ln_of_factorial(x_l) -
204 ln_of_factorial(n1 as f64 - x_l) -
205 ln_of_factorial(k as f64 - x_l) -
206 ln_of_factorial((n2 - k) as f64 + x_l));
207 let k_r = f64::exp(a -
208 ln_of_factorial(x_r - 1.0) -
209 ln_of_factorial(n1 as f64 - x_r + 1.0) -
210 ln_of_factorial(k as f64 - x_r + 1.0) -
211 ln_of_factorial((n2 - k) as f64 + x_r - 1.0));
212
213 let numerator = x_l * ((n2 - k) as f64 + x_l);
214 let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
215 let lambda_l = -((numerator / denominator).ln());
216
217 let numerator = (n1 as f64 - x_r + 1.0) * (k as f64 - x_r + 1.0);
218 let denominator = x_r * ((n2 - k) as f64 + x_r);
219 let lambda_r = -((numerator / denominator).ln());
220
221 let p1 = 2.0 * d;
224 let p2 = p1 + k_l / lambda_l;
225 let p3 = p2 + k_r / lambda_r;
226
227 SamplingMethod::RejectionAcceptance {
228 m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3
229 }
230 };
231
232 Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method })
233 }
234}
235
236impl Distribution<u64> for Hypergeometric {
237 #[allow(clippy::many_single_char_names)] fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
239 use SamplingMethod::*;
240
241 let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self;
242 let x = match sampling_method {
243 InverseTransform { initial_p: mut p, initial_x: mut x } => {
244 let mut u = rng.gen::<f64>();
245 while u > p && x < k as i64 { u -= p;
247 p *= ((n1 as i64 - x as i64) * (k as i64 - x as i64)) as f64;
248 p /= ((x as i64 + 1) * (n2 as i64 - k as i64 + 1 + x as i64)) as f64;
249 x += 1;
250 }
251 x
252 },
253 RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => {
254 let distr_region_select = Uniform::new(0.0, p3);
255 loop {
256 let (y, v) = loop {
257 let u = distr_region_select.sample(rng);
258 let v = rng.gen::<f64>(); if u <= p1 {
261 let y = (x_l + u).floor();
263 break (y, v);
264 } else if u <= p2 {
265 let y = (x_l + v.ln() / lambda_l).floor();
267 if y as i64 >= i64::max(0, k as i64 - n2 as i64) {
268 let v = v * (u - p1) * lambda_l;
269 break (y, v);
270 }
271 } else {
272 let y = (x_r - v.ln() / lambda_r).floor();
274 if y as u64 <= u64::min(n1, k) {
275 let v = v * (u - p2) * lambda_r;
276 break (y, v);
277 }
278 }
279 };
280
281 if m < 100.0 || y <= 50.0 {
283 let mut f = 1.0;
285 if m < y {
286 for i in (m as u64 + 1)..=(y as u64) {
287 f *= (n1 - i + 1) as f64 * (k - i + 1) as f64;
288 f /= i as f64 * (n2 - k + i) as f64;
289 }
290 } else {
291 for i in (y as u64 + 1)..=(m as u64) {
292 f *= i as f64 * (n2 - k + i) as f64;
293 f /= (n1 - i) as f64 * (k - i) as f64;
294 }
295 }
296
297 if v <= f { break y as i64; }
298 } else {
299 let y1 = y + 1.0;
301 let ym = y - m;
302 let yn = n1 as f64 - y + 1.0;
303 let yk = k as f64 - y + 1.0;
304 let nk = n2 as f64 - k as f64 + y1;
305 let r = -ym / y1;
306 let s = ym / yn;
307 let t = ym / yk;
308 let e = -ym / nk;
309 let g = yn * yk / (y1 * nk) - 1.0;
310 let dg = if g < 0.0 {
311 1.0 + g
312 } else {
313 1.0
314 };
315 let gu = g * (1.0 + g * (-0.5 + g / 3.0));
316 let gl = gu - g.powi(4) / (4.0 * dg);
317 let xm = m + 0.5;
318 let xn = n1 as f64 - m + 0.5;
319 let xk = k as f64 - m + 0.5;
320 let nm = n2 as f64 - k as f64 + xm;
321 let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) +
322 xn * s * (1.0 + s * (-0.5 + s / 3.0)) +
323 xk * t * (1.0 + t * (-0.5 + t / 3.0)) +
324 nm * e * (1.0 + e * (-0.5 + e / 3.0)) +
325 y * gu - m * gl + 0.0034;
326 let av = v.ln();
327 if av > ub { continue; }
328 let dr = if r < 0.0 {
329 xm * r.powi(4) / (1.0 + r)
330 } else {
331 xm * r.powi(4)
332 };
333 let ds = if s < 0.0 {
334 xn * s.powi(4) / (1.0 + s)
335 } else {
336 xn * s.powi(4)
337 };
338 let dt = if t < 0.0 {
339 xk * t.powi(4) / (1.0 + t)
340 } else {
341 xk * t.powi(4)
342 };
343 let de = if e < 0.0 {
344 nm * e.powi(4) / (1.0 + e)
345 } else {
346 nm * e.powi(4)
347 };
348
349 if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 {
350 break y as i64;
351 }
352
353 let av_critical = a -
355 ln_of_factorial(y) -
356 ln_of_factorial(n1 as f64 - y) -
357 ln_of_factorial(k as f64 - y) -
358 ln_of_factorial((n2 - k) as f64 + y);
359 if v.ln() <= av_critical {
360 break y as i64;
361 }
362 }
363 }
364 }
365 };
366
367 (offset_x + sign_x * x) as u64
368 }
369}
370
371#[cfg(test)]
372mod test {
373 use super::*;
374
375 #[test]
376 fn test_hypergeometric_invalid_params() {
377 assert!(Hypergeometric::new(100, 101, 5).is_err());
378 assert!(Hypergeometric::new(100, 10, 101).is_err());
379 assert!(Hypergeometric::new(100, 101, 101).is_err());
380 assert!(Hypergeometric::new(100, 10, 5).is_ok());
381 }
382
383 fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R)
384 {
385 let distr = Hypergeometric::new(n, k, s).unwrap();
386
387 let expected_mean = s as f64 * k as f64 / n as f64;
388 let expected_variance = {
389 let numerator = (s * k * (n - k) * (n - s)) as f64;
390 let denominator = (n * n * (n - 1)) as f64;
391 numerator / denominator
392 };
393
394 let mut results = [0.0; 1000];
395 for i in results.iter_mut() {
396 *i = distr.sample(rng) as f64;
397 }
398
399 let mean = results.iter().sum::<f64>() / results.len() as f64;
400 assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);
401
402 let variance =
403 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
404 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
405 }
406
407 #[test]
408 fn test_hypergeometric() {
409 let mut rng = crate::test::rng(737);
410
411 test_hypergeometric_mean_and_variance(500, 400, 30, &mut rng);
413 test_hypergeometric_mean_and_variance(250, 200, 230, &mut rng);
414 test_hypergeometric_mean_and_variance(100, 20, 6, &mut rng);
415 test_hypergeometric_mean_and_variance(50, 10, 47, &mut rng);
416
417 test_hypergeometric_mean_and_variance(5000, 2500, 500, &mut rng);
419 test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
420 test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
421 }
422}