matchit/
tree.rs

1use crate::{InsertError, MatchError, Params};
2
3use std::cell::UnsafeCell;
4use std::cmp::min;
5use std::mem;
6
7/// The types of nodes the tree can hold
8#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone)]
9pub(crate) enum NodeType {
10    /// The root path
11    Root,
12    /// A route parameter, ex: `/:id`.
13    Param,
14    /// A catchall parameter, ex: `/*file`
15    CatchAll,
16    /// Anything else
17    Static,
18}
19
20/// A radix tree used for URL path matching.
21///
22/// See [the crate documentation](crate) for details.
23pub struct Node<T> {
24    priority: u32,
25    wild_child: bool,
26    indices: Vec<u8>,
27    // see `at` for why an unsafe cell is needed
28    value: Option<UnsafeCell<T>>,
29    pub(crate) param_remapping: ParamRemapping,
30    pub(crate) node_type: NodeType,
31    pub(crate) prefix: Vec<u8>,
32    pub(crate) children: Vec<Self>,
33}
34
35// SAFETY: we expose `value` per rust's usual borrowing rules, so we can just delegate these traits
36unsafe impl<T: Send> Send for Node<T> {}
37unsafe impl<T: Sync> Sync for Node<T> {}
38
39impl<T> Node<T> {
40    pub fn insert(&mut self, route: impl Into<String>, val: T) -> Result<(), InsertError> {
41        let route = route.into().into_bytes();
42        let (route, param_remapping) = normalize_params(route)?;
43        let mut prefix = route.as_ref();
44
45        self.priority += 1;
46
47        // the tree is empty
48        if self.prefix.is_empty() && self.children.is_empty() {
49            let last = self.insert_child(prefix, &route, val)?;
50            last.param_remapping = param_remapping;
51            self.node_type = NodeType::Root;
52            return Ok(());
53        }
54
55        let mut current = self;
56
57        'walk: loop {
58            // find the longest common prefix
59            let len = min(prefix.len(), current.prefix.len());
60            let common_prefix = (0..len)
61                .find(|&i| prefix[i] != current.prefix[i])
62                .unwrap_or(len);
63
64            // the common prefix is a substring of the current node's prefix, split the node
65            if common_prefix < current.prefix.len() {
66                let child = Node {
67                    prefix: current.prefix[common_prefix..].to_owned(),
68                    children: mem::take(&mut current.children),
69                    wild_child: current.wild_child,
70                    indices: current.indices.clone(),
71                    value: current.value.take(),
72                    param_remapping: mem::take(&mut current.param_remapping),
73                    priority: current.priority - 1,
74                    ..Node::default()
75                };
76
77                // the current node now holds only the common prefix
78                current.children = vec![child];
79                current.indices = vec![current.prefix[common_prefix]];
80                current.prefix = prefix[..common_prefix].to_owned();
81                current.wild_child = false;
82            }
83
84            // the route has a common prefix, search deeper
85            if prefix.len() > common_prefix {
86                prefix = &prefix[common_prefix..];
87
88                let next = prefix[0];
89
90                // `/` after param
91                if current.node_type == NodeType::Param
92                    && next == b'/'
93                    && current.children.len() == 1
94                {
95                    current = &mut current.children[0];
96                    current.priority += 1;
97
98                    continue 'walk;
99                }
100
101                // find a child that matches the next path byte
102                for mut i in 0..current.indices.len() {
103                    // found a match
104                    if next == current.indices[i] {
105                        i = current.update_child_priority(i);
106                        current = &mut current.children[i];
107                        continue 'walk;
108                    }
109                }
110
111                // not a wildcard and there is no matching child node, create a new one
112                if !matches!(next, b':' | b'*') && current.node_type != NodeType::CatchAll {
113                    current.indices.push(next);
114                    let mut child = current.add_child(Node::default());
115                    child = current.update_child_priority(child);
116
117                    // insert into the new node
118                    let last = current.children[child].insert_child(prefix, &route, val)?;
119                    last.param_remapping = param_remapping;
120                    return Ok(());
121                }
122
123                // inserting a wildcard, and this node already has a wildcard child
124                if current.wild_child {
125                    // wildcards are always at the end
126                    current = current.children.last_mut().unwrap();
127                    current.priority += 1;
128
129                    // make sure the wildcard matches
130                    if prefix.len() < current.prefix.len()
131                        || current.prefix != prefix[..current.prefix.len()]
132                        // catch-alls cannot have children 
133                        || current.node_type == NodeType::CatchAll
134                        // check for longer wildcard, e.g. :name and :names
135                        || (current.prefix.len() < prefix.len()
136                            && prefix[current.prefix.len()] != b'/')
137                    {
138                        return Err(InsertError::conflict(&route, prefix, current));
139                    }
140
141                    continue 'walk;
142                }
143
144                // otherwise, create the wildcard node
145                let last = current.insert_child(prefix, &route, val)?;
146                last.param_remapping = param_remapping;
147                return Ok(());
148            }
149
150            // exact match, this node should be empty
151            if current.value.is_some() {
152                return Err(InsertError::conflict(&route, prefix, current));
153            }
154
155            // add the value to current node
156            current.value = Some(UnsafeCell::new(val));
157            current.param_remapping = param_remapping;
158
159            return Ok(());
160        }
161    }
162
163    // add a child node, keeping wildcards at the end
164    fn add_child(&mut self, child: Node<T>) -> usize {
165        let len = self.children.len();
166
167        if self.wild_child && len > 0 {
168            self.children.insert(len - 1, child);
169            len - 1
170        } else {
171            self.children.push(child);
172            len
173        }
174    }
175
176    // increments priority of the given child and reorders if necessary.
177    //
178    // returns the new index of the child
179    fn update_child_priority(&mut self, i: usize) -> usize {
180        self.children[i].priority += 1;
181        let priority = self.children[i].priority;
182
183        // adjust position (move to front)
184        let mut updated = i;
185        while updated > 0 && self.children[updated - 1].priority < priority {
186            // swap node positions
187            self.children.swap(updated - 1, updated);
188            updated -= 1;
189        }
190
191        // build new index list
192        if updated != i {
193            self.indices = [
194                &self.indices[..updated],  // unchanged prefix, might be empty
195                &self.indices[i..=i],      // the index char we move
196                &self.indices[updated..i], // rest without char at 'pos'
197                &self.indices[i + 1..],
198            ]
199            .concat();
200        }
201
202        updated
203    }
204
205    // insert a child node at this node
206    fn insert_child(
207        &mut self,
208        mut prefix: &[u8],
209        route: &[u8],
210        val: T,
211    ) -> Result<&mut Node<T>, InsertError> {
212        let mut current = self;
213
214        loop {
215            // search for a wildcard segment
216            let (wildcard, wildcard_index) = match find_wildcard(prefix)? {
217                Some((w, i)) => (w, i),
218                // no wildcard, simply use the current node
219                None => {
220                    current.value = Some(UnsafeCell::new(val));
221                    current.prefix = prefix.to_owned();
222                    return Ok(current);
223                }
224            };
225
226            // regular route parameter
227            if wildcard[0] == b':' {
228                // insert prefix before the current wildcard
229                if wildcard_index > 0 {
230                    current.prefix = prefix[..wildcard_index].to_owned();
231                    prefix = &prefix[wildcard_index..];
232                }
233
234                let child = Self {
235                    node_type: NodeType::Param,
236                    prefix: wildcard.to_owned(),
237                    ..Self::default()
238                };
239
240                let child = current.add_child(child);
241                current.wild_child = true;
242                current = &mut current.children[child];
243                current.priority += 1;
244
245                // if the route doesn't end with the wildcard, then there
246                // will be another non-wildcard subroute starting with '/'
247                if wildcard.len() < prefix.len() {
248                    prefix = &prefix[wildcard.len()..];
249                    let child = Self {
250                        priority: 1,
251                        ..Self::default()
252                    };
253
254                    let child = current.add_child(child);
255                    current = &mut current.children[child];
256                    continue;
257                }
258
259                // otherwise we're done. Insert the value in the new leaf
260                current.value = Some(UnsafeCell::new(val));
261                return Ok(current);
262
263            // catch-all route
264            } else if wildcard[0] == b'*' {
265                // "/foo/*x/bar"
266                if wildcard_index + wildcard.len() != prefix.len() {
267                    return Err(InsertError::InvalidCatchAll);
268                }
269
270                if let Some(i) = wildcard_index.checked_sub(1) {
271                    // "/foo/bar*x"
272                    if prefix[i] != b'/' {
273                        return Err(InsertError::InvalidCatchAll);
274                    }
275                }
276
277                // "*x" without leading `/`
278                if prefix == route && route[0] != b'/' {
279                    return Err(InsertError::InvalidCatchAll);
280                }
281
282                // insert prefix before the current wildcard
283                if wildcard_index > 0 {
284                    current.prefix = prefix[..wildcard_index].to_owned();
285                    prefix = &prefix[wildcard_index..];
286                }
287
288                let child = Self {
289                    prefix: prefix.to_owned(),
290                    node_type: NodeType::CatchAll,
291                    value: Some(UnsafeCell::new(val)),
292                    priority: 1,
293                    ..Self::default()
294                };
295
296                let i = current.add_child(child);
297                current.wild_child = true;
298
299                return Ok(&mut current.children[i]);
300            }
301        }
302    }
303}
304
305struct Skipped<'n, 'p, T> {
306    path: &'p [u8],
307    node: &'n Node<T>,
308    params: usize,
309}
310
311#[rustfmt::skip]
312macro_rules! backtracker {
313    ($skipped_nodes:ident, $path:ident, $current:ident, $params:ident, $backtracking:ident, $walk:lifetime) => {
314        macro_rules! try_backtrack {
315            () => {
316                // try backtracking to any matching wildcard nodes we skipped while traversing
317                // the tree
318                while let Some(skipped) = $skipped_nodes.pop() {
319                    if skipped.path.ends_with($path) {
320                        $path = skipped.path;
321                        $current = &skipped.node;
322                        $params.truncate(skipped.params);
323                        $backtracking = true;
324                        continue $walk;
325                    }
326                }
327            };
328        }
329    };
330}
331
332impl<T> Node<T> {
333    // it's a bit sad that we have to introduce unsafe here but rust doesn't really have a way
334    // to abstract over mutability, so `UnsafeCell` lets us avoid having to duplicate logic between
335    // `at` and `at_mut`
336    pub fn at<'n, 'p>(
337        &'n self,
338        full_path: &'p [u8],
339    ) -> Result<(&'n UnsafeCell<T>, Params<'n, 'p>), MatchError> {
340        let mut current = self;
341        let mut path = full_path;
342        let mut backtracking = false;
343        let mut params = Params::new();
344        let mut skipped_nodes = Vec::new();
345
346        'walk: loop {
347            backtracker!(skipped_nodes, path, current, params, backtracking, 'walk);
348
349            // the path is longer than this node's prefix, we are expecting a child node
350            if path.len() > current.prefix.len() {
351                let (prefix, rest) = path.split_at(current.prefix.len());
352
353                // the prefix matches
354                if prefix == current.prefix {
355                    let first = rest[0];
356                    let consumed = path;
357                    path = rest;
358
359                    // try searching for a matching static child unless we are currently
360                    // backtracking, which would mean we already traversed them
361                    if !backtracking {
362                        if let Some(i) = current.indices.iter().position(|&c| c == first) {
363                            // keep track of wildcard routes we skipped to backtrack to later if
364                            // we don't find a math
365                            if current.wild_child {
366                                skipped_nodes.push(Skipped {
367                                    path: consumed,
368                                    node: current,
369                                    params: params.len(),
370                                });
371                            }
372
373                            // child won't match because of an extra trailing slash
374                            if path == b"/"
375                                && current.children[i].prefix != b"/"
376                                && current.value.is_some()
377                            {
378                                return Err(MatchError::ExtraTrailingSlash);
379                            }
380
381                            // continue with the child node
382                            current = &current.children[i];
383                            continue 'walk;
384                        }
385                    }
386
387                    // we didn't find a match and there are no children with wildcards, there is no match
388                    if !current.wild_child {
389                        // extra trailing slash
390                        if path == b"/" && current.value.is_some() {
391                            return Err(MatchError::ExtraTrailingSlash);
392                        }
393
394                        // try backtracking
395                        if path != b"/" {
396                            try_backtrack!();
397                        }
398
399                        // nothing found
400                        return Err(MatchError::NotFound);
401                    }
402
403                    // handle the wildcard child, which is always at the end of the list
404                    current = current.children.last().unwrap();
405
406                    match current.node_type {
407                        NodeType::Param => {
408                            // check if there are more segments in the path other than this parameter
409                            match path.iter().position(|&c| c == b'/') {
410                                Some(i) => {
411                                    let (param, rest) = path.split_at(i);
412
413                                    if let [child] = current.children.as_slice() {
414                                        // child won't match because of an extra trailing slash
415                                        if rest == b"/"
416                                            && child.prefix != b"/"
417                                            && current.value.is_some()
418                                        {
419                                            return Err(MatchError::ExtraTrailingSlash);
420                                        }
421
422                                        // store the parameter value
423                                        params.push(&current.prefix[1..], param);
424
425                                        // continue with the child node
426                                        path = rest;
427                                        current = child;
428                                        backtracking = false;
429                                        continue 'walk;
430                                    }
431
432                                    // this node has no children yet the path has more segments...
433                                    // either the path has an extra trailing slash or there is no match
434                                    if path.len() == i + 1 {
435                                        return Err(MatchError::ExtraTrailingSlash);
436                                    }
437
438                                    // try backtracking
439                                    if path != b"/" {
440                                        try_backtrack!();
441                                    }
442
443                                    return Err(MatchError::NotFound);
444                                }
445                                // this is the last path segment
446                                None => {
447                                    // store the parameter value
448                                    params.push(&current.prefix[1..], path);
449
450                                    // found the matching value
451                                    if let Some(ref value) = current.value {
452                                        // remap parameter keys
453                                        params.for_each_key_mut(|(i, key)| {
454                                            *key = &current.param_remapping[i][1..]
455                                        });
456
457                                        return Ok((value, params));
458                                    }
459
460                                    // check the child node in case the path is missing a trailing slash
461                                    if let [child] = current.children.as_slice() {
462                                        current = child;
463
464                                        if (current.prefix == b"/" && current.value.is_some())
465                                            || (current.prefix.is_empty()
466                                                && current.indices == b"/")
467                                        {
468                                            return Err(MatchError::MissingTrailingSlash);
469                                        }
470
471                                        // no match, try backtracking
472                                        if path != b"/" {
473                                            try_backtrack!();
474                                        }
475                                    }
476
477                                    // this node doesn't have the value, no match
478                                    return Err(MatchError::NotFound);
479                                }
480                            }
481                        }
482                        NodeType::CatchAll => {
483                            // catch all segments are only allowed at the end of the route,
484                            // either this node has the value or there is no match
485                            return match current.value {
486                                Some(ref value) => {
487                                    // remap parameter keys
488                                    params.for_each_key_mut(|(i, key)| {
489                                        *key = &current.param_remapping[i][1..]
490                                    });
491
492                                    // store the final catch-all parameter
493                                    params.push(&current.prefix[1..], path);
494
495                                    Ok((value, params))
496                                }
497                                None => Err(MatchError::NotFound),
498                            };
499                        }
500                        _ => unreachable!(),
501                    }
502                }
503            }
504
505            // this is it, we should have reached the node containing the value
506            if path == current.prefix {
507                if let Some(ref value) = current.value {
508                    // remap parameter keys
509                    params.for_each_key_mut(|(i, key)| *key = &current.param_remapping[i][1..]);
510                    return Ok((value, params));
511                }
512
513                // nope, try backtracking
514                if path != b"/" {
515                    try_backtrack!();
516                }
517
518                // TODO: does this *always* means there is an extra trailing slash?
519                if path == b"/" && current.wild_child && current.node_type != NodeType::Root {
520                    return Err(MatchError::unsure(full_path));
521                }
522
523                if !backtracking {
524                    // check if the path is missing a trailing slash
525                    if let Some(i) = current.indices.iter().position(|&c| c == b'/') {
526                        current = &current.children[i];
527
528                        if current.prefix.len() == 1 && current.value.is_some() {
529                            return Err(MatchError::MissingTrailingSlash);
530                        }
531                    }
532                }
533
534                return Err(MatchError::NotFound);
535            }
536
537            // nothing matches, check for a missing trailing slash
538            if current.prefix.split_last() == Some((&b'/', path)) && current.value.is_some() {
539                return Err(MatchError::MissingTrailingSlash);
540            }
541
542            // last chance, try backtracking
543            if path != b"/" {
544                try_backtrack!();
545            }
546
547            return Err(MatchError::NotFound);
548        }
549    }
550
551    #[cfg(feature = "__test_helpers")]
552    pub fn check_priorities(&self) -> Result<u32, (u32, u32)> {
553        let mut priority: u32 = 0;
554        for child in &self.children {
555            priority += child.check_priorities()?;
556        }
557
558        if self.value.is_some() {
559            priority += 1;
560        }
561
562        if self.priority != priority {
563            return Err((self.priority, priority));
564        }
565
566        Ok(priority)
567    }
568}
569
570/// An ordered list of route parameters keys for a specific route, stored at leaf nodes.
571type ParamRemapping = Vec<Vec<u8>>;
572
573/// Returns `path` with normalized route parameters, and a parameter remapping
574/// to store at the leaf node for this route.
575fn normalize_params(mut path: Vec<u8>) -> Result<(Vec<u8>, ParamRemapping), InsertError> {
576    let mut start = 0;
577    let mut original = ParamRemapping::new();
578
579    // parameter names are normalized alphabetically
580    let mut next = b'a';
581
582    loop {
583        let (wildcard, mut wildcard_index) = match find_wildcard(&path[start..])? {
584            Some((w, i)) => (w, i),
585            None => return Ok((path, original)),
586        };
587
588        // makes sure the param has a valid name
589        if wildcard.len() < 2 {
590            return Err(InsertError::UnnamedParam);
591        }
592
593        // don't need to normalize catch-all parameters
594        if wildcard[0] == b'*' {
595            start += wildcard_index + wildcard.len();
596            continue;
597        }
598
599        wildcard_index += start;
600
601        // normalize the parameter
602        let removed = path.splice(
603            (wildcard_index)..(wildcard_index + wildcard.len()),
604            vec![b':', next],
605        );
606
607        // remember the original name for remappings
608        original.push(removed.collect());
609
610        // get the next key
611        next += 1;
612        if next > b'z' {
613            panic!("too many route parameters");
614        }
615
616        start = wildcard_index + 2;
617    }
618}
619
620/// Restores `route` to it's original, denormalized form.
621pub(crate) fn denormalize_params(route: &mut Vec<u8>, params: &ParamRemapping) {
622    let mut start = 0;
623    let mut i = 0;
624
625    loop {
626        // find the next wildcard
627        let (wildcard, mut wildcard_index) = match find_wildcard(&route[start..]).unwrap() {
628            Some((w, i)) => (w, i),
629            None => return,
630        };
631
632        wildcard_index += start;
633
634        let next = match params.get(i) {
635            Some(param) => param.clone(),
636            None => return,
637        };
638
639        // denormalize this parameter
640        route.splice(
641            (wildcard_index)..(wildcard_index + wildcard.len()),
642            next.clone(),
643        );
644
645        i += 1;
646        start = wildcard_index + 2;
647    }
648}
649
650// Searches for a wildcard segment and checks the path for invalid characters.
651fn find_wildcard(path: &[u8]) -> Result<Option<(&[u8], usize)>, InsertError> {
652    for (start, &c) in path.iter().enumerate() {
653        // a wildcard starts with ':' (param) or '*' (catch-all)
654        if c != b':' && c != b'*' {
655            continue;
656        }
657
658        for (end, &c) in path[start + 1..].iter().enumerate() {
659            match c {
660                b'/' => return Ok(Some((&path[start..start + 1 + end], start))),
661                b':' | b'*' => return Err(InsertError::TooManyParams),
662                _ => {}
663            }
664        }
665
666        return Ok(Some((&path[start..], start)));
667    }
668
669    Ok(None)
670}
671
672impl<T> Clone for Node<T>
673where
674    T: Clone,
675{
676    fn clone(&self) -> Self {
677        let value = self.value.as_ref().map(|value| {
678            // safety: we only expose &mut T through &mut self
679            let value = unsafe { &*value.get() };
680            UnsafeCell::new(value.clone())
681        });
682
683        Self {
684            value,
685            prefix: self.prefix.clone(),
686            wild_child: self.wild_child,
687            node_type: self.node_type.clone(),
688            indices: self.indices.clone(),
689            children: self.children.clone(),
690            param_remapping: self.param_remapping.clone(),
691            priority: self.priority,
692        }
693    }
694}
695
696impl<T> Default for Node<T> {
697    fn default() -> Self {
698        Self {
699            param_remapping: ParamRemapping::new(),
700            prefix: Vec::new(),
701            wild_child: false,
702            node_type: NodeType::Static,
703            indices: Vec::new(),
704            children: Vec::new(),
705            value: None,
706            priority: 0,
707        }
708    }
709}
710
711#[cfg(test)]
712const _: () = {
713    use std::fmt::{self, Debug, Formatter};
714
715    // visualize the tree structure when debugging
716    impl<T: Debug> Debug for Node<T> {
717        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
718            // safety: we only expose &mut T through &mut self
719            let value = unsafe { self.value.as_ref().map(|x| &*x.get()) };
720
721            let indices = self
722                .indices
723                .iter()
724                .map(|&x| char::from_u32(x as _))
725                .collect::<Vec<_>>();
726
727            let param_names = self
728                .param_remapping
729                .iter()
730                .map(|x| std::str::from_utf8(x).unwrap())
731                .collect::<Vec<_>>();
732
733            let mut fmt = f.debug_struct("Node");
734            fmt.field("value", &value);
735            fmt.field("prefix", &std::str::from_utf8(&self.prefix));
736            fmt.field("node_type", &self.node_type);
737            fmt.field("children", &self.children);
738            fmt.field("param_names", &param_names);
739            fmt.field("indices", &indices);
740            fmt.finish()
741        }
742    }
743};