rand_distr/
normal.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The normal and derived distributions.
11
12use crate::utils::ziggurat;
13use num_traits::Float;
14use crate::{ziggurat_tables, Distribution, Open01};
15use rand::Rng;
16use core::fmt;
17
18/// Samples floating-point numbers according to the normal distribution
19/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to
20/// `Normal::new(0.0, 1.0)` but faster.
21///
22/// See `Normal` for the general normal distribution.
23///
24/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method.
25///
26/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to
27///       Generate Normal Random Samples*](
28///       https://www.doornik.com/research/ziggurat.pdf).
29///       Nuffield College, Oxford
30///
31/// # Example
32/// ```
33/// use rand::prelude::*;
34/// use rand_distr::StandardNormal;
35///
36/// let val: f64 = thread_rng().sample(StandardNormal);
37/// println!("{}", val);
38/// ```
39#[derive(Clone, Copy, Debug)]
40#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
41pub struct StandardNormal;
42
43impl Distribution<f32> for StandardNormal {
44    #[inline]
45    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
46        // TODO: use optimal 32-bit implementation
47        let x: f64 = self.sample(rng);
48        x as f32
49    }
50}
51
52impl Distribution<f64> for StandardNormal {
53    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
54        #[inline]
55        fn pdf(x: f64) -> f64 {
56            (-x * x / 2.0).exp()
57        }
58        #[inline]
59        fn zero_case<R: Rng + ?Sized>(rng: &mut R, u: f64) -> f64 {
60            // compute a random number in the tail by hand
61
62            // strange initial conditions, because the loop is not
63            // do-while, so the condition should be true on the first
64            // run, they get overwritten anyway (0 < 1, so these are
65            // good).
66            let mut x = 1.0f64;
67            let mut y = 0.0f64;
68
69            while -2.0 * y < x * x {
70                let x_: f64 = rng.sample(Open01);
71                let y_: f64 = rng.sample(Open01);
72
73                x = x_.ln() / ziggurat_tables::ZIG_NORM_R;
74                y = y_.ln();
75            }
76
77            if u < 0.0 {
78                x - ziggurat_tables::ZIG_NORM_R
79            } else {
80                ziggurat_tables::ZIG_NORM_R - x
81            }
82        }
83
84        ziggurat(
85            rng,
86            true, // this is symmetric
87            &ziggurat_tables::ZIG_NORM_X,
88            &ziggurat_tables::ZIG_NORM_F,
89            pdf,
90            zero_case,
91        )
92    }
93}
94
95/// The normal distribution `N(mean, std_dev**2)`.
96///
97/// This uses the ZIGNOR variant of the Ziggurat method, see [`StandardNormal`]
98/// for more details.
99///
100/// Note that [`StandardNormal`] is an optimised implementation for mean 0, and
101/// standard deviation 1.
102///
103/// # Example
104///
105/// ```
106/// use rand_distr::{Normal, Distribution};
107///
108/// // mean 2, standard deviation 3
109/// let normal = Normal::new(2.0, 3.0).unwrap();
110/// let v = normal.sample(&mut rand::thread_rng());
111/// println!("{} is from a N(2, 9) distribution", v)
112/// ```
113///
114/// [`StandardNormal`]: crate::StandardNormal
115#[derive(Clone, Copy, Debug)]
116#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
117pub struct Normal<F>
118where F: Float, StandardNormal: Distribution<F>
119{
120    mean: F,
121    std_dev: F,
122}
123
124/// Error type returned from `Normal::new` and `LogNormal::new`.
125#[derive(Clone, Copy, Debug, PartialEq, Eq)]
126pub enum Error {
127    /// The mean value is too small (log-normal samples must be positive)
128    MeanTooSmall,
129    /// The standard deviation or other dispersion parameter is not finite.
130    BadVariance,
131}
132
133impl fmt::Display for Error {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.write_str(match self {
136            Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
137            Error::BadVariance => "variation parameter is non-finite in (log)normal distribution",
138        })
139    }
140}
141
142#[cfg(feature = "std")]
143#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
144impl std::error::Error for Error {}
145
146impl<F> Normal<F>
147where F: Float, StandardNormal: Distribution<F>
148{
149    /// Construct, from mean and standard deviation
150    ///
151    /// Parameters:
152    ///
153    /// -   mean (`μ`, unrestricted)
154    /// -   standard deviation (`σ`, must be finite)
155    #[inline]
156    pub fn new(mean: F, std_dev: F) -> Result<Normal<F>, Error> {
157        if !std_dev.is_finite() {
158            return Err(Error::BadVariance);
159        }
160        Ok(Normal { mean, std_dev })
161    }
162
163    /// Construct, from mean and coefficient of variation
164    ///
165    /// Parameters:
166    ///
167    /// -   mean (`μ`, unrestricted)
168    /// -   coefficient of variation (`cv = abs(σ / μ)`)
169    #[inline]
170    pub fn from_mean_cv(mean: F, cv: F) -> Result<Normal<F>, Error> {
171        if !cv.is_finite() || cv < F::zero() {
172            return Err(Error::BadVariance);
173        }
174        let std_dev = cv * mean;
175        Ok(Normal { mean, std_dev })
176    }
177
178    /// Sample from a z-score
179    ///
180    /// This may be useful for generating correlated samples `x1` and `x2`
181    /// from two different distributions, as follows.
182    /// ```
183    /// # use rand::prelude::*;
184    /// # use rand_distr::{Normal, StandardNormal};
185    /// let mut rng = thread_rng();
186    /// let z = StandardNormal.sample(&mut rng);
187    /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z);
188    /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z);
189    /// ```
190    #[inline]
191    pub fn from_zscore(&self, zscore: F) -> F {
192        self.mean + self.std_dev * zscore
193    }
194
195    /// Returns the mean (`μ`) of the distribution.
196    pub fn mean(&self) -> F {
197        self.mean
198    }
199
200    /// Returns the standard deviation (`σ`) of the distribution.
201    pub fn std_dev(&self) -> F {
202        self.std_dev
203    }
204}
205
206impl<F> Distribution<F> for Normal<F>
207where F: Float, StandardNormal: Distribution<F>
208{
209    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
210        self.from_zscore(rng.sample(StandardNormal))
211    }
212}
213
214
215/// The log-normal distribution `ln N(mean, std_dev**2)`.
216///
217/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)`
218/// distributed.
219///
220/// # Example
221///
222/// ```
223/// use rand_distr::{LogNormal, Distribution};
224///
225/// // mean 2, standard deviation 3
226/// let log_normal = LogNormal::new(2.0, 3.0).unwrap();
227/// let v = log_normal.sample(&mut rand::thread_rng());
228/// println!("{} is from an ln N(2, 9) distribution", v)
229/// ```
230#[derive(Clone, Copy, Debug)]
231#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
232pub struct LogNormal<F>
233where F: Float, StandardNormal: Distribution<F>
234{
235    norm: Normal<F>,
236}
237
238impl<F> LogNormal<F>
239where F: Float, StandardNormal: Distribution<F>
240{
241    /// Construct, from (log-space) mean and standard deviation
242    ///
243    /// Parameters are the "standard" log-space measures (these are the mean
244    /// and standard deviation of the logarithm of samples):
245    ///
246    /// -   `mu` (`μ`, unrestricted) is the mean of the underlying distribution
247    /// -   `sigma` (`σ`, must be finite) is the standard deviation of the
248    ///     underlying Normal distribution
249    #[inline]
250    pub fn new(mu: F, sigma: F) -> Result<LogNormal<F>, Error> {
251        let norm = Normal::new(mu, sigma)?;
252        Ok(LogNormal { norm })
253    }
254
255    /// Construct, from (linear-space) mean and coefficient of variation
256    ///
257    /// Parameters are linear-space measures:
258    ///
259    /// -   mean (`μ > 0`) is the (real) mean of the distribution
260    /// -   coefficient of variation (`cv = σ / μ`, requiring `cv ≥ 0`) is a
261    ///     standardized measure of dispersion
262    ///
263    /// As a special exception, `μ = 0, cv = 0` is allowed (samples are `-inf`).
264    #[inline]
265    pub fn from_mean_cv(mean: F, cv: F) -> Result<LogNormal<F>, Error> {
266        if cv == F::zero() {
267            let mu = mean.ln();
268            let norm = Normal::new(mu, F::zero()).unwrap();
269            return Ok(LogNormal { norm });
270        }
271        if !(mean > F::zero()) {
272            return Err(Error::MeanTooSmall);
273        }
274        if !(cv >= F::zero()) {
275            return Err(Error::BadVariance);
276        }
277
278        // Using X ~ lognormal(μ, σ), CV² = Var(X) / E(X)²
279        // E(X) = exp(μ + σ² / 2) = exp(μ) × exp(σ² / 2)
280        // Var(X) = exp(2μ + σ²)(exp(σ²) - 1) = E(X)² × (exp(σ²) - 1)
281        // but Var(X) = (CV × E(X))² so CV² = exp(σ²) - 1
282        // thus σ² = log(CV² + 1)
283        // and exp(μ) = E(X) / exp(σ² / 2) = E(X) / sqrt(CV² + 1)
284        let a = F::one() + cv * cv; // e
285        let mu = F::from(0.5).unwrap() * (mean * mean / a).ln();
286        let sigma = a.ln().sqrt();
287        let norm = Normal::new(mu, sigma)?;
288        Ok(LogNormal { norm })
289    }
290
291    /// Sample from a z-score
292    ///
293    /// This may be useful for generating correlated samples `x1` and `x2`
294    /// from two different distributions, as follows.
295    /// ```
296    /// # use rand::prelude::*;
297    /// # use rand_distr::{LogNormal, StandardNormal};
298    /// let mut rng = thread_rng();
299    /// let z = StandardNormal.sample(&mut rng);
300    /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z);
301    /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z);
302    /// ```
303    #[inline]
304    pub fn from_zscore(&self, zscore: F) -> F {
305        self.norm.from_zscore(zscore).exp()
306    }
307}
308
309impl<F> Distribution<F> for LogNormal<F>
310where F: Float, StandardNormal: Distribution<F>
311{
312    #[inline]
313    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
314        self.norm.sample(rng).exp()
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_normal() {
324        let norm = Normal::new(10.0, 10.0).unwrap();
325        let mut rng = crate::test::rng(210);
326        for _ in 0..1000 {
327            norm.sample(&mut rng);
328        }
329    }
330    #[test]
331    fn test_normal_cv() {
332        let norm = Normal::from_mean_cv(1024.0, 1.0 / 256.0).unwrap();
333        assert_eq!((norm.mean, norm.std_dev), (1024.0, 4.0));
334    }
335    #[test]
336    fn test_normal_invalid_sd() {
337        assert!(Normal::from_mean_cv(10.0, -1.0).is_err());
338    }
339
340    #[test]
341    fn test_log_normal() {
342        let lnorm = LogNormal::new(10.0, 10.0).unwrap();
343        let mut rng = crate::test::rng(211);
344        for _ in 0..1000 {
345            lnorm.sample(&mut rng);
346        }
347    }
348    #[test]
349    fn test_log_normal_cv() {
350        let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
351        assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (-core::f64::INFINITY, 0.0));
352
353        let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
354        assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));
355
356        let e = core::f64::consts::E;
357        let lnorm = LogNormal::from_mean_cv(e.sqrt(), (e - 1.0).sqrt()).unwrap();
358        assert_almost_eq!(lnorm.norm.mean, 0.0, 2e-16);
359        assert_almost_eq!(lnorm.norm.std_dev, 1.0, 2e-16);
360
361        let lnorm = LogNormal::from_mean_cv(e.powf(1.5), (e - 1.0).sqrt()).unwrap();
362        assert_almost_eq!(lnorm.norm.mean, 1.0, 1e-15);
363        assert_eq!(lnorm.norm.std_dev, 1.0);
364    }
365    #[test]
366    fn test_log_normal_invalid_sd() {
367        assert!(LogNormal::from_mean_cv(-1.0, 1.0).is_err());
368        assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
369        assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
370    }
371}