thiserror_impl/
attr.rs

1use proc_macro2::{Delimiter, Group, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
2use quote::{format_ident, quote, ToTokens};
3use std::collections::BTreeSet as Set;
4use syn::parse::discouraged::Speculative;
5use syn::parse::{End, ParseStream};
6use syn::{
7    braced, bracketed, parenthesized, token, Attribute, Error, Ident, Index, LitFloat, LitInt,
8    LitStr, Meta, Result, Token,
9};
10
11pub struct Attrs<'a> {
12    pub display: Option<Display<'a>>,
13    pub source: Option<&'a Attribute>,
14    pub backtrace: Option<&'a Attribute>,
15    pub from: Option<&'a Attribute>,
16    pub transparent: Option<Transparent<'a>>,
17}
18
19#[derive(Clone)]
20pub struct Display<'a> {
21    pub original: &'a Attribute,
22    pub fmt: LitStr,
23    pub args: TokenStream,
24    pub requires_fmt_machinery: bool,
25    pub has_bonus_display: bool,
26    pub implied_bounds: Set<(usize, Trait)>,
27}
28
29#[derive(Copy, Clone)]
30pub struct Transparent<'a> {
31    pub original: &'a Attribute,
32    pub span: Span,
33}
34
35#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
36pub enum Trait {
37    Debug,
38    Display,
39    Octal,
40    LowerHex,
41    UpperHex,
42    Pointer,
43    Binary,
44    LowerExp,
45    UpperExp,
46}
47
48pub fn get(input: &[Attribute]) -> Result<Attrs> {
49    let mut attrs = Attrs {
50        display: None,
51        source: None,
52        backtrace: None,
53        from: None,
54        transparent: None,
55    };
56
57    for attr in input {
58        if attr.path().is_ident("error") {
59            parse_error_attribute(&mut attrs, attr)?;
60        } else if attr.path().is_ident("source") {
61            attr.meta.require_path_only()?;
62            if attrs.source.is_some() {
63                return Err(Error::new_spanned(attr, "duplicate #[source] attribute"));
64            }
65            attrs.source = Some(attr);
66        } else if attr.path().is_ident("backtrace") {
67            attr.meta.require_path_only()?;
68            if attrs.backtrace.is_some() {
69                return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute"));
70            }
71            attrs.backtrace = Some(attr);
72        } else if attr.path().is_ident("from") {
73            match attr.meta {
74                Meta::Path(_) => {}
75                Meta::List(_) | Meta::NameValue(_) => {
76                    // Assume this is meant for derive_more crate or something.
77                    continue;
78                }
79            }
80            if attrs.from.is_some() {
81                return Err(Error::new_spanned(attr, "duplicate #[from] attribute"));
82            }
83            attrs.from = Some(attr);
84        }
85    }
86
87    Ok(attrs)
88}
89
90fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
91    syn::custom_keyword!(transparent);
92
93    attr.parse_args_with(|input: ParseStream| {
94        let lookahead = input.lookahead1();
95        let fmt = if lookahead.peek(LitStr) {
96            input.parse::<LitStr>()?
97        } else if lookahead.peek(transparent) {
98            let kw: transparent = input.parse()?;
99            if attrs.transparent.is_some() {
100                return Err(Error::new_spanned(
101                    attr,
102                    "duplicate #[error(transparent)] attribute",
103                ));
104            }
105            attrs.transparent = Some(Transparent {
106                original: attr,
107                span: kw.span,
108            });
109            return Ok(());
110        } else {
111            return Err(lookahead.error());
112        };
113
114        let args = if input.is_empty() || input.peek(Token![,]) && input.peek2(End) {
115            input.parse::<Option<Token![,]>>()?;
116            TokenStream::new()
117        } else {
118            parse_token_expr(input, false)?
119        };
120
121        let requires_fmt_machinery = !args.is_empty();
122
123        let display = Display {
124            original: attr,
125            fmt,
126            args,
127            requires_fmt_machinery,
128            has_bonus_display: false,
129            implied_bounds: Set::new(),
130        };
131        if attrs.display.is_some() {
132            return Err(Error::new_spanned(
133                attr,
134                "only one #[error(...)] attribute is allowed",
135            ));
136        }
137        attrs.display = Some(display);
138        Ok(())
139    })
140}
141
142fn parse_token_expr(input: ParseStream, mut begin_expr: bool) -> Result<TokenStream> {
143    let mut tokens = Vec::new();
144    while !input.is_empty() {
145        if input.peek(token::Group) {
146            let group: TokenTree = input.parse()?;
147            tokens.push(group);
148            begin_expr = false;
149            continue;
150        }
151
152        if begin_expr && input.peek(Token![.]) {
153            if input.peek2(Ident) {
154                input.parse::<Token![.]>()?;
155                begin_expr = false;
156                continue;
157            } else if input.peek2(LitInt) {
158                input.parse::<Token![.]>()?;
159                let int: Index = input.parse()?;
160                tokens.push({
161                    let ident = format_ident!("_{}", int.index, span = int.span);
162                    TokenTree::Ident(ident)
163                });
164                begin_expr = false;
165                continue;
166            } else if input.peek2(LitFloat) {
167                let ahead = input.fork();
168                ahead.parse::<Token![.]>()?;
169                let float: LitFloat = ahead.parse()?;
170                let repr = float.to_string();
171                let mut indices = repr.split('.').map(syn::parse_str::<Index>);
172                if let (Some(Ok(first)), Some(Ok(second)), None) =
173                    (indices.next(), indices.next(), indices.next())
174                {
175                    input.advance_to(&ahead);
176                    tokens.push({
177                        let ident = format_ident!("_{}", first, span = float.span());
178                        TokenTree::Ident(ident)
179                    });
180                    tokens.push({
181                        let mut punct = Punct::new('.', Spacing::Alone);
182                        punct.set_span(float.span());
183                        TokenTree::Punct(punct)
184                    });
185                    tokens.push({
186                        let mut literal = Literal::u32_unsuffixed(second.index);
187                        literal.set_span(float.span());
188                        TokenTree::Literal(literal)
189                    });
190                    begin_expr = false;
191                    continue;
192                }
193            }
194        }
195
196        begin_expr = input.peek(Token![break])
197            || input.peek(Token![continue])
198            || input.peek(Token![if])
199            || input.peek(Token![in])
200            || input.peek(Token![match])
201            || input.peek(Token![mut])
202            || input.peek(Token![return])
203            || input.peek(Token![while])
204            || input.peek(Token![+])
205            || input.peek(Token![&])
206            || input.peek(Token![!])
207            || input.peek(Token![^])
208            || input.peek(Token![,])
209            || input.peek(Token![/])
210            || input.peek(Token![=])
211            || input.peek(Token![>])
212            || input.peek(Token![<])
213            || input.peek(Token![|])
214            || input.peek(Token![%])
215            || input.peek(Token![;])
216            || input.peek(Token![*])
217            || input.peek(Token![-]);
218
219        let token: TokenTree = if input.peek(token::Paren) {
220            let content;
221            let delimiter = parenthesized!(content in input);
222            let nested = parse_token_expr(&content, true)?;
223            let mut group = Group::new(Delimiter::Parenthesis, nested);
224            group.set_span(delimiter.span.join());
225            TokenTree::Group(group)
226        } else if input.peek(token::Brace) {
227            let content;
228            let delimiter = braced!(content in input);
229            let nested = parse_token_expr(&content, true)?;
230            let mut group = Group::new(Delimiter::Brace, nested);
231            group.set_span(delimiter.span.join());
232            TokenTree::Group(group)
233        } else if input.peek(token::Bracket) {
234            let content;
235            let delimiter = bracketed!(content in input);
236            let nested = parse_token_expr(&content, true)?;
237            let mut group = Group::new(Delimiter::Bracket, nested);
238            group.set_span(delimiter.span.join());
239            TokenTree::Group(group)
240        } else {
241            input.parse()?
242        };
243        tokens.push(token);
244    }
245    Ok(TokenStream::from_iter(tokens))
246}
247
248impl ToTokens for Display<'_> {
249    fn to_tokens(&self, tokens: &mut TokenStream) {
250        let fmt = &self.fmt;
251        let args = &self.args;
252
253        // Currently `write!(f, "text")` produces less efficient code than
254        // `f.write_str("text")`. We recognize the case when the format string
255        // has no braces and no interpolated values, and generate simpler code.
256        tokens.extend(if self.requires_fmt_machinery {
257            quote! {
258                ::core::write!(__formatter, #fmt #args)
259            }
260        } else {
261            quote! {
262                __formatter.write_str(#fmt)
263            }
264        });
265    }
266}
267
268impl ToTokens for Trait {
269    fn to_tokens(&self, tokens: &mut TokenStream) {
270        let trait_name = match self {
271            Trait::Debug => "Debug",
272            Trait::Display => "Display",
273            Trait::Octal => "Octal",
274            Trait::LowerHex => "LowerHex",
275            Trait::UpperHex => "UpperHex",
276            Trait::Pointer => "Pointer",
277            Trait::Binary => "Binary",
278            Trait::LowerExp => "LowerExp",
279            Trait::UpperExp => "UpperExp",
280        };
281        let ident = Ident::new(trait_name, Span::call_site());
282        tokens.extend(quote!(::core::fmt::#ident));
283    }
284}