amplify_derive/
getters.rs

1// Rust language amplification derive library providing multiple generic trait
2// implementations, type wrappers, derive macros and other language enhancements
3//
4// Written in 2019-2020 by
5//     Dr. Maxim Orlovsky <orlovsky@pandoracore.com>
6//
7// To the extent possible under law, the author(s) have dedicated all
8// copyright and related and neighboring rights to this software to
9// the public domain worldwide. This software is distributed without
10// any warranty.
11//
12// You should have received a copy of the MIT License
13// along with this software.
14// If not, see <https://opensource.org/licenses/MIT>.
15
16use std::collections::HashMap;
17use std::convert::TryInto;
18use std::iter::FromIterator;
19
20use amplify_syn::{ArgValue, ArgValueReq, AttrReq, ParametrizedAttr, ValueClass};
21use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
22use quote::ToTokens;
23use syn::spanned::Spanned;
24use syn::{
25    Attribute, Data, DataStruct, DeriveInput, Error, Field, Fields, ImplGenerics, LitStr, Result,
26    TypeGenerics, WhereClause,
27};
28
29pub(crate) fn derive(input: DeriveInput) -> Result<TokenStream2> {
30    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
31    let struct_name = &input.ident;
32
33    let mut global_param = ParametrizedAttr::with("getter", &input.attrs)?;
34    let _ = GetterDerive::try_from(&mut global_param, true)?;
35
36    match input.data {
37        Data::Struct(data) => derive_struct_impl(
38            data,
39            struct_name,
40            global_param,
41            impl_generics,
42            ty_generics,
43            where_clause,
44        ),
45        Data::Enum(_) => {
46            Err(Error::new_spanned(&input, "Deriving getters is not supported in enums"))
47        }
48        Data::Union(_) => {
49            Err(Error::new_spanned(&input, "Deriving getters is not supported in unions"))
50        }
51    }
52}
53
54#[derive(Clone)]
55struct GetterDerive {
56    pub prefix: LitStr,
57    // pub doc: Attribute,
58    pub skip: bool,
59    pub copy: bool,
60    pub base: Option<LitStr>,
61    pub main: Option<LitStr>,
62    pub as_ref: Option<LitStr>,
63    pub as_mut: Option<LitStr>,
64}
65
66impl GetterDerive {
67    #[allow(clippy::blocks_in_if_conditions)]
68    fn try_from(attr: &mut ParametrizedAttr, global: bool) -> Result<GetterDerive> {
69        let mut map = HashMap::from_iter(vec![
70            ("prefix", ArgValueReq::with_default("")),
71            ("all", ArgValueReq::Prohibited),
72            ("as_copy", ArgValueReq::with_default("")),
73            ("as_clone", ArgValueReq::with_default("")),
74            ("as_ref", ArgValueReq::with_default("")),
75            ("as_mut", ArgValueReq::with_default("_mut")),
76        ]);
77
78        if !global {
79            map.insert("skip", ArgValueReq::Prohibited);
80            map.insert("base_name", ArgValueReq::Optional(ValueClass::str()));
81        }
82
83        attr.check(AttrReq::with(map))?;
84
85        if attr.args.contains_key("all") {
86            if attr.args.contains_key("as_clone") ||
87                attr.args.contains_key("as_ref") ||
88                attr.args.contains_key("as_mut")
89            {
90                return Err(Error::new(
91                    Span::call_site(),
92                    "`all` attribute can't be combined with other",
93                ));
94            }
95            attr.args.remove("all");
96            attr.args.insert("as_clone".to_owned(), ArgValue::from(""));
97            attr.args
98                .insert("as_ref".to_owned(), ArgValue::from("_ref"));
99            attr.args
100                .insert("as_mut".to_owned(), ArgValue::from("_mut"));
101        }
102
103        if attr.args.contains_key("as_clone") && attr.args.contains_key("as_copy") {
104            return Err(Error::new(
105                Span::call_site(),
106                "`as_clone` and `as_copy` attributes can't be present together",
107            ));
108        }
109
110        // If we have to return copy or a clone of value and did not explicitly
111        // specified different prefix for borrowing accessor, we need not to derive it
112        // since we will have a naming conflict
113        if (attr.args.contains_key("as_clone") || attr.args.contains_key("as_copy")) &&
114            attr.args
115                .get("as_ref")
116                .map(|a| {
117                    if let ArgValue::Literal(lit) = a {
118                        lit.to_token_stream().to_string() == "\"\""
119                    } else {
120                        false
121                    }
122                })
123                .unwrap_or_default()
124        {
125            attr.args.remove("as_ref");
126        }
127
128        // If we are not provided with any options, default to deriving borrows
129        if !(attr.args.contains_key("as_clone") ||
130            attr.args.contains_key("as_copy") ||
131            attr.args.contains_key("as_ref") ||
132            attr.args.contains_key("as_mut"))
133        {
134            attr.args.insert("as_ref".to_owned(), ArgValue::from(""));
135        }
136
137        Ok(GetterDerive {
138            prefix: attr
139                .args
140                .get("prefix")
141                .map(|a| a.clone().try_into())
142                .transpose()?
143                .unwrap_or_else(|| LitStr::new("", Span::call_site())),
144            skip: attr.args.get("skip").is_some(),
145            copy: attr.args.contains_key("as_copy"),
146            base: attr
147                .args
148                .get("base_name")
149                .map(|a| a.clone().try_into())
150                .transpose()?,
151            main: attr
152                .args
153                .get("as_copy")
154                .or_else(|| attr.args.get("as_clone"))
155                .map(|a| a.clone().try_into())
156                .transpose()?,
157            as_ref: attr
158                .args
159                .get("as_ref")
160                .map(|a| a.clone().try_into())
161                .transpose()?,
162            as_mut: attr
163                .args
164                .get("as_mut")
165                .map(|a| a.clone().try_into())
166                .transpose()?,
167        })
168    }
169}
170
171#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
172enum GetterMethod {
173    Main { copy: bool },
174    AsRef,
175    AsMut,
176}
177
178impl GetterMethod {
179    fn doc_phrase(&self) -> &'static str {
180        match self {
181            GetterMethod::Main { copy: true } => "returning copy of",
182            GetterMethod::Main { copy: false } => "cloning",
183            GetterMethod::AsRef => "borrowing",
184            GetterMethod::AsMut => "returning mutable borrow of",
185        }
186    }
187
188    fn mut_prefix(&self) -> TokenStream2 {
189        match self {
190            GetterMethod::Main { copy: true } => quote! {},
191            GetterMethod::Main { copy: false } => quote! {},
192            GetterMethod::AsRef => quote! {},
193            GetterMethod::AsMut => quote! { mut },
194        }
195    }
196
197    fn ret_prefix(&self) -> TokenStream2 {
198        match self {
199            GetterMethod::Main { copy: true } => quote! {},
200            GetterMethod::Main { copy: false } => quote! {},
201            GetterMethod::AsRef => quote! { & },
202            GetterMethod::AsMut => quote! { &mut },
203        }
204    }
205
206    fn ret_suffix(&self) -> TokenStream2 {
207        match self {
208            GetterMethod::Main { copy: true } => quote! {},
209            GetterMethod::Main { copy: false } => quote! { .clone() },
210            GetterMethod::AsRef => quote! {},
211            GetterMethod::AsMut => quote! {},
212        }
213    }
214}
215
216impl GetterDerive {
217    pub fn all_methods(&self) -> Vec<GetterMethod> {
218        let mut methods = Vec::with_capacity(3);
219        if self.main.is_some() {
220            methods.push(GetterMethod::Main { copy: self.copy });
221        }
222        if self.as_ref.is_some() {
223            methods.push(GetterMethod::AsRef);
224        }
225        if self.as_mut.is_some() {
226            methods.push(GetterMethod::AsMut);
227        }
228        methods
229    }
230
231    pub fn getter_fn_ident(
232        &self,
233        method: GetterMethod,
234        field_name: Option<&Ident>,
235        span: Span,
236    ) -> Result<Ident> {
237        let base_string = self
238            .base
239            .as_ref()
240            .map(LitStr::value)
241            .or_else(|| field_name.map(Ident::to_string))
242            .ok_or_else(|| {
243                Error::new(
244                    span,
245                    "Unnamed fields must be equipped with `#[getter(base_name = \"name\"]` \
246                     attribute",
247                )
248            })?;
249
250        let name_lit = match method {
251            GetterMethod::Main { .. } => &self.main,
252            GetterMethod::AsRef => &self.as_ref,
253            GetterMethod::AsMut => &self.as_mut,
254        }
255        .clone()
256        .expect("Internal inconsistency in getter derivation macro implementation");
257
258        let s = format!("{}{}{}", self.prefix.value(), base_string, name_lit.value());
259
260        Ok(Ident::new(&s, span))
261    }
262
263    pub fn getter_fn_doc(
264        &self,
265        method: GetterMethod,
266        struct_name: &Ident,
267        field_name: Option<&Ident>,
268        field_index: usize,
269        field_doc: Option<&Attribute>,
270    ) -> TokenStream2 {
271        let fn_doc = format!(
272            "Method {} [`{}::{}`] field.\n",
273            method.doc_phrase(),
274            struct_name,
275            field_name
276                .map(Ident::to_string)
277                .unwrap_or_else(|| field_index.to_string())
278        );
279
280        if let Some(field_doc) = field_doc {
281            quote! {
282                #[doc = #fn_doc]
283                #field_doc
284            }
285        } else {
286            quote! {
287                #[doc = #fn_doc]
288            }
289        }
290    }
291}
292
293fn derive_struct_impl(
294    data: DataStruct,
295    struct_name: &Ident,
296    global_param: ParametrizedAttr,
297    impl_generics: ImplGenerics,
298    ty_generics: TypeGenerics,
299    where_clause: Option<&WhereClause>,
300) -> Result<TokenStream2> {
301    let mut methods = Vec::with_capacity(data.fields.len());
302    match data.fields {
303        Fields::Named(ref fields) => {
304            for (index, field) in fields.named.iter().enumerate() {
305                methods.extend(derive_field_methods(field, index, struct_name, &global_param)?)
306            }
307        }
308        Fields::Unnamed(_) => {
309            return Err(Error::new(
310                Span::call_site(),
311                "Deriving getters is not supported for tuple-bases structs",
312            ));
313        }
314        Fields::Unit => {
315            return Err(Error::new(
316                Span::call_site(),
317                "Deriving getters is meaningless for unit structs",
318            ));
319        }
320    };
321
322    Ok(quote! {
323        #[automatically_derived]
324        impl #impl_generics #struct_name #ty_generics #where_clause {
325            #( #methods )*
326        }
327    })
328}
329
330fn derive_field_methods(
331    field: &Field,
332    index: usize,
333    struct_name: &Ident,
334    global_param: &ParametrizedAttr,
335) -> Result<Vec<TokenStream2>> {
336    let mut local_param = ParametrizedAttr::with("getter", &field.attrs)?;
337
338    // First, test individual attribute
339    let _ = GetterDerive::try_from(&mut local_param, false)?;
340    // Second, combine global and local together
341    let mut local_args = local_param.args.clone();
342    let mut params = global_param.clone().merged(local_param)?;
343    if local_args
344        .keys()
345        .any(|k| k == "as_copy" || k == "as_clone" || k == "as_ref")
346    {
347        // we have to use local arguments since they do override globals
348        params.args.remove("as_copy");
349        params.args.remove("as_clone");
350        params.args.remove("as_ref");
351        local_args
352            .remove("as_copy")
353            .map(|a| params.args.insert("as_copy".to_owned(), a));
354        local_args
355            .remove("as_clone")
356            .map(|a| params.args.insert("as_clone".to_owned(), a));
357        local_args
358            .remove("as_ref")
359            .map(|a| params.args.insert("as_ref".to_owned(), a));
360    }
361    let getter = GetterDerive::try_from(&mut params, false)?;
362
363    if getter.skip {
364        return Ok(Vec::new());
365    }
366
367    let field_name = field.ident.as_ref();
368    let ty = &field.ty;
369    let doc = field.attrs.iter().find(|a| a.path.is_ident("doc"));
370
371    let mut res = Vec::with_capacity(3);
372    for method in getter.all_methods() {
373        let fn_name = getter.getter_fn_ident(method, field_name, field.span())?;
374        let fn_doc = getter.getter_fn_doc(method, struct_name, field_name, index, doc);
375        let ret_prefix = method.ret_prefix();
376        let ret_suffix = method.ret_suffix();
377        let mut_prefix = method.mut_prefix();
378
379        res.push(quote_spanned! { field.span() =>
380            #fn_doc
381            #[inline]
382            pub fn #fn_name(&#mut_prefix self) -> #ret_prefix #ty {
383                #ret_prefix self.#field_name #ret_suffix
384            }
385        })
386    }
387
388    Ok(res)
389}