1use core::{
2 ops::{Add, AddAssign, Sub, SubAssign, Neg, Mul, MulAssign},
3 iter::{Sum, Product},
4};
5
6use zeroize::Zeroize;
7use rand_core::RngCore;
8
9use subtle::{
10 Choice, CtOption, ConstantTimeEq, ConstantTimeLess, ConditionallyNegatable,
11 ConditionallySelectable,
12};
13
14use crypto_bigint::{
15 Integer, NonZero, Encoding, U256, U512,
16 modular::constant_mod::{ResidueParams, Residue},
17 impl_modulus,
18};
19
20use group::ff::{Field, PrimeField, FieldBits, PrimeFieldBits};
21
22use crate::{u8_from_bool, constant_time, math_op, math};
23
24const MODULUS: U256 = U256::from_u8(1).shl_vartime(255).saturating_sub(&U256::from_u8(19));
27const WIDE_MODULUS: U512 = U256::ZERO.concat(&MODULUS);
28
29impl_modulus!(
30 FieldModulus,
31 U256,
32 "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed"
34);
35type ResidueType = Residue<FieldModulus, { FieldModulus::LIMBS }>;
36
37#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)]
39pub struct FieldElement(ResidueType);
40
41const SQRT_M1: FieldElement = FieldElement(
45 ResidueType::new(&U256::from_u8(2))
46 .pow(&MODULUS.saturating_sub(&U256::ONE).wrapping_div(&U256::from_u8(4))),
47);
48
49const MOD_3_8: FieldElement = FieldElement(ResidueType::new(
51 &MODULUS.saturating_add(&U256::from_u8(3)).wrapping_div(&U256::from_u8(8)),
52));
53
54const MOD_5_8: FieldElement = FieldElement(ResidueType::sub(&MOD_3_8.0, &ResidueType::ONE));
56
57fn reduce(x: U512) -> ResidueType {
58 ResidueType::new(&U256::from_le_slice(
59 &x.rem(&NonZero::new(WIDE_MODULUS).unwrap()).to_le_bytes()[.. 32],
60 ))
61}
62
63constant_time!(FieldElement, ResidueType);
64math!(
65 FieldElement,
66 FieldElement,
67 |x: ResidueType, y: ResidueType| x.add(&y),
68 |x: ResidueType, y: ResidueType| x.sub(&y),
69 |x: ResidueType, y: ResidueType| x.mul(&y)
70);
71
72macro_rules! from_wrapper {
73 ($uint: ident) => {
74 impl From<$uint> for FieldElement {
75 fn from(a: $uint) -> FieldElement {
76 Self(ResidueType::new(&U256::from(a)))
77 }
78 }
79 };
80}
81
82from_wrapper!(u8);
83from_wrapper!(u16);
84from_wrapper!(u32);
85from_wrapper!(u64);
86from_wrapper!(u128);
87
88impl Neg for FieldElement {
89 type Output = Self;
90 fn neg(self) -> Self::Output {
91 Self(self.0.neg())
92 }
93}
94
95impl<'a> Neg for &'a FieldElement {
96 type Output = FieldElement;
97 fn neg(self) -> Self::Output {
98 (*self).neg()
99 }
100}
101
102impl Field for FieldElement {
103 const ZERO: Self = Self(ResidueType::ZERO);
104 const ONE: Self = Self(ResidueType::ONE);
105
106 fn random(mut rng: impl RngCore) -> Self {
107 let mut bytes = [0; 64];
108 rng.fill_bytes(&mut bytes);
109 FieldElement(reduce(U512::from_le_bytes(bytes)))
110 }
111
112 fn square(&self) -> Self {
113 FieldElement(self.0.square())
114 }
115 fn double(&self) -> Self {
116 FieldElement(self.0.add(&self.0))
117 }
118
119 fn invert(&self) -> CtOption<Self> {
120 const NEG_2: FieldElement =
121 FieldElement(ResidueType::new(&MODULUS.saturating_sub(&U256::from_u8(2))));
122 CtOption::new(self.pow(NEG_2), !self.is_zero())
123 }
124
125 fn sqrt(&self) -> CtOption<Self> {
127 let tv1 = self.pow(MOD_3_8);
128 let tv2 = tv1 * SQRT_M1;
129 let candidate = Self::conditional_select(&tv2, &tv1, tv1.square().ct_eq(self));
130 CtOption::new(candidate, candidate.square().ct_eq(self))
131 }
132
133 fn sqrt_ratio(u: &FieldElement, v: &FieldElement) -> (Choice, FieldElement) {
134 let i = SQRT_M1;
135
136 let u = *u;
137 let v = *v;
138
139 let v3 = v.square() * v;
140 let v7 = v3.square() * v;
141 let mut r = (u * v3) * (u * v7).pow(MOD_5_8);
142
143 let check = v * r.square();
144 let correct_sign = check.ct_eq(&u);
145 let flipped_sign = check.ct_eq(&(-u));
146 let flipped_sign_i = check.ct_eq(&((-u) * i));
147
148 r.conditional_assign(&(r * i), flipped_sign | flipped_sign_i);
149
150 let r_is_negative = r.is_odd();
151 r.conditional_negate(r_is_negative);
152
153 (correct_sign | flipped_sign, r)
154 }
155}
156
157impl PrimeField for FieldElement {
158 type Repr = [u8; 32];
159
160 const MODULUS: &'static str = "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed";
162
163 const NUM_BITS: u32 = 255;
164 const CAPACITY: u32 = 254;
165
166 const TWO_INV: Self = FieldElement(ResidueType::new(&U256::from_u8(2)).invert().0);
167
168 const MULTIPLICATIVE_GENERATOR: Self = Self(ResidueType::new(&U256::from_u8(2)));
171 const S: u32 = 2;
174
175 const ROOT_OF_UNITY: Self = FieldElement(ResidueType::new(&U256::from_be_hex(
178 "2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0",
179 )));
180 const ROOT_OF_UNITY_INV: Self = FieldElement(Self::ROOT_OF_UNITY.0.invert().0);
182
183 const DELTA: Self = FieldElement(ResidueType::new(&U256::from_be_hex(
186 "0000000000000000000000000000000000000000000000000000000000000010",
187 )));
188
189 fn from_repr(bytes: [u8; 32]) -> CtOption<Self> {
190 let res = U256::from_le_bytes(bytes);
191 CtOption::new(Self(ResidueType::new(&res)), res.ct_lt(&MODULUS))
192 }
193 fn to_repr(&self) -> [u8; 32] {
194 self.0.retrieve().to_le_bytes()
195 }
196
197 fn is_odd(&self) -> Choice {
198 self.0.retrieve().is_odd()
199 }
200
201 fn from_u128(num: u128) -> Self {
202 Self::from(num)
203 }
204}
205
206impl PrimeFieldBits for FieldElement {
207 type ReprBits = [u8; 32];
208
209 fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
210 self.to_repr().into()
211 }
212
213 fn char_le_bits() -> FieldBits<Self::ReprBits> {
214 MODULUS.to_le_bytes().into()
215 }
216}
217
218impl FieldElement {
219 pub fn from_square(value: [u8; 32]) -> FieldElement {
221 let value = U256::from_le_bytes(value);
222 FieldElement(reduce(U512::from(value.mul_wide(&value))))
223 }
224
225 pub fn pow(&self, other: FieldElement) -> FieldElement {
227 let mut table = [FieldElement::ONE; 16];
228 table[1] = *self;
229 for i in 2 .. 16 {
230 table[i] = table[i - 1] * self;
231 }
232
233 let mut res = FieldElement::ONE;
234 let mut bits = 0;
235 for (i, mut bit) in other.to_le_bits().iter_mut().rev().enumerate() {
236 bits <<= 1;
237 let mut bit = u8_from_bool(&mut bit);
238 bits |= bit;
239 bit.zeroize();
240
241 if ((i + 1) % 4) == 0 {
242 if i != 3 {
243 for _ in 0 .. 4 {
244 res *= res;
245 }
246 }
247
248 let mut scale_by = FieldElement::ONE;
249 #[allow(clippy::needless_range_loop)]
250 for i in 0 .. 16 {
251 #[allow(clippy::cast_possible_truncation)] {
253 scale_by = <_>::conditional_select(&scale_by, &table[i], bits.ct_eq(&(i as u8)));
254 }
255 }
256 res *= scale_by;
257 bits = 0;
258 }
259 }
260 res
261 }
262
263 pub fn sqrt_ratio_i(u: FieldElement, v: FieldElement) -> (Choice, FieldElement) {
270 let i = SQRT_M1;
271
272 let v3 = v.square() * v;
273 let v7 = v3.square() * v;
274 let mut r = (u * v3) * (u * v7).pow(MOD_5_8);
276
277 let check = v * r.square();
279 let correct_sign = check.ct_eq(&u);
280 let neg_u = -u;
282 let flipped_sign = check.ct_eq(&neg_u);
283 let flipped_sign_i = check.ct_eq(&(neg_u * i));
285
286 r.conditional_assign(&(r * i), flipped_sign | flipped_sign_i);
288
289 r.conditional_negate(r.is_odd());
295
296 (correct_sign | flipped_sign, r)
297 }
298}
299
300impl Sum<FieldElement> for FieldElement {
301 fn sum<I: Iterator<Item = FieldElement>>(iter: I) -> FieldElement {
302 let mut res = FieldElement::ZERO;
303 for item in iter {
304 res += item;
305 }
306 res
307 }
308}
309
310impl<'a> Sum<&'a FieldElement> for FieldElement {
311 fn sum<I: Iterator<Item = &'a FieldElement>>(iter: I) -> FieldElement {
312 iter.copied().sum()
313 }
314}
315
316impl Product<FieldElement> for FieldElement {
317 fn product<I: Iterator<Item = FieldElement>>(iter: I) -> FieldElement {
318 let mut res = FieldElement::ONE;
319 for item in iter {
320 res *= item;
321 }
322 res
323 }
324}
325
326impl<'a> Product<&'a FieldElement> for FieldElement {
327 fn product<I: Iterator<Item = &'a FieldElement>>(iter: I) -> FieldElement {
328 iter.copied().product()
329 }
330}
331
332#[test]
333fn test_wide_modulus() {
334 let mut wide = [0; 64];
335 wide[.. 32].copy_from_slice(&MODULUS.to_le_bytes());
336 assert_eq!(wide, WIDE_MODULUS.to_le_bytes());
337}
338
339#[test]
340fn test_sqrt_m1() {
341 const SQRT_M1_MAGIC: U256 =
343 U256::from_be_hex("2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0");
344 assert_eq!(SQRT_M1.0.retrieve(), SQRT_M1_MAGIC);
345
346 assert_eq!(
349 SQRT_M1,
350 FieldElement::from(2u8).pow(FieldElement(ResidueType::new(
351 &(FieldElement::ZERO - FieldElement::ONE).0.retrieve().wrapping_div(&U256::from(4u8))
352 )))
353 );
354}
355
356#[test]
357fn test_field() {
358 ff_group_tests::prime_field::test_prime_field_bits::<_, FieldElement>(&mut rand_core::OsRng);
359}