rustls/server/
handy.rs

1use alloc::sync::Arc;
2use alloc::vec::Vec;
3use core::fmt::Debug;
4
5use crate::server::ClientHello;
6use crate::{server, sign};
7
8/// Something which never stores sessions.
9#[derive(Debug)]
10pub struct NoServerSessionStorage {}
11
12impl server::StoresServerSessions for NoServerSessionStorage {
13    fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
14        false
15    }
16    fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
17        None
18    }
19    fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
20        None
21    }
22    fn can_cache(&self) -> bool {
23        false
24    }
25}
26
27#[cfg(any(feature = "std", feature = "hashbrown"))]
28mod cache {
29    use alloc::sync::Arc;
30    use alloc::vec::Vec;
31    use core::fmt::{Debug, Formatter};
32
33    use crate::lock::Mutex;
34    use crate::{limited_cache, server};
35
36    /// An implementer of `StoresServerSessions` that stores everything
37    /// in memory.  If enforces a limit on the number of stored sessions
38    /// to bound memory usage.
39    pub struct ServerSessionMemoryCache {
40        cache: Mutex<limited_cache::LimitedCache<Vec<u8>, Vec<u8>>>,
41    }
42
43    impl ServerSessionMemoryCache {
44        /// Make a new ServerSessionMemoryCache.  `size` is the maximum
45        /// number of stored sessions, and may be rounded-up for
46        /// efficiency.
47        #[cfg(feature = "std")]
48        pub fn new(size: usize) -> Arc<Self> {
49            Arc::new(Self {
50                cache: Mutex::new(limited_cache::LimitedCache::new(size)),
51            })
52        }
53
54        /// Make a new ServerSessionMemoryCache.  `size` is the maximum
55        /// number of stored sessions, and may be rounded-up for
56        /// efficiency.
57        #[cfg(not(feature = "std"))]
58        pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Arc<Self> {
59            Arc::new(Self {
60                cache: Mutex::new::<M>(limited_cache::LimitedCache::new(size)),
61            })
62        }
63    }
64
65    impl server::StoresServerSessions for ServerSessionMemoryCache {
66        fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
67            self.cache
68                .lock()
69                .unwrap()
70                .insert(key, value);
71            true
72        }
73
74        fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
75            self.cache
76                .lock()
77                .unwrap()
78                .get(key)
79                .cloned()
80        }
81
82        fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
83            self.cache.lock().unwrap().remove(key)
84        }
85
86        fn can_cache(&self) -> bool {
87            true
88        }
89    }
90
91    impl Debug for ServerSessionMemoryCache {
92        fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
93            f.debug_struct("ServerSessionMemoryCache")
94                .finish()
95        }
96    }
97
98    #[cfg(test)]
99    mod tests {
100        use std::vec;
101
102        use super::*;
103        use crate::server::StoresServerSessions;
104
105        #[test]
106        fn test_serversessionmemorycache_accepts_put() {
107            let c = ServerSessionMemoryCache::new(4);
108            assert!(c.put(vec![0x01], vec![0x02]));
109        }
110
111        #[test]
112        fn test_serversessionmemorycache_persists_put() {
113            let c = ServerSessionMemoryCache::new(4);
114            assert!(c.put(vec![0x01], vec![0x02]));
115            assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
116            assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
117        }
118
119        #[test]
120        fn test_serversessionmemorycache_overwrites_put() {
121            let c = ServerSessionMemoryCache::new(4);
122            assert!(c.put(vec![0x01], vec![0x02]));
123            assert!(c.put(vec![0x01], vec![0x04]));
124            assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
125        }
126
127        #[test]
128        fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
129            let c = ServerSessionMemoryCache::new(2);
130            assert!(c.put(vec![0x01], vec![0x02]));
131            assert!(c.put(vec![0x03], vec![0x04]));
132            assert!(c.put(vec![0x05], vec![0x06]));
133            assert!(c.put(vec![0x07], vec![0x08]));
134            assert!(c.put(vec![0x09], vec![0x0a]));
135
136            let count = c.get(&[0x01]).iter().count()
137                + c.get(&[0x03]).iter().count()
138                + c.get(&[0x05]).iter().count()
139                + c.get(&[0x07]).iter().count()
140                + c.get(&[0x09]).iter().count();
141
142            assert!(count < 5);
143        }
144    }
145}
146
147#[cfg(any(feature = "std", feature = "hashbrown"))]
148pub use cache::ServerSessionMemoryCache;
149
150/// Something which never produces tickets.
151#[derive(Debug)]
152pub(super) struct NeverProducesTickets {}
153
154impl server::ProducesTickets for NeverProducesTickets {
155    fn enabled(&self) -> bool {
156        false
157    }
158    fn lifetime(&self) -> u32 {
159        0
160    }
161    fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
162        None
163    }
164    fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
165        None
166    }
167}
168
169/// Something which always resolves to the same cert chain.
170#[derive(Debug)]
171pub(super) struct AlwaysResolvesChain(Arc<sign::CertifiedKey>);
172
173impl AlwaysResolvesChain {
174    /// Creates an `AlwaysResolvesChain`, using the supplied `CertifiedKey`.
175    pub(super) fn new(certified_key: sign::CertifiedKey) -> Self {
176        Self(Arc::new(certified_key))
177    }
178
179    /// Creates an `AlwaysResolvesChain`, using the supplied `CertifiedKey` and OCSP response.
180    ///
181    /// If non-empty, the given OCSP response is attached.
182    pub(super) fn new_with_extras(certified_key: sign::CertifiedKey, ocsp: Vec<u8>) -> Self {
183        let mut r = Self::new(certified_key);
184
185        {
186            let cert = Arc::make_mut(&mut r.0);
187            if !ocsp.is_empty() {
188                cert.ocsp = Some(ocsp);
189            }
190        }
191
192        r
193    }
194}
195
196impl server::ResolvesServerCert for AlwaysResolvesChain {
197    fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
198        Some(Arc::clone(&self.0))
199    }
200}
201
202/// An exemplar `ResolvesServerCert` implementation that always resolves to a single
203/// [RFC 7250] raw public key.
204///
205/// [RFC 7250]: https://tools.ietf.org/html/rfc7250
206#[derive(Clone, Debug)]
207pub struct AlwaysResolvesServerRawPublicKeys(Arc<sign::CertifiedKey>);
208
209impl AlwaysResolvesServerRawPublicKeys {
210    /// Create a new `AlwaysResolvesServerRawPublicKeys` instance.
211    pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self {
212        Self(certified_key)
213    }
214}
215
216impl server::ResolvesServerCert for AlwaysResolvesServerRawPublicKeys {
217    fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
218        Some(Arc::clone(&self.0))
219    }
220
221    fn only_raw_public_keys(&self) -> bool {
222        true
223    }
224}
225
226#[cfg(any(feature = "std", feature = "hashbrown"))]
227mod sni_resolver {
228    use alloc::string::{String, ToString};
229    use alloc::sync::Arc;
230    use core::fmt::Debug;
231
232    use pki_types::{DnsName, ServerName};
233
234    use crate::error::Error;
235    use crate::hash_map::HashMap;
236    use crate::server::ClientHello;
237    use crate::webpki::{verify_server_name, ParsedCertificate};
238    use crate::{server, sign};
239
240    /// Something that resolves do different cert chains/keys based
241    /// on client-supplied server name (via SNI).
242    #[derive(Debug)]
243    pub struct ResolvesServerCertUsingSni {
244        by_name: HashMap<String, Arc<sign::CertifiedKey>>,
245    }
246
247    impl ResolvesServerCertUsingSni {
248        /// Create a new and empty (i.e., knows no certificates) resolver.
249        pub fn new() -> Self {
250            Self {
251                by_name: HashMap::new(),
252            }
253        }
254
255        /// Add a new `sign::CertifiedKey` to be used for the given SNI `name`.
256        ///
257        /// This function fails if `name` is not a valid DNS name, or if
258        /// it's not valid for the supplied certificate, or if the certificate
259        /// chain is syntactically faulty.
260        pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), Error> {
261            let server_name = {
262                let checked_name = DnsName::try_from(name)
263                    .map_err(|_| Error::General("Bad DNS name".into()))
264                    .map(|name| name.to_lowercase_owned())?;
265                ServerName::DnsName(checked_name)
266            };
267
268            // Check the certificate chain for validity:
269            // - it should be non-empty list
270            // - the first certificate should be parsable as a x509v3,
271            // - the first certificate should quote the given server name
272            //   (if provided)
273            //
274            // These checks are not security-sensitive.  They are the
275            // *server* attempting to detect accidental misconfiguration.
276
277            ck.end_entity_cert()
278                .and_then(ParsedCertificate::try_from)
279                .and_then(|cert| verify_server_name(&cert, &server_name))?;
280
281            if let ServerName::DnsName(name) = server_name {
282                self.by_name
283                    .insert(name.as_ref().to_string(), Arc::new(ck));
284            }
285            Ok(())
286        }
287    }
288
289    impl server::ResolvesServerCert for ResolvesServerCertUsingSni {
290        fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
291            if let Some(name) = client_hello.server_name() {
292                self.by_name.get(name).cloned()
293            } else {
294                // This kind of resolver requires SNI
295                None
296            }
297        }
298    }
299
300    #[cfg(test)]
301    mod tests {
302        use super::*;
303        use crate::server::ResolvesServerCert;
304
305        #[test]
306        fn test_resolvesservercertusingsni_requires_sni() {
307            let rscsni = ResolvesServerCertUsingSni::new();
308            assert!(rscsni
309                .resolve(ClientHello {
310                    server_name: &None,
311                    signature_schemes: &[],
312                    alpn: None,
313                    server_cert_types: None,
314                    client_cert_types: None,
315                    cipher_suites: &[]
316                })
317                .is_none());
318        }
319
320        #[test]
321        fn test_resolvesservercertusingsni_handles_unknown_name() {
322            let rscsni = ResolvesServerCertUsingSni::new();
323            let name = DnsName::try_from("hello.com")
324                .unwrap()
325                .to_owned();
326            assert!(rscsni
327                .resolve(ClientHello {
328                    server_name: &Some(name),
329                    signature_schemes: &[],
330                    alpn: None,
331                    server_cert_types: None,
332                    client_cert_types: None,
333                    cipher_suites: &[]
334                })
335                .is_none());
336        }
337    }
338}
339
340#[cfg(any(feature = "std", feature = "hashbrown"))]
341pub use sni_resolver::ResolvesServerCertUsingSni;
342
343#[cfg(test)]
344mod tests {
345    use std::vec;
346
347    use super::*;
348    use crate::server::{ProducesTickets, StoresServerSessions};
349
350    #[test]
351    fn test_noserversessionstorage_drops_put() {
352        let c = NoServerSessionStorage {};
353        assert!(!c.put(vec![0x01], vec![0x02]));
354    }
355
356    #[test]
357    fn test_noserversessionstorage_denies_gets() {
358        let c = NoServerSessionStorage {};
359        c.put(vec![0x01], vec![0x02]);
360        assert_eq!(c.get(&[]), None);
361        assert_eq!(c.get(&[0x01]), None);
362        assert_eq!(c.get(&[0x02]), None);
363    }
364
365    #[test]
366    fn test_noserversessionstorage_denies_takes() {
367        let c = NoServerSessionStorage {};
368        assert_eq!(c.take(&[]), None);
369        assert_eq!(c.take(&[0x01]), None);
370        assert_eq!(c.take(&[0x02]), None);
371    }
372
373    #[test]
374    fn test_neverproducestickets_does_nothing() {
375        let npt = NeverProducesTickets {};
376        assert!(!npt.enabled());
377        assert_eq!(0, npt.lifetime());
378        assert_eq!(None, npt.encrypt(&[]));
379        assert_eq!(None, npt.decrypt(&[]));
380    }
381}