axum/routing/
strip_prefix.rs

1use http::{Request, Uri};
2use std::{
3    sync::Arc,
4    task::{Context, Poll},
5};
6use tower::Layer;
7use tower_layer::layer_fn;
8use tower_service::Service;
9
10#[derive(Clone)]
11pub(super) struct StripPrefix<S> {
12    inner: S,
13    prefix: Arc<str>,
14}
15
16impl<S> StripPrefix<S> {
17    pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
18        let prefix = Arc::from(prefix);
19        layer_fn(move |inner| Self {
20            inner,
21            prefix: Arc::clone(&prefix),
22        })
23    }
24}
25
26impl<S, B> Service<Request<B>> for StripPrefix<S>
27where
28    S: Service<Request<B>>,
29{
30    type Response = S::Response;
31    type Error = S::Error;
32    type Future = S::Future;
33
34    #[inline]
35    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
36        self.inner.poll_ready(cx)
37    }
38
39    fn call(&mut self, mut req: Request<B>) -> Self::Future {
40        if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
41            *req.uri_mut() = new_uri;
42        }
43        self.inner.call(req)
44    }
45}
46
47fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
48    let path_and_query = uri.path_and_query()?;
49
50    // Check whether the prefix matches the path and if so how long the matching prefix is.
51    //
52    // For example:
53    //
54    // prefix = /api
55    // path   = /api/users
56    //          ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4
57    //          characters we get the remainder
58    //
59    // prefix = /api/:version
60    // path   = /api/v0/users
61    //          ^^^^^^^ this much is matched and the length is 7.
62    let mut matching_prefix_length = Some(0);
63    for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
64        // count the `/`
65        *matching_prefix_length.as_mut().unwrap() += 1;
66
67        match item {
68            Item::Both(path_segment, prefix_segment) => {
69                if prefix_segment.starts_with(':') || path_segment == prefix_segment {
70                    // the prefix segment is either a param, which matches anything, or
71                    // it actually matches the path segment
72                    *matching_prefix_length.as_mut().unwrap() += path_segment.len();
73                } else if prefix_segment.is_empty() {
74                    // the prefix ended in a `/` so we got a match.
75                    //
76                    // For example:
77                    //
78                    // prefix = /foo/
79                    // path   = /foo/bar
80                    //
81                    // The prefix matches and the new path should be `/bar`
82                    break;
83                } else {
84                    // the prefix segment didn't match so there is no match
85                    matching_prefix_length = None;
86                    break;
87                }
88            }
89            // the path had more segments than the prefix but we got a match.
90            //
91            // For example:
92            //
93            // prefix = /foo
94            // path   = /foo/bar
95            Item::First(_) => {
96                break;
97            }
98            // the prefix had more segments than the path so there is no match
99            Item::Second(_) => {
100                matching_prefix_length = None;
101                break;
102            }
103        }
104    }
105
106    // if the prefix matches it will always do so up until a `/`, it cannot match only
107    // part of a segment. Therefore this will always be at a char boundary and `split_at` won't
108    // panic
109    let after_prefix = uri.path().split_at(matching_prefix_length?).1;
110
111    let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) {
112        (true, None) => after_prefix.parse().unwrap(),
113        (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
114        (false, None) => format!("/{after_prefix}").parse().unwrap(),
115        (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
116    };
117
118    let mut parts = uri.clone().into_parts();
119    parts.path_and_query = Some(new_path_and_query);
120
121    Some(Uri::from_parts(parts).unwrap())
122}
123
124fn segments(s: &str) -> impl Iterator<Item = &str> {
125    assert!(
126        s.starts_with('/'),
127        "path didn't start with '/'. axum should have caught this higher up."
128    );
129
130    s.split('/')
131        // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
132        // otherwise
133        .skip(1)
134}
135
136fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
137where
138    I: Iterator,
139    I2: Iterator<Item = I::Item>,
140{
141    let a = a.map(Some).chain(std::iter::repeat_with(|| None));
142    let b = b.map(Some).chain(std::iter::repeat_with(|| None));
143    a.zip(b).map_while(|(a, b)| match (a, b) {
144        (Some(a), Some(b)) => Some(Item::Both(a, b)),
145        (Some(a), None) => Some(Item::First(a)),
146        (None, Some(b)) => Some(Item::Second(b)),
147        (None, None) => None,
148    })
149}
150
151#[derive(Debug)]
152enum Item<T> {
153    Both(T, T),
154    First(T),
155    Second(T),
156}
157
158#[cfg(test)]
159mod tests {
160    #[allow(unused_imports)]
161    use super::*;
162    use quickcheck::Arbitrary;
163    use quickcheck_macros::quickcheck;
164
165    macro_rules! test {
166        (
167            $name:ident,
168            uri = $uri:literal,
169            prefix = $prefix:literal,
170            expected = $expected:expr,
171        ) => {
172            #[test]
173            fn $name() {
174                let uri = $uri.parse().unwrap();
175                let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
176                assert_eq!(new_uri.as_deref(), $expected);
177            }
178        };
179    }
180
181    test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
182
183    test!(
184        single_segment,
185        uri = "/a",
186        prefix = "/a",
187        expected = Some("/"),
188    );
189
190    test!(
191        single_segment_root_uri,
192        uri = "/",
193        prefix = "/a",
194        expected = None,
195    );
196
197    // the prefix is empty, so removing it should have no effect
198    test!(
199        single_segment_root_prefix,
200        uri = "/a",
201        prefix = "/",
202        expected = Some("/a"),
203    );
204
205    test!(
206        single_segment_no_match,
207        uri = "/a",
208        prefix = "/b",
209        expected = None,
210    );
211
212    test!(
213        single_segment_trailing_slash,
214        uri = "/a/",
215        prefix = "/a/",
216        expected = Some("/"),
217    );
218
219    test!(
220        single_segment_trailing_slash_2,
221        uri = "/a",
222        prefix = "/a/",
223        expected = None,
224    );
225
226    test!(
227        single_segment_trailing_slash_3,
228        uri = "/a/",
229        prefix = "/a",
230        expected = Some("/"),
231    );
232
233    test!(
234        multi_segment,
235        uri = "/a/b",
236        prefix = "/a",
237        expected = Some("/b"),
238    );
239
240    test!(
241        multi_segment_2,
242        uri = "/b/a",
243        prefix = "/a",
244        expected = None,
245    );
246
247    test!(
248        multi_segment_3,
249        uri = "/a",
250        prefix = "/a/b",
251        expected = None,
252    );
253
254    test!(
255        multi_segment_4,
256        uri = "/a/b",
257        prefix = "/b",
258        expected = None,
259    );
260
261    test!(
262        multi_segment_trailing_slash,
263        uri = "/a/b/",
264        prefix = "/a/b/",
265        expected = Some("/"),
266    );
267
268    test!(
269        multi_segment_trailing_slash_2,
270        uri = "/a/b",
271        prefix = "/a/b/",
272        expected = None,
273    );
274
275    test!(
276        multi_segment_trailing_slash_3,
277        uri = "/a/b/",
278        prefix = "/a/b",
279        expected = Some("/"),
280    );
281
282    test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),);
283
284    test!(
285        param_1,
286        uri = "/a",
287        prefix = "/:param",
288        expected = Some("/"),
289    );
290
291    test!(
292        param_2,
293        uri = "/a/b",
294        prefix = "/:param",
295        expected = Some("/b"),
296    );
297
298    test!(
299        param_3,
300        uri = "/b/a",
301        prefix = "/:param",
302        expected = Some("/a"),
303    );
304
305    test!(
306        param_4,
307        uri = "/a/b",
308        prefix = "/a/:param",
309        expected = Some("/"),
310    );
311
312    test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,);
313
314    test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,);
315
316    test!(
317        param_7,
318        uri = "/b/a",
319        prefix = "/:param/a",
320        expected = Some("/"),
321    );
322
323    test!(
324        param_8,
325        uri = "/a/b/c",
326        prefix = "/a/:param/c",
327        expected = Some("/"),
328    );
329
330    test!(
331        param_9,
332        uri = "/c/b/a",
333        prefix = "/a/:param/c",
334        expected = None,
335    );
336
337    test!(
338        param_10,
339        uri = "/a/",
340        prefix = "/:param",
341        expected = Some("/"),
342    );
343
344    test!(param_11, uri = "/a", prefix = "/:param/", expected = None,);
345
346    test!(
347        param_12,
348        uri = "/a/",
349        prefix = "/:param/",
350        expected = Some("/"),
351    );
352
353    test!(
354        param_13,
355        uri = "/a/a",
356        prefix = "/a/",
357        expected = Some("/a"),
358    );
359
360    #[quickcheck]
361    fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
362        let UriAndPrefix { uri, prefix } = uri_and_prefix;
363        strip_prefix(&uri, &prefix);
364        true
365    }
366
367    #[derive(Clone, Debug)]
368    struct UriAndPrefix {
369        uri: Uri,
370        prefix: String,
371    }
372
373    impl Arbitrary for UriAndPrefix {
374        fn arbitrary(g: &mut quickcheck::Gen) -> Self {
375            let mut uri = String::new();
376            let mut prefix = String::new();
377
378            let size = u8_between(1, 20, g);
379
380            for _ in 0..size {
381                let segment = ascii_alphanumeric(g);
382
383                uri.push('/');
384                uri.push_str(&segment);
385
386                prefix.push('/');
387
388                let make_matching_segment = bool::arbitrary(g);
389                let make_capture = bool::arbitrary(g);
390
391                match (make_matching_segment, make_capture) {
392                    (_, true) => {
393                        prefix.push_str(":a");
394                    }
395                    (true, false) => {
396                        prefix.push_str(&segment);
397                    }
398                    (false, false) => {
399                        prefix.push_str(&ascii_alphanumeric(g));
400                    }
401                }
402            }
403
404            if bool::arbitrary(g) {
405                uri.push('/');
406            }
407
408            if bool::arbitrary(g) {
409                prefix.push('/');
410            }
411
412            Self {
413                uri: uri.parse().unwrap(),
414                prefix,
415            }
416        }
417    }
418
419    fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
420        #[derive(Clone)]
421        struct AsciiAlphanumeric(String);
422
423        impl Arbitrary for AsciiAlphanumeric {
424            fn arbitrary(g: &mut quickcheck::Gen) -> Self {
425                let mut out = String::new();
426
427                let size = u8_between(1, 20, g) as usize;
428
429                while out.len() < size {
430                    let c = char::arbitrary(g);
431                    if c.is_ascii_alphanumeric() {
432                        out.push(c);
433                    }
434                }
435                Self(out)
436            }
437        }
438
439        let out = AsciiAlphanumeric::arbitrary(g).0;
440        assert!(!out.is_empty());
441        out
442    }
443
444    fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
445        loop {
446            let size = u8::arbitrary(g);
447            if size > lower && size <= upper {
448                break size;
449            }
450        }
451    }
452}