strum_macros/macros/
enum_iter.rs1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Ident};
4
5use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
7pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11 let vis = &ast.vis;
12 let type_properties = ast.get_type_properties()?;
13 let strum_module_path = type_properties.crate_module_path();
14 let doc_comment = format!("An iterator over the variants of [{}]", name);
15
16 if gen.lifetimes().count() > 0 {
17 return Err(syn::Error::new(
18 Span::call_site(),
19 "This macro doesn't support enums with lifetimes. \
20 The resulting enums would be unbounded.",
21 ));
22 }
23
24 let phantom_data = if gen.type_params().count() > 0 {
25 let g = gen.type_params().map(|param| ¶m.ident);
26 quote! { < fn() -> ( #(#g),* ) > }
27 } else {
28 quote! { < fn() -> () > }
29 };
30
31 let variants = match &ast.data {
32 Data::Enum(v) => &v.variants,
33 _ => return Err(non_enum_error()),
34 };
35
36 let mut arms = Vec::new();
37 let mut idx = 0usize;
38 for variant in variants {
39 if variant.get_variant_properties()?.disabled.is_some() {
40 continue;
41 }
42
43 let ident = &variant.ident;
44 let params = match &variant.fields {
45 Fields::Unit => quote! {},
46 Fields::Unnamed(fields) => {
47 let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
48 .take(fields.unnamed.len());
49 quote! { (#(#defaults),*) }
50 }
51 Fields::Named(fields) => {
52 let fields = fields
53 .named
54 .iter()
55 .map(|field| field.ident.as_ref().unwrap());
56 quote! { {#(#fields: ::core::default::Default::default()),*} }
57 }
58 };
59
60 arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
61 idx += 1;
62 }
63
64 let variant_count = arms.len();
65 arms.push(quote! { _ => ::core::option::Option::None });
66 let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
67
68 let iter_name_debug_struct =
70 syn::parse_str::<syn::LitStr>(&format!("\"{}\"", iter_name)).unwrap();
71
72 Ok(quote! {
73 #[doc = #doc_comment]
74 #[allow(
75 missing_copy_implementations,
76 )]
77 #vis struct #iter_name #impl_generics {
78 idx: usize,
79 back_idx: usize,
80 marker: ::core::marker::PhantomData #phantom_data,
81 }
82
83 #[automatically_derived]
84 impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause {
85 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
86 f.debug_struct(#iter_name_debug_struct)
89 .field("len", &self.len())
90 .finish()
91 }
92 }
93
94 #[automatically_derived]
95 impl #impl_generics #iter_name #ty_generics #where_clause {
96 fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> {
97 match idx {
98 #(#arms),*
99 }
100 }
101 }
102
103 #[automatically_derived]
104 impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
105 type Iterator = #iter_name #ty_generics;
106
107 #[inline]
108 fn iter() -> #iter_name #ty_generics {
109 #iter_name {
110 idx: 0,
111 back_idx: 0,
112 marker: ::core::marker::PhantomData,
113 }
114 }
115 }
116
117 #[automatically_derived]
118 impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
119 type Item = #name #ty_generics;
120
121 #[inline]
122 fn next(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
123 self.nth(0)
124 }
125
126 #[inline]
127 fn size_hint(&self) -> (usize, ::core::option::Option<usize>) {
128 let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
129 (t, Some(t))
130 }
131
132 #[inline]
133 fn nth(&mut self, n: usize) -> ::core::option::Option<<Self as Iterator>::Item> {
134 let idx = self.idx + n + 1;
135 if idx + self.back_idx > #variant_count {
136 self.idx = #variant_count;
140 ::core::option::Option::None
141 } else {
142 self.idx = idx;
143 #iter_name::get(self, idx - 1)
144 }
145 }
146 }
147
148 #[automatically_derived]
149 impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
150 #[inline]
151 fn len(&self) -> usize {
152 self.size_hint().0
153 }
154 }
155
156 #[automatically_derived]
157 impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
158 #[inline]
159 fn next_back(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
160 let back_idx = self.back_idx + 1;
161
162 if self.idx + back_idx > #variant_count {
163 self.back_idx = #variant_count;
167 ::core::option::Option::None
168 } else {
169 self.back_idx = back_idx;
170 #iter_name::get(self, #variant_count - self.back_idx)
171 }
172 }
173 }
174
175 #[automatically_derived]
176 impl #impl_generics ::core::iter::FusedIterator for #iter_name #ty_generics #where_clause { }
177
178 #[automatically_derived]
179 impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
180 #[inline]
181 fn clone(&self) -> #iter_name #ty_generics {
182 #iter_name {
183 idx: self.idx,
184 back_idx: self.back_idx,
185 marker: self.marker.clone(),
186 }
187 }
188 }
189 })
190}