rand_distr/
normal_inverse_gaussian.rs1use crate::{Distribution, InverseGaussian, Standard, StandardNormal};
2use num_traits::Float;
3use rand::Rng;
4use core::fmt;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Error {
9 AlphaNegativeOrNull,
11 AbsoluteBetaNotLessThanAlpha,
13}
14
15impl fmt::Display for Error {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 f.write_str(match self {
18 Error::AlphaNegativeOrNull => "alpha <= 0 or is NaN in normal inverse Gaussian distribution",
19 Error::AbsoluteBetaNotLessThanAlpha => "|beta| >= alpha or is NaN in normal inverse Gaussian distribution",
20 })
21 }
22}
23
24#[cfg(feature = "std")]
25#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
26impl std::error::Error for Error {}
27
28#[derive(Debug, Clone, Copy)]
30#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
31pub struct NormalInverseGaussian<F>
32where
33 F: Float,
34 StandardNormal: Distribution<F>,
35 Standard: Distribution<F>,
36{
37 alpha: F,
38 beta: F,
39 inverse_gaussian: InverseGaussian<F>,
40}
41
42impl<F> NormalInverseGaussian<F>
43where
44 F: Float,
45 StandardNormal: Distribution<F>,
46 Standard: Distribution<F>,
47{
48 pub fn new(alpha: F, beta: F) -> Result<NormalInverseGaussian<F>, Error> {
51 if !(alpha > F::zero()) {
52 return Err(Error::AlphaNegativeOrNull);
53 }
54
55 if !(beta.abs() < alpha) {
56 return Err(Error::AbsoluteBetaNotLessThanAlpha);
57 }
58
59 let gamma = (alpha * alpha - beta * beta).sqrt();
60
61 let mu = F::one() / gamma;
62
63 let inverse_gaussian = InverseGaussian::new(mu, F::one()).unwrap();
64
65 Ok(Self {
66 alpha,
67 beta,
68 inverse_gaussian,
69 })
70 }
71}
72
73impl<F> Distribution<F> for NormalInverseGaussian<F>
74where
75 F: Float,
76 StandardNormal: Distribution<F>,
77 Standard: Distribution<F>,
78{
79 fn sample<R>(&self, rng: &mut R) -> F
80 where R: Rng + ?Sized {
81 let inv_gauss = rng.sample(&self.inverse_gaussian);
82
83 self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal)
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_normal_inverse_gaussian() {
93 let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap();
94 let mut rng = crate::test::rng(210);
95 for _ in 0..1000 {
96 norm_inv_gauss.sample(&mut rng);
97 }
98 }
99
100 #[test]
101 fn test_normal_inverse_gaussian_invalid_param() {
102 assert!(NormalInverseGaussian::new(-1.0, 1.0).is_err());
103 assert!(NormalInverseGaussian::new(-1.0, -1.0).is_err());
104 assert!(NormalInverseGaussian::new(1.0, 2.0).is_err());
105 assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok());
106 }
107}