1use crate::utils::{AttrParams, DeriveType, State};
2use convert_case::{Case, Casing};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, Ident, Result, Type};
6
7pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8 let state = State::with_attr_params(
9 input,
10 trait_name,
11 "try_unwrap".into(),
12 AttrParams {
13 enum_: vec!["ignore", "owned", "ref", "ref_mut"],
14 variant: vec!["ignore", "owned", "ref", "ref_mut"],
15 struct_: vec!["ignore"],
16 field: vec!["ignore"],
17 },
18 )?;
19 assert!(
20 state.derive_type == DeriveType::Enum,
21 "TryUnwrap can only be derived for enums",
22 );
23
24 let enum_name = &input.ident;
25 let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
26
27 let variant_data = state.enabled_variant_data();
28
29 let mut funcs = vec![];
30 for (variant_state, info) in
31 Iterator::zip(variant_data.variant_states.iter(), variant_data.infos)
32 {
33 let variant = variant_state.variant.unwrap();
34 let fn_name = format_ident!(
35 "try_unwrap_{ident}",
36 ident = variant.ident.to_string().to_case(Case::Snake),
37 span = variant.ident.span(),
38 );
39 let ref_fn_name = format_ident!(
40 "try_unwrap_{ident}_ref",
41 ident = variant.ident.to_string().to_case(Case::Snake),
42 span = variant.ident.span(),
43 );
44 let mut_fn_name = format_ident!(
45 "try_unwrap_{ident}_mut",
46 ident = variant.ident.to_string().to_case(Case::Snake),
47 span = variant.ident.span(),
48 );
49 let variant_ident = &variant.ident;
50 let (data_pattern, ret_value, data_types) = get_field_info(&variant.fields);
51 let pattern = quote! { #enum_name :: #variant_ident #data_pattern };
52
53 let (failed_block, failed_block_ref, failed_block_mut) = (
54 failed_block(&state, enum_name, &fn_name),
55 failed_block(&state, enum_name, &ref_fn_name),
56 failed_block(&state, enum_name, &mut_fn_name),
57 );
58
59 let doc_owned = format!(
60 "Attempts to unwrap this value to the `{enum_name}::{variant_ident}` variant.\n",
61 );
62 let doc_ref = format!(
63 "Attempts to unwrap this reference to the `{enum_name}::{variant_ident}` variant.\n",
64 );
65 let doc_mut = format!(
66 "Attempts to unwrap this mutable reference to the `{enum_name}::{variant_ident}` variant.\n",
67 );
68 let doc_else = "Returns a [TryUnwrapError] with the original value if this value is of any other type.";
69 let func = quote! {
70 #[inline]
71 #[track_caller]
72 #[doc = #doc_owned]
73 #[doc = #doc_else]
74 pub fn #fn_name(self) -> derive_more::core::result::Result<
75 (#(#data_types),*), derive_more::TryUnwrapError<Self>
76 > {
77 match self {
78 #pattern => derive_more::core::result::Result::Ok(#ret_value),
79 val @ _ => #failed_block,
80 }
81 }
82 };
83
84 let ref_func = quote! {
85 #[inline]
86 #[track_caller]
87 #[doc = #doc_ref]
88 #[doc = #doc_else]
89 pub fn #ref_fn_name(&self) -> derive_more::core::result::Result<
90 (#(&#data_types),*), derive_more::TryUnwrapError<&Self>
91 > {
92 match self {
93 #pattern => derive_more::core::result::Result::Ok(#ret_value),
94 val @ _ => #failed_block_ref,
95 }
96 }
97 };
98
99 let mut_func = quote! {
100 #[inline]
101 #[track_caller]
102 #[doc = #doc_mut]
103 #[doc = #doc_else]
104 pub fn #mut_fn_name(&mut self) -> derive_more::core::result::Result<
105 (#(&mut #data_types),*), derive_more::TryUnwrapError<&mut Self>
106 > {
107 match self {
108 #pattern => derive_more::core::result::Result::Ok(#ret_value),
109 val @ _ => #failed_block_mut,
110 }
111 }
112 };
113
114 if info.owned && state.default_info.owned {
115 funcs.push(func);
116 }
117 if info.ref_ && state.default_info.ref_ {
118 funcs.push(ref_func);
119 }
120 if info.ref_mut && state.default_info.ref_mut {
121 funcs.push(mut_func);
122 }
123 }
124
125 let imp = quote! {
126 #[allow(unreachable_code)] #[automatically_derived]
128 impl #imp_generics #enum_name #type_generics #where_clause {
129 #(#funcs)*
130 }
131 };
132
133 Ok(imp)
134}
135
136fn get_field_info(fields: &Fields) -> (TokenStream, TokenStream, Vec<&Type>) {
137 match fields {
138 Fields::Named(_) => panic!("cannot unwrap anonymous records"),
139 Fields::Unnamed(ref fields) => {
140 let (idents, types) = fields
141 .unnamed
142 .iter()
143 .enumerate()
144 .map(|(n, it)| (format_ident!("field_{n}"), &it.ty))
145 .unzip::<_, _, Vec<_>, Vec<_>>();
146 (quote! { (#(#idents),*) }, quote! { (#(#idents),*) }, types)
147 }
148 Fields::Unit => (quote! {}, quote! { () }, vec![]),
149 }
150}
151
152fn failed_block(state: &State, enum_name: &Ident, func_name: &Ident) -> TokenStream {
153 let arms = state
154 .variant_states
155 .iter()
156 .map(|it| it.variant.unwrap())
157 .map(|variant| {
158 let data_pattern = match variant.fields {
159 Fields::Named(_) => quote! { {..} },
160 Fields::Unnamed(_) => quote! { (..) },
161 Fields::Unit => quote! {},
162 };
163 let variant_ident = &variant.ident;
164 let error = quote! {
165 derive_more::TryUnwrapError::<_>::new(
166 val,
167 stringify!(#enum_name),
168 stringify!(#variant_ident),
169 stringify!(#func_name),
170 )
171 };
172 quote! {
173 val @ #enum_name :: #variant_ident #data_pattern
174 => derive_more::core::result::Result::Err(#error)
175 }
176 });
177
178 quote! {
179 match val {
180 #(#arms),*
181 }
182 }
183}