asn1_rs_derive/
container.rs

1use proc_macro2::{Literal, Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    parse::ParseStream, parse_quote, spanned::Spanned, Attribute, DataStruct, DeriveInput, Field,
5    Fields, Ident, Lifetime, LitInt, Meta, Type, WherePredicate,
6};
7
8#[derive(Copy, Clone, Debug, PartialEq)]
9pub enum ContainerType {
10    Alias,
11    Sequence,
12    Set,
13}
14
15impl ToTokens for ContainerType {
16    fn to_tokens(&self, tokens: &mut TokenStream) {
17        let s = match self {
18            ContainerType::Alias => quote! {},
19            ContainerType::Sequence => quote! { asn1_rs::Tag::Sequence },
20            ContainerType::Set => quote! { asn1_rs::Tag::Set },
21        };
22        s.to_tokens(tokens)
23    }
24}
25
26#[derive(Clone, Copy, Debug, PartialEq)]
27enum Asn1Type {
28    Ber,
29    Der,
30}
31
32#[derive(Copy, Clone, Debug, PartialEq)]
33pub enum Asn1TagKind {
34    Explicit,
35    Implicit,
36}
37
38impl ToTokens for Asn1TagKind {
39    fn to_tokens(&self, tokens: &mut TokenStream) {
40        let s = match self {
41            Asn1TagKind::Explicit => quote! { asn1_rs::Explicit },
42            Asn1TagKind::Implicit => quote! { asn1_rs::Implicit },
43        };
44        s.to_tokens(tokens)
45    }
46}
47
48#[derive(Copy, Clone, Debug, PartialEq)]
49pub enum Asn1TagClass {
50    Universal,
51    Application,
52    ContextSpecific,
53    Private,
54}
55
56impl ToTokens for Asn1TagClass {
57    fn to_tokens(&self, tokens: &mut TokenStream) {
58        let s = match self {
59            Asn1TagClass::Application => quote! { asn1_rs::Class::APPLICATION },
60            Asn1TagClass::ContextSpecific => quote! { asn1_rs::Class::CONTEXT_SPECIFIC },
61            Asn1TagClass::Private => quote! { asn1_rs::Class::PRIVATE },
62            Asn1TagClass::Universal => quote! { asn1_rs::Class::UNIVERSAL },
63        };
64        s.to_tokens(tokens)
65    }
66}
67
68pub struct Container {
69    pub container_type: ContainerType,
70    pub fields: Vec<FieldInfo>,
71    pub where_predicates: Vec<WherePredicate>,
72    pub error: Option<Attribute>,
73
74    is_any: bool,
75}
76
77impl Container {
78    pub fn from_datastruct(
79        ds: &DataStruct,
80        ast: &DeriveInput,
81        container_type: ContainerType,
82    ) -> Self {
83        let mut is_any = false;
84        match (container_type, &ds.fields) {
85            (ContainerType::Alias, Fields::Unnamed(f)) => {
86                if f.unnamed.len() != 1 {
87                    panic!("Alias: only tuple fields with one element are supported");
88                }
89                match &f.unnamed[0].ty {
90                    Type::Path(type_path)
91                        if type_path
92                            .clone()
93                            .into_token_stream()
94                            .to_string()
95                            .starts_with("Any") =>
96                    {
97                        is_any = true;
98                    }
99                    _ => (),
100                }
101            }
102            (ContainerType::Alias, _) => panic!("BER/DER alias must be used with tuple strucs"),
103            (_, Fields::Unnamed(_)) => panic!("BER/DER sequence cannot be used on tuple structs"),
104            _ => (),
105        }
106
107        let fields = ds.fields.iter().map(FieldInfo::from).collect();
108
109        // get lifetimes from generics
110        let lfts: Vec<_> = ast.generics.lifetimes().collect();
111        let mut where_predicates = Vec::new();
112        if !lfts.is_empty() {
113            // input slice must outlive all lifetimes from Self
114            let lft = Lifetime::new("'ber", Span::call_site());
115            let wh: WherePredicate = parse_quote! { #lft: #(#lfts)+* };
116            where_predicates.push(wh);
117        };
118
119        // get custom attributes on container
120        let error = ast
121            .attrs
122            .iter()
123            .find(|attr| {
124                attr.meta
125                    .path()
126                    .is_ident(&Ident::new("error", Span::call_site()))
127            })
128            .cloned();
129
130        Container {
131            container_type,
132            fields,
133            where_predicates,
134            error,
135            is_any,
136        }
137    }
138
139    pub fn gen_tryfrom(&self) -> TokenStream {
140        let field_names = &self.fields.iter().map(|f| &f.name).collect::<Vec<_>>();
141        let parse_content =
142            derive_ber_sequence_content(&self.fields, Asn1Type::Ber, self.error.is_some());
143        let lifetime = Lifetime::new("'ber", Span::call_site());
144        let wh = &self.where_predicates;
145        let error = if let Some(attr) = &self.error {
146            get_attribute_meta(attr).expect("Invalid error attribute format")
147        } else {
148            quote! { asn1_rs::Error }
149        };
150
151        let fn_content = if self.container_type == ContainerType::Alias {
152            // special case: is this an alias for Any
153            if self.is_any {
154                quote! { Ok(Self(any)) }
155            } else {
156                quote! {
157                    let res = TryFrom::try_from(any)?;
158                    Ok(Self(res))
159                }
160            }
161        } else {
162            quote! {
163                use asn1_rs::nom::*;
164                any.tag().assert_eq(Self::TAG)?;
165
166                // no need to parse sequence, we already have content
167                let i = any.data;
168                //
169                #parse_content
170                //
171                let _ = i; // XXX check if empty?
172                Ok(Self{#(#field_names),*})
173            }
174        };
175        // note: `gen impl` in synstructure takes care of appending extra where clauses if any, and removing
176        // the `where` statement if there are none.
177        quote! {
178            use asn1_rs::{Any, FromBer};
179            use core::convert::TryFrom;
180
181            gen impl<#lifetime> TryFrom<Any<#lifetime>> for @Self where #(#wh)+* {
182                type Error = #error;
183
184                fn try_from(any: Any<#lifetime>) -> asn1_rs::Result<Self, #error> {
185                    #fn_content
186                }
187            }
188        }
189    }
190
191    pub fn gen_tagged(&self) -> TokenStream {
192        let tag = if self.container_type == ContainerType::Alias {
193            // special case: is this an alias for Any
194            if self.is_any {
195                return quote! {};
196            }
197            // find type of sub-item
198            let ty = &self.fields[0].type_;
199            quote! { <#ty as asn1_rs::Tagged>::TAG }
200        } else {
201            let container_type = self.container_type;
202            quote! { #container_type }
203        };
204        quote! {
205            gen impl<'ber> asn1_rs::Tagged for @Self {
206                const TAG: asn1_rs::Tag = #tag;
207            }
208        }
209    }
210
211    pub fn gen_checkconstraints(&self) -> TokenStream {
212        let lifetime = Lifetime::new("'ber", Span::call_site());
213        let wh = &self.where_predicates;
214        // let parse_content = derive_ber_sequence_content(&field_names, Asn1Type::Der);
215
216        let fn_content = if self.container_type == ContainerType::Alias {
217            // special case: is this an alias for Any
218            if self.is_any {
219                return quote! {};
220            }
221            let ty = &self.fields[0].type_;
222            quote! {
223                any.tag().assert_eq(Self::TAG)?;
224                <#ty>::check_constraints(any)
225            }
226        } else {
227            let check_fields: Vec<_> = self
228                .fields
229                .iter()
230                .map(|field| {
231                    let ty = &field.type_;
232                    quote! {
233                        let (rem, any) = Any::from_der(rem)?;
234                        <#ty as CheckDerConstraints>::check_constraints(&any)?;
235                    }
236                })
237                .collect();
238            quote! {
239                any.tag().assert_eq(Self::TAG)?;
240                let rem = &any.data;
241                #(#check_fields)*
242                Ok(())
243            }
244        };
245
246        // note: `gen impl` in synstructure takes care of appending extra where clauses if any, and removing
247        // the `where` statement if there are none.
248        quote! {
249            use asn1_rs::{CheckDerConstraints, Tagged};
250            gen impl<#lifetime> CheckDerConstraints for @Self where #(#wh)+* {
251                fn check_constraints(any: &Any) -> asn1_rs::Result<()> {
252                    #fn_content
253                }
254            }
255        }
256    }
257
258    pub fn gen_fromder(&self) -> TokenStream {
259        let lifetime = Lifetime::new("'ber", Span::call_site());
260        let wh = &self.where_predicates;
261        let field_names = &self.fields.iter().map(|f| &f.name).collect::<Vec<_>>();
262        let parse_content =
263            derive_ber_sequence_content(&self.fields, Asn1Type::Der, self.error.is_some());
264        let error = if let Some(attr) = &self.error {
265            get_attribute_meta(attr).expect("Invalid error attribute format")
266        } else {
267            quote! { asn1_rs::Error }
268        };
269
270        let fn_content = if self.container_type == ContainerType::Alias {
271            // special case: is this an alias for Any
272            if self.is_any {
273                quote! {
274                    let (rem, any) = asn1_rs::Any::from_der(bytes).map_err(asn1_rs::nom::Err::convert)?;
275                    Ok((rem,Self(any)))
276                }
277            } else {
278                quote! {
279                    let (rem, any) = asn1_rs::Any::from_der(bytes).map_err(asn1_rs::nom::Err::convert)?;
280                    any.header.assert_tag(Self::TAG).map_err(|e| asn1_rs::nom::Err::Error(e.into()))?;
281                    let res = TryFrom::try_from(any)?;
282                    Ok((rem,Self(res)))
283                }
284            }
285        } else {
286            quote! {
287                let (rem, any) = asn1_rs::Any::from_der(bytes).map_err(asn1_rs::nom::Err::convert)?;
288                any.header.assert_tag(Self::TAG).map_err(|e| asn1_rs::nom::Err::Error(e.into()))?;
289                let i = any.data;
290                //
291                #parse_content
292                //
293                // let _ = i; // XXX check if empty?
294                Ok((rem,Self{#(#field_names),*}))
295            }
296        };
297        // note: `gen impl` in synstructure takes care of appending extra where clauses if any, and removing
298        // the `where` statement if there are none.
299        quote! {
300            use asn1_rs::FromDer;
301
302            gen impl<#lifetime> asn1_rs::FromDer<#lifetime, #error> for @Self where #(#wh)+* {
303                fn from_der(bytes: &#lifetime [u8]) -> asn1_rs::ParseResult<#lifetime, Self, #error> {
304                    #fn_content
305                }
306            }
307        }
308    }
309
310    pub fn gen_to_der_len(&self) -> TokenStream {
311        let field_names = &self.fields.iter().map(|f| &f.name).collect::<Vec<_>>();
312        let add_len_instructions = field_names.iter().fold(Vec::new(), |mut instrs, field| {
313            instrs.push(quote! {total_len += self.#field.to_der_len()?;});
314            instrs
315        });
316        quote! {
317            fn to_der_len(&self) -> asn1_rs::Result<usize> {
318                let mut total_len = 0;
319                #(#add_len_instructions)*
320                // now add header length computation
321                if total_len < 127 {
322                    // 1 (class+tag) + 1 (length) + len
323                    Ok(2 + total_len)
324                } else {
325                    // 1 (class+tag) + n (length) + len
326                    let n = asn1_rs::Length::Definite(total_len).to_der_len()?;
327                    Ok(1 + n + total_len)
328                }
329            }
330        }
331    }
332
333    pub fn gen_write_der_header(&self) -> TokenStream {
334        quote! {
335            fn write_der_header(&self, writer: &mut dyn std::io::Write) -> asn1_rs::SerializeResult<usize> {
336                let mut empty = std::io::empty();
337                let num_bytes = self.write_der_content(&mut empty)?;
338                let header = asn1_rs::Header::new(
339                    asn1_rs::Class::Universal,
340                    true,
341                    asn1_rs::Sequence::TAG,
342                    asn1_rs::Length::Definite(num_bytes),
343                );
344                header.write_der_header(writer).map_err(Into::into)
345            }
346        }
347    }
348
349    pub fn gen_write_der_content(&self) -> TokenStream {
350        let field_names = &self.fields.iter().map(|f| &f.name).collect::<Vec<_>>();
351        let write_instructions = field_names.iter().fold(Vec::new(), |mut instrs, field| {
352            instrs.push(quote! {num_bytes += self.#field.write_der_header(writer)?;});
353            instrs.push(quote! {num_bytes += self.#field.write_der_content(writer)?;});
354            instrs
355        });
356        quote! {
357            fn write_der_content(&self, writer: &mut dyn std::io::Write) -> asn1_rs::SerializeResult<usize> {
358                let mut num_bytes = 0;
359                #(#write_instructions)*
360                Ok(num_bytes)
361            }
362        }
363    }
364}
365
366#[derive(Debug)]
367pub struct FieldInfo {
368    pub name: Ident,
369    pub type_: Type,
370    pub default: Option<TokenStream>,
371    pub optional: bool,
372    pub tag: Option<(Asn1TagKind, Asn1TagClass, u16)>,
373    pub map_err: Option<TokenStream>,
374}
375
376impl From<&Field> for FieldInfo {
377    fn from(field: &Field) -> Self {
378        // parse attributes and keep supported ones
379        let mut optional = false;
380        let mut tag = None;
381        let mut map_err = None;
382        let mut default = None;
383        let name = field
384            .ident
385            .as_ref()
386            .map_or_else(|| Ident::new("_", Span::call_site()), |s| s.clone());
387        for attr in &field.attrs {
388            let ident = match attr.meta.path().get_ident() {
389                Some(ident) => ident.to_string(),
390                None => continue,
391            };
392            match ident.as_str() {
393                "map_err" => {
394                    let expr: syn::Expr = attr.parse_args().expect("could not parse map_err");
395                    map_err = Some(quote! { #expr });
396                }
397                "default" => {
398                    let expr: syn::Expr = attr.parse_args().expect("could not parse default");
399                    default = Some(quote! { #expr });
400                    optional = true;
401                }
402                "optional" => optional = true,
403                "tag_explicit" => {
404                    if tag.is_some() {
405                        panic!("tag cannot be set twice!");
406                    }
407                    let (class, value) = attr.parse_args_with(parse_tag_args).unwrap();
408                    tag = Some((Asn1TagKind::Explicit, class, value));
409                }
410                "tag_implicit" => {
411                    if tag.is_some() {
412                        panic!("tag cannot be set twice!");
413                    }
414                    let (class, value) = attr.parse_args_with(parse_tag_args).unwrap();
415                    tag = Some((Asn1TagKind::Implicit, class, value));
416                }
417                // ignore unknown attributes
418                _ => (),
419            }
420        }
421        FieldInfo {
422            name,
423            type_: field.ty.clone(),
424            default,
425            optional,
426            tag,
427            map_err,
428        }
429    }
430}
431
432fn parse_tag_args(stream: ParseStream) -> Result<(Asn1TagClass, u16), syn::Error> {
433    let tag_class: Option<Ident> = stream.parse()?;
434    let tag_class = if let Some(ident) = tag_class {
435        let s = ident.to_string().to_uppercase();
436        match s.as_str() {
437            "UNIVERSAL" => Asn1TagClass::Universal,
438            "CONTEXT-SPECIFIC" => Asn1TagClass::ContextSpecific,
439            "APPLICATION" => Asn1TagClass::Application,
440            "PRIVATE" => Asn1TagClass::Private,
441            _ => {
442                return Err(syn::Error::new(stream.span(), "Invalid tag class"));
443            }
444        }
445    } else {
446        Asn1TagClass::ContextSpecific
447    };
448    let lit: LitInt = stream.parse()?;
449    let value = lit.base10_parse::<u16>()?;
450    Ok((tag_class, value))
451}
452
453fn derive_ber_sequence_content(
454    fields: &[FieldInfo],
455    asn1_type: Asn1Type,
456    custom_errors: bool,
457) -> TokenStream {
458    let field_parsers: Vec<_> = fields
459        .iter()
460        .map(|f| get_field_parser(f, asn1_type, custom_errors))
461        .collect();
462
463    quote! {
464        #(#field_parsers)*
465    }
466}
467
468fn get_field_parser(f: &FieldInfo, asn1_type: Asn1Type, custom_errors: bool) -> TokenStream {
469    let from = match asn1_type {
470        Asn1Type::Ber => quote! {FromBer::from_ber},
471        Asn1Type::Der => quote! {FromDer::from_der},
472    };
473    let name = &f.name;
474    let default = f
475        .default
476        .as_ref()
477        // use a type hint, otherwise compiler will not know what type provides .unwrap_or
478        .map(|x| quote! {let #name: Option<_> = #name; let #name = #name.unwrap_or(#x);});
479    let map_err = if let Some(tt) = f.map_err.as_ref() {
480        if asn1_type == Asn1Type::Ber {
481            Some(quote! {
482                .map_err(|err| err.map(#tt))
483                .map_err(asn1_rs::from_nom_error::<_, Self::Error>)
484            })
485        } else {
486            // Some(quote! { .map_err(|err| nom::Err::convert(#tt)) })
487            Some(quote! { .map_err(|err| err.map(#tt)) })
488        }
489    } else {
490        // add mapping functions only if custom errors are used
491        if custom_errors {
492            if asn1_type == Asn1Type::Ber {
493                Some(quote! { .map_err(asn1_rs::from_nom_error::<_, Self::Error>) })
494            } else {
495                Some(quote! { .map_err(nom::Err::convert) })
496            }
497        } else {
498            None
499        }
500    };
501    if let Some((tag_kind, class, n)) = f.tag {
502        let tag = Literal::u16_unsuffixed(n);
503        // test if tagged + optional
504        if f.optional {
505            return quote! {
506                let (i, #name) = {
507                    if i.is_empty() {
508                        (i, None)
509                    } else {
510                        let (_, header): (_, asn1_rs::Header) = #from(i)#map_err?;
511                        if header.tag().0 == #tag {
512                            let (i, t): (_, asn1_rs::TaggedValue::<_, _, #tag_kind, {#class}, #tag>) = #from(i)#map_err?;
513                            (i, Some(t.into_inner()))
514                        } else {
515                            (i, None)
516                        }
517                    }
518                };
519                #default
520            };
521        } else {
522            // tagged, but not OPTIONAL
523            return quote! {
524                let (i, #name) = {
525                    let (i, t): (_, asn1_rs::TaggedValue::<_, _, #tag_kind, {#class}, #tag>) = #from(i)#map_err?;
526                    (i, t.into_inner())
527                };
528                #default
529            };
530        }
531    } else {
532        // neither tagged nor optional
533        quote! {
534            let (i, #name) = #from(i)#map_err?;
535            #default
536        }
537    }
538}
539
540fn get_attribute_meta(attr: &Attribute) -> Result<TokenStream, syn::Error> {
541    if let Meta::List(meta) = &attr.meta {
542        let content = &meta.tokens;
543        Ok(quote! { #content })
544    } else {
545        Err(syn::Error::new(
546            attr.span(),
547            "Invalid error attribute format",
548        ))
549    }
550}