rand_distr/
zipf.rs

1// Copyright 2021 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! The Zeta and related distributions.
10
11use num_traits::Float;
12use crate::{Distribution, Standard};
13use rand::{Rng, distributions::OpenClosed01};
14use core::fmt;
15
16/// Samples integers according to the [zeta distribution].
17///
18/// The zeta distribution is a limit of the [`Zipf`] distribution. Sometimes it
19/// is called one of the following: discrete Pareto, Riemann-Zeta, Zipf, or
20/// Zipf–Estoup distribution.
21///
22/// It has the density function `f(k) = k^(-a) / C(a)` for `k >= 1`, where `a`
23/// is the parameter and `C(a)` is the Riemann zeta function.
24///
25/// # Example
26/// ```
27/// use rand::prelude::*;
28/// use rand_distr::Zeta;
29///
30/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap());
31/// println!("{}", val);
32/// ```
33///
34/// # Remarks
35///
36/// The zeta distribution has no upper limit. Sampled values may be infinite.
37/// In particular, a value of infinity might be returned for the following
38/// reasons:
39/// 1. it is the best representation in the type `F` of the actual sample.
40/// 2. to prevent infinite loops for very small `a`.
41///
42/// # Implementation details
43///
44/// We are using the algorithm from [Non-Uniform Random Variate Generation],
45/// Section 6.1, page 551.
46///
47/// [zeta distribution]: https://en.wikipedia.org/wiki/Zeta_distribution
48/// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8
49#[derive(Clone, Copy, Debug)]
50pub struct Zeta<F>
51where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
52{
53    a_minus_1: F,
54    b: F,
55}
56
57/// Error type returned from `Zeta::new`.
58#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59pub enum ZetaError {
60    /// `a <= 1` or `nan`.
61    ATooSmall,
62}
63
64impl fmt::Display for ZetaError {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.write_str(match self {
67            ZetaError::ATooSmall => "a <= 1 or is NaN in Zeta distribution",
68        })
69    }
70}
71
72#[cfg(feature = "std")]
73#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
74impl std::error::Error for ZetaError {}
75
76impl<F> Zeta<F>
77where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
78{
79    /// Construct a new `Zeta` distribution with given `a` parameter.
80    #[inline]
81    pub fn new(a: F) -> Result<Zeta<F>, ZetaError> {
82        if !(a > F::one()) {
83            return Err(ZetaError::ATooSmall);
84        }
85        let a_minus_1 = a - F::one();
86        let two = F::one() + F::one();
87        Ok(Zeta {
88            a_minus_1,
89            b: two.powf(a_minus_1),
90        })
91    }
92}
93
94impl<F> Distribution<F> for Zeta<F>
95where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
96{
97    #[inline]
98    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
99        loop {
100            let u = rng.sample(OpenClosed01);
101            let x = u.powf(-F::one() / self.a_minus_1).floor();
102            debug_assert!(x >= F::one());
103            if x.is_infinite() {
104                // For sufficiently small `a`, `x` will always be infinite,
105                // which is rejected, resulting in an infinite loop. We avoid
106                // this by always returning infinity instead.
107                return x;
108            }
109
110            let t = (F::one() + F::one() / x).powf(self.a_minus_1);
111
112            let v = rng.sample(Standard);
113            if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) {
114                return x;
115            }
116        }
117    }
118}
119
120/// Samples integers according to the Zipf distribution.
121///
122/// The samples follow Zipf's law: The frequency of each sample from a finite
123/// set of size `n` is inversely proportional to a power of its frequency rank
124/// (with exponent `s`).
125///
126/// For large `n`, this converges to the [`Zeta`] distribution.
127///
128/// For `s = 0`, this becomes a uniform distribution.
129///
130/// # Example
131/// ```
132/// use rand::prelude::*;
133/// use rand_distr::Zipf;
134///
135/// let val: f64 = thread_rng().sample(Zipf::new(10, 1.5).unwrap());
136/// println!("{}", val);
137/// ```
138///
139/// # Implementation details
140///
141/// Implemented via [rejection sampling](https://en.wikipedia.org/wiki/Rejection_sampling),
142/// due to Jason Crease[1].
143///
144/// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa
145#[derive(Clone, Copy, Debug)]
146pub struct Zipf<F>
147where F: Float, Standard: Distribution<F> {
148    n: F,
149    s: F,
150    t: F,
151    q: F,
152}
153
154/// Error type returned from `Zipf::new`.
155#[derive(Clone, Copy, Debug, PartialEq, Eq)]
156pub enum ZipfError {
157    /// `s < 0` or `nan`.
158    STooSmall,
159    /// `n < 1`.
160    NTooSmall,
161}
162
163impl fmt::Display for ZipfError {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.write_str(match self {
166            ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution",
167            ZipfError::NTooSmall => "n < 1 in Zipf distribution",
168        })
169    }
170}
171
172#[cfg(feature = "std")]
173#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
174impl std::error::Error for ZipfError {}
175
176impl<F> Zipf<F>
177where F: Float, Standard: Distribution<F> {
178    /// Construct a new `Zipf` distribution for a set with `n` elements and a
179    /// frequency rank exponent `s`.
180    ///
181    /// For large `n`, rounding may occur to fit the number into the float type.
182    #[inline]
183    pub fn new(n: u64, s: F) -> Result<Zipf<F>, ZipfError> {
184        if !(s >= F::zero()) {
185            return Err(ZipfError::STooSmall);
186        }
187        if n < 1 {
188            return Err(ZipfError::NTooSmall);
189        }
190        let n = F::from(n).unwrap();  // This does not fail.
191        let q = if s != F::one() {
192            // Make sure to calculate the division only once.
193            F::one() / (F::one() - s)
194        } else {
195            // This value is never used.
196            F::zero()
197        };
198        let t = if s != F::one() {
199            (n.powf(F::one() - s) - s) * q
200        } else {
201            F::one() + n.ln()
202        };
203        debug_assert!(t > F::zero());
204        Ok(Zipf {
205            n, s, t, q
206        })
207    }
208
209    /// Inverse cumulative density function
210    #[inline]
211    fn inv_cdf(&self, p: F) -> F {
212        let one = F::one();
213        let pt = p * self.t;
214        if pt <= one {
215            pt
216        } else if self.s != one {
217            (pt * (one - self.s) + self.s).powf(self.q)
218        } else {
219            (pt - one).exp()
220        }
221    }
222}
223
224impl<F> Distribution<F> for Zipf<F>
225where F: Float, Standard: Distribution<F>
226{
227    #[inline]
228    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
229        let one = F::one();
230        loop {
231            let inv_b = self.inv_cdf(rng.sample(Standard));
232            let x = (inv_b + one).floor();
233            let mut ratio = x.powf(-self.s);
234            if x > one {
235                ratio = ratio * inv_b.powf(self.s)
236            };
237
238            let y = rng.sample(Standard);
239            if y < ratio {
240                return x;
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    fn test_samples<F: Float + core::fmt::Debug, D: Distribution<F>>(
251        distr: D, zero: F, expected: &[F],
252    ) {
253        let mut rng = crate::test::rng(213);
254        let mut buf = [zero; 4];
255        for x in &mut buf {
256            *x = rng.sample(&distr);
257        }
258        assert_eq!(buf, expected);
259    }
260
261    #[test]
262    #[should_panic]
263    fn zeta_invalid() {
264        Zeta::new(1.).unwrap();
265    }
266
267    #[test]
268    #[should_panic]
269    fn zeta_nan() {
270        Zeta::new(core::f64::NAN).unwrap();
271    }
272
273    #[test]
274    fn zeta_sample() {
275        let a = 2.0;
276        let d = Zeta::new(a).unwrap();
277        let mut rng = crate::test::rng(1);
278        for _ in 0..1000 {
279            let r = d.sample(&mut rng);
280            assert!(r >= 1.);
281        }
282    }
283
284    #[test]
285    fn zeta_small_a() {
286        let a = 1. + 1e-15;
287        let d = Zeta::new(a).unwrap();
288        let mut rng = crate::test::rng(2);
289        for _ in 0..1000 {
290            let r = d.sample(&mut rng);
291            assert!(r >= 1.);
292        }
293    }
294
295    #[test]
296    fn zeta_value_stability() {
297        test_samples(Zeta::new(1.5).unwrap(), 0f32, &[
298            1.0, 2.0, 1.0, 1.0,
299        ]);
300        test_samples(Zeta::new(2.0).unwrap(), 0f64, &[
301            2.0, 1.0, 1.0, 1.0,
302        ]);
303    }
304
305    #[test]
306    #[should_panic]
307    fn zipf_s_too_small() {
308        Zipf::new(10, -1.).unwrap();
309    }
310
311    #[test]
312    #[should_panic]
313    fn zipf_n_too_small() {
314        Zipf::new(0, 1.).unwrap();
315    }
316
317    #[test]
318    #[should_panic]
319    fn zipf_nan() {
320        Zipf::new(10, core::f64::NAN).unwrap();
321    }
322
323    #[test]
324    fn zipf_sample() {
325        let d = Zipf::new(10, 0.5).unwrap();
326        let mut rng = crate::test::rng(2);
327        for _ in 0..1000 {
328            let r = d.sample(&mut rng);
329            assert!(r >= 1.);
330        }
331    }
332
333    #[test]
334    fn zipf_sample_s_1() {
335        let d = Zipf::new(10, 1.).unwrap();
336        let mut rng = crate::test::rng(2);
337        for _ in 0..1000 {
338            let r = d.sample(&mut rng);
339            assert!(r >= 1.);
340        }
341    }
342
343    #[test]
344    fn zipf_sample_s_0() {
345        let d = Zipf::new(10, 0.).unwrap();
346        let mut rng = crate::test::rng(2);
347        for _ in 0..1000 {
348            let r = d.sample(&mut rng);
349            assert!(r >= 1.);
350        }
351        // TODO: verify that this is a uniform distribution
352    }
353
354    #[test]
355    fn zipf_sample_large_n() {
356        let d = Zipf::new(core::u64::MAX, 1.5).unwrap();
357        let mut rng = crate::test::rng(2);
358        for _ in 0..1000 {
359            let r = d.sample(&mut rng);
360            assert!(r >= 1.);
361        }
362        // TODO: verify that this is a zeta distribution
363    }
364
365    #[test]
366    fn zipf_value_stability() {
367        test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[
368            10.0, 2.0, 6.0, 7.0
369        ]);
370        test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[
371            1.0, 2.0, 3.0, 2.0
372        ]);
373    }
374}