derive_arbitrary/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::*;
6
7mod container_attributes;
8mod field_attributes;
9mod variant_attributes;
10
11use container_attributes::ContainerAttributes;
12use field_attributes::{determine_field_constructor, FieldConstructor};
13use variant_attributes::not_skipped;
14
15const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
16const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
17
18#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
19pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
20    let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
21    expand_derive_arbitrary(input)
22        .unwrap_or_else(syn::Error::into_compile_error)
23        .into()
24}
25
26fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
27    let container_attrs = ContainerAttributes::from_derive_input(&input)?;
28
29    let (lifetime_without_bounds, lifetime_with_bounds) =
30        build_arbitrary_lifetime(input.generics.clone());
31
32    let recursive_count = syn::Ident::new(
33        &format!("RECURSIVE_COUNT_{}", input.ident),
34        Span::call_site(),
35    );
36
37    let arbitrary_method =
38        gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
39    let size_hint_method = gen_size_hint_method(&input)?;
40    let name = input.ident;
41
42    // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
43    let generics = apply_trait_bounds(
44        input.generics,
45        lifetime_without_bounds.clone(),
46        &container_attrs,
47    )?;
48
49    // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
50    let mut generics_with_lifetime = generics.clone();
51    generics_with_lifetime
52        .params
53        .push(GenericParam::Lifetime(lifetime_with_bounds));
54    let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
55
56    // Build TypeGenerics and WhereClause without a lifetime
57    let (_, ty_generics, where_clause) = generics.split_for_impl();
58
59    Ok(quote! {
60        const _: () = {
61            ::std::thread_local! {
62                #[allow(non_upper_case_globals)]
63                static #recursive_count: ::core::cell::Cell<u32> = ::core::cell::Cell::new(0);
64            }
65
66            #[automatically_derived]
67            impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
68                #arbitrary_method
69                #size_hint_method
70            }
71        };
72    })
73}
74
75// Returns: (lifetime without bounds, lifetime with bounds)
76// Example: ("'arbitrary", "'arbitrary: 'a + 'b")
77fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
78    let lifetime_without_bounds =
79        LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
80    let mut lifetime_with_bounds = lifetime_without_bounds.clone();
81
82    for param in generics.params.iter() {
83        if let GenericParam::Lifetime(lifetime_def) = param {
84            lifetime_with_bounds
85                .bounds
86                .push(lifetime_def.lifetime.clone());
87        }
88    }
89
90    (lifetime_without_bounds, lifetime_with_bounds)
91}
92
93fn apply_trait_bounds(
94    mut generics: Generics,
95    lifetime: LifetimeParam,
96    container_attrs: &ContainerAttributes,
97) -> Result<Generics> {
98    // If user-supplied bounds exist, apply them to their matching type parameters.
99    if let Some(config_bounds) = &container_attrs.bounds {
100        let mut config_bounds_applied = 0;
101        for param in generics.params.iter_mut() {
102            if let GenericParam::Type(type_param) = param {
103                if let Some(replacement) = config_bounds
104                    .iter()
105                    .flatten()
106                    .find(|p| p.ident == type_param.ident)
107                {
108                    *type_param = replacement.clone();
109                    config_bounds_applied += 1;
110                } else {
111                    // If no user-supplied bounds exist for this type, delete the original bounds.
112                    // This mimics serde.
113                    type_param.bounds = Default::default();
114                    type_param.default = None;
115                }
116            }
117        }
118        let config_bounds_supplied = config_bounds
119            .iter()
120            .map(|bounds| bounds.len())
121            .sum::<usize>();
122        if config_bounds_applied != config_bounds_supplied {
123            return Err(Error::new(
124                Span::call_site(),
125                format!(
126                    "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
127                    ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
128                ),
129            ));
130        }
131        Ok(generics)
132    } else {
133        // Otherwise, inject a `T: Arbitrary` bound for every parameter.
134        Ok(add_trait_bounds(generics, lifetime))
135    }
136}
137
138// Add a bound `T: Arbitrary` to every type parameter T.
139fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
140    for param in generics.params.iter_mut() {
141        if let GenericParam::Type(type_param) = param {
142            type_param
143                .bounds
144                .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
145        }
146    }
147    generics
148}
149
150fn with_recursive_count_guard(
151    recursive_count: &syn::Ident,
152    expr: impl quote::ToTokens,
153) -> impl quote::ToTokens {
154    quote! {
155        let guard_against_recursion = u.is_empty();
156        if guard_against_recursion {
157            #recursive_count.with(|count| {
158                if count.get() > 0 {
159                    return Err(arbitrary::Error::NotEnoughData);
160                }
161                count.set(count.get() + 1);
162                Ok(())
163            })?;
164        }
165
166        let result = (|| { #expr })();
167
168        if guard_against_recursion {
169            #recursive_count.with(|count| {
170                count.set(count.get() - 1);
171            });
172        }
173
174        result
175    }
176}
177
178fn gen_arbitrary_method(
179    input: &DeriveInput,
180    lifetime: LifetimeParam,
181    recursive_count: &syn::Ident,
182) -> Result<TokenStream> {
183    fn arbitrary_structlike(
184        fields: &Fields,
185        ident: &syn::Ident,
186        lifetime: LifetimeParam,
187        recursive_count: &syn::Ident,
188    ) -> Result<TokenStream> {
189        let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
190        let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
191
192        let arbitrary_take_rest = construct_take_rest(fields)?;
193        let take_rest_body =
194            with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
195
196        Ok(quote! {
197            fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
198                #body
199            }
200
201            fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
202                #take_rest_body
203            }
204        })
205    }
206
207    fn arbitrary_variant(
208        index: u64,
209        enum_name: &Ident,
210        variant_name: &Ident,
211        ctor: TokenStream,
212    ) -> TokenStream {
213        quote! { #index => #enum_name::#variant_name #ctor }
214    }
215
216    fn arbitrary_enum_method(
217        recursive_count: &syn::Ident,
218        unstructured: TokenStream,
219        variants: &[TokenStream],
220    ) -> impl quote::ToTokens {
221        let count = variants.len() as u64;
222        with_recursive_count_guard(
223            recursive_count,
224            quote! {
225                // Use a multiply + shift to generate a ranged random number
226                // with slight bias. For details, see:
227                // https://lemire.me/blog/2016/06/30/fast-random-shuffling
228                Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
229                    #(#variants,)*
230                    _ => unreachable!()
231                })
232            },
233        )
234    }
235
236    fn arbitrary_enum(
237        DataEnum { variants, .. }: &DataEnum,
238        enum_name: &Ident,
239        lifetime: LifetimeParam,
240        recursive_count: &syn::Ident,
241    ) -> Result<TokenStream> {
242        let filtered_variants = variants.iter().filter(not_skipped);
243
244        // Check attributes of all variants:
245        filtered_variants
246            .clone()
247            .try_for_each(check_variant_attrs)?;
248
249        // From here on, we can assume that the attributes of all variants were checked.
250        let enumerated_variants = filtered_variants
251            .enumerate()
252            .map(|(index, variant)| (index as u64, variant));
253
254        // Construct `match`-arms for the `arbitrary` method.
255        let variants = enumerated_variants
256            .clone()
257            .map(|(index, Variant { fields, ident, .. })| {
258                construct(fields, |_, field| gen_constructor_for_field(field))
259                    .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
260            })
261            .collect::<Result<Vec<TokenStream>>>()?;
262
263        // Construct `match`-arms for the `arbitrary_take_rest` method.
264        let variants_take_rest = enumerated_variants
265            .map(|(index, Variant { fields, ident, .. })| {
266                construct_take_rest(fields)
267                    .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
268            })
269            .collect::<Result<Vec<TokenStream>>>()?;
270
271        // Most of the time, `variants` is not empty (the happy path),
272        //   thus `variants_take_rest` will be used,
273        //   so no need to move this check before constructing `variants_take_rest`.
274        // If `variants` is empty, this will emit a compiler-error.
275        (!variants.is_empty())
276            .then(|| {
277                // TODO: Improve dealing with `u` vs. `&mut u`.
278                let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
279                let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
280
281                quote! {
282                    fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
283                        #arbitrary
284                    }
285
286                    fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
287                        #arbitrary_take_rest
288                    }
289                }
290            })
291            .ok_or_else(|| Error::new_spanned(
292                enum_name,
293                "Enum must have at least one variant, that is not skipped"
294            ))
295    }
296
297    let ident = &input.ident;
298    match &input.data {
299        Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
300        Data::Union(data) => arbitrary_structlike(
301            &Fields::Named(data.fields.clone()),
302            ident,
303            lifetime,
304            recursive_count,
305        ),
306        Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
307    }
308}
309
310fn construct(
311    fields: &Fields,
312    ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
313) -> Result<TokenStream> {
314    let output = match fields {
315        Fields::Named(names) => {
316            let names: Vec<TokenStream> = names
317                .named
318                .iter()
319                .enumerate()
320                .map(|(i, f)| {
321                    let name = f.ident.as_ref().unwrap();
322                    ctor(i, f).map(|ctor| quote! { #name: #ctor })
323                })
324                .collect::<Result<_>>()?;
325            quote! { { #(#names,)* } }
326        }
327        Fields::Unnamed(names) => {
328            let names: Vec<TokenStream> = names
329                .unnamed
330                .iter()
331                .enumerate()
332                .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
333                .collect::<Result<_>>()?;
334            quote! { ( #(#names),* ) }
335        }
336        Fields::Unit => quote!(),
337    };
338    Ok(output)
339}
340
341fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
342    construct(fields, |idx, field| {
343        determine_field_constructor(field).map(|field_constructor| match field_constructor {
344            FieldConstructor::Default => quote!(::core::default::Default::default()),
345            FieldConstructor::Arbitrary => {
346                if idx + 1 == fields.len() {
347                    quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
348                } else {
349                    quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
350                }
351            }
352            FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
353            FieldConstructor::Value(value) => quote!(#value),
354        })
355    })
356}
357
358fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
359    let size_hint_fields = |fields: &Fields| {
360        fields
361            .iter()
362            .map(|f| {
363                let ty = &f.ty;
364                determine_field_constructor(f).map(|field_constructor| {
365                    match field_constructor {
366                        FieldConstructor::Default | FieldConstructor::Value(_) => {
367                            quote!(Ok((0, Some(0))))
368                        }
369                        FieldConstructor::Arbitrary => {
370                            quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
371                        }
372
373                        // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
374                        // just an educated guess, although it's gonna be inaccurate for dynamically
375                        // allocated types (Vec, HashMap, etc.).
376                        FieldConstructor::With(_) => {
377                            quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
378                        }
379                    }
380                })
381            })
382            .collect::<Result<Vec<TokenStream>>>()
383            .map(|hints| {
384                quote! {
385                    Ok(arbitrary::size_hint::and_all(&[
386                        #( #hints? ),*
387                    ]))
388                }
389            })
390    };
391    let size_hint_structlike = |fields: &Fields| {
392        size_hint_fields(fields).map(|hint| {
393            quote! {
394                #[inline]
395                fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
396                    Self::try_size_hint(depth).unwrap_or_default()
397                }
398
399                #[inline]
400                fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
401                    arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
402                }
403            }
404        })
405    };
406    match &input.data {
407        Data::Struct(data) => size_hint_structlike(&data.fields),
408        Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
409        Data::Enum(data) => data
410            .variants
411            .iter()
412            .filter(not_skipped)
413            .map(|Variant { fields, .. }| {
414                // The attributes of all variants are checked in `gen_arbitrary_method` above
415                //   and can therefore assume that they are valid.
416                size_hint_fields(fields)
417            })
418            .collect::<Result<Vec<TokenStream>>>()
419            .map(|variants| {
420                quote! {
421                    fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
422                        Self::try_size_hint(depth).unwrap_or_default()
423                    }
424                    #[inline]
425                    fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
426                        Ok(arbitrary::size_hint::and(
427                            <u32 as arbitrary::Arbitrary>::try_size_hint(depth)?,
428                            arbitrary::size_hint::try_recursion_guard(depth, |depth| {
429                                Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
430                            })?,
431                        ))
432                    }
433                }
434            }),
435    }
436}
437
438fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
439    let ctor = match determine_field_constructor(field)? {
440        FieldConstructor::Default => quote!(::core::default::Default::default()),
441        FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
442        FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
443        FieldConstructor::Value(value) => quote!(#value),
444    };
445    Ok(ctor)
446}
447
448fn check_variant_attrs(variant: &Variant) -> Result<()> {
449    for attr in &variant.attrs {
450        if attr.path().is_ident(ARBITRARY_ATTRIBUTE_NAME) {
451            return Err(Error::new_spanned(
452                attr,
453                format!(
454                    "invalid `{}` attribute. it is unsupported on enum variants. try applying it to a field of the variant instead",
455                    ARBITRARY_ATTRIBUTE_NAME
456                ),
457            ));
458        }
459    }
460    Ok(())
461}