derive_more_impl/
try_unwrap.rs

1use crate::utils::{AttrParams, DeriveType, State};
2use convert_case::{Case, Casing};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, Ident, Result, Type};
6
7pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8    let state = State::with_attr_params(
9        input,
10        trait_name,
11        "try_unwrap".into(),
12        AttrParams {
13            enum_: vec!["ignore", "owned", "ref", "ref_mut"],
14            variant: vec!["ignore", "owned", "ref", "ref_mut"],
15            struct_: vec!["ignore"],
16            field: vec!["ignore"],
17        },
18    )?;
19    assert!(
20        state.derive_type == DeriveType::Enum,
21        "TryUnwrap can only be derived for enums",
22    );
23
24    let enum_name = &input.ident;
25    let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
26
27    let variant_data = state.enabled_variant_data();
28
29    let mut funcs = vec![];
30    for (variant_state, info) in
31        Iterator::zip(variant_data.variant_states.iter(), variant_data.infos)
32    {
33        let variant = variant_state.variant.unwrap();
34        let fn_name = format_ident!(
35            "try_unwrap_{ident}",
36            ident = variant.ident.to_string().to_case(Case::Snake),
37            span = variant.ident.span(),
38        );
39        let ref_fn_name = format_ident!(
40            "try_unwrap_{ident}_ref",
41            ident = variant.ident.to_string().to_case(Case::Snake),
42            span = variant.ident.span(),
43        );
44        let mut_fn_name = format_ident!(
45            "try_unwrap_{ident}_mut",
46            ident = variant.ident.to_string().to_case(Case::Snake),
47            span = variant.ident.span(),
48        );
49        let variant_ident = &variant.ident;
50        let (data_pattern, ret_value, data_types) = get_field_info(&variant.fields);
51        let pattern = quote! { #enum_name :: #variant_ident #data_pattern };
52
53        let (failed_block, failed_block_ref, failed_block_mut) = (
54            failed_block(&state, enum_name, &fn_name),
55            failed_block(&state, enum_name, &ref_fn_name),
56            failed_block(&state, enum_name, &mut_fn_name),
57        );
58
59        let doc_owned = format!(
60            "Attempts to unwrap this value to the `{enum_name}::{variant_ident}` variant.\n",
61        );
62        let doc_ref = format!(
63            "Attempts to unwrap this reference to the `{enum_name}::{variant_ident}` variant.\n",
64        );
65        let doc_mut = format!(
66            "Attempts to unwrap this mutable reference to the `{enum_name}::{variant_ident}` variant.\n",
67        );
68        let doc_else = "Returns a [TryUnwrapError] with the original value if this value is of any other type.";
69        let func = quote! {
70            #[inline]
71            #[track_caller]
72            #[doc = #doc_owned]
73            #[doc = #doc_else]
74            pub fn #fn_name(self) -> derive_more::core::result::Result<
75                (#(#data_types),*), derive_more::TryUnwrapError<Self>
76            > {
77                match self {
78                    #pattern => derive_more::core::result::Result::Ok(#ret_value),
79                    val @ _ => #failed_block,
80                }
81            }
82        };
83
84        let ref_func = quote! {
85            #[inline]
86            #[track_caller]
87            #[doc = #doc_ref]
88            #[doc = #doc_else]
89            pub fn #ref_fn_name(&self) -> derive_more::core::result::Result<
90                (#(&#data_types),*), derive_more::TryUnwrapError<&Self>
91            > {
92                match self {
93                    #pattern => derive_more::core::result::Result::Ok(#ret_value),
94                    val @ _ => #failed_block_ref,
95                }
96            }
97        };
98
99        let mut_func = quote! {
100            #[inline]
101            #[track_caller]
102            #[doc = #doc_mut]
103            #[doc = #doc_else]
104            pub fn #mut_fn_name(&mut self) -> derive_more::core::result::Result<
105                (#(&mut #data_types),*), derive_more::TryUnwrapError<&mut Self>
106            > {
107                match self {
108                    #pattern => derive_more::core::result::Result::Ok(#ret_value),
109                    val @ _ => #failed_block_mut,
110                }
111            }
112        };
113
114        if info.owned && state.default_info.owned {
115            funcs.push(func);
116        }
117        if info.ref_ && state.default_info.ref_ {
118            funcs.push(ref_func);
119        }
120        if info.ref_mut && state.default_info.ref_mut {
121            funcs.push(mut_func);
122        }
123    }
124
125    let imp = quote! {
126        #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
127        #[automatically_derived]
128        impl #imp_generics #enum_name #type_generics #where_clause {
129            #(#funcs)*
130        }
131    };
132
133    Ok(imp)
134}
135
136fn get_field_info(fields: &Fields) -> (TokenStream, TokenStream, Vec<&Type>) {
137    match fields {
138        Fields::Named(_) => panic!("cannot unwrap anonymous records"),
139        Fields::Unnamed(ref fields) => {
140            let (idents, types) = fields
141                .unnamed
142                .iter()
143                .enumerate()
144                .map(|(n, it)| (format_ident!("field_{n}"), &it.ty))
145                .unzip::<_, _, Vec<_>, Vec<_>>();
146            (quote! { (#(#idents),*) }, quote! { (#(#idents),*) }, types)
147        }
148        Fields::Unit => (quote! {}, quote! { () }, vec![]),
149    }
150}
151
152fn failed_block(state: &State, enum_name: &Ident, func_name: &Ident) -> TokenStream {
153    let arms = state
154        .variant_states
155        .iter()
156        .map(|it| it.variant.unwrap())
157        .map(|variant| {
158            let data_pattern = match variant.fields {
159                Fields::Named(_) => quote! { {..} },
160                Fields::Unnamed(_) => quote! { (..) },
161                Fields::Unit => quote! {},
162            };
163            let variant_ident = &variant.ident;
164            let error = quote! {
165                derive_more::TryUnwrapError::<_>::new(
166                    val,
167                    stringify!(#enum_name),
168                    stringify!(#variant_ident),
169                    stringify!(#func_name),
170                )
171            };
172            quote! {
173                val @ #enum_name :: #variant_ident #data_pattern
174                    => derive_more::core::result::Result::Err(#error)
175            }
176        });
177
178    quote! {
179        match val {
180            #(#arms),*
181        }
182    }
183}