rand_distr/normal.rs
1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The normal and derived distributions.
11
12use crate::utils::ziggurat;
13use num_traits::Float;
14use crate::{ziggurat_tables, Distribution, Open01};
15use rand::Rng;
16use core::fmt;
17
18/// Samples floating-point numbers according to the normal distribution
19/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to
20/// `Normal::new(0.0, 1.0)` but faster.
21///
22/// See `Normal` for the general normal distribution.
23///
24/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method.
25///
26/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to
27/// Generate Normal Random Samples*](
28/// https://www.doornik.com/research/ziggurat.pdf).
29/// Nuffield College, Oxford
30///
31/// # Example
32/// ```
33/// use rand::prelude::*;
34/// use rand_distr::StandardNormal;
35///
36/// let val: f64 = thread_rng().sample(StandardNormal);
37/// println!("{}", val);
38/// ```
39#[derive(Clone, Copy, Debug)]
40#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
41pub struct StandardNormal;
42
43impl Distribution<f32> for StandardNormal {
44 #[inline]
45 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
46 // TODO: use optimal 32-bit implementation
47 let x: f64 = self.sample(rng);
48 x as f32
49 }
50}
51
52impl Distribution<f64> for StandardNormal {
53 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
54 #[inline]
55 fn pdf(x: f64) -> f64 {
56 (-x * x / 2.0).exp()
57 }
58 #[inline]
59 fn zero_case<R: Rng + ?Sized>(rng: &mut R, u: f64) -> f64 {
60 // compute a random number in the tail by hand
61
62 // strange initial conditions, because the loop is not
63 // do-while, so the condition should be true on the first
64 // run, they get overwritten anyway (0 < 1, so these are
65 // good).
66 let mut x = 1.0f64;
67 let mut y = 0.0f64;
68
69 while -2.0 * y < x * x {
70 let x_: f64 = rng.sample(Open01);
71 let y_: f64 = rng.sample(Open01);
72
73 x = x_.ln() / ziggurat_tables::ZIG_NORM_R;
74 y = y_.ln();
75 }
76
77 if u < 0.0 {
78 x - ziggurat_tables::ZIG_NORM_R
79 } else {
80 ziggurat_tables::ZIG_NORM_R - x
81 }
82 }
83
84 ziggurat(
85 rng,
86 true, // this is symmetric
87 &ziggurat_tables::ZIG_NORM_X,
88 &ziggurat_tables::ZIG_NORM_F,
89 pdf,
90 zero_case,
91 )
92 }
93}
94
95/// The normal distribution `N(mean, std_dev**2)`.
96///
97/// This uses the ZIGNOR variant of the Ziggurat method, see [`StandardNormal`]
98/// for more details.
99///
100/// Note that [`StandardNormal`] is an optimised implementation for mean 0, and
101/// standard deviation 1.
102///
103/// # Example
104///
105/// ```
106/// use rand_distr::{Normal, Distribution};
107///
108/// // mean 2, standard deviation 3
109/// let normal = Normal::new(2.0, 3.0).unwrap();
110/// let v = normal.sample(&mut rand::thread_rng());
111/// println!("{} is from a N(2, 9) distribution", v)
112/// ```
113///
114/// [`StandardNormal`]: crate::StandardNormal
115#[derive(Clone, Copy, Debug)]
116#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
117pub struct Normal<F>
118where F: Float, StandardNormal: Distribution<F>
119{
120 mean: F,
121 std_dev: F,
122}
123
124/// Error type returned from `Normal::new` and `LogNormal::new`.
125#[derive(Clone, Copy, Debug, PartialEq, Eq)]
126pub enum Error {
127 /// The mean value is too small (log-normal samples must be positive)
128 MeanTooSmall,
129 /// The standard deviation or other dispersion parameter is not finite.
130 BadVariance,
131}
132
133impl fmt::Display for Error {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 f.write_str(match self {
136 Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
137 Error::BadVariance => "variation parameter is non-finite in (log)normal distribution",
138 })
139 }
140}
141
142#[cfg(feature = "std")]
143#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
144impl std::error::Error for Error {}
145
146impl<F> Normal<F>
147where F: Float, StandardNormal: Distribution<F>
148{
149 /// Construct, from mean and standard deviation
150 ///
151 /// Parameters:
152 ///
153 /// - mean (`μ`, unrestricted)
154 /// - standard deviation (`σ`, must be finite)
155 #[inline]
156 pub fn new(mean: F, std_dev: F) -> Result<Normal<F>, Error> {
157 if !std_dev.is_finite() {
158 return Err(Error::BadVariance);
159 }
160 Ok(Normal { mean, std_dev })
161 }
162
163 /// Construct, from mean and coefficient of variation
164 ///
165 /// Parameters:
166 ///
167 /// - mean (`μ`, unrestricted)
168 /// - coefficient of variation (`cv = abs(σ / μ)`)
169 #[inline]
170 pub fn from_mean_cv(mean: F, cv: F) -> Result<Normal<F>, Error> {
171 if !cv.is_finite() || cv < F::zero() {
172 return Err(Error::BadVariance);
173 }
174 let std_dev = cv * mean;
175 Ok(Normal { mean, std_dev })
176 }
177
178 /// Sample from a z-score
179 ///
180 /// This may be useful for generating correlated samples `x1` and `x2`
181 /// from two different distributions, as follows.
182 /// ```
183 /// # use rand::prelude::*;
184 /// # use rand_distr::{Normal, StandardNormal};
185 /// let mut rng = thread_rng();
186 /// let z = StandardNormal.sample(&mut rng);
187 /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z);
188 /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z);
189 /// ```
190 #[inline]
191 pub fn from_zscore(&self, zscore: F) -> F {
192 self.mean + self.std_dev * zscore
193 }
194
195 /// Returns the mean (`μ`) of the distribution.
196 pub fn mean(&self) -> F {
197 self.mean
198 }
199
200 /// Returns the standard deviation (`σ`) of the distribution.
201 pub fn std_dev(&self) -> F {
202 self.std_dev
203 }
204}
205
206impl<F> Distribution<F> for Normal<F>
207where F: Float, StandardNormal: Distribution<F>
208{
209 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
210 self.from_zscore(rng.sample(StandardNormal))
211 }
212}
213
214
215/// The log-normal distribution `ln N(mean, std_dev**2)`.
216///
217/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)`
218/// distributed.
219///
220/// # Example
221///
222/// ```
223/// use rand_distr::{LogNormal, Distribution};
224///
225/// // mean 2, standard deviation 3
226/// let log_normal = LogNormal::new(2.0, 3.0).unwrap();
227/// let v = log_normal.sample(&mut rand::thread_rng());
228/// println!("{} is from an ln N(2, 9) distribution", v)
229/// ```
230#[derive(Clone, Copy, Debug)]
231#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
232pub struct LogNormal<F>
233where F: Float, StandardNormal: Distribution<F>
234{
235 norm: Normal<F>,
236}
237
238impl<F> LogNormal<F>
239where F: Float, StandardNormal: Distribution<F>
240{
241 /// Construct, from (log-space) mean and standard deviation
242 ///
243 /// Parameters are the "standard" log-space measures (these are the mean
244 /// and standard deviation of the logarithm of samples):
245 ///
246 /// - `mu` (`μ`, unrestricted) is the mean of the underlying distribution
247 /// - `sigma` (`σ`, must be finite) is the standard deviation of the
248 /// underlying Normal distribution
249 #[inline]
250 pub fn new(mu: F, sigma: F) -> Result<LogNormal<F>, Error> {
251 let norm = Normal::new(mu, sigma)?;
252 Ok(LogNormal { norm })
253 }
254
255 /// Construct, from (linear-space) mean and coefficient of variation
256 ///
257 /// Parameters are linear-space measures:
258 ///
259 /// - mean (`μ > 0`) is the (real) mean of the distribution
260 /// - coefficient of variation (`cv = σ / μ`, requiring `cv ≥ 0`) is a
261 /// standardized measure of dispersion
262 ///
263 /// As a special exception, `μ = 0, cv = 0` is allowed (samples are `-inf`).
264 #[inline]
265 pub fn from_mean_cv(mean: F, cv: F) -> Result<LogNormal<F>, Error> {
266 if cv == F::zero() {
267 let mu = mean.ln();
268 let norm = Normal::new(mu, F::zero()).unwrap();
269 return Ok(LogNormal { norm });
270 }
271 if !(mean > F::zero()) {
272 return Err(Error::MeanTooSmall);
273 }
274 if !(cv >= F::zero()) {
275 return Err(Error::BadVariance);
276 }
277
278 // Using X ~ lognormal(μ, σ), CV² = Var(X) / E(X)²
279 // E(X) = exp(μ + σ² / 2) = exp(μ) × exp(σ² / 2)
280 // Var(X) = exp(2μ + σ²)(exp(σ²) - 1) = E(X)² × (exp(σ²) - 1)
281 // but Var(X) = (CV × E(X))² so CV² = exp(σ²) - 1
282 // thus σ² = log(CV² + 1)
283 // and exp(μ) = E(X) / exp(σ² / 2) = E(X) / sqrt(CV² + 1)
284 let a = F::one() + cv * cv; // e
285 let mu = F::from(0.5).unwrap() * (mean * mean / a).ln();
286 let sigma = a.ln().sqrt();
287 let norm = Normal::new(mu, sigma)?;
288 Ok(LogNormal { norm })
289 }
290
291 /// Sample from a z-score
292 ///
293 /// This may be useful for generating correlated samples `x1` and `x2`
294 /// from two different distributions, as follows.
295 /// ```
296 /// # use rand::prelude::*;
297 /// # use rand_distr::{LogNormal, StandardNormal};
298 /// let mut rng = thread_rng();
299 /// let z = StandardNormal.sample(&mut rng);
300 /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z);
301 /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z);
302 /// ```
303 #[inline]
304 pub fn from_zscore(&self, zscore: F) -> F {
305 self.norm.from_zscore(zscore).exp()
306 }
307}
308
309impl<F> Distribution<F> for LogNormal<F>
310where F: Float, StandardNormal: Distribution<F>
311{
312 #[inline]
313 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
314 self.norm.sample(rng).exp()
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_normal() {
324 let norm = Normal::new(10.0, 10.0).unwrap();
325 let mut rng = crate::test::rng(210);
326 for _ in 0..1000 {
327 norm.sample(&mut rng);
328 }
329 }
330 #[test]
331 fn test_normal_cv() {
332 let norm = Normal::from_mean_cv(1024.0, 1.0 / 256.0).unwrap();
333 assert_eq!((norm.mean, norm.std_dev), (1024.0, 4.0));
334 }
335 #[test]
336 fn test_normal_invalid_sd() {
337 assert!(Normal::from_mean_cv(10.0, -1.0).is_err());
338 }
339
340 #[test]
341 fn test_log_normal() {
342 let lnorm = LogNormal::new(10.0, 10.0).unwrap();
343 let mut rng = crate::test::rng(211);
344 for _ in 0..1000 {
345 lnorm.sample(&mut rng);
346 }
347 }
348 #[test]
349 fn test_log_normal_cv() {
350 let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
351 assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (-core::f64::INFINITY, 0.0));
352
353 let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
354 assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));
355
356 let e = core::f64::consts::E;
357 let lnorm = LogNormal::from_mean_cv(e.sqrt(), (e - 1.0).sqrt()).unwrap();
358 assert_almost_eq!(lnorm.norm.mean, 0.0, 2e-16);
359 assert_almost_eq!(lnorm.norm.std_dev, 1.0, 2e-16);
360
361 let lnorm = LogNormal::from_mean_cv(e.powf(1.5), (e - 1.0).sqrt()).unwrap();
362 assert_almost_eq!(lnorm.norm.mean, 1.0, 1e-15);
363 assert_eq!(lnorm.norm.std_dev, 1.0);
364 }
365 #[test]
366 fn test_log_normal_invalid_sd() {
367 assert!(LogNormal::from_mean_cv(-1.0, 1.0).is_err());
368 assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
369 assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
370 }
371}