crypto_bigint/uint/
sqrt.rs

1//! [`Uint`] square root operations.
2
3use super::Uint;
4use crate::{Limb, Word};
5use subtle::{ConstantTimeEq, CtOption};
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8    /// See [`Self::sqrt_vartime`].
9    #[deprecated(
10        since = "0.5.3",
11        note = "This functionality will be moved to `sqrt_vartime` in a future release."
12    )]
13    pub const fn sqrt(&self) -> Self {
14        self.sqrt_vartime()
15    }
16
17    /// Computes √(`self`)
18    /// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
19    ///
20    /// Callers can check if `self` is a square by squaring the result
21    pub const fn sqrt_vartime(&self) -> Self {
22        let max_bits = (self.bits_vartime() + 1) >> 1;
23        let cap = Self::ONE.shl_vartime(max_bits);
24        let mut guess = cap; // ≥ √(`self`)
25        let mut xn = {
26            let q = self.wrapping_div(&guess);
27            let t = guess.wrapping_add(&q);
28            t.shr_vartime(1)
29        };
30
31        // If guess increased, the initial guess was low.
32        // Repeat until reverse course.
33        while Uint::ct_lt(&guess, &xn).is_true_vartime() {
34            // Sometimes an increase is too far, especially with large
35            // powers, and then takes a long time to walk back.  The upper
36            // bound is based on bit size, so saturate on that.
37            let le = Limb::ct_le(Limb(xn.bits_vartime() as Word), Limb(max_bits as Word));
38            guess = Self::ct_select(&cap, &xn, le);
39            xn = {
40                let q = self.wrapping_div(&guess);
41                let t = guess.wrapping_add(&q);
42                t.shr_vartime(1)
43            };
44        }
45
46        // Repeat while guess decreases.
47        while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() {
48            guess = xn;
49            xn = {
50                let q = self.wrapping_div(&guess);
51                let t = guess.wrapping_add(&q);
52                t.shr_vartime(1)
53            };
54        }
55
56        Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero())
57    }
58
59    /// See [`Self::wrapping_sqrt_vartime`].
60    #[deprecated(
61        since = "0.5.3",
62        note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release."
63    )]
64    pub const fn wrapping_sqrt(&self) -> Self {
65        self.wrapping_sqrt_vartime()
66    }
67
68    /// Wrapped sqrt is just normal √(`self`)
69    /// There’s no way wrapping could ever happen.
70    /// This function exists, so that all operations are accounted for in the wrapping operations.
71    pub const fn wrapping_sqrt_vartime(&self) -> Self {
72        self.sqrt_vartime()
73    }
74
75    /// See [`Self::checked_sqrt_vartime`].
76    #[deprecated(
77        since = "0.5.3",
78        note = "This functionality will be moved to `checked_sqrt_vartime` in a future release."
79    )]
80    pub fn checked_sqrt(&self) -> CtOption<Self> {
81        self.checked_sqrt_vartime()
82    }
83
84    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
85    /// only if the √(`self`)² == self
86    pub fn checked_sqrt_vartime(&self) -> CtOption<Self> {
87        let r = self.sqrt_vartime();
88        let s = r.wrapping_mul(&r);
89        CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use crate::{Limb, U256};
96
97    #[cfg(feature = "rand")]
98    use {
99        crate::{CheckedMul, Random, U512},
100        rand_chacha::ChaChaRng,
101        rand_core::{RngCore, SeedableRng},
102    };
103
104    #[test]
105    fn edge() {
106        assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO);
107        assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE);
108        let mut half = U256::ZERO;
109        for i in 0..half.limbs.len() / 2 {
110            half.limbs[i] = Limb::MAX;
111        }
112        assert_eq!(U256::MAX.sqrt_vartime(), half,);
113    }
114
115    #[test]
116    fn simple() {
117        let tests = [
118            (4u8, 2u8),
119            (9, 3),
120            (16, 4),
121            (25, 5),
122            (36, 6),
123            (49, 7),
124            (64, 8),
125            (81, 9),
126            (100, 10),
127            (121, 11),
128            (144, 12),
129            (169, 13),
130        ];
131        for (a, e) in &tests {
132            let l = U256::from(*a);
133            let r = U256::from(*e);
134            assert_eq!(l.sqrt_vartime(), r);
135            assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
136        }
137    }
138
139    #[test]
140    fn nonsquares() {
141        assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8));
142        assert_eq!(
143            U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
144            0
145        );
146        assert_eq!(U256::from(3u8).sqrt_vartime(), U256::from(1u8));
147        assert_eq!(
148            U256::from(3u8).checked_sqrt_vartime().is_some().unwrap_u8(),
149            0
150        );
151        assert_eq!(U256::from(5u8).sqrt_vartime(), U256::from(2u8));
152        assert_eq!(U256::from(6u8).sqrt_vartime(), U256::from(2u8));
153        assert_eq!(U256::from(7u8).sqrt_vartime(), U256::from(2u8));
154        assert_eq!(U256::from(8u8).sqrt_vartime(), U256::from(2u8));
155        assert_eq!(U256::from(10u8).sqrt_vartime(), U256::from(3u8));
156    }
157
158    #[cfg(feature = "rand")]
159    #[test]
160    fn fuzz() {
161        let mut rng = ChaChaRng::from_seed([7u8; 32]);
162        for _ in 0..50 {
163            let t = rng.next_u32() as u64;
164            let s = U256::from(t);
165            let s2 = s.checked_mul(&s).unwrap();
166            assert_eq!(s2.sqrt_vartime(), s);
167            assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
168        }
169
170        for _ in 0..50 {
171            let s = U256::random(&mut rng);
172            let mut s2 = U512::ZERO;
173            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
174            assert_eq!(s.square().sqrt_vartime(), s2);
175        }
176    }
177}