curve25519_dalek/backend/serial/u64/
scalar.rs

1//! Arithmetic mod \\(2\^{252} + 27742317777372353535851937790883648493\\)
2//! with five \\(52\\)-bit unsigned limbs.
3//!
4//! \\(51\\)-bit limbs would cover the desired bit range (\\(253\\)
5//! bits), but isn't large enough to reduce a \\(512\\)-bit number with
6//! Montgomery multiplication, so \\(52\\) bits is used instead.  To see
7//! that this is safe for intermediate results, note that the largest
8//! limb in a \\(5\times 5\\) product of \\(52\\)-bit limbs will be
9//!
10//! ```text
11//! (0xfffffffffffff^2) * 5 = 0x4ffffffffffff60000000000005 (107 bits).
12//! ```
13
14use core::fmt::Debug;
15use core::ops::{Index, IndexMut};
16
17#[cfg(feature = "zeroize")]
18use zeroize::Zeroize;
19
20use crate::constants;
21
22/// The `Scalar52` struct represents an element in
23/// \\(\mathbb Z / \ell \mathbb Z\\) as 5 \\(52\\)-bit limbs.
24#[derive(Copy, Clone)]
25pub struct Scalar52(pub [u64; 5]);
26
27impl Debug for Scalar52 {
28    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
29        write!(f, "Scalar52: {:?}", &self.0[..])
30    }
31}
32
33#[cfg(feature = "zeroize")]
34impl Zeroize for Scalar52 {
35    fn zeroize(&mut self) {
36        self.0.zeroize();
37    }
38}
39
40impl Index<usize> for Scalar52 {
41    type Output = u64;
42    fn index(&self, _index: usize) -> &u64 {
43        &(self.0[_index])
44    }
45}
46
47impl IndexMut<usize> for Scalar52 {
48    fn index_mut(&mut self, _index: usize) -> &mut u64 {
49        &mut (self.0[_index])
50    }
51}
52
53/// u64 * u64 = u128 multiply helper
54#[inline(always)]
55fn m(x: u64, y: u64) -> u128 {
56    (x as u128) * (y as u128)
57}
58
59impl Scalar52 {
60    /// The scalar \\( 0 \\).
61    pub const ZERO: Scalar52 = Scalar52([0, 0, 0, 0, 0]);
62
63    /// Unpack a 32 byte / 256 bit scalar into 5 52-bit limbs.
64    #[rustfmt::skip] // keep alignment of s[*] calculations
65    pub fn from_bytes(bytes: &[u8; 32]) -> Scalar52 {
66        let mut words = [0u64; 4];
67        for i in 0..4 {
68            for j in 0..8 {
69                words[i] |= (bytes[(i * 8) + j] as u64) << (j * 8);
70            }
71        }
72
73        let mask = (1u64 << 52) - 1;
74        let top_mask = (1u64 << 48) - 1;
75        let mut s = Scalar52::ZERO;
76
77        s[0] =   words[0]                            & mask;
78        s[1] = ((words[0] >> 52) | (words[1] << 12)) & mask;
79        s[2] = ((words[1] >> 40) | (words[2] << 24)) & mask;
80        s[3] = ((words[2] >> 28) | (words[3] << 36)) & mask;
81        s[4] =  (words[3] >> 16)                     & top_mask;
82
83        s
84    }
85
86    /// Reduce a 64 byte / 512 bit scalar mod l
87    #[rustfmt::skip] // keep alignment of lo[*] and hi[*] calculations
88    pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar52 {
89        let mut words = [0u64; 8];
90        for i in 0..8 {
91            for j in 0..8 {
92                words[i] |= (bytes[(i * 8) + j] as u64) << (j * 8);
93            }
94        }
95
96        let mask = (1u64 << 52) - 1;
97        let mut lo = Scalar52::ZERO;
98        let mut hi = Scalar52::ZERO;
99
100        lo[0] =   words[0]                             & mask;
101        lo[1] = ((words[0] >> 52) | (words[ 1] << 12)) & mask;
102        lo[2] = ((words[1] >> 40) | (words[ 2] << 24)) & mask;
103        lo[3] = ((words[2] >> 28) | (words[ 3] << 36)) & mask;
104        lo[4] = ((words[3] >> 16) | (words[ 4] << 48)) & mask;
105        hi[0] =  (words[4] >>  4)                      & mask;
106        hi[1] = ((words[4] >> 56) | (words[ 5] <<  8)) & mask;
107        hi[2] = ((words[5] >> 44) | (words[ 6] << 20)) & mask;
108        hi[3] = ((words[6] >> 32) | (words[ 7] << 32)) & mask;
109        hi[4] =   words[7] >> 20                             ;
110
111        lo = Scalar52::montgomery_mul(&lo, &constants::R);  // (lo * R) / R = lo
112        hi = Scalar52::montgomery_mul(&hi, &constants::RR); // (hi * R^2) / R = hi * R
113
114        Scalar52::add(&hi, &lo)
115    }
116
117    /// Pack the limbs of this `Scalar52` into 32 bytes
118    #[rustfmt::skip] // keep alignment of s[*] calculations
119    #[allow(clippy::identity_op)]
120    pub fn as_bytes(&self) -> [u8; 32] {
121        let mut s = [0u8; 32];
122
123        s[ 0] =  (self.0[ 0] >>  0)                      as u8;
124        s[ 1] =  (self.0[ 0] >>  8)                      as u8;
125        s[ 2] =  (self.0[ 0] >> 16)                      as u8;
126        s[ 3] =  (self.0[ 0] >> 24)                      as u8;
127        s[ 4] =  (self.0[ 0] >> 32)                      as u8;
128        s[ 5] =  (self.0[ 0] >> 40)                      as u8;
129        s[ 6] = ((self.0[ 0] >> 48) | (self.0[ 1] << 4)) as u8;
130        s[ 7] =  (self.0[ 1] >>  4)                      as u8;
131        s[ 8] =  (self.0[ 1] >> 12)                      as u8;
132        s[ 9] =  (self.0[ 1] >> 20)                      as u8;
133        s[10] =  (self.0[ 1] >> 28)                      as u8;
134        s[11] =  (self.0[ 1] >> 36)                      as u8;
135        s[12] =  (self.0[ 1] >> 44)                      as u8;
136        s[13] =  (self.0[ 2] >>  0)                      as u8;
137        s[14] =  (self.0[ 2] >>  8)                      as u8;
138        s[15] =  (self.0[ 2] >> 16)                      as u8;
139        s[16] =  (self.0[ 2] >> 24)                      as u8;
140        s[17] =  (self.0[ 2] >> 32)                      as u8;
141        s[18] =  (self.0[ 2] >> 40)                      as u8;
142        s[19] = ((self.0[ 2] >> 48) | (self.0[ 3] << 4)) as u8;
143        s[20] =  (self.0[ 3] >>  4)                      as u8;
144        s[21] =  (self.0[ 3] >> 12)                      as u8;
145        s[22] =  (self.0[ 3] >> 20)                      as u8;
146        s[23] =  (self.0[ 3] >> 28)                      as u8;
147        s[24] =  (self.0[ 3] >> 36)                      as u8;
148        s[25] =  (self.0[ 3] >> 44)                      as u8;
149        s[26] =  (self.0[ 4] >>  0)                      as u8;
150        s[27] =  (self.0[ 4] >>  8)                      as u8;
151        s[28] =  (self.0[ 4] >> 16)                      as u8;
152        s[29] =  (self.0[ 4] >> 24)                      as u8;
153        s[30] =  (self.0[ 4] >> 32)                      as u8;
154        s[31] =  (self.0[ 4] >> 40)                      as u8;
155
156        s
157    }
158
159    /// Compute `a + b` (mod l)
160    pub fn add(a: &Scalar52, b: &Scalar52) -> Scalar52 {
161        let mut sum = Scalar52::ZERO;
162        let mask = (1u64 << 52) - 1;
163
164        // a + b
165        let mut carry: u64 = 0;
166        for i in 0..5 {
167            carry = a[i] + b[i] + (carry >> 52);
168            sum[i] = carry & mask;
169        }
170
171        // subtract l if the sum is >= l
172        Scalar52::sub(&sum, &constants::L)
173    }
174
175    /// Compute `a - b` (mod l)
176    pub fn sub(a: &Scalar52, b: &Scalar52) -> Scalar52 {
177        // Optimization barrier to prevent compiler from inserting branch instructions
178        // TODO(tarcieri): find a better home (or abstraction) for this
179        fn black_box(value: u64) -> u64 {
180            // SAFETY: `u64` is a simple integer `Copy` type and `value` lives on the stack so
181            // a pointer to it will be valid.
182            unsafe { core::ptr::read_volatile(&value) }
183        }
184
185        let mut difference = Scalar52::ZERO;
186        let mask = (1u64 << 52) - 1;
187
188        // a - b
189        let mut borrow: u64 = 0;
190        for i in 0..5 {
191            borrow = a[i].wrapping_sub(b[i] + (borrow >> 63));
192            difference[i] = borrow & mask;
193        }
194
195        // conditionally add l if the difference is negative
196        let underflow_mask = ((borrow >> 63) ^ 1).wrapping_sub(1);
197        let mut carry: u64 = 0;
198        for i in 0..5 {
199            // SECURITY: `black_box` prevents LLVM from inserting a `jns` conditional on x86(_64)
200            // which can be used to bypass this section when `underflow_mask` is zero.
201            carry = (carry >> 52) + difference[i] + (constants::L[i] & black_box(underflow_mask));
202            difference[i] = carry & mask;
203        }
204
205        difference
206    }
207
208    /// Compute `a * b`
209    #[inline(always)]
210    #[rustfmt::skip] // keep alignment of z[*] calculations
211    pub (crate) fn mul_internal(a: &Scalar52, b: &Scalar52) -> [u128; 9] {
212        let mut z = [0u128; 9];
213
214        z[0] = m(a[0], b[0]);
215        z[1] = m(a[0], b[1]) + m(a[1], b[0]);
216        z[2] = m(a[0], b[2]) + m(a[1], b[1]) + m(a[2], b[0]);
217        z[3] = m(a[0], b[3]) + m(a[1], b[2]) + m(a[2], b[1]) + m(a[3], b[0]);
218        z[4] = m(a[0], b[4]) + m(a[1], b[3]) + m(a[2], b[2]) + m(a[3], b[1]) + m(a[4], b[0]);
219        z[5] =                 m(a[1], b[4]) + m(a[2], b[3]) + m(a[3], b[2]) + m(a[4], b[1]);
220        z[6] =                                 m(a[2], b[4]) + m(a[3], b[3]) + m(a[4], b[2]);
221        z[7] =                                                 m(a[3], b[4]) + m(a[4], b[3]);
222        z[8] =                                                                 m(a[4], b[4]);
223
224        z
225    }
226
227    /// Compute `a^2`
228    #[inline(always)]
229    #[rustfmt::skip] // keep alignment of return calculations
230    fn square_internal(a: &Scalar52) -> [u128; 9] {
231        let aa = [
232            a[0] * 2,
233            a[1] * 2,
234            a[2] * 2,
235            a[3] * 2,
236        ];
237
238        [
239            m( a[0], a[0]),
240            m(aa[0], a[1]),
241            m(aa[0], a[2]) + m( a[1], a[1]),
242            m(aa[0], a[3]) + m(aa[1], a[2]),
243            m(aa[0], a[4]) + m(aa[1], a[3]) + m( a[2], a[2]),
244                             m(aa[1], a[4]) + m(aa[2], a[3]),
245                                              m(aa[2], a[4]) + m( a[3], a[3]),
246                                                               m(aa[3], a[4]),
247                                                                                m(a[4], a[4])
248        ]
249    }
250
251    /// Compute `limbs/R` (mod l), where R is the Montgomery modulus 2^260
252    #[inline(always)]
253    #[rustfmt::skip] // keep alignment of n* and r* calculations
254    pub (crate) fn montgomery_reduce(limbs: &[u128; 9]) -> Scalar52 {
255
256        #[inline(always)]
257        fn part1(sum: u128) -> (u128, u64) {
258            let p = (sum as u64).wrapping_mul(constants::LFACTOR) & ((1u64 << 52) - 1);
259            ((sum + m(p, constants::L[0])) >> 52, p)
260        }
261
262        #[inline(always)]
263        fn part2(sum: u128) -> (u128, u64) {
264            let w = (sum as u64) & ((1u64 << 52) - 1);
265            (sum >> 52, w)
266        }
267
268        // note: l[3] is zero, so its multiples can be skipped
269        let l = &constants::L;
270
271        // the first half computes the Montgomery adjustment factor n, and begins adding n*l to make limbs divisible by R
272        let (carry, n0) = part1(        limbs[0]);
273        let (carry, n1) = part1(carry + limbs[1] + m(n0, l[1]));
274        let (carry, n2) = part1(carry + limbs[2] + m(n0, l[2]) + m(n1, l[1]));
275        let (carry, n3) = part1(carry + limbs[3]               + m(n1, l[2]) + m(n2, l[1]));
276        let (carry, n4) = part1(carry + limbs[4] + m(n0, l[4])               + m(n2, l[2]) + m(n3, l[1]));
277
278        // limbs is divisible by R now, so we can divide by R by simply storing the upper half as the result
279        let (carry, r0) = part2(carry + limbs[5]               + m(n1, l[4])               + m(n3, l[2])   + m(n4, l[1]));
280        let (carry, r1) = part2(carry + limbs[6]                             + m(n2,l[4])                  + m(n4, l[2]));
281        let (carry, r2) = part2(carry + limbs[7]                                           + m(n3, l[4])                );
282        let (carry, r3) = part2(carry + limbs[8]                                                           + m(n4, l[4]));
283        let         r4 = carry as u64;
284
285        // result may be >= l, so attempt to subtract l
286        Scalar52::sub(&Scalar52([r0, r1, r2, r3, r4]), l)
287    }
288
289    /// Compute `a * b` (mod l)
290    #[inline(never)]
291    pub fn mul(a: &Scalar52, b: &Scalar52) -> Scalar52 {
292        let ab = Scalar52::montgomery_reduce(&Scalar52::mul_internal(a, b));
293        Scalar52::montgomery_reduce(&Scalar52::mul_internal(&ab, &constants::RR))
294    }
295
296    /// Compute `a^2` (mod l)
297    #[inline(never)]
298    #[allow(dead_code)] // XXX we don't expose square() via the Scalar API
299    pub fn square(&self) -> Scalar52 {
300        let aa = Scalar52::montgomery_reduce(&Scalar52::square_internal(self));
301        Scalar52::montgomery_reduce(&Scalar52::mul_internal(&aa, &constants::RR))
302    }
303
304    /// Compute `(a * b) / R` (mod l), where R is the Montgomery modulus 2^260
305    #[inline(never)]
306    pub fn montgomery_mul(a: &Scalar52, b: &Scalar52) -> Scalar52 {
307        Scalar52::montgomery_reduce(&Scalar52::mul_internal(a, b))
308    }
309
310    /// Compute `(a^2) / R` (mod l) in Montgomery form, where R is the Montgomery modulus 2^260
311    #[inline(never)]
312    pub fn montgomery_square(&self) -> Scalar52 {
313        Scalar52::montgomery_reduce(&Scalar52::square_internal(self))
314    }
315
316    /// Puts a Scalar52 in to Montgomery form, i.e. computes `a*R (mod l)`
317    #[inline(never)]
318    pub fn as_montgomery(&self) -> Scalar52 {
319        Scalar52::montgomery_mul(self, &constants::RR)
320    }
321
322    /// Takes a Scalar52 out of Montgomery form, i.e. computes `a/R (mod l)`
323    #[allow(clippy::wrong_self_convention)]
324    #[inline(never)]
325    pub fn from_montgomery(&self) -> Scalar52 {
326        let mut limbs = [0u128; 9];
327        for i in 0..5 {
328            limbs[i] = self[i] as u128;
329        }
330        Scalar52::montgomery_reduce(&limbs)
331    }
332}
333
334#[cfg(test)]
335mod test {
336    use super::*;
337
338    /// Note: x is 2^253-1 which is slightly larger than the largest scalar produced by
339    /// this implementation (l-1), and should show there are no overflows for valid scalars
340    ///
341    /// x = 14474011154664524427946373126085988481658748083205070504932198000989141204991
342    /// x = 7237005577332262213973186563042994240801631723825162898930247062703686954002 mod l
343    /// x = 3057150787695215392275360544382990118917283750546154083604586903220563173085*R mod l in Montgomery form
344    pub static X: Scalar52 = Scalar52([
345        0x000fffffffffffff,
346        0x000fffffffffffff,
347        0x000fffffffffffff,
348        0x000fffffffffffff,
349        0x00001fffffffffff,
350    ]);
351
352    /// x^2 = 3078544782642840487852506753550082162405942681916160040940637093560259278169 mod l
353    pub static XX: Scalar52 = Scalar52([
354        0x0001668020217559,
355        0x000531640ffd0ec0,
356        0x00085fd6f9f38a31,
357        0x000c268f73bb1cf4,
358        0x000006ce65046df0,
359    ]);
360
361    /// x^2 = 4413052134910308800482070043710297189082115023966588301924965890668401540959*R mod l in Montgomery form
362    pub static XX_MONT: Scalar52 = Scalar52([
363        0x000c754eea569a5c,
364        0x00063b6ed36cb215,
365        0x0008ffa36bf25886,
366        0x000e9183614e7543,
367        0x0000061db6c6f26f,
368    ]);
369
370    /// y = 6145104759870991071742105800796537629880401874866217824609283457819451087098
371    pub static Y: Scalar52 = Scalar52([
372        0x000b75071e1458fa,
373        0x000bf9d75e1ecdac,
374        0x000433d2baf0672b,
375        0x0005fffcc11fad13,
376        0x00000d96018bb825,
377    ]);
378
379    /// x*y = 36752150652102274958925982391442301741 mod l
380    pub static XY: Scalar52 = Scalar52([
381        0x000ee6d76ba7632d,
382        0x000ed50d71d84e02,
383        0x00000000001ba634,
384        0x0000000000000000,
385        0x0000000000000000,
386    ]);
387
388    /// x*y = 658448296334113745583381664921721413881518248721417041768778176391714104386*R mod l in Montgomery form
389    pub static XY_MONT: Scalar52 = Scalar52([
390        0x0006d52bf200cfd5,
391        0x00033fb1d7021570,
392        0x000f201bc07139d8,
393        0x0001267e3e49169e,
394        0x000007b839c00268,
395    ]);
396
397    /// a = 2351415481556538453565687241199399922945659411799870114962672658845158063753
398    pub static A: Scalar52 = Scalar52([
399        0x0005236c07b3be89,
400        0x0001bc3d2a67c0c4,
401        0x000a4aa782aae3ee,
402        0x0006b3f6e4fec4c4,
403        0x00000532da9fab8c,
404    ]);
405
406    /// b = 4885590095775723760407499321843594317911456947580037491039278279440296187236
407    pub static B: Scalar52 = Scalar52([
408        0x000d3fae55421564,
409        0x000c2df24f65a4bc,
410        0x0005b5587d69fb0b,
411        0x00094c091b013b3b,
412        0x00000acd25605473,
413    ]);
414
415    /// a+b = 0
416    /// a-b = 4702830963113076907131374482398799845891318823599740229925345317690316127506
417    pub static AB: Scalar52 = Scalar52([
418        0x000a46d80f677d12,
419        0x0003787a54cf8188,
420        0x0004954f0555c7dc,
421        0x000d67edc9fd8989,
422        0x00000a65b53f5718,
423    ]);
424
425    // c = (2^512 - 1) % l = 1627715501170711445284395025044413883736156588369414752970002579683115011840
426    pub static C: Scalar52 = Scalar52([
427        0x000611e3449c0f00,
428        0x000a768859347a40,
429        0x0007f5be65d00e1b,
430        0x0009a3dceec73d21,
431        0x00000399411b7c30,
432    ]);
433
434    #[test]
435    fn mul_max() {
436        let res = Scalar52::mul(&X, &X);
437        for i in 0..5 {
438            assert!(res[i] == XX[i]);
439        }
440    }
441
442    #[test]
443    fn square_max() {
444        let res = X.square();
445        for i in 0..5 {
446            assert!(res[i] == XX[i]);
447        }
448    }
449
450    #[test]
451    fn montgomery_mul_max() {
452        let res = Scalar52::montgomery_mul(&X, &X);
453        for i in 0..5 {
454            assert!(res[i] == XX_MONT[i]);
455        }
456    }
457
458    #[test]
459    fn montgomery_square_max() {
460        let res = X.montgomery_square();
461        for i in 0..5 {
462            assert!(res[i] == XX_MONT[i]);
463        }
464    }
465
466    #[test]
467    fn mul() {
468        let res = Scalar52::mul(&X, &Y);
469        for i in 0..5 {
470            assert!(res[i] == XY[i]);
471        }
472    }
473
474    #[test]
475    fn montgomery_mul() {
476        let res = Scalar52::montgomery_mul(&X, &Y);
477        for i in 0..5 {
478            assert!(res[i] == XY_MONT[i]);
479        }
480    }
481
482    #[test]
483    fn add() {
484        let res = Scalar52::add(&A, &B);
485        let zero = Scalar52::ZERO;
486        for i in 0..5 {
487            assert!(res[i] == zero[i]);
488        }
489    }
490
491    #[test]
492    fn sub() {
493        let res = Scalar52::sub(&A, &B);
494        for i in 0..5 {
495            assert!(res[i] == AB[i]);
496        }
497    }
498
499    #[test]
500    fn from_bytes_wide() {
501        let bignum = [255u8; 64]; // 2^512 - 1
502        let reduced = Scalar52::from_bytes_wide(&bignum);
503        for i in 0..5 {
504            assert!(reduced[i] == C[i]);
505        }
506    }
507}