rustls/client/
handy.rs

1use alloc::sync::Arc;
2
3use pki_types::ServerName;
4
5use crate::enums::SignatureScheme;
6use crate::error::Error;
7use crate::msgs::handshake::CertificateChain;
8use crate::msgs::persist;
9use crate::{client, sign, NamedGroup};
10
11/// An implementer of `ClientSessionStore` which does nothing.
12#[derive(Debug)]
13pub(super) struct NoClientSessionStorage;
14
15impl client::ClientSessionStore for NoClientSessionStorage {
16    fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
17
18    fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
19        None
20    }
21
22    fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
23
24    fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
25        None
26    }
27
28    fn remove_tls12_session(&self, _: &ServerName<'_>) {}
29
30    fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
31
32    fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
33        None
34    }
35}
36
37#[cfg(any(feature = "std", feature = "hashbrown"))]
38mod cache {
39    use alloc::collections::VecDeque;
40    use core::fmt;
41
42    use pki_types::ServerName;
43
44    use crate::lock::Mutex;
45    use crate::msgs::persist;
46    use crate::{limited_cache, NamedGroup};
47
48    const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
49
50    struct ServerData {
51        kx_hint: Option<NamedGroup>,
52
53        // Zero or one TLS1.2 sessions.
54        #[cfg(feature = "tls12")]
55        tls12: Option<persist::Tls12ClientSessionValue>,
56
57        // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first.
58        tls13: VecDeque<persist::Tls13ClientSessionValue>,
59    }
60
61    impl Default for ServerData {
62        fn default() -> Self {
63            Self {
64                kx_hint: None,
65                #[cfg(feature = "tls12")]
66                tls12: None,
67                tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
68            }
69        }
70    }
71
72    /// An implementer of `ClientSessionStore` that stores everything
73    /// in memory.
74    ///
75    /// It enforces a limit on the number of entries to bound memory usage.
76    pub struct ClientSessionMemoryCache {
77        servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
78    }
79
80    impl ClientSessionMemoryCache {
81        /// Make a new ClientSessionMemoryCache.  `size` is the
82        /// maximum number of stored sessions.
83        #[cfg(feature = "std")]
84        pub fn new(size: usize) -> Self {
85            let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
86                / MAX_TLS13_TICKETS_PER_SERVER;
87            Self {
88                servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
89            }
90        }
91
92        /// Make a new ClientSessionMemoryCache.  `size` is the
93        /// maximum number of stored sessions.
94        #[cfg(not(feature = "std"))]
95        pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Self {
96            let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
97                / MAX_TLS13_TICKETS_PER_SERVER;
98            Self {
99                servers: Mutex::new::<M>(limited_cache::LimitedCache::new(max_servers)),
100            }
101        }
102    }
103
104    impl super::client::ClientSessionStore for ClientSessionMemoryCache {
105        fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
106            self.servers
107                .lock()
108                .unwrap()
109                .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
110        }
111
112        fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
113            self.servers
114                .lock()
115                .unwrap()
116                .get(server_name)
117                .and_then(|sd| sd.kx_hint)
118        }
119
120        fn set_tls12_session(
121            &self,
122            _server_name: ServerName<'static>,
123            _value: persist::Tls12ClientSessionValue,
124        ) {
125            #[cfg(feature = "tls12")]
126            self.servers
127                .lock()
128                .unwrap()
129                .get_or_insert_default_and_edit(_server_name.clone(), |data| {
130                    data.tls12 = Some(_value)
131                });
132        }
133
134        fn tls12_session(
135            &self,
136            _server_name: &ServerName<'_>,
137        ) -> Option<persist::Tls12ClientSessionValue> {
138            #[cfg(not(feature = "tls12"))]
139            return None;
140
141            #[cfg(feature = "tls12")]
142            self.servers
143                .lock()
144                .unwrap()
145                .get(_server_name)
146                .and_then(|sd| sd.tls12.as_ref().cloned())
147        }
148
149        fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
150            #[cfg(feature = "tls12")]
151            self.servers
152                .lock()
153                .unwrap()
154                .get_mut(_server_name)
155                .and_then(|data| data.tls12.take());
156        }
157
158        fn insert_tls13_ticket(
159            &self,
160            server_name: ServerName<'static>,
161            value: persist::Tls13ClientSessionValue,
162        ) {
163            self.servers
164                .lock()
165                .unwrap()
166                .get_or_insert_default_and_edit(server_name.clone(), |data| {
167                    if data.tls13.len() == data.tls13.capacity() {
168                        data.tls13.pop_front();
169                    }
170                    data.tls13.push_back(value);
171                });
172        }
173
174        fn take_tls13_ticket(
175            &self,
176            server_name: &ServerName<'static>,
177        ) -> Option<persist::Tls13ClientSessionValue> {
178            self.servers
179                .lock()
180                .unwrap()
181                .get_mut(server_name)
182                .and_then(|data| data.tls13.pop_back())
183        }
184    }
185
186    impl fmt::Debug for ClientSessionMemoryCache {
187        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188            // Note: we omit self.servers as it may contain sensitive data.
189            f.debug_struct("ClientSessionMemoryCache")
190                .finish()
191        }
192    }
193}
194
195#[cfg(any(feature = "std", feature = "hashbrown"))]
196pub use cache::ClientSessionMemoryCache;
197
198#[derive(Debug)]
199pub(super) struct FailResolveClientCert {}
200
201impl client::ResolvesClientCert for FailResolveClientCert {
202    fn resolve(
203        &self,
204        _root_hint_subjects: &[&[u8]],
205        _sigschemes: &[SignatureScheme],
206    ) -> Option<Arc<sign::CertifiedKey>> {
207        None
208    }
209
210    fn has_certs(&self) -> bool {
211        false
212    }
213}
214
215#[derive(Debug)]
216pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
217
218impl AlwaysResolvesClientCert {
219    pub(super) fn new(
220        private_key: Arc<dyn sign::SigningKey>,
221        chain: CertificateChain<'static>,
222    ) -> Result<Self, Error> {
223        Ok(Self(Arc::new(sign::CertifiedKey::new(
224            chain.0,
225            private_key,
226        ))))
227    }
228}
229
230impl client::ResolvesClientCert for AlwaysResolvesClientCert {
231    fn resolve(
232        &self,
233        _root_hint_subjects: &[&[u8]],
234        _sigschemes: &[SignatureScheme],
235    ) -> Option<Arc<sign::CertifiedKey>> {
236        Some(Arc::clone(&self.0))
237    }
238
239    fn has_certs(&self) -> bool {
240        true
241    }
242}
243
244/// An exemplar `ResolvesClientCert` implementation that always resolves to a single
245/// [RFC 7250] raw public key.
246///
247/// [RFC 7250]: https://tools.ietf.org/html/rfc7250
248#[derive(Clone, Debug)]
249pub struct AlwaysResolvesClientRawPublicKeys(Arc<sign::CertifiedKey>);
250impl AlwaysResolvesClientRawPublicKeys {
251    /// Create a new `AlwaysResolvesClientRawPublicKeys` instance.
252    pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self {
253        Self(certified_key)
254    }
255}
256
257impl client::ResolvesClientCert for AlwaysResolvesClientRawPublicKeys {
258    fn resolve(
259        &self,
260        _root_hint_subjects: &[&[u8]],
261        _sigschemes: &[SignatureScheme],
262    ) -> Option<Arc<sign::CertifiedKey>> {
263        Some(Arc::clone(&self.0))
264    }
265
266    fn only_raw_public_keys(&self) -> bool {
267        true
268    }
269
270    /// Returns true if the resolver is ready to present an identity.
271    ///
272    /// Even though the function is called `has_certs`, it returns true
273    /// although only an RPK (Raw Public Key) is available, not an actual certificate.
274    fn has_certs(&self) -> bool {
275        true
276    }
277}
278
279#[cfg(test)]
280#[macro_rules_attribute::apply(test_for_each_provider)]
281mod tests {
282    use alloc::sync::Arc;
283    use std::prelude::v1::*;
284
285    use pki_types::{ServerName, UnixTime};
286
287    use super::provider::cipher_suite;
288    use super::NoClientSessionStorage;
289    use crate::client::ClientSessionStore;
290    use crate::msgs::base::PayloadU16;
291    use crate::msgs::enums::NamedGroup;
292    use crate::msgs::handshake::CertificateChain;
293    #[cfg(feature = "tls12")]
294    use crate::msgs::handshake::SessionId;
295    use crate::msgs::persist::Tls13ClientSessionValue;
296    use crate::suites::SupportedCipherSuite;
297
298    #[test]
299    fn test_noclientsessionstorage_does_nothing() {
300        let c = NoClientSessionStorage {};
301        let name = ServerName::try_from("example.com").unwrap();
302        let now = UnixTime::now();
303
304        c.set_kx_hint(name.clone(), NamedGroup::X25519);
305        assert_eq!(None, c.kx_hint(&name));
306
307        #[cfg(feature = "tls12")]
308        {
309            use crate::msgs::persist::Tls12ClientSessionValue;
310            let SupportedCipherSuite::Tls12(tls12_suite) =
311                cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
312            else {
313                unreachable!()
314            };
315
316            c.set_tls12_session(
317                name.clone(),
318                Tls12ClientSessionValue::new(
319                    tls12_suite,
320                    SessionId::empty(),
321                    Arc::new(PayloadU16::empty()),
322                    &[],
323                    CertificateChain::default(),
324                    now,
325                    0,
326                    true,
327                ),
328            );
329            assert!(c.tls12_session(&name).is_none());
330            c.remove_tls12_session(&name);
331        }
332
333        let SupportedCipherSuite::Tls13(tls13_suite) = cipher_suite::TLS13_AES_256_GCM_SHA384
334        else {
335            unreachable!();
336        };
337        c.insert_tls13_ticket(
338            name.clone(),
339            Tls13ClientSessionValue::new(
340                tls13_suite,
341                Arc::new(PayloadU16::empty()),
342                &[],
343                CertificateChain::default(),
344                now,
345                0,
346                0,
347                0,
348            ),
349        );
350        assert!(c.take_tls13_ticket(&name).is_none());
351    }
352}