borsh_derive/internals/deserialize/enums/
mod.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::quote;
3use syn::{Fields, ItemEnum, Path, Variant};
4
5use crate::internals::{attributes::item, deserialize, enum_discriminant::Discriminants, generics};
6
7pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
8    let name = &input.ident;
9    let generics = generics::without_defaults(&input.generics);
10    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
11    let mut where_clause = generics::default_where(where_clause);
12    let mut variant_arms = TokenStream2::new();
13    let use_discriminant = item::contains_use_discriminant(input)?;
14    let discriminants = Discriminants::new(&input.variants);
15    let mut generics_output = deserialize::GenericsOutput::new(&generics);
16
17    for (variant_idx, variant) in input.variants.iter().enumerate() {
18        let variant_body = process_variant(variant, &cratename, &mut generics_output)?;
19        let variant_ident = &variant.ident;
20
21        let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?;
22        variant_arms.extend(quote! {
23            if variant_tag == #discriminant_value { #name::#variant_ident #variant_body } else
24        });
25    }
26    let init = if let Some(method_ident) = item::contains_initialize_with(&input.attrs)? {
27        quote! {
28            return_value.#method_ident();
29        }
30    } else {
31        quote! {}
32    };
33    generics_output.extend(&mut where_clause, &cratename);
34
35    Ok(quote! {
36        impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause {
37            fn deserialize_reader<__R: #cratename::io::Read>(reader: &mut __R) -> ::core::result::Result<Self, #cratename::io::Error> {
38                let tag = <u8 as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?;
39                <Self as #cratename::de::EnumExt>::deserialize_variant(reader, tag)
40            }
41        }
42
43        impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause {
44            fn deserialize_variant<__R: #cratename::io::Read>(
45                reader: &mut __R,
46                variant_tag: u8,
47            ) -> ::core::result::Result<Self, #cratename::io::Error> {
48                let mut return_value =
49                    #variant_arms {
50                    return Err(#cratename::io::Error::new(
51                        #cratename::io::ErrorKind::InvalidData,
52                        #cratename::__private::maybestd::format!("Unexpected variant tag: {:?}", variant_tag),
53                    ))
54                };
55                #init
56                Ok(return_value)
57            }
58        }
59    })
60}
61
62fn process_variant(
63    variant: &Variant,
64    cratename: &Path,
65    generics: &mut deserialize::GenericsOutput,
66) -> syn::Result<TokenStream2> {
67    let mut body = TokenStream2::new();
68    match &variant.fields {
69        Fields::Named(fields) => {
70            for field in &fields.named {
71                deserialize::process_field(field, cratename, &mut body, generics)?;
72            }
73            body = quote! { { #body }};
74        }
75        Fields::Unnamed(fields) => {
76            for field in fields.unnamed.iter() {
77                deserialize::process_field(field, cratename, &mut body, generics)?;
78            }
79            body = quote! { ( #body )};
80        }
81        Fields::Unit => {}
82    }
83    Ok(body)
84}
85
86#[cfg(test)]
87mod tests {
88    use crate::internals::test_helpers::{
89        default_cratename, local_insta_assert_snapshot, pretty_print_syn_str,
90    };
91
92    use super::*;
93
94    #[test]
95    fn borsh_skip_struct_variant_field() {
96        let item_enum: ItemEnum = syn::parse2(quote! {
97            enum AA {
98                B {
99                    #[borsh(skip)]
100                    c: i32,
101                    d: u32,
102                },
103                NegatedVariant {
104                    beta: u8,
105                }
106            }
107        })
108        .unwrap();
109        let actual = process(&item_enum, default_cratename()).unwrap();
110
111        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
112    }
113
114    #[test]
115    fn borsh_skip_tuple_variant_field() {
116        let item_enum: ItemEnum = syn::parse2(quote! {
117            enum AAT {
118                B(#[borsh(skip)] i32, u32),
119
120                NegatedVariant {
121                    beta: u8,
122                }
123            }
124        })
125        .unwrap();
126        let actual = process(&item_enum, default_cratename()).unwrap();
127
128        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
129    }
130
131    #[test]
132    fn simple_enum_with_custom_crate() {
133        let item_struct: ItemEnum = syn::parse2(quote! {
134            enum A {
135                B {
136                    x: HashMap<u32, String>,
137                    y: String,
138                },
139                C(K, Vec<u64>),
140            }
141        })
142        .unwrap();
143
144        let crate_: Path = syn::parse2(quote! { reexporter::borsh }).unwrap();
145        let actual = process(&item_struct, crate_).unwrap();
146        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
147    }
148
149    #[test]
150    fn simple_generics() {
151        let item_struct: ItemEnum = syn::parse2(quote! {
152            enum A<K, V, U> {
153                B {
154                    x: HashMap<K, V>,
155                    y: String,
156                },
157                C(K, Vec<U>),
158            }
159        })
160        .unwrap();
161
162        let actual = process(&item_struct, default_cratename()).unwrap();
163        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
164    }
165
166    #[test]
167    fn bound_generics() {
168        let item_struct: ItemEnum = syn::parse2(quote! {
169            enum A<K: Key, V, U> where V: Value {
170                B {
171                    x: HashMap<K, V>,
172                    y: String,
173                },
174                C(K, Vec<U>),
175            }
176        })
177        .unwrap();
178
179        let actual = process(&item_struct, default_cratename()).unwrap();
180        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
181    }
182
183    #[test]
184    fn recursive_enum() {
185        let item_struct: ItemEnum = syn::parse2(quote! {
186            enum A<K: Key, V> where V: Value {
187                B {
188                    x: HashMap<K, V>,
189                    y: String,
190                },
191                C(K, Vec<A>),
192            }
193        })
194        .unwrap();
195
196        let actual = process(&item_struct, default_cratename()).unwrap();
197
198        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
199    }
200    #[test]
201    fn generic_borsh_skip_struct_field() {
202        let item_struct: ItemEnum = syn::parse2(quote! {
203            enum A<K: Key, V, U> where V: Value {
204                B {
205                    #[borsh(skip)]
206                    x: HashMap<K, V>,
207                    y: String,
208                },
209                C(K, Vec<U>),
210            }
211        })
212        .unwrap();
213
214        let actual = process(&item_struct, default_cratename()).unwrap();
215
216        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
217    }
218
219    #[test]
220    fn generic_borsh_skip_tuple_field() {
221        let item_struct: ItemEnum = syn::parse2(quote! {
222            enum A<K: Key, V, U> where V: Value {
223                B {
224                    x: HashMap<K, V>,
225                    y: String,
226                },
227                C(K, #[borsh(skip)] Vec<U>),
228            }
229        })
230        .unwrap();
231
232        let actual = process(&item_struct, default_cratename()).unwrap();
233
234        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
235    }
236
237    #[test]
238    fn generic_deserialize_bound() {
239        let item_struct: ItemEnum = syn::parse2(quote! {
240            enum A<T: Debug, U> {
241                C {
242                    a: String,
243                    #[borsh(bound(deserialize =
244                        "T: PartialOrd + Hash + Eq + borsh::de::BorshDeserialize,
245                         U: borsh::de::BorshDeserialize"
246                    ))]
247                    b: HashMap<T, U>,
248                },
249                D(u32, u32),
250            }
251        })
252        .unwrap();
253
254        let actual = process(&item_struct, default_cratename()).unwrap();
255
256        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
257    }
258
259    #[test]
260    fn check_deserialize_with_attr() {
261        let item_struct: ItemEnum = syn::parse2(quote! {
262            enum C<K: Ord, V> {
263                C3(u64, u64),
264                C4 {
265                    x: u64,
266                    #[borsh(deserialize_with = "third_party_impl::deserialize_third_party")]
267                    y: ThirdParty<K, V>
268                },
269            }
270        })
271        .unwrap();
272
273        let actual = process(&item_struct, default_cratename()).unwrap();
274
275        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
276    }
277
278    #[test]
279    fn borsh_discriminant_false() {
280        let item_enum: ItemEnum = syn::parse2(quote! {
281           #[borsh(use_discriminant = false)]
282            enum X {
283                A,
284                B = 20,
285                C,
286                D,
287                E = 10,
288                F,
289            }
290        })
291        .unwrap();
292        let actual = process(&item_enum, default_cratename()).unwrap();
293
294        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
295    }
296    #[test]
297    fn borsh_discriminant_true() {
298        let item_enum: ItemEnum = syn::parse2(quote! {
299            #[borsh(use_discriminant = true)]
300            enum X {
301                A,
302                B = 20,
303                C,
304                D,
305                E = 10,
306                F,
307            }
308        })
309        .unwrap();
310        let actual = process(&item_enum, default_cratename()).unwrap();
311
312        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
313    }
314    #[test]
315    fn borsh_init_func() {
316        let item_enum: ItemEnum = syn::parse2(quote! {
317            #[borsh(init = initialization_method)]
318            enum A {
319                A,
320                B,
321                C,
322                D,
323                E,
324                F,
325            }
326        })
327        .unwrap();
328        let actual = process(&item_enum, default_cratename()).unwrap();
329        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
330    }
331}