borsh_derive/internals/
enum_discriminant.rs

1use std::collections::HashMap;
2use std::convert::TryFrom;
3
4use proc_macro2::{Ident, TokenStream};
5use quote::quote;
6use syn::{punctuated::Punctuated, token::Comma, Variant};
7
8pub struct Discriminants(HashMap<Ident, TokenStream>);
9impl Discriminants {
10    /// Calculates the discriminant that will be assigned by the compiler.
11    /// See: https://doc.rust-lang.org/reference/items/enumerations.html#assigning-discriminant-values
12    pub fn new(variants: &Punctuated<Variant, Comma>) -> Self {
13        let mut map = HashMap::new();
14        let mut next_discriminant_if_not_specified = quote! {0};
15
16        for variant in variants {
17            let this_discriminant = variant.discriminant.clone().map_or_else(
18                || quote! { #next_discriminant_if_not_specified },
19                |(_, e)| quote! { #e },
20            );
21
22            next_discriminant_if_not_specified = quote! { #this_discriminant + 1 };
23            map.insert(variant.ident.clone(), this_discriminant);
24        }
25
26        Self(map)
27    }
28
29    pub fn get(
30        &self,
31        variant_ident: &Ident,
32        use_discriminant: bool,
33        variant_idx: usize,
34    ) -> syn::Result<TokenStream> {
35        let variant_idx = u8::try_from(variant_idx).map_err(|err| {
36            syn::Error::new(
37                variant_ident.span(),
38                format!("up to 256 enum variants are supported: {}", err),
39            )
40        })?;
41        let result = if use_discriminant {
42            let discriminant_value = self.0.get(variant_ident).unwrap();
43            quote! { #discriminant_value }
44        } else {
45            quote! { #variant_idx }
46        };
47        Ok(result)
48    }
49}