1use num_traits::Float;
12use crate::{Distribution, Standard};
13use rand::{Rng, distributions::OpenClosed01};
14use core::fmt;
15
16#[derive(Clone, Copy, Debug)]
50pub struct Zeta<F>
51where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
52{
53 a_minus_1: F,
54 b: F,
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59pub enum ZetaError {
60 ATooSmall,
62}
63
64impl fmt::Display for ZetaError {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.write_str(match self {
67 ZetaError::ATooSmall => "a <= 1 or is NaN in Zeta distribution",
68 })
69 }
70}
71
72#[cfg(feature = "std")]
73#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
74impl std::error::Error for ZetaError {}
75
76impl<F> Zeta<F>
77where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
78{
79 #[inline]
81 pub fn new(a: F) -> Result<Zeta<F>, ZetaError> {
82 if !(a > F::one()) {
83 return Err(ZetaError::ATooSmall);
84 }
85 let a_minus_1 = a - F::one();
86 let two = F::one() + F::one();
87 Ok(Zeta {
88 a_minus_1,
89 b: two.powf(a_minus_1),
90 })
91 }
92}
93
94impl<F> Distribution<F> for Zeta<F>
95where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
96{
97 #[inline]
98 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
99 loop {
100 let u = rng.sample(OpenClosed01);
101 let x = u.powf(-F::one() / self.a_minus_1).floor();
102 debug_assert!(x >= F::one());
103 if x.is_infinite() {
104 return x;
108 }
109
110 let t = (F::one() + F::one() / x).powf(self.a_minus_1);
111
112 let v = rng.sample(Standard);
113 if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) {
114 return x;
115 }
116 }
117 }
118}
119
120#[derive(Clone, Copy, Debug)]
146pub struct Zipf<F>
147where F: Float, Standard: Distribution<F> {
148 n: F,
149 s: F,
150 t: F,
151 q: F,
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, Eq)]
156pub enum ZipfError {
157 STooSmall,
159 NTooSmall,
161}
162
163impl fmt::Display for ZipfError {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.write_str(match self {
166 ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution",
167 ZipfError::NTooSmall => "n < 1 in Zipf distribution",
168 })
169 }
170}
171
172#[cfg(feature = "std")]
173#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
174impl std::error::Error for ZipfError {}
175
176impl<F> Zipf<F>
177where F: Float, Standard: Distribution<F> {
178 #[inline]
183 pub fn new(n: u64, s: F) -> Result<Zipf<F>, ZipfError> {
184 if !(s >= F::zero()) {
185 return Err(ZipfError::STooSmall);
186 }
187 if n < 1 {
188 return Err(ZipfError::NTooSmall);
189 }
190 let n = F::from(n).unwrap(); let q = if s != F::one() {
192 F::one() / (F::one() - s)
194 } else {
195 F::zero()
197 };
198 let t = if s != F::one() {
199 (n.powf(F::one() - s) - s) * q
200 } else {
201 F::one() + n.ln()
202 };
203 debug_assert!(t > F::zero());
204 Ok(Zipf {
205 n, s, t, q
206 })
207 }
208
209 #[inline]
211 fn inv_cdf(&self, p: F) -> F {
212 let one = F::one();
213 let pt = p * self.t;
214 if pt <= one {
215 pt
216 } else if self.s != one {
217 (pt * (one - self.s) + self.s).powf(self.q)
218 } else {
219 (pt - one).exp()
220 }
221 }
222}
223
224impl<F> Distribution<F> for Zipf<F>
225where F: Float, Standard: Distribution<F>
226{
227 #[inline]
228 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
229 let one = F::one();
230 loop {
231 let inv_b = self.inv_cdf(rng.sample(Standard));
232 let x = (inv_b + one).floor();
233 let mut ratio = x.powf(-self.s);
234 if x > one {
235 ratio = ratio * inv_b.powf(self.s)
236 };
237
238 let y = rng.sample(Standard);
239 if y < ratio {
240 return x;
241 }
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 fn test_samples<F: Float + core::fmt::Debug, D: Distribution<F>>(
251 distr: D, zero: F, expected: &[F],
252 ) {
253 let mut rng = crate::test::rng(213);
254 let mut buf = [zero; 4];
255 for x in &mut buf {
256 *x = rng.sample(&distr);
257 }
258 assert_eq!(buf, expected);
259 }
260
261 #[test]
262 #[should_panic]
263 fn zeta_invalid() {
264 Zeta::new(1.).unwrap();
265 }
266
267 #[test]
268 #[should_panic]
269 fn zeta_nan() {
270 Zeta::new(core::f64::NAN).unwrap();
271 }
272
273 #[test]
274 fn zeta_sample() {
275 let a = 2.0;
276 let d = Zeta::new(a).unwrap();
277 let mut rng = crate::test::rng(1);
278 for _ in 0..1000 {
279 let r = d.sample(&mut rng);
280 assert!(r >= 1.);
281 }
282 }
283
284 #[test]
285 fn zeta_small_a() {
286 let a = 1. + 1e-15;
287 let d = Zeta::new(a).unwrap();
288 let mut rng = crate::test::rng(2);
289 for _ in 0..1000 {
290 let r = d.sample(&mut rng);
291 assert!(r >= 1.);
292 }
293 }
294
295 #[test]
296 fn zeta_value_stability() {
297 test_samples(Zeta::new(1.5).unwrap(), 0f32, &[
298 1.0, 2.0, 1.0, 1.0,
299 ]);
300 test_samples(Zeta::new(2.0).unwrap(), 0f64, &[
301 2.0, 1.0, 1.0, 1.0,
302 ]);
303 }
304
305 #[test]
306 #[should_panic]
307 fn zipf_s_too_small() {
308 Zipf::new(10, -1.).unwrap();
309 }
310
311 #[test]
312 #[should_panic]
313 fn zipf_n_too_small() {
314 Zipf::new(0, 1.).unwrap();
315 }
316
317 #[test]
318 #[should_panic]
319 fn zipf_nan() {
320 Zipf::new(10, core::f64::NAN).unwrap();
321 }
322
323 #[test]
324 fn zipf_sample() {
325 let d = Zipf::new(10, 0.5).unwrap();
326 let mut rng = crate::test::rng(2);
327 for _ in 0..1000 {
328 let r = d.sample(&mut rng);
329 assert!(r >= 1.);
330 }
331 }
332
333 #[test]
334 fn zipf_sample_s_1() {
335 let d = Zipf::new(10, 1.).unwrap();
336 let mut rng = crate::test::rng(2);
337 for _ in 0..1000 {
338 let r = d.sample(&mut rng);
339 assert!(r >= 1.);
340 }
341 }
342
343 #[test]
344 fn zipf_sample_s_0() {
345 let d = Zipf::new(10, 0.).unwrap();
346 let mut rng = crate::test::rng(2);
347 for _ in 0..1000 {
348 let r = d.sample(&mut rng);
349 assert!(r >= 1.);
350 }
351 }
353
354 #[test]
355 fn zipf_sample_large_n() {
356 let d = Zipf::new(core::u64::MAX, 1.5).unwrap();
357 let mut rng = crate::test::rng(2);
358 for _ in 0..1000 {
359 let r = d.sample(&mut rng);
360 assert!(r >= 1.);
361 }
362 }
364
365 #[test]
366 fn zipf_value_stability() {
367 test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[
368 10.0, 2.0, 6.0, 7.0
369 ]);
370 test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[
371 1.0, 2.0, 3.0, 2.0
372 ]);
373 }
374}