1use num_traits::Float;
11use crate::{Beta, Distribution, Exp1, Open01, StandardNormal};
12use rand::Rng;
13use core::fmt;
14
15#[derive(Clone, Copy, Debug)]
34#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
35pub struct Pert<F>
36where
37 F: Float,
38 StandardNormal: Distribution<F>,
39 Exp1: Distribution<F>,
40 Open01: Distribution<F>,
41{
42 min: F,
43 range: F,
44 beta: Beta<F>,
45}
46
47#[derive(Clone, Copy, Debug, PartialEq, Eq)]
49pub enum PertError {
50 RangeTooSmall,
52 ModeRange,
54 ShapeTooSmall,
56}
57
58impl fmt::Display for PertError {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.write_str(match self {
61 PertError::RangeTooSmall => "requirement min < max is not met in PERT distribution",
62 PertError::ModeRange => "mode is outside [min, max] in PERT distribution",
63 PertError::ShapeTooSmall => "shape < 0 or is NaN in PERT distribution",
64 })
65 }
66}
67
68#[cfg(feature = "std")]
69#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
70impl std::error::Error for PertError {}
71
72impl<F> Pert<F>
73where
74 F: Float,
75 StandardNormal: Distribution<F>,
76 Exp1: Distribution<F>,
77 Open01: Distribution<F>,
78{
79 #[inline]
83 pub fn new(min: F, max: F, mode: F) -> Result<Pert<F>, PertError> {
84 Pert::new_with_shape(min, max, mode, F::from(4.).unwrap())
85 }
86
87 pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result<Pert<F>, PertError> {
90 if !(max > min) {
91 return Err(PertError::RangeTooSmall);
92 }
93 if !(mode >= min && max >= mode) {
94 return Err(PertError::ModeRange);
95 }
96 if !(shape >= F::from(0.).unwrap()) {
97 return Err(PertError::ShapeTooSmall);
98 }
99
100 let range = max - min;
101 let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap());
102 let v = if mu == mode {
103 shape * F::from(0.5).unwrap() + F::from(1.).unwrap()
104 } else {
105 (mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min))
106 };
107 let w = v * (max - mu) / (mu - min);
108 let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
109 Ok(Pert { min, range, beta })
110 }
111}
112
113impl<F> Distribution<F> for Pert<F>
114where
115 F: Float,
116 StandardNormal: Distribution<F>,
117 Exp1: Distribution<F>,
118 Open01: Distribution<F>,
119{
120 #[inline]
121 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
122 self.beta.sample(rng) * self.range + self.min
123 }
124}
125
126#[cfg(test)]
127mod test {
128 use super::*;
129
130 #[test]
131 fn test_pert() {
132 for &(min, max, mode) in &[
133 (-1., 1., 0.),
134 (1., 2., 1.),
135 (5., 25., 25.),
136 ] {
137 let _distr = Pert::new(min, max, mode).unwrap();
138 }
140
141 for &(min, max, mode) in &[
142 (-1., 1., 2.),
143 (-1., 1., -2.),
144 (2., 1., 1.),
145 ] {
146 assert!(Pert::new(min, max, mode).is_err());
147 }
148 }
149}