1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::*;
6
7mod container_attributes;
8mod field_attributes;
9mod variant_attributes;
10
11use container_attributes::ContainerAttributes;
12use field_attributes::{determine_field_constructor, FieldConstructor};
13use variant_attributes::not_skipped;
14
15const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
16const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
17
18#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
19pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
20 let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
21 expand_derive_arbitrary(input)
22 .unwrap_or_else(syn::Error::into_compile_error)
23 .into()
24}
25
26fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
27 let container_attrs = ContainerAttributes::from_derive_input(&input)?;
28
29 let (lifetime_without_bounds, lifetime_with_bounds) =
30 build_arbitrary_lifetime(input.generics.clone());
31
32 let recursive_count = syn::Ident::new(
33 &format!("RECURSIVE_COUNT_{}", input.ident),
34 Span::call_site(),
35 );
36
37 let arbitrary_method =
38 gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
39 let size_hint_method = gen_size_hint_method(&input)?;
40 let name = input.ident;
41
42 let generics = apply_trait_bounds(
44 input.generics,
45 lifetime_without_bounds.clone(),
46 &container_attrs,
47 )?;
48
49 let mut generics_with_lifetime = generics.clone();
51 generics_with_lifetime
52 .params
53 .push(GenericParam::Lifetime(lifetime_with_bounds));
54 let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
55
56 let (_, ty_generics, where_clause) = generics.split_for_impl();
58
59 Ok(quote! {
60 const _: () = {
61 ::std::thread_local! {
62 #[allow(non_upper_case_globals)]
63 static #recursive_count: ::core::cell::Cell<u32> = ::core::cell::Cell::new(0);
64 }
65
66 #[automatically_derived]
67 impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
68 #arbitrary_method
69 #size_hint_method
70 }
71 };
72 })
73}
74
75fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
78 let lifetime_without_bounds =
79 LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
80 let mut lifetime_with_bounds = lifetime_without_bounds.clone();
81
82 for param in generics.params.iter() {
83 if let GenericParam::Lifetime(lifetime_def) = param {
84 lifetime_with_bounds
85 .bounds
86 .push(lifetime_def.lifetime.clone());
87 }
88 }
89
90 (lifetime_without_bounds, lifetime_with_bounds)
91}
92
93fn apply_trait_bounds(
94 mut generics: Generics,
95 lifetime: LifetimeParam,
96 container_attrs: &ContainerAttributes,
97) -> Result<Generics> {
98 if let Some(config_bounds) = &container_attrs.bounds {
100 let mut config_bounds_applied = 0;
101 for param in generics.params.iter_mut() {
102 if let GenericParam::Type(type_param) = param {
103 if let Some(replacement) = config_bounds
104 .iter()
105 .flatten()
106 .find(|p| p.ident == type_param.ident)
107 {
108 *type_param = replacement.clone();
109 config_bounds_applied += 1;
110 } else {
111 type_param.bounds = Default::default();
114 type_param.default = None;
115 }
116 }
117 }
118 let config_bounds_supplied = config_bounds
119 .iter()
120 .map(|bounds| bounds.len())
121 .sum::<usize>();
122 if config_bounds_applied != config_bounds_supplied {
123 return Err(Error::new(
124 Span::call_site(),
125 format!(
126 "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
127 ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
128 ),
129 ));
130 }
131 Ok(generics)
132 } else {
133 Ok(add_trait_bounds(generics, lifetime))
135 }
136}
137
138fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
140 for param in generics.params.iter_mut() {
141 if let GenericParam::Type(type_param) = param {
142 type_param
143 .bounds
144 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
145 }
146 }
147 generics
148}
149
150fn with_recursive_count_guard(
151 recursive_count: &syn::Ident,
152 expr: impl quote::ToTokens,
153) -> impl quote::ToTokens {
154 quote! {
155 let guard_against_recursion = u.is_empty();
156 if guard_against_recursion {
157 #recursive_count.with(|count| {
158 if count.get() > 0 {
159 return Err(arbitrary::Error::NotEnoughData);
160 }
161 count.set(count.get() + 1);
162 Ok(())
163 })?;
164 }
165
166 let result = (|| { #expr })();
167
168 if guard_against_recursion {
169 #recursive_count.with(|count| {
170 count.set(count.get() - 1);
171 });
172 }
173
174 result
175 }
176}
177
178fn gen_arbitrary_method(
179 input: &DeriveInput,
180 lifetime: LifetimeParam,
181 recursive_count: &syn::Ident,
182) -> Result<TokenStream> {
183 fn arbitrary_structlike(
184 fields: &Fields,
185 ident: &syn::Ident,
186 lifetime: LifetimeParam,
187 recursive_count: &syn::Ident,
188 ) -> Result<TokenStream> {
189 let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
190 let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
191
192 let arbitrary_take_rest = construct_take_rest(fields)?;
193 let take_rest_body =
194 with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
195
196 Ok(quote! {
197 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
198 #body
199 }
200
201 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
202 #take_rest_body
203 }
204 })
205 }
206
207 fn arbitrary_variant(
208 index: u64,
209 enum_name: &Ident,
210 variant_name: &Ident,
211 ctor: TokenStream,
212 ) -> TokenStream {
213 quote! { #index => #enum_name::#variant_name #ctor }
214 }
215
216 fn arbitrary_enum_method(
217 recursive_count: &syn::Ident,
218 unstructured: TokenStream,
219 variants: &[TokenStream],
220 ) -> impl quote::ToTokens {
221 let count = variants.len() as u64;
222 with_recursive_count_guard(
223 recursive_count,
224 quote! {
225 Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
229 #(#variants,)*
230 _ => unreachable!()
231 })
232 },
233 )
234 }
235
236 fn arbitrary_enum(
237 DataEnum { variants, .. }: &DataEnum,
238 enum_name: &Ident,
239 lifetime: LifetimeParam,
240 recursive_count: &syn::Ident,
241 ) -> Result<TokenStream> {
242 let filtered_variants = variants.iter().filter(not_skipped);
243
244 filtered_variants
246 .clone()
247 .try_for_each(check_variant_attrs)?;
248
249 let enumerated_variants = filtered_variants
251 .enumerate()
252 .map(|(index, variant)| (index as u64, variant));
253
254 let variants = enumerated_variants
256 .clone()
257 .map(|(index, Variant { fields, ident, .. })| {
258 construct(fields, |_, field| gen_constructor_for_field(field))
259 .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
260 })
261 .collect::<Result<Vec<TokenStream>>>()?;
262
263 let variants_take_rest = enumerated_variants
265 .map(|(index, Variant { fields, ident, .. })| {
266 construct_take_rest(fields)
267 .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
268 })
269 .collect::<Result<Vec<TokenStream>>>()?;
270
271 (!variants.is_empty())
276 .then(|| {
277 let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
279 let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
280
281 quote! {
282 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
283 #arbitrary
284 }
285
286 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
287 #arbitrary_take_rest
288 }
289 }
290 })
291 .ok_or_else(|| Error::new_spanned(
292 enum_name,
293 "Enum must have at least one variant, that is not skipped"
294 ))
295 }
296
297 let ident = &input.ident;
298 match &input.data {
299 Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
300 Data::Union(data) => arbitrary_structlike(
301 &Fields::Named(data.fields.clone()),
302 ident,
303 lifetime,
304 recursive_count,
305 ),
306 Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
307 }
308}
309
310fn construct(
311 fields: &Fields,
312 ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
313) -> Result<TokenStream> {
314 let output = match fields {
315 Fields::Named(names) => {
316 let names: Vec<TokenStream> = names
317 .named
318 .iter()
319 .enumerate()
320 .map(|(i, f)| {
321 let name = f.ident.as_ref().unwrap();
322 ctor(i, f).map(|ctor| quote! { #name: #ctor })
323 })
324 .collect::<Result<_>>()?;
325 quote! { { #(#names,)* } }
326 }
327 Fields::Unnamed(names) => {
328 let names: Vec<TokenStream> = names
329 .unnamed
330 .iter()
331 .enumerate()
332 .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
333 .collect::<Result<_>>()?;
334 quote! { ( #(#names),* ) }
335 }
336 Fields::Unit => quote!(),
337 };
338 Ok(output)
339}
340
341fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
342 construct(fields, |idx, field| {
343 determine_field_constructor(field).map(|field_constructor| match field_constructor {
344 FieldConstructor::Default => quote!(::core::default::Default::default()),
345 FieldConstructor::Arbitrary => {
346 if idx + 1 == fields.len() {
347 quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
348 } else {
349 quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
350 }
351 }
352 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
353 FieldConstructor::Value(value) => quote!(#value),
354 })
355 })
356}
357
358fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
359 let size_hint_fields = |fields: &Fields| {
360 fields
361 .iter()
362 .map(|f| {
363 let ty = &f.ty;
364 determine_field_constructor(f).map(|field_constructor| {
365 match field_constructor {
366 FieldConstructor::Default | FieldConstructor::Value(_) => {
367 quote!(Ok((0, Some(0))))
368 }
369 FieldConstructor::Arbitrary => {
370 quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
371 }
372
373 FieldConstructor::With(_) => {
377 quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
378 }
379 }
380 })
381 })
382 .collect::<Result<Vec<TokenStream>>>()
383 .map(|hints| {
384 quote! {
385 Ok(arbitrary::size_hint::and_all(&[
386 #( #hints? ),*
387 ]))
388 }
389 })
390 };
391 let size_hint_structlike = |fields: &Fields| {
392 size_hint_fields(fields).map(|hint| {
393 quote! {
394 #[inline]
395 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
396 Self::try_size_hint(depth).unwrap_or_default()
397 }
398
399 #[inline]
400 fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
401 arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
402 }
403 }
404 })
405 };
406 match &input.data {
407 Data::Struct(data) => size_hint_structlike(&data.fields),
408 Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
409 Data::Enum(data) => data
410 .variants
411 .iter()
412 .filter(not_skipped)
413 .map(|Variant { fields, .. }| {
414 size_hint_fields(fields)
417 })
418 .collect::<Result<Vec<TokenStream>>>()
419 .map(|variants| {
420 quote! {
421 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
422 Self::try_size_hint(depth).unwrap_or_default()
423 }
424 #[inline]
425 fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
426 Ok(arbitrary::size_hint::and(
427 <u32 as arbitrary::Arbitrary>::try_size_hint(depth)?,
428 arbitrary::size_hint::try_recursion_guard(depth, |depth| {
429 Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
430 })?,
431 ))
432 }
433 }
434 }),
435 }
436}
437
438fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
439 let ctor = match determine_field_constructor(field)? {
440 FieldConstructor::Default => quote!(::core::default::Default::default()),
441 FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
442 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
443 FieldConstructor::Value(value) => quote!(#value),
444 };
445 Ok(ctor)
446}
447
448fn check_variant_attrs(variant: &Variant) -> Result<()> {
449 for attr in &variant.attrs {
450 if attr.path().is_ident(ARBITRARY_ATTRIBUTE_NAME) {
451 return Err(Error::new_spanned(
452 attr,
453 format!(
454 "invalid `{}` attribute. it is unsupported on enum variants. try applying it to a field of the variant instead",
455 ARBITRARY_ATTRIBUTE_NAME
456 ),
457 ));
458 }
459 }
460 Ok(())
461}