1use alloc::sync::Arc;
2
3use pki_types::ServerName;
4
5use crate::enums::SignatureScheme;
6use crate::error::Error;
7use crate::msgs::handshake::CertificateChain;
8use crate::msgs::persist;
9use crate::{client, sign, NamedGroup};
10
11#[derive(Debug)]
13pub(super) struct NoClientSessionStorage;
14
15impl client::ClientSessionStore for NoClientSessionStorage {
16 fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
17
18 fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
19 None
20 }
21
22 fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
23
24 fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
25 None
26 }
27
28 fn remove_tls12_session(&self, _: &ServerName<'_>) {}
29
30 fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
31
32 fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
33 None
34 }
35}
36
37#[cfg(any(feature = "std", feature = "hashbrown"))]
38mod cache {
39 use alloc::collections::VecDeque;
40 use core::fmt;
41
42 use pki_types::ServerName;
43
44 use crate::lock::Mutex;
45 use crate::msgs::persist;
46 use crate::{limited_cache, NamedGroup};
47
48 const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
49
50 struct ServerData {
51 kx_hint: Option<NamedGroup>,
52
53 #[cfg(feature = "tls12")]
55 tls12: Option<persist::Tls12ClientSessionValue>,
56
57 tls13: VecDeque<persist::Tls13ClientSessionValue>,
59 }
60
61 impl Default for ServerData {
62 fn default() -> Self {
63 Self {
64 kx_hint: None,
65 #[cfg(feature = "tls12")]
66 tls12: None,
67 tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
68 }
69 }
70 }
71
72 pub struct ClientSessionMemoryCache {
77 servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
78 }
79
80 impl ClientSessionMemoryCache {
81 #[cfg(feature = "std")]
84 pub fn new(size: usize) -> Self {
85 let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
86 / MAX_TLS13_TICKETS_PER_SERVER;
87 Self {
88 servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
89 }
90 }
91
92 #[cfg(not(feature = "std"))]
95 pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Self {
96 let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
97 / MAX_TLS13_TICKETS_PER_SERVER;
98 Self {
99 servers: Mutex::new::<M>(limited_cache::LimitedCache::new(max_servers)),
100 }
101 }
102 }
103
104 impl super::client::ClientSessionStore for ClientSessionMemoryCache {
105 fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
106 self.servers
107 .lock()
108 .unwrap()
109 .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
110 }
111
112 fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
113 self.servers
114 .lock()
115 .unwrap()
116 .get(server_name)
117 .and_then(|sd| sd.kx_hint)
118 }
119
120 fn set_tls12_session(
121 &self,
122 _server_name: ServerName<'static>,
123 _value: persist::Tls12ClientSessionValue,
124 ) {
125 #[cfg(feature = "tls12")]
126 self.servers
127 .lock()
128 .unwrap()
129 .get_or_insert_default_and_edit(_server_name.clone(), |data| {
130 data.tls12 = Some(_value)
131 });
132 }
133
134 fn tls12_session(
135 &self,
136 _server_name: &ServerName<'_>,
137 ) -> Option<persist::Tls12ClientSessionValue> {
138 #[cfg(not(feature = "tls12"))]
139 return None;
140
141 #[cfg(feature = "tls12")]
142 self.servers
143 .lock()
144 .unwrap()
145 .get(_server_name)
146 .and_then(|sd| sd.tls12.as_ref().cloned())
147 }
148
149 fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
150 #[cfg(feature = "tls12")]
151 self.servers
152 .lock()
153 .unwrap()
154 .get_mut(_server_name)
155 .and_then(|data| data.tls12.take());
156 }
157
158 fn insert_tls13_ticket(
159 &self,
160 server_name: ServerName<'static>,
161 value: persist::Tls13ClientSessionValue,
162 ) {
163 self.servers
164 .lock()
165 .unwrap()
166 .get_or_insert_default_and_edit(server_name.clone(), |data| {
167 if data.tls13.len() == data.tls13.capacity() {
168 data.tls13.pop_front();
169 }
170 data.tls13.push_back(value);
171 });
172 }
173
174 fn take_tls13_ticket(
175 &self,
176 server_name: &ServerName<'static>,
177 ) -> Option<persist::Tls13ClientSessionValue> {
178 self.servers
179 .lock()
180 .unwrap()
181 .get_mut(server_name)
182 .and_then(|data| data.tls13.pop_back())
183 }
184 }
185
186 impl fmt::Debug for ClientSessionMemoryCache {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 f.debug_struct("ClientSessionMemoryCache")
190 .finish()
191 }
192 }
193}
194
195#[cfg(any(feature = "std", feature = "hashbrown"))]
196pub use cache::ClientSessionMemoryCache;
197
198#[derive(Debug)]
199pub(super) struct FailResolveClientCert {}
200
201impl client::ResolvesClientCert for FailResolveClientCert {
202 fn resolve(
203 &self,
204 _root_hint_subjects: &[&[u8]],
205 _sigschemes: &[SignatureScheme],
206 ) -> Option<Arc<sign::CertifiedKey>> {
207 None
208 }
209
210 fn has_certs(&self) -> bool {
211 false
212 }
213}
214
215#[derive(Debug)]
216pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
217
218impl AlwaysResolvesClientCert {
219 pub(super) fn new(
220 private_key: Arc<dyn sign::SigningKey>,
221 chain: CertificateChain<'static>,
222 ) -> Result<Self, Error> {
223 Ok(Self(Arc::new(sign::CertifiedKey::new(
224 chain.0,
225 private_key,
226 ))))
227 }
228}
229
230impl client::ResolvesClientCert for AlwaysResolvesClientCert {
231 fn resolve(
232 &self,
233 _root_hint_subjects: &[&[u8]],
234 _sigschemes: &[SignatureScheme],
235 ) -> Option<Arc<sign::CertifiedKey>> {
236 Some(Arc::clone(&self.0))
237 }
238
239 fn has_certs(&self) -> bool {
240 true
241 }
242}
243
244#[derive(Clone, Debug)]
249pub struct AlwaysResolvesClientRawPublicKeys(Arc<sign::CertifiedKey>);
250impl AlwaysResolvesClientRawPublicKeys {
251 pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self {
253 Self(certified_key)
254 }
255}
256
257impl client::ResolvesClientCert for AlwaysResolvesClientRawPublicKeys {
258 fn resolve(
259 &self,
260 _root_hint_subjects: &[&[u8]],
261 _sigschemes: &[SignatureScheme],
262 ) -> Option<Arc<sign::CertifiedKey>> {
263 Some(Arc::clone(&self.0))
264 }
265
266 fn only_raw_public_keys(&self) -> bool {
267 true
268 }
269
270 fn has_certs(&self) -> bool {
275 true
276 }
277}
278
279#[cfg(test)]
280#[macro_rules_attribute::apply(test_for_each_provider)]
281mod tests {
282 use alloc::sync::Arc;
283 use std::prelude::v1::*;
284
285 use pki_types::{ServerName, UnixTime};
286
287 use super::provider::cipher_suite;
288 use super::NoClientSessionStorage;
289 use crate::client::ClientSessionStore;
290 use crate::msgs::base::PayloadU16;
291 use crate::msgs::enums::NamedGroup;
292 use crate::msgs::handshake::CertificateChain;
293 #[cfg(feature = "tls12")]
294 use crate::msgs::handshake::SessionId;
295 use crate::msgs::persist::Tls13ClientSessionValue;
296 use crate::suites::SupportedCipherSuite;
297
298 #[test]
299 fn test_noclientsessionstorage_does_nothing() {
300 let c = NoClientSessionStorage {};
301 let name = ServerName::try_from("example.com").unwrap();
302 let now = UnixTime::now();
303
304 c.set_kx_hint(name.clone(), NamedGroup::X25519);
305 assert_eq!(None, c.kx_hint(&name));
306
307 #[cfg(feature = "tls12")]
308 {
309 use crate::msgs::persist::Tls12ClientSessionValue;
310 let SupportedCipherSuite::Tls12(tls12_suite) =
311 cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
312 else {
313 unreachable!()
314 };
315
316 c.set_tls12_session(
317 name.clone(),
318 Tls12ClientSessionValue::new(
319 tls12_suite,
320 SessionId::empty(),
321 Arc::new(PayloadU16::empty()),
322 &[],
323 CertificateChain::default(),
324 now,
325 0,
326 true,
327 ),
328 );
329 assert!(c.tls12_session(&name).is_none());
330 c.remove_tls12_session(&name);
331 }
332
333 let SupportedCipherSuite::Tls13(tls13_suite) = cipher_suite::TLS13_AES_256_GCM_SHA384
334 else {
335 unreachable!();
336 };
337 c.insert_tls13_ticket(
338 name.clone(),
339 Tls13ClientSessionValue::new(
340 tls13_suite,
341 Arc::new(PayloadU16::empty()),
342 &[],
343 CertificateChain::default(),
344 now,
345 0,
346 0,
347 0,
348 ),
349 );
350 assert!(c.take_tls13_ticket(&name).is_none());
351 }
352}