borsh_derive/internals/attributes/item/
mod.rs

1use crate::internals::attributes::{BORSH, CRATE, INIT, USE_DISCRIMINANT};
2use quote::ToTokens;
3use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path};
4
5use super::{get_one_attribute, parsing};
6
7pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> {
8    let borsh = get_one_attribute(&derive_input.attrs)?;
9
10    if let Some(attr) = borsh {
11        attr.parse_nested_meta(|meta| {
12            if meta.path != USE_DISCRIMINANT && meta.path != INIT && meta.path != CRATE {
13                return Err(syn::Error::new(
14                    meta.path.span(),
15                    "`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`",
16                ));
17            }
18            if meta.path == USE_DISCRIMINANT {
19                let _expr: Expr = meta.value()?.parse()?;
20                if let syn::Data::Struct(ref _data) = derive_input.data {
21                    return Err(syn::Error::new(
22                        derive_input.ident.span(),
23                        "borsh(use_discriminant=<bool>) does not support structs",
24                    ));
25                }
26            } else if meta.path == INIT || meta.path == CRATE {
27                let _expr: Expr = meta.value()?.parse()?;
28            }
29
30            Ok(())
31        })?;
32    }
33    Ok(())
34}
35
36pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result<bool, syn::Error> {
37    if input.variants.len() > 256 {
38        return Err(syn::Error::new(
39            input.span(),
40            "up to 256 enum variants are supported",
41        ));
42    }
43
44    let attrs = &input.attrs;
45    let mut use_discriminant = None;
46    let attr = attrs.iter().find(|attr| attr.path() == BORSH);
47    if let Some(attr) = attr {
48        attr.parse_nested_meta(|meta| {
49            if meta.path == USE_DISCRIMINANT {
50                let value_expr: Expr = meta.value()?.parse()?;
51                let value = value_expr.to_token_stream().to_string();
52                match value.as_str() {
53                    "true" => {
54                        use_discriminant = Some(true);
55                    }
56                    "false" => use_discriminant = Some(false),
57                    _ => {
58                        return Err(syn::Error::new(
59                            value_expr.span(),
60                            "`use_discriminant` accepts only `true` or `false`",
61                        ));
62                    }
63                };
64            } else if meta.path == INIT || meta.path == CRATE {
65                let _value_expr: Expr = meta.value()?.parse()?;
66            }
67            Ok(())
68        })?;
69    }
70    let has_explicit_discriminants = input
71        .variants
72        .iter()
73        .any(|variant| variant.discriminant.is_some());
74    if has_explicit_discriminants && use_discriminant.is_none() {
75        return Err(syn::Error::new(
76                input.ident.span(),
77                "You have to specify `#[borsh(use_discriminant=true)]` or `#[borsh(use_discriminant=false)]` for all enums with explicit discriminant",
78            ));
79    }
80    Ok(use_discriminant.unwrap_or(false))
81}
82
83pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Result<Option<Path>, Error> {
84    let mut res = None;
85    let attr = attrs.iter().find(|attr| attr.path() == BORSH);
86    if let Some(attr) = attr {
87        attr.parse_nested_meta(|meta| {
88            if meta.path == INIT {
89                let value_expr: Path = meta.value()?.parse()?;
90                res = Some(value_expr);
91            } else if meta.path == USE_DISCRIMINANT || meta.path == CRATE {
92                let _value_expr: Expr = meta.value()?.parse()?;
93            }
94
95            Ok(())
96        })?;
97    }
98
99    Ok(res)
100}
101
102pub(crate) fn get_crate(attrs: &[Attribute]) -> Result<Option<Path>, Error> {
103    let mut res = None;
104    let attr = attrs.iter().find(|attr| attr.path() == BORSH);
105    if let Some(attr) = attr {
106        attr.parse_nested_meta(|meta| {
107            if meta.path == CRATE {
108                let value_expr: Path = parsing::parse_lit_into(BORSH, CRATE, &meta)?;
109                res = Some(value_expr);
110            } else if meta.path == USE_DISCRIMINANT || meta.path == INIT {
111                let _value_expr: Expr = meta.value()?.parse()?;
112            }
113
114            Ok(())
115        })?;
116    }
117
118    Ok(res)
119}
120
121#[cfg(test)]
122mod tests {
123    use crate::internals::test_helpers::local_insta_assert_debug_snapshot;
124    use quote::{quote, ToTokens};
125    use syn::ItemEnum;
126
127    use super::*;
128    #[test]
129    fn test_use_discriminant() {
130        let item_enum: ItemEnum = syn::parse2(quote! {
131            #[derive(BorshDeserialize, Debug)]
132            #[borsh(use_discriminant = false)]
133            enum AWithUseDiscriminantFalse {
134                X,
135                Y,
136            }
137        })
138        .unwrap();
139        let actual = contains_use_discriminant(&item_enum);
140        assert!(!actual.unwrap());
141    }
142
143    #[test]
144    fn test_use_discriminant_true() {
145        let item_enum: ItemEnum = syn::parse2(quote! {
146            #[derive(BorshDeserialize, Debug)]
147            #[borsh(use_discriminant = true)]
148            enum AWithUseDiscriminantTrue {
149                X,
150                Y,
151            }
152        })
153        .unwrap();
154        let actual = contains_use_discriminant(&item_enum);
155        assert!(actual.unwrap());
156    }
157
158    #[test]
159    fn test_use_discriminant_wrong_value() {
160        let item_enum: ItemEnum = syn::parse2(quote! {
161            #[derive(BorshDeserialize, Debug)]
162            #[borsh(use_discriminant = 111)]
163            enum AWithUseDiscriminantFalse {
164                X,
165                Y,
166            }
167        })
168        .unwrap();
169        let actual = contains_use_discriminant(&item_enum);
170        let err = match actual {
171            Ok(..) => unreachable!("expecting error here"),
172            Err(err) => err,
173        };
174        local_insta_assert_debug_snapshot!(err);
175    }
176    #[test]
177    fn test_check_attrs_use_discriminant_on_struct() {
178        let item_enum: DeriveInput = syn::parse2(quote! {
179            #[derive(BorshDeserialize, Debug)]
180            #[borsh(use_discriminant = false)]
181            struct AWithUseDiscriminantFalse {
182                x: X,
183                y: Y,
184            }
185        })
186        .unwrap();
187        let actual = check_attributes(&item_enum);
188        local_insta_assert_debug_snapshot!(actual.unwrap_err());
189    }
190    #[test]
191    fn test_check_attrs_borsh_skip_on_whole_item() {
192        let item_enum: DeriveInput = syn::parse2(quote! {
193            #[derive(BorshDeserialize, Debug)]
194            #[borsh(skip)]
195            struct AWithUseDiscriminantFalse {
196                 x: X,
197                 y: Y,
198            }
199        })
200        .unwrap();
201        let actual = check_attributes(&item_enum);
202        local_insta_assert_debug_snapshot!(actual.unwrap_err());
203    }
204    #[test]
205    fn test_check_attrs_borsh_invalid_on_whole_item() {
206        let item_enum: DeriveInput = syn::parse2(quote! {
207            #[derive(BorshDeserialize, Debug)]
208            #[borsh(invalid)]
209            enum AWithUseDiscriminantFalse {
210                X,
211                Y,
212            }
213        })
214        .unwrap();
215        let actual = check_attributes(&item_enum);
216        local_insta_assert_debug_snapshot!(actual.unwrap_err());
217    }
218    #[test]
219    fn test_check_attrs_init_function() {
220        let item_struct = syn::parse2::<DeriveInput>(quote! {
221            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
222            #[borsh(init = initialization_method)]
223            struct A<'a> {
224                x: u64,
225            }
226        })
227        .unwrap();
228
229        let actual = check_attributes(&item_struct);
230        assert!(actual.is_ok());
231    }
232
233    #[test]
234    fn test_check_attrs_init_function_with_use_discriminant_reversed() {
235        let item_struct = syn::parse2::<DeriveInput>(quote! {
236            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
237            #[borsh(use_discriminant=true, init = initialization_method)]
238            enum A {
239                B,
240                C,
241                D= 10,
242            }
243        })
244        .unwrap();
245
246        let actual = check_attributes(&item_struct);
247        assert!(actual.is_ok());
248    }
249
250    #[test]
251    fn test_reject_multiple_borsh_attrs() {
252        let item_struct = syn::parse2::<DeriveInput>(quote! {
253            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
254            #[borsh(use_discriminant=true)]
255            #[borsh(init = initialization_method)]
256            enum A {
257                B,
258                C,
259                D= 10,
260            }
261        })
262        .unwrap();
263
264        let actual = check_attributes(&item_struct);
265        local_insta_assert_debug_snapshot!(actual.unwrap_err());
266    }
267
268    #[test]
269    fn test_check_attrs_init_function_with_use_discriminant() {
270        let item_struct = syn::parse2::<DeriveInput>(quote! {
271            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
272            #[borsh(init = initialization_method, use_discriminant=true)]
273            enum A {
274                B,
275                C,
276                D= 10,
277            }
278        })
279        .unwrap();
280
281        let actual = check_attributes(&item_struct);
282        assert!(actual.is_ok());
283    }
284
285    #[test]
286    fn test_check_attrs_init_function_wrong_format() {
287        let item_struct: DeriveInput = syn::parse2(quote! {
288        #[derive(BorshDeserialize, Debug)]
289        #[borsh(init_func = initialization_method)]
290        struct A<'a> {
291            x: u64,
292            b: B,
293            y: f32,
294            z: String,
295            v: Vec<String>,
296
297        }
298            })
299        .unwrap();
300        let actual = check_attributes(&item_struct);
301        local_insta_assert_debug_snapshot!(actual.unwrap_err());
302    }
303    #[test]
304    fn test_init_function() {
305        let item_struct = syn::parse2::<DeriveInput>(quote! {
306            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
307            #[borsh(init = initialization_method)]
308            struct A<'a> {
309                x: u64,
310            }
311        })
312        .unwrap();
313
314        let actual = contains_initialize_with(&item_struct.attrs);
315        assert_eq!(
316            actual.unwrap().to_token_stream().to_string(),
317            "initialization_method"
318        );
319    }
320
321    #[test]
322    fn test_init_function_parsing_error() {
323        let item_struct = syn::parse2::<DeriveInput>(quote! {
324            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
325            #[borsh(init={strange; blocky})]
326            struct A {
327                lazy: Option<u64>,
328            }
329        })
330        .unwrap();
331
332        let actual = contains_initialize_with(&item_struct.attrs);
333        let err = match actual {
334            Ok(..) => unreachable!("expecting error here"),
335            Err(err) => err,
336        };
337        local_insta_assert_debug_snapshot!(err);
338    }
339
340    #[test]
341    fn test_init_function_with_use_discriminant() {
342        let item_struct = syn::parse2::<ItemEnum>(quote! {
343            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
344            #[borsh(init = initialization_method, use_discriminant=true)]
345            enum A {
346                B,
347                C,
348                D,
349            }
350        })
351        .unwrap();
352
353        let actual = contains_initialize_with(&item_struct.attrs);
354        assert_eq!(
355            actual.unwrap().to_token_stream().to_string(),
356            "initialization_method"
357        );
358        let actual = contains_use_discriminant(&item_struct);
359        assert!(actual.unwrap());
360    }
361
362    #[test]
363    fn test_init_function_with_use_discriminant_reversed() {
364        let item_struct = syn::parse2::<ItemEnum>(quote! {
365            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
366            #[borsh(use_discriminant=true, init = initialization_method)]
367            enum A {
368                B,
369                C,
370                D,
371            }
372        })
373        .unwrap();
374
375        let actual = contains_initialize_with(&item_struct.attrs);
376        assert_eq!(
377            actual.unwrap().to_token_stream().to_string(),
378            "initialization_method"
379        );
380        let actual = contains_use_discriminant(&item_struct);
381        assert!(actual.unwrap());
382    }
383
384    #[test]
385    fn test_init_function_with_use_discriminant_with_crate() {
386        let item_struct = syn::parse2::<ItemEnum>(quote! {
387            #[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
388            #[borsh(init = initialization_method, crate = "reexporter::borsh", use_discriminant=true)]
389            enum A {
390                B,
391                C,
392                D,
393            }
394        })
395        .unwrap();
396
397        let actual = contains_initialize_with(&item_struct.attrs);
398        assert_eq!(
399            actual.unwrap().to_token_stream().to_string(),
400            "initialization_method"
401        );
402        let actual = contains_use_discriminant(&item_struct);
403        assert!(actual.unwrap());
404
405        let crate_ = get_crate(&item_struct.attrs);
406        assert_eq!(
407            crate_.unwrap().to_token_stream().to_string(),
408            "reexporter :: borsh"
409        );
410    }
411}