derive_arbitrary/
field_attributes.rs

1use crate::ARBITRARY_ATTRIBUTE_NAME;
2use proc_macro2::{Span, TokenStream, TokenTree};
3use quote::quote;
4use syn::{spanned::Spanned, *};
5
6/// Determines how a value for a field should be constructed.
7#[cfg_attr(test, derive(Debug))]
8pub enum FieldConstructor {
9    /// Assume that Arbitrary is defined for the type of this field and use it (default)
10    Arbitrary,
11
12    /// Places `Default::default()` as a field value.
13    Default,
14
15    /// Use custom function or closure to generate a value for a field.
16    With(TokenStream),
17
18    /// Set a field always to the given value.
19    Value(TokenStream),
20}
21
22pub fn determine_field_constructor(field: &Field) -> Result<FieldConstructor> {
23    let opt_attr = fetch_attr_from_field(field)?;
24    let ctor = match opt_attr {
25        Some(attr) => parse_attribute(attr)?,
26        None => FieldConstructor::Arbitrary,
27    };
28    Ok(ctor)
29}
30
31fn fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>> {
32    let found_attributes: Vec<_> = field
33        .attrs
34        .iter()
35        .filter(|a| {
36            let path = a.path();
37            let name = quote!(#path).to_string();
38            name == ARBITRARY_ATTRIBUTE_NAME
39        })
40        .collect();
41    if found_attributes.len() > 1 {
42        let name = field.ident.as_ref().unwrap();
43        let msg = format!(
44            "Multiple conflicting #[{ARBITRARY_ATTRIBUTE_NAME}] attributes found on field `{name}`"
45        );
46        return Err(syn::Error::new(field.span(), msg));
47    }
48    Ok(found_attributes.into_iter().next())
49}
50
51fn parse_attribute(attr: &Attribute) -> Result<FieldConstructor> {
52    if let Meta::List(ref meta_list) = attr.meta {
53        parse_attribute_internals(meta_list)
54    } else {
55        let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] must contain a group");
56        Err(syn::Error::new(attr.span(), msg))
57    }
58}
59
60fn parse_attribute_internals(meta_list: &MetaList) -> Result<FieldConstructor> {
61    let mut tokens_iter = meta_list.tokens.clone().into_iter();
62    let token = tokens_iter.next().ok_or_else(|| {
63        let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
64        syn::Error::new(meta_list.span(), msg)
65    })?;
66    match token.to_string().as_ref() {
67        "default" => Ok(FieldConstructor::Default),
68        "with" => {
69            let func_path = parse_assigned_value("with", tokens_iter, meta_list.span())?;
70            Ok(FieldConstructor::With(func_path))
71        }
72        "value" => {
73            let value = parse_assigned_value("value", tokens_iter, meta_list.span())?;
74            Ok(FieldConstructor::Value(value))
75        }
76        _ => {
77            let msg = format!("Unknown option for #[{ARBITRARY_ATTRIBUTE_NAME}]: `{token}`");
78            Err(syn::Error::new(token.span(), msg))
79        }
80    }
81}
82
83// Input:
84//     = 2 + 2
85// Output:
86//     2 + 2
87fn parse_assigned_value(
88    opt_name: &str,
89    mut tokens_iter: impl Iterator<Item = TokenTree>,
90    default_span: Span,
91) -> Result<TokenStream> {
92    let eq_sign = tokens_iter.next().ok_or_else(|| {
93        let msg = format!(
94            "Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], `{opt_name}` is missing assignment."
95        );
96        syn::Error::new(default_span, msg)
97    })?;
98
99    if eq_sign.to_string() == "=" {
100        Ok(tokens_iter.collect())
101    } else {
102        let msg = format!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], expected `=` after `{opt_name}`, got: `{eq_sign}`");
103        Err(syn::Error::new(eq_sign.span(), msg))
104    }
105}