monero_bulletproofs/
scalar_vector.rs

1use core::{
2  borrow::Borrow,
3  ops::{Index, IndexMut, Add, Sub, Mul},
4};
5use std_shims::{vec, vec::Vec};
6
7use zeroize::{Zeroize, ZeroizeOnDrop};
8
9use curve25519_dalek::{scalar::Scalar, edwards::EdwardsPoint};
10
11use crate::core::multiexp;
12
13#[derive(Clone, PartialEq, Eq, Debug, Zeroize, ZeroizeOnDrop)]
14pub(crate) struct ScalarVector(pub(crate) Vec<Scalar>);
15
16impl Index<usize> for ScalarVector {
17  type Output = Scalar;
18  fn index(&self, index: usize) -> &Scalar {
19    &self.0[index]
20  }
21}
22impl IndexMut<usize> for ScalarVector {
23  fn index_mut(&mut self, index: usize) -> &mut Scalar {
24    &mut self.0[index]
25  }
26}
27
28impl<S: Borrow<Scalar>> Add<S> for ScalarVector {
29  type Output = ScalarVector;
30  fn add(mut self, scalar: S) -> ScalarVector {
31    for s in &mut self.0 {
32      *s += scalar.borrow();
33    }
34    self
35  }
36}
37impl<S: Borrow<Scalar>> Sub<S> for ScalarVector {
38  type Output = ScalarVector;
39  fn sub(mut self, scalar: S) -> ScalarVector {
40    for s in &mut self.0 {
41      *s -= scalar.borrow();
42    }
43    self
44  }
45}
46impl<S: Borrow<Scalar>> Mul<S> for ScalarVector {
47  type Output = ScalarVector;
48  fn mul(mut self, scalar: S) -> ScalarVector {
49    for s in &mut self.0 {
50      *s *= scalar.borrow();
51    }
52    self
53  }
54}
55
56impl Add<&ScalarVector> for ScalarVector {
57  type Output = ScalarVector;
58  fn add(mut self, other: &ScalarVector) -> ScalarVector {
59    debug_assert_eq!(self.len(), other.len());
60    for (s, o) in self.0.iter_mut().zip(other.0.iter()) {
61      *s += o;
62    }
63    self
64  }
65}
66impl Sub<&ScalarVector> for ScalarVector {
67  type Output = ScalarVector;
68  fn sub(mut self, other: &ScalarVector) -> ScalarVector {
69    debug_assert_eq!(self.len(), other.len());
70    for (s, o) in self.0.iter_mut().zip(other.0.iter()) {
71      *s -= o;
72    }
73    self
74  }
75}
76impl Mul<&ScalarVector> for ScalarVector {
77  type Output = ScalarVector;
78  fn mul(mut self, other: &ScalarVector) -> ScalarVector {
79    debug_assert_eq!(self.len(), other.len());
80    for (s, o) in self.0.iter_mut().zip(other.0.iter()) {
81      *s *= o;
82    }
83    self
84  }
85}
86
87impl Mul<&[EdwardsPoint]> for &ScalarVector {
88  type Output = EdwardsPoint;
89  fn mul(self, b: &[EdwardsPoint]) -> EdwardsPoint {
90    debug_assert_eq!(self.len(), b.len());
91    let mut multiexp_args = self.0.iter().copied().zip(b.iter().copied()).collect::<Vec<_>>();
92    let res = multiexp(&multiexp_args);
93    multiexp_args.zeroize();
94    res
95  }
96}
97
98impl ScalarVector {
99  pub(crate) fn new(len: usize) -> Self {
100    ScalarVector(vec![Scalar::ZERO; len])
101  }
102
103  pub(crate) fn powers(x: Scalar, len: usize) -> Self {
104    debug_assert!(len != 0);
105
106    let mut res = Vec::with_capacity(len);
107    res.push(Scalar::ONE);
108    res.push(x);
109    for i in 2 .. len {
110      res.push(res[i - 1] * x);
111    }
112    res.truncate(len);
113    ScalarVector(res)
114  }
115
116  pub(crate) fn len(&self) -> usize {
117    self.0.len()
118  }
119
120  pub(crate) fn sum(mut self) -> Scalar {
121    self.0.drain(..).sum()
122  }
123
124  pub(crate) fn inner_product(self, vector: &Self) -> Scalar {
125    (self * vector).sum()
126  }
127
128  pub(crate) fn weighted_inner_product(self, vector: &Self, y: &Self) -> Scalar {
129    (self * vector * y).sum()
130  }
131
132  pub(crate) fn split(mut self) -> (Self, Self) {
133    debug_assert!(self.len() > 1);
134    let r = self.0.split_off(self.0.len() / 2);
135    debug_assert_eq!(self.len(), r.len());
136    (self, ScalarVector(r))
137  }
138}