monero_bulletproofs/original/
inner_product.rs

1use std_shims::{vec, vec::Vec};
2
3use zeroize::Zeroize;
4
5use curve25519_dalek::{Scalar, EdwardsPoint};
6use curve25519_dalek::edwards::CompressedEdwardsY;
7use monero_generators::H;
8use monero_io::decompress_point;
9use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
10use crate::{
11  core::{multiexp_vartime, challenge_products},
12  scalar_vector::ScalarVector,
13  point_vector::PointVector,
14  BulletproofsBatchVerifier,
15};
16
17/// An error from proving/verifying Inner-Product statements.
18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
19pub(crate) enum IpError {
20  IncorrectAmountOfGenerators,
21  DifferingLrLengths,
22}
23
24/// The Bulletproofs Inner-Product statement.
25///
26/// This is for usage with Protocol 2 from the Bulletproofs paper.
27#[derive(Clone, Debug)]
28pub(crate) struct IpStatement {
29  // Weights for h_bold
30  h_bold_weights: ScalarVector,
31  // u as the discrete logarithm of G
32  u: Scalar,
33}
34
35/// The witness for the Bulletproofs Inner-Product statement.
36#[derive(Clone, Debug)]
37pub(crate) struct IpWitness {
38  // a
39  a: ScalarVector,
40  // b
41  b: ScalarVector,
42}
43
44impl IpWitness {
45  /// Construct a new witness for an Inner-Product statement.
46  ///
47  /// This functions return None if the lengths of a, b are mismatched, not a power of two, or are
48  /// empty.
49  pub(crate) fn new(a: ScalarVector, b: ScalarVector) -> Option<Self> {
50    if a.0.is_empty() || (a.len() != b.len()) {
51      None?;
52    }
53
54    let mut power_of_2 = 1;
55    while power_of_2 < a.len() {
56      power_of_2 <<= 1;
57    }
58    if power_of_2 != a.len() {
59      None?;
60    }
61
62    Some(Self { a, b })
63  }
64}
65
66/// A proof for the Bulletproofs Inner-Product statement.
67#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
68pub(crate) struct IpProof {
69  pub(crate) L: Vec<CompressedEdwardsY>,
70  pub(crate) R: Vec<CompressedEdwardsY>,
71  pub(crate) a: Scalar,
72  pub(crate) b: Scalar,
73}
74
75impl IpStatement {
76  /// Create a new Inner-Product statement which won't transcript P.
77  ///
78  /// This MUST only be called when P is deterministic to already transcripted elements.
79  pub(crate) fn new_without_P_transcript(h_bold_weights: ScalarVector, u: Scalar) -> Self {
80    Self { h_bold_weights, u }
81  }
82
83  // Transcript a round of the protocol
84  fn transcript_L_R(transcript: Scalar, L: CompressedEdwardsY, R: CompressedEdwardsY) -> Scalar {
85    let mut transcript = transcript.to_bytes().to_vec();
86    transcript.extend_from_slice(L.as_bytes());
87    transcript.extend_from_slice(R.as_bytes());
88    keccak256_to_scalar(transcript)
89  }
90
91  /// Prove for this Inner-Product statement.
92  ///
93  /// Returns an error if this statement couldn't be proven for (such as if the witness isn't
94  /// consistent).
95  pub(crate) fn prove(
96    self,
97    mut transcript: Scalar,
98    witness: IpWitness,
99  ) -> Result<IpProof, IpError> {
100    let generators = &crate::original::GENERATORS;
101    let g_bold_slice = &generators.G[.. witness.a.len()];
102    let h_bold_slice = &generators.H[.. witness.a.len()];
103
104    let (mut g_bold, mut h_bold, u, mut a, mut b) = {
105      let IpStatement { h_bold_weights, u } = self;
106      let u = *H * u;
107
108      // Ensure we have the exact amount of weights
109      if h_bold_weights.len() != g_bold_slice.len() {
110        Err(IpError::IncorrectAmountOfGenerators)?;
111      }
112      // Acquire a local copy of the generators
113      let g_bold = PointVector(g_bold_slice.to_vec());
114      let h_bold = PointVector(h_bold_slice.to_vec()).mul_vec(&h_bold_weights);
115
116      let IpWitness { a, b } = witness;
117
118      (g_bold, h_bold, u, a, b)
119    };
120
121    let mut L_vec = vec![];
122    let mut R_vec = vec![];
123
124    // `else: (n > 1)` case, lines 18-35 of the Bulletproofs paper
125    // This interprets `g_bold.len()` as `n`
126    while g_bold.len() > 1 {
127      // Split a, b, g_bold, h_bold as needed for lines 20-24
128      let (a1, a2) = a.clone().split();
129      let (b1, b2) = b.clone().split();
130
131      let (g_bold1, g_bold2) = g_bold.split();
132      let (h_bold1, h_bold2) = h_bold.split();
133
134      let n_hat = g_bold1.len();
135
136      // Sanity
137      debug_assert_eq!(a1.len(), n_hat);
138      debug_assert_eq!(a2.len(), n_hat);
139      debug_assert_eq!(b1.len(), n_hat);
140      debug_assert_eq!(b2.len(), n_hat);
141      debug_assert_eq!(g_bold1.len(), n_hat);
142      debug_assert_eq!(g_bold2.len(), n_hat);
143      debug_assert_eq!(h_bold1.len(), n_hat);
144      debug_assert_eq!(h_bold2.len(), n_hat);
145
146      // cl, cr, lines 21-22
147      let cl = a1.clone().inner_product(&b2);
148      let cr = a2.clone().inner_product(&b1);
149
150      let L = {
151        let mut L_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
152        for (a, g) in a1.0.iter().zip(g_bold2.0.iter()) {
153          L_terms.push((*a, *g));
154        }
155        for (b, h) in b2.0.iter().zip(h_bold1.0.iter()) {
156          L_terms.push((*b, *h));
157        }
158        L_terms.push((cl, u));
159        // Uses vartime since this isn't a ZK proof
160        multiexp_vartime(&L_terms)
161      };
162      L_vec.push((L * INV_EIGHT()).compress());
163
164      let R = {
165        let mut R_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
166        for (a, g) in a2.0.iter().zip(g_bold1.0.iter()) {
167          R_terms.push((*a, *g));
168        }
169        for (b, h) in b1.0.iter().zip(h_bold2.0.iter()) {
170          R_terms.push((*b, *h));
171        }
172        R_terms.push((cr, u));
173        multiexp_vartime(&R_terms)
174      };
175      R_vec.push((R * INV_EIGHT()).compress());
176
177      // Now that we've calculate L, R, transcript them to receive x (26-27)
178      transcript = Self::transcript_L_R(transcript, *L_vec.last().unwrap(), *R_vec.last().unwrap());
179      let x = transcript;
180      let x_inv = x.invert();
181
182      // The prover and verifier now calculate the following (28-31)
183      g_bold = PointVector(Vec::with_capacity(g_bold1.len()));
184      for (a, b) in g_bold1.0.into_iter().zip(g_bold2.0.into_iter()) {
185        g_bold.0.push(multiexp_vartime(&[(x_inv, a), (x, b)]));
186      }
187      h_bold = PointVector(Vec::with_capacity(h_bold1.len()));
188      for (a, b) in h_bold1.0.into_iter().zip(h_bold2.0.into_iter()) {
189        h_bold.0.push(multiexp_vartime(&[(x, a), (x_inv, b)]));
190      }
191
192      // 32-34
193      a = (a1 * x) + &(a2 * x_inv);
194      b = (b1 * x_inv) + &(b2 * x);
195    }
196
197    // `if n = 1` case from line 14-17
198
199    // Sanity
200    debug_assert_eq!(g_bold.len(), 1);
201    debug_assert_eq!(h_bold.len(), 1);
202    debug_assert_eq!(a.len(), 1);
203    debug_assert_eq!(b.len(), 1);
204
205    // We simply send a/b
206    Ok(IpProof { L: L_vec, R: R_vec, a: a[0], b: b[0] })
207  }
208
209  /// Queue an Inner-Product proof for batch verification.
210  ///
211  /// This will return Err if there is an error. This will return Ok if the proof was successfully
212  /// queued for batch verification. The caller is required to verify the batch in order to ensure
213  /// the proof is actually correct.
214  pub(crate) fn verify(
215    self,
216    verifier: &mut BulletproofsBatchVerifier,
217    ip_rows: usize,
218    mut transcript: Scalar,
219    verifier_weight: Scalar,
220    proof: IpProof,
221  ) -> Result<(), IpError> {
222    let generators = &crate::original::GENERATORS;
223    let g_bold_slice = &generators.G[.. ip_rows];
224    let h_bold_slice = &generators.H[.. ip_rows];
225
226    let IpStatement { h_bold_weights, u } = self;
227
228    // Verify the L/R lengths
229    {
230      // Calculate the discrete log w.r.t. 2 for the amount of generators present
231      let mut lr_len = 0;
232      while (1 << lr_len) < g_bold_slice.len() {
233        lr_len += 1;
234      }
235
236      // This proof has less/more terms than the passed in generators are for
237      if proof.L.len() != lr_len {
238        Err(IpError::IncorrectAmountOfGenerators)?;
239      }
240      if proof.L.len() != proof.R.len() {
241        Err(IpError::DifferingLrLengths)?;
242      }
243    }
244
245    // Again, we start with the `else: (n > 1)` case
246
247    // We need x, x_inv per lines 25-27 for lines 28-31
248    let mut xs = Vec::with_capacity(proof.L.len());
249    for (L, R) in proof.L.iter().zip(proof.R.iter()) {
250      transcript = Self::transcript_L_R(transcript, *L, *R);
251      xs.push(transcript);
252    }
253
254    // We calculate their inverse in batch
255    let mut x_invs = xs.clone();
256    Scalar::batch_invert(&mut x_invs);
257
258    // Now, with x and x_inv, we need to calculate g_bold', h_bold', P'
259    //
260    // For the sake of performance, we solely want to calculate all of these in terms of scalings
261    // for g_bold, h_bold, P, and don't want to actually perform intermediary scalings of the
262    // points
263    //
264    // L and R are easy, as it's simply x**2, x**-2
265    //
266    // For the series of g_bold, h_bold, we use the `challenge_products` function
267    // For how that works, please see its own documentation
268    let product_cache = {
269      let mut challenges = Vec::with_capacity(proof.L.len());
270
271      let x_iter = xs.into_iter().zip(x_invs);
272      let lr_iter = proof.L.into_iter().zip(proof.R);
273      for ((x, x_inv), (L, R)) in x_iter.zip(lr_iter) {
274        challenges.push((x, x_inv));
275
276        // TODO: create proper error
277        let L = decompress_point(L)
278          .map(|p| EdwardsPoint::mul_by_cofactor(&p))
279          .ok_or(IpError::DifferingLrLengths)?;
280        let R = decompress_point(R)
281          .map(|p| EdwardsPoint::mul_by_cofactor(&p))
282          .ok_or(IpError::DifferingLrLengths)?;
283
284        verifier.0.other.push((verifier_weight * (x * x), L));
285        verifier.0.other.push((verifier_weight * (x_inv * x_inv), R));
286      }
287
288      challenge_products(&challenges)
289    };
290
291    // And now for the `if n = 1` case
292    let c = proof.a * proof.b;
293
294    // The multiexp of these terms equate to the final permutation of P
295    // We now add terms for a * g_bold' + b * h_bold' b + c * u, with the scalars negative such
296    // that the terms sum to 0 for an honest prover
297
298    // The g_bold * a term case from line 16
299    #[allow(clippy::needless_range_loop)]
300    for i in 0 .. g_bold_slice.len() {
301      verifier.0.g_bold[i] -= verifier_weight * product_cache[i] * proof.a;
302    }
303    // The h_bold * b term case from line 16
304    for i in 0 .. h_bold_slice.len() {
305      verifier.0.h_bold[i] -=
306        verifier_weight * product_cache[product_cache.len() - 1 - i] * proof.b * h_bold_weights[i];
307    }
308    // The c * u term case from line 16
309    verifier.0.h -= verifier_weight * c * u;
310
311    Ok(())
312  }
313}