1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::*;
6
7mod container_attributes;
8mod field_attributes;
9mod variant_attributes;
10
11use container_attributes::ContainerAttributes;
12use field_attributes::{determine_field_constructor, FieldConstructor};
13use variant_attributes::not_skipped;
14
15const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
16const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
17
18#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
19pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
20 let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
21 expand_derive_arbitrary(input)
22 .unwrap_or_else(syn::Error::into_compile_error)
23 .into()
24}
25
26fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
27 let container_attrs = ContainerAttributes::from_derive_input(&input)?;
28
29 let (lifetime_without_bounds, lifetime_with_bounds) =
30 build_arbitrary_lifetime(input.generics.clone());
31
32 let recursive_count = syn::Ident::new(
34 &format!("RECURSIVE_COUNT_{}", input.ident),
35 Span::call_site(),
36 );
37
38 let (arbitrary_method, needs_recursive_count) =
39 gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
40 let size_hint_method = gen_size_hint_method(&input, needs_recursive_count)?;
41 let name = input.ident;
42
43 let generics = apply_trait_bounds(
45 input.generics,
46 lifetime_without_bounds.clone(),
47 &container_attrs,
48 )?;
49
50 let mut generics_with_lifetime = generics.clone();
52 generics_with_lifetime
53 .params
54 .push(GenericParam::Lifetime(lifetime_with_bounds));
55 let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
56
57 let (_, ty_generics, where_clause) = generics.split_for_impl();
59
60 let recursive_count = needs_recursive_count.then(|| {
61 Some(quote! {
62 ::std::thread_local! {
63 #[allow(non_upper_case_globals)]
64 static #recursive_count: ::core::cell::Cell<u32> = const {
65 ::core::cell::Cell::new(0)
66 };
67 }
68 })
69 });
70
71 Ok(quote! {
72 const _: () = {
73 #recursive_count
74
75 #[automatically_derived]
76 impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds>
77 for #name #ty_generics #where_clause
78 {
79 #arbitrary_method
80 #size_hint_method
81 }
82 };
83 })
84}
85
86fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
89 let lifetime_without_bounds =
90 LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
91 let mut lifetime_with_bounds = lifetime_without_bounds.clone();
92
93 for param in generics.params.iter() {
94 if let GenericParam::Lifetime(lifetime_def) = param {
95 lifetime_with_bounds
96 .bounds
97 .push(lifetime_def.lifetime.clone());
98 }
99 }
100
101 (lifetime_without_bounds, lifetime_with_bounds)
102}
103
104fn apply_trait_bounds(
105 mut generics: Generics,
106 lifetime: LifetimeParam,
107 container_attrs: &ContainerAttributes,
108) -> Result<Generics> {
109 if let Some(config_bounds) = &container_attrs.bounds {
111 let mut config_bounds_applied = 0;
112 for param in generics.params.iter_mut() {
113 if let GenericParam::Type(type_param) = param {
114 if let Some(replacement) = config_bounds
115 .iter()
116 .flatten()
117 .find(|p| p.ident == type_param.ident)
118 {
119 *type_param = replacement.clone();
120 config_bounds_applied += 1;
121 } else {
122 type_param.bounds = Default::default();
125 type_param.default = None;
126 }
127 }
128 }
129 let config_bounds_supplied = config_bounds
130 .iter()
131 .map(|bounds| bounds.len())
132 .sum::<usize>();
133 if config_bounds_applied != config_bounds_supplied {
134 return Err(Error::new(
135 Span::call_site(),
136 format!(
137 "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
138 ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
139 ),
140 ));
141 }
142 Ok(generics)
143 } else {
144 Ok(add_trait_bounds(generics, lifetime))
146 }
147}
148
149fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
151 for param in generics.params.iter_mut() {
152 if let GenericParam::Type(type_param) = param {
153 type_param
154 .bounds
155 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
156 }
157 }
158 generics
159}
160
161fn gen_arbitrary_method(
162 input: &DeriveInput,
163 lifetime: LifetimeParam,
164 recursive_count: &syn::Ident,
165) -> Result<(TokenStream, bool)> {
166 fn arbitrary_structlike(
167 fields: &Fields,
168 ident: &syn::Ident,
169 lifetime: LifetimeParam,
170 recursive_count: &syn::Ident,
171 ) -> Result<TokenStream> {
172 let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
173 let body = quote! {
174 arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
175 Ok(#ident #arbitrary)
176 })
177 };
178
179 let arbitrary_take_rest = construct_take_rest(fields)?;
180 let take_rest_body = quote! {
181 arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
182 Ok(#ident #arbitrary_take_rest)
183 })
184 };
185
186 Ok(quote! {
187 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
188 #body
189 }
190
191 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
192 #take_rest_body
193 }
194 })
195 }
196
197 fn arbitrary_variant(
198 index: u64,
199 enum_name: &Ident,
200 variant_name: &Ident,
201 ctor: TokenStream,
202 ) -> TokenStream {
203 quote! { #index => #enum_name::#variant_name #ctor }
204 }
205
206 fn arbitrary_enum_method(
207 recursive_count: &syn::Ident,
208 unstructured: TokenStream,
209 variants: &[TokenStream],
210 needs_recursive_count: bool,
211 ) -> TokenStream {
212 let count = variants.len() as u64;
213
214 let do_variants = quote! {
215 Ok(match (
219 u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count
220 ) >> 32
221 {
222 #(#variants,)*
223 _ => unreachable!()
224 })
225 };
226
227 if needs_recursive_count {
228 quote! {
229 arbitrary::details::with_recursive_count(u, &#recursive_count, |mut u| {
230 #do_variants
231 })
232 }
233 } else {
234 do_variants
235 }
236 }
237
238 fn arbitrary_enum(
239 DataEnum { variants, .. }: &DataEnum,
240 enum_name: &Ident,
241 lifetime: LifetimeParam,
242 recursive_count: &syn::Ident,
243 ) -> Result<(TokenStream, bool)> {
244 let filtered_variants = variants.iter().filter(not_skipped);
245
246 filtered_variants
248 .clone()
249 .try_for_each(check_variant_attrs)?;
250
251 let enumerated_variants = filtered_variants
253 .enumerate()
254 .map(|(index, variant)| (index as u64, variant));
255
256 let mut needs_recursive_count = false;
258 let variants = enumerated_variants
259 .clone()
260 .map(|(index, Variant { fields, ident, .. })| {
261 construct(fields, |_, field| gen_constructor_for_field(field)).map(|ctor| {
262 if !ctor.is_empty() {
263 needs_recursive_count = true;
264 }
265 arbitrary_variant(index, enum_name, ident, ctor)
266 })
267 })
268 .collect::<Result<Vec<TokenStream>>>()?;
269
270 let variants_take_rest = enumerated_variants
272 .map(|(index, Variant { fields, ident, .. })| {
273 construct_take_rest(fields)
274 .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
275 })
276 .collect::<Result<Vec<TokenStream>>>()?;
277
278 (!variants.is_empty())
283 .then(|| {
284 let arbitrary = arbitrary_enum_method(
286 recursive_count,
287 quote! { u },
288 &variants,
289 needs_recursive_count,
290 );
291 let arbitrary_take_rest = arbitrary_enum_method(
292 recursive_count,
293 quote! { &mut u },
294 &variants_take_rest,
295 needs_recursive_count,
296 );
297
298 (
299 quote! {
300 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>)
301 -> arbitrary::Result<Self>
302 {
303 #arbitrary
304 }
305
306 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>)
307 -> arbitrary::Result<Self>
308 {
309 #arbitrary_take_rest
310 }
311 },
312 needs_recursive_count,
313 )
314 })
315 .ok_or_else(|| {
316 Error::new_spanned(
317 enum_name,
318 "Enum must have at least one variant, that is not skipped",
319 )
320 })
321 }
322
323 let ident = &input.ident;
324 let needs_recursive_count = true;
325 match &input.data {
326 Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)
327 .map(|ts| (ts, needs_recursive_count)),
328 Data::Union(data) => arbitrary_structlike(
329 &Fields::Named(data.fields.clone()),
330 ident,
331 lifetime,
332 recursive_count,
333 )
334 .map(|ts| (ts, needs_recursive_count)),
335 Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
336 }
337}
338
339fn construct(
340 fields: &Fields,
341 ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
342) -> Result<TokenStream> {
343 let output = match fields {
344 Fields::Named(names) => {
345 let names: Vec<TokenStream> = names
346 .named
347 .iter()
348 .enumerate()
349 .map(|(i, f)| {
350 let name = f.ident.as_ref().unwrap();
351 ctor(i, f).map(|ctor| quote! { #name: #ctor })
352 })
353 .collect::<Result<_>>()?;
354 quote! { { #(#names,)* } }
355 }
356 Fields::Unnamed(names) => {
357 let names: Vec<TokenStream> = names
358 .unnamed
359 .iter()
360 .enumerate()
361 .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
362 .collect::<Result<_>>()?;
363 quote! { ( #(#names),* ) }
364 }
365 Fields::Unit => quote!(),
366 };
367 Ok(output)
368}
369
370fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
371 construct(fields, |idx, field| {
372 determine_field_constructor(field).map(|field_constructor| match field_constructor {
373 FieldConstructor::Default => quote!(::core::default::Default::default()),
374 FieldConstructor::Arbitrary => {
375 if idx + 1 == fields.len() {
376 quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
377 } else {
378 quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
379 }
380 }
381 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
382 FieldConstructor::Value(value) => quote!(#value),
383 })
384 })
385}
386
387fn gen_size_hint_method(input: &DeriveInput, needs_recursive_count: bool) -> Result<TokenStream> {
388 let size_hint_fields = |fields: &Fields| {
389 fields
390 .iter()
391 .map(|f| {
392 let ty = &f.ty;
393 determine_field_constructor(f).map(|field_constructor| {
394 match field_constructor {
395 FieldConstructor::Default | FieldConstructor::Value(_) => {
396 quote!(Ok((0, Some(0))))
397 }
398 FieldConstructor::Arbitrary => {
399 quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
400 }
401
402 FieldConstructor::With(_) => {
406 quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
407 }
408 }
409 })
410 })
411 .collect::<Result<Vec<TokenStream>>>()
412 .map(|hints| {
413 quote! {
414 Ok(arbitrary::size_hint::and_all(&[
415 #( #hints? ),*
416 ]))
417 }
418 })
419 };
420 let size_hint_structlike = |fields: &Fields| {
421 assert!(needs_recursive_count);
422 size_hint_fields(fields).map(|hint| {
423 quote! {
424 #[inline]
425 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
426 Self::try_size_hint(depth).unwrap_or_default()
427 }
428
429 #[inline]
430 fn try_size_hint(depth: usize)
431 -> ::core::result::Result<
432 (usize, ::core::option::Option<usize>),
433 arbitrary::MaxRecursionReached,
434 >
435 {
436 arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
437 }
438 }
439 })
440 };
441 match &input.data {
442 Data::Struct(data) => size_hint_structlike(&data.fields),
443 Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
444 Data::Enum(data) => data
445 .variants
446 .iter()
447 .filter(not_skipped)
448 .map(|Variant { fields, .. }| {
449 if !needs_recursive_count {
450 assert!(fields.is_empty());
451 }
452 size_hint_fields(fields)
455 })
456 .collect::<Result<Vec<TokenStream>>>()
457 .map(|variants| {
458 if needs_recursive_count {
459 quote! {
462 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
463 Self::try_size_hint(depth).unwrap_or_default()
464 }
465 #[inline]
466 fn try_size_hint(depth: usize)
467 -> ::core::result::Result<
468 (usize, ::core::option::Option<usize>),
469 arbitrary::MaxRecursionReached,
470 >
471 {
472 Ok(arbitrary::size_hint::and(
473 <u32 as arbitrary::Arbitrary>::size_hint(depth),
474 arbitrary::size_hint::try_recursion_guard(depth, |depth| {
475 Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
476 })?,
477 ))
478 }
479 }
480 } else {
481 quote! {
484 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
485 <u32 as arbitrary::Arbitrary>::size_hint(depth)
486 }
487 }
488 }
489 }),
490 }
491}
492
493fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
494 let ctor = match determine_field_constructor(field)? {
495 FieldConstructor::Default => quote!(::core::default::Default::default()),
496 FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
497 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
498 FieldConstructor::Value(value) => quote!(#value),
499 };
500 Ok(ctor)
501}
502
503fn check_variant_attrs(variant: &Variant) -> Result<()> {
504 for attr in &variant.attrs {
505 if attr.path().is_ident(ARBITRARY_ATTRIBUTE_NAME) {
506 return Err(Error::new_spanned(
507 attr,
508 format!(
509 "invalid `{}` attribute. it is unsupported on enum variants. try applying it to a field of the variant instead",
510 ARBITRARY_ATTRIBUTE_NAME
511 ),
512 ));
513 }
514 }
515 Ok(())
516}