educe/trait_handlers/default/default_enum.rs
1use std::{fmt::Write, str::FromStr};
2
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{Data, DeriveInput, Fields, Generics, Lit, Meta};
6
7use super::{
8 super::TraitHandler,
9 models::{FieldAttributeBuilder, TypeAttributeBuilder},
10};
11use crate::{panic, Trait};
12
13pub struct DefaultEnumHandler;
14
15impl TraitHandler for DefaultEnumHandler {
16 fn trait_meta_handler(
17 ast: &DeriveInput,
18 tokens: &mut TokenStream,
19 traits: &[Trait],
20 meta: &Meta,
21 ) {
22 let type_attribute = TypeAttributeBuilder {
23 enable_flag: true,
24 enable_new: true,
25 enable_expression: true,
26 enable_bound: true,
27 }
28 .from_default_meta(meta);
29
30 let bound = type_attribute
31 .bound
32 .into_punctuated_where_predicates_by_generic_parameters(&ast.generics.params);
33
34 let mut builder_tokens = TokenStream::new();
35
36 if let Data::Enum(data) = &ast.data {
37 match type_attribute.expression {
38 Some(expression) => {
39 for variant in data.variants.iter() {
40 let _ = TypeAttributeBuilder {
41 enable_flag: false,
42 enable_new: false,
43 enable_expression: false,
44 enable_bound: false,
45 }
46 .from_attributes(&variant.attrs, traits);
47
48 ensure_fields_no_attribute(&variant.fields, traits);
49 }
50
51 builder_tokens.extend(quote!(#expression));
52 },
53 None => {
54 let variant = {
55 let variants = &data.variants;
56
57 if variants.len() == 1 {
58 let variant = &variants[0];
59
60 let _ = TypeAttributeBuilder {
61 enable_flag: true,
62 enable_new: false,
63 enable_expression: false,
64 enable_bound: false,
65 }
66 .from_attributes(&variant.attrs, traits);
67
68 variant
69 } else {
70 let mut variants_iter = variants.iter();
71
72 loop {
73 let variant = variants_iter.next();
74
75 match variant {
76 Some(variant) => {
77 let variant_attribute = TypeAttributeBuilder {
78 enable_flag: true,
79 enable_new: false,
80 enable_expression: false,
81 enable_bound: false,
82 }
83 .from_attributes(&variant.attrs, traits);
84
85 if variant_attribute.flag {
86 loop {
87 let variant = variants_iter.next();
88
89 match variant {
90 Some(variant) => {
91 let variant_attribute = TypeAttributeBuilder {
92 enable_flag: true,
93 enable_new: false,
94 enable_expression: false,
95 enable_bound: false,
96 }.from_attributes(&variant.attrs, traits);
97
98 if variant_attribute.flag {
99 panic::multiple_default_variants();
100 } else {
101 ensure_fields_no_attribute(
102 &variant.fields,
103 traits,
104 );
105 }
106 },
107 None => break,
108 }
109 }
110
111 break variant;
112 } else {
113 ensure_fields_no_attribute(&variant.fields, traits);
114 }
115 },
116 None => panic::no_default_variant(),
117 }
118 }
119 }
120 };
121
122 let enum_name = ast.ident.to_string();
123 let variant_name = variant.ident.to_string();
124
125 let mut enum_tokens = format!(
126 "{enum_name}::{variant_name}",
127 enum_name = enum_name,
128 variant_name = variant_name
129 );
130
131 match &variant.fields {
132 Fields::Unit => (), // TODO Unit
133 Fields::Named(fields) => {
134 // TODO Struct
135 enum_tokens.push('{');
136
137 for field in fields.named.iter() {
138 let field_attribute = FieldAttributeBuilder {
139 enable_flag: false,
140 enable_literal: true,
141 enable_expression: true,
142 }
143 .from_attributes(&field.attrs, traits);
144
145 let field_name = field.ident.as_ref().unwrap().to_string();
146
147 enum_tokens
148 .write_fmt(format_args!(
149 "{field_name}: ",
150 field_name = field_name
151 ))
152 .unwrap();
153
154 match field_attribute.literal {
155 Some(value) => match &value {
156 Lit::Str(s) => {
157 enum_tokens
158 .write_fmt(format_args!(
159 "core::convert::Into::into({s})",
160 s = s.into_token_stream()
161 ))
162 .unwrap();
163 },
164 _ => {
165 enum_tokens
166 .push_str(&value.into_token_stream().to_string());
167 },
168 },
169 None => {
170 match field_attribute.expression {
171 Some(expression) => {
172 enum_tokens.push_str(&expression);
173 },
174 None => {
175 let typ = field
176 .ty
177 .clone()
178 .into_token_stream()
179 .to_string();
180
181 enum_tokens.write_fmt(format_args!("<{typ} as core::default::Default>::default()", typ = typ)).unwrap();
182 },
183 }
184 },
185 }
186
187 enum_tokens.push(',');
188 }
189
190 enum_tokens.push('}');
191 },
192 Fields::Unnamed(fields) => {
193 // TODO Tuple
194 enum_tokens.push('(');
195
196 for field in fields.unnamed.iter() {
197 let field_attribute = FieldAttributeBuilder {
198 enable_flag: false,
199 enable_literal: true,
200 enable_expression: true,
201 }
202 .from_attributes(&field.attrs, traits);
203
204 match field_attribute.literal {
205 Some(value) => match &value {
206 Lit::Str(s) => {
207 enum_tokens
208 .write_fmt(format_args!(
209 "core::convert::Into::into({s})",
210 s = s.into_token_stream()
211 ))
212 .unwrap();
213 },
214 _ => {
215 enum_tokens
216 .push_str(&value.into_token_stream().to_string());
217 },
218 },
219 None => {
220 match field_attribute.expression {
221 Some(expression) => {
222 enum_tokens.push_str(&expression);
223 },
224 None => {
225 let typ = field
226 .ty
227 .clone()
228 .into_token_stream()
229 .to_string();
230
231 enum_tokens.write_fmt(format_args!("<{typ} as core::default::Default>::default()", typ = typ)).unwrap();
232 },
233 }
234 },
235 }
236
237 enum_tokens.push(',');
238 }
239
240 enum_tokens.push(')');
241 },
242 }
243
244 builder_tokens.extend(TokenStream::from_str(&enum_tokens).unwrap());
245 },
246 }
247 }
248
249 let ident = &ast.ident;
250
251 let mut generics_cloned: Generics = ast.generics.clone();
252
253 let where_clause = generics_cloned.make_where_clause();
254
255 for where_predicate in bound {
256 where_clause.predicates.push(where_predicate);
257 }
258
259 let (impl_generics, ty_generics, where_clause) = generics_cloned.split_for_impl();
260
261 let default_impl = quote! {
262 impl #impl_generics core::default::Default for #ident #ty_generics #where_clause {
263 #[inline]
264 fn default() -> Self {
265 #builder_tokens
266 }
267 }
268 };
269
270 tokens.extend(default_impl);
271
272 if type_attribute.new {
273 let new_impl = quote! {
274 impl #impl_generics #ident #ty_generics #where_clause {
275 /// Returns the "default value" for a type.
276 #[inline]
277 pub fn new() -> Self {
278 <Self as core::default::Default>::default()
279 }
280 }
281 };
282
283 tokens.extend(new_impl);
284 }
285 }
286}
287
288fn ensure_fields_no_attribute(fields: &Fields, traits: &[Trait]) {
289 match fields {
290 Fields::Unit => (),
291 Fields::Named(fields) => {
292 for field in fields.named.iter() {
293 let _ = FieldAttributeBuilder {
294 enable_flag: false,
295 enable_literal: false,
296 enable_expression: false,
297 }
298 .from_attributes(&field.attrs, traits);
299 }
300 },
301 Fields::Unnamed(fields) => {
302 for field in fields.unnamed.iter() {
303 let _ = FieldAttributeBuilder {
304 enable_flag: false,
305 enable_literal: false,
306 enable_expression: false,
307 }
308 .from_attributes(&field.attrs, traits);
309 }
310 },
311 }
312}