borsh_derive/internals/serialize/enums/
mod.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::quote;
3use syn::{Fields, Ident, ItemEnum, Path, Variant};
4
5use crate::internals::{
6    attributes::{field, item, BoundType},
7    enum_discriminant::Discriminants,
8    generics, serialize,
9};
10
11pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
12    let enum_ident = &input.ident;
13    let generics = generics::without_defaults(&input.generics);
14    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
15    let mut where_clause = generics::default_where(where_clause);
16    let mut generics_output = serialize::GenericsOutput::new(&generics);
17    let mut all_variants_idx_body = TokenStream2::new();
18    let mut fields_body = TokenStream2::new();
19    let use_discriminant = item::contains_use_discriminant(input)?;
20    let discriminants = Discriminants::new(&input.variants);
21    let mut has_unit_variant = false;
22
23    for (variant_idx, variant) in input.variants.iter().enumerate() {
24        let variant_ident = &variant.ident;
25        let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?;
26        let variant_output = process_variant(
27            variant,
28            enum_ident,
29            &discriminant_value,
30            &cratename,
31            &mut generics_output,
32        )?;
33        all_variants_idx_body.extend(variant_output.variant_idx_body);
34        match variant_output.body {
35            VariantBody::Unit => has_unit_variant = true,
36            VariantBody::Fields(VariantFields { header, body }) => fields_body.extend(quote!(
37                #enum_ident::#variant_ident #header => {
38                    #body
39                }
40            )),
41        }
42    }
43    let fields_body = optimize_fields_body(fields_body, has_unit_variant);
44    generics_output.extend(&mut where_clause, &cratename);
45
46    Ok(quote! {
47        impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause {
48            fn serialize<__W: #cratename::io::Write>(&self, writer: &mut __W) -> ::core::result::Result<(), #cratename::io::Error> {
49                let variant_idx: u8 = match self {
50                    #all_variants_idx_body
51                };
52                writer.write_all(&variant_idx.to_le_bytes())?;
53
54                #fields_body
55                Ok(())
56            }
57        }
58    })
59}
60
61fn optimize_fields_body(fields_body: TokenStream2, has_unit_variant: bool) -> TokenStream2 {
62    if fields_body.is_empty() {
63        // If we no variants with fields, there's nothing to match against. Just
64        // re-use the empty token stream.
65        fields_body
66    } else {
67        let unit_fields_catchall = if has_unit_variant {
68            // We had some variants with unit fields, create a catch-all for
69            // these to be used at the bottom.
70            quote!(
71                _ => {}
72            )
73        } else {
74            TokenStream2::new()
75        };
76        // Create a match that serialises all the fields for each non-unit
77        // variant and add a catch-all at the bottom if we do have unit
78        // variants.
79        quote!(
80            match self {
81                #fields_body
82                #unit_fields_catchall
83            }
84        )
85    }
86}
87
88#[derive(Default)]
89struct VariantFields {
90    header: TokenStream2,
91    body: TokenStream2,
92}
93
94impl VariantFields {
95    fn named_header(self) -> Self {
96        let header = self.header;
97
98        VariantFields {
99            // `..` pattern matching works even if all fields were specified
100            header: quote! { { #header.. }},
101            body: self.body,
102        }
103    }
104    fn unnamed_header(self) -> Self {
105        let header = self.header;
106
107        VariantFields {
108            header: quote! { ( #header )},
109            body: self.body,
110        }
111    }
112}
113
114enum VariantBody {
115    // No body variant, unit enum variant.
116    Unit,
117    // Variant with body (fields)
118    Fields(VariantFields),
119}
120
121struct VariantOutput {
122    body: VariantBody,
123    variant_idx_body: TokenStream2,
124}
125
126fn process_variant(
127    variant: &Variant,
128    enum_ident: &Ident,
129    discriminant_value: &TokenStream2,
130    cratename: &Path,
131    generics: &mut serialize::GenericsOutput,
132) -> syn::Result<VariantOutput> {
133    let variant_ident = &variant.ident;
134    let variant_output = match &variant.fields {
135        Fields::Named(fields) => {
136            let mut variant_fields = VariantFields::default();
137            for field in &fields.named {
138                let field_id = serialize::FieldId::Enum(field.ident.clone().unwrap());
139                process_field(field, field_id, cratename, generics, &mut variant_fields)?;
140            }
141            VariantOutput {
142                body: VariantBody::Fields(variant_fields.named_header()),
143                variant_idx_body: quote!(
144                    #enum_ident::#variant_ident {..} => #discriminant_value,
145                ),
146            }
147        }
148        Fields::Unnamed(fields) => {
149            let mut variant_fields = VariantFields::default();
150            for (field_idx, field) in fields.unnamed.iter().enumerate() {
151                let field_id = serialize::FieldId::new_enum_unnamed(field_idx)?;
152                process_field(field, field_id, cratename, generics, &mut variant_fields)?;
153            }
154            VariantOutput {
155                body: VariantBody::Fields(variant_fields.unnamed_header()),
156                variant_idx_body: quote!(
157                    #enum_ident::#variant_ident(..) => #discriminant_value,
158                ),
159            }
160        }
161        Fields::Unit => VariantOutput {
162            body: VariantBody::Unit,
163            variant_idx_body: quote!(
164                #enum_ident::#variant_ident => #discriminant_value,
165            ),
166        },
167    };
168    Ok(variant_output)
169}
170
171fn process_field(
172    field: &syn::Field,
173    field_id: serialize::FieldId,
174    cratename: &Path,
175    generics: &mut serialize::GenericsOutput,
176    output: &mut VariantFields,
177) -> syn::Result<()> {
178    let parsed = field::Attributes::parse(&field.attrs)?;
179
180    let needs_bounds_derive = parsed.needs_bounds_derive(BoundType::Serialize);
181    generics
182        .overrides
183        .extend(parsed.collect_bounds(BoundType::Serialize));
184
185    let field_variant_header = field_id.enum_variant_header(parsed.skip);
186    if let Some(field_variant_header) = field_variant_header {
187        output.header.extend(field_variant_header);
188    }
189
190    if !parsed.skip {
191        let delta = field_id.serialize_output(cratename, parsed.serialize_with);
192        output.body.extend(delta);
193        if needs_bounds_derive {
194            generics.serialize_visitor.visit_field(field);
195        }
196    }
197    Ok(())
198}
199
200#[cfg(test)]
201mod tests {
202    use crate::internals::test_helpers::{
203        default_cratename, local_insta_assert_snapshot, pretty_print_syn_str,
204    };
205
206    use super::*;
207    #[test]
208    fn borsh_skip_tuple_variant_field() {
209        let item_enum: ItemEnum = syn::parse2(quote! {
210            enum AATTB {
211                B(#[borsh(skip)] i32, #[borsh(skip)] u32),
212
213                NegatedVariant {
214                    beta: u8,
215                }
216            }
217        })
218        .unwrap();
219        let actual = process(&item_enum, default_cratename()).unwrap();
220
221        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
222    }
223
224    #[test]
225    fn struct_variant_field() {
226        let item_enum: ItemEnum = syn::parse2(quote! {
227            enum AB {
228                B {
229                    c: i32,
230                    d: u32,
231                },
232
233                NegatedVariant {
234                    beta: String,
235                }
236            }
237        })
238        .unwrap();
239
240        let actual = process(&item_enum, default_cratename()).unwrap();
241
242        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
243    }
244
245    #[test]
246    fn simple_enum_with_custom_crate() {
247        let item_enum: ItemEnum = syn::parse2(quote! {
248            enum AB {
249                B {
250                    c: i32,
251                    d: u32,
252                },
253
254                NegatedVariant {
255                    beta: String,
256                }
257            }
258        })
259        .unwrap();
260
261        let crate_: Path = syn::parse2(quote! { reexporter::borsh }).unwrap();
262        let actual = process(&item_enum, crate_).unwrap();
263
264        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
265    }
266
267    #[test]
268    fn borsh_skip_struct_variant_field() {
269        let item_enum: ItemEnum = syn::parse2(quote! {
270
271            enum AB {
272                B {
273                    #[borsh(skip)]
274                    c: i32,
275
276                    d: u32,
277                },
278
279                NegatedVariant {
280                    beta: String,
281                }
282            }
283        })
284        .unwrap();
285
286        let actual = process(&item_enum, default_cratename()).unwrap();
287
288        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
289    }
290
291    #[test]
292    fn borsh_skip_struct_variant_all_fields() {
293        let item_enum: ItemEnum = syn::parse2(quote! {
294
295            enum AAB {
296                B {
297                    #[borsh(skip)]
298                    c: i32,
299
300                    #[borsh(skip)]
301                    d: u32,
302                },
303
304                NegatedVariant {
305                    beta: String,
306                }
307            }
308        })
309        .unwrap();
310
311        let actual = process(&item_enum, default_cratename()).unwrap();
312
313        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
314    }
315
316    #[test]
317    fn simple_generics() {
318        let item_struct: ItemEnum = syn::parse2(quote! {
319            enum A<K, V, U> {
320                B {
321                    x: HashMap<K, V>,
322                    y: String,
323                },
324                C(K, Vec<U>),
325            }
326        })
327        .unwrap();
328
329        let actual = process(&item_struct, default_cratename()).unwrap();
330        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
331    }
332
333    #[test]
334    fn bound_generics() {
335        let item_struct: ItemEnum = syn::parse2(quote! {
336            enum A<K: Key, V, U> where V: Value {
337                B {
338                    x: HashMap<K, V>,
339                    y: String,
340                },
341                C(K, Vec<U>),
342            }
343        })
344        .unwrap();
345
346        let actual = process(&item_struct, default_cratename()).unwrap();
347        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
348    }
349
350    #[test]
351    fn recursive_enum() {
352        let item_struct: ItemEnum = syn::parse2(quote! {
353            enum A<K: Key, V> where V: Value {
354                B {
355                    x: HashMap<K, V>,
356                    y: String,
357                },
358                C(K, Vec<A>),
359            }
360        })
361        .unwrap();
362
363        let actual = process(&item_struct, default_cratename()).unwrap();
364
365        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
366    }
367
368    #[test]
369    fn generic_borsh_skip_struct_field() {
370        let item_struct: ItemEnum = syn::parse2(quote! {
371            enum A<K: Key, V, U> where V: Value {
372                B {
373                    #[borsh(skip)]
374                    x: HashMap<K, V>,
375                    y: String,
376                },
377                C(K, Vec<U>),
378            }
379        })
380        .unwrap();
381
382        let actual = process(&item_struct, default_cratename()).unwrap();
383
384        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
385    }
386
387    #[test]
388    fn generic_borsh_skip_tuple_field() {
389        let item_struct: ItemEnum = syn::parse2(quote! {
390            enum A<K: Key, V, U> where V: Value {
391                B {
392                    x: HashMap<K, V>,
393                    y: String,
394                },
395                C(K, #[borsh(skip)] Vec<U>),
396            }
397        })
398        .unwrap();
399
400        let actual = process(&item_struct, default_cratename()).unwrap();
401
402        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
403    }
404
405    #[test]
406    fn generic_serialize_bound() {
407        let item_struct: ItemEnum = syn::parse2(quote! {
408            enum A<T: Debug, U> {
409                C {
410                    a: String,
411                    #[borsh(bound(serialize =
412                        "T: borsh::ser::BorshSerialize + PartialOrd,
413                         U: borsh::ser::BorshSerialize"
414                    ))]
415                    b: HashMap<T, U>,
416                },
417                D(u32, u32),
418            }
419        })
420        .unwrap();
421
422        let actual = process(&item_struct, default_cratename()).unwrap();
423
424        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
425    }
426
427    #[test]
428    fn check_serialize_with_attr() {
429        let item_struct: ItemEnum = syn::parse2(quote! {
430            enum C<K: Ord, V> {
431                C3(u64, u64),
432                C4 {
433                    x: u64,
434                    #[borsh(serialize_with = "third_party_impl::serialize_third_party")]
435                    y: ThirdParty<K, V>
436                },
437            }
438        })
439        .unwrap();
440
441        let actual = process(&item_struct, default_cratename()).unwrap();
442
443        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
444    }
445
446    #[test]
447    fn borsh_discriminant_false() {
448        let item_enum: ItemEnum = syn::parse2(quote! {
449           #[borsh(use_discriminant = false)]
450            enum X {
451                A,
452                B = 20,
453                C,
454                D,
455                E = 10,
456                F,
457            }
458        })
459        .unwrap();
460        let actual = process(&item_enum, default_cratename()).unwrap();
461
462        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
463    }
464    #[test]
465    fn borsh_discriminant_true() {
466        let item_enum: ItemEnum = syn::parse2(quote! {
467            #[borsh(use_discriminant = true)]
468            enum X {
469                A,
470                B = 20,
471                C,
472                D,
473                E = 10,
474                F,
475            }
476        })
477        .unwrap();
478        let actual = process(&item_enum, default_cratename()).unwrap();
479
480        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
481    }
482
483    #[test]
484    fn mixed_with_unit_variants() {
485        let item_enum: ItemEnum = syn::parse2(quote! {
486            enum X {
487                A(u16),
488                B,
489                C {x: i32, y: i32},
490                D,
491            }
492        })
493        .unwrap();
494        let actual = process(&item_enum, default_cratename()).unwrap();
495
496        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
497    }
498}