rand_distr/
skew_normal.rs

1// Copyright 2021 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! The Skew Normal distribution.
10
11use crate::{Distribution, StandardNormal};
12use core::fmt;
13use num_traits::Float;
14use rand::Rng;
15
16/// The [skew normal distribution] `SN(location, scale, shape)`.
17///
18/// The skew normal distribution is a generalization of the
19/// [`Normal`] distribution to allow for non-zero skewness.
20///
21/// It has the density function, for `scale > 0`,
22/// `f(x) = 2 / scale * phi((x - location) / scale) * Phi(alpha * (x - location) / scale)`
23/// where `phi` and `Phi` are the density and distribution of a standard normal variable.
24///
25/// # Example
26///
27/// ```
28/// use rand_distr::{SkewNormal, Distribution};
29///
30/// // location 2, scale 3, shape 1
31/// let skew_normal = SkewNormal::new(2.0, 3.0, 1.0).unwrap();
32/// let v = skew_normal.sample(&mut rand::thread_rng());
33/// println!("{} is from a SN(2, 3, 1) distribution", v)
34/// ```
35///
36/// # Implementation details
37///
38/// We are using the algorithm from [A Method to Simulate the Skew Normal Distribution].
39///
40/// [skew normal distribution]: https://en.wikipedia.org/wiki/Skew_normal_distribution
41/// [`Normal`]: struct.Normal.html
42/// [A Method to Simulate the Skew Normal Distribution]: https://dx.doi.org/10.4236/am.2014.513201
43#[derive(Clone, Copy, Debug)]
44#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
45pub struct SkewNormal<F>
46where
47    F: Float,
48    StandardNormal: Distribution<F>,
49{
50    location: F,
51    scale: F,
52    shape: F,
53}
54
55/// Error type returned from `SkewNormal::new`.
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum Error {
58    /// The scale parameter is not finite or it is less or equal to zero.
59    ScaleTooSmall,
60    /// The shape parameter is not finite.
61    BadShape,
62}
63
64impl fmt::Display for Error {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.write_str(match self {
67            Error::ScaleTooSmall => {
68                "scale parameter is either non-finite or it is less or equal to zero in skew normal distribution"
69            }
70            Error::BadShape => "shape parameter is non-finite in skew normal distribution",
71        })
72    }
73}
74
75#[cfg(feature = "std")]
76#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
77impl std::error::Error for Error {}
78
79impl<F> SkewNormal<F>
80where
81    F: Float,
82    StandardNormal: Distribution<F>,
83{
84    /// Construct, from location, scale and shape.
85    ///
86    /// Parameters:
87    ///
88    /// -   location (unrestricted)
89    /// -   scale (must be finite and larger than zero)
90    /// -   shape (must be finite)
91    #[inline]
92    pub fn new(location: F, scale: F, shape: F) -> Result<SkewNormal<F>, Error> {
93        if !scale.is_finite() || !(scale > F::zero()) {
94            return Err(Error::ScaleTooSmall);
95        }
96        if !shape.is_finite() {
97            return Err(Error::BadShape);
98        }
99        Ok(SkewNormal {
100            location,
101            scale,
102            shape,
103        })
104    }
105
106    /// Returns the location of the distribution.
107    pub fn location(&self) -> F {
108        self.location
109    }
110
111    /// Returns the scale of the distribution.
112    pub fn scale(&self) -> F {
113        self.scale
114    }
115
116    /// Returns the shape of the distribution.
117    pub fn shape(&self) -> F {
118        self.shape
119    }
120}
121
122impl<F> Distribution<F> for SkewNormal<F>
123where
124    F: Float,
125    StandardNormal: Distribution<F>,
126{
127    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
128        let linear_map = |x: F| -> F { x * self.scale + self.location };
129        let u_1: F = rng.sample(StandardNormal);
130        if self.shape == F::zero() {
131            linear_map(u_1)
132        } else {
133            let u_2 = rng.sample(StandardNormal);
134            let (u, v) = (u_1.max(u_2), u_1.min(u_2));
135            if self.shape == -F::one() {
136                linear_map(v)
137            } else if self.shape == F::one() {
138                linear_map(u)
139            } else {
140                let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v)
141                    / ((F::one() + self.shape * self.shape).sqrt()
142                        * F::from(core::f64::consts::SQRT_2).unwrap());
143                linear_map(normalized)
144            }
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn test_samples<F: Float + core::fmt::Debug, D: Distribution<F>>(
154        distr: D, zero: F, expected: &[F],
155    ) {
156        let mut rng = crate::test::rng(213);
157        let mut buf = [zero; 4];
158        for x in &mut buf {
159            *x = rng.sample(&distr);
160        }
161        assert_eq!(buf, expected);
162    }
163
164    #[test]
165    #[should_panic]
166    fn invalid_scale_nan() {
167        SkewNormal::new(0.0, core::f64::NAN, 0.0).unwrap();
168    }
169
170    #[test]
171    #[should_panic]
172    fn invalid_scale_zero() {
173        SkewNormal::new(0.0, 0.0, 0.0).unwrap();
174    }
175
176    #[test]
177    #[should_panic]
178    fn invalid_scale_negative() {
179        SkewNormal::new(0.0, -1.0, 0.0).unwrap();
180    }
181
182    #[test]
183    #[should_panic]
184    fn invalid_scale_infinite() {
185        SkewNormal::new(0.0, core::f64::INFINITY, 0.0).unwrap();
186    }
187
188    #[test]
189    #[should_panic]
190    fn invalid_shape_nan() {
191        SkewNormal::new(0.0, 1.0, core::f64::NAN).unwrap();
192    }
193
194    #[test]
195    #[should_panic]
196    fn invalid_shape_infinite() {
197        SkewNormal::new(0.0, 1.0, core::f64::INFINITY).unwrap();
198    }
199
200    #[test]
201    fn valid_location_nan() {
202        SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap();
203    }
204
205    #[test]
206    fn skew_normal_value_stability() {
207        test_samples(
208            SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
209            0f32,
210            &[-0.11844189, 0.781378, 0.06563994, -1.1932899],
211        );
212        test_samples(
213            SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
214            0f64,
215            &[
216                -0.11844188827977231,
217                0.7813779637772346,
218                0.06563993969580051,
219                -1.1932899004186373,
220            ],
221        );
222        test_samples(
223            SkewNormal::new(core::f64::INFINITY, 1.0, 0.0).unwrap(),
224            0f64,
225            &[
226                core::f64::INFINITY,
227                core::f64::INFINITY,
228                core::f64::INFINITY,
229                core::f64::INFINITY,
230            ],
231        );
232        test_samples(
233            SkewNormal::new(core::f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
234            0f64,
235            &[
236                core::f64::NEG_INFINITY,
237                core::f64::NEG_INFINITY,
238                core::f64::NEG_INFINITY,
239                core::f64::NEG_INFINITY,
240            ],
241        );
242    }
243
244    #[test]
245    fn skew_normal_value_location_nan() {
246        let skew_normal = SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap();
247        let mut rng = crate::test::rng(213);
248        let mut buf = [0.0; 4];
249        for x in &mut buf {
250            *x = rng.sample(&skew_normal);
251        }
252        for value in buf.iter() {
253            assert!(value.is_nan());
254        }
255    }
256}