crypto_bigint/
non_zero.rs

1//! Wrapper type for non-zero integers.
2
3use crate::{CtChoice, Encoding, Integer, Limb, Uint, Zero};
4use core::{
5    fmt,
6    num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
7    ops::Deref,
8};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "generic-array")]
12use crate::{ArrayEncoding, ByteArray};
13
14#[cfg(feature = "rand_core")]
15use {crate::Random, rand_core::CryptoRngCore};
16
17#[cfg(feature = "serde")]
18use serdect::serde::{
19    de::{Error, Unexpected},
20    Deserialize, Deserializer, Serialize, Serializer,
21};
22
23/// Wrapper type for non-zero integers.
24#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
25pub struct NonZero<T: Zero>(T);
26
27impl NonZero<Limb> {
28    /// Creates a new non-zero limb in a const context.
29    /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise.
30    pub const fn const_new(n: Limb) -> (Self, CtChoice) {
31        (Self(n), n.ct_is_nonzero())
32    }
33}
34
35impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
36    /// Creates a new non-zero integer in a const context.
37    /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise.
38    pub const fn const_new(n: Uint<LIMBS>) -> (Self, CtChoice) {
39        (Self(n), n.ct_is_nonzero())
40    }
41}
42
43impl<T> NonZero<T>
44where
45    T: Zero,
46{
47    /// Create a new non-zero integer.
48    pub fn new(n: T) -> CtOption<Self> {
49        let is_zero = n.is_zero();
50        CtOption::new(Self(n), !is_zero)
51    }
52}
53
54impl<T> NonZero<T>
55where
56    T: Integer,
57{
58    /// The value `1`.
59    pub const ONE: Self = Self(T::ONE);
60
61    /// Maximum value this integer can express.
62    pub const MAX: Self = Self(T::MAX);
63}
64
65impl<T> NonZero<T>
66where
67    T: Encoding + Zero,
68{
69    /// Decode from big endian bytes.
70    pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
71        Self::new(T::from_be_bytes(bytes))
72    }
73
74    /// Decode from little endian bytes.
75    pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
76        Self::new(T::from_le_bytes(bytes))
77    }
78}
79
80#[cfg(feature = "generic-array")]
81impl<T> NonZero<T>
82where
83    T: ArrayEncoding + Zero,
84{
85    /// Decode a non-zero integer from big endian bytes.
86    pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
87        Self::new(T::from_be_byte_array(bytes))
88    }
89
90    /// Decode a non-zero integer from big endian bytes.
91    pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
92        Self::new(T::from_be_byte_array(bytes))
93    }
94}
95
96impl<T> AsRef<T> for NonZero<T>
97where
98    T: Zero,
99{
100    fn as_ref(&self) -> &T {
101        &self.0
102    }
103}
104
105impl<T> ConditionallySelectable for NonZero<T>
106where
107    T: ConditionallySelectable + Zero,
108{
109    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
110        Self(T::conditional_select(&a.0, &b.0, choice))
111    }
112}
113
114impl<T> ConstantTimeEq for NonZero<T>
115where
116    T: Zero,
117{
118    fn ct_eq(&self, other: &Self) -> Choice {
119        self.0.ct_eq(&other.0)
120    }
121}
122
123impl<T> Deref for NonZero<T>
124where
125    T: Zero,
126{
127    type Target = T;
128
129    fn deref(&self) -> &T {
130        &self.0
131    }
132}
133
134#[cfg(feature = "rand_core")]
135impl<T> Random for NonZero<T>
136where
137    T: Random + Zero,
138{
139    /// Generate a random `NonZero<T>`.
140    fn random(mut rng: &mut impl CryptoRngCore) -> Self {
141        // Use rejection sampling to eliminate zero values.
142        // While this method isn't constant-time, the attacker shouldn't learn
143        // anything about unrelated outputs so long as `rng` is a CSRNG.
144        loop {
145            if let Some(result) = Self::new(T::random(&mut rng)).into() {
146                break result;
147            }
148        }
149    }
150}
151
152impl<T> fmt::Display for NonZero<T>
153where
154    T: fmt::Display + Zero,
155{
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        fmt::Display::fmt(&self.0, f)
158    }
159}
160
161impl<T> fmt::Binary for NonZero<T>
162where
163    T: fmt::Binary + Zero,
164{
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        fmt::Binary::fmt(&self.0, f)
167    }
168}
169
170impl<T> fmt::Octal for NonZero<T>
171where
172    T: fmt::Octal + Zero,
173{
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        fmt::Octal::fmt(&self.0, f)
176    }
177}
178
179impl<T> fmt::LowerHex for NonZero<T>
180where
181    T: fmt::LowerHex + Zero,
182{
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        fmt::LowerHex::fmt(&self.0, f)
185    }
186}
187
188impl<T> fmt::UpperHex for NonZero<T>
189where
190    T: fmt::UpperHex + Zero,
191{
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        fmt::UpperHex::fmt(&self.0, f)
194    }
195}
196
197impl NonZero<Limb> {
198    /// Create a [`NonZero<Limb>`] from a [`NonZeroU8`] (const-friendly)
199    // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
200    pub const fn from_u8(n: NonZeroU8) -> Self {
201        Self(Limb::from_u8(n.get()))
202    }
203
204    /// Create a [`NonZero<Limb>`] from a [`NonZeroU16`] (const-friendly)
205    // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
206    pub const fn from_u16(n: NonZeroU16) -> Self {
207        Self(Limb::from_u16(n.get()))
208    }
209
210    /// Create a [`NonZero<Limb>`] from a [`NonZeroU32`] (const-friendly)
211    // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
212    pub const fn from_u32(n: NonZeroU32) -> Self {
213        Self(Limb::from_u32(n.get()))
214    }
215
216    /// Create a [`NonZero<Limb>`] from a [`NonZeroU64`] (const-friendly)
217    // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
218    #[cfg(target_pointer_width = "64")]
219    pub const fn from_u64(n: NonZeroU64) -> Self {
220        Self(Limb::from_u64(n.get()))
221    }
222}
223
224impl From<NonZeroU8> for NonZero<Limb> {
225    fn from(integer: NonZeroU8) -> Self {
226        Self::from_u8(integer)
227    }
228}
229
230impl From<NonZeroU16> for NonZero<Limb> {
231    fn from(integer: NonZeroU16) -> Self {
232        Self::from_u16(integer)
233    }
234}
235
236impl From<NonZeroU32> for NonZero<Limb> {
237    fn from(integer: NonZeroU32) -> Self {
238        Self::from_u32(integer)
239    }
240}
241
242#[cfg(target_pointer_width = "64")]
243impl From<NonZeroU64> for NonZero<Limb> {
244    fn from(integer: NonZeroU64) -> Self {
245        Self::from_u64(integer)
246    }
247}
248
249impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
250    /// Create a [`NonZero<Uint>`] from a [`Uint`] (const-friendly)
251    pub const fn from_uint(n: Uint<LIMBS>) -> Self {
252        let mut i = 0;
253        let mut found_non_zero = false;
254        while i < LIMBS {
255            if n.as_limbs()[i].0 != 0 {
256                found_non_zero = true;
257            }
258            i += 1;
259        }
260        assert!(found_non_zero, "found zero");
261        Self(n)
262    }
263
264    /// Create a [`NonZero<Uint>`] from a [`NonZeroU8`] (const-friendly)
265    // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
266    pub const fn from_u8(n: NonZeroU8) -> Self {
267        Self(Uint::from_u8(n.get()))
268    }
269
270    /// Create a [`NonZero<Uint>`] from a [`NonZeroU16`] (const-friendly)
271    // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
272    pub const fn from_u16(n: NonZeroU16) -> Self {
273        Self(Uint::from_u16(n.get()))
274    }
275
276    /// Create a [`NonZero<Uint>`] from a [`NonZeroU32`] (const-friendly)
277    // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
278    pub const fn from_u32(n: NonZeroU32) -> Self {
279        Self(Uint::from_u32(n.get()))
280    }
281
282    /// Create a [`NonZero<Uint>`] from a [`NonZeroU64`] (const-friendly)
283    // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
284    pub const fn from_u64(n: NonZeroU64) -> Self {
285        Self(Uint::from_u64(n.get()))
286    }
287
288    /// Create a [`NonZero<Uint>`] from a [`NonZeroU128`] (const-friendly)
289    // TODO(tarcieri): replace with `const impl From<NonZeroU128>` when stable
290    pub const fn from_u128(n: NonZeroU128) -> Self {
291        Self(Uint::from_u128(n.get()))
292    }
293}
294
295impl<const LIMBS: usize> From<NonZeroU8> for NonZero<Uint<LIMBS>> {
296    fn from(integer: NonZeroU8) -> Self {
297        Self::from_u8(integer)
298    }
299}
300
301impl<const LIMBS: usize> From<NonZeroU16> for NonZero<Uint<LIMBS>> {
302    fn from(integer: NonZeroU16) -> Self {
303        Self::from_u16(integer)
304    }
305}
306
307impl<const LIMBS: usize> From<NonZeroU32> for NonZero<Uint<LIMBS>> {
308    fn from(integer: NonZeroU32) -> Self {
309        Self::from_u32(integer)
310    }
311}
312
313impl<const LIMBS: usize> From<NonZeroU64> for NonZero<Uint<LIMBS>> {
314    fn from(integer: NonZeroU64) -> Self {
315        Self::from_u64(integer)
316    }
317}
318
319impl<const LIMBS: usize> From<NonZeroU128> for NonZero<Uint<LIMBS>> {
320    fn from(integer: NonZeroU128) -> Self {
321        Self::from_u128(integer)
322    }
323}
324
325#[cfg(feature = "serde")]
326impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
327    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
328    where
329        D: Deserializer<'de>,
330    {
331        let value: T = T::deserialize(deserializer)?;
332
333        if bool::from(value.is_zero()) {
334            Err(D::Error::invalid_value(
335                Unexpected::Other("zero"),
336                &"a non-zero value",
337            ))
338        } else {
339            Ok(Self(value))
340        }
341    }
342}
343
344#[cfg(feature = "serde")]
345impl<T: Serialize + Zero> Serialize for NonZero<T> {
346    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
347    where
348        S: Serializer,
349    {
350        self.0.serialize(serializer)
351    }
352}
353
354#[cfg(all(test, feature = "serde"))]
355#[allow(clippy::unwrap_used)]
356mod tests {
357    use crate::{NonZero, U64};
358    use bincode::ErrorKind;
359
360    #[test]
361    fn serde() {
362        let test =
363            Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
364
365        let serialized = bincode::serialize(&test).unwrap();
366        let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
367
368        assert_eq!(test, deserialized);
369
370        let serialized = bincode::serialize(&U64::ZERO).unwrap();
371        assert!(matches!(
372            *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
373            ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
374        ));
375    }
376
377    #[test]
378    fn serde_owned() {
379        let test =
380            Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
381
382        let serialized = bincode::serialize(&test).unwrap();
383        let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
384
385        assert_eq!(test, deserialized);
386
387        let serialized = bincode::serialize(&U64::ZERO).unwrap();
388        assert!(matches!(
389            *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
390            ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
391        ));
392    }
393}