1use rsa::pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey};
19use std::fmt;
20use subtle::{Choice, ConstantTimeEq};
21
22#[cfg(feature = "memquota-memcost")]
23use {derive_deftly::Deftly, tor_memquota::derive_deftly_template_HasMemoryCost};
24
25use crate::util::ct::CtByteArray;
26
27pub const RSA_ID_LEN: usize = 20;
31
32#[derive(Clone, Copy, Hash, Ord, PartialOrd, Eq, PartialEq)]
44#[cfg_attr(
45 feature = "memquota-memcost",
46 derive(Deftly),
47 derive_deftly(HasMemoryCost)
48)]
49pub struct RsaIdentity {
50 id: CtByteArray<RSA_ID_LEN>,
52}
53
54impl ConstantTimeEq for RsaIdentity {
55 fn ct_eq(&self, other: &Self) -> Choice {
56 self.id.ct_eq(&other.id)
57 }
58}
59
60impl fmt::Display for RsaIdentity {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 write!(f, "${}", hex::encode(&self.id.as_ref()[..]))
63 }
64}
65impl fmt::Debug for RsaIdentity {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 write!(f, "RsaIdentity {{ {} }}", self)
68 }
69}
70
71impl safelog::Redactable for RsaIdentity {
72 fn display_redacted(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "${}…", hex::encode(&self.id.as_ref()[..1]))
76 }
77
78 fn debug_redacted(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 write!(f, "RsaIdentity {{ {} }}", self.redacted())
80 }
81}
82
83impl serde::Serialize for RsaIdentity {
84 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85 where
86 S: serde::Serializer,
87 {
88 if serializer.is_human_readable() {
89 serializer.serialize_str(&hex::encode(&self.id.as_ref()[..]))
90 } else {
91 serializer.serialize_bytes(&self.id.as_ref()[..])
92 }
93 }
94}
95
96impl<'de> serde::Deserialize<'de> for RsaIdentity {
97 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
98 where
99 D: serde::Deserializer<'de>,
100 {
101 if deserializer.is_human_readable() {
102 struct RsaIdentityVisitor;
104 impl<'de> serde::de::Visitor<'de> for RsaIdentityVisitor {
105 type Value = RsaIdentity;
106 fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> fmt::Result {
107 fmt.write_str("hex-encoded RSA identity")
108 }
109 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
110 where
111 E: serde::de::Error,
112 {
113 RsaIdentity::from_hex(s)
114 .ok_or_else(|| E::custom("wrong encoding for RSA identity"))
115 }
116 }
117
118 deserializer.deserialize_str(RsaIdentityVisitor)
119 } else {
120 struct RsaIdentityVisitor;
122 impl<'de> serde::de::Visitor<'de> for RsaIdentityVisitor {
123 type Value = RsaIdentity;
124 fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> fmt::Result {
125 fmt.write_str("RSA identity")
126 }
127 fn visit_bytes<E>(self, bytes: &[u8]) -> Result<Self::Value, E>
128 where
129 E: serde::de::Error,
130 {
131 RsaIdentity::from_bytes(bytes)
132 .ok_or_else(|| E::custom("wrong length for RSA identity"))
133 }
134 }
135 deserializer.deserialize_bytes(RsaIdentityVisitor)
136 }
137 }
138}
139
140impl RsaIdentity {
141 pub fn as_bytes(&self) -> &[u8] {
143 &self.id.as_ref()[..]
144 }
145 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
161 Some(RsaIdentity {
162 id: CtByteArray::from(<[u8; RSA_ID_LEN]>::try_from(bytes).ok()?),
163 })
164 }
165 pub fn from_hex(s: &str) -> Option<Self> {
169 let mut array = [0_u8; 20];
170 match hex::decode_to_slice(s, &mut array) {
171 Err(_) => None,
172 Ok(()) => Some(RsaIdentity::from(array)),
173 }
174 }
175
176 pub fn is_zero(&self) -> bool {
183 self.id.ct_eq(&[0; RSA_ID_LEN].into()).into()
185 }
186}
187
188impl From<[u8; 20]> for RsaIdentity {
189 fn from(id: [u8; 20]) -> RsaIdentity {
190 RsaIdentity { id: id.into() }
191 }
192}
193
194#[derive(Clone, Debug)]
199pub struct PublicKey(rsa::RsaPublicKey);
200
201pub struct PrivateKey(rsa::RsaPrivateKey);
207
208impl PrivateKey {
209 pub fn to_public_key(&self) -> PublicKey {
211 PublicKey(self.0.to_public_key())
212 }
213 pub fn from_der(der: &[u8]) -> Option<Self> {
215 Some(PrivateKey(rsa::RsaPrivateKey::from_pkcs1_der(der).ok()?))
216 }
217 }
219impl PublicKey {
220 pub fn exponent_is(&self, e: u32) -> bool {
223 use rsa::traits::PublicKeyParts;
224 *self.0.e() == rsa::BigUint::new(vec![e])
225 }
226 pub fn bits(&self) -> usize {
228 use rsa::traits::PublicKeyParts;
229 self.0.n().bits()
230 }
231 pub fn verify(&self, hashed: &[u8], sig: &[u8]) -> Result<(), signature::Error> {
237 let padding = rsa::pkcs1v15::Pkcs1v15Sign::new_unprefixed();
238 self.0
239 .verify(padding, hashed, sig)
240 .map_err(|_| signature::Error::new())
241 }
242 pub fn from_der(der: &[u8]) -> Option<Self> {
249 Some(PublicKey(rsa::RsaPublicKey::from_pkcs1_der(der).ok()?))
250 }
251 pub fn to_der(&self) -> Vec<u8> {
255 use der_parser::ber::BerObject;
256 use rsa::traits::PublicKeyParts;
257
258 let mut n = self.0.n().to_bytes_be();
259 if n[0] & 0b10000000 != 0 {
261 n.insert(0, 0_u8);
262 }
263 let n = BerObject::from_int_slice(&n);
264
265 let mut e = self.0.e().to_bytes_be();
266 if e[0] & 0b10000000 != 0 {
268 e.insert(0, 0_u8);
269 }
270 let e = BerObject::from_int_slice(&e);
271
272 let asn1 = BerObject::from_seq(vec![n, e]);
273 asn1.to_vec().expect("RSA key not encodable as DER")
274 }
275
276 pub fn to_rsa_identity(&self) -> RsaIdentity {
278 use crate::d::Sha1;
279 use digest::Digest;
280 let id: [u8; RSA_ID_LEN] = Sha1::digest(self.to_der()).into();
281 RsaIdentity { id: id.into() }
282 }
283}
284
285pub struct ValidatableRsaSignature {
287 key: PublicKey,
289 sig: Vec<u8>,
291 expected_hash: Vec<u8>,
293}
294
295impl ValidatableRsaSignature {
296 pub fn new(key: &PublicKey, sig: &[u8], expected_hash: &[u8]) -> Self {
298 ValidatableRsaSignature {
299 key: key.clone(),
300 sig: sig.into(),
301 expected_hash: expected_hash.into(),
302 }
303 }
304}
305
306impl super::ValidatableSignature for ValidatableRsaSignature {
307 fn is_valid(&self) -> bool {
308 self.key
309 .verify(&self.expected_hash[..], &self.sig[..])
310 .is_ok()
311 }
312}