1use std::fmt::Display;
6use std::str::FromStr;
7use std::sync::Arc;
8
9use super::{PolicyError, PortRange};
10use crate::util::intern::InternCache;
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash)]
35pub struct PortPolicy {
36 allowed: Vec<PortRange>,
40}
41
42impl Display for PortPolicy {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 if self.allowed.is_empty() {
45 write!(f, "reject 1-65535")?;
46 } else {
47 write!(f, "accept ")?;
48 let mut comma = "";
49 for range in &self.allowed {
50 write!(f, "{}{}", comma, range)?;
51 comma = ",";
52 }
53 }
54 Ok(())
55 }
56}
57
58impl PortPolicy {
59 pub fn new_reject_all() -> Self {
61 PortPolicy {
62 allowed: Vec::new(),
63 }
64 }
65
66 pub fn from_allowed_port_list(ports: Vec<u16>) -> Self {
69 let mut ports = ports;
70 ports.sort();
71 let mut ports = ports.iter().peekable();
72
73 let mut out = PortPolicy::new_reject_all();
74
75 let mut current_min = None;
76 while let Some(port) = ports.next() {
77 if current_min.is_none() {
78 current_min = Some(port);
79 }
80 if let Some(next_port) = ports.peek() {
81 if **next_port != port + 1 {
82 let _ = out.push_policy(PortRange::new_unchecked(
83 *current_min.expect("Don't have min port number"),
84 *port,
85 ));
86 current_min = None;
87 }
88 } else {
89 let _ = out.push_policy(PortRange::new_unchecked(
90 *current_min.expect("Don't have min port number"),
91 *port,
92 ));
93 }
94 }
95
96 out
97 }
98
99 fn invert(&mut self) {
101 let mut prev_hi = 0;
102 let mut new_allowed = Vec::new();
103 for entry in &self.allowed {
104 if entry.lo > prev_hi + 1 {
107 new_allowed.push(PortRange::new_unchecked(prev_hi + 1, entry.lo - 1));
108 }
109 prev_hi = entry.hi;
110 }
111 if prev_hi < 65535 {
112 new_allowed.push(PortRange::new_unchecked(prev_hi + 1, 65535));
113 }
114 self.allowed = new_allowed;
115 }
116 fn push_policy(&mut self, item: PortRange) -> Result<(), PolicyError> {
120 if let Some(prev) = self.allowed.last() {
121 if prev.hi >= item.lo {
124 return Err(PolicyError::InvalidPolicy);
125 } else if prev.hi == item.lo - 1 {
126 let r = PortRange::new_unchecked(prev.lo, item.hi);
128 self.allowed.pop();
129 self.allowed.push(r);
130 return Ok(());
131 }
132 }
133
134 self.allowed.push(item);
135 Ok(())
136 }
137 pub fn allows_port(&self, port: u16) -> bool {
139 self.allowed
140 .binary_search_by(|range| range.compare_to_port(port))
141 .is_ok()
142 }
143 pub fn intern(self) -> Arc<Self> {
145 POLICY_CACHE.intern(self)
146 }
147 pub fn allows_some_port(&self) -> bool {
159 !self.allowed.is_empty()
160 }
161}
162
163impl FromStr for PortPolicy {
164 type Err = PolicyError;
165 fn from_str(mut s: &str) -> Result<Self, PolicyError> {
166 let invert = if s.starts_with("accept ") {
167 false
168 } else if s.starts_with("reject ") {
169 true
170 } else {
171 return Err(PolicyError::InvalidPolicy);
172 };
173 let mut result = PortPolicy {
174 allowed: Vec::new(),
175 };
176 s = &s[7..];
177 for item in s.split(',') {
178 let r: PortRange = item.parse()?;
179 result.push_policy(r)?;
180 }
181 if invert {
182 result.invert();
183 }
184 Ok(result)
185 }
186}
187
188static POLICY_CACHE: InternCache<PortPolicy> = InternCache::new();
193
194#[cfg(test)]
195mod test {
196 #![allow(clippy::bool_assert_comparison)]
198 #![allow(clippy::clone_on_copy)]
199 #![allow(clippy::dbg_macro)]
200 #![allow(clippy::mixed_attributes_style)]
201 #![allow(clippy::print_stderr)]
202 #![allow(clippy::print_stdout)]
203 #![allow(clippy::single_char_pattern)]
204 #![allow(clippy::unwrap_used)]
205 #![allow(clippy::unchecked_duration_subtraction)]
206 #![allow(clippy::useless_vec)]
207 #![allow(clippy::needless_pass_by_value)]
208 use itertools::Itertools;
210
211 use super::*;
212
213 #[test]
214 fn test_roundtrip() {
215 fn check(inp: &str, outp: &str, allow: &[u16], deny: &[u16]) {
216 let policy = inp.parse::<PortPolicy>().unwrap();
217 assert_eq!(format!("{}", policy), outp);
218 for p in allow {
219 assert!(policy.allows_port(*p));
220 }
221 for p in deny {
222 assert!(!policy.allows_port(*p));
223 }
224 }
225
226 check(
227 "accept 1-10,30-50,600",
228 "accept 1-10,30-50,600",
229 &[1, 10, 35, 600],
230 &[0, 11, 55, 599, 601],
231 );
232 check("accept 1-10,11-20", "accept 1-20", &[], &[]);
233 check(
234 "reject 1-30",
235 "accept 31-65535",
236 &[31, 10001, 65535],
237 &[0, 1, 30],
238 );
239 check(
240 "reject 300-500",
241 "accept 1-299,501-65535",
242 &[31, 10001, 65535],
243 &[300, 301, 500],
244 );
245 check("reject 10,11,12,13,15", "accept 1-9,14,16-65535", &[], &[]);
246 check(
247 "reject 1-65535",
248 "reject 1-65535",
249 &[],
250 &[1, 300, 301, 500, 10001, 65535],
251 );
252 }
253
254 #[test]
255 fn test_bad() {
256 for s in &[
257 "ignore 1-10",
258 "allow 1-100",
259 "accept",
260 "reject",
261 "accept x-y",
262 "accept 1-20,19-30",
263 "accept 1-20,20-30",
264 "reject 1,1,1,1",
265 "reject 1,2,foo,4",
266 "reject 5,4,3,2",
267 ] {
268 assert!(s.parse::<PortPolicy>().is_err());
269 }
270 }
271
272 #[test]
273 fn test_from_allowed_port_list() {
274 let mut cases = vec![];
275 cases.push((vec![1, 2, 3, 7, 8, 10, 42], "accept 1-3,7-8,10,42"));
276 cases.push((vec![1, 3, 5], "accept 1,3,5"));
277 cases.push((vec![1, 2, 3, 4], "accept 1-4"));
278 cases.push((vec![65535], "accept 65535"));
279 cases.push((vec![], "reject 1-65535"));
280
281 for (port_list, port_range) in cases {
282 let expected = port_range.parse::<PortPolicy>().unwrap();
283 for port_list in port_list.iter().copied().permutations(port_list.len()) {
284 assert_eq!(PortPolicy::from_allowed_port_list(port_list), expected,);
285 }
286 }
287 }
288}