rand_distr/
weighted_alias.rs

1// Copyright 2019 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! This module contains an implementation of alias method for sampling random
10//! indices with probabilities proportional to a collection of weights.
11
12use super::WeightedError;
13use crate::{uniform::SampleUniform, Distribution, Uniform};
14use core::fmt;
15use core::iter::Sum;
16use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
17use rand::Rng;
18use alloc::{boxed::Box, vec, vec::Vec};
19#[cfg(feature = "serde1")]
20use serde::{Serialize, Deserialize};
21
22/// A distribution using weighted sampling to pick a discretely selected item.
23///
24/// Sampling a [`WeightedAliasIndex<W>`] distribution returns the index of a randomly
25/// selected element from the vector used to create the [`WeightedAliasIndex<W>`].
26/// The chance of a given element being picked is proportional to the value of
27/// the element. The weights can have any type `W` for which a implementation of
28/// [`AliasableWeight`] exists.
29///
30/// # Performance
31///
32/// Given that `n` is the number of items in the vector used to create an
33/// [`WeightedAliasIndex<W>`], it will require `O(n)` amount of memory.
34/// More specifically it takes up some constant amount of memory plus
35/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
36///
37/// Time complexity for the creation of a [`WeightedAliasIndex<W>`] is `O(n)`.
38/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
39/// to [`Uniform<W>::sample`].
40///
41/// # Example
42///
43/// ```
44/// use rand_distr::WeightedAliasIndex;
45/// use rand::prelude::*;
46///
47/// let choices = vec!['a', 'b', 'c'];
48/// let weights = vec![2, 1, 1];
49/// let dist = WeightedAliasIndex::new(weights).unwrap();
50/// let mut rng = thread_rng();
51/// for _ in 0..100 {
52///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
53///     println!("{}", choices[dist.sample(&mut rng)]);
54/// }
55///
56/// let items = [('a', 0), ('b', 3), ('c', 7)];
57/// let dist2 = WeightedAliasIndex::new(items.iter().map(|item| item.1).collect()).unwrap();
58/// for _ in 0..100 {
59///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
60///     println!("{}", items[dist2.sample(&mut rng)].0);
61/// }
62/// ```
63///
64/// [`WeightedAliasIndex<W>`]: WeightedAliasIndex
65/// [`Vec<u32>`]: Vec
66/// [`Uniform<u32>::sample`]: Distribution::sample
67/// [`Uniform<W>::sample`]: Distribution::sample
68#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
69#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
70#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))]
71#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))]
72pub struct WeightedAliasIndex<W: AliasableWeight> {
73    aliases: Box<[u32]>,
74    no_alias_odds: Box<[W]>,
75    uniform_index: Uniform<u32>,
76    uniform_within_weight_sum: Uniform<W>,
77}
78
79impl<W: AliasableWeight> WeightedAliasIndex<W> {
80    /// Creates a new [`WeightedAliasIndex`].
81    ///
82    /// Returns an error if:
83    /// - The vector is empty.
84    /// - The vector is longer than `u32::MAX`.
85    /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
86    ///   weights.len()`.
87    /// - The sum of weights is zero.
88    pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
89        let n = weights.len();
90        if n == 0 {
91            return Err(WeightedError::NoItem);
92        } else if n > ::core::u32::MAX as usize {
93            return Err(WeightedError::TooMany);
94        }
95        let n = n as u32;
96
97        let max_weight_size = W::try_from_u32_lossy(n)
98            .map(|n| W::MAX / n)
99            .unwrap_or(W::ZERO);
100        if !weights
101            .iter()
102            .all(|&w| W::ZERO <= w && w <= max_weight_size)
103        {
104            return Err(WeightedError::InvalidWeight);
105        }
106
107        // The sum of weights will represent 100% of no alias odds.
108        let weight_sum = AliasableWeight::sum(weights.as_slice());
109        // Prevent floating point overflow due to rounding errors.
110        let weight_sum = if weight_sum > W::MAX {
111            W::MAX
112        } else {
113            weight_sum
114        };
115        if weight_sum == W::ZERO {
116            return Err(WeightedError::AllWeightsZero);
117        }
118
119        // `weight_sum` would have been zero if `try_from_lossy` causes an error here.
120        let n_converted = W::try_from_u32_lossy(n).unwrap();
121
122        let mut no_alias_odds = weights.into_boxed_slice();
123        for odds in no_alias_odds.iter_mut() {
124            *odds *= n_converted;
125            // Prevent floating point overflow due to rounding errors.
126            *odds = if *odds > W::MAX { W::MAX } else { *odds };
127        }
128
129        /// This struct is designed to contain three data structures at once,
130        /// sharing the same memory. More precisely it contains two linked lists
131        /// and an alias map, which will be the output of this method. To keep
132        /// the three data structures from getting in each other's way, it must
133        /// be ensured that a single index is only ever in one of them at the
134        /// same time.
135        struct Aliases {
136            aliases: Box<[u32]>,
137            smalls_head: u32,
138            bigs_head: u32,
139        }
140
141        impl Aliases {
142            fn new(size: u32) -> Self {
143                Aliases {
144                    aliases: vec![0; size as usize].into_boxed_slice(),
145                    smalls_head: ::core::u32::MAX,
146                    bigs_head: ::core::u32::MAX,
147                }
148            }
149
150            fn push_small(&mut self, idx: u32) {
151                self.aliases[idx as usize] = self.smalls_head;
152                self.smalls_head = idx;
153            }
154
155            fn push_big(&mut self, idx: u32) {
156                self.aliases[idx as usize] = self.bigs_head;
157                self.bigs_head = idx;
158            }
159
160            fn pop_small(&mut self) -> u32 {
161                let popped = self.smalls_head;
162                self.smalls_head = self.aliases[popped as usize];
163                popped
164            }
165
166            fn pop_big(&mut self) -> u32 {
167                let popped = self.bigs_head;
168                self.bigs_head = self.aliases[popped as usize];
169                popped
170            }
171
172            fn smalls_is_empty(&self) -> bool {
173                self.smalls_head == ::core::u32::MAX
174            }
175
176            fn bigs_is_empty(&self) -> bool {
177                self.bigs_head == ::core::u32::MAX
178            }
179
180            fn set_alias(&mut self, idx: u32, alias: u32) {
181                self.aliases[idx as usize] = alias;
182            }
183        }
184
185        let mut aliases = Aliases::new(n);
186
187        // Split indices into those with small weights and those with big weights.
188        for (index, &odds) in no_alias_odds.iter().enumerate() {
189            if odds < weight_sum {
190                aliases.push_small(index as u32);
191            } else {
192                aliases.push_big(index as u32);
193            }
194        }
195
196        // Build the alias map by finding an alias with big weight for each index with
197        // small weight.
198        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
199            let s = aliases.pop_small();
200            let b = aliases.pop_big();
201
202            aliases.set_alias(s, b);
203            no_alias_odds[b as usize] =
204                no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
205
206            if no_alias_odds[b as usize] < weight_sum {
207                aliases.push_small(b);
208            } else {
209                aliases.push_big(b);
210            }
211        }
212
213        // The remaining indices should have no alias odds of about 100%. This is due to
214        // numeric accuracy. Otherwise they would be exactly 100%.
215        while !aliases.smalls_is_empty() {
216            no_alias_odds[aliases.pop_small() as usize] = weight_sum;
217        }
218        while !aliases.bigs_is_empty() {
219            no_alias_odds[aliases.pop_big() as usize] = weight_sum;
220        }
221
222        // Prepare distributions for sampling. Creating them beforehand improves
223        // sampling performance.
224        let uniform_index = Uniform::new(0, n);
225        let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum);
226
227        Ok(Self {
228            aliases: aliases.aliases,
229            no_alias_odds,
230            uniform_index,
231            uniform_within_weight_sum,
232        })
233    }
234}
235
236impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
237    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
238        let candidate = rng.sample(self.uniform_index);
239        if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
240            candidate as usize
241        } else {
242            self.aliases[candidate as usize] as usize
243        }
244    }
245}
246
247impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
248where
249    W: fmt::Debug,
250    Uniform<W>: fmt::Debug,
251{
252    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253        f.debug_struct("WeightedAliasIndex")
254            .field("aliases", &self.aliases)
255            .field("no_alias_odds", &self.no_alias_odds)
256            .field("uniform_index", &self.uniform_index)
257            .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
258            .finish()
259    }
260}
261
262impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
263where Uniform<W>: Clone
264{
265    fn clone(&self) -> Self {
266        Self {
267            aliases: self.aliases.clone(),
268            no_alias_odds: self.no_alias_odds.clone(),
269            uniform_index: self.uniform_index,
270            uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
271        }
272    }
273}
274
275/// Trait that must be implemented for weights, that are used with
276/// [`WeightedAliasIndex`]. Currently no guarantees on the correctness of
277/// [`WeightedAliasIndex`] are given for custom implementations of this trait.
278#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
279pub trait AliasableWeight:
280    Sized
281    + Copy
282    + SampleUniform
283    + PartialOrd
284    + Add<Output = Self>
285    + AddAssign
286    + Sub<Output = Self>
287    + SubAssign
288    + Mul<Output = Self>
289    + MulAssign
290    + Div<Output = Self>
291    + DivAssign
292    + Sum
293{
294    /// Maximum number representable by `Self`.
295    const MAX: Self;
296
297    /// Element of `Self` equivalent to 0.
298    const ZERO: Self;
299
300    /// Produce an instance of `Self` from a `u32` value, or return `None` if
301    /// out of range. Loss of precision (where `Self` is a floating point type)
302    /// is acceptable.
303    fn try_from_u32_lossy(n: u32) -> Option<Self>;
304
305    /// Sums all values in slice `values`.
306    fn sum(values: &[Self]) -> Self {
307        values.iter().copied().sum()
308    }
309}
310
311macro_rules! impl_weight_for_float {
312    ($T: ident) => {
313        impl AliasableWeight for $T {
314            const MAX: Self = ::core::$T::MAX;
315            const ZERO: Self = 0.0;
316
317            fn try_from_u32_lossy(n: u32) -> Option<Self> {
318                Some(n as $T)
319            }
320
321            fn sum(values: &[Self]) -> Self {
322                pairwise_sum(values)
323            }
324        }
325    };
326}
327
328/// In comparison to naive accumulation, the pairwise sum algorithm reduces
329/// rounding errors when there are many floating point values.
330fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
331    if values.len() <= 32 {
332        values.iter().copied().sum()
333    } else {
334        let mid = values.len() / 2;
335        let (a, b) = values.split_at(mid);
336        pairwise_sum(a) + pairwise_sum(b)
337    }
338}
339
340macro_rules! impl_weight_for_int {
341    ($T: ident) => {
342        impl AliasableWeight for $T {
343            const MAX: Self = ::core::$T::MAX;
344            const ZERO: Self = 0;
345
346            fn try_from_u32_lossy(n: u32) -> Option<Self> {
347                let n_converted = n as Self;
348                if n_converted >= Self::ZERO && n_converted as u32 == n {
349                    Some(n_converted)
350                } else {
351                    None
352                }
353            }
354        }
355    };
356}
357
358impl_weight_for_float!(f64);
359impl_weight_for_float!(f32);
360impl_weight_for_int!(usize);
361impl_weight_for_int!(u128);
362impl_weight_for_int!(u64);
363impl_weight_for_int!(u32);
364impl_weight_for_int!(u16);
365impl_weight_for_int!(u8);
366impl_weight_for_int!(isize);
367impl_weight_for_int!(i128);
368impl_weight_for_int!(i64);
369impl_weight_for_int!(i32);
370impl_weight_for_int!(i16);
371impl_weight_for_int!(i8);
372
373#[cfg(test)]
374mod test {
375    use super::*;
376
377    #[test]
378    #[cfg_attr(miri, ignore)] // Miri is too slow
379    fn test_weighted_index_f32() {
380        test_weighted_index(f32::into);
381
382        // Floating point special cases
383        assert_eq!(
384            WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
385            WeightedError::InvalidWeight
386        );
387        assert_eq!(
388            WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
389            WeightedError::AllWeightsZero
390        );
391        assert_eq!(
392            WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
393            WeightedError::InvalidWeight
394        );
395        assert_eq!(
396            WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
397            WeightedError::InvalidWeight
398        );
399        assert_eq!(
400            WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(),
401            WeightedError::InvalidWeight
402        );
403    }
404
405    #[test]
406    #[cfg_attr(miri, ignore)] // Miri is too slow
407    fn test_weighted_index_u128() {
408        test_weighted_index(|x: u128| x as f64);
409    }
410
411    #[test]
412    #[cfg_attr(miri, ignore)] // Miri is too slow
413    fn test_weighted_index_i128() {
414        test_weighted_index(|x: i128| x as f64);
415
416        // Signed integer special cases
417        assert_eq!(
418            WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
419            WeightedError::InvalidWeight
420        );
421        assert_eq!(
422            WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(),
423            WeightedError::InvalidWeight
424        );
425    }
426
427    #[test]
428    #[cfg_attr(miri, ignore)] // Miri is too slow
429    fn test_weighted_index_u8() {
430        test_weighted_index(u8::into);
431    }
432
433    #[test]
434    #[cfg_attr(miri, ignore)] // Miri is too slow
435    fn test_weighted_index_i8() {
436        test_weighted_index(i8::into);
437
438        // Signed integer special cases
439        assert_eq!(
440            WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
441            WeightedError::InvalidWeight
442        );
443        assert_eq!(
444            WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(),
445            WeightedError::InvalidWeight
446        );
447    }
448
449    fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
450    where WeightedAliasIndex<W>: fmt::Debug {
451        const NUM_WEIGHTS: u32 = 10;
452        const ZERO_WEIGHT_INDEX: u32 = 3;
453        const NUM_SAMPLES: u32 = 15000;
454        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
455
456        let weights = {
457            let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
458            let random_weight_distribution = Uniform::new_inclusive(
459                W::ZERO,
460                W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
461            );
462            for _ in 0..NUM_WEIGHTS {
463                weights.push(rng.sample(&random_weight_distribution));
464            }
465            weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
466            weights
467        };
468        let weight_sum = weights.iter().copied().sum::<W>();
469        let expected_counts = weights
470            .iter()
471            .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
472            .collect::<Vec<f64>>();
473        let weight_distribution = WeightedAliasIndex::new(weights).unwrap();
474
475        let mut counts = vec![0; NUM_WEIGHTS as usize];
476        for _ in 0..NUM_SAMPLES {
477            counts[rng.sample(&weight_distribution)] += 1;
478        }
479
480        assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
481        for (count, expected_count) in counts.into_iter().zip(expected_counts) {
482            let difference = (count as f64 - expected_count).abs();
483            let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
484            assert!(difference <= max_allowed_difference);
485        }
486
487        assert_eq!(
488            WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
489            WeightedError::NoItem
490        );
491        assert_eq!(
492            WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
493            WeightedError::AllWeightsZero
494        );
495        assert_eq!(
496            WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
497            WeightedError::InvalidWeight
498        );
499    }
500
501    #[test]
502    fn value_stability() {
503        fn test_samples<W: AliasableWeight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) {
504            assert_eq!(buf.len(), expected.len());
505            let distr = WeightedAliasIndex::new(weights).unwrap();
506            let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
507            for r in buf.iter_mut() {
508                *r = rng.sample(&distr);
509            }
510            assert_eq!(buf, expected);
511        }
512
513        let mut buf = [0; 10];
514        test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
515            6, 5, 7, 5, 8, 7, 6, 2, 3, 7,
516        ]);
517        test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
518            2, 0, 0, 0, 0, 0, 0, 0, 1, 3,
519        ]);
520        test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
521            2, 1, 2, 3, 2, 1, 3, 2, 1, 1,
522        ]);
523    }
524}