borsh_derive/internals/
generics.rs
1use std::collections::{HashMap, HashSet};
2
3use quote::{quote, ToTokens};
4use syn::{
5 punctuated::Pair, Field, GenericArgument, Generics, Ident, Macro, Path, PathArguments,
6 PathSegment, ReturnType, Type, TypeParamBound, TypePath, WhereClause, WherePredicate,
7};
8
9pub fn default_where(where_clause: Option<&WhereClause>) -> WhereClause {
10 where_clause.map_or_else(
11 || WhereClause {
12 where_token: Default::default(),
13 predicates: Default::default(),
14 },
15 Clone::clone,
16 )
17}
18
19pub fn compute_predicates(params: Vec<Type>, traitname: &Path) -> Vec<WherePredicate> {
20 params
21 .into_iter()
22 .map(|param| {
23 syn::parse2(quote! {
24 #param: #traitname
25 })
26 .unwrap()
27 })
28 .collect()
29}
30
31pub fn without_defaults(generics: &Generics) -> Generics {
35 syn::Generics {
36 params: generics
37 .params
38 .iter()
39 .map(|param| match param {
40 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
41 eq_token: None,
42 default: None,
43 ..param.clone()
44 }),
45 _ => param.clone(),
46 })
47 .collect(),
48 ..generics.clone()
49 }
50}
51
52#[cfg(feature = "schema")]
53pub fn type_contains_some_param(type_: &Type, params: &HashSet<Ident>) -> bool {
54 let mut find: FindTyParams = FindTyParams::from_params(params.iter());
55
56 find.visit_type_top_level(type_);
57
58 find.at_least_one_hit()
59}
60
61#[derive(Clone)]
63pub struct FindTyParams {
64 all_type_params: HashSet<Ident>,
66 all_type_params_ordered: Vec<Ident>,
67
68 relevant_type_params: HashSet<Ident>,
71
72 associated_type_params_usage: HashMap<Ident, Vec<Type>>,
74}
75
76fn ungroup(mut ty: &Type) -> &Type {
77 while let Type::Group(group) = ty {
78 ty = &group.elem;
79 }
80 ty
81}
82
83impl FindTyParams {
84 pub fn new(generics: &Generics) -> Self {
85 let all_type_params = generics
86 .type_params()
87 .map(|param| param.ident.clone())
88 .collect();
89
90 let all_type_params_ordered = generics
91 .type_params()
92 .map(|param| param.ident.clone())
93 .collect();
94
95 FindTyParams {
96 all_type_params,
97 all_type_params_ordered,
98 relevant_type_params: HashSet::new(),
99 associated_type_params_usage: HashMap::new(),
100 }
101 }
102 pub fn process_for_bounds(self) -> Vec<Type> {
103 let relevant_type_params = self.relevant_type_params;
104 let associated_type_params_usage = self.associated_type_params_usage;
105 let mut new_predicates: Vec<Type> = vec![];
106 let mut new_predicates_set: HashSet<String> = HashSet::new();
107
108 self.all_type_params_ordered.iter().for_each(|param| {
109 if relevant_type_params.contains(param) {
110 let ty = Type::Path(TypePath {
111 qself: None,
112 path: param.clone().into(),
113 });
114 let ty_str_repr = ty.to_token_stream().to_string();
115 if !new_predicates_set.contains(&ty_str_repr) {
116 new_predicates.push(ty);
117 new_predicates_set.insert(ty_str_repr);
118 }
119 }
120 if let Some(vec_type) = associated_type_params_usage.get(param) {
121 for type_ in vec_type {
122 let ty_str_repr = type_.to_token_stream().to_string();
123 if !new_predicates_set.contains(&ty_str_repr) {
124 new_predicates.push(type_.clone());
125 new_predicates_set.insert(ty_str_repr);
126 }
127 }
128 }
129 });
130
131 new_predicates
132 }
133}
134
135#[cfg(feature = "schema")]
136impl FindTyParams {
137 pub fn from_params<'a>(params: impl Iterator<Item = &'a Ident>) -> Self {
138 let all_type_params_ordered: Vec<Ident> = params.cloned().collect();
139 let all_type_params = all_type_params_ordered.clone().into_iter().collect();
140 FindTyParams {
141 all_type_params,
142 all_type_params_ordered,
143 relevant_type_params: HashSet::new(),
144 associated_type_params_usage: HashMap::new(),
145 }
146 }
147
148 pub fn process_for_params(self) -> Vec<Ident> {
149 let relevant_type_params = self.relevant_type_params;
150 let associated_type_params_usage = self.associated_type_params_usage;
151
152 let mut params: Vec<Ident> = vec![];
153 let mut params_set: HashSet<Ident> = HashSet::new();
154 self.all_type_params_ordered.iter().for_each(|param| {
155 if relevant_type_params.contains(param) && !params_set.contains(param) {
156 params.push(param.clone());
157 params_set.insert(param.clone());
158 }
159 if associated_type_params_usage.contains_key(param) && !params_set.contains(param) {
160 params.push(param.clone());
161 params_set.insert(param.clone());
162 }
163 });
164 params
165 }
166 pub fn at_least_one_hit(&self) -> bool {
167 !self.relevant_type_params.is_empty() || !self.associated_type_params_usage.is_empty()
168 }
169}
170
171impl FindTyParams {
172 pub fn visit_field(&mut self, field: &Field) {
173 self.visit_type_top_level(&field.ty);
174 }
175
176 pub fn visit_type_top_level(&mut self, type_: &Type) {
177 if let Type::Path(ty) = ungroup(type_) {
178 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
179 if self.all_type_params.contains(&t.ident) {
180 self.param_associated_type_insert(t.ident.clone(), type_.clone());
181 }
182 }
183 }
184 self.visit_type(type_);
185 }
186
187 pub fn param_associated_type_insert(&mut self, param: Ident, type_: Type) {
188 if let Some(type_vec) = self.associated_type_params_usage.get_mut(¶m) {
189 type_vec.push(type_);
190 } else {
191 let type_vec = vec![type_];
192 self.associated_type_params_usage.insert(param, type_vec);
193 }
194 }
195
196 fn visit_return_type(&mut self, return_type: &ReturnType) {
197 match return_type {
198 ReturnType::Default => {}
199 ReturnType::Type(_, output) => self.visit_type(output),
200 }
201 }
202
203 fn visit_path_segment(&mut self, segment: &PathSegment) {
204 self.visit_path_arguments(&segment.arguments);
205 }
206
207 fn visit_path_arguments(&mut self, arguments: &PathArguments) {
208 match arguments {
209 PathArguments::None => {}
210 PathArguments::AngleBracketed(arguments) => {
211 for arg in &arguments.args {
212 #[cfg_attr(
213 feature = "force_exhaustive_checks",
214 deny(non_exhaustive_omitted_patterns)
215 )]
216 match arg {
217 GenericArgument::Type(arg) => self.visit_type(arg),
218 GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
219 GenericArgument::Lifetime(_)
220 | GenericArgument::Const(_)
221 | GenericArgument::AssocConst(_)
222 | GenericArgument::Constraint(_) => {}
223 _ => {}
224 }
225 }
226 }
227 PathArguments::Parenthesized(arguments) => {
228 for argument in &arguments.inputs {
229 self.visit_type(argument);
230 }
231 self.visit_return_type(&arguments.output);
232 }
233 }
234 }
235
236 fn visit_path(&mut self, path: &Path) {
237 if let Some(seg) = path.segments.last() {
238 if seg.ident == "PhantomData" {
239 return;
242 }
243 }
244 if path.leading_colon.is_none() && path.segments.len() == 1 {
245 let id = &path.segments[0].ident;
246 if self.all_type_params.contains(id) {
247 self.relevant_type_params.insert(id.clone());
248 }
249 }
250 for segment in &path.segments {
251 self.visit_path_segment(segment);
252 }
253 }
254
255 fn visit_type_param_bound(&mut self, bound: &TypeParamBound) {
256 #[cfg_attr(
257 feature = "force_exhaustive_checks",
258 deny(non_exhaustive_omitted_patterns)
259 )]
260 match bound {
261 TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
262 TypeParamBound::Lifetime(_)
263 | TypeParamBound::Verbatim(_)
264 | TypeParamBound::PreciseCapture(_) => {}
265 _ => {}
266 }
267 }
268 fn visit_macro(&mut self, _mac: &Macro) {}
275
276 fn visit_type(&mut self, ty: &Type) {
277 #[cfg_attr(
278 feature = "force_exhaustive_checks",
279 deny(non_exhaustive_omitted_patterns)
280 )]
281 match ty {
282 Type::Array(ty) => self.visit_type(&ty.elem),
283 Type::BareFn(ty) => {
284 for arg in &ty.inputs {
285 self.visit_type(&arg.ty);
286 }
287 self.visit_return_type(&ty.output);
288 }
289 Type::Group(ty) => self.visit_type(&ty.elem),
290 Type::ImplTrait(ty) => {
291 for bound in &ty.bounds {
292 self.visit_type_param_bound(bound);
293 }
294 }
295 Type::Macro(ty) => self.visit_macro(&ty.mac),
296 Type::Paren(ty) => self.visit_type(&ty.elem),
297 Type::Path(ty) => {
298 if let Some(qself) = &ty.qself {
299 self.visit_type(&qself.ty);
300 }
301 self.visit_path(&ty.path);
302 }
303 Type::Ptr(ty) => self.visit_type(&ty.elem),
304 Type::Reference(ty) => self.visit_type(&ty.elem),
305 Type::Slice(ty) => self.visit_type(&ty.elem),
306 Type::TraitObject(ty) => {
307 for bound in &ty.bounds {
308 self.visit_type_param_bound(bound);
309 }
310 }
311 Type::Tuple(ty) => {
312 for elem in &ty.elems {
313 self.visit_type(elem);
314 }
315 }
316
317 Type::Infer(_) | Type::Never(_) | Type::Verbatim(_) => {}
318
319 _ => {}
320 }
321 }
322}