1use std::fmt::Display;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::str::FromStr;
7
8use super::{PolicyError, PortRange};
9
10#[derive(Clone, Debug, Default)]
36pub struct AddrPolicy {
37 rules: Vec<AddrPolicyRule>,
43}
44
45#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48#[allow(clippy::exhaustive_enums)]
49pub enum RuleKind {
50 Accept,
52 Reject,
54}
55
56impl AddrPolicy {
57 pub fn allows(&self, addr: &IpAddr, port: u16) -> Option<RuleKind> {
64 self.rules
65 .iter()
66 .find(|rule| rule.pattern.matches(addr, port))
67 .map(|AddrPolicyRule { kind, .. }| *kind)
68 }
69
70 pub fn allows_sockaddr(&self, addr: &SocketAddr) -> Option<RuleKind> {
72 self.allows(&addr.ip(), addr.port())
73 }
74
75 pub fn new() -> Self {
77 AddrPolicy::default()
78 }
79
80 pub fn push(&mut self, kind: RuleKind, pattern: AddrPortPattern) {
88 self.rules.push(AddrPolicyRule { kind, pattern });
89 }
90}
91
92#[derive(Clone, Debug)]
96struct AddrPolicyRule {
97 kind: RuleKind,
99 pattern: AddrPortPattern,
101}
102
103#[derive(
134 Clone, Debug, Eq, PartialEq, serde_with::SerializeDisplay, serde_with::DeserializeFromStr,
135)]
136pub struct AddrPortPattern {
137 pattern: IpPattern,
139 ports: PortRange,
141}
142
143impl AddrPortPattern {
144 pub fn new_all() -> Self {
146 Self {
147 pattern: IpPattern::Star,
148 ports: PortRange::new_all(),
149 }
150 }
151
152 pub fn matches(&self, addr: &IpAddr, port: u16) -> bool {
154 self.pattern.matches(addr) && self.ports.contains(port)
155 }
156 pub fn matches_sockaddr(&self, addr: &SocketAddr) -> bool {
158 self.matches(&addr.ip(), addr.port())
159 }
160}
161
162impl Display for AddrPortPattern {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 if self.ports.is_all() {
165 write!(f, "{}:*", self.pattern)
166 } else {
167 write!(f, "{}:{}", self.pattern, self.ports)
168 }
169 }
170}
171
172impl FromStr for AddrPortPattern {
173 type Err = PolicyError;
174 fn from_str(s: &str) -> Result<Self, PolicyError> {
175 let last_colon = s.rfind(':').ok_or(PolicyError::InvalidPolicy)?;
176 let pattern: IpPattern = s[..last_colon].parse()?;
177 let ports_s = &s[last_colon + 1..];
178 let ports: PortRange = if ports_s == "*" {
179 PortRange::new_all()
180 } else {
181 ports_s.parse()?
182 };
183
184 Ok(AddrPortPattern { pattern, ports })
185 }
186}
187
188#[derive(Clone, Debug, Eq, PartialEq)]
195enum IpPattern {
196 Star,
198 V4Star,
200 V6Star,
202 V4(Ipv4Addr, u8),
204 V6(Ipv6Addr, u8),
206}
207
208impl IpPattern {
209 fn from_addr_and_mask(addr: IpAddr, mask: u8) -> Result<Self, PolicyError> {
211 match (addr, mask) {
212 (IpAddr::V4(_), 0) => Ok(IpPattern::V4Star),
213 (IpAddr::V6(_), 0) => Ok(IpPattern::V6Star),
214 (IpAddr::V4(a), m) if m <= 32 => Ok(IpPattern::V4(a, m)),
215 (IpAddr::V6(a), m) if m <= 128 => Ok(IpPattern::V6(a, m)),
216 (_, _) => Err(PolicyError::InvalidMask),
217 }
218 }
219 fn matches(&self, addr: &IpAddr) -> bool {
221 match (self, addr) {
222 (IpPattern::Star, _) => true,
223 (IpPattern::V4Star, IpAddr::V4(_)) => true,
224 (IpPattern::V6Star, IpAddr::V6(_)) => true,
225 (IpPattern::V4(pat, mask), IpAddr::V4(addr)) => {
226 let p1 = u32::from_be_bytes(pat.octets());
227 let p2 = u32::from_be_bytes(addr.octets());
228 let shift = 32 - mask;
229 (p1 >> shift) == (p2 >> shift)
230 }
231 (IpPattern::V6(pat, mask), IpAddr::V6(addr)) => {
232 let p1 = u128::from_be_bytes(pat.octets());
233 let p2 = u128::from_be_bytes(addr.octets());
234 let shift = 128 - mask;
235 (p1 >> shift) == (p2 >> shift)
236 }
237 (_, _) => false,
238 }
239 }
240}
241
242impl Display for IpPattern {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 use IpPattern::*;
245 match self {
246 Star | V4Star | V6Star => write!(f, "*"),
247 V4(a, 32) => write!(f, "{}", a),
248 V4(a, m) => write!(f, "{}/{}", a, m),
249 V6(a, 128) => write!(f, "[{}]", a),
250 V6(a, m) => write!(f, "[{}]/{}", a, m),
251 }
252 }
253}
254
255fn parse_addr(mut s: &str) -> Result<IpAddr, PolicyError> {
258 let bracketed = s.starts_with('[') && s.ends_with(']');
259 if bracketed {
260 s = &s[1..s.len() - 1];
261 }
262 let addr: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?;
263 if addr.is_ipv6() != bracketed {
264 return Err(PolicyError::InvalidAddress);
265 }
266 Ok(addr)
267}
268
269impl FromStr for IpPattern {
270 type Err = PolicyError;
271 fn from_str(s: &str) -> Result<Self, PolicyError> {
272 let (ip_s, mask_s) = match s.find('/') {
273 Some(slash_idx) => (&s[..slash_idx], Some(&s[slash_idx + 1..])),
274 None => (s, None),
275 };
276 match (ip_s, mask_s) {
277 ("*", Some(_)) => Err(PolicyError::MaskWithStar),
278 ("*", None) => Ok(IpPattern::Star),
279 (s, Some(m)) => {
280 let a: IpAddr = parse_addr(s)?;
281 let m: u8 = m.parse().map_err(|_| PolicyError::InvalidMask)?;
282 IpPattern::from_addr_and_mask(a, m)
283 }
284 (s, None) => {
285 let a: IpAddr = parse_addr(s)?;
286 let m = if a.is_ipv4() { 32 } else { 128 };
287 IpPattern::from_addr_and_mask(a, m)
288 }
289 }
290 }
291}
292
293#[cfg(test)]
294mod test {
295 #![allow(clippy::bool_assert_comparison)]
297 #![allow(clippy::clone_on_copy)]
298 #![allow(clippy::dbg_macro)]
299 #![allow(clippy::mixed_attributes_style)]
300 #![allow(clippy::print_stderr)]
301 #![allow(clippy::print_stdout)]
302 #![allow(clippy::single_char_pattern)]
303 #![allow(clippy::unwrap_used)]
304 #![allow(clippy::unchecked_duration_subtraction)]
305 #![allow(clippy::useless_vec)]
306 #![allow(clippy::needless_pass_by_value)]
307 use super::*;
309
310 #[test]
311 fn test_roundtrip_rules() {
312 fn check(inp: &str, outp: &str) {
313 let policy = inp.parse::<AddrPortPattern>().unwrap();
314 assert_eq!(format!("{}", policy), outp);
315 }
316
317 check("127.0.0.2/32:77-10000", "127.0.0.2:77-10000");
318 check("127.0.0.2/32:*", "127.0.0.2:*");
319 check("127.0.0.0/16:9-100", "127.0.0.0/16:9-100");
320 check("127.0.0.0/0:443", "*:443");
321 check("*:443", "*:443");
322 check("[::1]:443", "[::1]:443");
323 check("[ffaa::]/16:80", "[ffaa::]/16:80");
324 check("[ffaa::77]/128:80", "[ffaa::77]:80");
325 }
326
327 #[test]
328 fn test_bad_rules() {
329 fn check(s: &str) {
330 assert!(s.parse::<AddrPortPattern>().is_err());
331 }
332
333 check("marzipan:80");
334 check("1.2.3.4:90-80");
335 check("1.2.3.4/100:8888");
336 check("[1.2.3.4]/16:80");
337 check("[::1]/130:8888");
338 }
339
340 #[test]
341 fn test_rule_matches() {
342 fn check(addr: &str, yes: &[&str], no: &[&str]) {
343 use std::net::SocketAddr;
344 let policy = addr.parse::<AddrPortPattern>().unwrap();
345 for s in yes {
346 let sa = s.parse::<SocketAddr>().unwrap();
347 assert!(policy.matches_sockaddr(&sa));
348 }
349 for s in no {
350 let sa = s.parse::<SocketAddr>().unwrap();
351 assert!(!policy.matches_sockaddr(&sa));
352 }
353 }
354
355 check(
356 "1.2.3.4/16:80",
357 &["1.2.3.4:80", "1.2.44.55:80"],
358 &["9.9.9.9:80", "1.3.3.4:80", "1.2.3.4:81"],
359 );
360 check(
361 "*:443-8000",
362 &["1.2.3.4:443", "[::1]:500"],
363 &["9.0.0.0:80", "[::1]:80"],
364 );
365 check(
366 "[face::]/8:80",
367 &["[fab0::7]:80"],
368 &["[dd00::]:80", "[face::7]:443"],
369 );
370
371 check("0.0.0.0/0:*", &["127.0.0.1:80"], &["[f00b::]:80"]);
372 check("[::]/0:*", &["[f00b::]:80"], &["127.0.0.1:80"]);
373 }
374
375 #[test]
376 fn test_policy_matches() -> Result<(), PolicyError> {
377 let mut policy = AddrPolicy::default();
378 policy.push(RuleKind::Accept, "*:443".parse()?);
379 policy.push(RuleKind::Accept, "[::1]:80".parse()?);
380 policy.push(RuleKind::Reject, "*:80".parse()?);
381
382 let policy = policy; assert_eq!(
384 policy.allows_sockaddr(&"[::6]:443".parse().unwrap()),
385 Some(RuleKind::Accept)
386 );
387 assert_eq!(
388 policy.allows_sockaddr(&"127.0.0.1:443".parse().unwrap()),
389 Some(RuleKind::Accept)
390 );
391 assert_eq!(
392 policy.allows_sockaddr(&"[::1]:80".parse().unwrap()),
393 Some(RuleKind::Accept)
394 );
395 assert_eq!(
396 policy.allows_sockaddr(&"[::2]:80".parse().unwrap()),
397 Some(RuleKind::Reject)
398 );
399 assert_eq!(
400 policy.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
401 Some(RuleKind::Reject)
402 );
403 assert_eq!(
404 policy.allows_sockaddr(&"127.0.0.1:66".parse().unwrap()),
405 None
406 );
407 Ok(())
408 }
409
410 #[test]
411 fn serde() {
412 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
413 struct X {
414 p1: AddrPortPattern,
415 p2: AddrPortPattern,
416 }
417
418 let x = X {
419 p1: "127.0.0.1/8:9-10".parse().unwrap(),
420 p2: "*:80".parse().unwrap(),
421 };
422
423 let encoded = serde_json::to_string(&x).unwrap();
424 let expected = r#"{"p1":"127.0.0.1/8:9-10","p2":"*:80"}"#;
425 let x2: X = serde_json::from_str(&encoded).unwrap();
426 let x3: X = serde_json::from_str(expected).unwrap();
427 assert_eq!(&x2, &x3);
428 assert_eq!(&x2, &x);
429 }
430}