rand_distr/
inverse_gaussian.rs1use crate::{Distribution, Standard, StandardNormal};
2use num_traits::Float;
3use rand::Rng;
4use core::fmt;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Error {
9 MeanNegativeOrNull,
11 ShapeNegativeOrNull,
13}
14
15impl fmt::Display for Error {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 f.write_str(match self {
18 Error::MeanNegativeOrNull => "mean <= 0 or is NaN in inverse Gaussian distribution",
19 Error::ShapeNegativeOrNull => "shape <= 0 or is NaN in inverse Gaussian distribution",
20 })
21 }
22}
23
24#[cfg(feature = "std")]
25#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
26impl std::error::Error for Error {}
27
28#[derive(Debug, Clone, Copy)]
30#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
31pub struct InverseGaussian<F>
32where
33 F: Float,
34 StandardNormal: Distribution<F>,
35 Standard: Distribution<F>,
36{
37 mean: F,
38 shape: F,
39}
40
41impl<F> InverseGaussian<F>
42where
43 F: Float,
44 StandardNormal: Distribution<F>,
45 Standard: Distribution<F>,
46{
47 pub fn new(mean: F, shape: F) -> Result<InverseGaussian<F>, Error> {
50 let zero = F::zero();
51 if !(mean > zero) {
52 return Err(Error::MeanNegativeOrNull);
53 }
54
55 if !(shape > zero) {
56 return Err(Error::ShapeNegativeOrNull);
57 }
58
59 Ok(Self { mean, shape })
60 }
61}
62
63impl<F> Distribution<F> for InverseGaussian<F>
64where
65 F: Float,
66 StandardNormal: Distribution<F>,
67 Standard: Distribution<F>,
68{
69 #[allow(clippy::many_single_char_names)]
70 fn sample<R>(&self, rng: &mut R) -> F
71 where R: Rng + ?Sized {
72 let mu = self.mean;
73 let l = self.shape;
74
75 let v: F = rng.sample(StandardNormal);
76 let y = mu * v * v;
77
78 let mu_2l = mu / (F::from(2.).unwrap() * l);
79
80 let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt());
81
82 let u: F = rng.gen();
83
84 if u <= mu / (mu + x) {
85 return x;
86 }
87
88 mu * mu / x
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_inverse_gaussian() {
98 let inv_gauss = InverseGaussian::new(1.0, 1.0).unwrap();
99 let mut rng = crate::test::rng(210);
100 for _ in 0..1000 {
101 inv_gauss.sample(&mut rng);
102 }
103 }
104
105 #[test]
106 fn test_inverse_gaussian_invalid_param() {
107 assert!(InverseGaussian::new(-1.0, 1.0).is_err());
108 assert!(InverseGaussian::new(-1.0, -1.0).is_err());
109 assert!(InverseGaussian::new(1.0, -1.0).is_err());
110 assert!(InverseGaussian::new(1.0, 1.0).is_ok());
111 }
112}