1use http::{Request, Uri};
2use std::{
3 sync::Arc,
4 task::{Context, Poll},
5};
6use tower::Layer;
7use tower_layer::layer_fn;
8use tower_service::Service;
9
10#[derive(Clone)]
11pub(super) struct StripPrefix<S> {
12 inner: S,
13 prefix: Arc<str>,
14}
15
16impl<S> StripPrefix<S> {
17 pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
18 let prefix = Arc::from(prefix);
19 layer_fn(move |inner| Self {
20 inner,
21 prefix: Arc::clone(&prefix),
22 })
23 }
24}
25
26impl<S, B> Service<Request<B>> for StripPrefix<S>
27where
28 S: Service<Request<B>>,
29{
30 type Response = S::Response;
31 type Error = S::Error;
32 type Future = S::Future;
33
34 #[inline]
35 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
36 self.inner.poll_ready(cx)
37 }
38
39 fn call(&mut self, mut req: Request<B>) -> Self::Future {
40 if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
41 *req.uri_mut() = new_uri;
42 }
43 self.inner.call(req)
44 }
45}
46
47fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
48 let path_and_query = uri.path_and_query()?;
49
50 let mut matching_prefix_length = Some(0);
63 for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
64 *matching_prefix_length.as_mut().unwrap() += 1;
66
67 match item {
68 Item::Both(path_segment, prefix_segment) => {
69 if prefix_segment.starts_with(':') || path_segment == prefix_segment {
70 *matching_prefix_length.as_mut().unwrap() += path_segment.len();
73 } else if prefix_segment.is_empty() {
74 break;
83 } else {
84 matching_prefix_length = None;
86 break;
87 }
88 }
89 Item::First(_) => {
96 break;
97 }
98 Item::Second(_) => {
100 matching_prefix_length = None;
101 break;
102 }
103 }
104 }
105
106 let after_prefix = uri.path().split_at(matching_prefix_length?).1;
110
111 let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) {
112 (true, None) => after_prefix.parse().unwrap(),
113 (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
114 (false, None) => format!("/{after_prefix}").parse().unwrap(),
115 (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
116 };
117
118 let mut parts = uri.clone().into_parts();
119 parts.path_and_query = Some(new_path_and_query);
120
121 Some(Uri::from_parts(parts).unwrap())
122}
123
124fn segments(s: &str) -> impl Iterator<Item = &str> {
125 assert!(
126 s.starts_with('/'),
127 "path didn't start with '/'. axum should have caught this higher up."
128 );
129
130 s.split('/')
131 .skip(1)
134}
135
136fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
137where
138 I: Iterator,
139 I2: Iterator<Item = I::Item>,
140{
141 let a = a.map(Some).chain(std::iter::repeat_with(|| None));
142 let b = b.map(Some).chain(std::iter::repeat_with(|| None));
143 a.zip(b).map_while(|(a, b)| match (a, b) {
144 (Some(a), Some(b)) => Some(Item::Both(a, b)),
145 (Some(a), None) => Some(Item::First(a)),
146 (None, Some(b)) => Some(Item::Second(b)),
147 (None, None) => None,
148 })
149}
150
151#[derive(Debug)]
152enum Item<T> {
153 Both(T, T),
154 First(T),
155 Second(T),
156}
157
158#[cfg(test)]
159mod tests {
160 #[allow(unused_imports)]
161 use super::*;
162 use quickcheck::Arbitrary;
163 use quickcheck_macros::quickcheck;
164
165 macro_rules! test {
166 (
167 $name:ident,
168 uri = $uri:literal,
169 prefix = $prefix:literal,
170 expected = $expected:expr,
171 ) => {
172 #[test]
173 fn $name() {
174 let uri = $uri.parse().unwrap();
175 let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
176 assert_eq!(new_uri.as_deref(), $expected);
177 }
178 };
179 }
180
181 test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
182
183 test!(
184 single_segment,
185 uri = "/a",
186 prefix = "/a",
187 expected = Some("/"),
188 );
189
190 test!(
191 single_segment_root_uri,
192 uri = "/",
193 prefix = "/a",
194 expected = None,
195 );
196
197 test!(
199 single_segment_root_prefix,
200 uri = "/a",
201 prefix = "/",
202 expected = Some("/a"),
203 );
204
205 test!(
206 single_segment_no_match,
207 uri = "/a",
208 prefix = "/b",
209 expected = None,
210 );
211
212 test!(
213 single_segment_trailing_slash,
214 uri = "/a/",
215 prefix = "/a/",
216 expected = Some("/"),
217 );
218
219 test!(
220 single_segment_trailing_slash_2,
221 uri = "/a",
222 prefix = "/a/",
223 expected = None,
224 );
225
226 test!(
227 single_segment_trailing_slash_3,
228 uri = "/a/",
229 prefix = "/a",
230 expected = Some("/"),
231 );
232
233 test!(
234 multi_segment,
235 uri = "/a/b",
236 prefix = "/a",
237 expected = Some("/b"),
238 );
239
240 test!(
241 multi_segment_2,
242 uri = "/b/a",
243 prefix = "/a",
244 expected = None,
245 );
246
247 test!(
248 multi_segment_3,
249 uri = "/a",
250 prefix = "/a/b",
251 expected = None,
252 );
253
254 test!(
255 multi_segment_4,
256 uri = "/a/b",
257 prefix = "/b",
258 expected = None,
259 );
260
261 test!(
262 multi_segment_trailing_slash,
263 uri = "/a/b/",
264 prefix = "/a/b/",
265 expected = Some("/"),
266 );
267
268 test!(
269 multi_segment_trailing_slash_2,
270 uri = "/a/b",
271 prefix = "/a/b/",
272 expected = None,
273 );
274
275 test!(
276 multi_segment_trailing_slash_3,
277 uri = "/a/b/",
278 prefix = "/a/b",
279 expected = Some("/"),
280 );
281
282 test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),);
283
284 test!(
285 param_1,
286 uri = "/a",
287 prefix = "/:param",
288 expected = Some("/"),
289 );
290
291 test!(
292 param_2,
293 uri = "/a/b",
294 prefix = "/:param",
295 expected = Some("/b"),
296 );
297
298 test!(
299 param_3,
300 uri = "/b/a",
301 prefix = "/:param",
302 expected = Some("/a"),
303 );
304
305 test!(
306 param_4,
307 uri = "/a/b",
308 prefix = "/a/:param",
309 expected = Some("/"),
310 );
311
312 test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,);
313
314 test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,);
315
316 test!(
317 param_7,
318 uri = "/b/a",
319 prefix = "/:param/a",
320 expected = Some("/"),
321 );
322
323 test!(
324 param_8,
325 uri = "/a/b/c",
326 prefix = "/a/:param/c",
327 expected = Some("/"),
328 );
329
330 test!(
331 param_9,
332 uri = "/c/b/a",
333 prefix = "/a/:param/c",
334 expected = None,
335 );
336
337 test!(
338 param_10,
339 uri = "/a/",
340 prefix = "/:param",
341 expected = Some("/"),
342 );
343
344 test!(param_11, uri = "/a", prefix = "/:param/", expected = None,);
345
346 test!(
347 param_12,
348 uri = "/a/",
349 prefix = "/:param/",
350 expected = Some("/"),
351 );
352
353 test!(
354 param_13,
355 uri = "/a/a",
356 prefix = "/a/",
357 expected = Some("/a"),
358 );
359
360 #[quickcheck]
361 fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
362 let UriAndPrefix { uri, prefix } = uri_and_prefix;
363 strip_prefix(&uri, &prefix);
364 true
365 }
366
367 #[derive(Clone, Debug)]
368 struct UriAndPrefix {
369 uri: Uri,
370 prefix: String,
371 }
372
373 impl Arbitrary for UriAndPrefix {
374 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
375 let mut uri = String::new();
376 let mut prefix = String::new();
377
378 let size = u8_between(1, 20, g);
379
380 for _ in 0..size {
381 let segment = ascii_alphanumeric(g);
382
383 uri.push('/');
384 uri.push_str(&segment);
385
386 prefix.push('/');
387
388 let make_matching_segment = bool::arbitrary(g);
389 let make_capture = bool::arbitrary(g);
390
391 match (make_matching_segment, make_capture) {
392 (_, true) => {
393 prefix.push_str(":a");
394 }
395 (true, false) => {
396 prefix.push_str(&segment);
397 }
398 (false, false) => {
399 prefix.push_str(&ascii_alphanumeric(g));
400 }
401 }
402 }
403
404 if bool::arbitrary(g) {
405 uri.push('/');
406 }
407
408 if bool::arbitrary(g) {
409 prefix.push('/');
410 }
411
412 Self {
413 uri: uri.parse().unwrap(),
414 prefix,
415 }
416 }
417 }
418
419 fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
420 #[derive(Clone)]
421 struct AsciiAlphanumeric(String);
422
423 impl Arbitrary for AsciiAlphanumeric {
424 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
425 let mut out = String::new();
426
427 let size = u8_between(1, 20, g) as usize;
428
429 while out.len() < size {
430 let c = char::arbitrary(g);
431 if c.is_ascii_alphanumeric() {
432 out.push(c);
433 }
434 }
435 Self(out)
436 }
437 }
438
439 let out = AsciiAlphanumeric::arbitrary(g).0;
440 assert!(!out.is_empty());
441 out
442 }
443
444 fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
445 loop {
446 let size = u8::arbitrary(g);
447 if size > lower && size <= upper {
448 break size;
449 }
450 }
451 }
452}