crypto_bigint/
non_zero.rs1use crate::{CtChoice, Encoding, Integer, Limb, Uint, Zero};
4use core::{
5 fmt,
6 num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
7 ops::Deref,
8};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "generic-array")]
12use crate::{ArrayEncoding, ByteArray};
13
14#[cfg(feature = "rand_core")]
15use {crate::Random, rand_core::CryptoRngCore};
16
17#[cfg(feature = "serde")]
18use serdect::serde::{
19 de::{Error, Unexpected},
20 Deserialize, Deserializer, Serialize, Serializer,
21};
22
23#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
25pub struct NonZero<T: Zero>(T);
26
27impl NonZero<Limb> {
28 pub const fn const_new(n: Limb) -> (Self, CtChoice) {
31 (Self(n), n.ct_is_nonzero())
32 }
33}
34
35impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
36 pub const fn const_new(n: Uint<LIMBS>) -> (Self, CtChoice) {
39 (Self(n), n.ct_is_nonzero())
40 }
41}
42
43impl<T> NonZero<T>
44where
45 T: Zero,
46{
47 pub fn new(n: T) -> CtOption<Self> {
49 let is_zero = n.is_zero();
50 CtOption::new(Self(n), !is_zero)
51 }
52}
53
54impl<T> NonZero<T>
55where
56 T: Integer,
57{
58 pub const ONE: Self = Self(T::ONE);
60
61 pub const MAX: Self = Self(T::MAX);
63}
64
65impl<T> NonZero<T>
66where
67 T: Encoding + Zero,
68{
69 pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
71 Self::new(T::from_be_bytes(bytes))
72 }
73
74 pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
76 Self::new(T::from_le_bytes(bytes))
77 }
78}
79
80#[cfg(feature = "generic-array")]
81impl<T> NonZero<T>
82where
83 T: ArrayEncoding + Zero,
84{
85 pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
87 Self::new(T::from_be_byte_array(bytes))
88 }
89
90 pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
92 Self::new(T::from_be_byte_array(bytes))
93 }
94}
95
96impl<T> AsRef<T> for NonZero<T>
97where
98 T: Zero,
99{
100 fn as_ref(&self) -> &T {
101 &self.0
102 }
103}
104
105impl<T> ConditionallySelectable for NonZero<T>
106where
107 T: ConditionallySelectable + Zero,
108{
109 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
110 Self(T::conditional_select(&a.0, &b.0, choice))
111 }
112}
113
114impl<T> ConstantTimeEq for NonZero<T>
115where
116 T: Zero,
117{
118 fn ct_eq(&self, other: &Self) -> Choice {
119 self.0.ct_eq(&other.0)
120 }
121}
122
123impl<T> Deref for NonZero<T>
124where
125 T: Zero,
126{
127 type Target = T;
128
129 fn deref(&self) -> &T {
130 &self.0
131 }
132}
133
134#[cfg(feature = "rand_core")]
135impl<T> Random for NonZero<T>
136where
137 T: Random + Zero,
138{
139 fn random(mut rng: &mut impl CryptoRngCore) -> Self {
141 loop {
145 if let Some(result) = Self::new(T::random(&mut rng)).into() {
146 break result;
147 }
148 }
149 }
150}
151
152impl<T> fmt::Display for NonZero<T>
153where
154 T: fmt::Display + Zero,
155{
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 fmt::Display::fmt(&self.0, f)
158 }
159}
160
161impl<T> fmt::Binary for NonZero<T>
162where
163 T: fmt::Binary + Zero,
164{
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 fmt::Binary::fmt(&self.0, f)
167 }
168}
169
170impl<T> fmt::Octal for NonZero<T>
171where
172 T: fmt::Octal + Zero,
173{
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 fmt::Octal::fmt(&self.0, f)
176 }
177}
178
179impl<T> fmt::LowerHex for NonZero<T>
180where
181 T: fmt::LowerHex + Zero,
182{
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 fmt::LowerHex::fmt(&self.0, f)
185 }
186}
187
188impl<T> fmt::UpperHex for NonZero<T>
189where
190 T: fmt::UpperHex + Zero,
191{
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 fmt::UpperHex::fmt(&self.0, f)
194 }
195}
196
197impl NonZero<Limb> {
198 pub const fn from_u8(n: NonZeroU8) -> Self {
201 Self(Limb::from_u8(n.get()))
202 }
203
204 pub const fn from_u16(n: NonZeroU16) -> Self {
207 Self(Limb::from_u16(n.get()))
208 }
209
210 pub const fn from_u32(n: NonZeroU32) -> Self {
213 Self(Limb::from_u32(n.get()))
214 }
215
216 #[cfg(target_pointer_width = "64")]
219 pub const fn from_u64(n: NonZeroU64) -> Self {
220 Self(Limb::from_u64(n.get()))
221 }
222}
223
224impl From<NonZeroU8> for NonZero<Limb> {
225 fn from(integer: NonZeroU8) -> Self {
226 Self::from_u8(integer)
227 }
228}
229
230impl From<NonZeroU16> for NonZero<Limb> {
231 fn from(integer: NonZeroU16) -> Self {
232 Self::from_u16(integer)
233 }
234}
235
236impl From<NonZeroU32> for NonZero<Limb> {
237 fn from(integer: NonZeroU32) -> Self {
238 Self::from_u32(integer)
239 }
240}
241
242#[cfg(target_pointer_width = "64")]
243impl From<NonZeroU64> for NonZero<Limb> {
244 fn from(integer: NonZeroU64) -> Self {
245 Self::from_u64(integer)
246 }
247}
248
249impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
250 pub const fn from_uint(n: Uint<LIMBS>) -> Self {
252 let mut i = 0;
253 let mut found_non_zero = false;
254 while i < LIMBS {
255 if n.as_limbs()[i].0 != 0 {
256 found_non_zero = true;
257 }
258 i += 1;
259 }
260 assert!(found_non_zero, "found zero");
261 Self(n)
262 }
263
264 pub const fn from_u8(n: NonZeroU8) -> Self {
267 Self(Uint::from_u8(n.get()))
268 }
269
270 pub const fn from_u16(n: NonZeroU16) -> Self {
273 Self(Uint::from_u16(n.get()))
274 }
275
276 pub const fn from_u32(n: NonZeroU32) -> Self {
279 Self(Uint::from_u32(n.get()))
280 }
281
282 pub const fn from_u64(n: NonZeroU64) -> Self {
285 Self(Uint::from_u64(n.get()))
286 }
287
288 pub const fn from_u128(n: NonZeroU128) -> Self {
291 Self(Uint::from_u128(n.get()))
292 }
293}
294
295impl<const LIMBS: usize> From<NonZeroU8> for NonZero<Uint<LIMBS>> {
296 fn from(integer: NonZeroU8) -> Self {
297 Self::from_u8(integer)
298 }
299}
300
301impl<const LIMBS: usize> From<NonZeroU16> for NonZero<Uint<LIMBS>> {
302 fn from(integer: NonZeroU16) -> Self {
303 Self::from_u16(integer)
304 }
305}
306
307impl<const LIMBS: usize> From<NonZeroU32> for NonZero<Uint<LIMBS>> {
308 fn from(integer: NonZeroU32) -> Self {
309 Self::from_u32(integer)
310 }
311}
312
313impl<const LIMBS: usize> From<NonZeroU64> for NonZero<Uint<LIMBS>> {
314 fn from(integer: NonZeroU64) -> Self {
315 Self::from_u64(integer)
316 }
317}
318
319impl<const LIMBS: usize> From<NonZeroU128> for NonZero<Uint<LIMBS>> {
320 fn from(integer: NonZeroU128) -> Self {
321 Self::from_u128(integer)
322 }
323}
324
325#[cfg(feature = "serde")]
326impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
327 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
328 where
329 D: Deserializer<'de>,
330 {
331 let value: T = T::deserialize(deserializer)?;
332
333 if bool::from(value.is_zero()) {
334 Err(D::Error::invalid_value(
335 Unexpected::Other("zero"),
336 &"a non-zero value",
337 ))
338 } else {
339 Ok(Self(value))
340 }
341 }
342}
343
344#[cfg(feature = "serde")]
345impl<T: Serialize + Zero> Serialize for NonZero<T> {
346 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
347 where
348 S: Serializer,
349 {
350 self.0.serialize(serializer)
351 }
352}
353
354#[cfg(all(test, feature = "serde"))]
355#[allow(clippy::unwrap_used)]
356mod tests {
357 use crate::{NonZero, U64};
358 use bincode::ErrorKind;
359
360 #[test]
361 fn serde() {
362 let test =
363 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
364
365 let serialized = bincode::serialize(&test).unwrap();
366 let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
367
368 assert_eq!(test, deserialized);
369
370 let serialized = bincode::serialize(&U64::ZERO).unwrap();
371 assert!(matches!(
372 *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
373 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
374 ));
375 }
376
377 #[test]
378 fn serde_owned() {
379 let test =
380 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
381
382 let serialized = bincode::serialize(&test).unwrap();
383 let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
384
385 assert_eq!(test, deserialized);
386
387 let serialized = bincode::serialize(&U64::ZERO).unwrap();
388 assert!(matches!(
389 *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
390 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
391 ));
392 }
393}