monero_bulletproofs/plus/
weighted_inner_product.rs

1use std_shims::{vec, vec::Vec};
2
3use rand_core::{RngCore, CryptoRng};
4use zeroize::{Zeroize, ZeroizeOnDrop};
5
6use curve25519_dalek::{scalar::Scalar, edwards::EdwardsPoint};
7
8use monero_io::CompressedPoint;
9use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
10use crate::{
11  core::{multiexp, multiexp_vartime, challenge_products},
12  batch_verifier::BulletproofsPlusBatchVerifier,
13  plus::{ScalarVector, PointVector, GeneratorsList, BpPlusGenerators, padded_pow_of_2},
14};
15
16// Figure 1 of the Bulletproofs+ paper
17#[derive(Clone, Debug)]
18pub(crate) struct WipStatement {
19  generators: BpPlusGenerators,
20  P: EdwardsPoint,
21  y: ScalarVector,
22}
23
24impl Zeroize for WipStatement {
25  fn zeroize(&mut self) {
26    self.P.zeroize();
27    self.y.zeroize();
28  }
29}
30
31#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
32pub(crate) struct WipWitness {
33  a: ScalarVector,
34  b: ScalarVector,
35  alpha: Scalar,
36}
37
38impl WipWitness {
39  pub(crate) fn new(mut a: ScalarVector, mut b: ScalarVector, alpha: Scalar) -> Option<Self> {
40    if a.0.is_empty() || (a.len() != b.len()) {
41      return None;
42    }
43
44    // Pad to the nearest power of 2
45    let missing = padded_pow_of_2(a.len()) - a.len();
46    a.0.reserve(missing);
47    b.0.reserve(missing);
48    for _ in 0 .. missing {
49      a.0.push(Scalar::ZERO);
50      b.0.push(Scalar::ZERO);
51    }
52
53    Some(Self { a, b, alpha })
54  }
55}
56
57#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
58pub(crate) struct WipProof {
59  pub(crate) L: Vec<CompressedPoint>,
60  pub(crate) R: Vec<CompressedPoint>,
61  pub(crate) A: CompressedPoint,
62  pub(crate) B: CompressedPoint,
63  pub(crate) r_answer: Scalar,
64  pub(crate) s_answer: Scalar,
65  pub(crate) delta_answer: Scalar,
66}
67
68impl WipStatement {
69  pub(crate) fn new(generators: BpPlusGenerators, P: EdwardsPoint, y: Scalar) -> Self {
70    debug_assert_eq!(generators.len(), padded_pow_of_2(generators.len()));
71
72    // y ** n
73    let mut y_vec = ScalarVector::new(generators.len());
74    y_vec[0] = y;
75    for i in 1 .. y_vec.len() {
76      y_vec[i] = y_vec[i - 1] * y;
77    }
78
79    Self { generators, P, y: y_vec }
80  }
81
82  fn transcript_L_R(transcript: &mut Scalar, L: CompressedPoint, R: CompressedPoint) -> Scalar {
83    let e = keccak256_to_scalar([transcript.to_bytes(), L.to_bytes(), R.to_bytes()].concat());
84    *transcript = e;
85    e
86  }
87
88  fn transcript_A_B(transcript: &mut Scalar, A: CompressedPoint, B: CompressedPoint) -> Scalar {
89    let e = keccak256_to_scalar([transcript.to_bytes(), A.to_bytes(), B.to_bytes()].concat());
90    *transcript = e;
91    e
92  }
93
94  // Prover's variant of the shared code block to calculate G/H/P when n > 1
95  // Returns each permutation of G/H since the prover needs to do operation on each permutation
96  // P is dropped as it's unused in the prover's path
97  #[allow(clippy::too_many_arguments)]
98  fn next_G_H(
99    transcript: &mut Scalar,
100    mut g_bold1: PointVector,
101    mut g_bold2: PointVector,
102    mut h_bold1: PointVector,
103    mut h_bold2: PointVector,
104    L: CompressedPoint,
105    R: CompressedPoint,
106    y_inv_n_hat: Scalar,
107  ) -> (Scalar, Scalar, Scalar, Scalar, PointVector, PointVector) {
108    debug_assert_eq!(g_bold1.len(), g_bold2.len());
109    debug_assert_eq!(h_bold1.len(), h_bold2.len());
110    debug_assert_eq!(g_bold1.len(), h_bold1.len());
111
112    let e = Self::transcript_L_R(transcript, L, R);
113    let inv_e = e.invert();
114
115    // This vartime is safe as all of these arguments are public
116    let mut new_g_bold = Vec::with_capacity(g_bold1.len());
117    let e_y_inv = e * y_inv_n_hat;
118    for g_bold in g_bold1.0.drain(..).zip(g_bold2.0.drain(..)) {
119      new_g_bold.push(multiexp_vartime(&[(inv_e, g_bold.0), (e_y_inv, g_bold.1)]));
120    }
121
122    let mut new_h_bold = Vec::with_capacity(h_bold1.len());
123    for h_bold in h_bold1.0.drain(..).zip(h_bold2.0.drain(..)) {
124      new_h_bold.push(multiexp_vartime(&[(e, h_bold.0), (inv_e, h_bold.1)]));
125    }
126
127    let e_square = e * e;
128    let inv_e_square = inv_e * inv_e;
129
130    (e, inv_e, e_square, inv_e_square, PointVector(new_g_bold), PointVector(new_h_bold))
131  }
132
133  pub(crate) fn prove<R: RngCore + CryptoRng>(
134    self,
135    rng: &mut R,
136    mut transcript: Scalar,
137    witness: &WipWitness,
138  ) -> Option<WipProof> {
139    let WipStatement { generators, P, mut y } = self;
140    #[cfg(not(debug_assertions))]
141    let _ = P;
142
143    if generators.len() != witness.a.len() {
144      return None;
145    }
146    let (g, h) = (BpPlusGenerators::g(), BpPlusGenerators::h());
147    let mut g_bold = vec![];
148    let mut h_bold = vec![];
149    for i in 0 .. generators.len() {
150      g_bold.push(generators.generator(GeneratorsList::GBold, i));
151      h_bold.push(generators.generator(GeneratorsList::HBold, i));
152    }
153    let mut g_bold = PointVector(g_bold);
154    let mut h_bold = PointVector(h_bold);
155
156    let mut y_inv = {
157      let mut i = 1;
158      let mut to_invert = vec![];
159      while i < g_bold.len() {
160        to_invert.push(y[i - 1]);
161        i *= 2;
162      }
163      Scalar::batch_invert(&mut to_invert);
164      to_invert
165    };
166
167    // Check P has the expected relationship
168    #[cfg(debug_assertions)]
169    {
170      let mut P_terms = witness
171        .a
172        .0
173        .iter()
174        .copied()
175        .zip(g_bold.0.iter().copied())
176        .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied()))
177        .collect::<Vec<_>>();
178      P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g));
179      P_terms.push((witness.alpha, h));
180      debug_assert_eq!(multiexp(&P_terms), P);
181      P_terms.zeroize();
182    }
183
184    let mut a = witness.a.clone();
185    let mut b = witness.b.clone();
186    let mut alpha = witness.alpha;
187
188    // From here on, g_bold.len() is used as n
189    debug_assert_eq!(g_bold.len(), a.len());
190
191    let mut L_vec = vec![];
192    let mut R_vec = vec![];
193
194    // else n > 1 case from figure 1
195    while g_bold.len() > 1 {
196      let (a1, a2) = a.clone().split();
197      let (b1, b2) = b.clone().split();
198      let (g_bold1, g_bold2) = g_bold.split();
199      let (h_bold1, h_bold2) = h_bold.split();
200
201      let n_hat = g_bold1.len();
202      debug_assert_eq!(a1.len(), n_hat);
203      debug_assert_eq!(a2.len(), n_hat);
204      debug_assert_eq!(b1.len(), n_hat);
205      debug_assert_eq!(b2.len(), n_hat);
206      debug_assert_eq!(g_bold1.len(), n_hat);
207      debug_assert_eq!(g_bold2.len(), n_hat);
208      debug_assert_eq!(h_bold1.len(), n_hat);
209      debug_assert_eq!(h_bold2.len(), n_hat);
210
211      let y_n_hat = y[n_hat - 1];
212      y.0.truncate(n_hat);
213
214      let d_l = Scalar::random(&mut *rng);
215      let d_r = Scalar::random(&mut *rng);
216
217      let c_l = a1.clone().weighted_inner_product(&b2, &y);
218      let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);
219
220      let y_inv_n_hat = y_inv
221        .pop()
222        .expect("couldn't pop y_inv despite y_inv being of same length as times iterated");
223
224      let mut L_terms = (a1.clone() * y_inv_n_hat)
225        .0
226        .drain(..)
227        .zip(g_bold2.0.iter().copied())
228        .chain(b2.0.iter().copied().zip(h_bold1.0.iter().copied()))
229        .collect::<Vec<_>>();
230      L_terms.push((c_l, g));
231      L_terms.push((d_l, h));
232      let L = CompressedPoint::from((multiexp(&L_terms) * INV_EIGHT()).compress());
233      L_vec.push(L);
234      L_terms.zeroize();
235
236      let mut R_terms = (a2.clone() * y_n_hat)
237        .0
238        .drain(..)
239        .zip(g_bold1.0.iter().copied())
240        .chain(b1.0.iter().copied().zip(h_bold2.0.iter().copied()))
241        .collect::<Vec<_>>();
242      R_terms.push((c_r, g));
243      R_terms.push((d_r, h));
244      let R = CompressedPoint::from((multiexp(&R_terms) * INV_EIGHT()).compress());
245      R_vec.push(R);
246      R_terms.zeroize();
247
248      let (e, inv_e, e_square, inv_e_square);
249      (e, inv_e, e_square, inv_e_square, g_bold, h_bold) =
250        Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat);
251
252      a = (a1 * e) + &(a2 * (y_n_hat * inv_e));
253      b = (b1 * inv_e) + &(b2 * e);
254      alpha += (d_l * e_square) + (d_r * inv_e_square);
255
256      debug_assert_eq!(g_bold.len(), a.len());
257      debug_assert_eq!(g_bold.len(), h_bold.len());
258      debug_assert_eq!(g_bold.len(), b.len());
259    }
260
261    // n == 1 case from figure 1
262    debug_assert_eq!(g_bold.len(), 1);
263    debug_assert_eq!(h_bold.len(), 1);
264
265    debug_assert_eq!(a.len(), 1);
266    debug_assert_eq!(b.len(), 1);
267
268    let r = Scalar::random(&mut *rng);
269    let s = Scalar::random(&mut *rng);
270    let delta = Scalar::random(&mut *rng);
271    let eta = Scalar::random(&mut *rng);
272
273    let ry = r * y[0];
274
275    let mut A_terms =
276      vec![(r, g_bold[0]), (s, h_bold[0]), ((ry * b[0]) + (s * y[0] * a[0]), g), (delta, h)];
277    let A = CompressedPoint::from((multiexp(&A_terms) * INV_EIGHT()).compress());
278    A_terms.zeroize();
279
280    let mut B_terms = vec![(ry * s, g), (eta, h)];
281    let B = CompressedPoint::from((multiexp(&B_terms) * INV_EIGHT()).compress());
282    B_terms.zeroize();
283
284    let e = Self::transcript_A_B(&mut transcript, A, B);
285
286    let r_answer = r + (a[0] * e);
287    let s_answer = s + (b[0] * e);
288    let delta_answer = eta + (delta * e) + (alpha * (e * e));
289
290    Some(WipProof { L: L_vec, R: R_vec, A, B, r_answer, s_answer, delta_answer })
291  }
292
293  pub(crate) fn verify<R: RngCore + CryptoRng>(
294    self,
295    rng: &mut R,
296    verifier: &mut BulletproofsPlusBatchVerifier,
297    mut transcript: Scalar,
298    WipProof { L, R, A, B, r_answer, s_answer, delta_answer }: WipProof,
299  ) -> bool {
300    let verifier_weight = Scalar::random(rng);
301
302    let WipStatement { generators, P, y } = self;
303
304    // Verify the L/R lengths
305    {
306      let mut lr_len = 0;
307      while (1 << lr_len) < generators.len() {
308        lr_len += 1;
309      }
310      if (L.len() != lr_len) || (R.len() != lr_len) || (generators.len() != (1 << lr_len)) {
311        return false;
312      }
313    }
314
315    let inv_y = {
316      let inv_y = y[0].invert();
317      let mut res = Vec::with_capacity(y.len());
318      res.push(inv_y);
319      while res.len() < y.len() {
320        res.push(
321          inv_y * res.last().expect("couldn't get last inv_y despite inv_y always being non-empty"),
322        );
323      }
324      res
325    };
326
327    let mut e_is = Vec::with_capacity(L.len());
328    let mut L_decomp = Vec::with_capacity(L.len());
329    let mut R_decomp = Vec::with_capacity(R.len());
330
331    let decomp_mul_cofactor =
332      |p| CompressedPoint::decompress(&p).map(|p| EdwardsPoint::mul_by_cofactor(&p));
333
334    for (L_i, R_i) in L.into_iter().zip(R.into_iter()) {
335      e_is.push(Self::transcript_L_R(&mut transcript, L_i, R_i));
336
337      let (Some(L_i), Some(R_i)) = (decomp_mul_cofactor(L_i), decomp_mul_cofactor(R_i)) else {
338        return false;
339      };
340
341      L_decomp.push(L_i);
342      R_decomp.push(R_i);
343    }
344
345    let L = L_decomp;
346    let R = R_decomp;
347
348    let e = Self::transcript_A_B(&mut transcript, A, B);
349
350    let (Some(A), Some(B)) = (decomp_mul_cofactor(A), decomp_mul_cofactor(B)) else {
351      return false;
352    };
353
354    let neg_e_square = verifier_weight * -(e * e);
355
356    verifier.0.other.push((neg_e_square, P));
357
358    let mut challenges = Vec::with_capacity(L.len());
359    let product_cache = {
360      let mut inv_e_is = e_is.clone();
361      Scalar::batch_invert(&mut inv_e_is);
362
363      debug_assert_eq!(e_is.len(), inv_e_is.len());
364      debug_assert_eq!(e_is.len(), L.len());
365      debug_assert_eq!(e_is.len(), R.len());
366      for ((e_i, inv_e_i), (L, R)) in
367        e_is.drain(..).zip(inv_e_is.drain(..)).zip(L.iter().zip(R.iter()))
368      {
369        debug_assert_eq!(e_i.invert(), inv_e_i);
370
371        challenges.push((e_i, inv_e_i));
372
373        let e_i_square = e_i * e_i;
374        let inv_e_i_square = inv_e_i * inv_e_i;
375        verifier.0.other.push((neg_e_square * e_i_square, *L));
376        verifier.0.other.push((neg_e_square * inv_e_i_square, *R));
377      }
378
379      challenge_products(&challenges)
380    };
381
382    while verifier.0.g_bold.len() < generators.len() {
383      verifier.0.g_bold.push(Scalar::ZERO);
384    }
385    while verifier.0.h_bold.len() < generators.len() {
386      verifier.0.h_bold.push(Scalar::ZERO);
387    }
388
389    let re = r_answer * e;
390    for i in 0 .. generators.len() {
391      let mut scalar = product_cache[i] * re;
392      if i > 0 {
393        scalar *= inv_y[i - 1];
394      }
395      verifier.0.g_bold[i] += verifier_weight * scalar;
396    }
397
398    let se = s_answer * e;
399    for i in 0 .. generators.len() {
400      verifier.0.h_bold[i] += verifier_weight * (se * product_cache[product_cache.len() - 1 - i]);
401    }
402
403    verifier.0.other.push((verifier_weight * -e, A));
404    verifier.0.g += verifier_weight * (r_answer * y[0] * s_answer);
405    verifier.0.h += verifier_weight * delta_answer;
406    verifier.0.other.push((-verifier_weight, B));
407
408    true
409  }
410}