1use super::collect;
7use rayon::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer};
8use rayon::prelude::*;
9
10use alloc::boxed::Box;
11use alloc::vec::Vec;
12use core::cmp::Ordering;
13use core::fmt;
14use core::hash::{BuildHasher, Hash};
15use core::ops::RangeBounds;
16
17use crate::set::Slice;
18use crate::IndexSet;
19
20type Bucket<T> = crate::Bucket<T, ()>;
21
22impl<T, S> IntoParallelIterator for IndexSet<T, S>
23where
24 T: Send,
25{
26 type Item = T;
27 type Iter = IntoParIter<T>;
28
29 fn into_par_iter(self) -> Self::Iter {
30 IntoParIter {
31 entries: self.into_entries(),
32 }
33 }
34}
35
36impl<T> IntoParallelIterator for Box<Slice<T>>
37where
38 T: Send,
39{
40 type Item = T;
41 type Iter = IntoParIter<T>;
42
43 fn into_par_iter(self) -> Self::Iter {
44 IntoParIter {
45 entries: self.into_entries(),
46 }
47 }
48}
49
50pub struct IntoParIter<T> {
55 entries: Vec<Bucket<T>>,
56}
57
58impl<T: fmt::Debug> fmt::Debug for IntoParIter<T> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 let iter = self.entries.iter().map(Bucket::key_ref);
61 f.debug_list().entries(iter).finish()
62 }
63}
64
65impl<T: Send> ParallelIterator for IntoParIter<T> {
66 type Item = T;
67
68 parallel_iterator_methods!(Bucket::key);
69}
70
71impl<T: Send> IndexedParallelIterator for IntoParIter<T> {
72 indexed_parallel_iterator_methods!(Bucket::key);
73}
74
75impl<'a, T, S> IntoParallelIterator for &'a IndexSet<T, S>
76where
77 T: Sync,
78{
79 type Item = &'a T;
80 type Iter = ParIter<'a, T>;
81
82 fn into_par_iter(self) -> Self::Iter {
83 ParIter {
84 entries: self.as_entries(),
85 }
86 }
87}
88
89impl<'a, T> IntoParallelIterator for &'a Slice<T>
90where
91 T: Sync,
92{
93 type Item = &'a T;
94 type Iter = ParIter<'a, T>;
95
96 fn into_par_iter(self) -> Self::Iter {
97 ParIter {
98 entries: &self.entries,
99 }
100 }
101}
102
103pub struct ParIter<'a, T> {
110 entries: &'a [Bucket<T>],
111}
112
113impl<T> Clone for ParIter<'_, T> {
114 fn clone(&self) -> Self {
115 ParIter { ..*self }
116 }
117}
118
119impl<T: fmt::Debug> fmt::Debug for ParIter<'_, T> {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 let iter = self.entries.iter().map(Bucket::key_ref);
122 f.debug_list().entries(iter).finish()
123 }
124}
125
126impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
127 type Item = &'a T;
128
129 parallel_iterator_methods!(Bucket::key_ref);
130}
131
132impl<T: Sync> IndexedParallelIterator for ParIter<'_, T> {
133 indexed_parallel_iterator_methods!(Bucket::key_ref);
134}
135
136impl<'a, T, S> ParallelDrainRange<usize> for &'a mut IndexSet<T, S>
137where
138 T: Send,
139{
140 type Item = T;
141 type Iter = ParDrain<'a, T>;
142
143 fn par_drain<R: RangeBounds<usize>>(self, range: R) -> Self::Iter {
144 ParDrain {
145 entries: self.map.core.par_drain(range),
146 }
147 }
148}
149
150pub struct ParDrain<'a, T: Send> {
157 entries: rayon::vec::Drain<'a, Bucket<T>>,
158}
159
160impl<T: Send> ParallelIterator for ParDrain<'_, T> {
161 type Item = T;
162
163 parallel_iterator_methods!(Bucket::key);
164}
165
166impl<T: Send> IndexedParallelIterator for ParDrain<'_, T> {
167 indexed_parallel_iterator_methods!(Bucket::key);
168}
169
170impl<T, S> IndexSet<T, S>
176where
177 T: Hash + Eq + Sync,
178 S: BuildHasher + Sync,
179{
180 pub fn par_difference<'a, S2>(
185 &'a self,
186 other: &'a IndexSet<T, S2>,
187 ) -> ParDifference<'a, T, S, S2>
188 where
189 S2: BuildHasher + Sync,
190 {
191 ParDifference {
192 set1: self,
193 set2: other,
194 }
195 }
196
197 pub fn par_symmetric_difference<'a, S2>(
205 &'a self,
206 other: &'a IndexSet<T, S2>,
207 ) -> ParSymmetricDifference<'a, T, S, S2>
208 where
209 S2: BuildHasher + Sync,
210 {
211 ParSymmetricDifference {
212 set1: self,
213 set2: other,
214 }
215 }
216
217 pub fn par_intersection<'a, S2>(
222 &'a self,
223 other: &'a IndexSet<T, S2>,
224 ) -> ParIntersection<'a, T, S, S2>
225 where
226 S2: BuildHasher + Sync,
227 {
228 ParIntersection {
229 set1: self,
230 set2: other,
231 }
232 }
233
234 pub fn par_union<'a, S2>(&'a self, other: &'a IndexSet<T, S2>) -> ParUnion<'a, T, S, S2>
241 where
242 S2: BuildHasher + Sync,
243 {
244 ParUnion {
245 set1: self,
246 set2: other,
247 }
248 }
249
250 pub fn par_eq<S2>(&self, other: &IndexSet<T, S2>) -> bool
253 where
254 S2: BuildHasher + Sync,
255 {
256 self.len() == other.len() && self.par_is_subset(other)
257 }
258
259 pub fn par_is_disjoint<S2>(&self, other: &IndexSet<T, S2>) -> bool
262 where
263 S2: BuildHasher + Sync,
264 {
265 if self.len() <= other.len() {
266 self.par_iter().all(move |value| !other.contains(value))
267 } else {
268 other.par_iter().all(move |value| !self.contains(value))
269 }
270 }
271
272 pub fn par_is_superset<S2>(&self, other: &IndexSet<T, S2>) -> bool
275 where
276 S2: BuildHasher + Sync,
277 {
278 other.par_is_subset(self)
279 }
280
281 pub fn par_is_subset<S2>(&self, other: &IndexSet<T, S2>) -> bool
284 where
285 S2: BuildHasher + Sync,
286 {
287 self.len() <= other.len() && self.par_iter().all(move |value| other.contains(value))
288 }
289}
290
291pub struct ParDifference<'a, T, S1, S2> {
296 set1: &'a IndexSet<T, S1>,
297 set2: &'a IndexSet<T, S2>,
298}
299
300impl<T, S1, S2> Clone for ParDifference<'_, T, S1, S2> {
301 fn clone(&self) -> Self {
302 ParDifference { ..*self }
303 }
304}
305
306impl<T, S1, S2> fmt::Debug for ParDifference<'_, T, S1, S2>
307where
308 T: fmt::Debug + Eq + Hash,
309 S1: BuildHasher,
310 S2: BuildHasher,
311{
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 f.debug_list()
314 .entries(self.set1.difference(self.set2))
315 .finish()
316 }
317}
318
319impl<'a, T, S1, S2> ParallelIterator for ParDifference<'a, T, S1, S2>
320where
321 T: Hash + Eq + Sync,
322 S1: BuildHasher + Sync,
323 S2: BuildHasher + Sync,
324{
325 type Item = &'a T;
326
327 fn drive_unindexed<C>(self, consumer: C) -> C::Result
328 where
329 C: UnindexedConsumer<Self::Item>,
330 {
331 let Self { set1, set2 } = self;
332
333 set1.par_iter()
334 .filter(move |&item| !set2.contains(item))
335 .drive_unindexed(consumer)
336 }
337}
338
339pub struct ParIntersection<'a, T, S1, S2> {
344 set1: &'a IndexSet<T, S1>,
345 set2: &'a IndexSet<T, S2>,
346}
347
348impl<T, S1, S2> Clone for ParIntersection<'_, T, S1, S2> {
349 fn clone(&self) -> Self {
350 ParIntersection { ..*self }
351 }
352}
353
354impl<T, S1, S2> fmt::Debug for ParIntersection<'_, T, S1, S2>
355where
356 T: fmt::Debug + Eq + Hash,
357 S1: BuildHasher,
358 S2: BuildHasher,
359{
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 f.debug_list()
362 .entries(self.set1.intersection(self.set2))
363 .finish()
364 }
365}
366
367impl<'a, T, S1, S2> ParallelIterator for ParIntersection<'a, T, S1, S2>
368where
369 T: Hash + Eq + Sync,
370 S1: BuildHasher + Sync,
371 S2: BuildHasher + Sync,
372{
373 type Item = &'a T;
374
375 fn drive_unindexed<C>(self, consumer: C) -> C::Result
376 where
377 C: UnindexedConsumer<Self::Item>,
378 {
379 let Self { set1, set2 } = self;
380
381 set1.par_iter()
382 .filter(move |&item| set2.contains(item))
383 .drive_unindexed(consumer)
384 }
385}
386
387pub struct ParSymmetricDifference<'a, T, S1, S2> {
392 set1: &'a IndexSet<T, S1>,
393 set2: &'a IndexSet<T, S2>,
394}
395
396impl<T, S1, S2> Clone for ParSymmetricDifference<'_, T, S1, S2> {
397 fn clone(&self) -> Self {
398 ParSymmetricDifference { ..*self }
399 }
400}
401
402impl<T, S1, S2> fmt::Debug for ParSymmetricDifference<'_, T, S1, S2>
403where
404 T: fmt::Debug + Eq + Hash,
405 S1: BuildHasher,
406 S2: BuildHasher,
407{
408 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409 f.debug_list()
410 .entries(self.set1.symmetric_difference(self.set2))
411 .finish()
412 }
413}
414
415impl<'a, T, S1, S2> ParallelIterator for ParSymmetricDifference<'a, T, S1, S2>
416where
417 T: Hash + Eq + Sync,
418 S1: BuildHasher + Sync,
419 S2: BuildHasher + Sync,
420{
421 type Item = &'a T;
422
423 fn drive_unindexed<C>(self, consumer: C) -> C::Result
424 where
425 C: UnindexedConsumer<Self::Item>,
426 {
427 let Self { set1, set2 } = self;
428
429 set1.par_difference(set2)
430 .chain(set2.par_difference(set1))
431 .drive_unindexed(consumer)
432 }
433}
434
435pub struct ParUnion<'a, T, S1, S2> {
440 set1: &'a IndexSet<T, S1>,
441 set2: &'a IndexSet<T, S2>,
442}
443
444impl<T, S1, S2> Clone for ParUnion<'_, T, S1, S2> {
445 fn clone(&self) -> Self {
446 ParUnion { ..*self }
447 }
448}
449
450impl<T, S1, S2> fmt::Debug for ParUnion<'_, T, S1, S2>
451where
452 T: fmt::Debug + Eq + Hash,
453 S1: BuildHasher,
454 S2: BuildHasher,
455{
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 f.debug_list().entries(self.set1.union(self.set2)).finish()
458 }
459}
460
461impl<'a, T, S1, S2> ParallelIterator for ParUnion<'a, T, S1, S2>
462where
463 T: Hash + Eq + Sync,
464 S1: BuildHasher + Sync,
465 S2: BuildHasher + Sync,
466{
467 type Item = &'a T;
468
469 fn drive_unindexed<C>(self, consumer: C) -> C::Result
470 where
471 C: UnindexedConsumer<Self::Item>,
472 {
473 let Self { set1, set2 } = self;
474
475 set1.par_iter()
476 .chain(set2.par_difference(set1))
477 .drive_unindexed(consumer)
478 }
479}
480
481impl<T, S> IndexSet<T, S>
485where
486 T: Send,
487{
488 pub fn par_sort(&mut self)
490 where
491 T: Ord,
492 {
493 self.with_entries(|entries| {
494 entries.par_sort_by(|a, b| T::cmp(&a.key, &b.key));
495 });
496 }
497
498 pub fn par_sort_by<F>(&mut self, cmp: F)
500 where
501 F: Fn(&T, &T) -> Ordering + Sync,
502 {
503 self.with_entries(|entries| {
504 entries.par_sort_by(move |a, b| cmp(&a.key, &b.key));
505 });
506 }
507
508 pub fn par_sorted_by<F>(self, cmp: F) -> IntoParIter<T>
511 where
512 F: Fn(&T, &T) -> Ordering + Sync,
513 {
514 let mut entries = self.into_entries();
515 entries.par_sort_by(move |a, b| cmp(&a.key, &b.key));
516 IntoParIter { entries }
517 }
518
519 pub fn par_sort_by_key<K, F>(&mut self, sort_key: F)
521 where
522 K: Ord,
523 F: Fn(&T) -> K + Sync,
524 {
525 self.with_entries(move |entries| {
526 entries.par_sort_by_key(move |a| sort_key(&a.key));
527 });
528 }
529
530 pub fn par_sort_unstable(&mut self)
532 where
533 T: Ord,
534 {
535 self.with_entries(|entries| {
536 entries.par_sort_unstable_by(|a, b| T::cmp(&a.key, &b.key));
537 });
538 }
539
540 pub fn par_sort_unstable_by<F>(&mut self, cmp: F)
542 where
543 F: Fn(&T, &T) -> Ordering + Sync,
544 {
545 self.with_entries(|entries| {
546 entries.par_sort_unstable_by(move |a, b| cmp(&a.key, &b.key));
547 });
548 }
549
550 pub fn par_sorted_unstable_by<F>(self, cmp: F) -> IntoParIter<T>
553 where
554 F: Fn(&T, &T) -> Ordering + Sync,
555 {
556 let mut entries = self.into_entries();
557 entries.par_sort_unstable_by(move |a, b| cmp(&a.key, &b.key));
558 IntoParIter { entries }
559 }
560
561 pub fn par_sort_unstable_by_key<K, F>(&mut self, sort_key: F)
563 where
564 K: Ord,
565 F: Fn(&T) -> K + Sync,
566 {
567 self.with_entries(move |entries| {
568 entries.par_sort_unstable_by_key(move |a| sort_key(&a.key));
569 });
570 }
571
572 pub fn par_sort_by_cached_key<K, F>(&mut self, sort_key: F)
574 where
575 K: Ord + Send,
576 F: Fn(&T) -> K + Sync,
577 {
578 self.with_entries(move |entries| {
579 entries.par_sort_by_cached_key(move |a| sort_key(&a.key));
580 });
581 }
582}
583
584impl<T, S> FromParallelIterator<T> for IndexSet<T, S>
585where
586 T: Eq + Hash + Send,
587 S: BuildHasher + Default + Send,
588{
589 fn from_par_iter<I>(iter: I) -> Self
590 where
591 I: IntoParallelIterator<Item = T>,
592 {
593 let list = collect(iter);
594 let len = list.iter().map(Vec::len).sum();
595 let mut set = Self::with_capacity_and_hasher(len, S::default());
596 for vec in list {
597 set.extend(vec);
598 }
599 set
600 }
601}
602
603impl<T, S> ParallelExtend<T> for IndexSet<T, S>
604where
605 T: Eq + Hash + Send,
606 S: BuildHasher + Send,
607{
608 fn par_extend<I>(&mut self, iter: I)
609 where
610 I: IntoParallelIterator<Item = T>,
611 {
612 for vec in collect(iter) {
613 self.extend(vec);
614 }
615 }
616}
617
618impl<'a, T: 'a, S> ParallelExtend<&'a T> for IndexSet<T, S>
619where
620 T: Copy + Eq + Hash + Send + Sync,
621 S: BuildHasher + Send,
622{
623 fn par_extend<I>(&mut self, iter: I)
624 where
625 I: IntoParallelIterator<Item = &'a T>,
626 {
627 for vec in collect(iter) {
628 self.extend(vec);
629 }
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn insert_order() {
639 let insert = [0, 4, 2, 12, 8, 7, 11, 5, 3, 17, 19, 22, 23];
640 let mut set = IndexSet::new();
641
642 for &elt in &insert {
643 set.insert(elt);
644 }
645
646 assert_eq!(set.par_iter().count(), set.len());
647 assert_eq!(set.par_iter().count(), insert.len());
648 insert.par_iter().zip(&set).for_each(|(a, b)| {
649 assert_eq!(a, b);
650 });
651 (0..insert.len())
652 .into_par_iter()
653 .zip(&set)
654 .for_each(|(i, v)| {
655 assert_eq!(set.get_index(i).unwrap(), v);
656 });
657 }
658
659 #[test]
660 fn partial_eq_and_eq() {
661 let mut set_a = IndexSet::new();
662 set_a.insert(1);
663 set_a.insert(2);
664 let mut set_b = set_a.clone();
665 assert!(set_a.par_eq(&set_b));
666 set_b.swap_remove(&1);
667 assert!(!set_a.par_eq(&set_b));
668 set_b.insert(3);
669 assert!(!set_a.par_eq(&set_b));
670
671 let set_c: IndexSet<_> = set_b.into_par_iter().collect();
672 assert!(!set_a.par_eq(&set_c));
673 assert!(!set_c.par_eq(&set_a));
674 }
675
676 #[test]
677 fn extend() {
678 let mut set = IndexSet::new();
679 set.par_extend(vec![&1, &2, &3, &4]);
680 set.par_extend(vec![5, 6]);
681 assert_eq!(
682 set.into_par_iter().collect::<Vec<_>>(),
683 vec![1, 2, 3, 4, 5, 6]
684 );
685 }
686
687 #[test]
688 fn comparisons() {
689 let set_a: IndexSet<_> = (0..3).collect();
690 let set_b: IndexSet<_> = (3..6).collect();
691 let set_c: IndexSet<_> = (0..6).collect();
692 let set_d: IndexSet<_> = (3..9).collect();
693
694 assert!(!set_a.par_is_disjoint(&set_a));
695 assert!(set_a.par_is_subset(&set_a));
696 assert!(set_a.par_is_superset(&set_a));
697
698 assert!(set_a.par_is_disjoint(&set_b));
699 assert!(set_b.par_is_disjoint(&set_a));
700 assert!(!set_a.par_is_subset(&set_b));
701 assert!(!set_b.par_is_subset(&set_a));
702 assert!(!set_a.par_is_superset(&set_b));
703 assert!(!set_b.par_is_superset(&set_a));
704
705 assert!(!set_a.par_is_disjoint(&set_c));
706 assert!(!set_c.par_is_disjoint(&set_a));
707 assert!(set_a.par_is_subset(&set_c));
708 assert!(!set_c.par_is_subset(&set_a));
709 assert!(!set_a.par_is_superset(&set_c));
710 assert!(set_c.par_is_superset(&set_a));
711
712 assert!(!set_c.par_is_disjoint(&set_d));
713 assert!(!set_d.par_is_disjoint(&set_c));
714 assert!(!set_c.par_is_subset(&set_d));
715 assert!(!set_d.par_is_subset(&set_c));
716 assert!(!set_c.par_is_superset(&set_d));
717 assert!(!set_d.par_is_superset(&set_c));
718 }
719
720 #[test]
721 fn iter_comparisons() {
722 use std::iter::empty;
723
724 fn check<'a, I1, I2>(iter1: I1, iter2: I2)
725 where
726 I1: ParallelIterator<Item = &'a i32>,
727 I2: Iterator<Item = i32>,
728 {
729 let v1: Vec<_> = iter1.copied().collect();
730 let v2: Vec<_> = iter2.collect();
731 assert_eq!(v1, v2);
732 }
733
734 let set_a: IndexSet<_> = (0..3).collect();
735 let set_b: IndexSet<_> = (3..6).collect();
736 let set_c: IndexSet<_> = (0..6).collect();
737 let set_d: IndexSet<_> = (3..9).rev().collect();
738
739 check(set_a.par_difference(&set_a), empty());
740 check(set_a.par_symmetric_difference(&set_a), empty());
741 check(set_a.par_intersection(&set_a), 0..3);
742 check(set_a.par_union(&set_a), 0..3);
743
744 check(set_a.par_difference(&set_b), 0..3);
745 check(set_b.par_difference(&set_a), 3..6);
746 check(set_a.par_symmetric_difference(&set_b), 0..6);
747 check(set_b.par_symmetric_difference(&set_a), (3..6).chain(0..3));
748 check(set_a.par_intersection(&set_b), empty());
749 check(set_b.par_intersection(&set_a), empty());
750 check(set_a.par_union(&set_b), 0..6);
751 check(set_b.par_union(&set_a), (3..6).chain(0..3));
752
753 check(set_a.par_difference(&set_c), empty());
754 check(set_c.par_difference(&set_a), 3..6);
755 check(set_a.par_symmetric_difference(&set_c), 3..6);
756 check(set_c.par_symmetric_difference(&set_a), 3..6);
757 check(set_a.par_intersection(&set_c), 0..3);
758 check(set_c.par_intersection(&set_a), 0..3);
759 check(set_a.par_union(&set_c), 0..6);
760 check(set_c.par_union(&set_a), 0..6);
761
762 check(set_c.par_difference(&set_d), 0..3);
763 check(set_d.par_difference(&set_c), (6..9).rev());
764 check(
765 set_c.par_symmetric_difference(&set_d),
766 (0..3).chain((6..9).rev()),
767 );
768 check(
769 set_d.par_symmetric_difference(&set_c),
770 (6..9).rev().chain(0..3),
771 );
772 check(set_c.par_intersection(&set_d), 3..6);
773 check(set_d.par_intersection(&set_c), (3..6).rev());
774 check(set_c.par_union(&set_d), (0..6).chain((6..9).rev()));
775 check(set_d.par_union(&set_c), (3..9).rev().chain(0..3));
776 }
777}