dalek_ff_group/
lib.rs

1#![allow(deprecated)]
2#![cfg_attr(docsrs, feature(doc_auto_cfg))]
3#![no_std] // Prevents writing new code, in what should be a simple wrapper, which requires std
4#![doc = include_str!("../README.md")]
5#![allow(clippy::redundant_closure_call)]
6
7use core::{
8  borrow::Borrow,
9  ops::{Deref, Add, AddAssign, Sub, SubAssign, Neg, Mul, MulAssign},
10  iter::{Iterator, Sum, Product},
11  hash::{Hash, Hasher},
12};
13
14use zeroize::Zeroize;
15use subtle::{ConstantTimeEq, ConditionallySelectable};
16
17use rand_core::RngCore;
18use digest::{consts::U64, Digest, HashMarker};
19
20use subtle::{Choice, CtOption};
21
22pub use curve25519_dalek as dalek;
23
24use dalek::{
25  constants::{self, BASEPOINT_ORDER},
26  scalar::Scalar as DScalar,
27  edwards::{EdwardsPoint as DEdwardsPoint, EdwardsBasepointTable, CompressedEdwardsY},
28  ristretto::{RistrettoPoint as DRistrettoPoint, RistrettoBasepointTable, CompressedRistretto},
29};
30pub use constants::{ED25519_BASEPOINT_TABLE, RISTRETTO_BASEPOINT_TABLE};
31
32use group::{
33  ff::{Field, PrimeField, FieldBits, PrimeFieldBits},
34  Group, GroupEncoding,
35  prime::PrimeGroup,
36};
37
38mod field;
39pub use field::FieldElement;
40
41// Use black_box when possible
42#[rustversion::since(1.66)]
43use core::hint::black_box;
44#[rustversion::before(1.66)]
45fn black_box<T>(val: T) -> T {
46  val
47}
48
49fn u8_from_bool(bit_ref: &mut bool) -> u8 {
50  let bit_ref = black_box(bit_ref);
51
52  let mut bit = black_box(*bit_ref);
53  #[allow(clippy::cast_lossless)]
54  let res = black_box(bit as u8);
55  bit.zeroize();
56  debug_assert!((res | 1) == 1);
57
58  bit_ref.zeroize();
59  res
60}
61
62// Convert a boolean to a Choice in a *presumably* constant time manner
63fn choice(mut value: bool) -> Choice {
64  Choice::from(u8_from_bool(&mut value))
65}
66
67macro_rules! deref_borrow {
68  ($Source: ident, $Target: ident) => {
69    impl Deref for $Source {
70      type Target = $Target;
71
72      fn deref(&self) -> &Self::Target {
73        &self.0
74      }
75    }
76
77    impl Borrow<$Target> for $Source {
78      fn borrow(&self) -> &$Target {
79        &self.0
80      }
81    }
82
83    impl Borrow<$Target> for &$Source {
84      fn borrow(&self) -> &$Target {
85        &self.0
86      }
87    }
88  };
89}
90
91macro_rules! constant_time {
92  ($Value: ident, $Inner: ident) => {
93    impl ConstantTimeEq for $Value {
94      fn ct_eq(&self, other: &Self) -> Choice {
95        self.0.ct_eq(&other.0)
96      }
97    }
98
99    impl ConditionallySelectable for $Value {
100      fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
101        $Value($Inner::conditional_select(&a.0, &b.0, choice))
102      }
103    }
104  };
105}
106pub(crate) use constant_time;
107
108macro_rules! math_op {
109  (
110    $Value: ident,
111    $Other: ident,
112    $Op: ident,
113    $op_fn: ident,
114    $Assign: ident,
115    $assign_fn: ident,
116    $function: expr
117  ) => {
118    impl $Op<$Other> for $Value {
119      type Output = $Value;
120      fn $op_fn(self, other: $Other) -> Self::Output {
121        Self($function(self.0, other.0))
122      }
123    }
124    impl $Assign<$Other> for $Value {
125      fn $assign_fn(&mut self, other: $Other) {
126        self.0 = $function(self.0, other.0);
127      }
128    }
129    impl<'a> $Op<&'a $Other> for $Value {
130      type Output = $Value;
131      fn $op_fn(self, other: &'a $Other) -> Self::Output {
132        Self($function(self.0, other.0))
133      }
134    }
135    impl<'a> $Assign<&'a $Other> for $Value {
136      fn $assign_fn(&mut self, other: &'a $Other) {
137        self.0 = $function(self.0, other.0);
138      }
139    }
140  };
141}
142pub(crate) use math_op;
143
144macro_rules! math {
145  ($Value: ident, $Factor: ident, $add: expr, $sub: expr, $mul: expr) => {
146    math_op!($Value, $Value, Add, add, AddAssign, add_assign, $add);
147    math_op!($Value, $Value, Sub, sub, SubAssign, sub_assign, $sub);
148    math_op!($Value, $Factor, Mul, mul, MulAssign, mul_assign, $mul);
149  };
150}
151pub(crate) use math;
152
153macro_rules! math_neg {
154  ($Value: ident, $Factor: ident, $add: expr, $sub: expr, $mul: expr) => {
155    math!($Value, $Factor, $add, $sub, $mul);
156
157    impl Neg for $Value {
158      type Output = Self;
159      fn neg(self) -> Self::Output {
160        Self(-self.0)
161      }
162    }
163  };
164}
165
166/// Wrapper around the dalek Scalar type.
167#[derive(Clone, Copy, PartialEq, Eq, Default, Debug, Zeroize)]
168pub struct Scalar(pub DScalar);
169deref_borrow!(Scalar, DScalar);
170constant_time!(Scalar, DScalar);
171math_neg!(Scalar, Scalar, DScalar::add, DScalar::sub, DScalar::mul);
172
173macro_rules! from_wrapper {
174  ($uint: ident) => {
175    impl From<$uint> for Scalar {
176      fn from(a: $uint) -> Scalar {
177        Scalar(DScalar::from(a))
178      }
179    }
180  };
181}
182
183from_wrapper!(u8);
184from_wrapper!(u16);
185from_wrapper!(u32);
186from_wrapper!(u64);
187from_wrapper!(u128);
188
189impl Scalar {
190  pub fn pow(&self, other: Scalar) -> Scalar {
191    let mut table = [Scalar::ONE; 16];
192    table[1] = *self;
193    for i in 2 .. 16 {
194      table[i] = table[i - 1] * self;
195    }
196
197    let mut res = Scalar::ONE;
198    let mut bits = 0;
199    for (i, mut bit) in other.to_le_bits().iter_mut().rev().enumerate() {
200      bits <<= 1;
201      let mut bit = u8_from_bool(&mut bit);
202      bits |= bit;
203      bit.zeroize();
204
205      if ((i + 1) % 4) == 0 {
206        if i != 3 {
207          for _ in 0 .. 4 {
208            res *= res;
209          }
210        }
211
212        let mut scale_by = Scalar::ONE;
213        #[allow(clippy::needless_range_loop)]
214        for i in 0 .. 16 {
215          #[allow(clippy::cast_possible_truncation)] // Safe since 0 .. 16
216          {
217            scale_by = <_>::conditional_select(&scale_by, &table[i], bits.ct_eq(&(i as u8)));
218          }
219        }
220        res *= scale_by;
221        bits = 0;
222      }
223    }
224    res
225  }
226
227  /// Perform wide reduction on a 64-byte array to create a Scalar without bias.
228  pub fn from_bytes_mod_order_wide(bytes: &[u8; 64]) -> Scalar {
229    Self(DScalar::from_bytes_mod_order_wide(bytes))
230  }
231
232  /// Derive a Scalar without bias from a digest via wide reduction.
233  pub fn from_hash<D: Digest<OutputSize = U64> + HashMarker>(hash: D) -> Scalar {
234    let mut output = [0u8; 64];
235    output.copy_from_slice(&hash.finalize());
236    let res = Scalar(DScalar::from_bytes_mod_order_wide(&output));
237    output.zeroize();
238    res
239  }
240}
241
242impl Field for Scalar {
243  const ZERO: Scalar = Scalar(DScalar::ZERO);
244  const ONE: Scalar = Scalar(DScalar::ONE);
245
246  fn random(rng: impl RngCore) -> Self {
247    Self(<DScalar as Field>::random(rng))
248  }
249
250  fn square(&self) -> Self {
251    Self(self.0.square())
252  }
253  fn double(&self) -> Self {
254    Self(self.0.double())
255  }
256  fn invert(&self) -> CtOption<Self> {
257    <DScalar as Field>::invert(&self.0).map(Self)
258  }
259
260  fn sqrt(&self) -> CtOption<Self> {
261    self.0.sqrt().map(Self)
262  }
263
264  fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
265    let (choice, res) = DScalar::sqrt_ratio(num, div);
266    (choice, Self(res))
267  }
268}
269
270impl PrimeField for Scalar {
271  type Repr = [u8; 32];
272
273  const MODULUS: &'static str = <DScalar as PrimeField>::MODULUS;
274
275  const NUM_BITS: u32 = <DScalar as PrimeField>::NUM_BITS;
276  const CAPACITY: u32 = <DScalar as PrimeField>::CAPACITY;
277
278  const TWO_INV: Scalar = Scalar(<DScalar as PrimeField>::TWO_INV);
279
280  const MULTIPLICATIVE_GENERATOR: Scalar =
281    Scalar(<DScalar as PrimeField>::MULTIPLICATIVE_GENERATOR);
282  const S: u32 = <DScalar as PrimeField>::S;
283
284  const ROOT_OF_UNITY: Scalar = Scalar(<DScalar as PrimeField>::ROOT_OF_UNITY);
285  const ROOT_OF_UNITY_INV: Scalar = Scalar(<DScalar as PrimeField>::ROOT_OF_UNITY_INV);
286
287  const DELTA: Scalar = Scalar(<DScalar as PrimeField>::DELTA);
288
289  fn from_repr(bytes: [u8; 32]) -> CtOption<Self> {
290    <DScalar as PrimeField>::from_repr(bytes).map(Scalar)
291  }
292  fn to_repr(&self) -> [u8; 32] {
293    self.0.to_repr()
294  }
295
296  fn is_odd(&self) -> Choice {
297    self.0.is_odd()
298  }
299
300  fn from_u128(num: u128) -> Self {
301    Scalar(DScalar::from_u128(num))
302  }
303}
304
305impl PrimeFieldBits for Scalar {
306  type ReprBits = [u8; 32];
307
308  fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
309    self.to_repr().into()
310  }
311
312  fn char_le_bits() -> FieldBits<Self::ReprBits> {
313    BASEPOINT_ORDER.to_bytes().into()
314  }
315}
316
317impl Sum<Scalar> for Scalar {
318  fn sum<I: Iterator<Item = Scalar>>(iter: I) -> Scalar {
319    Self(DScalar::sum(iter))
320  }
321}
322
323impl<'a> Sum<&'a Scalar> for Scalar {
324  fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Scalar {
325    Self(DScalar::sum(iter))
326  }
327}
328
329impl Product<Scalar> for Scalar {
330  fn product<I: Iterator<Item = Scalar>>(iter: I) -> Scalar {
331    Self(DScalar::product(iter))
332  }
333}
334
335impl<'a> Product<&'a Scalar> for Scalar {
336  fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Scalar {
337    Self(DScalar::product(iter))
338  }
339}
340
341macro_rules! dalek_group {
342  (
343    $Point: ident,
344    $DPoint: ident,
345    $torsion_free: expr,
346
347    $Table: ident,
348
349    $DCompressed: ident,
350
351    $BASEPOINT_POINT: ident,
352    $BASEPOINT_TABLE: ident
353  ) => {
354    /// Wrapper around the dalek Point type. For Ed25519, this is restricted to the prime subgroup.
355    #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroize)]
356    pub struct $Point(pub $DPoint);
357    deref_borrow!($Point, $DPoint);
358    constant_time!($Point, $DPoint);
359    math_neg!($Point, Scalar, $DPoint::add, $DPoint::sub, $DPoint::mul);
360
361    /// The basepoint for this curve.
362    pub const $BASEPOINT_POINT: $Point = $Point(constants::$BASEPOINT_POINT);
363
364    impl Sum<$Point> for $Point {
365      fn sum<I: Iterator<Item = $Point>>(iter: I) -> $Point {
366        Self($DPoint::sum(iter))
367      }
368    }
369    impl<'a> Sum<&'a $Point> for $Point {
370      fn sum<I: Iterator<Item = &'a $Point>>(iter: I) -> $Point {
371        Self($DPoint::sum(iter))
372      }
373    }
374
375    impl Group for $Point {
376      type Scalar = Scalar;
377      fn random(mut rng: impl RngCore) -> Self {
378        loop {
379          let mut bytes = [0; 32];
380          rng.fill_bytes(&mut bytes);
381          let Some(point) = Option::<$Point>::from($Point::from_bytes(&bytes)) else {
382            continue;
383          };
384          // Ban identity, per the trait specification
385          if !bool::from(point.is_identity()) {
386            return point;
387          }
388        }
389      }
390      fn identity() -> Self {
391        Self($DPoint::identity())
392      }
393      fn generator() -> Self {
394        $BASEPOINT_POINT
395      }
396      fn is_identity(&self) -> Choice {
397        self.0.ct_eq(&$DPoint::identity())
398      }
399      fn double(&self) -> Self {
400        Self(self.0.double())
401      }
402    }
403
404    impl GroupEncoding for $Point {
405      type Repr = [u8; 32];
406
407      fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
408        let decompressed = $DCompressed(*bytes).decompress();
409        // TODO: Same note on unwrap_or as above
410        let point = decompressed.unwrap_or($DPoint::identity());
411        CtOption::new(
412          $Point(point),
413          choice(black_box(decompressed).is_some()) & choice($torsion_free(point)),
414        )
415      }
416
417      fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
418        $Point::from_bytes(bytes)
419      }
420
421      fn to_bytes(&self) -> Self::Repr {
422        self.0.to_bytes()
423      }
424    }
425
426    impl PrimeGroup for $Point {}
427
428    impl Mul<Scalar> for &$Table {
429      type Output = $Point;
430      fn mul(self, b: Scalar) -> $Point {
431        $Point(&b.0 * self)
432      }
433    }
434
435    // Support being used as a key in a table
436    // While it is expensive as a key, due to the field operations required, there's frequently
437    // use cases for public key -> value lookups
438    #[allow(unknown_lints, renamed_and_removed_lints)]
439    #[allow(clippy::derived_hash_with_manual_eq, clippy::derive_hash_xor_eq)]
440    impl Hash for $Point {
441      fn hash<H: Hasher>(&self, state: &mut H) {
442        self.to_bytes().hash(state);
443      }
444    }
445  };
446}
447
448dalek_group!(
449  EdwardsPoint,
450  DEdwardsPoint,
451  |point: DEdwardsPoint| point.is_torsion_free(),
452  EdwardsBasepointTable,
453  CompressedEdwardsY,
454  ED25519_BASEPOINT_POINT,
455  ED25519_BASEPOINT_TABLE
456);
457
458impl EdwardsPoint {
459  pub fn mul_by_cofactor(&self) -> EdwardsPoint {
460    EdwardsPoint(self.0.mul_by_cofactor())
461  }
462}
463
464dalek_group!(
465  RistrettoPoint,
466  DRistrettoPoint,
467  |_| true,
468  RistrettoBasepointTable,
469  CompressedRistretto,
470  RISTRETTO_BASEPOINT_POINT,
471  RISTRETTO_BASEPOINT_TABLE
472);
473
474#[test]
475fn test_ed25519_group() {
476  ff_group_tests::group::test_prime_group_bits::<_, EdwardsPoint>(&mut rand_core::OsRng);
477}
478
479#[test]
480fn test_ristretto_group() {
481  ff_group_tests::group::test_prime_group_bits::<_, RistrettoPoint>(&mut rand_core::OsRng);
482}