1#![cfg(feature = "alloc")]
12use num_traits::Float;
13use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
14use rand::Rng;
15use core::fmt;
16use alloc::{boxed::Box, vec, vec::Vec};
17
18#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
35#[derive(Clone, Debug)]
36#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
37pub struct Dirichlet<F>
38where
39 F: Float,
40 StandardNormal: Distribution<F>,
41 Exp1: Distribution<F>,
42 Open01: Distribution<F>,
43{
44 alpha: Box<[F]>,
46}
47
48#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub enum Error {
52 AlphaTooShort,
54 AlphaTooSmall,
56 SizeTooSmall,
58}
59
60impl fmt::Display for Error {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.write_str(match self {
63 Error::AlphaTooShort | Error::SizeTooSmall => {
64 "less than 2 dimensions in Dirichlet distribution"
65 }
66 Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution",
67 })
68 }
69}
70
71#[cfg(feature = "std")]
72#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
73impl std::error::Error for Error {}
74
75impl<F> Dirichlet<F>
76where
77 F: Float,
78 StandardNormal: Distribution<F>,
79 Exp1: Distribution<F>,
80 Open01: Distribution<F>,
81{
82 #[inline]
86 pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
87 if alpha.len() < 2 {
88 return Err(Error::AlphaTooShort);
89 }
90 for &ai in alpha.iter() {
91 if !(ai > F::zero()) {
92 return Err(Error::AlphaTooSmall);
93 }
94 }
95
96 Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() })
97 }
98
99 #[inline]
103 pub fn new_with_size(alpha: F, size: usize) -> Result<Dirichlet<F>, Error> {
104 if !(alpha > F::zero()) {
105 return Err(Error::AlphaTooSmall);
106 }
107 if size < 2 {
108 return Err(Error::SizeTooSmall);
109 }
110 Ok(Dirichlet {
111 alpha: vec![alpha; size].into_boxed_slice(),
112 })
113 }
114}
115
116impl<F> Distribution<Vec<F>> for Dirichlet<F>
117where
118 F: Float,
119 StandardNormal: Distribution<F>,
120 Exp1: Distribution<F>,
121 Open01: Distribution<F>,
122{
123 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
124 let n = self.alpha.len();
125 let mut samples = vec![F::zero(); n];
126 let mut sum = F::zero();
127
128 for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
129 let g = Gamma::new(a, F::one()).unwrap();
130 *s = g.sample(rng);
131 sum = sum + (*s);
132 }
133 let invacc = F::one() / sum;
134 for s in samples.iter_mut() {
135 *s = (*s)*invacc;
136 }
137 samples
138 }
139}
140
141#[cfg(test)]
142mod test {
143 use super::*;
144
145 #[test]
146 fn test_dirichlet() {
147 let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
148 let mut rng = crate::test::rng(221);
149 let samples = d.sample(&mut rng);
150 let _: Vec<f64> = samples
151 .into_iter()
152 .map(|x| {
153 assert!(x > 0.0);
154 x
155 })
156 .collect();
157 }
158
159 #[test]
160 fn test_dirichlet_with_param() {
161 let alpha = 0.5f64;
162 let size = 2;
163 let d = Dirichlet::new_with_size(alpha, size).unwrap();
164 let mut rng = crate::test::rng(221);
165 let samples = d.sample(&mut rng);
166 let _: Vec<f64> = samples
167 .into_iter()
168 .map(|x| {
169 assert!(x > 0.0);
170 x
171 })
172 .collect();
173 }
174
175 #[test]
176 #[should_panic]
177 fn test_dirichlet_invalid_length() {
178 Dirichlet::new_with_size(0.5f64, 1).unwrap();
179 }
180
181 #[test]
182 #[should_panic]
183 fn test_dirichlet_invalid_alpha() {
184 Dirichlet::new_with_size(0.0f64, 2).unwrap();
185 }
186}