curve25519_dalek_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use syn::spanned::Spanned;
6
7macro_rules! unsupported_if_some {
8    ($value:expr) => {
9        if let Some(value) = $value {
10            return syn::Error::new(value.span(), "unsupported by #[unsafe_target_feature(...)]")
11                .into_compile_error()
12                .into();
13        }
14    };
15}
16
17macro_rules! unsupported {
18    ($value: expr) => {
19        return syn::Error::new(
20            $value.span(),
21            "unsupported by #[unsafe_target_feature(...)]",
22        )
23        .into_compile_error()
24        .into()
25    };
26}
27
28mod kw {
29    syn::custom_keyword!(conditional);
30}
31
32enum SpecializeArg {
33    LitStr(syn::LitStr),
34    Conditional(Conditional),
35}
36
37impl SpecializeArg {
38    fn lit(&self) -> &syn::LitStr {
39        match self {
40            SpecializeArg::LitStr(lit) => lit,
41            SpecializeArg::Conditional(conditional) => &conditional.lit,
42        }
43    }
44
45    fn condition(&self) -> Option<&TokenStream2> {
46        match self {
47            SpecializeArg::LitStr(..) => None,
48            SpecializeArg::Conditional(conditional) => Some(&conditional.attr),
49        }
50    }
51}
52
53struct Conditional {
54    lit: syn::LitStr,
55    attr: TokenStream2,
56}
57
58impl syn::parse::Parse for Conditional {
59    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
60        let lit = input.parse()?;
61        input.parse::<syn::Token![,]>()?;
62        let attr = input.parse()?;
63
64        Ok(Conditional { lit, attr })
65    }
66}
67
68impl syn::parse::Parse for SpecializeArg {
69    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
70        let lookahead = input.lookahead1();
71        if lookahead.peek(kw::conditional) {
72            input.parse::<kw::conditional>()?;
73
74            let content;
75            syn::parenthesized!(content in input);
76
77            let conditional = content.parse()?;
78            Ok(SpecializeArg::Conditional(conditional))
79        } else {
80            Ok(SpecializeArg::LitStr(input.parse()?))
81        }
82    }
83}
84
85struct SpecializeArgs(syn::punctuated::Punctuated<SpecializeArg, syn::Token![,]>);
86
87impl syn::parse::Parse for SpecializeArgs {
88    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
89        Ok(Self(syn::punctuated::Punctuated::parse_terminated(input)?))
90    }
91}
92
93#[proc_macro_attribute]
94pub fn unsafe_target_feature(attributes: TokenStream, input: TokenStream) -> TokenStream {
95    let attributes = syn::parse_macro_input!(attributes as syn::LitStr);
96    let item = syn::parse_macro_input!(input as syn::Item);
97    process_item(&attributes, item, true)
98}
99
100#[proc_macro_attribute]
101pub fn unsafe_target_feature_specialize(
102    attributes: TokenStream,
103    input: TokenStream,
104) -> TokenStream {
105    let attributes = syn::parse_macro_input!(attributes as SpecializeArgs);
106    let item_mod = syn::parse_macro_input!(input as syn::ItemMod);
107
108    let mut out = Vec::new();
109    for attributes in attributes.0 {
110        let features: Vec<_> = attributes
111            .lit()
112            .value()
113            .split(',')
114            .map(|feature| feature.replace(' ', ""))
115            .collect();
116        let name = format!("{}_{}", item_mod.ident, features.join("_"));
117        let ident = syn::Ident::new(&name, item_mod.ident.span());
118        let mut attrs = item_mod.attrs.clone();
119        if let Some(condition) = attributes.condition() {
120            attrs.push(syn::Attribute {
121                pound_token: Default::default(),
122                style: syn::AttrStyle::Outer,
123                bracket_token: Default::default(),
124                meta: syn::Meta::List(syn::MetaList {
125                    path: syn::Ident::new("cfg", attributes.lit().span()).into(),
126                    delimiter: syn::MacroDelimiter::Paren(Default::default()),
127                    tokens: condition.clone(),
128                }),
129            });
130        }
131
132        let item_mod = process_mod(
133            attributes.lit(),
134            syn::ItemMod {
135                attrs,
136                ident,
137                ..item_mod.clone()
138            },
139            Some(features),
140        );
141
142        out.push(item_mod);
143    }
144
145    quote::quote! {
146        #(#out)*
147    }
148    .into()
149}
150
151fn process_item(attributes: &syn::LitStr, item: syn::Item, strict: bool) -> TokenStream {
152    match item {
153        syn::Item::Fn(function) => process_function(attributes, function, None),
154        syn::Item::Impl(item_impl) => process_impl(attributes, item_impl),
155        syn::Item::Mod(item_mod) => process_mod(attributes, item_mod, None).into(),
156        item => {
157            if strict {
158                unsupported!(item)
159            } else {
160                quote::quote! { #item }.into()
161            }
162        }
163    }
164}
165
166fn process_mod(
167    attributes: &syn::LitStr,
168    mut item_mod: syn::ItemMod,
169    spec_features: Option<Vec<String>>,
170) -> TokenStream2 {
171    if let Some((_, ref mut content)) = item_mod.content {
172        'next_item: for item in content {
173            if let Some(ref spec_features) = spec_features {
174                match item {
175                    syn::Item::Const(syn::ItemConst { ref mut attrs, .. })
176                    | syn::Item::Enum(syn::ItemEnum { ref mut attrs, .. })
177                    | syn::Item::ExternCrate(syn::ItemExternCrate { ref mut attrs, .. })
178                    | syn::Item::Fn(syn::ItemFn { ref mut attrs, .. })
179                    | syn::Item::ForeignMod(syn::ItemForeignMod { ref mut attrs, .. })
180                    | syn::Item::Impl(syn::ItemImpl { ref mut attrs, .. })
181                    | syn::Item::Macro(syn::ItemMacro { ref mut attrs, .. })
182                    | syn::Item::Mod(syn::ItemMod { ref mut attrs, .. })
183                    | syn::Item::Static(syn::ItemStatic { ref mut attrs, .. })
184                    | syn::Item::Struct(syn::ItemStruct { ref mut attrs, .. })
185                    | syn::Item::Trait(syn::ItemTrait { ref mut attrs, .. })
186                    | syn::Item::TraitAlias(syn::ItemTraitAlias { ref mut attrs, .. })
187                    | syn::Item::Type(syn::ItemType { ref mut attrs, .. })
188                    | syn::Item::Union(syn::ItemUnion { ref mut attrs, .. })
189                    | syn::Item::Use(syn::ItemUse { ref mut attrs, .. }) => {
190                        let mut index = 0;
191                        while index < attrs.len() {
192                            let attr = &attrs[index];
193                            if matches!(attr.style, syn::AttrStyle::Outer) {
194                                match attr.meta {
195                                    syn::Meta::List(ref list)
196                                        if is_path_eq(&list.path, "for_target_feature") =>
197                                    {
198                                        let feature: syn::LitStr = match list.parse_args() {
199                                            Ok(feature) => feature,
200                                            Err(error) => {
201                                                return error.into_compile_error();
202                                            }
203                                        };
204
205                                        let feature = feature.value();
206                                        if !spec_features
207                                            .iter()
208                                            .any(|enabled_feature| feature == *enabled_feature)
209                                        {
210                                            *item = syn::Item::Verbatim(Default::default());
211                                            continue 'next_item;
212                                        }
213
214                                        attrs.remove(index);
215                                        continue;
216                                    }
217                                    _ => {}
218                                }
219                            }
220
221                            index += 1;
222                            continue;
223                        }
224                    }
225                    _ => {
226                        unsupported!(item_mod);
227                    }
228                }
229            }
230
231            *item = syn::Item::Verbatim(
232                process_item(
233                    attributes,
234                    std::mem::replace(item, syn::Item::Verbatim(Default::default())),
235                    false,
236                )
237                .into(),
238            );
239        }
240    }
241
242    quote::quote! {
243        #item_mod
244    }
245}
246
247fn process_impl(attributes: &syn::LitStr, mut item_impl: syn::ItemImpl) -> TokenStream {
248    unsupported_if_some!(item_impl.defaultness);
249    unsupported_if_some!(item_impl.unsafety);
250
251    let mut items = Vec::new();
252    for item in item_impl.items.drain(..) {
253        match item {
254            syn::ImplItem::Fn(function) => {
255                unsupported_if_some!(function.defaultness);
256                let function = syn::ItemFn {
257                    attrs: function.attrs,
258                    vis: function.vis,
259                    sig: function.sig,
260                    block: Box::new(function.block),
261                };
262                let output_item = process_function(
263                    attributes,
264                    function,
265                    Some((item_impl.generics.clone(), item_impl.self_ty.clone())),
266                );
267                items.push(syn::ImplItem::Verbatim(output_item.into()));
268            }
269            item => items.push(item),
270        }
271    }
272
273    item_impl.items = items;
274    quote::quote! {
275        #item_impl
276    }
277    .into()
278}
279
280fn is_path_eq(path: &syn::Path, ident: &str) -> bool {
281    let segments: Vec<_> = ident.split("::").collect();
282    path.segments.len() == segments.len()
283        && path
284            .segments
285            .iter()
286            .zip(segments.iter())
287            .all(|(segment, expected)| segment.ident == expected && segment.arguments.is_none())
288}
289
290fn process_function(
291    attributes: &syn::LitStr,
292    function: syn::ItemFn,
293    outer: Option<(syn::Generics, Box<syn::Type>)>,
294) -> TokenStream {
295    if function.sig.unsafety.is_some() {
296        return quote::quote! {
297            #[target_feature(enable = #attributes)]
298            #function
299        }
300        .into();
301    }
302
303    unsupported_if_some!(function.sig.constness);
304    unsupported_if_some!(function.sig.asyncness);
305    unsupported_if_some!(function.sig.abi);
306    unsupported_if_some!(function.sig.variadic);
307
308    let function_visibility = function.vis;
309    let function_name = function.sig.ident;
310    let function_return = function.sig.output;
311    let function_inner_name =
312        syn::Ident::new(&format!("_impl_{}", function_name), function_name.span());
313    let function_args = function.sig.inputs;
314    let function_body = function.block;
315    let mut function_call_args = Vec::new();
316    let mut function_args_outer = Vec::new();
317    let mut function_args_inner = Vec::new();
318    for (index, arg) in function_args.iter().enumerate() {
319        match arg {
320            syn::FnArg::Receiver(receiver) => {
321                unsupported_if_some!(receiver.attrs.first());
322                unsupported_if_some!(receiver.colon_token);
323
324                if outer.is_none() {
325                    return syn::Error::new(receiver.span(), "unsupported by #[unsafe_target_feature(...)]; put the attribute on the outer `impl`").into_compile_error().into();
326                }
327
328                function_args_inner.push(syn::FnArg::Receiver(receiver.clone()));
329                function_args_outer.push(syn::FnArg::Receiver(receiver.clone()));
330                function_call_args.push(syn::Ident::new("self", receiver.self_token.span()));
331            }
332            syn::FnArg::Typed(ty) => {
333                unsupported_if_some!(ty.attrs.first());
334
335                match &*ty.pat {
336                    syn::Pat::Ident(pat_ident) => {
337                        unsupported_if_some!(pat_ident.attrs.first());
338
339                        function_args_inner.push(arg.clone());
340                        function_args_outer.push(syn::FnArg::Typed(syn::PatType {
341                            attrs: Vec::new(),
342                            pat: Box::new(syn::Pat::Ident(syn::PatIdent {
343                                attrs: Vec::new(),
344                                by_ref: None,
345                                mutability: None,
346                                ident: pat_ident.ident.clone(),
347                                subpat: None,
348                            })),
349                            colon_token: ty.colon_token,
350                            ty: ty.ty.clone(),
351                        }));
352                        function_call_args.push(pat_ident.ident.clone());
353                    }
354                    syn::Pat::Wild(pat_wild) => {
355                        unsupported_if_some!(pat_wild.attrs.first());
356
357                        let ident = syn::Ident::new(
358                            &format!("__arg_{}__", index),
359                            pat_wild.underscore_token.span(),
360                        );
361                        function_args_inner.push(arg.clone());
362                        function_args_outer.push(syn::FnArg::Typed(syn::PatType {
363                            attrs: Vec::new(),
364                            pat: Box::new(syn::Pat::Ident(syn::PatIdent {
365                                attrs: Vec::new(),
366                                by_ref: None,
367                                mutability: None,
368                                ident: ident.clone(),
369                                subpat: None,
370                            })),
371                            colon_token: ty.colon_token,
372                            ty: ty.ty.clone(),
373                        }));
374                        function_call_args.push(ident);
375                    }
376                    _ => unsupported!(arg),
377                }
378            }
379        }
380    }
381
382    let mut maybe_inline = quote::quote! {};
383    let mut maybe_outer_attributes = Vec::new();
384    let mut maybe_cfg = quote::quote! {};
385    for attribute in function.attrs {
386        match &attribute.meta {
387            syn::Meta::Path(path) if is_path_eq(path, "inline") => {
388                maybe_inline = quote::quote! { #[inline] };
389            }
390            syn::Meta::Path(path) if is_path_eq(path, "test") => {
391                maybe_outer_attributes.push(attribute);
392                maybe_cfg = quote::quote! { #[cfg(target_feature = #attributes)] };
393            }
394            syn::Meta::List(syn::MetaList { path, tokens, .. })
395                if is_path_eq(path, "inline") && tokens.to_string() == "always" =>
396            {
397                maybe_inline = quote::quote! { #[inline] };
398            }
399            syn::Meta::NameValue(syn::MetaNameValue { path, .. }) if is_path_eq(path, "doc") => {
400                maybe_outer_attributes.push(attribute);
401            }
402            syn::Meta::List(syn::MetaList { path, .. })
403                if is_path_eq(path, "cfg")
404                    || is_path_eq(path, "allow")
405                    || is_path_eq(path, "deny") =>
406            {
407                maybe_outer_attributes.push(attribute);
408            }
409            syn::Meta::Path(path) if is_path_eq(path, "rustfmt::skip") => {
410                maybe_outer_attributes.push(attribute);
411            }
412            _ => unsupported!(attribute),
413        }
414    }
415
416    let (fn_impl_generics, fn_ty_generics, fn_where_clause) =
417        function.sig.generics.split_for_impl();
418    let fn_call_generics = fn_ty_generics.as_turbofish();
419
420    if let Some((generics, self_ty)) = outer {
421        let (outer_impl_generics, outer_ty_generics, outer_where_clause) =
422            generics.split_for_impl();
423        let trait_ident =
424            syn::Ident::new(&format!("__Impl_{}__", function_name), function_name.span());
425        let item_trait = quote::quote! {
426            #[allow(non_camel_case_types)]
427            trait #trait_ident #outer_impl_generics #outer_where_clause {
428                unsafe fn #function_inner_name #fn_impl_generics (#(#function_args_outer),*) #function_return #fn_where_clause;
429            }
430        };
431
432        let item_trait_impl = quote::quote! {
433            impl #outer_impl_generics #trait_ident #outer_ty_generics for #self_ty #outer_where_clause {
434                #[target_feature(enable = #attributes)]
435                #maybe_inline
436                unsafe fn #function_inner_name #fn_impl_generics (#(#function_args_inner),*) #function_return #fn_where_clause #function_body
437            }
438        };
439
440        quote::quote! {
441            #[inline(always)]
442            #(#maybe_outer_attributes)*
443            #function_visibility fn #function_name #fn_impl_generics (#(#function_args_outer),*) #function_return #fn_where_clause {
444                #item_trait
445                #item_trait_impl
446                unsafe {
447                    <Self as #trait_ident #outer_ty_generics> ::#function_inner_name #fn_call_generics (#(#function_call_args),*)
448                }
449            }
450        }.into()
451    } else {
452        quote::quote! {
453            #[inline(always)]
454            #maybe_cfg
455            #(#maybe_outer_attributes)*
456            #function_visibility fn #function_name #fn_impl_generics (#(#function_args_outer),*) #function_return #fn_where_clause {
457                #[target_feature(enable = #attributes)]
458                #maybe_inline
459                unsafe fn #function_inner_name #fn_impl_generics (#(#function_args_inner),*) #function_return #fn_where_clause #function_body
460                unsafe {
461                    #function_inner_name #fn_call_generics (#(#function_call_args),*)
462                }
463            }
464        }.into()
465    }
466}