monero_bulletproofs/plus/
weighted_inner_product.rs

1use std_shims::{vec, vec::Vec};
2
3use rand_core::{RngCore, CryptoRng};
4use zeroize::{Zeroize, ZeroizeOnDrop};
5
6use curve25519_dalek::{scalar::Scalar, edwards::EdwardsPoint};
7use curve25519_dalek::edwards::CompressedEdwardsY;
8use monero_io::decompress_point;
9use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
10use crate::{
11  core::{multiexp, multiexp_vartime, challenge_products},
12  batch_verifier::BulletproofsPlusBatchVerifier,
13  plus::{ScalarVector, PointVector, GeneratorsList, BpPlusGenerators, padded_pow_of_2},
14};
15
16// Figure 1 of the Bulletproofs+ paper
17#[derive(Clone, Debug)]
18pub(crate) struct WipStatement {
19  generators: BpPlusGenerators,
20  P: EdwardsPoint,
21  y: ScalarVector,
22}
23
24impl Zeroize for WipStatement {
25  fn zeroize(&mut self) {
26    self.P.zeroize();
27    self.y.zeroize();
28  }
29}
30
31#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
32pub(crate) struct WipWitness {
33  a: ScalarVector,
34  b: ScalarVector,
35  alpha: Scalar,
36}
37
38impl WipWitness {
39  pub(crate) fn new(mut a: ScalarVector, mut b: ScalarVector, alpha: Scalar) -> Option<Self> {
40    if a.0.is_empty() || (a.len() != b.len()) {
41      return None;
42    }
43
44    // Pad to the nearest power of 2
45    let missing = padded_pow_of_2(a.len()) - a.len();
46    a.0.reserve(missing);
47    b.0.reserve(missing);
48    for _ in 0 .. missing {
49      a.0.push(Scalar::ZERO);
50      b.0.push(Scalar::ZERO);
51    }
52
53    Some(Self { a, b, alpha })
54  }
55}
56
57#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
58pub(crate) struct WipProof {
59  pub(crate) L: Vec<CompressedEdwardsY>,
60  pub(crate) R: Vec<CompressedEdwardsY>,
61  pub(crate) A: CompressedEdwardsY,
62  pub(crate) B: CompressedEdwardsY,
63  pub(crate) r_answer: Scalar,
64  pub(crate) s_answer: Scalar,
65  pub(crate) delta_answer: Scalar,
66}
67
68impl WipStatement {
69  pub(crate) fn new(generators: BpPlusGenerators, P: EdwardsPoint, y: Scalar) -> Self {
70    debug_assert_eq!(generators.len(), padded_pow_of_2(generators.len()));
71
72    // y ** n
73    let mut y_vec = ScalarVector::new(generators.len());
74    y_vec[0] = y;
75    for i in 1 .. y_vec.len() {
76      y_vec[i] = y_vec[i - 1] * y;
77    }
78
79    Self { generators, P, y: y_vec }
80  }
81
82  fn transcript_L_R(
83    transcript: &mut Scalar,
84    L: CompressedEdwardsY,
85    R: CompressedEdwardsY,
86  ) -> Scalar {
87    let e = keccak256_to_scalar(
88      [transcript.as_bytes().as_ref(), L.as_bytes().as_ref(), R.as_bytes().as_ref()].concat(),
89    );
90    *transcript = e;
91    e
92  }
93
94  fn transcript_A_B(
95    transcript: &mut Scalar,
96    A: CompressedEdwardsY,
97    B: CompressedEdwardsY,
98  ) -> Scalar {
99    let e = keccak256_to_scalar(
100      [transcript.as_bytes().as_ref(), A.as_bytes().as_ref(), B.as_bytes().as_ref()].concat(),
101    );
102    *transcript = e;
103    e
104  }
105
106  // Prover's variant of the shared code block to calculate G/H/P when n > 1
107  // Returns each permutation of G/H since the prover needs to do operation on each permutation
108  // P is dropped as it's unused in the prover's path
109  #[allow(clippy::too_many_arguments)]
110  fn next_G_H(
111    transcript: &mut Scalar,
112    mut g_bold1: PointVector,
113    mut g_bold2: PointVector,
114    mut h_bold1: PointVector,
115    mut h_bold2: PointVector,
116    L: CompressedEdwardsY,
117    R: CompressedEdwardsY,
118    y_inv_n_hat: Scalar,
119  ) -> (Scalar, Scalar, Scalar, Scalar, PointVector, PointVector) {
120    debug_assert_eq!(g_bold1.len(), g_bold2.len());
121    debug_assert_eq!(h_bold1.len(), h_bold2.len());
122    debug_assert_eq!(g_bold1.len(), h_bold1.len());
123
124    let e = Self::transcript_L_R(transcript, L, R);
125    let inv_e = e.invert();
126
127    // This vartime is safe as all of these arguments are public
128    let mut new_g_bold = Vec::with_capacity(g_bold1.len());
129    let e_y_inv = e * y_inv_n_hat;
130    for g_bold in g_bold1.0.drain(..).zip(g_bold2.0.drain(..)) {
131      new_g_bold.push(multiexp_vartime(&[(inv_e, g_bold.0), (e_y_inv, g_bold.1)]));
132    }
133
134    let mut new_h_bold = Vec::with_capacity(h_bold1.len());
135    for h_bold in h_bold1.0.drain(..).zip(h_bold2.0.drain(..)) {
136      new_h_bold.push(multiexp_vartime(&[(e, h_bold.0), (inv_e, h_bold.1)]));
137    }
138
139    let e_square = e * e;
140    let inv_e_square = inv_e * inv_e;
141
142    (e, inv_e, e_square, inv_e_square, PointVector(new_g_bold), PointVector(new_h_bold))
143  }
144
145  pub(crate) fn prove<R: RngCore + CryptoRng>(
146    self,
147    rng: &mut R,
148    mut transcript: Scalar,
149    witness: &WipWitness,
150  ) -> Option<WipProof> {
151    let WipStatement { generators, P, mut y } = self;
152    #[cfg(not(debug_assertions))]
153    let _ = P;
154
155    if generators.len() != witness.a.len() {
156      return None;
157    }
158    let (g, h) = (BpPlusGenerators::g(), BpPlusGenerators::h());
159    let mut g_bold = vec![];
160    let mut h_bold = vec![];
161    for i in 0 .. generators.len() {
162      g_bold.push(generators.generator(GeneratorsList::GBold, i));
163      h_bold.push(generators.generator(GeneratorsList::HBold, i));
164    }
165    let mut g_bold = PointVector(g_bold);
166    let mut h_bold = PointVector(h_bold);
167
168    let mut y_inv = {
169      let mut i = 1;
170      let mut to_invert = vec![];
171      while i < g_bold.len() {
172        to_invert.push(y[i - 1]);
173        i *= 2;
174      }
175      Scalar::batch_invert(&mut to_invert);
176      to_invert
177    };
178
179    // Check P has the expected relationship
180    #[cfg(debug_assertions)]
181    {
182      let mut P_terms = witness
183        .a
184        .0
185        .iter()
186        .copied()
187        .zip(g_bold.0.iter().copied())
188        .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied()))
189        .collect::<Vec<_>>();
190      P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g));
191      P_terms.push((witness.alpha, h));
192      debug_assert_eq!(multiexp(&P_terms), P);
193      P_terms.zeroize();
194    }
195
196    let mut a = witness.a.clone();
197    let mut b = witness.b.clone();
198    let mut alpha = witness.alpha;
199
200    // From here on, g_bold.len() is used as n
201    debug_assert_eq!(g_bold.len(), a.len());
202
203    let mut L_vec = vec![];
204    let mut R_vec = vec![];
205
206    // else n > 1 case from figure 1
207    while g_bold.len() > 1 {
208      let (a1, a2) = a.clone().split();
209      let (b1, b2) = b.clone().split();
210      let (g_bold1, g_bold2) = g_bold.split();
211      let (h_bold1, h_bold2) = h_bold.split();
212
213      let n_hat = g_bold1.len();
214      debug_assert_eq!(a1.len(), n_hat);
215      debug_assert_eq!(a2.len(), n_hat);
216      debug_assert_eq!(b1.len(), n_hat);
217      debug_assert_eq!(b2.len(), n_hat);
218      debug_assert_eq!(g_bold1.len(), n_hat);
219      debug_assert_eq!(g_bold2.len(), n_hat);
220      debug_assert_eq!(h_bold1.len(), n_hat);
221      debug_assert_eq!(h_bold2.len(), n_hat);
222
223      let y_n_hat = y[n_hat - 1];
224      y.0.truncate(n_hat);
225
226      let d_l = Scalar::random(&mut *rng);
227      let d_r = Scalar::random(&mut *rng);
228
229      let c_l = a1.clone().weighted_inner_product(&b2, &y);
230      let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);
231
232      let y_inv_n_hat = y_inv.pop().unwrap();
233
234      let mut L_terms = (a1.clone() * y_inv_n_hat)
235        .0
236        .drain(..)
237        .zip(g_bold2.0.iter().copied())
238        .chain(b2.0.iter().copied().zip(h_bold1.0.iter().copied()))
239        .collect::<Vec<_>>();
240      L_terms.push((c_l, g));
241      L_terms.push((d_l, h));
242      let L = (multiexp(&L_terms) * INV_EIGHT()).compress();
243      L_vec.push(L);
244      L_terms.zeroize();
245
246      let mut R_terms = (a2.clone() * y_n_hat)
247        .0
248        .drain(..)
249        .zip(g_bold1.0.iter().copied())
250        .chain(b1.0.iter().copied().zip(h_bold2.0.iter().copied()))
251        .collect::<Vec<_>>();
252      R_terms.push((c_r, g));
253      R_terms.push((d_r, h));
254      let R = (multiexp(&R_terms) * INV_EIGHT()).compress();
255      R_vec.push(R);
256      R_terms.zeroize();
257
258      let (e, inv_e, e_square, inv_e_square);
259      (e, inv_e, e_square, inv_e_square, g_bold, h_bold) =
260        Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat);
261
262      a = (a1 * e) + &(a2 * (y_n_hat * inv_e));
263      b = (b1 * inv_e) + &(b2 * e);
264      alpha += (d_l * e_square) + (d_r * inv_e_square);
265
266      debug_assert_eq!(g_bold.len(), a.len());
267      debug_assert_eq!(g_bold.len(), h_bold.len());
268      debug_assert_eq!(g_bold.len(), b.len());
269    }
270
271    // n == 1 case from figure 1
272    debug_assert_eq!(g_bold.len(), 1);
273    debug_assert_eq!(h_bold.len(), 1);
274
275    debug_assert_eq!(a.len(), 1);
276    debug_assert_eq!(b.len(), 1);
277
278    let r = Scalar::random(&mut *rng);
279    let s = Scalar::random(&mut *rng);
280    let delta = Scalar::random(&mut *rng);
281    let eta = Scalar::random(&mut *rng);
282
283    let ry = r * y[0];
284
285    let mut A_terms =
286      vec![(r, g_bold[0]), (s, h_bold[0]), ((ry * b[0]) + (s * y[0] * a[0]), g), (delta, h)];
287    let A = (multiexp(&A_terms) * INV_EIGHT()).compress();
288    A_terms.zeroize();
289
290    let mut B_terms = vec![(ry * s, g), (eta, h)];
291    let B = (multiexp(&B_terms) * INV_EIGHT()).compress();
292    B_terms.zeroize();
293
294    let e = Self::transcript_A_B(&mut transcript, A, B);
295
296    let r_answer = r + (a[0] * e);
297    let s_answer = s + (b[0] * e);
298    let delta_answer = eta + (delta * e) + (alpha * (e * e));
299
300    Some(WipProof { L: L_vec, R: R_vec, A, B, r_answer, s_answer, delta_answer })
301  }
302
303  pub(crate) fn verify<R: RngCore + CryptoRng>(
304    self,
305    rng: &mut R,
306    verifier: &mut BulletproofsPlusBatchVerifier,
307    mut transcript: Scalar,
308    proof: WipProof,
309  ) -> bool {
310    let verifier_weight = Scalar::random(rng);
311
312    let WipStatement { generators, P, y } = self;
313
314    // Verify the L/R lengths
315    {
316      let mut lr_len = 0;
317      while (1 << lr_len) < generators.len() {
318        lr_len += 1;
319      }
320      if (proof.L.len() != lr_len) ||
321        (proof.R.len() != lr_len) ||
322        (generators.len() != (1 << lr_len))
323      {
324        return false;
325      }
326    }
327
328    let inv_y = {
329      let inv_y = y[0].invert();
330      let mut res = Vec::with_capacity(y.len());
331      res.push(inv_y);
332      while res.len() < y.len() {
333        res.push(inv_y * res.last().unwrap());
334      }
335      res
336    };
337
338    let mut e_is = Vec::with_capacity(proof.L.len());
339    let mut L = Vec::with_capacity(proof.L.len());
340    let mut R = Vec::with_capacity(proof.R.len());
341
342    for (L_i, R_i) in proof.L.into_iter().zip(proof.R.into_iter()) {
343      e_is.push(Self::transcript_L_R(&mut transcript, L_i, R_i));
344
345      let Some(L_i) = decompress_point(L_i).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
346        return false;
347      };
348      let Some(R_i) = decompress_point(R_i).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
349        return false;
350      };
351
352      L.push(L_i);
353      R.push(R_i);
354    }
355
356    let e = Self::transcript_A_B(&mut transcript, proof.A, proof.B);
357    let Some(A) = decompress_point(proof.A).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
358      return false;
359    };
360    let Some(B) = decompress_point(proof.B).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
361      return false;
362    };
363    let neg_e_square = verifier_weight * -(e * e);
364
365    verifier.0.other.push((neg_e_square, P));
366
367    let mut challenges = Vec::with_capacity(L.len());
368    let product_cache = {
369      let mut inv_e_is = e_is.clone();
370      Scalar::batch_invert(&mut inv_e_is);
371
372      debug_assert_eq!(e_is.len(), inv_e_is.len());
373      debug_assert_eq!(e_is.len(), L.len());
374      debug_assert_eq!(e_is.len(), R.len());
375      for ((e_i, inv_e_i), (L, R)) in
376        e_is.drain(..).zip(inv_e_is.drain(..)).zip(L.iter().zip(R.iter()))
377      {
378        debug_assert_eq!(e_i.invert(), inv_e_i);
379
380        challenges.push((e_i, inv_e_i));
381
382        let e_i_square = e_i * e_i;
383        let inv_e_i_square = inv_e_i * inv_e_i;
384        verifier.0.other.push((neg_e_square * e_i_square, *L));
385        verifier.0.other.push((neg_e_square * inv_e_i_square, *R));
386      }
387
388      challenge_products(&challenges)
389    };
390
391    while verifier.0.g_bold.len() < generators.len() {
392      verifier.0.g_bold.push(Scalar::ZERO);
393    }
394    while verifier.0.h_bold.len() < generators.len() {
395      verifier.0.h_bold.push(Scalar::ZERO);
396    }
397
398    let re = proof.r_answer * e;
399    for i in 0 .. generators.len() {
400      let mut scalar = product_cache[i] * re;
401      if i > 0 {
402        scalar *= inv_y[i - 1];
403      }
404      verifier.0.g_bold[i] += verifier_weight * scalar;
405    }
406
407    let se = proof.s_answer * e;
408    for i in 0 .. generators.len() {
409      verifier.0.h_bold[i] += verifier_weight * (se * product_cache[product_cache.len() - 1 - i]);
410    }
411
412    verifier.0.other.push((verifier_weight * -e, A));
413    verifier.0.g += verifier_weight * (proof.r_answer * y[0] * proof.s_answer);
414    verifier.0.h += verifier_weight * proof.delta_answer;
415    verifier.0.other.push((-verifier_weight, B));
416
417    true
418  }
419}