bytemuck_derive/
traits.rs

1#![allow(unused_imports)]
2use std::{cmp, convert::TryFrom};
3
4use proc_macro2::{Ident, Span, TokenStream, TokenTree};
5use quote::{quote, ToTokens};
6use syn::{
7  parse::{Parse, ParseStream, Parser},
8  punctuated::Punctuated,
9  spanned::Spanned,
10  Result, *,
11};
12
13macro_rules! bail {
14  ($msg:expr $(,)?) => {
15    return Err(Error::new(Span::call_site(), &$msg[..]))
16  };
17
18  ( $msg:expr => $span_to_blame:expr $(,)? ) => {
19    return Err(Error::new_spanned(&$span_to_blame, $msg))
20  };
21}
22
23pub trait Derivable {
24  fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>;
25  fn implies_trait(_crate_name: &TokenStream) -> Option<TokenStream> {
26    None
27  }
28  fn asserts(
29    _input: &DeriveInput, _crate_name: &TokenStream,
30  ) -> Result<TokenStream> {
31    Ok(quote!())
32  }
33  fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
34    Ok(())
35  }
36  fn trait_impl(
37    _input: &DeriveInput, _crate_name: &TokenStream,
38  ) -> Result<(TokenStream, TokenStream)> {
39    Ok((quote!(), quote!()))
40  }
41  fn requires_where_clause() -> bool {
42    true
43  }
44  fn explicit_bounds_attribute_name() -> Option<&'static str> {
45    None
46  }
47
48  /// If this trait has a custom meaning for "perfect derive", this function
49  /// should be overridden to return `Some`.
50  ///
51  /// The default is "the fields of a struct; unions and enums not supported".
52  fn perfect_derive_fields(_input: &DeriveInput) -> Option<Fields> {
53    None
54  }
55}
56
57pub struct Pod;
58
59impl Derivable for Pod {
60  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
61    Ok(syn::parse_quote!(#crate_name::Pod))
62  }
63
64  fn asserts(
65    input: &DeriveInput, crate_name: &TokenStream,
66  ) -> Result<TokenStream> {
67    let repr = get_repr(&input.attrs)?;
68
69    let completly_packed =
70      repr.packed == Some(1) || repr.repr == Repr::Transparent;
71
72    if !completly_packed && !input.generics.params.is_empty() {
73      bail!("\
74        Pod requires cannot be derived for non-packed types containing \
75        generic parameters because the padding requirements can't be verified \
76        for generic non-packed structs\
77      " => input.generics.params.first().unwrap());
78    }
79
80    match &input.data {
81      Data::Struct(_) => {
82        let assert_no_padding = if !completly_packed {
83          Some(generate_assert_no_padding(input, None, "Pod")?)
84        } else {
85          None
86        };
87        let assert_fields_are_pod = generate_fields_are_trait(
88          input,
89          None,
90          Self::ident(input, crate_name)?,
91        )?;
92
93        Ok(quote!(
94          #assert_no_padding
95          #assert_fields_are_pod
96        ))
97      }
98      Data::Enum(_) => bail!("Deriving Pod is not supported for enums"),
99      Data::Union(_) => bail!("Deriving Pod is not supported for unions"),
100    }
101  }
102
103  fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
104    let repr = get_repr(attributes)?;
105    match repr.repr {
106      Repr::C => Ok(()),
107      Repr::Transparent => Ok(()),
108      _ => {
109        bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
110      }
111    }
112  }
113}
114
115pub struct AnyBitPattern;
116
117impl Derivable for AnyBitPattern {
118  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
119    Ok(syn::parse_quote!(#crate_name::AnyBitPattern))
120  }
121
122  fn implies_trait(crate_name: &TokenStream) -> Option<TokenStream> {
123    Some(quote!(#crate_name::Zeroable))
124  }
125
126  fn asserts(
127    input: &DeriveInput, crate_name: &TokenStream,
128  ) -> Result<TokenStream> {
129    match &input.data {
130      Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
131      Data::Struct(_) => {
132        generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
133      }
134      Data::Enum(_) => {
135        bail!("Deriving AnyBitPattern is not supported for enums")
136      }
137    }
138  }
139}
140
141pub struct Zeroable;
142
143/// Helper function to get the variant with discriminant zero (implicit or
144/// explicit).
145fn get_zero_variant(enum_: &DataEnum) -> Result<Option<&Variant>> {
146  let iter = VariantDiscriminantIterator::new(enum_.variants.iter());
147  let mut zero_variant = None;
148  for res in iter {
149    let (discriminant, variant) = res?;
150    if discriminant == 0 {
151      zero_variant = Some(variant);
152      break;
153    }
154  }
155  Ok(zero_variant)
156}
157
158impl Derivable for Zeroable {
159  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
160    Ok(syn::parse_quote!(#crate_name::Zeroable))
161  }
162
163  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
164    let repr = get_repr(attributes)?;
165    match ty {
166      Data::Struct(_) => Ok(()),
167      Data::Enum(_) => {
168        if !matches!(
169          repr.repr,
170          Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_)
171        ) {
172          bail!("Zeroable requires the enum to be an explicit #[repr(Int)] and/or #[repr(C)]")
173        }
174
175        // We ensure there is a zero variant in `asserts`, since it is needed
176        // there anyway.
177
178        Ok(())
179      }
180      Data::Union(_) => Ok(()),
181    }
182  }
183
184  fn asserts(
185    input: &DeriveInput, crate_name: &TokenStream,
186  ) -> Result<TokenStream> {
187    match &input.data {
188      Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
189      Data::Struct(_) => {
190        generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
191      }
192      Data::Enum(enum_) => {
193        let zero_variant = get_zero_variant(enum_)?;
194
195        if zero_variant.is_none() {
196          bail!("No variant's discriminant is 0")
197        };
198
199        generate_fields_are_trait(
200          input,
201          zero_variant,
202          Self::ident(input, crate_name)?,
203        )
204      }
205    }
206  }
207
208  fn explicit_bounds_attribute_name() -> Option<&'static str> {
209    Some("zeroable")
210  }
211
212  fn perfect_derive_fields(input: &DeriveInput) -> Option<Fields> {
213    match &input.data {
214      Data::Struct(struct_) => Some(struct_.fields.clone()),
215      Data::Enum(enum_) => {
216        // We handle `Err` returns from `get_zero_variant` in `asserts`, so it's
217        // fine to just ignore them here and return `None`.
218        // Otherwise, we clone the `fields` of the zero variant (if any).
219        Some(get_zero_variant(enum_).ok()??.fields.clone())
220      }
221      Data::Union(_) => None,
222    }
223  }
224}
225
226pub struct NoUninit;
227
228impl Derivable for NoUninit {
229  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
230    Ok(syn::parse_quote!(#crate_name::NoUninit))
231  }
232
233  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
234    let repr = get_repr(attributes)?;
235    match ty {
236      Data::Struct(_) => match repr.repr {
237        Repr::C | Repr::Transparent => Ok(()),
238        _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
239      },
240      Data::Enum(DataEnum { variants,.. }) => {
241        if !enum_has_fields(variants.iter()) {
242          if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
243            Ok(())
244          } else {
245            bail!("NoUninit requires the enum to be #[repr(C)] or #[repr(Int)]")
246          }
247        } else if matches!(repr.repr, Repr::Rust) {
248          bail!("NoUninit requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
249        } else {
250          Ok(())
251        }
252      },
253      Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
254    }
255  }
256
257  fn asserts(
258    input: &DeriveInput, crate_name: &TokenStream,
259  ) -> Result<TokenStream> {
260    if !input.generics.params.is_empty() {
261      bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
262    }
263
264    match &input.data {
265      Data::Struct(DataStruct { .. }) => {
266        let assert_no_padding =
267          generate_assert_no_padding(&input, None, "NoUninit")?;
268        let assert_fields_are_no_padding = generate_fields_are_trait(
269          &input,
270          None,
271          Self::ident(input, crate_name)?,
272        )?;
273
274        Ok(quote!(
275            #assert_no_padding
276            #assert_fields_are_no_padding
277        ))
278      }
279      Data::Enum(DataEnum { variants, .. }) => {
280        if enum_has_fields(variants.iter()) {
281          // There are two different C representations for enums with fields:
282          // There's `#[repr(C)]`/`[repr(C, int)]` and `#[repr(int)]`.
283          // `#[repr(C)]` is equivalent to a struct containing the discriminant
284          // and a union of structs representing each variant's fields.
285          // `#[repr(int)]` is equivalent to a union containing structs of the
286          // discriminant and the fields.
287          //
288          // See https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.c.adt
289          // and https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.primitive.adt
290          //
291          // In practice the only difference between the two is whether and
292          // where padding bytes are placed. For `#[repr(C)]` enums, the first
293          // enum fields of all variants start at the same location (the first
294          // byte in the union). For `#[repr(int)]` enums, the structs
295          // representing each variant are layed out individually and padding
296          // does not depend on other variants, but only on the size of the
297          // discriminant and the alignment of the first field. The location of
298          // the first field might differ between variants, potentially
299          // resulting in less padding or padding placed later in the enum.
300          //
301          // The `NoUninit` derive macro asserts that no padding exists by
302          // removing all padding with `#[repr(packed)]` and checking that this
303          // doesn't change the size. Since the location and presence of
304          // padding bytes is the only difference between the two
305          // representations and we're removing all padding bytes, the resuling
306          // layout would identical for both representations. This means that
307          // we can just pick one of the representations and don't have to
308          // implement desugaring for both. We chose to implement the
309          // desugaring for `#[repr(int)]`.
310
311          let enum_discriminant = generate_enum_discriminant(input)?;
312          let variant_assertions = variants
313            .iter()
314            .map(|variant| {
315              let assert_no_padding =
316                generate_assert_no_padding(&input, Some(variant), "NoUninit")?;
317              let assert_fields_are_no_padding = generate_fields_are_trait(
318                &input,
319                Some(variant),
320                Self::ident(input, crate_name)?,
321              )?;
322
323              Ok(quote!(
324                  #assert_no_padding
325                  #assert_fields_are_no_padding
326              ))
327            })
328            .collect::<Result<Vec<_>>>()?;
329          Ok(quote! {
330            const _: () = {
331              #enum_discriminant
332              #(#variant_assertions)*
333            };
334          })
335        } else {
336          Ok(quote!())
337        }
338      }
339      Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
340    }
341  }
342
343  fn trait_impl(
344    _input: &DeriveInput, _crate_name: &TokenStream,
345  ) -> Result<(TokenStream, TokenStream)> {
346    Ok((quote!(), quote!()))
347  }
348}
349
350pub struct CheckedBitPattern;
351
352impl Derivable for CheckedBitPattern {
353  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
354    Ok(syn::parse_quote!(#crate_name::CheckedBitPattern))
355  }
356
357  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
358    let repr = get_repr(attributes)?;
359    match ty {
360      Data::Struct(_) => match repr.repr {
361        Repr::C | Repr::Transparent => Ok(()),
362        _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
363      },
364      Data::Enum(DataEnum { variants,.. }) => {
365        if !enum_has_fields(variants.iter()){
366          if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
367            Ok(())
368          } else {
369            bail!("CheckedBitPattern requires the enum to be #[repr(C)] or #[repr(Int)]")
370          }
371        } else if matches!(repr.repr, Repr::Rust) {
372          bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
373        } else {
374          Ok(())
375        }
376      }
377      Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
378    }
379  }
380
381  fn asserts(
382    input: &DeriveInput, crate_name: &TokenStream,
383  ) -> Result<TokenStream> {
384    if !input.generics.params.is_empty() {
385      bail!("CheckedBitPattern cannot be derived for structs containing generic parameters");
386    }
387
388    match &input.data {
389      Data::Struct(DataStruct { .. }) => {
390        let assert_fields_are_maybe_pod = generate_fields_are_trait(
391          &input,
392          None,
393          Self::ident(input, crate_name)?,
394        )?;
395
396        Ok(assert_fields_are_maybe_pod)
397      }
398      // nothing needed, already guaranteed OK by NoUninit.
399      Data::Enum(_) => Ok(quote!()),
400      Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
401    }
402  }
403
404  fn trait_impl(
405    input: &DeriveInput, crate_name: &TokenStream,
406  ) -> Result<(TokenStream, TokenStream)> {
407    match &input.data {
408      Data::Struct(DataStruct { fields, .. }) => {
409        generate_checked_bit_pattern_struct(
410          &input.ident,
411          fields,
412          &input.attrs,
413          crate_name,
414        )
415      }
416      Data::Enum(DataEnum { variants, .. }) => {
417        generate_checked_bit_pattern_enum(input, variants, crate_name)
418      }
419      Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
420    }
421  }
422}
423
424pub struct TransparentWrapper;
425
426impl TransparentWrapper {
427  fn get_wrapper_type(
428    attributes: &[Attribute], fields: &Fields,
429  ) -> Option<TokenStream> {
430    let transparent_param = get_simple_attr(attributes, "transparent");
431    transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
432      let mut types = get_field_types(&fields);
433      let first_type = types.next();
434      if let Some(_) = types.next() {
435        // can't guess param type if there is more than one field
436        return None;
437      } else {
438        first_type.map(|ty| ty.to_token_stream())
439      }
440    })
441  }
442}
443
444impl Derivable for TransparentWrapper {
445  fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
446    let fields = get_struct_fields(input)?;
447
448    let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
449      Some(ty) => ty,
450      None => bail!(
451        "\
452        when deriving TransparentWrapper for a struct with more than one field \
453        you need to specify the transparent field using #[transparent(T)]\
454      "
455      ),
456    };
457
458    Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
459  }
460
461  fn asserts(
462    input: &DeriveInput, crate_name: &TokenStream,
463  ) -> Result<TokenStream> {
464    let (impl_generics, _ty_generics, where_clause) =
465      input.generics.split_for_impl();
466    let fields = get_struct_fields(input)?;
467    let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
468      Some(wrapped_type) => wrapped_type.to_string(),
469      None => unreachable!(), /* other code will already reject this derive */
470    };
471    let mut wrapped_field_ty = None;
472    let mut nonwrapped_field_tys = vec![];
473    for field in fields.iter() {
474      let field_ty = &field.ty;
475      if field_ty.to_token_stream().to_string() == wrapped_type {
476        if wrapped_field_ty.is_some() {
477          bail!(
478            "TransparentWrapper can only have one field of the wrapped type"
479          );
480        }
481        wrapped_field_ty = Some(field_ty);
482      } else {
483        nonwrapped_field_tys.push(field_ty);
484      }
485    }
486    if let Some(wrapped_field_ty) = wrapped_field_ty {
487      Ok(quote!(
488        const _: () = {
489          #[repr(transparent)]
490          #[allow(clippy::multiple_bound_locations)]
491          struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
492          fn assert_zeroable<Z: #crate_name::Zeroable>() {}
493          #[allow(clippy::multiple_bound_locations)]
494          fn check #impl_generics () #where_clause {
495            #(
496              assert_zeroable::<#nonwrapped_field_tys>();
497            )*
498          }
499        };
500      ))
501    } else {
502      bail!("TransparentWrapper must have one field of the wrapped type")
503    }
504  }
505
506  fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
507    let repr = get_repr(attributes)?;
508
509    match repr.repr {
510      Repr::Transparent => Ok(()),
511      _ => {
512        bail!(
513          "TransparentWrapper requires the struct to be #[repr(transparent)]"
514        )
515      }
516    }
517  }
518
519  fn requires_where_clause() -> bool {
520    false
521  }
522}
523
524pub struct Contiguous;
525
526impl Derivable for Contiguous {
527  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
528    Ok(syn::parse_quote!(#crate_name::Contiguous))
529  }
530
531  fn trait_impl(
532    input: &DeriveInput, _crate_name: &TokenStream,
533  ) -> Result<(TokenStream, TokenStream)> {
534    let repr = get_repr(&input.attrs)?;
535
536    let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
537      integer_ty
538    } else {
539      bail!("Contiguous requires the enum to be #[repr(Int)]");
540    };
541
542    let variants = get_enum_variants(input)?;
543    if enum_has_fields(variants.clone()) {
544      return Err(Error::new_spanned(
545        &input,
546        "Only fieldless enums are supported",
547      ));
548    }
549
550    let mut variants_with_discriminant =
551      VariantDiscriminantIterator::new(variants);
552
553    let (min, max, count) = variants_with_discriminant.try_fold(
554      (i128::MAX, i128::MIN, 0),
555      |(min, max, count), res| {
556        let (discriminant, _variant) = res?;
557        Ok::<_, Error>((
558          i128::min(min, discriminant),
559          i128::max(max, discriminant),
560          count + 1,
561        ))
562      },
563    )?;
564
565    if max - min != count - 1 {
566      bail! {
567        "Contiguous requires the enum discriminants to be contiguous",
568      }
569    }
570
571    let min_lit = LitInt::new(&format!("{}", min), input.span());
572    let max_lit = LitInt::new(&format!("{}", max), input.span());
573
574    // `from_integer` and `into_integer` are usually provided by the trait's
575    // default implementation. We override this implementation because it
576    // goes through `transmute_copy`, which can lead to inefficient assembly as seen in https://github.com/Lokathor/bytemuck/issues/175 .
577
578    Ok((
579      quote!(),
580      quote! {
581          type Int = #integer_ty;
582
583          #[allow(clippy::missing_docs_in_private_items)]
584          const MIN_VALUE: #integer_ty = #min_lit;
585
586          #[allow(clippy::missing_docs_in_private_items)]
587          const MAX_VALUE: #integer_ty = #max_lit;
588
589          #[inline]
590          fn from_integer(value: Self::Int) -> Option<Self> {
591            #[allow(clippy::manual_range_contains)]
592            if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE {
593              Some(unsafe { ::core::mem::transmute(value) })
594            } else {
595              None
596            }
597          }
598
599          #[inline]
600          fn into_integer(self) -> Self::Int {
601              self as #integer_ty
602          }
603      },
604    ))
605  }
606}
607
608fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
609  if let Data::Struct(DataStruct { fields, .. }) = &input.data {
610    Ok(fields)
611  } else {
612    bail!("deriving this trait is only supported for structs")
613  }
614}
615
616/// Extract the `Fields` off a `DeriveInput`, or, in the `enum` case, off
617/// those of the `enum_variant`, when provided (e.g., for `Zeroable`).
618///
619/// We purposely allow not providing an `enum_variant` for cases where
620/// the caller wants to reject supporting `enum`s (e.g., `NoPadding`).
621fn get_fields(
622  input: &DeriveInput, enum_variant: Option<&Variant>,
623) -> Result<Fields> {
624  match &input.data {
625    Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
626    Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
627    Data::Enum(_) => match enum_variant {
628      Some(variant) => Ok(variant.fields.clone()),
629      None => bail!("deriving this trait is not supported for enums"),
630    },
631  }
632}
633
634fn get_enum_variants<'a>(
635  input: &'a DeriveInput,
636) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
637  if let Data::Enum(DataEnum { variants, .. }) = &input.data {
638    Ok(variants.iter())
639  } else {
640    bail!("deriving this trait is only supported for enums")
641  }
642}
643
644fn get_field_types<'a>(
645  fields: &'a Fields,
646) -> impl Iterator<Item = &'a Type> + 'a {
647  fields.iter().map(|field| &field.ty)
648}
649
650fn generate_checked_bit_pattern_struct(
651  input_ident: &Ident, fields: &Fields, attrs: &[Attribute],
652  crate_name: &TokenStream,
653) -> Result<(TokenStream, TokenStream)> {
654  let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
655
656  let repr = get_repr(attrs)?;
657
658  let field_names = fields
659    .iter()
660    .enumerate()
661    .map(|(i, field)| {
662      field.ident.clone().unwrap_or_else(|| {
663        Ident::new(&format!("field{}", i), input_ident.span())
664      })
665    })
666    .collect::<Vec<_>>();
667  let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
668
669  let field_name = &field_names[..];
670  let field_ty = &field_tys[..];
671
672  Ok((
673    quote! {
674        #[doc = #GENERATED_TYPE_DOCUMENTATION]
675        #repr
676        #[derive(Clone, Copy, #crate_name::AnyBitPattern)]
677        #[allow(missing_docs)]
678        pub struct #bits_ty {
679            #(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)*
680        }
681
682        #[allow(unexpected_cfgs)]
683        const _: () = {
684          #[cfg(not(target_arch = "spirv"))]
685          impl ::core::fmt::Debug for #bits_ty {
686            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
687              let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty));
688              #(::core::fmt::DebugStruct::field(&mut debug_struct, ::core::stringify!(#field_name), &{ self.#field_name });)*
689              ::core::fmt::DebugStruct::finish(&mut debug_struct)
690            }
691          }
692        };
693    },
694    quote! {
695        type Bits = #bits_ty;
696
697        #[inline]
698        #[allow(clippy::double_comparisons, unused)]
699        fn is_valid_bit_pattern(bits: &#bits_ty) -> bool {
700            #(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true
701        }
702    },
703  ))
704}
705
706fn generate_checked_bit_pattern_enum(
707  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
708  crate_name: &TokenStream,
709) -> Result<(TokenStream, TokenStream)> {
710  if enum_has_fields(variants.iter()) {
711    generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
712  } else {
713    generate_checked_bit_pattern_enum_without_fields(
714      input, variants, crate_name,
715    )
716  }
717}
718
719fn generate_checked_bit_pattern_enum_without_fields(
720  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
721  crate_name: &TokenStream,
722) -> Result<(TokenStream, TokenStream)> {
723  let span = input.span();
724  let mut variants_with_discriminant =
725    VariantDiscriminantIterator::new(variants.iter());
726
727  let (min, max, count) = variants_with_discriminant.try_fold(
728    (i128::MAX, i128::MIN, 0),
729    |(min, max, count), res| {
730      let (discriminant, _variant) = res?;
731      Ok::<_, Error>((
732        i128::min(min, discriminant),
733        i128::max(max, discriminant),
734        count + 1,
735      ))
736    },
737  )?;
738
739  let check = if count == 0 {
740    quote!(false)
741  } else if max - min == count - 1 {
742    // contiguous range
743    let min_lit = LitInt::new(&format!("{}", min), span);
744    let max_lit = LitInt::new(&format!("{}", max), span);
745
746    quote!(*bits >= #min_lit && *bits <= #max_lit)
747  } else {
748    // not contiguous range, check for each
749    let variant_discriminant_lits =
750      VariantDiscriminantIterator::new(variants.iter())
751        .map(|res| {
752          let (discriminant, _variant) = res?;
753          Ok(LitInt::new(&format!("{}", discriminant), span))
754        })
755        .collect::<Result<Vec<_>>>()?;
756
757    // count is at least 1
758    let first = &variant_discriminant_lits[0];
759    let rest = &variant_discriminant_lits[1..];
760
761    quote!(matches!(*bits, #first #(| #rest )*))
762  };
763
764  let (integer, defs) = get_enum_discriminant(input, crate_name)?;
765  Ok((
766    quote!(#defs),
767    quote! {
768        type Bits = #integer;
769
770        #[inline]
771        #[allow(clippy::double_comparisons)]
772        fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
773            #check
774        }
775    },
776  ))
777}
778
779fn generate_checked_bit_pattern_enum_with_fields(
780  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
781  crate_name: &TokenStream,
782) -> Result<(TokenStream, TokenStream)> {
783  let representation = get_repr(&input.attrs)?;
784  let vis = &input.vis;
785
786  match representation.repr {
787    Repr::Rust => unreachable!(),
788    Repr::C | Repr::CWithDiscriminant(_) => {
789      let (integer, defs) = get_enum_discriminant(input, crate_name)?;
790      let input_ident = &input.ident;
791
792      let bits_repr = Representation { repr: Repr::C, ..representation };
793
794      // the enum manually re-configured as the actual tagged union it
795      // represents, thus circumventing the requirements rust imposes on
796      // the tag even when using #[repr(C)] enum layout
797      // see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
798      let bits_ty_ident =
799        Ident::new(&format!("{input_ident}Bits"), input.span());
800
801      // the variants union part of the tagged union. These get put into a union
802      // which gets the AnyBitPattern derive applied to it, thus checking
803      // that the fields of the union obey the requriements of AnyBitPattern.
804      // The types that actually go in the union are one more level of
805      // indirection deep: we generate new structs for each variant
806      // (`variant_struct_definitions`) which themselves have the
807      // `CheckedBitPattern` derive applied, thus generating
808      // `{variant_struct_ident}Bits` structs, which are the ones that go
809      // into this union.
810      let variants_union_ident =
811        Ident::new(&format!("{}Variants", input.ident), input.span());
812
813      let variant_struct_idents = variants.iter().map(|v| {
814        Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
815      });
816
817      let variant_struct_definitions =
818        variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
819          let fields = v.fields.iter().map(|v| &v.ty);
820
821          quote! {
822            #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
823            #[repr(C)]
824            #vis struct #variant_struct_ident(#(#fields),*);
825          }
826        });
827
828      let union_fields = variant_struct_idents
829        .clone()
830        .zip(variants.iter())
831        .map(|(variant_struct_ident, v)| {
832          let variant_struct_bits_ident =
833            Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
834          let field_ident = &v.ident;
835          quote! {
836            #field_ident: #variant_struct_bits_ident
837          }
838        });
839
840      let variant_checks = variant_struct_idents
841        .clone()
842        .zip(VariantDiscriminantIterator::new(variants.iter()))
843        .zip(variants.iter())
844        .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
845          let (discriminant, _variant) = discriminant?;
846          let discriminant = LitInt::new(&discriminant.to_string(), v.span());
847          let ident = &v.ident;
848          Ok(quote! {
849            #discriminant => {
850              let payload = unsafe { &bits.payload.#ident };
851              <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
852            }
853          })
854        })
855        .collect::<Result<Vec<_>>>()?;
856
857      Ok((
858        quote! {
859          #defs
860
861          #[doc = #GENERATED_TYPE_DOCUMENTATION]
862          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
863          #bits_repr
864          #vis struct #bits_ty_ident {
865            tag: #integer,
866            payload: #variants_union_ident,
867          }
868
869          #[allow(unexpected_cfgs)]
870          const _: () = {
871            #[cfg(not(target_arch = "spirv"))]
872            impl ::core::fmt::Debug for #bits_ty_ident {
873              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
874                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
875                ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", &self.tag);
876                ::core::fmt::DebugStruct::field(&mut debug_struct, "payload", &self.payload);
877                ::core::fmt::DebugStruct::finish(&mut debug_struct)
878              }
879            }
880          };
881
882          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
883          #[repr(C)]
884          #[allow(non_snake_case)]
885          #vis union #variants_union_ident {
886            #(#union_fields,)*
887          }
888
889          #[allow(unexpected_cfgs)]
890          const _: () = {
891            #[cfg(not(target_arch = "spirv"))]
892            impl ::core::fmt::Debug for #variants_union_ident {
893              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
894                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
895                ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
896              }
897            }
898          };
899
900          #(#variant_struct_definitions)*
901        },
902        quote! {
903          type Bits = #bits_ty_ident;
904
905          #[inline]
906          #[allow(clippy::double_comparisons)]
907          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
908            match bits.tag {
909              #(#variant_checks)*
910              _ => false,
911            }
912          }
913        },
914      ))
915    }
916    Repr::Transparent => {
917      if variants.len() != 1 {
918        bail!("enums with more than one variant cannot be transparent")
919      }
920
921      let variant = &variants[0];
922
923      let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
924      let fields = variant.fields.iter().map(|v| &v.ty);
925
926      Ok((
927        quote! {
928          #[doc = #GENERATED_TYPE_DOCUMENTATION]
929          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
930          #[repr(C)]
931          #vis struct #bits_ty(#(#fields),*);
932        },
933        quote! {
934          type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits;
935
936          #[inline]
937          #[allow(clippy::double_comparisons)]
938          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
939            <#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits)
940          }
941        },
942      ))
943    }
944    Repr::Integer(integer) => {
945      let bits_repr = Representation { repr: Repr::C, ..representation };
946      let input_ident = &input.ident;
947
948      // the enum manually re-configured as the union it represents. such a
949      // union is the union of variants as a repr(c) struct with the
950      // discriminator type inserted at the beginning. in our case we
951      // union the `Bits` representation of each variant rather than the variant
952      // itself, which we generate via a nested `CheckedBitPattern` derive
953      // on the `variant_struct_definitions` generated below.
954      //
955      // see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
956      let bits_ty_ident =
957        Ident::new(&format!("{input_ident}Bits"), input.span());
958
959      let variant_struct_idents = variants.iter().map(|v| {
960        Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
961      });
962
963      let variant_struct_definitions =
964        variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
965          let fields = v.fields.iter().map(|v| &v.ty);
966
967          // adding the discriminant repr integer as first field, as described above
968          quote! {
969            #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
970            #[repr(C)]
971            #vis struct #variant_struct_ident(#integer, #(#fields),*);
972          }
973        });
974
975      let union_fields = variant_struct_idents
976        .clone()
977        .zip(variants.iter())
978        .map(|(variant_struct_ident, v)| {
979          let variant_struct_bits_ident =
980            Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
981          let field_ident = &v.ident;
982          quote! {
983            #field_ident: #variant_struct_bits_ident
984          }
985        });
986
987      let variant_checks = variant_struct_idents
988        .clone()
989        .zip(VariantDiscriminantIterator::new(variants.iter()))
990        .zip(variants.iter())
991        .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
992          let (discriminant, _variant) = discriminant?;
993          let discriminant = LitInt::new(&discriminant.to_string(), v.span());
994          let ident = &v.ident;
995          Ok(quote! {
996            #discriminant => {
997              let payload = unsafe { &bits.#ident };
998              <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
999            }
1000          })
1001        })
1002        .collect::<Result<Vec<_>>>()?;
1003
1004      Ok((
1005        quote! {
1006          #[doc = #GENERATED_TYPE_DOCUMENTATION]
1007          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
1008          #bits_repr
1009          #[allow(non_snake_case)]
1010          #vis union #bits_ty_ident {
1011            __tag: #integer,
1012            #(#union_fields,)*
1013          }
1014
1015          #[allow(unexpected_cfgs)]
1016          const _: () = {
1017            #[cfg(not(target_arch = "spirv"))]
1018            impl ::core::fmt::Debug for #bits_ty_ident {
1019              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1020                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
1021                ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
1022                ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
1023              }
1024            }
1025          };
1026
1027          #(#variant_struct_definitions)*
1028        },
1029        quote! {
1030          type Bits = #bits_ty_ident;
1031
1032          #[inline]
1033          #[allow(clippy::double_comparisons)]
1034          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
1035            match unsafe { bits.__tag } {
1036              #(#variant_checks)*
1037              _ => false,
1038            }
1039          }
1040        },
1041      ))
1042    }
1043  }
1044}
1045
1046/// Check that a struct or enum has no padding by asserting that the size of
1047/// the type is equal to the sum of the size of it's fields and discriminant
1048/// (for enums, this must be asserted for each variant).
1049fn generate_assert_no_padding(
1050  input: &DeriveInput, enum_variant: Option<&Variant>, for_trait: &str,
1051) -> Result<TokenStream> {
1052  let struct_type = &input.ident;
1053  let fields = get_fields(input, enum_variant)?;
1054
1055  // If the type is an enum, determine the type of its discriminant.
1056  let enum_discriminant = if matches!(input.data, Data::Enum(_)) {
1057    let ident =
1058      Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1059    Some(ident.into_token_stream())
1060  } else {
1061    None
1062  };
1063
1064  // Prepend the type of the discriminant to the types of the fields.
1065  let mut field_types = enum_discriminant
1066    .into_iter()
1067    .chain(get_field_types(&fields).map(ToTokens::to_token_stream));
1068  let size_sum = if let Some(first) = field_types.next() {
1069    let size_first = quote!(::core::mem::size_of::<#first>());
1070    let size_rest = quote!(#( + ::core::mem::size_of::<#field_types>() )*);
1071
1072    quote!(#size_first #size_rest)
1073  } else {
1074    quote!(0)
1075  };
1076
1077  let message =
1078    format!("derive({for_trait}) was applied to a type with padding");
1079
1080  Ok(quote! {const _: () = {
1081    assert!(
1082        ::core::mem::size_of::<#struct_type>() == (#size_sum),
1083        #message,
1084    );
1085  };})
1086}
1087
1088/// Check that all fields implement a given trait
1089fn generate_fields_are_trait(
1090  input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path,
1091) -> Result<TokenStream> {
1092  let (impl_generics, _ty_generics, where_clause) =
1093    input.generics.split_for_impl();
1094  let fields = get_fields(input, enum_variant)?;
1095  let field_types = get_field_types(&fields);
1096  Ok(quote! {#(const _: fn() = || {
1097      #[allow(clippy::missing_const_for_fn)]
1098      #[doc(hidden)]
1099      fn check #impl_generics () #where_clause {
1100        fn assert_impl<T: #trait_>() {}
1101        assert_impl::<#field_types>();
1102      }
1103    };)*
1104  })
1105}
1106
1107/// Get the type of an enum's discriminant.
1108///
1109/// For `repr(int)` and `repr(C, int)` enums, this will return the known bare
1110/// integer type specified.
1111///
1112/// For `repr(C)` enums, this will extract the underlying size chosen by rustc.
1113/// It will return a token stream which is a type expression that evaluates to
1114/// a primitive integer type of this size, using our `EnumTagIntegerBytes`
1115/// trait.
1116///
1117/// For fieldless `repr(C)` enums, we can feed the size of the enum directly
1118/// into the trait.
1119///
1120/// For `repr(C)` enums with fields, we generate a new fieldless `repr(C)` enum
1121/// with the same variants, then use that in the calculation. This is the
1122/// specified behavior, see https://doc.rust-lang.org/stable/reference/type-layout.html#reprc-enums-with-fields
1123///
1124/// Returns a tuple of (type ident, auxiliary definitions)
1125fn get_enum_discriminant(
1126  input: &DeriveInput, crate_name: &TokenStream,
1127) -> Result<(TokenStream, TokenStream)> {
1128  let repr = get_repr(&input.attrs)?;
1129  match repr.repr {
1130    Repr::C => {
1131      let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1132      if enum_has_fields(e.variants.iter()) {
1133        // If the enum has fields, we must first isolate the discriminant by
1134        // removing all the fields.
1135        let enum_discriminant = generate_enum_discriminant(input)?;
1136        let discriminant_ident = Ident::new(
1137          &format!("{}Discriminant", input.ident),
1138          input.ident.span(),
1139        );
1140        Ok((
1141          quote!(<[::core::primitive::u8; ::core::mem::size_of::<#discriminant_ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1142          quote! {
1143            #enum_discriminant
1144          },
1145        ))
1146      } else {
1147        // If the enum doesn't have fields, we can just use it directly.
1148        let ident = &input.ident;
1149        Ok((
1150          quote!(<[::core::primitive::u8; ::core::mem::size_of::<#ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1151          quote!(),
1152        ))
1153      }
1154    }
1155    Repr::Integer(integer) | Repr::CWithDiscriminant(integer) => {
1156      Ok((quote!(#integer), quote!()))
1157    }
1158    _ => unreachable!(),
1159  }
1160}
1161
1162fn generate_enum_discriminant(input: &DeriveInput) -> Result<TokenStream> {
1163  let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1164  let repr = get_repr(&input.attrs)?;
1165  let repr = match repr.repr {
1166    Repr::C => quote!(#[repr(C)]),
1167    Repr::Integer(int) | Repr::CWithDiscriminant(int) => quote!(#[repr(#int)]),
1168    Repr::Rust | Repr::Transparent => unreachable!(),
1169  };
1170  let ident =
1171    Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1172  let variants = e.variants.iter().cloned().map(|mut e| {
1173    e.fields = Fields::Unit;
1174    e
1175  });
1176  Ok(quote! {
1177    #repr
1178    #[allow(dead_code)]
1179    enum #ident {
1180      #(#variants,)*
1181    }
1182  })
1183}
1184
1185fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
1186  match tokens.into_iter().next() {
1187    Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
1188    Some(TokenTree::Ident(ident)) => Some(ident),
1189    _ => None,
1190  }
1191}
1192
1193/// get a simple #[foo(bar)] attribute, returning "bar"
1194fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
1195  for attr in attributes {
1196    if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) {
1197      if list.path.is_ident(attr_name) {
1198        if let Some(ident) = get_ident_from_stream(list.tokens.clone()) {
1199          return Some(ident);
1200        }
1201      }
1202    }
1203  }
1204
1205  None
1206}
1207
1208fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
1209  attributes
1210    .iter()
1211    .filter_map(|attr| {
1212      if attr.path().is_ident("repr") {
1213        Some(attr.parse_args::<Representation>())
1214      } else {
1215        None
1216      }
1217    })
1218    .try_fold(Representation::default(), |a, b| {
1219      let b = b?;
1220      Ok(Representation {
1221        repr: match (a.repr, b.repr) {
1222          (a, Repr::Rust) => a,
1223          (Repr::Rust, b) => b,
1224          _ => bail!("conflicting representation hints"),
1225        },
1226        packed: match (a.packed, b.packed) {
1227          (a, None) => a,
1228          (None, b) => b,
1229          _ => bail!("conflicting representation hints"),
1230        },
1231        align: match (a.align, b.align) {
1232          (Some(a), Some(b)) => Some(cmp::max(a, b)),
1233          (a, None) => a,
1234          (None, b) => b,
1235        },
1236      })
1237    })
1238}
1239
1240mk_repr! {
1241  U8 => u8,
1242  I8 => i8,
1243  U16 => u16,
1244  I16 => i16,
1245  U32 => u32,
1246  I32 => i32,
1247  U64 => u64,
1248  I64 => i64,
1249  I128 => i128,
1250  U128 => u128,
1251  Usize => usize,
1252  Isize => isize,
1253}
1254// where
1255macro_rules! mk_repr {(
1256  $(
1257    $Xn:ident => $xn:ident
1258  ),* $(,)?
1259) => (
1260  #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1261  enum IntegerRepr {
1262    $($Xn),*
1263  }
1264
1265  impl<'a> TryFrom<&'a str> for IntegerRepr {
1266    type Error = &'a str;
1267
1268    fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
1269      match value {
1270        $(
1271          stringify!($xn) => Ok(Self::$Xn),
1272        )*
1273        _ => Err(value),
1274      }
1275    }
1276  }
1277
1278  impl ToTokens for IntegerRepr {
1279    fn to_tokens(&self, tokens: &mut TokenStream) {
1280      match self {
1281        $(
1282          Self::$Xn => tokens.extend(quote!($xn)),
1283        )*
1284      }
1285    }
1286  }
1287)}
1288use mk_repr;
1289
1290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1291enum Repr {
1292  Rust,
1293  C,
1294  Transparent,
1295  Integer(IntegerRepr),
1296  CWithDiscriminant(IntegerRepr),
1297}
1298
1299impl Repr {
1300  fn as_integer(&self) -> Option<IntegerRepr> {
1301    if let Self::Integer(v) = self {
1302      Some(*v)
1303    } else {
1304      None
1305    }
1306  }
1307}
1308
1309#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1310struct Representation {
1311  packed: Option<u32>,
1312  align: Option<u32>,
1313  repr: Repr,
1314}
1315
1316impl Default for Representation {
1317  fn default() -> Self {
1318    Self { packed: None, align: None, repr: Repr::Rust }
1319  }
1320}
1321
1322impl Parse for Representation {
1323  fn parse(input: ParseStream<'_>) -> Result<Representation> {
1324    let mut ret = Representation::default();
1325    while !input.is_empty() {
1326      let keyword = input.parse::<Ident>()?;
1327      // preƫmptively call `.to_string()` *once* (rather than on `is_ident()`)
1328      let keyword_str = keyword.to_string();
1329      let new_repr = match keyword_str.as_str() {
1330        "C" => Repr::C,
1331        "transparent" => Repr::Transparent,
1332        "packed" => {
1333          ret.packed = Some(if input.peek(token::Paren) {
1334            let contents;
1335            parenthesized!(contents in input);
1336            LitInt::base10_parse::<u32>(&contents.parse()?)?
1337          } else {
1338            1
1339          });
1340          let _: Option<Token![,]> = input.parse()?;
1341          continue;
1342        }
1343        "align" => {
1344          let contents;
1345          parenthesized!(contents in input);
1346          let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
1347          ret.align = Some(
1348            ret
1349              .align
1350              .map_or(new_align, |old_align| cmp::max(old_align, new_align)),
1351          );
1352          let _: Option<Token![,]> = input.parse()?;
1353          continue;
1354        }
1355        ident => {
1356          let primitive = IntegerRepr::try_from(ident)
1357            .map_err(|_| input.error("unrecognized representation hint"))?;
1358          Repr::Integer(primitive)
1359        }
1360      };
1361      ret.repr = match (ret.repr, new_repr) {
1362        (Repr::Rust, new_repr) => {
1363          // This is the first explicit repr.
1364          new_repr
1365        }
1366        (Repr::C, Repr::Integer(integer))
1367        | (Repr::Integer(integer), Repr::C) => {
1368          // Both the C repr and an integer repr have been specified
1369          // -> merge into a C wit discriminant.
1370          Repr::CWithDiscriminant(integer)
1371        }
1372        (_, _) => {
1373          return Err(input.error("duplicate representation hint"));
1374        }
1375      };
1376      let _: Option<Token![,]> = input.parse()?;
1377    }
1378    Ok(ret)
1379  }
1380}
1381
1382impl ToTokens for Representation {
1383  fn to_tokens(&self, tokens: &mut TokenStream) {
1384    let mut meta = Punctuated::<_, Token![,]>::new();
1385
1386    match self.repr {
1387      Repr::Rust => {}
1388      Repr::C => meta.push(quote!(C)),
1389      Repr::Transparent => meta.push(quote!(transparent)),
1390      Repr::Integer(primitive) => meta.push(quote!(#primitive)),
1391      Repr::CWithDiscriminant(primitive) => {
1392        meta.push(quote!(C));
1393        meta.push(quote!(#primitive));
1394      }
1395    }
1396
1397    if let Some(packed) = self.packed.as_ref() {
1398      let lit = LitInt::new(&packed.to_string(), Span::call_site());
1399      meta.push(quote!(packed(#lit)));
1400    }
1401
1402    if let Some(align) = self.align.as_ref() {
1403      let lit = LitInt::new(&align.to_string(), Span::call_site());
1404      meta.push(quote!(align(#lit)));
1405    }
1406
1407    tokens.extend(quote!(
1408      #[repr(#meta)]
1409    ));
1410  }
1411}
1412
1413fn enum_has_fields<'a>(
1414  mut variants: impl Iterator<Item = &'a Variant>,
1415) -> bool {
1416  variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
1417}
1418
1419struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
1420  inner: I,
1421  last_value: i128,
1422}
1423
1424impl<'a, I: Iterator<Item = &'a Variant> + 'a>
1425  VariantDiscriminantIterator<'a, I>
1426{
1427  fn new(inner: I) -> Self {
1428    VariantDiscriminantIterator { inner, last_value: -1 }
1429  }
1430}
1431
1432impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
1433  for VariantDiscriminantIterator<'a, I>
1434{
1435  type Item = Result<(i128, &'a Variant)>;
1436
1437  fn next(&mut self) -> Option<Self::Item> {
1438    let variant = self.inner.next()?;
1439
1440    if let Some((_, discriminant)) = &variant.discriminant {
1441      let discriminant_value = match parse_int_expr(discriminant) {
1442        Ok(value) => value,
1443        Err(e) => return Some(Err(e)),
1444      };
1445      self.last_value = discriminant_value;
1446    } else {
1447      // If this wraps, then either:
1448      // 1. the enum is using repr(u128), so wrapping is correct
1449      // 2. the enum is using repr(i<=128 or u<128), so the compiler will
1450      //    already emit a "wrapping discriminant" E0370 error.
1451      self.last_value = self.last_value.wrapping_add(1);
1452      // Static assert that there is no integer repr > 128 bits. If that
1453      // changes, the above comment is inaccurate and needs to be updated!
1454      // FIXME(zachs18): maybe should also do something to ensure `isize::BITS
1455      // <= 128`?
1456      if let Some(repr) = None::<IntegerRepr> {
1457        match repr {
1458          IntegerRepr::U8
1459          | IntegerRepr::I8
1460          | IntegerRepr::U16
1461          | IntegerRepr::I16
1462          | IntegerRepr::U32
1463          | IntegerRepr::I32
1464          | IntegerRepr::U64
1465          | IntegerRepr::I64
1466          | IntegerRepr::I128
1467          | IntegerRepr::U128
1468          | IntegerRepr::Usize
1469          | IntegerRepr::Isize => (),
1470        }
1471      }
1472    }
1473
1474    Some(Ok((self.last_value, variant)))
1475  }
1476}
1477
1478fn parse_int_expr(expr: &Expr) -> Result<i128> {
1479  match expr {
1480    Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
1481      parse_int_expr(expr).map(|int| -int)
1482    }
1483    Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
1484    Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()),
1485    _ => bail!("Not an integer expression"),
1486  }
1487}
1488
1489#[cfg(test)]
1490mod tests {
1491  use syn::parse_quote;
1492
1493  use super::{get_repr, IntegerRepr, Repr, Representation};
1494
1495  #[test]
1496  fn parse_basic_repr() {
1497    let attr = parse_quote!(#[repr(C)]);
1498    let repr = get_repr(&[attr]).unwrap();
1499    assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
1500
1501    let attr = parse_quote!(#[repr(transparent)]);
1502    let repr = get_repr(&[attr]).unwrap();
1503    assert_eq!(
1504      repr,
1505      Representation { repr: Repr::Transparent, ..Default::default() }
1506    );
1507
1508    let attr = parse_quote!(#[repr(u8)]);
1509    let repr = get_repr(&[attr]).unwrap();
1510    assert_eq!(
1511      repr,
1512      Representation {
1513        repr: Repr::Integer(IntegerRepr::U8),
1514        ..Default::default()
1515      }
1516    );
1517
1518    let attr = parse_quote!(#[repr(packed)]);
1519    let repr = get_repr(&[attr]).unwrap();
1520    assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1521
1522    let attr = parse_quote!(#[repr(packed(1))]);
1523    let repr = get_repr(&[attr]).unwrap();
1524    assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1525
1526    let attr = parse_quote!(#[repr(packed(2))]);
1527    let repr = get_repr(&[attr]).unwrap();
1528    assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
1529
1530    let attr = parse_quote!(#[repr(align(2))]);
1531    let repr = get_repr(&[attr]).unwrap();
1532    assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
1533  }
1534
1535  #[test]
1536  fn parse_advanced_repr() {
1537    let attr = parse_quote!(#[repr(align(4), align(2))]);
1538    let repr = get_repr(&[attr]).unwrap();
1539    assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1540
1541    let attr1 = parse_quote!(#[repr(align(1))]);
1542    let attr2 = parse_quote!(#[repr(align(4))]);
1543    let attr3 = parse_quote!(#[repr(align(2))]);
1544    let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
1545    assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1546
1547    let attr = parse_quote!(#[repr(C, u8)]);
1548    let repr = get_repr(&[attr]).unwrap();
1549    assert_eq!(
1550      repr,
1551      Representation {
1552        repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1553        ..Default::default()
1554      }
1555    );
1556
1557    let attr = parse_quote!(#[repr(u8, C)]);
1558    let repr = get_repr(&[attr]).unwrap();
1559    assert_eq!(
1560      repr,
1561      Representation {
1562        repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1563        ..Default::default()
1564      }
1565    );
1566  }
1567}
1568
1569pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream {
1570  const ATTR_NAME: &'static str = "crate";
1571
1572  let mut crate_name = quote!(::bytemuck);
1573  for attr in &input.attrs {
1574    if !attr.path().is_ident("bytemuck") {
1575      continue;
1576    }
1577
1578    attr.parse_nested_meta(|meta| {
1579      if meta.path.is_ident(ATTR_NAME) {
1580        let expr: syn::Expr = meta.value()?.parse()?;
1581        let mut value = &expr;
1582        while let syn::Expr::Group(e) = value {
1583          value = &e.expr;
1584        }
1585        if let syn::Expr::Lit(syn::ExprLit {
1586          lit: syn::Lit::Str(lit), ..
1587        }) = value
1588        {
1589          let suffix = lit.suffix();
1590          if !suffix.is_empty() {
1591            bail!(format!("Unexpected suffix `{}` on string literal", suffix))
1592          }
1593          let path: syn::Path = match lit.parse() {
1594            Ok(path) => path,
1595            Err(_) => {
1596              bail!(format!("Failed to parse path: {:?}", lit.value()))
1597            }
1598          };
1599          crate_name = path.into_token_stream();
1600        } else {
1601          bail!(
1602            "Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`",
1603          )
1604        }
1605      }
1606      Ok(())
1607    }).unwrap();
1608  }
1609
1610  return crate_name;
1611}
1612
1613const GENERATED_TYPE_DOCUMENTATION: &str =
1614  " `bytemuck`-generated type for internal purposes only.";