1use crate::{Distribution, StandardNormal};
12use core::fmt;
13use num_traits::Float;
14use rand::Rng;
15
16#[derive(Clone, Copy, Debug)]
44#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
45pub struct SkewNormal<F>
46where
47 F: Float,
48 StandardNormal: Distribution<F>,
49{
50 location: F,
51 scale: F,
52 shape: F,
53}
54
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum Error {
58 ScaleTooSmall,
60 BadShape,
62}
63
64impl fmt::Display for Error {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.write_str(match self {
67 Error::ScaleTooSmall => {
68 "scale parameter is either non-finite or it is less or equal to zero in skew normal distribution"
69 }
70 Error::BadShape => "shape parameter is non-finite in skew normal distribution",
71 })
72 }
73}
74
75#[cfg(feature = "std")]
76#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
77impl std::error::Error for Error {}
78
79impl<F> SkewNormal<F>
80where
81 F: Float,
82 StandardNormal: Distribution<F>,
83{
84 #[inline]
92 pub fn new(location: F, scale: F, shape: F) -> Result<SkewNormal<F>, Error> {
93 if !scale.is_finite() || !(scale > F::zero()) {
94 return Err(Error::ScaleTooSmall);
95 }
96 if !shape.is_finite() {
97 return Err(Error::BadShape);
98 }
99 Ok(SkewNormal {
100 location,
101 scale,
102 shape,
103 })
104 }
105
106 pub fn location(&self) -> F {
108 self.location
109 }
110
111 pub fn scale(&self) -> F {
113 self.scale
114 }
115
116 pub fn shape(&self) -> F {
118 self.shape
119 }
120}
121
122impl<F> Distribution<F> for SkewNormal<F>
123where
124 F: Float,
125 StandardNormal: Distribution<F>,
126{
127 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
128 let linear_map = |x: F| -> F { x * self.scale + self.location };
129 let u_1: F = rng.sample(StandardNormal);
130 if self.shape == F::zero() {
131 linear_map(u_1)
132 } else {
133 let u_2 = rng.sample(StandardNormal);
134 let (u, v) = (u_1.max(u_2), u_1.min(u_2));
135 if self.shape == -F::one() {
136 linear_map(v)
137 } else if self.shape == F::one() {
138 linear_map(u)
139 } else {
140 let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v)
141 / ((F::one() + self.shape * self.shape).sqrt()
142 * F::from(core::f64::consts::SQRT_2).unwrap());
143 linear_map(normalized)
144 }
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 fn test_samples<F: Float + core::fmt::Debug, D: Distribution<F>>(
154 distr: D, zero: F, expected: &[F],
155 ) {
156 let mut rng = crate::test::rng(213);
157 let mut buf = [zero; 4];
158 for x in &mut buf {
159 *x = rng.sample(&distr);
160 }
161 assert_eq!(buf, expected);
162 }
163
164 #[test]
165 #[should_panic]
166 fn invalid_scale_nan() {
167 SkewNormal::new(0.0, core::f64::NAN, 0.0).unwrap();
168 }
169
170 #[test]
171 #[should_panic]
172 fn invalid_scale_zero() {
173 SkewNormal::new(0.0, 0.0, 0.0).unwrap();
174 }
175
176 #[test]
177 #[should_panic]
178 fn invalid_scale_negative() {
179 SkewNormal::new(0.0, -1.0, 0.0).unwrap();
180 }
181
182 #[test]
183 #[should_panic]
184 fn invalid_scale_infinite() {
185 SkewNormal::new(0.0, core::f64::INFINITY, 0.0).unwrap();
186 }
187
188 #[test]
189 #[should_panic]
190 fn invalid_shape_nan() {
191 SkewNormal::new(0.0, 1.0, core::f64::NAN).unwrap();
192 }
193
194 #[test]
195 #[should_panic]
196 fn invalid_shape_infinite() {
197 SkewNormal::new(0.0, 1.0, core::f64::INFINITY).unwrap();
198 }
199
200 #[test]
201 fn valid_location_nan() {
202 SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap();
203 }
204
205 #[test]
206 fn skew_normal_value_stability() {
207 test_samples(
208 SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
209 0f32,
210 &[-0.11844189, 0.781378, 0.06563994, -1.1932899],
211 );
212 test_samples(
213 SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
214 0f64,
215 &[
216 -0.11844188827977231,
217 0.7813779637772346,
218 0.06563993969580051,
219 -1.1932899004186373,
220 ],
221 );
222 test_samples(
223 SkewNormal::new(core::f64::INFINITY, 1.0, 0.0).unwrap(),
224 0f64,
225 &[
226 core::f64::INFINITY,
227 core::f64::INFINITY,
228 core::f64::INFINITY,
229 core::f64::INFINITY,
230 ],
231 );
232 test_samples(
233 SkewNormal::new(core::f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
234 0f64,
235 &[
236 core::f64::NEG_INFINITY,
237 core::f64::NEG_INFINITY,
238 core::f64::NEG_INFINITY,
239 core::f64::NEG_INFINITY,
240 ],
241 );
242 }
243
244 #[test]
245 fn skew_normal_value_location_nan() {
246 let skew_normal = SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap();
247 let mut rng = crate::test::rng(213);
248 let mut buf = [0.0; 4];
249 for x in &mut buf {
250 *x = rng.sample(&skew_normal);
251 }
252 for value in buf.iter() {
253 assert!(value.is_nan());
254 }
255 }
256}