rand_distr/
inverse_gaussian.rs

1use crate::{Distribution, Standard, StandardNormal};
2use num_traits::Float;
3use rand::Rng;
4use core::fmt;
5
6/// Error type returned from `InverseGaussian::new`
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Error {
9    /// `mean <= 0` or `nan`.
10    MeanNegativeOrNull,
11    /// `shape <= 0` or `nan`.
12    ShapeNegativeOrNull,
13}
14
15impl fmt::Display for Error {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        f.write_str(match self {
18            Error::MeanNegativeOrNull => "mean <= 0 or is NaN in inverse Gaussian distribution",
19            Error::ShapeNegativeOrNull => "shape <= 0 or is NaN in inverse Gaussian distribution",
20        })
21    }
22}
23
24#[cfg(feature = "std")]
25#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
26impl std::error::Error for Error {}
27
28/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution)
29#[derive(Debug, Clone, Copy)]
30#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
31pub struct InverseGaussian<F>
32where
33    F: Float,
34    StandardNormal: Distribution<F>,
35    Standard: Distribution<F>,
36{
37    mean: F,
38    shape: F,
39}
40
41impl<F> InverseGaussian<F>
42where
43    F: Float,
44    StandardNormal: Distribution<F>,
45    Standard: Distribution<F>,
46{
47    /// Construct a new `InverseGaussian` distribution with the given mean and
48    /// shape.
49    pub fn new(mean: F, shape: F) -> Result<InverseGaussian<F>, Error> {
50        let zero = F::zero();
51        if !(mean > zero) {
52            return Err(Error::MeanNegativeOrNull);
53        }
54
55        if !(shape > zero) {
56            return Err(Error::ShapeNegativeOrNull);
57        }
58
59        Ok(Self { mean, shape })
60    }
61}
62
63impl<F> Distribution<F> for InverseGaussian<F>
64where
65    F: Float,
66    StandardNormal: Distribution<F>,
67    Standard: Distribution<F>,
68{
69    #[allow(clippy::many_single_char_names)]
70    fn sample<R>(&self, rng: &mut R) -> F
71    where R: Rng + ?Sized {
72        let mu = self.mean;
73        let l = self.shape;
74
75        let v: F = rng.sample(StandardNormal);
76        let y = mu * v * v;
77
78        let mu_2l = mu / (F::from(2.).unwrap() * l);
79
80        let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt());
81
82        let u: F = rng.gen();
83
84        if u <= mu / (mu + x) {
85            return x;
86        }
87
88        mu * mu / x
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_inverse_gaussian() {
98        let inv_gauss = InverseGaussian::new(1.0, 1.0).unwrap();
99        let mut rng = crate::test::rng(210);
100        for _ in 0..1000 {
101            inv_gauss.sample(&mut rng);
102        }
103    }
104
105    #[test]
106    fn test_inverse_gaussian_invalid_param() {
107        assert!(InverseGaussian::new(-1.0, 1.0).is_err());
108        assert!(InverseGaussian::new(-1.0, -1.0).is_err());
109        assert!(InverseGaussian::new(1.0, -1.0).is_err());
110        assert!(InverseGaussian::new(1.0, 1.0).is_ok());
111    }
112}