1use alloc::sync::Arc;
2use alloc::vec::Vec;
3use core::fmt::Debug;
4
5use crate::server::ClientHello;
6use crate::{server, sign};
7
8#[derive(Debug)]
10pub struct NoServerSessionStorage {}
11
12impl server::StoresServerSessions for NoServerSessionStorage {
13 fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
14 false
15 }
16 fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
17 None
18 }
19 fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
20 None
21 }
22 fn can_cache(&self) -> bool {
23 false
24 }
25}
26
27#[cfg(any(feature = "std", feature = "hashbrown"))]
28mod cache {
29 use alloc::sync::Arc;
30 use alloc::vec::Vec;
31 use core::fmt::{Debug, Formatter};
32
33 use crate::lock::Mutex;
34 use crate::{limited_cache, server};
35
36 pub struct ServerSessionMemoryCache {
40 cache: Mutex<limited_cache::LimitedCache<Vec<u8>, Vec<u8>>>,
41 }
42
43 impl ServerSessionMemoryCache {
44 #[cfg(feature = "std")]
48 pub fn new(size: usize) -> Arc<Self> {
49 Arc::new(Self {
50 cache: Mutex::new(limited_cache::LimitedCache::new(size)),
51 })
52 }
53
54 #[cfg(not(feature = "std"))]
58 pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Arc<Self> {
59 Arc::new(Self {
60 cache: Mutex::new::<M>(limited_cache::LimitedCache::new(size)),
61 })
62 }
63 }
64
65 impl server::StoresServerSessions for ServerSessionMemoryCache {
66 fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
67 self.cache
68 .lock()
69 .unwrap()
70 .insert(key, value);
71 true
72 }
73
74 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
75 self.cache
76 .lock()
77 .unwrap()
78 .get(key)
79 .cloned()
80 }
81
82 fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
83 self.cache.lock().unwrap().remove(key)
84 }
85
86 fn can_cache(&self) -> bool {
87 true
88 }
89 }
90
91 impl Debug for ServerSessionMemoryCache {
92 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
93 f.debug_struct("ServerSessionMemoryCache")
94 .finish()
95 }
96 }
97
98 #[cfg(test)]
99 mod tests {
100 use std::vec;
101
102 use super::*;
103 use crate::server::StoresServerSessions;
104
105 #[test]
106 fn test_serversessionmemorycache_accepts_put() {
107 let c = ServerSessionMemoryCache::new(4);
108 assert!(c.put(vec![0x01], vec![0x02]));
109 }
110
111 #[test]
112 fn test_serversessionmemorycache_persists_put() {
113 let c = ServerSessionMemoryCache::new(4);
114 assert!(c.put(vec![0x01], vec![0x02]));
115 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
116 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
117 }
118
119 #[test]
120 fn test_serversessionmemorycache_overwrites_put() {
121 let c = ServerSessionMemoryCache::new(4);
122 assert!(c.put(vec![0x01], vec![0x02]));
123 assert!(c.put(vec![0x01], vec![0x04]));
124 assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
125 }
126
127 #[test]
128 fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
129 let c = ServerSessionMemoryCache::new(2);
130 assert!(c.put(vec![0x01], vec![0x02]));
131 assert!(c.put(vec![0x03], vec![0x04]));
132 assert!(c.put(vec![0x05], vec![0x06]));
133 assert!(c.put(vec![0x07], vec![0x08]));
134 assert!(c.put(vec![0x09], vec![0x0a]));
135
136 let count = c.get(&[0x01]).iter().count()
137 + c.get(&[0x03]).iter().count()
138 + c.get(&[0x05]).iter().count()
139 + c.get(&[0x07]).iter().count()
140 + c.get(&[0x09]).iter().count();
141
142 assert!(count < 5);
143 }
144 }
145}
146
147#[cfg(any(feature = "std", feature = "hashbrown"))]
148pub use cache::ServerSessionMemoryCache;
149
150#[derive(Debug)]
152pub(super) struct NeverProducesTickets {}
153
154impl server::ProducesTickets for NeverProducesTickets {
155 fn enabled(&self) -> bool {
156 false
157 }
158 fn lifetime(&self) -> u32 {
159 0
160 }
161 fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
162 None
163 }
164 fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
165 None
166 }
167}
168
169#[derive(Debug)]
171pub(super) struct AlwaysResolvesChain(Arc<sign::CertifiedKey>);
172
173impl AlwaysResolvesChain {
174 pub(super) fn new(certified_key: sign::CertifiedKey) -> Self {
176 Self(Arc::new(certified_key))
177 }
178
179 pub(super) fn new_with_extras(certified_key: sign::CertifiedKey, ocsp: Vec<u8>) -> Self {
183 let mut r = Self::new(certified_key);
184
185 {
186 let cert = Arc::make_mut(&mut r.0);
187 if !ocsp.is_empty() {
188 cert.ocsp = Some(ocsp);
189 }
190 }
191
192 r
193 }
194}
195
196impl server::ResolvesServerCert for AlwaysResolvesChain {
197 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
198 Some(Arc::clone(&self.0))
199 }
200}
201
202#[derive(Clone, Debug)]
207pub struct AlwaysResolvesServerRawPublicKeys(Arc<sign::CertifiedKey>);
208
209impl AlwaysResolvesServerRawPublicKeys {
210 pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self {
212 Self(certified_key)
213 }
214}
215
216impl server::ResolvesServerCert for AlwaysResolvesServerRawPublicKeys {
217 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
218 Some(Arc::clone(&self.0))
219 }
220
221 fn only_raw_public_keys(&self) -> bool {
222 true
223 }
224}
225
226#[cfg(any(feature = "std", feature = "hashbrown"))]
227mod sni_resolver {
228 use alloc::string::{String, ToString};
229 use alloc::sync::Arc;
230 use core::fmt::Debug;
231
232 use pki_types::{DnsName, ServerName};
233
234 use crate::error::Error;
235 use crate::hash_map::HashMap;
236 use crate::server::ClientHello;
237 use crate::webpki::{verify_server_name, ParsedCertificate};
238 use crate::{server, sign};
239
240 #[derive(Debug)]
243 pub struct ResolvesServerCertUsingSni {
244 by_name: HashMap<String, Arc<sign::CertifiedKey>>,
245 }
246
247 impl ResolvesServerCertUsingSni {
248 pub fn new() -> Self {
250 Self {
251 by_name: HashMap::new(),
252 }
253 }
254
255 pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), Error> {
261 let server_name = {
262 let checked_name = DnsName::try_from(name)
263 .map_err(|_| Error::General("Bad DNS name".into()))
264 .map(|name| name.to_lowercase_owned())?;
265 ServerName::DnsName(checked_name)
266 };
267
268 ck.end_entity_cert()
278 .and_then(ParsedCertificate::try_from)
279 .and_then(|cert| verify_server_name(&cert, &server_name))?;
280
281 if let ServerName::DnsName(name) = server_name {
282 self.by_name
283 .insert(name.as_ref().to_string(), Arc::new(ck));
284 }
285 Ok(())
286 }
287 }
288
289 impl server::ResolvesServerCert for ResolvesServerCertUsingSni {
290 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
291 if let Some(name) = client_hello.server_name() {
292 self.by_name.get(name).cloned()
293 } else {
294 None
296 }
297 }
298 }
299
300 #[cfg(test)]
301 mod tests {
302 use super::*;
303 use crate::server::ResolvesServerCert;
304
305 #[test]
306 fn test_resolvesservercertusingsni_requires_sni() {
307 let rscsni = ResolvesServerCertUsingSni::new();
308 assert!(rscsni
309 .resolve(ClientHello {
310 server_name: &None,
311 signature_schemes: &[],
312 alpn: None,
313 server_cert_types: None,
314 client_cert_types: None,
315 cipher_suites: &[]
316 })
317 .is_none());
318 }
319
320 #[test]
321 fn test_resolvesservercertusingsni_handles_unknown_name() {
322 let rscsni = ResolvesServerCertUsingSni::new();
323 let name = DnsName::try_from("hello.com")
324 .unwrap()
325 .to_owned();
326 assert!(rscsni
327 .resolve(ClientHello {
328 server_name: &Some(name),
329 signature_schemes: &[],
330 alpn: None,
331 server_cert_types: None,
332 client_cert_types: None,
333 cipher_suites: &[]
334 })
335 .is_none());
336 }
337 }
338}
339
340#[cfg(any(feature = "std", feature = "hashbrown"))]
341pub use sni_resolver::ResolvesServerCertUsingSni;
342
343#[cfg(test)]
344mod tests {
345 use std::vec;
346
347 use super::*;
348 use crate::server::{ProducesTickets, StoresServerSessions};
349
350 #[test]
351 fn test_noserversessionstorage_drops_put() {
352 let c = NoServerSessionStorage {};
353 assert!(!c.put(vec![0x01], vec![0x02]));
354 }
355
356 #[test]
357 fn test_noserversessionstorage_denies_gets() {
358 let c = NoServerSessionStorage {};
359 c.put(vec![0x01], vec![0x02]);
360 assert_eq!(c.get(&[]), None);
361 assert_eq!(c.get(&[0x01]), None);
362 assert_eq!(c.get(&[0x02]), None);
363 }
364
365 #[test]
366 fn test_noserversessionstorage_denies_takes() {
367 let c = NoServerSessionStorage {};
368 assert_eq!(c.take(&[]), None);
369 assert_eq!(c.take(&[0x01]), None);
370 assert_eq!(c.take(&[0x02]), None);
371 }
372
373 #[test]
374 fn test_neverproducestickets_does_nothing() {
375 let npt = NeverProducesTickets {};
376 assert!(!npt.enabled());
377 assert_eq!(0, npt.lifetime());
378 assert_eq!(None, npt.encrypt(&[]));
379 assert_eq!(None, npt.decrypt(&[]));
380 }
381}