1use super::collect;
7use rayon::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer};
8use rayon::prelude::*;
9
10use crate::vec::Vec;
11use alloc::boxed::Box;
12use core::cmp::Ordering;
13use core::fmt;
14use core::hash::{BuildHasher, Hash};
15use core::ops::RangeBounds;
16
17use crate::set::Slice;
18use crate::Entries;
19use crate::IndexSet;
20
21type Bucket<T> = crate::Bucket<T, ()>;
22
23impl<T, S> IntoParallelIterator for IndexSet<T, S>
24where
25 T: Send,
26{
27 type Item = T;
28 type Iter = IntoParIter<T>;
29
30 fn into_par_iter(self) -> Self::Iter {
31 IntoParIter {
32 entries: self.into_entries(),
33 }
34 }
35}
36
37impl<T> IntoParallelIterator for Box<Slice<T>>
38where
39 T: Send,
40{
41 type Item = T;
42 type Iter = IntoParIter<T>;
43
44 fn into_par_iter(self) -> Self::Iter {
45 IntoParIter {
46 entries: self.into_entries(),
47 }
48 }
49}
50
51pub struct IntoParIter<T> {
56 entries: Vec<Bucket<T>>,
57}
58
59impl<T: fmt::Debug> fmt::Debug for IntoParIter<T> {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 let iter = self.entries.iter().map(Bucket::key_ref);
62 f.debug_list().entries(iter).finish()
63 }
64}
65
66impl<T: Send> ParallelIterator for IntoParIter<T> {
67 type Item = T;
68
69 parallel_iterator_methods!(Bucket::key);
70}
71
72impl<T: Send> IndexedParallelIterator for IntoParIter<T> {
73 indexed_parallel_iterator_methods!(Bucket::key);
74}
75
76impl<'a, T, S> IntoParallelIterator for &'a IndexSet<T, S>
77where
78 T: Sync,
79{
80 type Item = &'a T;
81 type Iter = ParIter<'a, T>;
82
83 fn into_par_iter(self) -> Self::Iter {
84 ParIter {
85 entries: self.as_entries(),
86 }
87 }
88}
89
90impl<'a, T> IntoParallelIterator for &'a Slice<T>
91where
92 T: Sync,
93{
94 type Item = &'a T;
95 type Iter = ParIter<'a, T>;
96
97 fn into_par_iter(self) -> Self::Iter {
98 ParIter {
99 entries: &self.entries,
100 }
101 }
102}
103
104pub struct ParIter<'a, T> {
111 entries: &'a [Bucket<T>],
112}
113
114impl<T> Clone for ParIter<'_, T> {
115 fn clone(&self) -> Self {
116 ParIter { ..*self }
117 }
118}
119
120impl<T: fmt::Debug> fmt::Debug for ParIter<'_, T> {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 let iter = self.entries.iter().map(Bucket::key_ref);
123 f.debug_list().entries(iter).finish()
124 }
125}
126
127impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
128 type Item = &'a T;
129
130 parallel_iterator_methods!(Bucket::key_ref);
131}
132
133impl<T: Sync> IndexedParallelIterator for ParIter<'_, T> {
134 indexed_parallel_iterator_methods!(Bucket::key_ref);
135}
136
137impl<'a, T, S> ParallelDrainRange<usize> for &'a mut IndexSet<T, S>
138where
139 T: Send,
140{
141 type Item = T;
142 type Iter = ParDrain<'a, T>;
143
144 fn par_drain<R: RangeBounds<usize>>(self, range: R) -> Self::Iter {
145 ParDrain {
146 entries: self.map.core.par_drain(range),
147 }
148 }
149}
150
151pub struct ParDrain<'a, T: Send> {
158 entries: rayon::vec::Drain<'a, Bucket<T>>,
159}
160
161impl<T: Send> ParallelIterator for ParDrain<'_, T> {
162 type Item = T;
163
164 parallel_iterator_methods!(Bucket::key);
165}
166
167impl<T: Send> IndexedParallelIterator for ParDrain<'_, T> {
168 indexed_parallel_iterator_methods!(Bucket::key);
169}
170
171impl<T, S> IndexSet<T, S>
177where
178 T: Hash + Eq + Sync,
179 S: BuildHasher + Sync,
180{
181 pub fn par_difference<'a, S2>(
186 &'a self,
187 other: &'a IndexSet<T, S2>,
188 ) -> ParDifference<'a, T, S, S2>
189 where
190 S2: BuildHasher + Sync,
191 {
192 ParDifference {
193 set1: self,
194 set2: other,
195 }
196 }
197
198 pub fn par_symmetric_difference<'a, S2>(
206 &'a self,
207 other: &'a IndexSet<T, S2>,
208 ) -> ParSymmetricDifference<'a, T, S, S2>
209 where
210 S2: BuildHasher + Sync,
211 {
212 ParSymmetricDifference {
213 set1: self,
214 set2: other,
215 }
216 }
217
218 pub fn par_intersection<'a, S2>(
223 &'a self,
224 other: &'a IndexSet<T, S2>,
225 ) -> ParIntersection<'a, T, S, S2>
226 where
227 S2: BuildHasher + Sync,
228 {
229 ParIntersection {
230 set1: self,
231 set2: other,
232 }
233 }
234
235 pub fn par_union<'a, S2>(&'a self, other: &'a IndexSet<T, S2>) -> ParUnion<'a, T, S, S2>
242 where
243 S2: BuildHasher + Sync,
244 {
245 ParUnion {
246 set1: self,
247 set2: other,
248 }
249 }
250
251 pub fn par_eq<S2>(&self, other: &IndexSet<T, S2>) -> bool
254 where
255 S2: BuildHasher + Sync,
256 {
257 self.len() == other.len() && self.par_is_subset(other)
258 }
259
260 pub fn par_is_disjoint<S2>(&self, other: &IndexSet<T, S2>) -> bool
263 where
264 S2: BuildHasher + Sync,
265 {
266 if self.len() <= other.len() {
267 self.par_iter().all(move |value| !other.contains(value))
268 } else {
269 other.par_iter().all(move |value| !self.contains(value))
270 }
271 }
272
273 pub fn par_is_superset<S2>(&self, other: &IndexSet<T, S2>) -> bool
276 where
277 S2: BuildHasher + Sync,
278 {
279 other.par_is_subset(self)
280 }
281
282 pub fn par_is_subset<S2>(&self, other: &IndexSet<T, S2>) -> bool
285 where
286 S2: BuildHasher + Sync,
287 {
288 self.len() <= other.len() && self.par_iter().all(move |value| other.contains(value))
289 }
290}
291
292pub struct ParDifference<'a, T, S1, S2> {
297 set1: &'a IndexSet<T, S1>,
298 set2: &'a IndexSet<T, S2>,
299}
300
301impl<T, S1, S2> Clone for ParDifference<'_, T, S1, S2> {
302 fn clone(&self) -> Self {
303 ParDifference { ..*self }
304 }
305}
306
307impl<T, S1, S2> fmt::Debug for ParDifference<'_, T, S1, S2>
308where
309 T: fmt::Debug + Eq + Hash,
310 S1: BuildHasher,
311 S2: BuildHasher,
312{
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 f.debug_list()
315 .entries(self.set1.difference(self.set2))
316 .finish()
317 }
318}
319
320impl<'a, T, S1, S2> ParallelIterator for ParDifference<'a, T, S1, S2>
321where
322 T: Hash + Eq + Sync,
323 S1: BuildHasher + Sync,
324 S2: BuildHasher + Sync,
325{
326 type Item = &'a T;
327
328 fn drive_unindexed<C>(self, consumer: C) -> C::Result
329 where
330 C: UnindexedConsumer<Self::Item>,
331 {
332 let Self { set1, set2 } = self;
333
334 set1.par_iter()
335 .filter(move |&item| !set2.contains(item))
336 .drive_unindexed(consumer)
337 }
338}
339
340pub struct ParIntersection<'a, T, S1, S2> {
345 set1: &'a IndexSet<T, S1>,
346 set2: &'a IndexSet<T, S2>,
347}
348
349impl<T, S1, S2> Clone for ParIntersection<'_, T, S1, S2> {
350 fn clone(&self) -> Self {
351 ParIntersection { ..*self }
352 }
353}
354
355impl<T, S1, S2> fmt::Debug for ParIntersection<'_, T, S1, S2>
356where
357 T: fmt::Debug + Eq + Hash,
358 S1: BuildHasher,
359 S2: BuildHasher,
360{
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 f.debug_list()
363 .entries(self.set1.intersection(self.set2))
364 .finish()
365 }
366}
367
368impl<'a, T, S1, S2> ParallelIterator for ParIntersection<'a, T, S1, S2>
369where
370 T: Hash + Eq + Sync,
371 S1: BuildHasher + Sync,
372 S2: BuildHasher + Sync,
373{
374 type Item = &'a T;
375
376 fn drive_unindexed<C>(self, consumer: C) -> C::Result
377 where
378 C: UnindexedConsumer<Self::Item>,
379 {
380 let Self { set1, set2 } = self;
381
382 set1.par_iter()
383 .filter(move |&item| set2.contains(item))
384 .drive_unindexed(consumer)
385 }
386}
387
388pub struct ParSymmetricDifference<'a, T, S1, S2> {
393 set1: &'a IndexSet<T, S1>,
394 set2: &'a IndexSet<T, S2>,
395}
396
397impl<T, S1, S2> Clone for ParSymmetricDifference<'_, T, S1, S2> {
398 fn clone(&self) -> Self {
399 ParSymmetricDifference { ..*self }
400 }
401}
402
403impl<T, S1, S2> fmt::Debug for ParSymmetricDifference<'_, T, S1, S2>
404where
405 T: fmt::Debug + Eq + Hash,
406 S1: BuildHasher,
407 S2: BuildHasher,
408{
409 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
410 f.debug_list()
411 .entries(self.set1.symmetric_difference(self.set2))
412 .finish()
413 }
414}
415
416impl<'a, T, S1, S2> ParallelIterator for ParSymmetricDifference<'a, T, S1, S2>
417where
418 T: Hash + Eq + Sync,
419 S1: BuildHasher + Sync,
420 S2: BuildHasher + Sync,
421{
422 type Item = &'a T;
423
424 fn drive_unindexed<C>(self, consumer: C) -> C::Result
425 where
426 C: UnindexedConsumer<Self::Item>,
427 {
428 let Self { set1, set2 } = self;
429
430 set1.par_difference(set2)
431 .chain(set2.par_difference(set1))
432 .drive_unindexed(consumer)
433 }
434}
435
436pub struct ParUnion<'a, T, S1, S2> {
441 set1: &'a IndexSet<T, S1>,
442 set2: &'a IndexSet<T, S2>,
443}
444
445impl<T, S1, S2> Clone for ParUnion<'_, T, S1, S2> {
446 fn clone(&self) -> Self {
447 ParUnion { ..*self }
448 }
449}
450
451impl<T, S1, S2> fmt::Debug for ParUnion<'_, T, S1, S2>
452where
453 T: fmt::Debug + Eq + Hash,
454 S1: BuildHasher,
455 S2: BuildHasher,
456{
457 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458 f.debug_list().entries(self.set1.union(self.set2)).finish()
459 }
460}
461
462impl<'a, T, S1, S2> ParallelIterator for ParUnion<'a, T, S1, S2>
463where
464 T: Hash + Eq + Sync,
465 S1: BuildHasher + Sync,
466 S2: BuildHasher + Sync,
467{
468 type Item = &'a T;
469
470 fn drive_unindexed<C>(self, consumer: C) -> C::Result
471 where
472 C: UnindexedConsumer<Self::Item>,
473 {
474 let Self { set1, set2 } = self;
475
476 set1.par_iter()
477 .chain(set2.par_difference(set1))
478 .drive_unindexed(consumer)
479 }
480}
481
482impl<T, S> IndexSet<T, S>
486where
487 T: Send,
488{
489 pub fn par_sort(&mut self)
491 where
492 T: Ord,
493 {
494 self.with_entries(|entries| {
495 entries.par_sort_by(|a, b| T::cmp(&a.key, &b.key));
496 });
497 }
498
499 pub fn par_sort_by<F>(&mut self, cmp: F)
501 where
502 F: Fn(&T, &T) -> Ordering + Sync,
503 {
504 self.with_entries(|entries| {
505 entries.par_sort_by(move |a, b| cmp(&a.key, &b.key));
506 });
507 }
508
509 pub fn par_sorted_by<F>(self, cmp: F) -> IntoParIter<T>
512 where
513 F: Fn(&T, &T) -> Ordering + Sync,
514 {
515 let mut entries = self.into_entries();
516 entries.par_sort_by(move |a, b| cmp(&a.key, &b.key));
517 IntoParIter { entries }
518 }
519
520 pub fn par_sort_unstable(&mut self)
522 where
523 T: Ord,
524 {
525 self.with_entries(|entries| {
526 entries.par_sort_unstable_by(|a, b| T::cmp(&a.key, &b.key));
527 });
528 }
529
530 pub fn par_sort_unstable_by<F>(&mut self, cmp: F)
532 where
533 F: Fn(&T, &T) -> Ordering + Sync,
534 {
535 self.with_entries(|entries| {
536 entries.par_sort_unstable_by(move |a, b| cmp(&a.key, &b.key));
537 });
538 }
539
540 pub fn par_sorted_unstable_by<F>(self, cmp: F) -> IntoParIter<T>
543 where
544 F: Fn(&T, &T) -> Ordering + Sync,
545 {
546 let mut entries = self.into_entries();
547 entries.par_sort_unstable_by(move |a, b| cmp(&a.key, &b.key));
548 IntoParIter { entries }
549 }
550
551 pub fn par_sort_by_cached_key<K, F>(&mut self, sort_key: F)
553 where
554 K: Ord + Send,
555 F: Fn(&T) -> K + Sync,
556 {
557 self.with_entries(move |entries| {
558 entries.par_sort_by_cached_key(move |a| sort_key(&a.key));
559 });
560 }
561}
562
563impl<T, S> FromParallelIterator<T> for IndexSet<T, S>
564where
565 T: Eq + Hash + Send,
566 S: BuildHasher + Default + Send,
567{
568 fn from_par_iter<I>(iter: I) -> Self
569 where
570 I: IntoParallelIterator<Item = T>,
571 {
572 let list = collect(iter);
573 let len = list.iter().map(Vec::len).sum();
574 let mut set = Self::with_capacity_and_hasher(len, S::default());
575 for vec in list {
576 set.extend(vec);
577 }
578 set
579 }
580}
581
582impl<T, S> ParallelExtend<T> for IndexSet<T, S>
583where
584 T: Eq + Hash + Send,
585 S: BuildHasher + Send,
586{
587 fn par_extend<I>(&mut self, iter: I)
588 where
589 I: IntoParallelIterator<Item = T>,
590 {
591 for vec in collect(iter) {
592 self.extend(vec);
593 }
594 }
595}
596
597impl<'a, T: 'a, S> ParallelExtend<&'a T> for IndexSet<T, S>
598where
599 T: Copy + Eq + Hash + Send + Sync,
600 S: BuildHasher + Send,
601{
602 fn par_extend<I>(&mut self, iter: I)
603 where
604 I: IntoParallelIterator<Item = &'a T>,
605 {
606 for vec in collect(iter) {
607 self.extend(vec);
608 }
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn insert_order() {
618 let insert = [0, 4, 2, 12, 8, 7, 11, 5, 3, 17, 19, 22, 23];
619 let mut set = IndexSet::new();
620
621 for &elt in &insert {
622 set.insert(elt);
623 }
624
625 assert_eq!(set.par_iter().count(), set.len());
626 assert_eq!(set.par_iter().count(), insert.len());
627 insert.par_iter().zip(&set).for_each(|(a, b)| {
628 assert_eq!(a, b);
629 });
630 (0..insert.len())
631 .into_par_iter()
632 .zip(&set)
633 .for_each(|(i, v)| {
634 assert_eq!(set.get_index(i).unwrap(), v);
635 });
636 }
637
638 #[test]
639 fn partial_eq_and_eq() {
640 let mut set_a = IndexSet::new();
641 set_a.insert(1);
642 set_a.insert(2);
643 let mut set_b = set_a.clone();
644 assert!(set_a.par_eq(&set_b));
645 set_b.swap_remove(&1);
646 assert!(!set_a.par_eq(&set_b));
647 set_b.insert(3);
648 assert!(!set_a.par_eq(&set_b));
649
650 let set_c: IndexSet<_> = set_b.into_par_iter().collect();
651 assert!(!set_a.par_eq(&set_c));
652 assert!(!set_c.par_eq(&set_a));
653 }
654
655 #[test]
656 fn extend() {
657 let mut set = IndexSet::new();
658 set.par_extend(vec![&1, &2, &3, &4]);
659 set.par_extend(vec![5, 6]);
660 assert_eq!(
661 set.into_par_iter().collect::<Vec<_>>(),
662 vec![1, 2, 3, 4, 5, 6]
663 );
664 }
665
666 #[test]
667 fn comparisons() {
668 let set_a: IndexSet<_> = (0..3).collect();
669 let set_b: IndexSet<_> = (3..6).collect();
670 let set_c: IndexSet<_> = (0..6).collect();
671 let set_d: IndexSet<_> = (3..9).collect();
672
673 assert!(!set_a.par_is_disjoint(&set_a));
674 assert!(set_a.par_is_subset(&set_a));
675 assert!(set_a.par_is_superset(&set_a));
676
677 assert!(set_a.par_is_disjoint(&set_b));
678 assert!(set_b.par_is_disjoint(&set_a));
679 assert!(!set_a.par_is_subset(&set_b));
680 assert!(!set_b.par_is_subset(&set_a));
681 assert!(!set_a.par_is_superset(&set_b));
682 assert!(!set_b.par_is_superset(&set_a));
683
684 assert!(!set_a.par_is_disjoint(&set_c));
685 assert!(!set_c.par_is_disjoint(&set_a));
686 assert!(set_a.par_is_subset(&set_c));
687 assert!(!set_c.par_is_subset(&set_a));
688 assert!(!set_a.par_is_superset(&set_c));
689 assert!(set_c.par_is_superset(&set_a));
690
691 assert!(!set_c.par_is_disjoint(&set_d));
692 assert!(!set_d.par_is_disjoint(&set_c));
693 assert!(!set_c.par_is_subset(&set_d));
694 assert!(!set_d.par_is_subset(&set_c));
695 assert!(!set_c.par_is_superset(&set_d));
696 assert!(!set_d.par_is_superset(&set_c));
697 }
698
699 #[test]
700 fn iter_comparisons() {
701 use std::iter::empty;
702
703 fn check<'a, I1, I2>(iter1: I1, iter2: I2)
704 where
705 I1: ParallelIterator<Item = &'a i32>,
706 I2: Iterator<Item = i32>,
707 {
708 let v1: Vec<_> = iter1.copied().collect();
709 let v2: Vec<_> = iter2.collect();
710 assert_eq!(v1, v2);
711 }
712
713 let set_a: IndexSet<_> = (0..3).collect();
714 let set_b: IndexSet<_> = (3..6).collect();
715 let set_c: IndexSet<_> = (0..6).collect();
716 let set_d: IndexSet<_> = (3..9).rev().collect();
717
718 check(set_a.par_difference(&set_a), empty());
719 check(set_a.par_symmetric_difference(&set_a), empty());
720 check(set_a.par_intersection(&set_a), 0..3);
721 check(set_a.par_union(&set_a), 0..3);
722
723 check(set_a.par_difference(&set_b), 0..3);
724 check(set_b.par_difference(&set_a), 3..6);
725 check(set_a.par_symmetric_difference(&set_b), 0..6);
726 check(set_b.par_symmetric_difference(&set_a), (3..6).chain(0..3));
727 check(set_a.par_intersection(&set_b), empty());
728 check(set_b.par_intersection(&set_a), empty());
729 check(set_a.par_union(&set_b), 0..6);
730 check(set_b.par_union(&set_a), (3..6).chain(0..3));
731
732 check(set_a.par_difference(&set_c), empty());
733 check(set_c.par_difference(&set_a), 3..6);
734 check(set_a.par_symmetric_difference(&set_c), 3..6);
735 check(set_c.par_symmetric_difference(&set_a), 3..6);
736 check(set_a.par_intersection(&set_c), 0..3);
737 check(set_c.par_intersection(&set_a), 0..3);
738 check(set_a.par_union(&set_c), 0..6);
739 check(set_c.par_union(&set_a), 0..6);
740
741 check(set_c.par_difference(&set_d), 0..3);
742 check(set_d.par_difference(&set_c), (6..9).rev());
743 check(
744 set_c.par_symmetric_difference(&set_d),
745 (0..3).chain((6..9).rev()),
746 );
747 check(
748 set_d.par_symmetric_difference(&set_c),
749 (6..9).rev().chain(0..3),
750 );
751 check(set_c.par_intersection(&set_d), 3..6);
752 check(set_d.par_intersection(&set_c), (3..6).rev());
753 check(set_c.par_union(&set_d), (0..6).chain((6..9).rev()));
754 check(set_d.par_union(&set_c), (3..9).rev().chain(0..3));
755 }
756}