derive_arbitrary/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::*;
6
7mod container_attributes;
8mod field_attributes;
9mod variant_attributes;
10
11use container_attributes::ContainerAttributes;
12use field_attributes::{determine_field_constructor, FieldConstructor};
13use variant_attributes::not_skipped;
14
15const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
16const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
17
18#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
19pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
20    let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
21    expand_derive_arbitrary(input)
22        .unwrap_or_else(syn::Error::into_compile_error)
23        .into()
24}
25
26fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
27    let container_attrs = ContainerAttributes::from_derive_input(&input)?;
28
29    let (lifetime_without_bounds, lifetime_with_bounds) =
30        build_arbitrary_lifetime(input.generics.clone());
31
32    // This won't be used if `needs_recursive_count` ends up false.
33    let recursive_count = syn::Ident::new(
34        &format!("RECURSIVE_COUNT_{}", input.ident),
35        Span::call_site(),
36    );
37
38    let (arbitrary_method, needs_recursive_count) =
39        gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
40    let size_hint_method = gen_size_hint_method(&input, needs_recursive_count)?;
41    let name = input.ident;
42
43    // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
44    let generics = apply_trait_bounds(
45        input.generics,
46        lifetime_without_bounds.clone(),
47        &container_attrs,
48    )?;
49
50    // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
51    let mut generics_with_lifetime = generics.clone();
52    generics_with_lifetime
53        .params
54        .push(GenericParam::Lifetime(lifetime_with_bounds));
55    let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
56
57    // Build TypeGenerics and WhereClause without a lifetime
58    let (_, ty_generics, where_clause) = generics.split_for_impl();
59
60    let recursive_count = needs_recursive_count.then(|| {
61        Some(quote! {
62            ::std::thread_local! {
63                #[allow(non_upper_case_globals)]
64                static #recursive_count: ::core::cell::Cell<u32> = const {
65                    ::core::cell::Cell::new(0)
66                };
67            }
68        })
69    });
70
71    Ok(quote! {
72        const _: () = {
73            #recursive_count
74
75            #[automatically_derived]
76            impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds>
77                for #name #ty_generics #where_clause
78            {
79                #arbitrary_method
80                #size_hint_method
81            }
82        };
83    })
84}
85
86// Returns: (lifetime without bounds, lifetime with bounds)
87// Example: ("'arbitrary", "'arbitrary: 'a + 'b")
88fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
89    let lifetime_without_bounds =
90        LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
91    let mut lifetime_with_bounds = lifetime_without_bounds.clone();
92
93    for param in generics.params.iter() {
94        if let GenericParam::Lifetime(lifetime_def) = param {
95            lifetime_with_bounds
96                .bounds
97                .push(lifetime_def.lifetime.clone());
98        }
99    }
100
101    (lifetime_without_bounds, lifetime_with_bounds)
102}
103
104fn apply_trait_bounds(
105    mut generics: Generics,
106    lifetime: LifetimeParam,
107    container_attrs: &ContainerAttributes,
108) -> Result<Generics> {
109    // If user-supplied bounds exist, apply them to their matching type parameters.
110    if let Some(config_bounds) = &container_attrs.bounds {
111        let mut config_bounds_applied = 0;
112        for param in generics.params.iter_mut() {
113            if let GenericParam::Type(type_param) = param {
114                if let Some(replacement) = config_bounds
115                    .iter()
116                    .flatten()
117                    .find(|p| p.ident == type_param.ident)
118                {
119                    *type_param = replacement.clone();
120                    config_bounds_applied += 1;
121                } else {
122                    // If no user-supplied bounds exist for this type, delete the original bounds.
123                    // This mimics serde.
124                    type_param.bounds = Default::default();
125                    type_param.default = None;
126                }
127            }
128        }
129        let config_bounds_supplied = config_bounds
130            .iter()
131            .map(|bounds| bounds.len())
132            .sum::<usize>();
133        if config_bounds_applied != config_bounds_supplied {
134            return Err(Error::new(
135                Span::call_site(),
136                format!(
137                    "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
138                    ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
139                ),
140            ));
141        }
142        Ok(generics)
143    } else {
144        // Otherwise, inject a `T: Arbitrary` bound for every parameter.
145        Ok(add_trait_bounds(generics, lifetime))
146    }
147}
148
149// Add a bound `T: Arbitrary` to every type parameter T.
150fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
151    for param in generics.params.iter_mut() {
152        if let GenericParam::Type(type_param) = param {
153            type_param
154                .bounds
155                .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
156        }
157    }
158    generics
159}
160
161fn gen_arbitrary_method(
162    input: &DeriveInput,
163    lifetime: LifetimeParam,
164    recursive_count: &syn::Ident,
165) -> Result<(TokenStream, bool)> {
166    fn arbitrary_structlike(
167        fields: &Fields,
168        ident: &syn::Ident,
169        lifetime: LifetimeParam,
170        recursive_count: &syn::Ident,
171    ) -> Result<TokenStream> {
172        let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
173        let body = quote! {
174            arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
175                Ok(#ident #arbitrary)
176            })
177        };
178
179        let arbitrary_take_rest = construct_take_rest(fields)?;
180        let take_rest_body = quote! {
181            arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
182                Ok(#ident #arbitrary_take_rest)
183            })
184        };
185
186        Ok(quote! {
187            fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
188                #body
189            }
190
191            fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
192                #take_rest_body
193            }
194        })
195    }
196
197    fn arbitrary_variant(
198        index: u64,
199        enum_name: &Ident,
200        variant_name: &Ident,
201        ctor: TokenStream,
202    ) -> TokenStream {
203        quote! { #index => #enum_name::#variant_name #ctor }
204    }
205
206    fn arbitrary_enum_method(
207        recursive_count: &syn::Ident,
208        unstructured: TokenStream,
209        variants: &[TokenStream],
210        needs_recursive_count: bool,
211    ) -> TokenStream {
212        let count = variants.len() as u64;
213
214        let do_variants = quote! {
215            // Use a multiply + shift to generate a ranged random number
216            // with slight bias. For details, see:
217            // https://lemire.me/blog/2016/06/30/fast-random-shuffling
218            Ok(match (
219                u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count
220            ) >> 32
221            {
222                #(#variants,)*
223                _ => unreachable!()
224            })
225        };
226
227        if needs_recursive_count {
228            quote! {
229                arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
230                    #do_variants
231                })
232            }
233        } else {
234            do_variants
235        }
236    }
237
238    fn arbitrary_enum(
239        DataEnum { variants, .. }: &DataEnum,
240        enum_name: &Ident,
241        lifetime: LifetimeParam,
242        recursive_count: &syn::Ident,
243    ) -> Result<(TokenStream, bool)> {
244        let filtered_variants = variants.iter().filter(not_skipped);
245
246        // Check attributes of all variants:
247        filtered_variants
248            .clone()
249            .try_for_each(check_variant_attrs)?;
250
251        // From here on, we can assume that the attributes of all variants were checked.
252        let enumerated_variants = filtered_variants
253            .enumerate()
254            .map(|(index, variant)| (index as u64, variant));
255
256        // Construct `match`-arms for the `arbitrary` method.
257        let mut needs_recursive_count = false;
258        let variants = enumerated_variants
259            .clone()
260            .map(|(index, Variant { fields, ident, .. })| {
261                construct(fields, |_, field| gen_constructor_for_field(field)).map(|ctor| {
262                    if !ctor.is_empty() {
263                        needs_recursive_count = true;
264                    }
265                    arbitrary_variant(index, enum_name, ident, ctor)
266                })
267            })
268            .collect::<Result<Vec<TokenStream>>>()?;
269
270        // Construct `match`-arms for the `arbitrary_take_rest` method.
271        let variants_take_rest = enumerated_variants
272            .map(|(index, Variant { fields, ident, .. })| {
273                construct_take_rest(fields)
274                    .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
275            })
276            .collect::<Result<Vec<TokenStream>>>()?;
277
278        // Most of the time, `variants` is not empty (the happy path),
279        //   thus `variants_take_rest` will be used,
280        //   so no need to move this check before constructing `variants_take_rest`.
281        // If `variants` is empty, this will emit a compiler-error.
282        (!variants.is_empty())
283            .then(|| {
284                // TODO: Improve dealing with `u` vs. `&mut u`.
285                let arbitrary = arbitrary_enum_method(
286                    recursive_count,
287                    quote! { u },
288                    &variants,
289                    needs_recursive_count,
290                );
291                let arbitrary_take_rest = arbitrary_enum_method(
292                    recursive_count,
293                    quote! { &mut u },
294                    &variants_take_rest,
295                    needs_recursive_count,
296                );
297
298                (
299                    quote! {
300                        fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>)
301                            -> arbitrary::Result<Self>
302                        {
303                            #arbitrary
304                        }
305
306                        fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>)
307                            -> arbitrary::Result<Self>
308                        {
309                            #arbitrary_take_rest
310                        }
311                    },
312                    needs_recursive_count,
313                )
314            })
315            .ok_or_else(|| {
316                Error::new_spanned(
317                    enum_name,
318                    "Enum must have at least one variant, that is not skipped",
319                )
320            })
321    }
322
323    let ident = &input.ident;
324    let needs_recursive_count = true;
325    match &input.data {
326        Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)
327            .map(|ts| (ts, needs_recursive_count)),
328        Data::Union(data) => arbitrary_structlike(
329            &Fields::Named(data.fields.clone()),
330            ident,
331            lifetime,
332            recursive_count,
333        )
334        .map(|ts| (ts, needs_recursive_count)),
335        Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
336    }
337}
338
339fn construct(
340    fields: &Fields,
341    ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
342) -> Result<TokenStream> {
343    let output = match fields {
344        Fields::Named(names) => {
345            let names: Vec<TokenStream> = names
346                .named
347                .iter()
348                .enumerate()
349                .map(|(i, f)| {
350                    let name = f.ident.as_ref().unwrap();
351                    ctor(i, f).map(|ctor| quote! { #name: #ctor })
352                })
353                .collect::<Result<_>>()?;
354            quote! { { #(#names,)* } }
355        }
356        Fields::Unnamed(names) => {
357            let names: Vec<TokenStream> = names
358                .unnamed
359                .iter()
360                .enumerate()
361                .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
362                .collect::<Result<_>>()?;
363            quote! { ( #(#names),* ) }
364        }
365        Fields::Unit => quote!(),
366    };
367    Ok(output)
368}
369
370fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
371    construct(fields, |idx, field| {
372        determine_field_constructor(field).map(|field_constructor| match field_constructor {
373            FieldConstructor::Default => quote!(::core::default::Default::default()),
374            FieldConstructor::Arbitrary => {
375                if idx + 1 == fields.len() {
376                    quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
377                } else {
378                    quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
379                }
380            }
381            FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
382            FieldConstructor::Value(value) => quote!(#value),
383        })
384    })
385}
386
387fn gen_size_hint_method(input: &DeriveInput, needs_recursive_count: bool) -> Result<TokenStream> {
388    let size_hint_fields = |fields: &Fields| {
389        fields
390            .iter()
391            .map(|f| {
392                let ty = &f.ty;
393                determine_field_constructor(f).map(|field_constructor| {
394                    match field_constructor {
395                        FieldConstructor::Default | FieldConstructor::Value(_) => {
396                            quote!(Ok((0, Some(0))))
397                        }
398                        FieldConstructor::Arbitrary => {
399                            quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
400                        }
401
402                        // Note that in this case it's hard to determine what size_hint must be, so
403                        // size_of::<T>() is just an educated guess, although it's gonna be
404                        // inaccurate for dynamically allocated types (Vec, HashMap, etc.).
405                        FieldConstructor::With(_) => {
406                            quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
407                        }
408                    }
409                })
410            })
411            .collect::<Result<Vec<TokenStream>>>()
412            .map(|hints| {
413                quote! {
414                    Ok(arbitrary::size_hint::and_all(&[
415                        #( #hints? ),*
416                    ]))
417                }
418            })
419    };
420    let size_hint_structlike = |fields: &Fields| {
421        assert!(needs_recursive_count);
422        size_hint_fields(fields).map(|hint| {
423            quote! {
424                #[inline]
425                fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
426                    Self::try_size_hint(depth).unwrap_or_default()
427                }
428
429                #[inline]
430                fn try_size_hint(depth: usize)
431                    -> ::core::result::Result<
432                        (usize, ::core::option::Option<usize>),
433                        arbitrary::MaxRecursionReached,
434                    >
435                {
436                    arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
437                }
438            }
439        })
440    };
441    match &input.data {
442        Data::Struct(data) => size_hint_structlike(&data.fields),
443        Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
444        Data::Enum(data) => data
445            .variants
446            .iter()
447            .filter(not_skipped)
448            .map(|Variant { fields, .. }| {
449                if !needs_recursive_count {
450                    assert!(fields.is_empty());
451                }
452                // The attributes of all variants are checked in `gen_arbitrary_method` above
453                // and can therefore assume that they are valid.
454                size_hint_fields(fields)
455            })
456            .collect::<Result<Vec<TokenStream>>>()
457            .map(|variants| {
458                if needs_recursive_count {
459                    // The enum might be recursive: `try_size_hint` is the primary one, and
460                    // `size_hint` is defined in terms of it.
461                    quote! {
462                        fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
463                            Self::try_size_hint(depth).unwrap_or_default()
464                        }
465                        #[inline]
466                        fn try_size_hint(depth: usize)
467                            -> ::core::result::Result<
468                                (usize, ::core::option::Option<usize>),
469                                arbitrary::MaxRecursionReached,
470                            >
471                        {
472                            Ok(arbitrary::size_hint::and(
473                                <u32 as arbitrary::Arbitrary>::size_hint(depth),
474                                arbitrary::size_hint::try_recursion_guard(depth, |depth| {
475                                    Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
476                                })?,
477                            ))
478                        }
479                    }
480                } else {
481                    // The enum is guaranteed non-recursive, i.e. fieldless: `size_hint` is the
482                    // primary one, and the default `try_size_hint` is good enough.
483                    quote! {
484                        fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
485                            <u32 as arbitrary::Arbitrary>::size_hint(depth)
486                        }
487                    }
488                }
489            }),
490    }
491}
492
493fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
494    let ctor = match determine_field_constructor(field)? {
495        FieldConstructor::Default => quote!(::core::default::Default::default()),
496        FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
497        FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
498        FieldConstructor::Value(value) => quote!(#value),
499    };
500    Ok(ctor)
501}
502
503fn check_variant_attrs(variant: &Variant) -> Result<()> {
504    for attr in &variant.attrs {
505        if attr.path().is_ident(ARBITRARY_ATTRIBUTE_NAME) {
506            return Err(Error::new_spanned(
507                attr,
508                format!(
509                    "invalid `{}` attribute. it is unsupported on enum variants. try applying it to a field of the variant instead",
510                    ARBITRARY_ATTRIBUTE_NAME
511                ),
512            ));
513        }
514    }
515    Ok(())
516}