serde_derive/
bound.rs

1use crate::internals::ast::{Container, Data};
2use crate::internals::{attr, ungroup};
3use proc_macro2::Span;
4use std::collections::HashSet;
5use syn::punctuated::{Pair, Punctuated};
6use syn::Token;
7
8// Remove the default from every type parameter because in the generated impls
9// they look like associated types: "error: associated type bindings are not
10// allowed here".
11pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
12    syn::Generics {
13        params: generics
14            .params
15            .iter()
16            .map(|param| match param {
17                syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
18                    eq_token: None,
19                    default: None,
20                    ..param.clone()
21                }),
22                _ => param.clone(),
23            })
24            .collect(),
25        ..generics.clone()
26    }
27}
28
29pub fn with_where_predicates(
30    generics: &syn::Generics,
31    predicates: &[syn::WherePredicate],
32) -> syn::Generics {
33    let mut generics = generics.clone();
34    generics
35        .make_where_clause()
36        .predicates
37        .extend(predicates.iter().cloned());
38    generics
39}
40
41pub fn with_where_predicates_from_fields(
42    cont: &Container,
43    generics: &syn::Generics,
44    from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
45) -> syn::Generics {
46    let predicates = cont
47        .data
48        .all_fields()
49        .filter_map(|field| from_field(&field.attrs))
50        .flat_map(<[syn::WherePredicate]>::to_vec);
51
52    let mut generics = generics.clone();
53    generics.make_where_clause().predicates.extend(predicates);
54    generics
55}
56
57pub fn with_where_predicates_from_variants(
58    cont: &Container,
59    generics: &syn::Generics,
60    from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
61) -> syn::Generics {
62    let variants = match &cont.data {
63        Data::Enum(variants) => variants,
64        Data::Struct(_, _) => {
65            return generics.clone();
66        }
67    };
68
69    let predicates = variants
70        .iter()
71        .filter_map(|variant| from_variant(&variant.attrs))
72        .flat_map(<[syn::WherePredicate]>::to_vec);
73
74    let mut generics = generics.clone();
75    generics.make_where_clause().predicates.extend(predicates);
76    generics
77}
78
79// Puts the given bound on any generic type parameters that are used in fields
80// for which filter returns true.
81//
82// For example, the following struct needs the bound `A: Serialize, B:
83// Serialize`.
84//
85//     struct S<'b, A, B: 'b, C> {
86//         a: A,
87//         b: Option<&'b B>
88//         #[serde(skip_serializing)]
89//         c: C,
90//     }
91pub fn with_bound(
92    cont: &Container,
93    generics: &syn::Generics,
94    filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
95    bound: &syn::Path,
96) -> syn::Generics {
97    struct FindTyParams<'ast> {
98        // Set of all generic type parameters on the current struct (A, B, C in
99        // the example). Initialized up front.
100        all_type_params: HashSet<syn::Ident>,
101
102        // Set of generic type parameters used in fields for which filter
103        // returns true (A and B in the example). Filled in as the visitor sees
104        // them.
105        relevant_type_params: HashSet<syn::Ident>,
106
107        // Fields whose type is an associated type of one of the generic type
108        // parameters.
109        associated_type_usage: Vec<&'ast syn::TypePath>,
110    }
111
112    impl<'ast> FindTyParams<'ast> {
113        fn visit_field(&mut self, field: &'ast syn::Field) {
114            if let syn::Type::Path(ty) = ungroup(&field.ty) {
115                if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
116                    if self.all_type_params.contains(&t.ident) {
117                        self.associated_type_usage.push(ty);
118                    }
119                }
120            }
121            self.visit_type(&field.ty);
122        }
123
124        fn visit_path(&mut self, path: &'ast syn::Path) {
125            if let Some(seg) = path.segments.last() {
126                if seg.ident == "PhantomData" {
127                    // Hardcoded exception, because PhantomData<T> implements
128                    // Serialize and Deserialize whether or not T implements it.
129                    return;
130                }
131            }
132            if path.leading_colon.is_none() && path.segments.len() == 1 {
133                let id = &path.segments[0].ident;
134                if self.all_type_params.contains(id) {
135                    self.relevant_type_params.insert(id.clone());
136                }
137            }
138            for segment in &path.segments {
139                self.visit_path_segment(segment);
140            }
141        }
142
143        // Everything below is simply traversing the syntax tree.
144
145        fn visit_type(&mut self, ty: &'ast syn::Type) {
146            match ty {
147                #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
148                syn::Type::Array(ty) => self.visit_type(&ty.elem),
149                syn::Type::BareFn(ty) => {
150                    for arg in &ty.inputs {
151                        self.visit_type(&arg.ty);
152                    }
153                    self.visit_return_type(&ty.output);
154                }
155                syn::Type::Group(ty) => self.visit_type(&ty.elem),
156                syn::Type::ImplTrait(ty) => {
157                    for bound in &ty.bounds {
158                        self.visit_type_param_bound(bound);
159                    }
160                }
161                syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
162                syn::Type::Paren(ty) => self.visit_type(&ty.elem),
163                syn::Type::Path(ty) => {
164                    if let Some(qself) = &ty.qself {
165                        self.visit_type(&qself.ty);
166                    }
167                    self.visit_path(&ty.path);
168                }
169                syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
170                syn::Type::Reference(ty) => self.visit_type(&ty.elem),
171                syn::Type::Slice(ty) => self.visit_type(&ty.elem),
172                syn::Type::TraitObject(ty) => {
173                    for bound in &ty.bounds {
174                        self.visit_type_param_bound(bound);
175                    }
176                }
177                syn::Type::Tuple(ty) => {
178                    for elem in &ty.elems {
179                        self.visit_type(elem);
180                    }
181                }
182
183                syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
184
185                _ => {}
186            }
187        }
188
189        fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
190            self.visit_path_arguments(&segment.arguments);
191        }
192
193        fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) {
194            match arguments {
195                syn::PathArguments::None => {}
196                syn::PathArguments::AngleBracketed(arguments) => {
197                    for arg in &arguments.args {
198                        match arg {
199                            #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
200                            syn::GenericArgument::Type(arg) => self.visit_type(arg),
201                            syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
202                            syn::GenericArgument::Lifetime(_)
203                            | syn::GenericArgument::Const(_)
204                            | syn::GenericArgument::AssocConst(_)
205                            | syn::GenericArgument::Constraint(_) => {}
206                            _ => {}
207                        }
208                    }
209                }
210                syn::PathArguments::Parenthesized(arguments) => {
211                    for argument in &arguments.inputs {
212                        self.visit_type(argument);
213                    }
214                    self.visit_return_type(&arguments.output);
215                }
216            }
217        }
218
219        fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
220            match return_type {
221                syn::ReturnType::Default => {}
222                syn::ReturnType::Type(_, output) => self.visit_type(output),
223            }
224        }
225
226        fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) {
227            match bound {
228                #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
229                syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
230                syn::TypeParamBound::Lifetime(_)
231                | syn::TypeParamBound::PreciseCapture(_)
232                | syn::TypeParamBound::Verbatim(_) => {}
233                _ => {}
234            }
235        }
236
237        // Type parameter should not be considered used by a macro path.
238        //
239        //     struct TypeMacro<T> {
240        //         mac: T!(),
241        //         marker: PhantomData<T>,
242        //     }
243        fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
244    }
245
246    let all_type_params = generics
247        .type_params()
248        .map(|param| param.ident.clone())
249        .collect();
250
251    let mut visitor = FindTyParams {
252        all_type_params,
253        relevant_type_params: HashSet::new(),
254        associated_type_usage: Vec::new(),
255    };
256    match &cont.data {
257        Data::Enum(variants) => {
258            for variant in variants {
259                let relevant_fields = variant
260                    .fields
261                    .iter()
262                    .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
263                for field in relevant_fields {
264                    visitor.visit_field(field.original);
265                }
266            }
267        }
268        Data::Struct(_, fields) => {
269            for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
270                visitor.visit_field(field.original);
271            }
272        }
273    }
274
275    let relevant_type_params = visitor.relevant_type_params;
276    let associated_type_usage = visitor.associated_type_usage;
277    let new_predicates = generics
278        .type_params()
279        .map(|param| param.ident.clone())
280        .filter(|id| relevant_type_params.contains(id))
281        .map(|id| syn::TypePath {
282            qself: None,
283            path: id.into(),
284        })
285        .chain(associated_type_usage.into_iter().cloned())
286        .map(|bounded_ty| {
287            syn::WherePredicate::Type(syn::PredicateType {
288                lifetimes: None,
289                // the type parameter that is being bounded e.g. T
290                bounded_ty: syn::Type::Path(bounded_ty),
291                colon_token: <Token![:]>::default(),
292                // the bound e.g. Serialize
293                bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
294                    paren_token: None,
295                    modifier: syn::TraitBoundModifier::None,
296                    lifetimes: None,
297                    path: bound.clone(),
298                })]
299                .into_iter()
300                .collect(),
301            })
302        });
303
304    let mut generics = generics.clone();
305    generics
306        .make_where_clause()
307        .predicates
308        .extend(new_predicates);
309    generics
310}
311
312pub fn with_self_bound(
313    cont: &Container,
314    generics: &syn::Generics,
315    bound: &syn::Path,
316) -> syn::Generics {
317    let mut generics = generics.clone();
318    generics
319        .make_where_clause()
320        .predicates
321        .push(syn::WherePredicate::Type(syn::PredicateType {
322            lifetimes: None,
323            // the type that is being bounded e.g. MyStruct<'a, T>
324            bounded_ty: type_of_item(cont),
325            colon_token: <Token![:]>::default(),
326            // the bound e.g. Default
327            bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
328                paren_token: None,
329                modifier: syn::TraitBoundModifier::None,
330                lifetimes: None,
331                path: bound.clone(),
332            })]
333            .into_iter()
334            .collect(),
335        }));
336    generics
337}
338
339pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
340    let bound = syn::Lifetime::new(lifetime, Span::call_site());
341    let def = syn::LifetimeParam {
342        attrs: Vec::new(),
343        lifetime: bound.clone(),
344        colon_token: None,
345        bounds: Punctuated::new(),
346    };
347
348    let params = Some(syn::GenericParam::Lifetime(def))
349        .into_iter()
350        .chain(generics.params.iter().cloned().map(|mut param| {
351            match &mut param {
352                syn::GenericParam::Lifetime(param) => {
353                    param.bounds.push(bound.clone());
354                }
355                syn::GenericParam::Type(param) => {
356                    param
357                        .bounds
358                        .push(syn::TypeParamBound::Lifetime(bound.clone()));
359                }
360                syn::GenericParam::Const(_) => {}
361            }
362            param
363        }))
364        .collect();
365
366    syn::Generics {
367        params,
368        ..generics.clone()
369    }
370}
371
372fn type_of_item(cont: &Container) -> syn::Type {
373    syn::Type::Path(syn::TypePath {
374        qself: None,
375        path: syn::Path {
376            leading_colon: None,
377            segments: vec![syn::PathSegment {
378                ident: cont.ident.clone(),
379                arguments: syn::PathArguments::AngleBracketed(
380                    syn::AngleBracketedGenericArguments {
381                        colon2_token: None,
382                        lt_token: <Token![<]>::default(),
383                        args: cont
384                            .generics
385                            .params
386                            .iter()
387                            .map(|param| match param {
388                                syn::GenericParam::Type(param) => {
389                                    syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
390                                        qself: None,
391                                        path: param.ident.clone().into(),
392                                    }))
393                                }
394                                syn::GenericParam::Lifetime(param) => {
395                                    syn::GenericArgument::Lifetime(param.lifetime.clone())
396                                }
397                                syn::GenericParam::Const(_) => {
398                                    panic!("Serde does not support const generics yet");
399                                }
400                            })
401                            .collect(),
402                        gt_token: <Token![>]>::default(),
403                    },
404                ),
405            }]
406            .into_iter()
407            .collect(),
408        },
409    })
410}