rand_distr/
pert.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//! The PERT distribution.
9
10use num_traits::Float;
11use crate::{Beta, Distribution, Exp1, Open01, StandardNormal};
12use rand::Rng;
13use core::fmt;
14
15/// The PERT distribution.
16///
17/// Similar to the [`Triangular`] distribution, the PERT distribution is
18/// parameterised by a range and a mode within that range. Unlike the
19/// [`Triangular`] distribution, the probability density function of the PERT
20/// distribution is smooth, with a configurable weighting around the mode.
21///
22/// # Example
23///
24/// ```rust
25/// use rand_distr::{Pert, Distribution};
26///
27/// let d = Pert::new(0., 5., 2.5).unwrap();
28/// let v = d.sample(&mut rand::thread_rng());
29/// println!("{} is from a PERT distribution", v);
30/// ```
31///
32/// [`Triangular`]: crate::Triangular
33#[derive(Clone, Copy, Debug)]
34#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
35pub struct Pert<F>
36where
37    F: Float,
38    StandardNormal: Distribution<F>,
39    Exp1: Distribution<F>,
40    Open01: Distribution<F>,
41{
42    min: F,
43    range: F,
44    beta: Beta<F>,
45}
46
47/// Error type returned from [`Pert`] constructors.
48#[derive(Clone, Copy, Debug, PartialEq, Eq)]
49pub enum PertError {
50    /// `max < min` or `min` or `max` is NaN.
51    RangeTooSmall,
52    /// `mode < min` or `mode > max` or `mode` is NaN.
53    ModeRange,
54    /// `shape < 0` or `shape` is NaN
55    ShapeTooSmall,
56}
57
58impl fmt::Display for PertError {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.write_str(match self {
61            PertError::RangeTooSmall => "requirement min < max is not met in PERT distribution",
62            PertError::ModeRange => "mode is outside [min, max] in PERT distribution",
63            PertError::ShapeTooSmall => "shape < 0 or is NaN in PERT distribution",
64        })
65    }
66}
67
68#[cfg(feature = "std")]
69#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
70impl std::error::Error for PertError {}
71
72impl<F> Pert<F>
73where
74    F: Float,
75    StandardNormal: Distribution<F>,
76    Exp1: Distribution<F>,
77    Open01: Distribution<F>,
78{
79    /// Set up the PERT distribution with defined `min`, `max` and `mode`.
80    ///
81    /// This is equivalent to calling `Pert::new_shape` with `shape == 4.0`.
82    #[inline]
83    pub fn new(min: F, max: F, mode: F) -> Result<Pert<F>, PertError> {
84        Pert::new_with_shape(min, max, mode, F::from(4.).unwrap())
85    }
86
87    /// Set up the PERT distribution with defined `min`, `max`, `mode` and
88    /// `shape`.
89    pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result<Pert<F>, PertError> {
90        if !(max > min) {
91            return Err(PertError::RangeTooSmall);
92        }
93        if !(mode >= min && max >= mode) {
94            return Err(PertError::ModeRange);
95        }
96        if !(shape >= F::from(0.).unwrap()) {
97            return Err(PertError::ShapeTooSmall);
98        }
99
100        let range = max - min;
101        let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap());
102        let v = if mu == mode {
103            shape * F::from(0.5).unwrap() + F::from(1.).unwrap()
104        } else {
105            (mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min))
106        };
107        let w = v * (max - mu) / (mu - min);
108        let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
109        Ok(Pert { min, range, beta })
110    }
111}
112
113impl<F> Distribution<F> for Pert<F>
114where
115    F: Float,
116    StandardNormal: Distribution<F>,
117    Exp1: Distribution<F>,
118    Open01: Distribution<F>,
119{
120    #[inline]
121    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
122        self.beta.sample(rng) * self.range + self.min
123    }
124}
125
126#[cfg(test)]
127mod test {
128    use super::*;
129
130    #[test]
131    fn test_pert() {
132        for &(min, max, mode) in &[
133            (-1., 1., 0.),
134            (1., 2., 1.),
135            (5., 25., 25.),
136        ] {
137            let _distr = Pert::new(min, max, mode).unwrap();
138            // TODO: test correctness
139        }
140
141        for &(min, max, mode) in &[
142            (-1., 1., 2.),
143            (-1., 1., -2.),
144            (2., 1., 1.),
145        ] {
146            assert!(Pert::new(min, max, mode).is_err());
147        }
148    }
149}