1#![allow(
2 clippy::blocks_in_conditions,
3 clippy::needless_pass_by_value,
4 clippy::if_not_else
5)]
6
7extern crate proc_macro;
8
9use proc_macro::TokenStream;
10use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree};
11use quote::{quote, quote_spanned};
12use syn::parse::{Nothing, ParseStream, Parser};
13use syn::punctuated::Punctuated;
14use syn::{
15 parenthesized, parse_macro_input, token, Abi, Attribute, Data, DeriveInput, Error, Expr, Field,
16 Generics, Path, Result, Token, Type, Visibility,
17};
18
19#[proc_macro_derive(RefCast, attributes(trivial))]
59pub fn derive_ref_cast(input: TokenStream) -> TokenStream {
60 let input = parse_macro_input!(input as DeriveInput);
61 expand_ref_cast(input)
62 .unwrap_or_else(Error::into_compile_error)
63 .into()
64}
65
66#[proc_macro_derive(RefCastCustom, attributes(trivial))]
73pub fn derive_ref_cast_custom(input: TokenStream) -> TokenStream {
74 let input = parse_macro_input!(input as DeriveInput);
75 expand_ref_cast_custom(input)
76 .unwrap_or_else(Error::into_compile_error)
77 .into()
78}
79
80#[proc_macro_attribute]
131pub fn ref_cast_custom(args: TokenStream, input: TokenStream) -> TokenStream {
132 let input = TokenStream2::from(input);
133 let expanded = match (|input: ParseStream| {
134 let attrs = input.call(Attribute::parse_outer)?;
135 let vis: Visibility = input.parse()?;
136 let constness: Option<Token![const]> = input.parse()?;
137 let asyncness: Option<Token![async]> = input.parse()?;
138 let unsafety: Option<Token![unsafe]> = input.parse()?;
139 let abi: Option<Abi> = input.parse()?;
140 let fn_token: Token![fn] = input.parse()?;
141 let ident: Ident = input.parse()?;
142 let mut generics: Generics = input.parse()?;
143
144 let content;
145 let paren_token = parenthesized!(content in input);
146 let arg: Ident = content.parse()?;
147 let colon_token: Token![:] = content.parse()?;
148 let from_type: Type = content.parse()?;
149 let _trailing_comma: Option<Token![,]> = content.parse()?;
150 if !content.is_empty() {
151 let rest: TokenStream2 = content.parse()?;
152 return Err(Error::new_spanned(
153 rest,
154 "ref_cast_custom function is required to have a single argument",
155 ));
156 }
157
158 let arrow_token: Token![->] = input.parse()?;
159 let to_type: Type = input.parse()?;
160 generics.where_clause = input.parse()?;
161 let semi_token: Token![;] = input.parse()?;
162
163 let _: Nothing = syn::parse::<Nothing>(args)?;
164
165 Ok(Function {
166 attrs,
167 vis,
168 constness,
169 asyncness,
170 unsafety,
171 abi,
172 fn_token,
173 ident,
174 generics,
175 paren_token,
176 arg,
177 colon_token,
178 from_type,
179 arrow_token,
180 to_type,
181 semi_token,
182 })
183 })
184 .parse2(input.clone())
185 {
186 Ok(function) => expand_function_body(function),
187 Err(parse_error) => {
188 let compile_error = parse_error.to_compile_error();
189 quote!(#compile_error #input)
190 }
191 };
192 TokenStream::from(expanded)
193}
194
195struct Function {
196 attrs: Vec<Attribute>,
197 vis: Visibility,
198 constness: Option<Token![const]>,
199 asyncness: Option<Token![async]>,
200 unsafety: Option<Token![unsafe]>,
201 abi: Option<Abi>,
202 fn_token: Token![fn],
203 ident: Ident,
204 generics: Generics,
205 paren_token: token::Paren,
206 arg: Ident,
207 colon_token: Token![:],
208 from_type: Type,
209 arrow_token: Token![->],
210 to_type: Type,
211 semi_token: Token![;],
212}
213
214fn expand_ref_cast(input: DeriveInput) -> Result<TokenStream2> {
215 check_repr(&input)?;
216
217 let name = &input.ident;
218 let name_str = name.to_string();
219 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
220
221 let fields = fields(&input)?;
222 let from = only_field_ty(fields)?;
223 let trivial = trivial_fields(fields)?;
224
225 let assert_trivial_fields = if !trivial.is_empty() {
226 Some(quote! {
227 if false {
228 #(
229 ::ref_cast::__private::assert_trivial::<#trivial>();
230 )*
231 }
232 })
233 } else {
234 None
235 };
236
237 Ok(quote! {
238 impl #impl_generics ::ref_cast::RefCast for #name #ty_generics #where_clause {
239 type From = #from;
240
241 #[inline]
242 fn ref_cast(_from: &Self::From) -> &Self {
243 #assert_trivial_fields
244 #[cfg(debug_assertions)]
245 {
246 #[allow(unused_imports)]
247 use ::ref_cast::__private::LayoutUnsized;
248 ::ref_cast::__private::assert_layout::<Self, Self::From>(
249 #name_str,
250 ::ref_cast::__private::Layout::<Self>::SIZE,
251 ::ref_cast::__private::Layout::<Self::From>::SIZE,
252 ::ref_cast::__private::Layout::<Self>::ALIGN,
253 ::ref_cast::__private::Layout::<Self::From>::ALIGN,
254 );
255 }
256 unsafe {
257 &*(_from as *const Self::From as *const Self)
258 }
259 }
260
261 #[inline]
262 fn ref_cast_mut(_from: &mut Self::From) -> &mut Self {
263 #[cfg(debug_assertions)]
264 {
265 #[allow(unused_imports)]
266 use ::ref_cast::__private::LayoutUnsized;
267 ::ref_cast::__private::assert_layout::<Self, Self::From>(
268 #name_str,
269 ::ref_cast::__private::Layout::<Self>::SIZE,
270 ::ref_cast::__private::Layout::<Self::From>::SIZE,
271 ::ref_cast::__private::Layout::<Self>::ALIGN,
272 ::ref_cast::__private::Layout::<Self::From>::ALIGN,
273 );
274 }
275 unsafe {
276 &mut *(_from as *mut Self::From as *mut Self)
277 }
278 }
279 }
280 })
281}
282
283fn expand_ref_cast_custom(input: DeriveInput) -> Result<TokenStream2> {
284 check_repr(&input)?;
285
286 let vis = &input.vis;
287 let name = &input.ident;
288 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
289
290 let fields = fields(&input)?;
291 let from = only_field_ty(fields)?;
292 let trivial = trivial_fields(fields)?;
293
294 let assert_trivial_fields = if !trivial.is_empty() {
295 Some(quote! {
296 fn __static_assert() {
297 if false {
298 #(
299 ::ref_cast::__private::assert_trivial::<#trivial>();
300 )*
301 }
302 }
303 })
304 } else {
305 None
306 };
307
308 Ok(quote! {
309 const _: () = {
310 #[non_exhaustive]
311 #vis struct RefCastCurrentCrate {}
312
313 unsafe impl #impl_generics ::ref_cast::__private::RefCastCustom<#from> for #name #ty_generics #where_clause {
314 type CurrentCrate = RefCastCurrentCrate;
315 #assert_trivial_fields
316 }
317 };
318 })
319}
320
321fn expand_function_body(function: Function) -> TokenStream2 {
322 let Function {
323 attrs,
324 vis,
325 constness,
326 asyncness,
327 unsafety,
328 abi,
329 fn_token,
330 ident,
331 generics,
332 paren_token,
333 arg,
334 colon_token,
335 from_type,
336 arrow_token,
337 to_type,
338 semi_token,
339 } = function;
340
341 let args = quote_spanned! {paren_token.span=>
342 (#arg #colon_token #from_type)
343 };
344
345 let allow_unused_unsafe = if unsafety.is_some() {
346 Some(quote!(#[allow(unused_unsafe)]))
347 } else {
348 None
349 };
350
351 let mut inline_attr = Some(quote!(#[inline]));
352 for attr in &attrs {
353 if attr.path().is_ident("inline") {
354 inline_attr = None;
355 break;
356 }
357 }
358
359 let macro_generated_unsafe = quote!(unsafe);
366
367 quote_spanned! {semi_token.span=>
368 #(#attrs)*
369 #inline_attr
370 #vis #constness #asyncness #unsafety #abi
371 #fn_token #ident #generics #args #arrow_token #to_type {
372 let _ = || {
374 ::ref_cast::__private::ref_cast_custom::<#from_type, #to_type>(#arg);
375 };
376
377 let _ = ::ref_cast::__private::CurrentCrate::<#from_type, #to_type> {};
379
380 #allow_unused_unsafe #[allow(clippy::transmute_ptr_to_ptr)]
382 #macro_generated_unsafe {
383 ::ref_cast::__private::transmute::<#from_type, #to_type>(#arg)
384 }
385 }
386 }
387}
388
389fn check_repr(input: &DeriveInput) -> Result<()> {
390 let mut has_repr = false;
391 let mut errors = None;
392 let mut push_error = |error| match &mut errors {
393 Some(errors) => Error::combine(errors, error),
394 None => errors = Some(error),
395 };
396
397 for attr in &input.attrs {
398 if attr.path().is_ident("repr") {
399 if let Err(error) = attr.parse_args_with(|input: ParseStream| {
400 while !input.is_empty() {
401 let path = input.call(Path::parse_mod_style)?;
402 if path.is_ident("transparent") || path.is_ident("C") {
403 has_repr = true;
404 } else if path.is_ident("packed") {
405 } else {
407 let meta_item_span = if input.peek(token::Paren) {
408 let group: TokenTree = input.parse()?;
409 quote!(#path #group)
410 } else if input.peek(Token![=]) {
411 let eq_token: Token![=] = input.parse()?;
412 let value: Expr = input.parse()?;
413 quote!(#path #eq_token #value)
414 } else {
415 quote!(#path)
416 };
417 let msg = if path.is_ident("align") {
418 "aligned repr on struct that implements RefCast is not supported"
419 } else {
420 "unrecognized repr on struct that implements RefCast"
421 };
422 push_error(Error::new_spanned(meta_item_span, msg));
423 }
424 if !input.is_empty() {
425 input.parse::<Token![,]>()?;
426 }
427 }
428 Ok(())
429 }) {
430 push_error(error);
431 }
432 }
433 }
434
435 if !has_repr {
436 let mut requires_repr = Error::new(
437 Span::call_site(),
438 "RefCast trait requires #[repr(transparent)]",
439 );
440 if let Some(errors) = errors {
441 requires_repr.combine(errors);
442 }
443 errors = Some(requires_repr);
444 }
445
446 match errors {
447 None => Ok(()),
448 Some(errors) => Err(errors),
449 }
450}
451
452type Fields = Punctuated<Field, Token![,]>;
453
454fn fields(input: &DeriveInput) -> Result<&Fields> {
455 use syn::Fields;
456
457 match &input.data {
458 Data::Struct(data) => match &data.fields {
459 Fields::Named(fields) => Ok(&fields.named),
460 Fields::Unnamed(fields) => Ok(&fields.unnamed),
461 Fields::Unit => Err(Error::new(
462 Span::call_site(),
463 "RefCast does not support unit structs",
464 )),
465 },
466 Data::Enum(_) => Err(Error::new(
467 Span::call_site(),
468 "RefCast does not support enums",
469 )),
470 Data::Union(_) => Err(Error::new(
471 Span::call_site(),
472 "RefCast does not support unions",
473 )),
474 }
475}
476
477fn only_field_ty(fields: &Fields) -> Result<&Type> {
478 let is_trivial = decide_trivial(fields)?;
479 let mut only_field = None;
480
481 for field in fields {
482 if !is_trivial(field)? {
483 if only_field.take().is_some() {
484 break;
485 }
486 only_field = Some(&field.ty);
487 }
488 }
489
490 only_field.ok_or_else(|| {
491 Error::new(
492 Span::call_site(),
493 "RefCast requires a struct with a single field",
494 )
495 })
496}
497
498fn trivial_fields(fields: &Fields) -> Result<Vec<&Type>> {
499 let is_trivial = decide_trivial(fields)?;
500 let mut trivial = Vec::new();
501
502 for field in fields {
503 if is_trivial(field)? {
504 trivial.push(&field.ty);
505 }
506 }
507
508 Ok(trivial)
509}
510
511fn decide_trivial(fields: &Fields) -> Result<fn(&Field) -> Result<bool>> {
512 for field in fields {
513 if is_explicit_trivial(field)? {
514 return Ok(is_explicit_trivial);
515 }
516 }
517 Ok(is_implicit_trivial)
518}
519
520#[allow(clippy::unnecessary_wraps)] fn is_implicit_trivial(field: &Field) -> Result<bool> {
522 match &field.ty {
523 Type::Tuple(ty) => Ok(ty.elems.is_empty()),
524 Type::Path(ty) => {
525 let ident = &ty.path.segments.last().unwrap().ident;
526 Ok(ident == "PhantomData" || ident == "PhantomPinned")
527 }
528 _ => Ok(false),
529 }
530}
531
532fn is_explicit_trivial(field: &Field) -> Result<bool> {
533 for attr in &field.attrs {
534 if attr.path().is_ident("trivial") {
535 attr.meta.require_path_only()?;
536 return Ok(true);
537 }
538 }
539 Ok(false)
540}