1use crate::{Distribution, OpenClosed01};
12use core::fmt;
13use num_traits::Float;
14use rand::Rng;
15
16#[derive(Clone, Copy, Debug)]
31#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
32pub struct Gumbel<F>
33where
34 F: Float,
35 OpenClosed01: Distribution<F>,
36{
37 location: F,
38 scale: F,
39}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum Error {
44 LocationNotFinite,
46 ScaleNotPositive,
48}
49
50impl fmt::Display for Error {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.write_str(match self {
53 Error::ScaleNotPositive => "scale is not positive and finite in Gumbel distribution",
54 Error::LocationNotFinite => "location is not finite in Gumbel distribution",
55 })
56 }
57}
58
59#[cfg(feature = "std")]
60#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
61impl std::error::Error for Error {}
62
63impl<F> Gumbel<F>
64where
65 F: Float,
66 OpenClosed01: Distribution<F>,
67{
68 pub fn new(location: F, scale: F) -> Result<Gumbel<F>, Error> {
70 if scale <= F::zero() || scale.is_infinite() || scale.is_nan() {
71 return Err(Error::ScaleNotPositive);
72 }
73 if location.is_infinite() || location.is_nan() {
74 return Err(Error::LocationNotFinite);
75 }
76 Ok(Gumbel { location, scale })
77 }
78}
79
80impl<F> Distribution<F> for Gumbel<F>
81where
82 F: Float,
83 OpenClosed01: Distribution<F>,
84{
85 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
86 let x: F = rng.sample(OpenClosed01);
87 self.location - self.scale * (-x.ln()).ln()
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[test]
96 #[should_panic]
97 fn test_zero_scale() {
98 Gumbel::new(0.0, 0.0).unwrap();
99 }
100
101 #[test]
102 #[should_panic]
103 fn test_infinite_scale() {
104 Gumbel::new(0.0, core::f64::INFINITY).unwrap();
105 }
106
107 #[test]
108 #[should_panic]
109 fn test_nan_scale() {
110 Gumbel::new(0.0, core::f64::NAN).unwrap();
111 }
112
113 #[test]
114 #[should_panic]
115 fn test_infinite_location() {
116 Gumbel::new(core::f64::INFINITY, 1.0).unwrap();
117 }
118
119 #[test]
120 #[should_panic]
121 fn test_nan_location() {
122 Gumbel::new(core::f64::NAN, 1.0).unwrap();
123 }
124
125 #[test]
126 fn test_sample_against_cdf() {
127 fn neg_log_log(x: f64) -> f64 {
128 -(-x.ln()).ln()
129 }
130 let location = 0.0;
131 let scale = 1.0;
132 let iterations = 100_000;
133 let increment = 1.0 / iterations as f64;
134 let probabilities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
135 let mut quantiles = [0.0; 9];
136 for (i, p) in probabilities.iter().enumerate() {
137 quantiles[i] = neg_log_log(*p);
138 }
139 let mut proportions = [0.0; 9];
140 let d = Gumbel::new(location, scale).unwrap();
141 let mut rng = crate::test::rng(1);
142 for _ in 0..iterations {
143 let replicate = d.sample(&mut rng);
144 for (i, q) in quantiles.iter().enumerate() {
145 if replicate < *q {
146 proportions[i] += increment;
147 }
148 }
149 }
150 assert!(proportions
151 .iter()
152 .zip(&probabilities)
153 .all(|(p_hat, p)| (p_hat - p).abs() < 0.003))
154 }
155}