dalek_ff_group/
field.rs

1use core::{
2  ops::{Add, AddAssign, Sub, SubAssign, Neg, Mul, MulAssign},
3  iter::{Sum, Product},
4};
5
6use zeroize::Zeroize;
7use rand_core::RngCore;
8
9use subtle::{
10  Choice, CtOption, ConstantTimeEq, ConstantTimeLess, ConditionallyNegatable,
11  ConditionallySelectable,
12};
13
14use crypto_bigint::{
15  Integer, NonZero, Encoding, U256, U512,
16  modular::constant_mod::{ResidueParams, Residue},
17  impl_modulus,
18};
19
20use group::ff::{Field, PrimeField, FieldBits, PrimeFieldBits};
21
22use crate::{u8_from_bool, constant_time, math_op, math};
23
24// 2 ** 255 - 19
25// Uses saturating_sub because checked_sub isn't available at compile time
26const MODULUS: U256 = U256::from_u8(1).shl_vartime(255).saturating_sub(&U256::from_u8(19));
27const WIDE_MODULUS: U512 = U256::ZERO.concat(&MODULUS);
28
29impl_modulus!(
30  FieldModulus,
31  U256,
32  // 2 ** 255 - 19
33  "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed"
34);
35type ResidueType = Residue<FieldModulus, { FieldModulus::LIMBS }>;
36
37/// A constant-time implementation of the Ed25519 field.
38#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)]
39pub struct FieldElement(ResidueType);
40
41// Square root of -1.
42// Formula from RFC-8032 (modp_sqrt_m1/sqrt8k5 z)
43// 2 ** ((MODULUS - 1) // 4) % MODULUS
44const SQRT_M1: FieldElement = FieldElement(
45  ResidueType::new(&U256::from_u8(2))
46    .pow(&MODULUS.saturating_sub(&U256::ONE).wrapping_div(&U256::from_u8(4))),
47);
48
49// Constant useful in calculating square roots (RFC-8032 sqrt8k5's exponent used to calculate y)
50const MOD_3_8: FieldElement = FieldElement(ResidueType::new(
51  &MODULUS.saturating_add(&U256::from_u8(3)).wrapping_div(&U256::from_u8(8)),
52));
53
54// Constant useful in sqrt_ratio_i (sqrt(u / v))
55const MOD_5_8: FieldElement = FieldElement(ResidueType::sub(&MOD_3_8.0, &ResidueType::ONE));
56
57fn reduce(x: U512) -> ResidueType {
58  ResidueType::new(&U256::from_le_slice(
59    &x.rem(&NonZero::new(WIDE_MODULUS).unwrap()).to_le_bytes()[.. 32],
60  ))
61}
62
63constant_time!(FieldElement, ResidueType);
64math!(
65  FieldElement,
66  FieldElement,
67  |x: ResidueType, y: ResidueType| x.add(&y),
68  |x: ResidueType, y: ResidueType| x.sub(&y),
69  |x: ResidueType, y: ResidueType| x.mul(&y)
70);
71
72macro_rules! from_wrapper {
73  ($uint: ident) => {
74    impl From<$uint> for FieldElement {
75      fn from(a: $uint) -> FieldElement {
76        Self(ResidueType::new(&U256::from(a)))
77      }
78    }
79  };
80}
81
82from_wrapper!(u8);
83from_wrapper!(u16);
84from_wrapper!(u32);
85from_wrapper!(u64);
86from_wrapper!(u128);
87
88impl Neg for FieldElement {
89  type Output = Self;
90  fn neg(self) -> Self::Output {
91    Self(self.0.neg())
92  }
93}
94
95impl<'a> Neg for &'a FieldElement {
96  type Output = FieldElement;
97  fn neg(self) -> Self::Output {
98    (*self).neg()
99  }
100}
101
102impl Field for FieldElement {
103  const ZERO: Self = Self(ResidueType::ZERO);
104  const ONE: Self = Self(ResidueType::ONE);
105
106  fn random(mut rng: impl RngCore) -> Self {
107    let mut bytes = [0; 64];
108    rng.fill_bytes(&mut bytes);
109    FieldElement(reduce(U512::from_le_bytes(bytes)))
110  }
111
112  fn square(&self) -> Self {
113    FieldElement(self.0.square())
114  }
115  fn double(&self) -> Self {
116    FieldElement(self.0.add(&self.0))
117  }
118
119  fn invert(&self) -> CtOption<Self> {
120    const NEG_2: FieldElement =
121      FieldElement(ResidueType::new(&MODULUS.saturating_sub(&U256::from_u8(2))));
122    CtOption::new(self.pow(NEG_2), !self.is_zero())
123  }
124
125  // RFC-8032 sqrt8k5
126  fn sqrt(&self) -> CtOption<Self> {
127    let tv1 = self.pow(MOD_3_8);
128    let tv2 = tv1 * SQRT_M1;
129    let candidate = Self::conditional_select(&tv2, &tv1, tv1.square().ct_eq(self));
130    CtOption::new(candidate, candidate.square().ct_eq(self))
131  }
132
133  fn sqrt_ratio(u: &FieldElement, v: &FieldElement) -> (Choice, FieldElement) {
134    let i = SQRT_M1;
135
136    let u = *u;
137    let v = *v;
138
139    let v3 = v.square() * v;
140    let v7 = v3.square() * v;
141    let mut r = (u * v3) * (u * v7).pow(MOD_5_8);
142
143    let check = v * r.square();
144    let correct_sign = check.ct_eq(&u);
145    let flipped_sign = check.ct_eq(&(-u));
146    let flipped_sign_i = check.ct_eq(&((-u) * i));
147
148    r.conditional_assign(&(r * i), flipped_sign | flipped_sign_i);
149
150    let r_is_negative = r.is_odd();
151    r.conditional_negate(r_is_negative);
152
153    (correct_sign | flipped_sign, r)
154  }
155}
156
157impl PrimeField for FieldElement {
158  type Repr = [u8; 32];
159
160  // Big endian representation of the modulus
161  const MODULUS: &'static str = "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed";
162
163  const NUM_BITS: u32 = 255;
164  const CAPACITY: u32 = 254;
165
166  const TWO_INV: Self = FieldElement(ResidueType::new(&U256::from_u8(2)).invert().0);
167
168  // This was calculated with the method from the ff crate docs
169  // SageMath GF(modulus).primitive_element()
170  const MULTIPLICATIVE_GENERATOR: Self = Self(ResidueType::new(&U256::from_u8(2)));
171  // This was set per the specification in the ff crate docs
172  // The number of leading zero bits in the little-endian bit representation of (modulus - 1)
173  const S: u32 = 2;
174
175  // This was calculated via the formula from the ff crate docs
176  // Self::MULTIPLICATIVE_GENERATOR ** ((modulus - 1) >> Self::S)
177  const ROOT_OF_UNITY: Self = FieldElement(ResidueType::new(&U256::from_be_hex(
178    "2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0",
179  )));
180  // Self::ROOT_OF_UNITY.invert()
181  const ROOT_OF_UNITY_INV: Self = FieldElement(Self::ROOT_OF_UNITY.0.invert().0);
182
183  // This was calculated via the formula from the ff crate docs
184  // Self::MULTIPLICATIVE_GENERATOR ** (2 ** Self::S)
185  const DELTA: Self = FieldElement(ResidueType::new(&U256::from_be_hex(
186    "0000000000000000000000000000000000000000000000000000000000000010",
187  )));
188
189  fn from_repr(bytes: [u8; 32]) -> CtOption<Self> {
190    let res = U256::from_le_bytes(bytes);
191    CtOption::new(Self(ResidueType::new(&res)), res.ct_lt(&MODULUS))
192  }
193  fn to_repr(&self) -> [u8; 32] {
194    self.0.retrieve().to_le_bytes()
195  }
196
197  fn is_odd(&self) -> Choice {
198    self.0.retrieve().is_odd()
199  }
200
201  fn from_u128(num: u128) -> Self {
202    Self::from(num)
203  }
204}
205
206impl PrimeFieldBits for FieldElement {
207  type ReprBits = [u8; 32];
208
209  fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
210    self.to_repr().into()
211  }
212
213  fn char_le_bits() -> FieldBits<Self::ReprBits> {
214    MODULUS.to_le_bytes().into()
215  }
216}
217
218impl FieldElement {
219  /// Interpret the value as a little-endian integer, square it, and reduce it into a FieldElement.
220  pub fn from_square(value: [u8; 32]) -> FieldElement {
221    let value = U256::from_le_bytes(value);
222    FieldElement(reduce(U512::from(value.mul_wide(&value))))
223  }
224
225  /// Perform an exponentiation.
226  pub fn pow(&self, other: FieldElement) -> FieldElement {
227    let mut table = [FieldElement::ONE; 16];
228    table[1] = *self;
229    for i in 2 .. 16 {
230      table[i] = table[i - 1] * self;
231    }
232
233    let mut res = FieldElement::ONE;
234    let mut bits = 0;
235    for (i, mut bit) in other.to_le_bits().iter_mut().rev().enumerate() {
236      bits <<= 1;
237      let mut bit = u8_from_bool(&mut bit);
238      bits |= bit;
239      bit.zeroize();
240
241      if ((i + 1) % 4) == 0 {
242        if i != 3 {
243          for _ in 0 .. 4 {
244            res *= res;
245          }
246        }
247
248        let mut scale_by = FieldElement::ONE;
249        #[allow(clippy::needless_range_loop)]
250        for i in 0 .. 16 {
251          #[allow(clippy::cast_possible_truncation)] // Safe since 0 .. 16
252          {
253            scale_by = <_>::conditional_select(&scale_by, &table[i], bits.ct_eq(&(i as u8)));
254          }
255        }
256        res *= scale_by;
257        bits = 0;
258      }
259    }
260    res
261  }
262
263  /// The square root of u/v, as used for Ed25519 point decoding (RFC 8032 5.1.3) and within
264  /// Ristretto (5.1 Extracting an Inverse Square Root).
265  ///
266  /// The result is only a valid square root if the Choice is true.
267  /// RFC 8032 simply fails if there isn't a square root, leaving any return value undefined.
268  /// Ristretto explicitly returns 0 or sqrt((SQRT_M1 * u) / v).
269  pub fn sqrt_ratio_i(u: FieldElement, v: FieldElement) -> (Choice, FieldElement) {
270    let i = SQRT_M1;
271
272    let v3 = v.square() * v;
273    let v7 = v3.square() * v;
274    // Candidate root
275    let mut r = (u * v3) * (u * v7).pow(MOD_5_8);
276
277    // 8032 3.1
278    let check = v * r.square();
279    let correct_sign = check.ct_eq(&u);
280    // 8032 3.2 conditional
281    let neg_u = -u;
282    let flipped_sign = check.ct_eq(&neg_u);
283    // Ristretto Step 5
284    let flipped_sign_i = check.ct_eq(&(neg_u * i));
285
286    // 3.2 set
287    r.conditional_assign(&(r * i), flipped_sign | flipped_sign_i);
288
289    // Always return the even root, per Ristretto
290    // This doesn't break Ed25519 point decoding as that doesn't expect these steps to return a
291    // specific root
292    // Ed25519 points include a dedicated sign bit to determine which root to use, so at worst
293    // this is a pointless inefficiency
294    r.conditional_negate(r.is_odd());
295
296    (correct_sign | flipped_sign, r)
297  }
298}
299
300impl Sum<FieldElement> for FieldElement {
301  fn sum<I: Iterator<Item = FieldElement>>(iter: I) -> FieldElement {
302    let mut res = FieldElement::ZERO;
303    for item in iter {
304      res += item;
305    }
306    res
307  }
308}
309
310impl<'a> Sum<&'a FieldElement> for FieldElement {
311  fn sum<I: Iterator<Item = &'a FieldElement>>(iter: I) -> FieldElement {
312    iter.copied().sum()
313  }
314}
315
316impl Product<FieldElement> for FieldElement {
317  fn product<I: Iterator<Item = FieldElement>>(iter: I) -> FieldElement {
318    let mut res = FieldElement::ONE;
319    for item in iter {
320      res *= item;
321    }
322    res
323  }
324}
325
326impl<'a> Product<&'a FieldElement> for FieldElement {
327  fn product<I: Iterator<Item = &'a FieldElement>>(iter: I) -> FieldElement {
328    iter.copied().product()
329  }
330}
331
332#[test]
333fn test_wide_modulus() {
334  let mut wide = [0; 64];
335  wide[.. 32].copy_from_slice(&MODULUS.to_le_bytes());
336  assert_eq!(wide, WIDE_MODULUS.to_le_bytes());
337}
338
339#[test]
340fn test_sqrt_m1() {
341  // Test equivalence against the known constant value
342  const SQRT_M1_MAGIC: U256 =
343    U256::from_be_hex("2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0");
344  assert_eq!(SQRT_M1.0.retrieve(), SQRT_M1_MAGIC);
345
346  // Also test equivalence against the result of the formula from RFC-8032 (modp_sqrt_m1/sqrt8k5 z)
347  // 2 ** ((MODULUS - 1) // 4) % MODULUS
348  assert_eq!(
349    SQRT_M1,
350    FieldElement::from(2u8).pow(FieldElement(ResidueType::new(
351      &(FieldElement::ZERO - FieldElement::ONE).0.retrieve().wrapping_div(&U256::from(4u8))
352    )))
353  );
354}
355
356#[test]
357fn test_field() {
358  ff_group_tests::prime_field::test_prime_field_bits::<_, FieldElement>(&mut rand_core::OsRng);
359}