borsh_derive/internals/
enum_discriminant.rs
1use std::collections::HashMap;
2use std::convert::TryFrom;
3
4use proc_macro2::{Ident, TokenStream};
5use quote::quote;
6use syn::{punctuated::Punctuated, token::Comma, Variant};
7
8pub struct Discriminants(HashMap<Ident, TokenStream>);
9impl Discriminants {
10 pub fn new(variants: &Punctuated<Variant, Comma>) -> Self {
13 let mut map = HashMap::new();
14 let mut next_discriminant_if_not_specified = quote! {0};
15
16 for variant in variants {
17 let this_discriminant = variant.discriminant.clone().map_or_else(
18 || quote! { #next_discriminant_if_not_specified },
19 |(_, e)| quote! { #e },
20 );
21
22 next_discriminant_if_not_specified = quote! { #this_discriminant + 1 };
23 map.insert(variant.ident.clone(), this_discriminant);
24 }
25
26 Self(map)
27 }
28
29 pub fn get(
30 &self,
31 variant_ident: &Ident,
32 use_discriminant: bool,
33 variant_idx: usize,
34 ) -> syn::Result<TokenStream> {
35 let variant_idx = u8::try_from(variant_idx).map_err(|err| {
36 syn::Error::new(
37 variant_ident.span(),
38 format!("up to 256 enum variants are supported: {}", err),
39 )
40 })?;
41 let result = if use_discriminant {
42 let discriminant_value = self.0.get(variant_ident).unwrap();
43 quote! { #discriminant_value }
44 } else {
45 quote! { #variant_idx }
46 };
47 Ok(result)
48 }
49}