derive_more_impl/
unwrap.rs1use crate::utils::{AttrParams, DeriveType, State};
2use convert_case::{Case, Casing};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, Ident, Result, Type};
6
7pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8 let state = State::with_attr_params(
9 input,
10 trait_name,
11 "unwrap".into(),
12 AttrParams {
13 enum_: vec!["ignore", "owned", "ref", "ref_mut"],
14 variant: vec!["ignore", "owned", "ref", "ref_mut"],
15 struct_: vec!["ignore"],
16 field: vec!["ignore"],
17 },
18 )?;
19 assert!(
20 state.derive_type == DeriveType::Enum,
21 "Unwrap can only be derived for enums",
22 );
23
24 let enum_name = &input.ident;
25 let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
26
27 let variant_data = state.enabled_variant_data();
28
29 let mut funcs = vec![];
30 for (variant_state, info) in
31 Iterator::zip(variant_data.variant_states.iter(), variant_data.infos)
32 {
33 let variant = variant_state.variant.unwrap();
34 let fn_name = format_ident!(
35 "unwrap_{ident}",
36 ident = variant.ident.to_string().to_case(Case::Snake),
37 span = variant.ident.span(),
38 );
39 let ref_fn_name = format_ident!(
40 "unwrap_{ident}_ref",
41 ident = variant.ident.to_string().to_case(Case::Snake),
42 span = variant.ident.span(),
43 );
44 let mut_fn_name = format_ident!(
45 "unwrap_{ident}_mut",
46 ident = variant.ident.to_string().to_case(Case::Snake),
47 span = variant.ident.span(),
48 );
49 let variant_ident = &variant.ident;
50 let (data_pattern, ret_value, data_types) = get_field_info(&variant.fields);
51 let pattern = quote! { #enum_name :: #variant_ident #data_pattern };
52
53 let (failed_block, failed_block_ref, failed_block_mut) = (
54 failed_block(&state, enum_name, &fn_name),
55 failed_block(&state, enum_name, &ref_fn_name),
56 failed_block(&state, enum_name, &mut_fn_name),
57 );
58
59 let doc_owned = format!(
60 "Unwraps this value to the `{enum_name}::{variant_ident}` variant.\n",
61 );
62 let doc_ref = format!(
63 "Unwraps this reference to the `{enum_name}::{variant_ident}` variant.\n",
64 );
65 let doc_mut = format!(
66 "Unwraps this mutable reference to the `{enum_name}::{variant_ident}` variant.\n",
67 );
68 let doc_else = "Panics if this value is of any other type.";
69
70 let func = quote! {
71 #[inline]
72 #[track_caller]
73 #[doc = #doc_owned]
74 #[doc = #doc_else]
75 pub fn #fn_name(self) -> (#(#data_types),*) {
76 match self {
77 #pattern => #ret_value,
78 val @ _ => #failed_block,
79 }
80 }
81 };
82
83 let ref_func = quote! {
84 #[inline]
85 #[track_caller]
86 #[doc = #doc_ref]
87 #[doc = #doc_else]
88 pub fn #ref_fn_name(&self) -> (#(&#data_types),*) {
89 match self {
90 #pattern => #ret_value,
91 val @ _ => #failed_block_ref,
92 }
93 }
94 };
95
96 let mut_func = quote! {
97 #[inline]
98 #[track_caller]
99 #[doc = #doc_mut]
100 #[doc = #doc_else]
101 pub fn #mut_fn_name(&mut self) -> (#(&mut #data_types),*) {
102 match self {
103 #pattern => #ret_value,
104 val @ _ => #failed_block_mut,
105 }
106 }
107 };
108
109 if info.owned && state.default_info.owned {
110 funcs.push(func);
111 }
112 if info.ref_ && state.default_info.ref_ {
113 funcs.push(ref_func);
114 }
115 if info.ref_mut && state.default_info.ref_mut {
116 funcs.push(mut_func);
117 }
118 }
119
120 let imp = quote! {
121 #[allow(unreachable_code)] #[automatically_derived]
123 impl #imp_generics #enum_name #type_generics #where_clause {
124 #(#funcs)*
125 }
126 };
127
128 Ok(imp)
129}
130
131fn get_field_info(fields: &Fields) -> (TokenStream, TokenStream, Vec<&Type>) {
132 match fields {
133 Fields::Named(_) => panic!("cannot unwrap anonymous records"),
134 Fields::Unnamed(ref fields) => {
135 let (idents, types) = fields
136 .unnamed
137 .iter()
138 .enumerate()
139 .map(|(n, it)| (format_ident!("field_{n}"), &it.ty))
140 .unzip::<_, _, Vec<_>, Vec<_>>();
141 (quote! { (#(#idents),*) }, quote! { (#(#idents),*) }, types)
142 }
143 Fields::Unit => (quote! {}, quote! { () }, vec![]),
144 }
145}
146
147fn failed_block(state: &State, enum_name: &Ident, fn_name: &Ident) -> TokenStream {
148 let arms = state
149 .variant_states
150 .iter()
151 .map(|it| it.variant.unwrap())
152 .map(|variant| {
153 let data_pattern = match variant.fields {
154 Fields::Named(_) => quote! { {..} },
155 Fields::Unnamed(_) => quote! { (..) },
156 Fields::Unit => quote! {},
157 };
158 let variant_ident = &variant.ident;
159 let panic_msg = format!(
160 "called `{enum_name}::{fn_name}()` on a `{enum_name}::{variant_ident}` value"
161 );
162 quote! { #enum_name :: #variant_ident #data_pattern => panic!(#panic_msg) }
163 });
164
165 quote! {
166 match val {
167 #(#arms),*
168 }
169 }
170}