tor_netdoc/types/policy/
addrpolicy.rs

1//! Implements address policies, based on a series of accept/reject
2//! rules.
3
4use std::fmt::Display;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::str::FromStr;
7
8use super::{PolicyError, PortRange};
9
10/// A sequence of rules that are applied to an address:port until one
11/// matches.
12///
13/// Each rule is of the form "accept PATTERN" or "reject PATTERN",
14/// where every pattern describes a set of addresses and ports.
15/// Address sets are given as a prefix of 0-128 bits that the address
16/// must have; port sets are given as a low-bound and high-bound that
17/// the target port might lie between.
18///
19/// Relays use this type for defining their own policies, and for
20/// publishing their IPv4 policies.  Clients instead use
21/// [super::portpolicy::PortPolicy] objects to view a summary of the
22/// relays' declared policies.
23///
24/// An example IPv4 policy might be:
25///
26/// ```ignore
27///  reject *:25
28///  reject 127.0.0.0/8:*
29///  reject 192.168.0.0/16:*
30///  accept *:80
31///  accept *:443
32///  accept *:9000-65535
33///  reject *:*
34/// ```
35#[derive(Clone, Debug, Default)]
36pub struct AddrPolicy {
37    /// A list of rules to apply to find out whether an address is
38    /// contained by this policy.
39    ///
40    /// The rules apply in order; the first one to match determines
41    /// whether the address is accepted or rejected.
42    rules: Vec<AddrPolicyRule>,
43}
44
45/// A kind of policy rule: either accepts or rejects addresses
46/// matching a pattern.
47#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48#[allow(clippy::exhaustive_enums)]
49pub enum RuleKind {
50    /// A rule that accepts matching address:port combinations.
51    Accept,
52    /// A rule that rejects matching address:port combinations.
53    Reject,
54}
55
56impl AddrPolicy {
57    /// Apply this policy to an address:port combination
58    ///
59    /// We do this by applying each rule in sequence, until one
60    /// matches.
61    ///
62    /// Returns None if no rule matches.
63    pub fn allows(&self, addr: &IpAddr, port: u16) -> Option<RuleKind> {
64        self.rules
65            .iter()
66            .find(|rule| rule.pattern.matches(addr, port))
67            .map(|AddrPolicyRule { kind, .. }| *kind)
68    }
69
70    /// As allows, but accept a SocketAddr.
71    pub fn allows_sockaddr(&self, addr: &SocketAddr) -> Option<RuleKind> {
72        self.allows(&addr.ip(), addr.port())
73    }
74
75    /// Create a new AddrPolicy that matches nothing.
76    pub fn new() -> Self {
77        AddrPolicy::default()
78    }
79
80    /// Add a new rule to this policy.
81    ///
82    /// The newly added rule is applied _after_ all previous rules.
83    /// It matches all addresses and ports covered by AddrPortPattern.
84    ///
85    /// If accept is true, the rule is to accept addresses that match;
86    /// if accept is false, the rule rejects such addresses.
87    pub fn push(&mut self, kind: RuleKind, pattern: AddrPortPattern) {
88        self.rules.push(AddrPolicyRule { kind, pattern });
89    }
90}
91
92/// A single rule in an address policy.
93///
94/// Contains a pattern and what to do with things that match it.
95#[derive(Clone, Debug)]
96struct AddrPolicyRule {
97    /// What do we do with items that match the pattern?
98    kind: RuleKind,
99    /// What pattern are we trying to match?
100    pattern: AddrPortPattern,
101}
102
103/*
104impl Display for AddrPolicyRule {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        let cmd = match self.kind {
107            RuleKind::Accept => "accept",
108            RuleKind::Reject => "reject",
109        };
110        write!(f, "{} {}", cmd, self.pattern)
111    }
112}
113*/
114
115/// A pattern that may or may not match an address and port.
116///
117/// Each AddrPortPattern has an IP pattern, which matches a set of
118/// addresses by prefix, and a port pattern, which matches a range of
119/// ports.
120///
121/// # Example
122///
123/// ```
124/// use tor_netdoc::types::policy::AddrPortPattern;
125/// use std::net::{IpAddr,Ipv4Addr};
126/// let localhost = IpAddr::V4(Ipv4Addr::new(127,3,4,5));
127/// let not_localhost = IpAddr::V4(Ipv4Addr::new(192,0,2,16));
128/// let pat: AddrPortPattern = "127.0.0.0/8:*".parse().unwrap();
129///
130/// assert!(pat.matches(&localhost, 22));
131/// assert!(! pat.matches(&not_localhost, 22));
132/// ```
133#[derive(
134    Clone, Debug, Eq, PartialEq, serde_with::SerializeDisplay, serde_with::DeserializeFromStr,
135)]
136pub struct AddrPortPattern {
137    /// A pattern to match somewhere between zero and all IP addresses.
138    pattern: IpPattern,
139    /// A pattern to match a range of ports.
140    ports: PortRange,
141}
142
143impl AddrPortPattern {
144    /// Return an AddrPortPattern matching all targets.
145    pub fn new_all() -> Self {
146        Self {
147            pattern: IpPattern::Star,
148            ports: PortRange::new_all(),
149        }
150    }
151
152    /// Return true iff this pattern matches a given address and port.
153    pub fn matches(&self, addr: &IpAddr, port: u16) -> bool {
154        self.pattern.matches(addr) && self.ports.contains(port)
155    }
156    /// As matches, but accept a SocketAddr.
157    pub fn matches_sockaddr(&self, addr: &SocketAddr) -> bool {
158        self.matches(&addr.ip(), addr.port())
159    }
160}
161
162impl Display for AddrPortPattern {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        if self.ports.is_all() {
165            write!(f, "{}:*", self.pattern)
166        } else {
167            write!(f, "{}:{}", self.pattern, self.ports)
168        }
169    }
170}
171
172impl FromStr for AddrPortPattern {
173    type Err = PolicyError;
174    fn from_str(s: &str) -> Result<Self, PolicyError> {
175        let last_colon = s.rfind(':').ok_or(PolicyError::InvalidPolicy)?;
176        let pattern: IpPattern = s[..last_colon].parse()?;
177        let ports_s = &s[last_colon + 1..];
178        let ports: PortRange = if ports_s == "*" {
179            PortRange::new_all()
180        } else {
181            ports_s.parse()?
182        };
183
184        Ok(AddrPortPattern { pattern, ports })
185    }
186}
187
188/// A pattern that matches one or more IP addresses.
189//
190// TODO(nickm): At present there is no way for Display or FromStr to distinguish
191// V4Star, V6Star, and Star.  If we decide it's important to have a syntax for
192// "all IPv4 addresses" that isn't "0.0.0.0/0", we'll need to revisit that.
193// At present, C tor allows '*', '*4', and '*6'.
194#[derive(Clone, Debug, Eq, PartialEq)]
195enum IpPattern {
196    /// Match all addresses.
197    Star,
198    /// Match all IPv4 addresses.
199    V4Star,
200    /// Match all IPv6 addresses.
201    V6Star,
202    /// Match all IPv4 addresses beginning with a given prefix.
203    V4(Ipv4Addr, u8),
204    /// Match all IPv6 addresses beginning with a given prefix.
205    V6(Ipv6Addr, u8),
206}
207
208impl IpPattern {
209    /// Construct an IpPattern that matches the first `mask` bits of `addr`.
210    fn from_addr_and_mask(addr: IpAddr, mask: u8) -> Result<Self, PolicyError> {
211        match (addr, mask) {
212            (IpAddr::V4(_), 0) => Ok(IpPattern::V4Star),
213            (IpAddr::V6(_), 0) => Ok(IpPattern::V6Star),
214            (IpAddr::V4(a), m) if m <= 32 => Ok(IpPattern::V4(a, m)),
215            (IpAddr::V6(a), m) if m <= 128 => Ok(IpPattern::V6(a, m)),
216            (_, _) => Err(PolicyError::InvalidMask),
217        }
218    }
219    /// Return true iff `addr` is matched by this pattern.
220    fn matches(&self, addr: &IpAddr) -> bool {
221        match (self, addr) {
222            (IpPattern::Star, _) => true,
223            (IpPattern::V4Star, IpAddr::V4(_)) => true,
224            (IpPattern::V6Star, IpAddr::V6(_)) => true,
225            (IpPattern::V4(pat, mask), IpAddr::V4(addr)) => {
226                let p1 = u32::from_be_bytes(pat.octets());
227                let p2 = u32::from_be_bytes(addr.octets());
228                let shift = 32 - mask;
229                (p1 >> shift) == (p2 >> shift)
230            }
231            (IpPattern::V6(pat, mask), IpAddr::V6(addr)) => {
232                let p1 = u128::from_be_bytes(pat.octets());
233                let p2 = u128::from_be_bytes(addr.octets());
234                let shift = 128 - mask;
235                (p1 >> shift) == (p2 >> shift)
236            }
237            (_, _) => false,
238        }
239    }
240}
241
242impl Display for IpPattern {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        use IpPattern::*;
245        match self {
246            Star | V4Star | V6Star => write!(f, "*"),
247            V4(a, 32) => write!(f, "{}", a),
248            V4(a, m) => write!(f, "{}/{}", a, m),
249            V6(a, 128) => write!(f, "[{}]", a),
250            V6(a, m) => write!(f, "[{}]/{}", a, m),
251        }
252    }
253}
254
255/// Helper: try to parse a plain ipv4 address, or an IPv6 address
256/// wrapped in brackets.
257fn parse_addr(mut s: &str) -> Result<IpAddr, PolicyError> {
258    let bracketed = s.starts_with('[') && s.ends_with(']');
259    if bracketed {
260        s = &s[1..s.len() - 1];
261    }
262    let addr: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?;
263    if addr.is_ipv6() != bracketed {
264        return Err(PolicyError::InvalidAddress);
265    }
266    Ok(addr)
267}
268
269impl FromStr for IpPattern {
270    type Err = PolicyError;
271    fn from_str(s: &str) -> Result<Self, PolicyError> {
272        let (ip_s, mask_s) = match s.find('/') {
273            Some(slash_idx) => (&s[..slash_idx], Some(&s[slash_idx + 1..])),
274            None => (s, None),
275        };
276        match (ip_s, mask_s) {
277            ("*", Some(_)) => Err(PolicyError::MaskWithStar),
278            ("*", None) => Ok(IpPattern::Star),
279            (s, Some(m)) => {
280                let a: IpAddr = parse_addr(s)?;
281                let m: u8 = m.parse().map_err(|_| PolicyError::InvalidMask)?;
282                IpPattern::from_addr_and_mask(a, m)
283            }
284            (s, None) => {
285                let a: IpAddr = parse_addr(s)?;
286                let m = if a.is_ipv4() { 32 } else { 128 };
287                IpPattern::from_addr_and_mask(a, m)
288            }
289        }
290    }
291}
292
293#[cfg(test)]
294mod test {
295    // @@ begin test lint list maintained by maint/add_warning @@
296    #![allow(clippy::bool_assert_comparison)]
297    #![allow(clippy::clone_on_copy)]
298    #![allow(clippy::dbg_macro)]
299    #![allow(clippy::mixed_attributes_style)]
300    #![allow(clippy::print_stderr)]
301    #![allow(clippy::print_stdout)]
302    #![allow(clippy::single_char_pattern)]
303    #![allow(clippy::unwrap_used)]
304    #![allow(clippy::unchecked_duration_subtraction)]
305    #![allow(clippy::useless_vec)]
306    #![allow(clippy::needless_pass_by_value)]
307    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
308    use super::*;
309
310    #[test]
311    fn test_roundtrip_rules() {
312        fn check(inp: &str, outp: &str) {
313            let policy = inp.parse::<AddrPortPattern>().unwrap();
314            assert_eq!(format!("{}", policy), outp);
315        }
316
317        check("127.0.0.2/32:77-10000", "127.0.0.2:77-10000");
318        check("127.0.0.2/32:*", "127.0.0.2:*");
319        check("127.0.0.0/16:9-100", "127.0.0.0/16:9-100");
320        check("127.0.0.0/0:443", "*:443");
321        check("*:443", "*:443");
322        check("[::1]:443", "[::1]:443");
323        check("[ffaa::]/16:80", "[ffaa::]/16:80");
324        check("[ffaa::77]/128:80", "[ffaa::77]:80");
325    }
326
327    #[test]
328    fn test_bad_rules() {
329        fn check(s: &str) {
330            assert!(s.parse::<AddrPortPattern>().is_err());
331        }
332
333        check("marzipan:80");
334        check("1.2.3.4:90-80");
335        check("1.2.3.4/100:8888");
336        check("[1.2.3.4]/16:80");
337        check("[::1]/130:8888");
338    }
339
340    #[test]
341    fn test_rule_matches() {
342        fn check(addr: &str, yes: &[&str], no: &[&str]) {
343            use std::net::SocketAddr;
344            let policy = addr.parse::<AddrPortPattern>().unwrap();
345            for s in yes {
346                let sa = s.parse::<SocketAddr>().unwrap();
347                assert!(policy.matches_sockaddr(&sa));
348            }
349            for s in no {
350                let sa = s.parse::<SocketAddr>().unwrap();
351                assert!(!policy.matches_sockaddr(&sa));
352            }
353        }
354
355        check(
356            "1.2.3.4/16:80",
357            &["1.2.3.4:80", "1.2.44.55:80"],
358            &["9.9.9.9:80", "1.3.3.4:80", "1.2.3.4:81"],
359        );
360        check(
361            "*:443-8000",
362            &["1.2.3.4:443", "[::1]:500"],
363            &["9.0.0.0:80", "[::1]:80"],
364        );
365        check(
366            "[face::]/8:80",
367            &["[fab0::7]:80"],
368            &["[dd00::]:80", "[face::7]:443"],
369        );
370
371        check("0.0.0.0/0:*", &["127.0.0.1:80"], &["[f00b::]:80"]);
372        check("[::]/0:*", &["[f00b::]:80"], &["127.0.0.1:80"]);
373    }
374
375    #[test]
376    fn test_policy_matches() -> Result<(), PolicyError> {
377        let mut policy = AddrPolicy::default();
378        policy.push(RuleKind::Accept, "*:443".parse()?);
379        policy.push(RuleKind::Accept, "[::1]:80".parse()?);
380        policy.push(RuleKind::Reject, "*:80".parse()?);
381
382        let policy = policy; // drop mut
383        assert_eq!(
384            policy.allows_sockaddr(&"[::6]:443".parse().unwrap()),
385            Some(RuleKind::Accept)
386        );
387        assert_eq!(
388            policy.allows_sockaddr(&"127.0.0.1:443".parse().unwrap()),
389            Some(RuleKind::Accept)
390        );
391        assert_eq!(
392            policy.allows_sockaddr(&"[::1]:80".parse().unwrap()),
393            Some(RuleKind::Accept)
394        );
395        assert_eq!(
396            policy.allows_sockaddr(&"[::2]:80".parse().unwrap()),
397            Some(RuleKind::Reject)
398        );
399        assert_eq!(
400            policy.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
401            Some(RuleKind::Reject)
402        );
403        assert_eq!(
404            policy.allows_sockaddr(&"127.0.0.1:66".parse().unwrap()),
405            None
406        );
407        Ok(())
408    }
409
410    #[test]
411    fn serde() {
412        #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
413        struct X {
414            p1: AddrPortPattern,
415            p2: AddrPortPattern,
416        }
417
418        let x = X {
419            p1: "127.0.0.1/8:9-10".parse().unwrap(),
420            p2: "*:80".parse().unwrap(),
421        };
422
423        let encoded = serde_json::to_string(&x).unwrap();
424        let expected = r#"{"p1":"127.0.0.1/8:9-10","p2":"*:80"}"#;
425        let x2: X = serde_json::from_str(&encoded).unwrap();
426        let x3: X = serde_json::from_str(expected).unwrap();
427        assert_eq!(&x2, &x3);
428        assert_eq!(&x2, &x);
429    }
430}