1use proc_macro2::TokenStream as TokenStream2;
2use quote::quote;
3use syn::{Fields, Ident, ItemEnum, Path, Variant};
4
5use crate::internals::{
6 attributes::{field, item, BoundType},
7 enum_discriminant::Discriminants,
8 generics, serialize,
9};
10
11pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
12 let enum_ident = &input.ident;
13 let generics = generics::without_defaults(&input.generics);
14 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
15 let mut where_clause = generics::default_where(where_clause);
16 let mut generics_output = serialize::GenericsOutput::new(&generics);
17 let mut all_variants_idx_body = TokenStream2::new();
18 let mut fields_body = TokenStream2::new();
19 let use_discriminant = item::contains_use_discriminant(input)?;
20 let discriminants = Discriminants::new(&input.variants);
21 let mut has_unit_variant = false;
22
23 for (variant_idx, variant) in input.variants.iter().enumerate() {
24 let variant_ident = &variant.ident;
25 let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?;
26 let variant_output = process_variant(
27 variant,
28 enum_ident,
29 &discriminant_value,
30 &cratename,
31 &mut generics_output,
32 )?;
33 all_variants_idx_body.extend(variant_output.variant_idx_body);
34 match variant_output.body {
35 VariantBody::Unit => has_unit_variant = true,
36 VariantBody::Fields(VariantFields { header, body }) => fields_body.extend(quote!(
37 #enum_ident::#variant_ident #header => {
38 #body
39 }
40 )),
41 }
42 }
43 let fields_body = optimize_fields_body(fields_body, has_unit_variant);
44 generics_output.extend(&mut where_clause, &cratename);
45
46 Ok(quote! {
47 impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause {
48 fn serialize<__W: #cratename::io::Write>(&self, writer: &mut __W) -> ::core::result::Result<(), #cratename::io::Error> {
49 let variant_idx: u8 = match self {
50 #all_variants_idx_body
51 };
52 writer.write_all(&variant_idx.to_le_bytes())?;
53
54 #fields_body
55 Ok(())
56 }
57 }
58 })
59}
60
61fn optimize_fields_body(fields_body: TokenStream2, has_unit_variant: bool) -> TokenStream2 {
62 if fields_body.is_empty() {
63 fields_body
66 } else {
67 let unit_fields_catchall = if has_unit_variant {
68 quote!(
71 _ => {}
72 )
73 } else {
74 TokenStream2::new()
75 };
76 quote!(
80 match self {
81 #fields_body
82 #unit_fields_catchall
83 }
84 )
85 }
86}
87
88#[derive(Default)]
89struct VariantFields {
90 header: TokenStream2,
91 body: TokenStream2,
92}
93
94impl VariantFields {
95 fn named_header(self) -> Self {
96 let header = self.header;
97
98 VariantFields {
99 header: quote! { { #header.. }},
101 body: self.body,
102 }
103 }
104 fn unnamed_header(self) -> Self {
105 let header = self.header;
106
107 VariantFields {
108 header: quote! { ( #header )},
109 body: self.body,
110 }
111 }
112}
113
114enum VariantBody {
115 Unit,
117 Fields(VariantFields),
119}
120
121struct VariantOutput {
122 body: VariantBody,
123 variant_idx_body: TokenStream2,
124}
125
126fn process_variant(
127 variant: &Variant,
128 enum_ident: &Ident,
129 discriminant_value: &TokenStream2,
130 cratename: &Path,
131 generics: &mut serialize::GenericsOutput,
132) -> syn::Result<VariantOutput> {
133 let variant_ident = &variant.ident;
134 let variant_output = match &variant.fields {
135 Fields::Named(fields) => {
136 let mut variant_fields = VariantFields::default();
137 for field in &fields.named {
138 let field_id = serialize::FieldId::Enum(field.ident.clone().unwrap());
139 process_field(field, field_id, cratename, generics, &mut variant_fields)?;
140 }
141 VariantOutput {
142 body: VariantBody::Fields(variant_fields.named_header()),
143 variant_idx_body: quote!(
144 #enum_ident::#variant_ident {..} => #discriminant_value,
145 ),
146 }
147 }
148 Fields::Unnamed(fields) => {
149 let mut variant_fields = VariantFields::default();
150 for (field_idx, field) in fields.unnamed.iter().enumerate() {
151 let field_id = serialize::FieldId::new_enum_unnamed(field_idx)?;
152 process_field(field, field_id, cratename, generics, &mut variant_fields)?;
153 }
154 VariantOutput {
155 body: VariantBody::Fields(variant_fields.unnamed_header()),
156 variant_idx_body: quote!(
157 #enum_ident::#variant_ident(..) => #discriminant_value,
158 ),
159 }
160 }
161 Fields::Unit => VariantOutput {
162 body: VariantBody::Unit,
163 variant_idx_body: quote!(
164 #enum_ident::#variant_ident => #discriminant_value,
165 ),
166 },
167 };
168 Ok(variant_output)
169}
170
171fn process_field(
172 field: &syn::Field,
173 field_id: serialize::FieldId,
174 cratename: &Path,
175 generics: &mut serialize::GenericsOutput,
176 output: &mut VariantFields,
177) -> syn::Result<()> {
178 let parsed = field::Attributes::parse(&field.attrs)?;
179
180 let needs_bounds_derive = parsed.needs_bounds_derive(BoundType::Serialize);
181 generics
182 .overrides
183 .extend(parsed.collect_bounds(BoundType::Serialize));
184
185 let field_variant_header = field_id.enum_variant_header(parsed.skip);
186 if let Some(field_variant_header) = field_variant_header {
187 output.header.extend(field_variant_header);
188 }
189
190 if !parsed.skip {
191 let delta = field_id.serialize_output(cratename, parsed.serialize_with);
192 output.body.extend(delta);
193 if needs_bounds_derive {
194 generics.serialize_visitor.visit_field(field);
195 }
196 }
197 Ok(())
198}
199
200#[cfg(test)]
201mod tests {
202 use crate::internals::test_helpers::{
203 default_cratename, local_insta_assert_snapshot, pretty_print_syn_str,
204 };
205
206 use super::*;
207 #[test]
208 fn borsh_skip_tuple_variant_field() {
209 let item_enum: ItemEnum = syn::parse2(quote! {
210 enum AATTB {
211 B(#[borsh(skip)] i32, #[borsh(skip)] u32),
212
213 NegatedVariant {
214 beta: u8,
215 }
216 }
217 })
218 .unwrap();
219 let actual = process(&item_enum, default_cratename()).unwrap();
220
221 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
222 }
223
224 #[test]
225 fn struct_variant_field() {
226 let item_enum: ItemEnum = syn::parse2(quote! {
227 enum AB {
228 B {
229 c: i32,
230 d: u32,
231 },
232
233 NegatedVariant {
234 beta: String,
235 }
236 }
237 })
238 .unwrap();
239
240 let actual = process(&item_enum, default_cratename()).unwrap();
241
242 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
243 }
244
245 #[test]
246 fn simple_enum_with_custom_crate() {
247 let item_enum: ItemEnum = syn::parse2(quote! {
248 enum AB {
249 B {
250 c: i32,
251 d: u32,
252 },
253
254 NegatedVariant {
255 beta: String,
256 }
257 }
258 })
259 .unwrap();
260
261 let crate_: Path = syn::parse2(quote! { reexporter::borsh }).unwrap();
262 let actual = process(&item_enum, crate_).unwrap();
263
264 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
265 }
266
267 #[test]
268 fn borsh_skip_struct_variant_field() {
269 let item_enum: ItemEnum = syn::parse2(quote! {
270
271 enum AB {
272 B {
273 #[borsh(skip)]
274 c: i32,
275
276 d: u32,
277 },
278
279 NegatedVariant {
280 beta: String,
281 }
282 }
283 })
284 .unwrap();
285
286 let actual = process(&item_enum, default_cratename()).unwrap();
287
288 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
289 }
290
291 #[test]
292 fn borsh_skip_struct_variant_all_fields() {
293 let item_enum: ItemEnum = syn::parse2(quote! {
294
295 enum AAB {
296 B {
297 #[borsh(skip)]
298 c: i32,
299
300 #[borsh(skip)]
301 d: u32,
302 },
303
304 NegatedVariant {
305 beta: String,
306 }
307 }
308 })
309 .unwrap();
310
311 let actual = process(&item_enum, default_cratename()).unwrap();
312
313 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
314 }
315
316 #[test]
317 fn simple_generics() {
318 let item_struct: ItemEnum = syn::parse2(quote! {
319 enum A<K, V, U> {
320 B {
321 x: HashMap<K, V>,
322 y: String,
323 },
324 C(K, Vec<U>),
325 }
326 })
327 .unwrap();
328
329 let actual = process(&item_struct, default_cratename()).unwrap();
330 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
331 }
332
333 #[test]
334 fn bound_generics() {
335 let item_struct: ItemEnum = syn::parse2(quote! {
336 enum A<K: Key, V, U> where V: Value {
337 B {
338 x: HashMap<K, V>,
339 y: String,
340 },
341 C(K, Vec<U>),
342 }
343 })
344 .unwrap();
345
346 let actual = process(&item_struct, default_cratename()).unwrap();
347 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
348 }
349
350 #[test]
351 fn recursive_enum() {
352 let item_struct: ItemEnum = syn::parse2(quote! {
353 enum A<K: Key, V> where V: Value {
354 B {
355 x: HashMap<K, V>,
356 y: String,
357 },
358 C(K, Vec<A>),
359 }
360 })
361 .unwrap();
362
363 let actual = process(&item_struct, default_cratename()).unwrap();
364
365 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
366 }
367
368 #[test]
369 fn generic_borsh_skip_struct_field() {
370 let item_struct: ItemEnum = syn::parse2(quote! {
371 enum A<K: Key, V, U> where V: Value {
372 B {
373 #[borsh(skip)]
374 x: HashMap<K, V>,
375 y: String,
376 },
377 C(K, Vec<U>),
378 }
379 })
380 .unwrap();
381
382 let actual = process(&item_struct, default_cratename()).unwrap();
383
384 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
385 }
386
387 #[test]
388 fn generic_borsh_skip_tuple_field() {
389 let item_struct: ItemEnum = syn::parse2(quote! {
390 enum A<K: Key, V, U> where V: Value {
391 B {
392 x: HashMap<K, V>,
393 y: String,
394 },
395 C(K, #[borsh(skip)] Vec<U>),
396 }
397 })
398 .unwrap();
399
400 let actual = process(&item_struct, default_cratename()).unwrap();
401
402 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
403 }
404
405 #[test]
406 fn generic_serialize_bound() {
407 let item_struct: ItemEnum = syn::parse2(quote! {
408 enum A<T: Debug, U> {
409 C {
410 a: String,
411 #[borsh(bound(serialize =
412 "T: borsh::ser::BorshSerialize + PartialOrd,
413 U: borsh::ser::BorshSerialize"
414 ))]
415 b: HashMap<T, U>,
416 },
417 D(u32, u32),
418 }
419 })
420 .unwrap();
421
422 let actual = process(&item_struct, default_cratename()).unwrap();
423
424 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
425 }
426
427 #[test]
428 fn check_serialize_with_attr() {
429 let item_struct: ItemEnum = syn::parse2(quote! {
430 enum C<K: Ord, V> {
431 C3(u64, u64),
432 C4 {
433 x: u64,
434 #[borsh(serialize_with = "third_party_impl::serialize_third_party")]
435 y: ThirdParty<K, V>
436 },
437 }
438 })
439 .unwrap();
440
441 let actual = process(&item_struct, default_cratename()).unwrap();
442
443 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
444 }
445
446 #[test]
447 fn borsh_discriminant_false() {
448 let item_enum: ItemEnum = syn::parse2(quote! {
449 #[borsh(use_discriminant = false)]
450 enum X {
451 A,
452 B = 20,
453 C,
454 D,
455 E = 10,
456 F,
457 }
458 })
459 .unwrap();
460 let actual = process(&item_enum, default_cratename()).unwrap();
461
462 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
463 }
464 #[test]
465 fn borsh_discriminant_true() {
466 let item_enum: ItemEnum = syn::parse2(quote! {
467 #[borsh(use_discriminant = true)]
468 enum X {
469 A,
470 B = 20,
471 C,
472 D,
473 E = 10,
474 F,
475 }
476 })
477 .unwrap();
478 let actual = process(&item_enum, default_cratename()).unwrap();
479
480 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
481 }
482
483 #[test]
484 fn mixed_with_unit_variants() {
485 let item_enum: ItemEnum = syn::parse2(quote! {
486 enum X {
487 A(u16),
488 B,
489 C {x: i32, y: i32},
490 D,
491 }
492 })
493 .unwrap();
494 let actual = process(&item_enum, default_cratename()).unwrap();
495
496 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
497 }
498}