proptest_derive/
use_tracking.rs

1// Copyright 2018 The proptest developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Provides `UseTracker` as well as `UseMarkable` which is used to
10//! track uses of type variables that need `Arbitrary` bounds in our
11//! impls.
12
13use std::borrow::Borrow;
14use std::collections::HashSet;
15
16use syn;
17
18use crate::attr;
19use crate::error::{Ctx, DeriveResult};
20use crate::util;
21
22//==============================================================================
23// API: Type variable use tracking
24//==============================================================================
25
26/// `UseTracker` tracks what type variables that have used in `any_with::<Type>`
27/// or similar and thus needs an `Arbitrary` bound added to them.
28pub struct UseTracker {
29    /// Tracks 'usage' of a type variable name.
30    /// Allocation of this "map" will happen at once and no further
31    /// allocation will happen after that. Only potential updates
32    /// will happen after initial allocation.
33    /// We need to preserve insertion order, thus using Vec instead of BTreeMap
34    /// or HashMap. A potential alternative would be indexmap crate, but our
35    /// maps are so small that it would not bring any significant benefit.
36    used_map: Vec<(syn::Ident, bool)>,
37    /// Extra types to bound by `Arbitrary` in the `where` clause.
38    where_types: HashSet<syn::Type>,
39    /// The generics that we are doing this for.
40    /// This what we will modify later once we're done.
41    generics: syn::Generics,
42    /// If set to `true`, then `mark_used` has no effect.
43    track: bool,
44}
45
46/// Models a thing that may have type variables in it that
47/// can be marked as 'used' as defined by `UseTracker`.
48pub trait UseMarkable {
49    fn mark_uses(&self, tracker: &mut UseTracker);
50}
51
52impl UseTracker {
53    /// Constructs the tracker for the given `generics`.
54    pub fn new(generics: syn::Generics) -> Self {
55        // Construct the map by setting all type variables as being unused
56        // initially. This is the only time we will allocate for the map.
57        let used_map = generics
58            .type_params()
59            .map(|v| (v.ident.clone(), false))
60            .collect();
61        Self {
62            generics,
63            used_map,
64            where_types: HashSet::default(),
65            track: true,
66        }
67    }
68
69    /// Stop tracking. `.mark_used` will have no effect.
70    pub fn no_track(&mut self) {
71        self.track = false;
72    }
73
74    /// Mark the _potential_ type variable `tyvar` as used.
75    /// If the tracker does not know about the name, it is not
76    /// a type variable and this call has no effect.
77    fn use_tyvar(&mut self, tyvar: impl Borrow<syn::Ident>) {
78        if self.track {
79            if let Some(used) = self
80                .used_map
81                .iter_mut()
82                .find_map(|(ty, used)| (ty == tyvar.borrow()).then(|| used))
83            {
84                *used = true;
85            }
86        }
87    }
88
89    /// Returns true iff the type variable given exists.
90    fn has_tyvar(&self, ty_var: impl Borrow<syn::Ident>) -> bool {
91        self.used_map.iter().any(|(ty, _)| ty == ty_var.borrow())
92    }
93
94    /// Mark the type as used.
95    fn use_type(&mut self, ty: syn::Type) {
96        self.where_types.insert(ty);
97    }
98
99    /// Adds the bound in `for_used` on used type variables and
100    /// the bound in `for_not` (`if .is_some()`) on unused type variables.
101    pub fn add_bounds(
102        &mut self,
103        ctx: Ctx,
104        for_used: &syn::TypeParamBound,
105        for_not: Option<syn::TypeParamBound>,
106    ) -> DeriveResult<()> {
107        {
108            let mut iter = self
109                .used_map
110                .iter()
111                .map(|(_, used)| used)
112                .zip(self.generics.type_params_mut());
113            if let Some(for_not) = for_not {
114                iter.try_for_each(|(&used, tv)| {
115                    // Steal the attributes:
116                    let no_bound = attr::has_no_bound(ctx, &tv.attrs)?;
117                    let bound = if used && !no_bound {
118                        for_used
119                    } else {
120                        &for_not
121                    };
122                    tv.bounds.push(bound.clone());
123                    Ok(())
124                })?;
125            } else {
126                iter.for_each(|(&used, tv)| {
127                    if used {
128                        tv.bounds.push(for_used.clone())
129                    }
130                })
131            }
132        }
133
134        self.generics.make_where_clause().predicates.extend(
135            self.where_types.iter().cloned().map(|ty| {
136                syn::WherePredicate::Type(syn::PredicateType {
137                    lifetimes: None,
138                    bounded_ty: ty,
139                    colon_token: <Token![:]>::default(),
140                    bounds: ::std::iter::once(for_used.clone()).collect(),
141                })
142            }),
143        );
144
145        Ok(())
146    }
147
148    /// Consumes the (potentially) modified generics that the
149    /// tracker was originally constructed with and returns it.
150    pub fn consume(self) -> syn::Generics {
151        self.generics
152    }
153}
154
155//==============================================================================
156// Impls
157//==============================================================================
158
159impl UseMarkable for syn::Type {
160    fn mark_uses(&self, ut: &mut UseTracker) {
161        use syn::visit;
162
163        visit::visit_type(&mut PathVisitor(ut), self);
164
165        struct PathVisitor<'ut>(&'ut mut UseTracker);
166
167        impl<'ut, 'ast> visit::Visit<'ast> for PathVisitor<'ut> {
168            fn visit_macro(&mut self, _: &syn::Macro) {}
169
170            fn visit_type_path(&mut self, tpath: &syn::TypePath) {
171                if matches_prj_tyvar(self.0, tpath) {
172                    self.0.use_type(adjust_simple_prj(tpath).into());
173                    return;
174                }
175                visit::visit_type_path(self, tpath);
176            }
177
178            fn visit_path(&mut self, path: &syn::Path) {
179                // If path is PhantomData<T> do not mark innards.
180                if util::is_phantom_data(path) {
181                    return;
182                }
183
184                if let Some(ident) = util::extract_simple_path(path) {
185                    self.0.use_tyvar(ident);
186                }
187
188                visit::visit_path(self, path);
189            }
190        }
191    }
192}
193
194fn matches_prj_tyvar(ut: &mut UseTracker, tpath: &syn::TypePath) -> bool {
195    let path = &tpath.path;
196    let segs = &path.segments;
197
198    if let Some(qself) = &tpath.qself {
199        // < $qself > :: $path
200        if let Some(sub_tp) = extract_path(&qself.ty) {
201            return sub_tp.qself.is_none()
202                && util::match_singleton(segs.iter().skip(qself.position))
203                    .filter(|ps| ps.arguments.is_empty())
204                    .and_then(|_| util::extract_simple_path(&sub_tp.path))
205                    .filter(|&ident| ut.has_tyvar(ident))
206                    .is_some() // < $tyvar as? $path? > :: $path
207                || matches_prj_tyvar(ut, sub_tp);
208        }
209
210        false
211    } else {
212        // true => $tyvar :: $projection
213        return !util::path_is_global(path)
214            && segs.len() == 2
215            && ut.has_tyvar(&segs[0].ident)
216            && segs[0].arguments.is_empty()
217            && segs[1].arguments.is_empty();
218    }
219}
220
221fn adjust_simple_prj(tpath: &syn::TypePath) -> syn::TypePath {
222    let segments = tpath
223        .qself
224        .as_ref()
225        .filter(|qp| qp.as_token.is_none())
226        .and_then(|qp| extract_path(&*qp.ty))
227        .filter(|tp| tp.qself.is_none())
228        .map(|tp| &tp.path.segments);
229
230    if let Some(segments) = segments {
231        let tpath = tpath.clone();
232        let mut segments = segments.clone();
233        segments.push_punct(<Token![::]>::default());
234        segments.extend(tpath.path.segments.into_pairs());
235        syn::TypePath {
236            qself: None,
237            path: syn::Path {
238                leading_colon: None,
239                segments,
240            },
241        }
242    } else {
243        tpath.clone()
244    }
245}
246
247fn extract_path(ty: &syn::Type) -> Option<&syn::TypePath> {
248    if let syn::Type::Path(tpath) = ty {
249        Some(tpath)
250    } else {
251        None
252    }
253}