1use super::WeightedError;
13use crate::{uniform::SampleUniform, Distribution, Uniform};
14use core::fmt;
15use core::iter::Sum;
16use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
17use rand::Rng;
18use alloc::{boxed::Box, vec, vec::Vec};
19#[cfg(feature = "serde1")]
20use serde::{Serialize, Deserialize};
21
22#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
69#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
70#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))]
71#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))]
72pub struct WeightedAliasIndex<W: AliasableWeight> {
73 aliases: Box<[u32]>,
74 no_alias_odds: Box<[W]>,
75 uniform_index: Uniform<u32>,
76 uniform_within_weight_sum: Uniform<W>,
77}
78
79impl<W: AliasableWeight> WeightedAliasIndex<W> {
80 pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
89 let n = weights.len();
90 if n == 0 {
91 return Err(WeightedError::NoItem);
92 } else if n > ::core::u32::MAX as usize {
93 return Err(WeightedError::TooMany);
94 }
95 let n = n as u32;
96
97 let max_weight_size = W::try_from_u32_lossy(n)
98 .map(|n| W::MAX / n)
99 .unwrap_or(W::ZERO);
100 if !weights
101 .iter()
102 .all(|&w| W::ZERO <= w && w <= max_weight_size)
103 {
104 return Err(WeightedError::InvalidWeight);
105 }
106
107 let weight_sum = AliasableWeight::sum(weights.as_slice());
109 let weight_sum = if weight_sum > W::MAX {
111 W::MAX
112 } else {
113 weight_sum
114 };
115 if weight_sum == W::ZERO {
116 return Err(WeightedError::AllWeightsZero);
117 }
118
119 let n_converted = W::try_from_u32_lossy(n).unwrap();
121
122 let mut no_alias_odds = weights.into_boxed_slice();
123 for odds in no_alias_odds.iter_mut() {
124 *odds *= n_converted;
125 *odds = if *odds > W::MAX { W::MAX } else { *odds };
127 }
128
129 struct Aliases {
136 aliases: Box<[u32]>,
137 smalls_head: u32,
138 bigs_head: u32,
139 }
140
141 impl Aliases {
142 fn new(size: u32) -> Self {
143 Aliases {
144 aliases: vec![0; size as usize].into_boxed_slice(),
145 smalls_head: ::core::u32::MAX,
146 bigs_head: ::core::u32::MAX,
147 }
148 }
149
150 fn push_small(&mut self, idx: u32) {
151 self.aliases[idx as usize] = self.smalls_head;
152 self.smalls_head = idx;
153 }
154
155 fn push_big(&mut self, idx: u32) {
156 self.aliases[idx as usize] = self.bigs_head;
157 self.bigs_head = idx;
158 }
159
160 fn pop_small(&mut self) -> u32 {
161 let popped = self.smalls_head;
162 self.smalls_head = self.aliases[popped as usize];
163 popped
164 }
165
166 fn pop_big(&mut self) -> u32 {
167 let popped = self.bigs_head;
168 self.bigs_head = self.aliases[popped as usize];
169 popped
170 }
171
172 fn smalls_is_empty(&self) -> bool {
173 self.smalls_head == ::core::u32::MAX
174 }
175
176 fn bigs_is_empty(&self) -> bool {
177 self.bigs_head == ::core::u32::MAX
178 }
179
180 fn set_alias(&mut self, idx: u32, alias: u32) {
181 self.aliases[idx as usize] = alias;
182 }
183 }
184
185 let mut aliases = Aliases::new(n);
186
187 for (index, &odds) in no_alias_odds.iter().enumerate() {
189 if odds < weight_sum {
190 aliases.push_small(index as u32);
191 } else {
192 aliases.push_big(index as u32);
193 }
194 }
195
196 while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
199 let s = aliases.pop_small();
200 let b = aliases.pop_big();
201
202 aliases.set_alias(s, b);
203 no_alias_odds[b as usize] =
204 no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
205
206 if no_alias_odds[b as usize] < weight_sum {
207 aliases.push_small(b);
208 } else {
209 aliases.push_big(b);
210 }
211 }
212
213 while !aliases.smalls_is_empty() {
216 no_alias_odds[aliases.pop_small() as usize] = weight_sum;
217 }
218 while !aliases.bigs_is_empty() {
219 no_alias_odds[aliases.pop_big() as usize] = weight_sum;
220 }
221
222 let uniform_index = Uniform::new(0, n);
225 let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum);
226
227 Ok(Self {
228 aliases: aliases.aliases,
229 no_alias_odds,
230 uniform_index,
231 uniform_within_weight_sum,
232 })
233 }
234}
235
236impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
237 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
238 let candidate = rng.sample(self.uniform_index);
239 if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
240 candidate as usize
241 } else {
242 self.aliases[candidate as usize] as usize
243 }
244 }
245}
246
247impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
248where
249 W: fmt::Debug,
250 Uniform<W>: fmt::Debug,
251{
252 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253 f.debug_struct("WeightedAliasIndex")
254 .field("aliases", &self.aliases)
255 .field("no_alias_odds", &self.no_alias_odds)
256 .field("uniform_index", &self.uniform_index)
257 .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
258 .finish()
259 }
260}
261
262impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
263where Uniform<W>: Clone
264{
265 fn clone(&self) -> Self {
266 Self {
267 aliases: self.aliases.clone(),
268 no_alias_odds: self.no_alias_odds.clone(),
269 uniform_index: self.uniform_index,
270 uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
271 }
272 }
273}
274
275#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
279pub trait AliasableWeight:
280 Sized
281 + Copy
282 + SampleUniform
283 + PartialOrd
284 + Add<Output = Self>
285 + AddAssign
286 + Sub<Output = Self>
287 + SubAssign
288 + Mul<Output = Self>
289 + MulAssign
290 + Div<Output = Self>
291 + DivAssign
292 + Sum
293{
294 const MAX: Self;
296
297 const ZERO: Self;
299
300 fn try_from_u32_lossy(n: u32) -> Option<Self>;
304
305 fn sum(values: &[Self]) -> Self {
307 values.iter().copied().sum()
308 }
309}
310
311macro_rules! impl_weight_for_float {
312 ($T: ident) => {
313 impl AliasableWeight for $T {
314 const MAX: Self = ::core::$T::MAX;
315 const ZERO: Self = 0.0;
316
317 fn try_from_u32_lossy(n: u32) -> Option<Self> {
318 Some(n as $T)
319 }
320
321 fn sum(values: &[Self]) -> Self {
322 pairwise_sum(values)
323 }
324 }
325 };
326}
327
328fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
331 if values.len() <= 32 {
332 values.iter().copied().sum()
333 } else {
334 let mid = values.len() / 2;
335 let (a, b) = values.split_at(mid);
336 pairwise_sum(a) + pairwise_sum(b)
337 }
338}
339
340macro_rules! impl_weight_for_int {
341 ($T: ident) => {
342 impl AliasableWeight for $T {
343 const MAX: Self = ::core::$T::MAX;
344 const ZERO: Self = 0;
345
346 fn try_from_u32_lossy(n: u32) -> Option<Self> {
347 let n_converted = n as Self;
348 if n_converted >= Self::ZERO && n_converted as u32 == n {
349 Some(n_converted)
350 } else {
351 None
352 }
353 }
354 }
355 };
356}
357
358impl_weight_for_float!(f64);
359impl_weight_for_float!(f32);
360impl_weight_for_int!(usize);
361impl_weight_for_int!(u128);
362impl_weight_for_int!(u64);
363impl_weight_for_int!(u32);
364impl_weight_for_int!(u16);
365impl_weight_for_int!(u8);
366impl_weight_for_int!(isize);
367impl_weight_for_int!(i128);
368impl_weight_for_int!(i64);
369impl_weight_for_int!(i32);
370impl_weight_for_int!(i16);
371impl_weight_for_int!(i8);
372
373#[cfg(test)]
374mod test {
375 use super::*;
376
377 #[test]
378 #[cfg_attr(miri, ignore)] fn test_weighted_index_f32() {
380 test_weighted_index(f32::into);
381
382 assert_eq!(
384 WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
385 WeightedError::InvalidWeight
386 );
387 assert_eq!(
388 WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
389 WeightedError::AllWeightsZero
390 );
391 assert_eq!(
392 WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
393 WeightedError::InvalidWeight
394 );
395 assert_eq!(
396 WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
397 WeightedError::InvalidWeight
398 );
399 assert_eq!(
400 WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(),
401 WeightedError::InvalidWeight
402 );
403 }
404
405 #[test]
406 #[cfg_attr(miri, ignore)] fn test_weighted_index_u128() {
408 test_weighted_index(|x: u128| x as f64);
409 }
410
411 #[test]
412 #[cfg_attr(miri, ignore)] fn test_weighted_index_i128() {
414 test_weighted_index(|x: i128| x as f64);
415
416 assert_eq!(
418 WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
419 WeightedError::InvalidWeight
420 );
421 assert_eq!(
422 WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(),
423 WeightedError::InvalidWeight
424 );
425 }
426
427 #[test]
428 #[cfg_attr(miri, ignore)] fn test_weighted_index_u8() {
430 test_weighted_index(u8::into);
431 }
432
433 #[test]
434 #[cfg_attr(miri, ignore)] fn test_weighted_index_i8() {
436 test_weighted_index(i8::into);
437
438 assert_eq!(
440 WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
441 WeightedError::InvalidWeight
442 );
443 assert_eq!(
444 WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(),
445 WeightedError::InvalidWeight
446 );
447 }
448
449 fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
450 where WeightedAliasIndex<W>: fmt::Debug {
451 const NUM_WEIGHTS: u32 = 10;
452 const ZERO_WEIGHT_INDEX: u32 = 3;
453 const NUM_SAMPLES: u32 = 15000;
454 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
455
456 let weights = {
457 let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
458 let random_weight_distribution = Uniform::new_inclusive(
459 W::ZERO,
460 W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
461 );
462 for _ in 0..NUM_WEIGHTS {
463 weights.push(rng.sample(&random_weight_distribution));
464 }
465 weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
466 weights
467 };
468 let weight_sum = weights.iter().copied().sum::<W>();
469 let expected_counts = weights
470 .iter()
471 .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
472 .collect::<Vec<f64>>();
473 let weight_distribution = WeightedAliasIndex::new(weights).unwrap();
474
475 let mut counts = vec![0; NUM_WEIGHTS as usize];
476 for _ in 0..NUM_SAMPLES {
477 counts[rng.sample(&weight_distribution)] += 1;
478 }
479
480 assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
481 for (count, expected_count) in counts.into_iter().zip(expected_counts) {
482 let difference = (count as f64 - expected_count).abs();
483 let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
484 assert!(difference <= max_allowed_difference);
485 }
486
487 assert_eq!(
488 WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
489 WeightedError::NoItem
490 );
491 assert_eq!(
492 WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
493 WeightedError::AllWeightsZero
494 );
495 assert_eq!(
496 WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
497 WeightedError::InvalidWeight
498 );
499 }
500
501 #[test]
502 fn value_stability() {
503 fn test_samples<W: AliasableWeight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) {
504 assert_eq!(buf.len(), expected.len());
505 let distr = WeightedAliasIndex::new(weights).unwrap();
506 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
507 for r in buf.iter_mut() {
508 *r = rng.sample(&distr);
509 }
510 assert_eq!(buf, expected);
511 }
512
513 let mut buf = [0; 10];
514 test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
515 6, 5, 7, 5, 8, 7, 6, 2, 3, 7,
516 ]);
517 test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
518 2, 0, 0, 0, 0, 0, 0, 0, 1, 3,
519 ]);
520 test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
521 2, 1, 2, 3, 2, 1, 3, 2, 1, 1,
522 ]);
523 }
524}