1use crate::{InsertError, MatchError, Params};
2
3use std::cell::UnsafeCell;
4use std::cmp::min;
5use std::mem;
6
7#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone)]
9pub(crate) enum NodeType {
10 Root,
12 Param,
14 CatchAll,
16 Static,
18}
19
20pub struct Node<T> {
24 priority: u32,
25 wild_child: bool,
26 indices: Vec<u8>,
27 value: Option<UnsafeCell<T>>,
29 pub(crate) param_remapping: ParamRemapping,
30 pub(crate) node_type: NodeType,
31 pub(crate) prefix: Vec<u8>,
32 pub(crate) children: Vec<Self>,
33}
34
35unsafe impl<T: Send> Send for Node<T> {}
37unsafe impl<T: Sync> Sync for Node<T> {}
38
39impl<T> Node<T> {
40 pub fn insert(&mut self, route: impl Into<String>, val: T) -> Result<(), InsertError> {
41 let route = route.into().into_bytes();
42 let (route, param_remapping) = normalize_params(route)?;
43 let mut prefix = route.as_ref();
44
45 self.priority += 1;
46
47 if self.prefix.is_empty() && self.children.is_empty() {
49 let last = self.insert_child(prefix, &route, val)?;
50 last.param_remapping = param_remapping;
51 self.node_type = NodeType::Root;
52 return Ok(());
53 }
54
55 let mut current = self;
56
57 'walk: loop {
58 let len = min(prefix.len(), current.prefix.len());
60 let common_prefix = (0..len)
61 .find(|&i| prefix[i] != current.prefix[i])
62 .unwrap_or(len);
63
64 if common_prefix < current.prefix.len() {
66 let child = Node {
67 prefix: current.prefix[common_prefix..].to_owned(),
68 children: mem::take(&mut current.children),
69 wild_child: current.wild_child,
70 indices: current.indices.clone(),
71 value: current.value.take(),
72 param_remapping: mem::take(&mut current.param_remapping),
73 priority: current.priority - 1,
74 ..Node::default()
75 };
76
77 current.children = vec![child];
79 current.indices = vec![current.prefix[common_prefix]];
80 current.prefix = prefix[..common_prefix].to_owned();
81 current.wild_child = false;
82 }
83
84 if prefix.len() > common_prefix {
86 prefix = &prefix[common_prefix..];
87
88 let next = prefix[0];
89
90 if current.node_type == NodeType::Param
92 && next == b'/'
93 && current.children.len() == 1
94 {
95 current = &mut current.children[0];
96 current.priority += 1;
97
98 continue 'walk;
99 }
100
101 for mut i in 0..current.indices.len() {
103 if next == current.indices[i] {
105 i = current.update_child_priority(i);
106 current = &mut current.children[i];
107 continue 'walk;
108 }
109 }
110
111 if !matches!(next, b':' | b'*') && current.node_type != NodeType::CatchAll {
113 current.indices.push(next);
114 let mut child = current.add_child(Node::default());
115 child = current.update_child_priority(child);
116
117 let last = current.children[child].insert_child(prefix, &route, val)?;
119 last.param_remapping = param_remapping;
120 return Ok(());
121 }
122
123 if current.wild_child {
125 current = current.children.last_mut().unwrap();
127 current.priority += 1;
128
129 if prefix.len() < current.prefix.len()
131 || current.prefix != prefix[..current.prefix.len()]
132 || current.node_type == NodeType::CatchAll
134 || (current.prefix.len() < prefix.len()
136 && prefix[current.prefix.len()] != b'/')
137 {
138 return Err(InsertError::conflict(&route, prefix, current));
139 }
140
141 continue 'walk;
142 }
143
144 let last = current.insert_child(prefix, &route, val)?;
146 last.param_remapping = param_remapping;
147 return Ok(());
148 }
149
150 if current.value.is_some() {
152 return Err(InsertError::conflict(&route, prefix, current));
153 }
154
155 current.value = Some(UnsafeCell::new(val));
157 current.param_remapping = param_remapping;
158
159 return Ok(());
160 }
161 }
162
163 fn add_child(&mut self, child: Node<T>) -> usize {
165 let len = self.children.len();
166
167 if self.wild_child && len > 0 {
168 self.children.insert(len - 1, child);
169 len - 1
170 } else {
171 self.children.push(child);
172 len
173 }
174 }
175
176 fn update_child_priority(&mut self, i: usize) -> usize {
180 self.children[i].priority += 1;
181 let priority = self.children[i].priority;
182
183 let mut updated = i;
185 while updated > 0 && self.children[updated - 1].priority < priority {
186 self.children.swap(updated - 1, updated);
188 updated -= 1;
189 }
190
191 if updated != i {
193 self.indices = [
194 &self.indices[..updated], &self.indices[i..=i], &self.indices[updated..i], &self.indices[i + 1..],
198 ]
199 .concat();
200 }
201
202 updated
203 }
204
205 fn insert_child(
207 &mut self,
208 mut prefix: &[u8],
209 route: &[u8],
210 val: T,
211 ) -> Result<&mut Node<T>, InsertError> {
212 let mut current = self;
213
214 loop {
215 let (wildcard, wildcard_index) = match find_wildcard(prefix)? {
217 Some((w, i)) => (w, i),
218 None => {
220 current.value = Some(UnsafeCell::new(val));
221 current.prefix = prefix.to_owned();
222 return Ok(current);
223 }
224 };
225
226 if wildcard[0] == b':' {
228 if wildcard_index > 0 {
230 current.prefix = prefix[..wildcard_index].to_owned();
231 prefix = &prefix[wildcard_index..];
232 }
233
234 let child = Self {
235 node_type: NodeType::Param,
236 prefix: wildcard.to_owned(),
237 ..Self::default()
238 };
239
240 let child = current.add_child(child);
241 current.wild_child = true;
242 current = &mut current.children[child];
243 current.priority += 1;
244
245 if wildcard.len() < prefix.len() {
248 prefix = &prefix[wildcard.len()..];
249 let child = Self {
250 priority: 1,
251 ..Self::default()
252 };
253
254 let child = current.add_child(child);
255 current = &mut current.children[child];
256 continue;
257 }
258
259 current.value = Some(UnsafeCell::new(val));
261 return Ok(current);
262
263 } else if wildcard[0] == b'*' {
265 if wildcard_index + wildcard.len() != prefix.len() {
267 return Err(InsertError::InvalidCatchAll);
268 }
269
270 if let Some(i) = wildcard_index.checked_sub(1) {
271 if prefix[i] != b'/' {
273 return Err(InsertError::InvalidCatchAll);
274 }
275 }
276
277 if prefix == route && route[0] != b'/' {
279 return Err(InsertError::InvalidCatchAll);
280 }
281
282 if wildcard_index > 0 {
284 current.prefix = prefix[..wildcard_index].to_owned();
285 prefix = &prefix[wildcard_index..];
286 }
287
288 let child = Self {
289 prefix: prefix.to_owned(),
290 node_type: NodeType::CatchAll,
291 value: Some(UnsafeCell::new(val)),
292 priority: 1,
293 ..Self::default()
294 };
295
296 let i = current.add_child(child);
297 current.wild_child = true;
298
299 return Ok(&mut current.children[i]);
300 }
301 }
302 }
303}
304
305struct Skipped<'n, 'p, T> {
306 path: &'p [u8],
307 node: &'n Node<T>,
308 params: usize,
309}
310
311#[rustfmt::skip]
312macro_rules! backtracker {
313 ($skipped_nodes:ident, $path:ident, $current:ident, $params:ident, $backtracking:ident, $walk:lifetime) => {
314 macro_rules! try_backtrack {
315 () => {
316 while let Some(skipped) = $skipped_nodes.pop() {
319 if skipped.path.ends_with($path) {
320 $path = skipped.path;
321 $current = &skipped.node;
322 $params.truncate(skipped.params);
323 $backtracking = true;
324 continue $walk;
325 }
326 }
327 };
328 }
329 };
330}
331
332impl<T> Node<T> {
333 pub fn at<'n, 'p>(
337 &'n self,
338 full_path: &'p [u8],
339 ) -> Result<(&'n UnsafeCell<T>, Params<'n, 'p>), MatchError> {
340 let mut current = self;
341 let mut path = full_path;
342 let mut backtracking = false;
343 let mut params = Params::new();
344 let mut skipped_nodes = Vec::new();
345
346 'walk: loop {
347 backtracker!(skipped_nodes, path, current, params, backtracking, 'walk);
348
349 if path.len() > current.prefix.len() {
351 let (prefix, rest) = path.split_at(current.prefix.len());
352
353 if prefix == current.prefix {
355 let first = rest[0];
356 let consumed = path;
357 path = rest;
358
359 if !backtracking {
362 if let Some(i) = current.indices.iter().position(|&c| c == first) {
363 if current.wild_child {
366 skipped_nodes.push(Skipped {
367 path: consumed,
368 node: current,
369 params: params.len(),
370 });
371 }
372
373 if path == b"/"
375 && current.children[i].prefix != b"/"
376 && current.value.is_some()
377 {
378 return Err(MatchError::ExtraTrailingSlash);
379 }
380
381 current = ¤t.children[i];
383 continue 'walk;
384 }
385 }
386
387 if !current.wild_child {
389 if path == b"/" && current.value.is_some() {
391 return Err(MatchError::ExtraTrailingSlash);
392 }
393
394 if path != b"/" {
396 try_backtrack!();
397 }
398
399 return Err(MatchError::NotFound);
401 }
402
403 current = current.children.last().unwrap();
405
406 match current.node_type {
407 NodeType::Param => {
408 match path.iter().position(|&c| c == b'/') {
410 Some(i) => {
411 let (param, rest) = path.split_at(i);
412
413 if let [child] = current.children.as_slice() {
414 if rest == b"/"
416 && child.prefix != b"/"
417 && current.value.is_some()
418 {
419 return Err(MatchError::ExtraTrailingSlash);
420 }
421
422 params.push(¤t.prefix[1..], param);
424
425 path = rest;
427 current = child;
428 backtracking = false;
429 continue 'walk;
430 }
431
432 if path.len() == i + 1 {
435 return Err(MatchError::ExtraTrailingSlash);
436 }
437
438 if path != b"/" {
440 try_backtrack!();
441 }
442
443 return Err(MatchError::NotFound);
444 }
445 None => {
447 params.push(¤t.prefix[1..], path);
449
450 if let Some(ref value) = current.value {
452 params.for_each_key_mut(|(i, key)| {
454 *key = ¤t.param_remapping[i][1..]
455 });
456
457 return Ok((value, params));
458 }
459
460 if let [child] = current.children.as_slice() {
462 current = child;
463
464 if (current.prefix == b"/" && current.value.is_some())
465 || (current.prefix.is_empty()
466 && current.indices == b"/")
467 {
468 return Err(MatchError::MissingTrailingSlash);
469 }
470
471 if path != b"/" {
473 try_backtrack!();
474 }
475 }
476
477 return Err(MatchError::NotFound);
479 }
480 }
481 }
482 NodeType::CatchAll => {
483 return match current.value {
486 Some(ref value) => {
487 params.for_each_key_mut(|(i, key)| {
489 *key = ¤t.param_remapping[i][1..]
490 });
491
492 params.push(¤t.prefix[1..], path);
494
495 Ok((value, params))
496 }
497 None => Err(MatchError::NotFound),
498 };
499 }
500 _ => unreachable!(),
501 }
502 }
503 }
504
505 if path == current.prefix {
507 if let Some(ref value) = current.value {
508 params.for_each_key_mut(|(i, key)| *key = ¤t.param_remapping[i][1..]);
510 return Ok((value, params));
511 }
512
513 if path != b"/" {
515 try_backtrack!();
516 }
517
518 if path == b"/" && current.wild_child && current.node_type != NodeType::Root {
520 return Err(MatchError::unsure(full_path));
521 }
522
523 if !backtracking {
524 if let Some(i) = current.indices.iter().position(|&c| c == b'/') {
526 current = ¤t.children[i];
527
528 if current.prefix.len() == 1 && current.value.is_some() {
529 return Err(MatchError::MissingTrailingSlash);
530 }
531 }
532 }
533
534 return Err(MatchError::NotFound);
535 }
536
537 if current.prefix.split_last() == Some((&b'/', path)) && current.value.is_some() {
539 return Err(MatchError::MissingTrailingSlash);
540 }
541
542 if path != b"/" {
544 try_backtrack!();
545 }
546
547 return Err(MatchError::NotFound);
548 }
549 }
550
551 #[cfg(feature = "__test_helpers")]
552 pub fn check_priorities(&self) -> Result<u32, (u32, u32)> {
553 let mut priority: u32 = 0;
554 for child in &self.children {
555 priority += child.check_priorities()?;
556 }
557
558 if self.value.is_some() {
559 priority += 1;
560 }
561
562 if self.priority != priority {
563 return Err((self.priority, priority));
564 }
565
566 Ok(priority)
567 }
568}
569
570type ParamRemapping = Vec<Vec<u8>>;
572
573fn normalize_params(mut path: Vec<u8>) -> Result<(Vec<u8>, ParamRemapping), InsertError> {
576 let mut start = 0;
577 let mut original = ParamRemapping::new();
578
579 let mut next = b'a';
581
582 loop {
583 let (wildcard, mut wildcard_index) = match find_wildcard(&path[start..])? {
584 Some((w, i)) => (w, i),
585 None => return Ok((path, original)),
586 };
587
588 if wildcard.len() < 2 {
590 return Err(InsertError::UnnamedParam);
591 }
592
593 if wildcard[0] == b'*' {
595 start += wildcard_index + wildcard.len();
596 continue;
597 }
598
599 wildcard_index += start;
600
601 let removed = path.splice(
603 (wildcard_index)..(wildcard_index + wildcard.len()),
604 vec![b':', next],
605 );
606
607 original.push(removed.collect());
609
610 next += 1;
612 if next > b'z' {
613 panic!("too many route parameters");
614 }
615
616 start = wildcard_index + 2;
617 }
618}
619
620pub(crate) fn denormalize_params(route: &mut Vec<u8>, params: &ParamRemapping) {
622 let mut start = 0;
623 let mut i = 0;
624
625 loop {
626 let (wildcard, mut wildcard_index) = match find_wildcard(&route[start..]).unwrap() {
628 Some((w, i)) => (w, i),
629 None => return,
630 };
631
632 wildcard_index += start;
633
634 let next = match params.get(i) {
635 Some(param) => param.clone(),
636 None => return,
637 };
638
639 route.splice(
641 (wildcard_index)..(wildcard_index + wildcard.len()),
642 next.clone(),
643 );
644
645 i += 1;
646 start = wildcard_index + 2;
647 }
648}
649
650fn find_wildcard(path: &[u8]) -> Result<Option<(&[u8], usize)>, InsertError> {
652 for (start, &c) in path.iter().enumerate() {
653 if c != b':' && c != b'*' {
655 continue;
656 }
657
658 for (end, &c) in path[start + 1..].iter().enumerate() {
659 match c {
660 b'/' => return Ok(Some((&path[start..start + 1 + end], start))),
661 b':' | b'*' => return Err(InsertError::TooManyParams),
662 _ => {}
663 }
664 }
665
666 return Ok(Some((&path[start..], start)));
667 }
668
669 Ok(None)
670}
671
672impl<T> Clone for Node<T>
673where
674 T: Clone,
675{
676 fn clone(&self) -> Self {
677 let value = self.value.as_ref().map(|value| {
678 let value = unsafe { &*value.get() };
680 UnsafeCell::new(value.clone())
681 });
682
683 Self {
684 value,
685 prefix: self.prefix.clone(),
686 wild_child: self.wild_child,
687 node_type: self.node_type.clone(),
688 indices: self.indices.clone(),
689 children: self.children.clone(),
690 param_remapping: self.param_remapping.clone(),
691 priority: self.priority,
692 }
693 }
694}
695
696impl<T> Default for Node<T> {
697 fn default() -> Self {
698 Self {
699 param_remapping: ParamRemapping::new(),
700 prefix: Vec::new(),
701 wild_child: false,
702 node_type: NodeType::Static,
703 indices: Vec::new(),
704 children: Vec::new(),
705 value: None,
706 priority: 0,
707 }
708 }
709}
710
711#[cfg(test)]
712const _: () = {
713 use std::fmt::{self, Debug, Formatter};
714
715 impl<T: Debug> Debug for Node<T> {
717 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
718 let value = unsafe { self.value.as_ref().map(|x| &*x.get()) };
720
721 let indices = self
722 .indices
723 .iter()
724 .map(|&x| char::from_u32(x as _))
725 .collect::<Vec<_>>();
726
727 let param_names = self
728 .param_remapping
729 .iter()
730 .map(|x| std::str::from_utf8(x).unwrap())
731 .collect::<Vec<_>>();
732
733 let mut fmt = f.debug_struct("Node");
734 fmt.field("value", &value);
735 fmt.field("prefix", &std::str::from_utf8(&self.prefix));
736 fmt.field("node_type", &self.node_type);
737 fmt.field("children", &self.children);
738 fmt.field("param_names", ¶m_names);
739 fmt.field("indices", &indices);
740 fmt.finish()
741 }
742 }
743};