thiserror_impl/
fmt.rs

1use crate::ast::Field;
2use crate::attr::{Display, Trait};
3use crate::scan_expr::scan_expr;
4use proc_macro2::{TokenStream, TokenTree};
5use quote::{format_ident, quote, quote_spanned};
6use std::collections::{BTreeSet as Set, HashMap as Map};
7use syn::ext::IdentExt;
8use syn::parse::discouraged::Speculative;
9use syn::parse::{ParseStream, Parser};
10use syn::{Expr, Ident, Index, LitStr, Member, Result, Token};
11
12impl Display<'_> {
13    // Transform `"error {var}"` to `"error {}", var`.
14    pub fn expand_shorthand(&mut self, fields: &[Field]) {
15        let raw_args = self.args.clone();
16        let mut named_args = explicit_named_args.parse2(raw_args).unwrap().named;
17        let mut member_index = Map::new();
18        for (i, field) in fields.iter().enumerate() {
19            member_index.insert(&field.member, i);
20        }
21
22        let span = self.fmt.span();
23        let fmt = self.fmt.value();
24        let mut read = fmt.as_str();
25        let mut out = String::new();
26        let mut args = self.args.clone();
27        let mut has_bonus_display = false;
28        let mut implied_bounds = Set::new();
29
30        let mut has_trailing_comma = false;
31        if let Some(TokenTree::Punct(punct)) = args.clone().into_iter().last() {
32            if punct.as_char() == ',' {
33                has_trailing_comma = true;
34            }
35        }
36
37        self.requires_fmt_machinery = self.requires_fmt_machinery || fmt.contains('}');
38
39        while let Some(brace) = read.find('{') {
40            self.requires_fmt_machinery = true;
41            out += &read[..brace + 1];
42            read = &read[brace + 1..];
43            if read.starts_with('{') {
44                out.push('{');
45                read = &read[1..];
46                continue;
47            }
48            let next = match read.chars().next() {
49                Some(next) => next,
50                None => return,
51            };
52            let member = match next {
53                '0'..='9' => {
54                    let int = take_int(&mut read);
55                    let member = match int.parse::<u32>() {
56                        Ok(index) => Member::Unnamed(Index { index, span }),
57                        Err(_) => return,
58                    };
59                    if !member_index.contains_key(&member) {
60                        out += &int;
61                        continue;
62                    }
63                    member
64                }
65                'a'..='z' | 'A'..='Z' | '_' => {
66                    let mut ident = take_ident(&mut read);
67                    ident.set_span(span);
68                    Member::Named(ident)
69                }
70                _ => continue,
71            };
72            if let Some(&field) = member_index.get(&member) {
73                let end_spec = match read.find('}') {
74                    Some(end_spec) => end_spec,
75                    None => return,
76                };
77                let bound = match read[..end_spec].chars().next_back() {
78                    Some('?') => Trait::Debug,
79                    Some('o') => Trait::Octal,
80                    Some('x') => Trait::LowerHex,
81                    Some('X') => Trait::UpperHex,
82                    Some('p') => Trait::Pointer,
83                    Some('b') => Trait::Binary,
84                    Some('e') => Trait::LowerExp,
85                    Some('E') => Trait::UpperExp,
86                    Some(_) | None => Trait::Display,
87                };
88                implied_bounds.insert((field, bound));
89            }
90            let local = match &member {
91                Member::Unnamed(index) => format_ident!("_{}", index),
92                Member::Named(ident) => ident.clone(),
93            };
94            let mut formatvar = local.clone();
95            if formatvar.to_string().starts_with("r#") {
96                formatvar = format_ident!("r_{}", formatvar);
97            }
98            out += &formatvar.to_string();
99            if !named_args.insert(formatvar.clone()) {
100                // Already specified in the format argument list.
101                continue;
102            }
103            if !has_trailing_comma {
104                args.extend(quote_spanned!(span=> ,));
105            }
106            args.extend(quote_spanned!(span=> #formatvar = #local));
107            if read.starts_with('}') && member_index.contains_key(&member) {
108                has_bonus_display = true;
109                args.extend(quote_spanned!(span=> .as_display()));
110            }
111            has_trailing_comma = false;
112        }
113
114        out += read;
115        self.fmt = LitStr::new(&out, self.fmt.span());
116        self.args = args;
117        self.has_bonus_display = has_bonus_display;
118        self.implied_bounds = implied_bounds;
119    }
120}
121
122struct FmtArguments {
123    named: Set<Ident>,
124    unnamed: bool,
125}
126
127#[allow(clippy::unnecessary_wraps)]
128fn explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
129    let ahead = input.fork();
130    if let Ok(set) = try_explicit_named_args(&ahead) {
131        input.advance_to(&ahead);
132        return Ok(set);
133    }
134
135    let ahead = input.fork();
136    if let Ok(set) = fallback_explicit_named_args(&ahead) {
137        input.advance_to(&ahead);
138        return Ok(set);
139    }
140
141    input.parse::<TokenStream>().unwrap();
142    Ok(FmtArguments {
143        named: Set::new(),
144        unnamed: false,
145    })
146}
147
148fn try_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
149    let mut syn_full = None;
150    let mut args = FmtArguments {
151        named: Set::new(),
152        unnamed: false,
153    };
154
155    while !input.is_empty() {
156        input.parse::<Token![,]>()?;
157        if input.is_empty() {
158            break;
159        }
160        if input.peek(Ident::peek_any) && input.peek2(Token![=]) && !input.peek2(Token![==]) {
161            let ident = input.call(Ident::parse_any)?;
162            input.parse::<Token![=]>()?;
163            args.named.insert(ident);
164        } else {
165            args.unnamed = true;
166        }
167        if *syn_full.get_or_insert_with(is_syn_full) {
168            let ahead = input.fork();
169            if ahead.parse::<Expr>().is_ok() {
170                input.advance_to(&ahead);
171                continue;
172            }
173        }
174        scan_expr(input)?;
175    }
176
177    Ok(args)
178}
179
180fn fallback_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
181    let mut args = FmtArguments {
182        named: Set::new(),
183        unnamed: false,
184    };
185
186    while !input.is_empty() {
187        if input.peek(Token![,])
188            && input.peek2(Ident::peek_any)
189            && input.peek3(Token![=])
190            && !input.peek3(Token![==])
191        {
192            input.parse::<Token![,]>()?;
193            let ident = input.call(Ident::parse_any)?;
194            input.parse::<Token![=]>()?;
195            args.named.insert(ident);
196        } else {
197            input.parse::<TokenTree>()?;
198        }
199    }
200
201    Ok(args)
202}
203
204fn is_syn_full() -> bool {
205    // Expr::Block contains syn::Block which contains Vec<syn::Stmt>. In the
206    // current version of Syn, syn::Stmt is exhaustive and could only plausibly
207    // represent `trait Trait {}` in Stmt::Item which contains syn::Item. Most
208    // of the point of syn's non-"full" mode is to avoid compiling Item and the
209    // entire expansive syntax tree it comprises. So the following expression
210    // being parsed to Expr::Block is a reliable indication that "full" is
211    // enabled.
212    let test = quote!({
213        trait Trait {}
214    });
215    match syn::parse2(test) {
216        Ok(Expr::Verbatim(_)) | Err(_) => false,
217        Ok(Expr::Block(_)) => true,
218        Ok(_) => unreachable!(),
219    }
220}
221
222fn take_int(read: &mut &str) -> String {
223    let mut int = String::new();
224    for (i, ch) in read.char_indices() {
225        match ch {
226            '0'..='9' => int.push(ch),
227            _ => {
228                *read = &read[i..];
229                break;
230            }
231        }
232    }
233    int
234}
235
236fn take_ident(read: &mut &str) -> Ident {
237    let mut ident = String::new();
238    let raw = read.starts_with("r#");
239    if raw {
240        ident.push_str("r#");
241        *read = &read[2..];
242    }
243    for (i, ch) in read.char_indices() {
244        match ch {
245            'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => ident.push(ch),
246            _ => {
247                *read = &read[i..];
248                break;
249            }
250        }
251    }
252    Ident::parse_any.parse_str(&ident).unwrap()
253}