borsh_derive/internals/
generics.rs

1use std::collections::{HashMap, HashSet};
2
3use quote::{quote, ToTokens};
4use syn::{
5    punctuated::Pair, Field, GenericArgument, Generics, Ident, Macro, Path, PathArguments,
6    PathSegment, ReturnType, Type, TypeParamBound, TypePath, WhereClause, WherePredicate,
7};
8
9pub fn default_where(where_clause: Option<&WhereClause>) -> WhereClause {
10    where_clause.map_or_else(
11        || WhereClause {
12            where_token: Default::default(),
13            predicates: Default::default(),
14        },
15        Clone::clone,
16    )
17}
18
19pub fn compute_predicates(params: Vec<Type>, traitname: &Path) -> Vec<WherePredicate> {
20    params
21        .into_iter()
22        .map(|param| {
23            syn::parse2(quote! {
24                #param: #traitname
25            })
26            .unwrap()
27        })
28        .collect()
29}
30
31// Remove the default from every type parameter because in the generated impls
32// they look like associated types: "error: associated type bindings are not
33// allowed here".
34pub fn without_defaults(generics: &Generics) -> Generics {
35    syn::Generics {
36        params: generics
37            .params
38            .iter()
39            .map(|param| match param {
40                syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
41                    eq_token: None,
42                    default: None,
43                    ..param.clone()
44                }),
45                _ => param.clone(),
46            })
47            .collect(),
48        ..generics.clone()
49    }
50}
51
52#[cfg(feature = "schema")]
53pub fn type_contains_some_param(type_: &Type, params: &HashSet<Ident>) -> bool {
54    let mut find: FindTyParams = FindTyParams::from_params(params.iter());
55
56    find.visit_type_top_level(type_);
57
58    find.at_least_one_hit()
59}
60
61/// a Visitor-like struct, which helps determine, if a type parameter is found in field
62#[derive(Clone)]
63pub struct FindTyParams {
64    // Set of all generic type parameters on the current struct . Initialized up front.
65    all_type_params: HashSet<Ident>,
66    all_type_params_ordered: Vec<Ident>,
67
68    // Set of generic type parameters used in fields for which filter
69    // returns true . Filled in as the visitor sees them.
70    relevant_type_params: HashSet<Ident>,
71
72    // [Param] => [Type, containing Param] mapping
73    associated_type_params_usage: HashMap<Ident, Vec<Type>>,
74}
75
76fn ungroup(mut ty: &Type) -> &Type {
77    while let Type::Group(group) = ty {
78        ty = &group.elem;
79    }
80    ty
81}
82
83impl FindTyParams {
84    pub fn new(generics: &Generics) -> Self {
85        let all_type_params = generics
86            .type_params()
87            .map(|param| param.ident.clone())
88            .collect();
89
90        let all_type_params_ordered = generics
91            .type_params()
92            .map(|param| param.ident.clone())
93            .collect();
94
95        FindTyParams {
96            all_type_params,
97            all_type_params_ordered,
98            relevant_type_params: HashSet::new(),
99            associated_type_params_usage: HashMap::new(),
100        }
101    }
102    pub fn process_for_bounds(self) -> Vec<Type> {
103        let relevant_type_params = self.relevant_type_params;
104        let associated_type_params_usage = self.associated_type_params_usage;
105        let mut new_predicates: Vec<Type> = vec![];
106        let mut new_predicates_set: HashSet<String> = HashSet::new();
107
108        self.all_type_params_ordered.iter().for_each(|param| {
109            if relevant_type_params.contains(param) {
110                let ty = Type::Path(TypePath {
111                    qself: None,
112                    path: param.clone().into(),
113                });
114                let ty_str_repr = ty.to_token_stream().to_string();
115                if !new_predicates_set.contains(&ty_str_repr) {
116                    new_predicates.push(ty);
117                    new_predicates_set.insert(ty_str_repr);
118                }
119            }
120            if let Some(vec_type) = associated_type_params_usage.get(param) {
121                for type_ in vec_type {
122                    let ty_str_repr = type_.to_token_stream().to_string();
123                    if !new_predicates_set.contains(&ty_str_repr) {
124                        new_predicates.push(type_.clone());
125                        new_predicates_set.insert(ty_str_repr);
126                    }
127                }
128            }
129        });
130
131        new_predicates
132    }
133}
134
135#[cfg(feature = "schema")]
136impl FindTyParams {
137    pub fn from_params<'a>(params: impl Iterator<Item = &'a Ident>) -> Self {
138        let all_type_params_ordered: Vec<Ident> = params.cloned().collect();
139        let all_type_params = all_type_params_ordered.clone().into_iter().collect();
140        FindTyParams {
141            all_type_params,
142            all_type_params_ordered,
143            relevant_type_params: HashSet::new(),
144            associated_type_params_usage: HashMap::new(),
145        }
146    }
147
148    pub fn process_for_params(self) -> Vec<Ident> {
149        let relevant_type_params = self.relevant_type_params;
150        let associated_type_params_usage = self.associated_type_params_usage;
151
152        let mut params: Vec<Ident> = vec![];
153        let mut params_set: HashSet<Ident> = HashSet::new();
154        self.all_type_params_ordered.iter().for_each(|param| {
155            if relevant_type_params.contains(param) && !params_set.contains(param) {
156                params.push(param.clone());
157                params_set.insert(param.clone());
158            }
159            if associated_type_params_usage.contains_key(param) && !params_set.contains(param) {
160                params.push(param.clone());
161                params_set.insert(param.clone());
162            }
163        });
164        params
165    }
166    pub fn at_least_one_hit(&self) -> bool {
167        !self.relevant_type_params.is_empty() || !self.associated_type_params_usage.is_empty()
168    }
169}
170
171impl FindTyParams {
172    pub fn visit_field(&mut self, field: &Field) {
173        self.visit_type_top_level(&field.ty);
174    }
175
176    pub fn visit_type_top_level(&mut self, type_: &Type) {
177        if let Type::Path(ty) = ungroup(type_) {
178            if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
179                if self.all_type_params.contains(&t.ident) {
180                    self.param_associated_type_insert(t.ident.clone(), type_.clone());
181                }
182            }
183        }
184        self.visit_type(type_);
185    }
186
187    pub fn param_associated_type_insert(&mut self, param: Ident, type_: Type) {
188        if let Some(type_vec) = self.associated_type_params_usage.get_mut(&param) {
189            type_vec.push(type_);
190        } else {
191            let type_vec = vec![type_];
192            self.associated_type_params_usage.insert(param, type_vec);
193        }
194    }
195
196    fn visit_return_type(&mut self, return_type: &ReturnType) {
197        match return_type {
198            ReturnType::Default => {}
199            ReturnType::Type(_, output) => self.visit_type(output),
200        }
201    }
202
203    fn visit_path_segment(&mut self, segment: &PathSegment) {
204        self.visit_path_arguments(&segment.arguments);
205    }
206
207    fn visit_path_arguments(&mut self, arguments: &PathArguments) {
208        match arguments {
209            PathArguments::None => {}
210            PathArguments::AngleBracketed(arguments) => {
211                for arg in &arguments.args {
212                    #[cfg_attr(
213                        feature = "force_exhaustive_checks",
214                        deny(non_exhaustive_omitted_patterns)
215                    )]
216                    match arg {
217                        GenericArgument::Type(arg) => self.visit_type(arg),
218                        GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
219                        GenericArgument::Lifetime(_)
220                        | GenericArgument::Const(_)
221                        | GenericArgument::AssocConst(_)
222                        | GenericArgument::Constraint(_) => {}
223                        _ => {}
224                    }
225                }
226            }
227            PathArguments::Parenthesized(arguments) => {
228                for argument in &arguments.inputs {
229                    self.visit_type(argument);
230                }
231                self.visit_return_type(&arguments.output);
232            }
233        }
234    }
235
236    fn visit_path(&mut self, path: &Path) {
237        if let Some(seg) = path.segments.last() {
238            if seg.ident == "PhantomData" {
239                // Hardcoded exception, because PhantomData<T> implements
240                // Serialize and Deserialize and Schema whether or not T implements it.
241                return;
242            }
243        }
244        if path.leading_colon.is_none() && path.segments.len() == 1 {
245            let id = &path.segments[0].ident;
246            if self.all_type_params.contains(id) {
247                self.relevant_type_params.insert(id.clone());
248            }
249        }
250        for segment in &path.segments {
251            self.visit_path_segment(segment);
252        }
253    }
254
255    fn visit_type_param_bound(&mut self, bound: &TypeParamBound) {
256        #[cfg_attr(
257            feature = "force_exhaustive_checks",
258            deny(non_exhaustive_omitted_patterns)
259        )]
260        match bound {
261            TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
262            TypeParamBound::Lifetime(_)
263            | TypeParamBound::Verbatim(_)
264            | TypeParamBound::PreciseCapture(_) => {}
265            _ => {}
266        }
267    }
268    // Type parameter should not be considered used by a macro path.
269    //
270    //     struct TypeMacro<T> {
271    //         mac: T!(),
272    //         marker: PhantomData<T>,
273    //     }
274    fn visit_macro(&mut self, _mac: &Macro) {}
275
276    fn visit_type(&mut self, ty: &Type) {
277        #[cfg_attr(
278            feature = "force_exhaustive_checks",
279            deny(non_exhaustive_omitted_patterns)
280        )]
281        match ty {
282            Type::Array(ty) => self.visit_type(&ty.elem),
283            Type::BareFn(ty) => {
284                for arg in &ty.inputs {
285                    self.visit_type(&arg.ty);
286                }
287                self.visit_return_type(&ty.output);
288            }
289            Type::Group(ty) => self.visit_type(&ty.elem),
290            Type::ImplTrait(ty) => {
291                for bound in &ty.bounds {
292                    self.visit_type_param_bound(bound);
293                }
294            }
295            Type::Macro(ty) => self.visit_macro(&ty.mac),
296            Type::Paren(ty) => self.visit_type(&ty.elem),
297            Type::Path(ty) => {
298                if let Some(qself) = &ty.qself {
299                    self.visit_type(&qself.ty);
300                }
301                self.visit_path(&ty.path);
302            }
303            Type::Ptr(ty) => self.visit_type(&ty.elem),
304            Type::Reference(ty) => self.visit_type(&ty.elem),
305            Type::Slice(ty) => self.visit_type(&ty.elem),
306            Type::TraitObject(ty) => {
307                for bound in &ty.bounds {
308                    self.visit_type_param_bound(bound);
309                }
310            }
311            Type::Tuple(ty) => {
312                for elem in &ty.elems {
313                    self.visit_type(elem);
314                }
315            }
316
317            Type::Infer(_) | Type::Never(_) | Type::Verbatim(_) => {}
318
319            _ => {}
320        }
321    }
322}