thiserror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use crate::span::MemberSpan;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use std::collections::BTreeSet as Set;
8use syn::{DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type};
9
10pub fn derive(input: &DeriveInput) -> TokenStream {
11    match try_expand(input) {
12        Ok(expanded) => expanded,
13        // If there are invalid attributes in the input, expand to an Error impl
14        // anyway to minimize spurious knock-on errors in other code that uses
15        // this type as an Error.
16        Err(error) => fallback(input, error),
17    }
18}
19
20fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
21    let input = Input::from_syn(input)?;
22    input.validate()?;
23    Ok(match input {
24        Input::Struct(input) => impl_struct(input),
25        Input::Enum(input) => impl_enum(input),
26    })
27}
28
29fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream {
30    let ty = &input.ident;
31    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33    let error = error.to_compile_error();
34
35    quote! {
36        #error
37
38        #[allow(unused_qualifications)]
39        #[automatically_derived]
40        impl #impl_generics std::error::Error for #ty #ty_generics #where_clause
41        where
42            // Work around trivial bounds being unstable.
43            // https://github.com/rust-lang/rust/issues/48214
44            for<'workaround> #ty #ty_generics: ::core::fmt::Debug,
45        {}
46
47        #[allow(unused_qualifications)]
48        #[automatically_derived]
49        impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
50            fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
51                ::core::unreachable!()
52            }
53        }
54    }
55}
56
57fn impl_struct(input: Struct) -> TokenStream {
58    let ty = &input.ident;
59    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
60    let mut error_inferred_bounds = InferredBounds::new();
61
62    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
63        let only_field = &input.fields[0];
64        if only_field.contains_generic {
65            error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
66        }
67        let member = &only_field.member;
68        Some(quote_spanned! {transparent_attr.span=>
69            std::error::Error::source(self.#member.as_dyn_error())
70        })
71    } else if let Some(source_field) = input.source_field() {
72        let source = &source_field.member;
73        if source_field.contains_generic {
74            let ty = unoptional_type(source_field.ty);
75            error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
76        }
77        let asref = if type_is_option(source_field.ty) {
78            Some(quote_spanned!(source.member_span()=> .as_ref()?))
79        } else {
80            None
81        };
82        let dyn_error = quote_spanned! {source_field.source_span()=>
83            self.#source #asref.as_dyn_error()
84        };
85        Some(quote! {
86            ::core::option::Option::Some(#dyn_error)
87        })
88    } else {
89        None
90    };
91    let source_method = source_body.map(|body| {
92        quote! {
93            fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
94                use thiserror::__private::AsDynError as _;
95                #body
96            }
97        }
98    });
99
100    let provide_method = input.backtrace_field().map(|backtrace_field| {
101        let request = quote!(request);
102        let backtrace = &backtrace_field.member;
103        let body = if let Some(source_field) = input.source_field() {
104            let source = &source_field.member;
105            let source_provide = if type_is_option(source_field.ty) {
106                quote_spanned! {source.member_span()=>
107                    if let ::core::option::Option::Some(source) = &self.#source {
108                        source.thiserror_provide(#request);
109                    }
110                }
111            } else {
112                quote_spanned! {source.member_span()=>
113                    self.#source.thiserror_provide(#request);
114                }
115            };
116            let self_provide = if source == backtrace {
117                None
118            } else if type_is_option(backtrace_field.ty) {
119                Some(quote! {
120                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
121                        #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
122                    }
123                })
124            } else {
125                Some(quote! {
126                    #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
127                })
128            };
129            quote! {
130                use thiserror::__private::ThiserrorProvide as _;
131                #source_provide
132                #self_provide
133            }
134        } else if type_is_option(backtrace_field.ty) {
135            quote! {
136                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
137                    #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
138                }
139            }
140        } else {
141            quote! {
142                #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
143            }
144        };
145        quote! {
146            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
147                #body
148            }
149        }
150    });
151
152    let mut display_implied_bounds = Set::new();
153    let display_body = if input.attrs.transparent.is_some() {
154        let only_field = &input.fields[0].member;
155        display_implied_bounds.insert((0, Trait::Display));
156        Some(quote! {
157            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
158        })
159    } else if let Some(display) = &input.attrs.display {
160        display_implied_bounds.clone_from(&display.implied_bounds);
161        let use_as_display = use_as_display(display.has_bonus_display);
162        let pat = fields_pat(&input.fields);
163        Some(quote! {
164            #use_as_display
165            #[allow(unused_variables, deprecated)]
166            let Self #pat = self;
167            #display
168        })
169    } else {
170        None
171    };
172    let display_impl = display_body.map(|body| {
173        let mut display_inferred_bounds = InferredBounds::new();
174        for (field, bound) in display_implied_bounds {
175            let field = &input.fields[field];
176            if field.contains_generic {
177                display_inferred_bounds.insert(field.ty, bound);
178            }
179        }
180        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
181        quote! {
182            #[allow(unused_qualifications)]
183            #[automatically_derived]
184            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
185                #[allow(clippy::used_underscore_binding)]
186                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
187                    #body
188                }
189            }
190        }
191    });
192
193    let from_impl = input.from_field().map(|from_field| {
194        let backtrace_field = input.distinct_backtrace_field();
195        let from = unoptional_type(from_field.ty);
196        let body = from_initializer(from_field, backtrace_field);
197        quote! {
198            #[allow(unused_qualifications)]
199            #[automatically_derived]
200            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
201                #[allow(deprecated)]
202                fn from(source: #from) -> Self {
203                    #ty #body
204                }
205            }
206        }
207    });
208
209    if input.generics.type_params().next().is_some() {
210        let self_token = <Token![Self]>::default();
211        error_inferred_bounds.insert(self_token, Trait::Debug);
212        error_inferred_bounds.insert(self_token, Trait::Display);
213    }
214    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
215
216    quote! {
217        #[allow(unused_qualifications)]
218        #[automatically_derived]
219        impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
220            #source_method
221            #provide_method
222        }
223        #display_impl
224        #from_impl
225    }
226}
227
228fn impl_enum(input: Enum) -> TokenStream {
229    let ty = &input.ident;
230    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
231    let mut error_inferred_bounds = InferredBounds::new();
232
233    let source_method = if input.has_source() {
234        let arms = input.variants.iter().map(|variant| {
235            let ident = &variant.ident;
236            if let Some(transparent_attr) = &variant.attrs.transparent {
237                let only_field = &variant.fields[0];
238                if only_field.contains_generic {
239                    error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
240                }
241                let member = &only_field.member;
242                let source = quote_spanned! {transparent_attr.span=>
243                    std::error::Error::source(transparent.as_dyn_error())
244                };
245                quote! {
246                    #ty::#ident {#member: transparent} => #source,
247                }
248            } else if let Some(source_field) = variant.source_field() {
249                let source = &source_field.member;
250                if source_field.contains_generic {
251                    let ty = unoptional_type(source_field.ty);
252                    error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
253                }
254                let asref = if type_is_option(source_field.ty) {
255                    Some(quote_spanned!(source.member_span()=> .as_ref()?))
256                } else {
257                    None
258                };
259                let varsource = quote!(source);
260                let dyn_error = quote_spanned! {source_field.source_span()=>
261                    #varsource #asref.as_dyn_error()
262                };
263                quote! {
264                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
265                }
266            } else {
267                quote! {
268                    #ty::#ident {..} => ::core::option::Option::None,
269                }
270            }
271        });
272        Some(quote! {
273            fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
274                use thiserror::__private::AsDynError as _;
275                #[allow(deprecated)]
276                match self {
277                    #(#arms)*
278                }
279            }
280        })
281    } else {
282        None
283    };
284
285    let provide_method = if input.has_backtrace() {
286        let request = quote!(request);
287        let arms = input.variants.iter().map(|variant| {
288            let ident = &variant.ident;
289            match (variant.backtrace_field(), variant.source_field()) {
290                (Some(backtrace_field), Some(source_field))
291                    if backtrace_field.attrs.backtrace.is_none() =>
292                {
293                    let backtrace = &backtrace_field.member;
294                    let source = &source_field.member;
295                    let varsource = quote!(source);
296                    let source_provide = if type_is_option(source_field.ty) {
297                        quote_spanned! {source.member_span()=>
298                            if let ::core::option::Option::Some(source) = #varsource {
299                                source.thiserror_provide(#request);
300                            }
301                        }
302                    } else {
303                        quote_spanned! {source.member_span()=>
304                            #varsource.thiserror_provide(#request);
305                        }
306                    };
307                    let self_provide = if type_is_option(backtrace_field.ty) {
308                        quote! {
309                            if let ::core::option::Option::Some(backtrace) = backtrace {
310                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
311                            }
312                        }
313                    } else {
314                        quote! {
315                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
316                        }
317                    };
318                    quote! {
319                        #ty::#ident {
320                            #backtrace: backtrace,
321                            #source: #varsource,
322                            ..
323                        } => {
324                            use thiserror::__private::ThiserrorProvide as _;
325                            #source_provide
326                            #self_provide
327                        }
328                    }
329                }
330                (Some(backtrace_field), Some(source_field))
331                    if backtrace_field.member == source_field.member =>
332                {
333                    let backtrace = &backtrace_field.member;
334                    let varsource = quote!(source);
335                    let source_provide = if type_is_option(source_field.ty) {
336                        quote_spanned! {backtrace.member_span()=>
337                            if let ::core::option::Option::Some(source) = #varsource {
338                                source.thiserror_provide(#request);
339                            }
340                        }
341                    } else {
342                        quote_spanned! {backtrace.member_span()=>
343                            #varsource.thiserror_provide(#request);
344                        }
345                    };
346                    quote! {
347                        #ty::#ident {#backtrace: #varsource, ..} => {
348                            use thiserror::__private::ThiserrorProvide as _;
349                            #source_provide
350                        }
351                    }
352                }
353                (Some(backtrace_field), _) => {
354                    let backtrace = &backtrace_field.member;
355                    let body = if type_is_option(backtrace_field.ty) {
356                        quote! {
357                            if let ::core::option::Option::Some(backtrace) = backtrace {
358                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
359                            }
360                        }
361                    } else {
362                        quote! {
363                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
364                        }
365                    };
366                    quote! {
367                        #ty::#ident {#backtrace: backtrace, ..} => {
368                            #body
369                        }
370                    }
371                }
372                (None, _) => quote! {
373                    #ty::#ident {..} => {}
374                },
375            }
376        });
377        Some(quote! {
378            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
379                #[allow(deprecated)]
380                match self {
381                    #(#arms)*
382                }
383            }
384        })
385    } else {
386        None
387    };
388
389    let display_impl = if input.has_display() {
390        let mut display_inferred_bounds = InferredBounds::new();
391        let has_bonus_display = input.variants.iter().any(|v| {
392            v.attrs
393                .display
394                .as_ref()
395                .map_or(false, |display| display.has_bonus_display)
396        });
397        let use_as_display = use_as_display(has_bonus_display);
398        let void_deref = if input.variants.is_empty() {
399            Some(quote!(*))
400        } else {
401            None
402        };
403        let arms = input.variants.iter().map(|variant| {
404            let mut display_implied_bounds = Set::new();
405            let display = match &variant.attrs.display {
406                Some(display) => {
407                    display_implied_bounds.clone_from(&display.implied_bounds);
408                    display.to_token_stream()
409                }
410                None => {
411                    let only_field = match &variant.fields[0].member {
412                        Member::Named(ident) => ident.clone(),
413                        Member::Unnamed(index) => format_ident!("_{}", index),
414                    };
415                    display_implied_bounds.insert((0, Trait::Display));
416                    quote!(::core::fmt::Display::fmt(#only_field, __formatter))
417                }
418            };
419            for (field, bound) in display_implied_bounds {
420                let field = &variant.fields[field];
421                if field.contains_generic {
422                    display_inferred_bounds.insert(field.ty, bound);
423                }
424            }
425            let ident = &variant.ident;
426            let pat = fields_pat(&variant.fields);
427            quote! {
428                #ty::#ident #pat => #display
429            }
430        });
431        let arms = arms.collect::<Vec<_>>();
432        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
433        Some(quote! {
434            #[allow(unused_qualifications)]
435            #[automatically_derived]
436            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
437                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
438                    #use_as_display
439                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
440                    match #void_deref self {
441                        #(#arms,)*
442                    }
443                }
444            }
445        })
446    } else {
447        None
448    };
449
450    let from_impls = input.variants.iter().filter_map(|variant| {
451        let from_field = variant.from_field()?;
452        let backtrace_field = variant.distinct_backtrace_field();
453        let variant = &variant.ident;
454        let from = unoptional_type(from_field.ty);
455        let body = from_initializer(from_field, backtrace_field);
456        Some(quote! {
457            #[allow(unused_qualifications)]
458            #[automatically_derived]
459            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
460                #[allow(deprecated)]
461                fn from(source: #from) -> Self {
462                    #ty::#variant #body
463                }
464            }
465        })
466    });
467
468    if input.generics.type_params().next().is_some() {
469        let self_token = <Token![Self]>::default();
470        error_inferred_bounds.insert(self_token, Trait::Debug);
471        error_inferred_bounds.insert(self_token, Trait::Display);
472    }
473    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
474
475    quote! {
476        #[allow(unused_qualifications)]
477        #[automatically_derived]
478        impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
479            #source_method
480            #provide_method
481        }
482        #display_impl
483        #(#from_impls)*
484    }
485}
486
487fn fields_pat(fields: &[Field]) -> TokenStream {
488    let mut members = fields.iter().map(|field| &field.member).peekable();
489    match members.peek() {
490        Some(Member::Named(_)) => quote!({ #(#members),* }),
491        Some(Member::Unnamed(_)) => {
492            let vars = members.map(|member| match member {
493                Member::Unnamed(member) => format_ident!("_{}", member),
494                Member::Named(_) => unreachable!(),
495            });
496            quote!((#(#vars),*))
497        }
498        None => quote!({}),
499    }
500}
501
502fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
503    if needs_as_display {
504        Some(quote! {
505            use thiserror::__private::AsDisplay as _;
506        })
507    } else {
508        None
509    }
510}
511
512fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
513    let from_member = &from_field.member;
514    let some_source = if type_is_option(from_field.ty) {
515        quote!(::core::option::Option::Some(source))
516    } else {
517        quote!(source)
518    };
519    let backtrace = backtrace_field.map(|backtrace_field| {
520        let backtrace_member = &backtrace_field.member;
521        if type_is_option(backtrace_field.ty) {
522            quote! {
523                #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
524            }
525        } else {
526            quote! {
527                #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
528            }
529        }
530    });
531    quote!({
532        #from_member: #some_source,
533        #backtrace
534    })
535}
536
537fn type_is_option(ty: &Type) -> bool {
538    type_parameter_of_option(ty).is_some()
539}
540
541fn unoptional_type(ty: &Type) -> TokenStream {
542    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
543    quote!(#unoptional)
544}
545
546fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
547    let path = match ty {
548        Type::Path(ty) => &ty.path,
549        _ => return None,
550    };
551
552    let last = path.segments.last().unwrap();
553    if last.ident != "Option" {
554        return None;
555    }
556
557    let bracketed = match &last.arguments {
558        PathArguments::AngleBracketed(bracketed) => bracketed,
559        _ => return None,
560    };
561
562    if bracketed.args.len() != 1 {
563        return None;
564    }
565
566    match &bracketed.args[0] {
567        GenericArgument::Type(arg) => Some(arg),
568        _ => None,
569    }
570}