1#![allow(clippy::many_single_char_names)]
15
16use self::ChiSquaredRepr::*;
17use self::GammaRepr::*;
18
19use crate::normal::StandardNormal;
20use num_traits::Float;
21use crate::{Distribution, Exp, Exp1, Open01};
22use rand::Rng;
23use core::fmt;
24#[cfg(feature = "serde1")]
25use serde::{Serialize, Deserialize};
26
27#[derive(Clone, Copy, Debug)]
58#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
59pub struct Gamma<F>
60where
61 F: Float,
62 StandardNormal: Distribution<F>,
63 Exp1: Distribution<F>,
64 Open01: Distribution<F>,
65{
66 repr: GammaRepr<F>,
67}
68
69#[derive(Clone, Copy, Debug, PartialEq, Eq)]
71pub enum Error {
72 ShapeTooSmall,
74 ScaleTooSmall,
76 ScaleTooLarge,
78}
79
80impl fmt::Display for Error {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.write_str(match self {
83 Error::ShapeTooSmall => "shape is not positive in gamma distribution",
84 Error::ScaleTooSmall => "scale is not positive in gamma distribution",
85 Error::ScaleTooLarge => "scale is infinity in gamma distribution",
86 })
87 }
88}
89
90#[cfg(feature = "std")]
91#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
92impl std::error::Error for Error {}
93
94#[derive(Clone, Copy, Debug)]
95#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
96enum GammaRepr<F>
97where
98 F: Float,
99 StandardNormal: Distribution<F>,
100 Exp1: Distribution<F>,
101 Open01: Distribution<F>,
102{
103 Large(GammaLargeShape<F>),
104 One(Exp<F>),
105 Small(GammaSmallShape<F>),
106}
107
108#[derive(Clone, Copy, Debug)]
123#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
124struct GammaSmallShape<F>
125where
126 F: Float,
127 StandardNormal: Distribution<F>,
128 Open01: Distribution<F>,
129{
130 inv_shape: F,
131 large_shape: GammaLargeShape<F>,
132}
133
134#[derive(Clone, Copy, Debug)]
139#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
140struct GammaLargeShape<F>
141where
142 F: Float,
143 StandardNormal: Distribution<F>,
144 Open01: Distribution<F>,
145{
146 scale: F,
147 c: F,
148 d: F,
149}
150
151impl<F> Gamma<F>
152where
153 F: Float,
154 StandardNormal: Distribution<F>,
155 Exp1: Distribution<F>,
156 Open01: Distribution<F>,
157{
158 #[inline]
161 pub fn new(shape: F, scale: F) -> Result<Gamma<F>, Error> {
162 if !(shape > F::zero()) {
163 return Err(Error::ShapeTooSmall);
164 }
165 if !(scale > F::zero()) {
166 return Err(Error::ScaleTooSmall);
167 }
168
169 let repr = if shape == F::one() {
170 One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
171 } else if shape < F::one() {
172 Small(GammaSmallShape::new_raw(shape, scale))
173 } else {
174 Large(GammaLargeShape::new_raw(shape, scale))
175 };
176 Ok(Gamma { repr })
177 }
178}
179
180impl<F> GammaSmallShape<F>
181where
182 F: Float,
183 StandardNormal: Distribution<F>,
184 Open01: Distribution<F>,
185{
186 fn new_raw(shape: F, scale: F) -> GammaSmallShape<F> {
187 GammaSmallShape {
188 inv_shape: F::one() / shape,
189 large_shape: GammaLargeShape::new_raw(shape + F::one(), scale),
190 }
191 }
192}
193
194impl<F> GammaLargeShape<F>
195where
196 F: Float,
197 StandardNormal: Distribution<F>,
198 Open01: Distribution<F>,
199{
200 fn new_raw(shape: F, scale: F) -> GammaLargeShape<F> {
201 let d = shape - F::from(1. / 3.).unwrap();
202 GammaLargeShape {
203 scale,
204 c: F::one() / (F::from(9.).unwrap() * d).sqrt(),
205 d,
206 }
207 }
208}
209
210impl<F> Distribution<F> for Gamma<F>
211where
212 F: Float,
213 StandardNormal: Distribution<F>,
214 Exp1: Distribution<F>,
215 Open01: Distribution<F>,
216{
217 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
218 match self.repr {
219 Small(ref g) => g.sample(rng),
220 One(ref g) => g.sample(rng),
221 Large(ref g) => g.sample(rng),
222 }
223 }
224}
225impl<F> Distribution<F> for GammaSmallShape<F>
226where
227 F: Float,
228 StandardNormal: Distribution<F>,
229 Open01: Distribution<F>,
230{
231 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
232 let u: F = rng.sample(Open01);
233
234 self.large_shape.sample(rng) * u.powf(self.inv_shape)
235 }
236}
237impl<F> Distribution<F> for GammaLargeShape<F>
238where
239 F: Float,
240 StandardNormal: Distribution<F>,
241 Open01: Distribution<F>,
242{
243 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
244 loop {
246 let x: F = rng.sample(StandardNormal);
247 let v_cbrt = F::one() + self.c * x;
248 if v_cbrt <= F::zero() {
249 continue;
251 }
252
253 let v = v_cbrt * v_cbrt * v_cbrt;
254 let u: F = rng.sample(Open01);
255
256 let x_sqr = x * x;
257 if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
258 || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
259 {
260 return self.d * v * self.scale;
261 }
262 }
263 }
264}
265
266#[derive(Clone, Copy, Debug)]
284#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
285pub struct ChiSquared<F>
286where
287 F: Float,
288 StandardNormal: Distribution<F>,
289 Exp1: Distribution<F>,
290 Open01: Distribution<F>,
291{
292 repr: ChiSquaredRepr<F>,
293}
294
295#[derive(Clone, Copy, Debug, PartialEq, Eq)]
297#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
298pub enum ChiSquaredError {
299 DoFTooSmall,
301}
302
303impl fmt::Display for ChiSquaredError {
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 f.write_str(match self {
306 ChiSquaredError::DoFTooSmall => {
307 "degrees-of-freedom k is not positive in chi-squared distribution"
308 }
309 })
310 }
311}
312
313#[cfg(feature = "std")]
314#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
315impl std::error::Error for ChiSquaredError {}
316
317#[derive(Clone, Copy, Debug)]
318#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
319enum ChiSquaredRepr<F>
320where
321 F: Float,
322 StandardNormal: Distribution<F>,
323 Exp1: Distribution<F>,
324 Open01: Distribution<F>,
325{
326 DoFExactlyOne,
330 DoFAnythingElse(Gamma<F>),
331}
332
333impl<F> ChiSquared<F>
334where
335 F: Float,
336 StandardNormal: Distribution<F>,
337 Exp1: Distribution<F>,
338 Open01: Distribution<F>,
339{
340 pub fn new(k: F) -> Result<ChiSquared<F>, ChiSquaredError> {
343 let repr = if k == F::one() {
344 DoFExactlyOne
345 } else {
346 if !(F::from(0.5).unwrap() * k > F::zero()) {
347 return Err(ChiSquaredError::DoFTooSmall);
348 }
349 DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap())
350 };
351 Ok(ChiSquared { repr })
352 }
353}
354impl<F> Distribution<F> for ChiSquared<F>
355where
356 F: Float,
357 StandardNormal: Distribution<F>,
358 Exp1: Distribution<F>,
359 Open01: Distribution<F>,
360{
361 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
362 match self.repr {
363 DoFExactlyOne => {
364 let norm: F = rng.sample(StandardNormal);
366 norm * norm
367 }
368 DoFAnythingElse(ref g) => g.sample(rng),
369 }
370 }
371}
372
373#[derive(Clone, Copy, Debug)]
389#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
390pub struct FisherF<F>
391where
392 F: Float,
393 StandardNormal: Distribution<F>,
394 Exp1: Distribution<F>,
395 Open01: Distribution<F>,
396{
397 numer: ChiSquared<F>,
398 denom: ChiSquared<F>,
399 dof_ratio: F,
402}
403
404#[derive(Clone, Copy, Debug, PartialEq, Eq)]
406#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
407pub enum FisherFError {
408 MTooSmall,
410 NTooSmall,
412}
413
414impl fmt::Display for FisherFError {
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416 f.write_str(match self {
417 FisherFError::MTooSmall => "m is not positive in Fisher F distribution",
418 FisherFError::NTooSmall => "n is not positive in Fisher F distribution",
419 })
420 }
421}
422
423#[cfg(feature = "std")]
424#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
425impl std::error::Error for FisherFError {}
426
427impl<F> FisherF<F>
428where
429 F: Float,
430 StandardNormal: Distribution<F>,
431 Exp1: Distribution<F>,
432 Open01: Distribution<F>,
433{
434 pub fn new(m: F, n: F) -> Result<FisherF<F>, FisherFError> {
436 let zero = F::zero();
437 if !(m > zero) {
438 return Err(FisherFError::MTooSmall);
439 }
440 if !(n > zero) {
441 return Err(FisherFError::NTooSmall);
442 }
443
444 Ok(FisherF {
445 numer: ChiSquared::new(m).unwrap(),
446 denom: ChiSquared::new(n).unwrap(),
447 dof_ratio: n / m,
448 })
449 }
450}
451impl<F> Distribution<F> for FisherF<F>
452where
453 F: Float,
454 StandardNormal: Distribution<F>,
455 Exp1: Distribution<F>,
456 Open01: Distribution<F>,
457{
458 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
459 self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio
460 }
461}
462
463#[derive(Clone, Copy, Debug)]
476#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
477pub struct StudentT<F>
478where
479 F: Float,
480 StandardNormal: Distribution<F>,
481 Exp1: Distribution<F>,
482 Open01: Distribution<F>,
483{
484 chi: ChiSquared<F>,
485 dof: F,
486}
487
488impl<F> StudentT<F>
489where
490 F: Float,
491 StandardNormal: Distribution<F>,
492 Exp1: Distribution<F>,
493 Open01: Distribution<F>,
494{
495 pub fn new(n: F) -> Result<StudentT<F>, ChiSquaredError> {
498 Ok(StudentT {
499 chi: ChiSquared::new(n)?,
500 dof: n,
501 })
502 }
503}
504impl<F> Distribution<F> for StudentT<F>
505where
506 F: Float,
507 StandardNormal: Distribution<F>,
508 Exp1: Distribution<F>,
509 Open01: Distribution<F>,
510{
511 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
512 let norm: F = rng.sample(StandardNormal);
513 norm * (self.dof / self.chi.sample(rng)).sqrt()
514 }
515}
516
517#[derive(Clone, Copy, Debug)]
526#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
527enum BetaAlgorithm<N> {
528 BB(BB<N>),
529 BC(BC<N>),
530}
531
532#[derive(Clone, Copy, Debug)]
534#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
535struct BB<N> {
536 alpha: N,
537 beta: N,
538 gamma: N,
539}
540
541#[derive(Clone, Copy, Debug)]
543#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
544struct BC<N> {
545 alpha: N,
546 beta: N,
547 delta: N,
548 kappa1: N,
549 kappa2: N,
550}
551
552#[derive(Clone, Copy, Debug)]
564#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
565pub struct Beta<F>
566where
567 F: Float,
568 Open01: Distribution<F>,
569{
570 a: F, b: F, switched_params: bool,
571 algorithm: BetaAlgorithm<F>,
572}
573
574#[derive(Clone, Copy, Debug, PartialEq, Eq)]
576#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
577pub enum BetaError {
578 AlphaTooSmall,
580 BetaTooSmall,
582}
583
584impl fmt::Display for BetaError {
585 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
586 f.write_str(match self {
587 BetaError::AlphaTooSmall => "alpha is not positive in beta distribution",
588 BetaError::BetaTooSmall => "beta is not positive in beta distribution",
589 })
590 }
591}
592
593#[cfg(feature = "std")]
594#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
595impl std::error::Error for BetaError {}
596
597impl<F> Beta<F>
598where
599 F: Float,
600 Open01: Distribution<F>,
601{
602 pub fn new(alpha: F, beta: F) -> Result<Beta<F>, BetaError> {
605 if !(alpha > F::zero()) {
606 return Err(BetaError::AlphaTooSmall);
607 }
608 if !(beta > F::zero()) {
609 return Err(BetaError::BetaTooSmall);
610 }
611 let (a0, b0) = (alpha, beta);
614 let (a, b, switched_params) = if a0 < b0 {
615 (a0, b0, false)
616 } else {
617 (b0, a0, true)
618 };
619 if a > F::one() {
620 let alpha = a + b;
622 let beta = ((alpha - F::from(2.).unwrap())
623 / (F::from(2.).unwrap()*a*b - alpha)).sqrt();
624 let gamma = a + F::one() / beta;
625
626 Ok(Beta {
627 a, b, switched_params,
628 algorithm: BetaAlgorithm::BB(BB {
629 alpha, beta, gamma,
630 })
631 })
632 } else {
633 let (a, b, switched_params) = (b, a, !switched_params);
637 let alpha = a + b;
638 let beta = F::one() / b;
639 let delta = F::one() + a - b;
640 let kappa1 = delta
641 * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b)
642 / (a*beta - F::from(14. / 18.).unwrap());
643 let kappa2 = F::from(0.25).unwrap()
644 + (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b;
645
646 Ok(Beta {
647 a, b, switched_params,
648 algorithm: BetaAlgorithm::BC(BC {
649 alpha, beta, delta, kappa1, kappa2,
650 })
651 })
652 }
653 }
654}
655
656impl<F> Distribution<F> for Beta<F>
657where
658 F: Float,
659 Open01: Distribution<F>,
660{
661 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
662 let mut w;
663 match self.algorithm {
664 BetaAlgorithm::BB(algo) => {
665 loop {
666 let u1 = rng.sample(Open01);
668 let u2 = rng.sample(Open01);
669 let v = algo.beta * (u1 / (F::one() - u1)).ln();
670 w = self.a * v.exp();
671 let z = u1*u1 * u2;
672 let r = algo.gamma * v - F::from(4.).unwrap().ln();
673 let s = self.a + r - w;
674 if s + F::one() + F::from(5.).unwrap().ln()
676 >= F::from(5.).unwrap() * z {
677 break;
678 }
679 let t = z.ln();
681 if s >= t {
682 break;
683 }
684 if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
686 break;
687 }
688 }
689 },
690 BetaAlgorithm::BC(algo) => {
691 loop {
692 let z;
693 let u1 = rng.sample(Open01);
695 let u2 = rng.sample(Open01);
696 if u1 < F::from(0.5).unwrap() {
697 let y = u1 * u2;
699 z = u1 * y;
700 if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
701 continue;
702 }
703 } else {
704 z = u1 * u1 * u2;
706 if z <= F::from(0.25).unwrap() {
707 let v = algo.beta * (u1 / (F::one() - u1)).ln();
708 w = self.a * v.exp();
709 break;
710 }
711 if z >= algo.kappa2 {
713 continue;
714 }
715 }
716 let v = algo.beta * (u1 / (F::one() - u1)).ln();
718 w = self.a * v.exp();
719 if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
720 - F::from(4.).unwrap().ln() < z.ln()) {
721 break;
722 };
723 }
724 },
725 };
726 if !self.switched_params {
728 if w == F::infinity() {
729 return F::one();
731 }
732 w / (self.b + w)
733 } else {
734 self.b / (self.b + w)
735 }
736 }
737}
738
739#[cfg(test)]
740mod test {
741 use super::*;
742
743 #[test]
744 fn test_chi_squared_one() {
745 let chi = ChiSquared::new(1.0).unwrap();
746 let mut rng = crate::test::rng(201);
747 for _ in 0..1000 {
748 chi.sample(&mut rng);
749 }
750 }
751 #[test]
752 fn test_chi_squared_small() {
753 let chi = ChiSquared::new(0.5).unwrap();
754 let mut rng = crate::test::rng(202);
755 for _ in 0..1000 {
756 chi.sample(&mut rng);
757 }
758 }
759 #[test]
760 fn test_chi_squared_large() {
761 let chi = ChiSquared::new(30.0).unwrap();
762 let mut rng = crate::test::rng(203);
763 for _ in 0..1000 {
764 chi.sample(&mut rng);
765 }
766 }
767 #[test]
768 #[should_panic]
769 fn test_chi_squared_invalid_dof() {
770 ChiSquared::new(-1.0).unwrap();
771 }
772
773 #[test]
774 fn test_f() {
775 let f = FisherF::new(2.0, 32.0).unwrap();
776 let mut rng = crate::test::rng(204);
777 for _ in 0..1000 {
778 f.sample(&mut rng);
779 }
780 }
781
782 #[test]
783 fn test_t() {
784 let t = StudentT::new(11.0).unwrap();
785 let mut rng = crate::test::rng(205);
786 for _ in 0..1000 {
787 t.sample(&mut rng);
788 }
789 }
790
791 #[test]
792 fn test_beta() {
793 let beta = Beta::new(1.0, 2.0).unwrap();
794 let mut rng = crate::test::rng(201);
795 for _ in 0..1000 {
796 beta.sample(&mut rng);
797 }
798 }
799
800 #[test]
801 #[should_panic]
802 fn test_beta_invalid_dof() {
803 Beta::new(0., 0.).unwrap();
804 }
805
806 #[test]
807 fn test_beta_small_param() {
808 let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
809 let mut rng = crate::test::rng(206);
810 for i in 0..1000 {
811 assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
812 }
813 }
814}