ref_cast_impl/
lib.rs

1#![allow(
2    clippy::blocks_in_conditions,
3    clippy::needless_pass_by_value,
4    clippy::if_not_else
5)]
6
7extern crate proc_macro;
8
9use proc_macro::TokenStream;
10use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree};
11use quote::{quote, quote_spanned};
12use syn::parse::{Nothing, ParseStream, Parser};
13use syn::punctuated::Punctuated;
14use syn::{
15    parenthesized, parse_macro_input, token, Abi, Attribute, Data, DeriveInput, Error, Expr, Field,
16    Generics, Path, Result, Token, Type, Visibility,
17};
18
19/// Derive the `RefCast` trait.
20///
21/// See the [crate-level documentation](./index.html) for usage examples!
22///
23/// # Attributes
24///
25/// Use the `#[trivial]` attribute to mark any zero-sized fields that are *not*
26/// the one that references are going to be converted from.
27///
28/// ```
29/// use ref_cast::RefCast;
30/// use std::marker::PhantomData;
31///
32/// #[derive(RefCast)]
33/// #[repr(transparent)]
34/// pub struct Generic<T, U> {
35///     raw: Vec<U>,
36///     #[trivial]
37///     aux: Variance<T, U>,
38/// }
39///
40/// type Variance<T, U> = PhantomData<fn(T) -> U>;
41/// ```
42///
43/// Fields with a type named `PhantomData` or `PhantomPinned` are automatically
44/// recognized and do not need to be marked with this attribute.
45///
46/// ```
47/// use ref_cast::RefCast;
48/// use std::marker::{PhantomData, PhantomPinned};
49///
50/// #[derive(RefCast)]  // generates a conversion from &[u8] to &Bytes<'_>
51/// #[repr(transparent)]
52/// pub struct Bytes<'arena> {
53///     lifetime: PhantomData<&'arena ()>,
54///     pin: PhantomPinned,
55///     bytes: [u8],
56/// }
57/// ```
58#[proc_macro_derive(RefCast, attributes(trivial))]
59pub fn derive_ref_cast(input: TokenStream) -> TokenStream {
60    let input = parse_macro_input!(input as DeriveInput);
61    expand_ref_cast(input)
62        .unwrap_or_else(Error::into_compile_error)
63        .into()
64}
65
66/// Derive that makes the `ref_cast_custom` attribute able to generate
67/// freestanding reference casting functions for a type.
68///
69/// Please refer to the documentation of
70/// [`#[ref_cast_custom]`][macro@ref_cast_custom] where these two macros are
71/// documented together.
72#[proc_macro_derive(RefCastCustom, attributes(trivial))]
73pub fn derive_ref_cast_custom(input: TokenStream) -> TokenStream {
74    let input = parse_macro_input!(input as DeriveInput);
75    expand_ref_cast_custom(input)
76        .unwrap_or_else(Error::into_compile_error)
77        .into()
78}
79
80/// Create a function for a RefCast-style reference cast. Call site gets control
81/// of the visibility, function name, argument name, `const`ness, unsafety, and
82/// documentation.
83///
84/// The `derive(RefCast)` macro produces a trait impl, which means the function
85/// names are predefined, and public if your type is public, and not callable in
86/// `const` (at least today on stable Rust). As an alternative to that,
87/// `derive(RefCastCustom)` exposes greater flexibility so that instead of a
88/// trait impl, the casting functions can be made associated functions or free
89/// functions, can be named what you want, documented, `const` or `unsafe` if
90/// you want, and have your exact choice of visibility.
91///
92/// ```rust
93/// use ref_cast::{ref_cast_custom, RefCastCustom};
94///
95/// #[derive(RefCastCustom)]  // does not generate any public API by itself
96/// #[repr(transparent)]
97/// pub struct Frame([u8]);
98///
99/// impl Frame {
100///     #[ref_cast_custom]  // requires derive(RefCastCustom) on the return type
101///     pub(crate) const fn new(bytes: &[u8]) -> &Self;
102///
103///     #[ref_cast_custom]
104///     pub(crate) fn new_mut(bytes: &mut [u8]) -> &mut Self;
105/// }
106///
107/// // example use of the const fn
108/// const FRAME: &Frame = Frame::new(b"...");
109/// ```
110///
111/// The above shows associated functions, but you might alternatively want to
112/// generate free functions:
113///
114/// ```rust
115/// # use ref_cast::{ref_cast_custom, RefCastCustom};
116/// #
117/// # #[derive(RefCastCustom)]
118/// # #[repr(transparent)]
119/// # pub struct Frame([u8]);
120/// #
121/// impl Frame {
122///     pub fn new<T: AsRef<[u8]>>(bytes: &T) -> &Self {
123///         #[ref_cast_custom]
124///         fn ref_cast(bytes: &[u8]) -> &Frame;
125///
126///         ref_cast(bytes.as_ref())
127///     }
128/// }
129/// ```
130#[proc_macro_attribute]
131pub fn ref_cast_custom(args: TokenStream, input: TokenStream) -> TokenStream {
132    let input = TokenStream2::from(input);
133    let expanded = match (|input: ParseStream| {
134        let attrs = input.call(Attribute::parse_outer)?;
135        let vis: Visibility = input.parse()?;
136        let constness: Option<Token![const]> = input.parse()?;
137        let asyncness: Option<Token![async]> = input.parse()?;
138        let unsafety: Option<Token![unsafe]> = input.parse()?;
139        let abi: Option<Abi> = input.parse()?;
140        let fn_token: Token![fn] = input.parse()?;
141        let ident: Ident = input.parse()?;
142        let mut generics: Generics = input.parse()?;
143
144        let content;
145        let paren_token = parenthesized!(content in input);
146        let arg: Ident = content.parse()?;
147        let colon_token: Token![:] = content.parse()?;
148        let from_type: Type = content.parse()?;
149        let _trailing_comma: Option<Token![,]> = content.parse()?;
150        if !content.is_empty() {
151            let rest: TokenStream2 = content.parse()?;
152            return Err(Error::new_spanned(
153                rest,
154                "ref_cast_custom function is required to have a single argument",
155            ));
156        }
157
158        let arrow_token: Token![->] = input.parse()?;
159        let to_type: Type = input.parse()?;
160        generics.where_clause = input.parse()?;
161        let semi_token: Token![;] = input.parse()?;
162
163        let _: Nothing = syn::parse::<Nothing>(args)?;
164
165        Ok(Function {
166            attrs,
167            vis,
168            constness,
169            asyncness,
170            unsafety,
171            abi,
172            fn_token,
173            ident,
174            generics,
175            paren_token,
176            arg,
177            colon_token,
178            from_type,
179            arrow_token,
180            to_type,
181            semi_token,
182        })
183    })
184    .parse2(input.clone())
185    {
186        Ok(function) => expand_function_body(function),
187        Err(parse_error) => {
188            let compile_error = parse_error.to_compile_error();
189            quote!(#compile_error #input)
190        }
191    };
192    TokenStream::from(expanded)
193}
194
195struct Function {
196    attrs: Vec<Attribute>,
197    vis: Visibility,
198    constness: Option<Token![const]>,
199    asyncness: Option<Token![async]>,
200    unsafety: Option<Token![unsafe]>,
201    abi: Option<Abi>,
202    fn_token: Token![fn],
203    ident: Ident,
204    generics: Generics,
205    paren_token: token::Paren,
206    arg: Ident,
207    colon_token: Token![:],
208    from_type: Type,
209    arrow_token: Token![->],
210    to_type: Type,
211    semi_token: Token![;],
212}
213
214fn expand_ref_cast(input: DeriveInput) -> Result<TokenStream2> {
215    check_repr(&input)?;
216
217    let name = &input.ident;
218    let name_str = name.to_string();
219    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
220
221    let fields = fields(&input)?;
222    let from = only_field_ty(fields)?;
223    let trivial = trivial_fields(fields)?;
224
225    let assert_trivial_fields = if !trivial.is_empty() {
226        Some(quote! {
227            if false {
228                #(
229                    ::ref_cast::__private::assert_trivial::<#trivial>();
230                )*
231            }
232        })
233    } else {
234        None
235    };
236
237    Ok(quote! {
238        impl #impl_generics ::ref_cast::RefCast for #name #ty_generics #where_clause {
239            type From = #from;
240
241            #[inline]
242            fn ref_cast(_from: &Self::From) -> &Self {
243                #assert_trivial_fields
244                #[cfg(debug_assertions)]
245                {
246                    #[allow(unused_imports)]
247                    use ::ref_cast::__private::LayoutUnsized;
248                    ::ref_cast::__private::assert_layout::<Self, Self::From>(
249                        #name_str,
250                        ::ref_cast::__private::Layout::<Self>::SIZE,
251                        ::ref_cast::__private::Layout::<Self::From>::SIZE,
252                        ::ref_cast::__private::Layout::<Self>::ALIGN,
253                        ::ref_cast::__private::Layout::<Self::From>::ALIGN,
254                    );
255                }
256                unsafe {
257                    &*(_from as *const Self::From as *const Self)
258                }
259            }
260
261            #[inline]
262            fn ref_cast_mut(_from: &mut Self::From) -> &mut Self {
263                #[cfg(debug_assertions)]
264                {
265                    #[allow(unused_imports)]
266                    use ::ref_cast::__private::LayoutUnsized;
267                    ::ref_cast::__private::assert_layout::<Self, Self::From>(
268                        #name_str,
269                        ::ref_cast::__private::Layout::<Self>::SIZE,
270                        ::ref_cast::__private::Layout::<Self::From>::SIZE,
271                        ::ref_cast::__private::Layout::<Self>::ALIGN,
272                        ::ref_cast::__private::Layout::<Self::From>::ALIGN,
273                    );
274                }
275                unsafe {
276                    &mut *(_from as *mut Self::From as *mut Self)
277                }
278            }
279        }
280    })
281}
282
283fn expand_ref_cast_custom(input: DeriveInput) -> Result<TokenStream2> {
284    check_repr(&input)?;
285
286    let vis = &input.vis;
287    let name = &input.ident;
288    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
289
290    let fields = fields(&input)?;
291    let from = only_field_ty(fields)?;
292    let trivial = trivial_fields(fields)?;
293
294    let assert_trivial_fields = if !trivial.is_empty() {
295        Some(quote! {
296            fn __static_assert() {
297                if false {
298                    #(
299                        ::ref_cast::__private::assert_trivial::<#trivial>();
300                    )*
301                }
302            }
303        })
304    } else {
305        None
306    };
307
308    Ok(quote! {
309        const _: () = {
310            #[non_exhaustive]
311            #vis struct RefCastCurrentCrate {}
312
313            unsafe impl #impl_generics ::ref_cast::__private::RefCastCustom<#from> for #name #ty_generics #where_clause {
314                type CurrentCrate = RefCastCurrentCrate;
315                #assert_trivial_fields
316            }
317        };
318    })
319}
320
321fn expand_function_body(function: Function) -> TokenStream2 {
322    let Function {
323        attrs,
324        vis,
325        constness,
326        asyncness,
327        unsafety,
328        abi,
329        fn_token,
330        ident,
331        generics,
332        paren_token,
333        arg,
334        colon_token,
335        from_type,
336        arrow_token,
337        to_type,
338        semi_token,
339    } = function;
340
341    let args = quote_spanned! {paren_token.span=>
342        (#arg #colon_token #from_type)
343    };
344
345    let allow_unused_unsafe = if unsafety.is_some() {
346        Some(quote!(#[allow(unused_unsafe)]))
347    } else {
348        None
349    };
350
351    let mut inline_attr = Some(quote!(#[inline]));
352    for attr in &attrs {
353        if attr.path().is_ident("inline") {
354            inline_attr = None;
355            break;
356        }
357    }
358
359    // Apply a macro-generated span to the "unsafe" token for the unsafe block.
360    // This is instead of reusing the caller's function signature's #unsafety
361    // across both the generated function signature and generated unsafe block,
362    // and instead of using `semi_token.span` like for the rest of the generated
363    // code below, both of which would cause `forbid(unsafe_code)` located in
364    // the caller to reject the expanded code.
365    let macro_generated_unsafe = quote!(unsafe);
366
367    quote_spanned! {semi_token.span=>
368        #(#attrs)*
369        #inline_attr
370        #vis #constness #asyncness #unsafety #abi
371        #fn_token #ident #generics #args #arrow_token #to_type {
372            // check lifetime
373            let _ = || {
374                ::ref_cast::__private::ref_cast_custom::<#from_type, #to_type>(#arg);
375            };
376
377            // check same crate
378            let _ = ::ref_cast::__private::CurrentCrate::<#from_type, #to_type> {};
379
380            #allow_unused_unsafe // in case they are building with deny(unsafe_op_in_unsafe_fn)
381            #[allow(clippy::transmute_ptr_to_ptr)]
382            #macro_generated_unsafe {
383                ::ref_cast::__private::transmute::<#from_type, #to_type>(#arg)
384            }
385        }
386    }
387}
388
389fn check_repr(input: &DeriveInput) -> Result<()> {
390    let mut has_repr = false;
391    let mut errors = None;
392    let mut push_error = |error| match &mut errors {
393        Some(errors) => Error::combine(errors, error),
394        None => errors = Some(error),
395    };
396
397    for attr in &input.attrs {
398        if attr.path().is_ident("repr") {
399            if let Err(error) = attr.parse_args_with(|input: ParseStream| {
400                while !input.is_empty() {
401                    let path = input.call(Path::parse_mod_style)?;
402                    if path.is_ident("transparent") || path.is_ident("C") {
403                        has_repr = true;
404                    } else if path.is_ident("packed") {
405                        // ignore
406                    } else {
407                        let meta_item_span = if input.peek(token::Paren) {
408                            let group: TokenTree = input.parse()?;
409                            quote!(#path #group)
410                        } else if input.peek(Token![=]) {
411                            let eq_token: Token![=] = input.parse()?;
412                            let value: Expr = input.parse()?;
413                            quote!(#path #eq_token #value)
414                        } else {
415                            quote!(#path)
416                        };
417                        let msg = if path.is_ident("align") {
418                            "aligned repr on struct that implements RefCast is not supported"
419                        } else {
420                            "unrecognized repr on struct that implements RefCast"
421                        };
422                        push_error(Error::new_spanned(meta_item_span, msg));
423                    }
424                    if !input.is_empty() {
425                        input.parse::<Token![,]>()?;
426                    }
427                }
428                Ok(())
429            }) {
430                push_error(error);
431            }
432        }
433    }
434
435    if !has_repr {
436        let mut requires_repr = Error::new(
437            Span::call_site(),
438            "RefCast trait requires #[repr(transparent)]",
439        );
440        if let Some(errors) = errors {
441            requires_repr.combine(errors);
442        }
443        errors = Some(requires_repr);
444    }
445
446    match errors {
447        None => Ok(()),
448        Some(errors) => Err(errors),
449    }
450}
451
452type Fields = Punctuated<Field, Token![,]>;
453
454fn fields(input: &DeriveInput) -> Result<&Fields> {
455    use syn::Fields;
456
457    match &input.data {
458        Data::Struct(data) => match &data.fields {
459            Fields::Named(fields) => Ok(&fields.named),
460            Fields::Unnamed(fields) => Ok(&fields.unnamed),
461            Fields::Unit => Err(Error::new(
462                Span::call_site(),
463                "RefCast does not support unit structs",
464            )),
465        },
466        Data::Enum(_) => Err(Error::new(
467            Span::call_site(),
468            "RefCast does not support enums",
469        )),
470        Data::Union(_) => Err(Error::new(
471            Span::call_site(),
472            "RefCast does not support unions",
473        )),
474    }
475}
476
477fn only_field_ty(fields: &Fields) -> Result<&Type> {
478    let is_trivial = decide_trivial(fields)?;
479    let mut only_field = None;
480
481    for field in fields {
482        if !is_trivial(field)? {
483            if only_field.take().is_some() {
484                break;
485            }
486            only_field = Some(&field.ty);
487        }
488    }
489
490    only_field.ok_or_else(|| {
491        Error::new(
492            Span::call_site(),
493            "RefCast requires a struct with a single field",
494        )
495    })
496}
497
498fn trivial_fields(fields: &Fields) -> Result<Vec<&Type>> {
499    let is_trivial = decide_trivial(fields)?;
500    let mut trivial = Vec::new();
501
502    for field in fields {
503        if is_trivial(field)? {
504            trivial.push(&field.ty);
505        }
506    }
507
508    Ok(trivial)
509}
510
511fn decide_trivial(fields: &Fields) -> Result<fn(&Field) -> Result<bool>> {
512    for field in fields {
513        if is_explicit_trivial(field)? {
514            return Ok(is_explicit_trivial);
515        }
516    }
517    Ok(is_implicit_trivial)
518}
519
520#[allow(clippy::unnecessary_wraps)] // match signature of is_explicit_trivial
521fn is_implicit_trivial(field: &Field) -> Result<bool> {
522    match &field.ty {
523        Type::Tuple(ty) => Ok(ty.elems.is_empty()),
524        Type::Path(ty) => {
525            let ident = &ty.path.segments.last().unwrap().ident;
526            Ok(ident == "PhantomData" || ident == "PhantomPinned")
527        }
528        _ => Ok(false),
529    }
530}
531
532fn is_explicit_trivial(field: &Field) -> Result<bool> {
533    for attr in &field.attrs {
534        if attr.path().is_ident("trivial") {
535            attr.meta.require_path_only()?;
536            return Ok(true);
537        }
538    }
539    Ok(false)
540}