monero_bulletproofs/original/
inner_product.rs

1use std_shims::{vec, vec::Vec};
2
3use zeroize::Zeroize;
4
5use curve25519_dalek::{Scalar, EdwardsPoint};
6use monero_generators::H;
7use monero_io::{CompressedPoint};
8use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
9use crate::{
10  core::{multiexp_vartime, challenge_products},
11  scalar_vector::ScalarVector,
12  point_vector::PointVector,
13  BulletproofsBatchVerifier,
14};
15
16/// An error from proving/verifying Inner-Product statements.
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18pub(crate) enum IpError {
19  IncorrectAmountOfGenerators,
20  DifferingLrLengths,
21  InvalidPoint,
22}
23
24/// The Bulletproofs Inner-Product statement.
25///
26/// This is for usage with Protocol 2 from the Bulletproofs paper.
27#[derive(Clone, Debug)]
28pub(crate) struct IpStatement {
29  // Weights for h_bold
30  h_bold_weights: ScalarVector,
31  // u as the discrete logarithm of G
32  u: Scalar,
33}
34
35/// The witness for the Bulletproofs Inner-Product statement.
36#[derive(Clone, Debug)]
37pub(crate) struct IpWitness {
38  // a
39  a: ScalarVector,
40  // b
41  b: ScalarVector,
42}
43
44impl IpWitness {
45  /// Construct a new witness for an Inner-Product statement.
46  ///
47  /// This functions return None if the lengths of a, b are mismatched, not a power of two, or are
48  /// empty.
49  pub(crate) fn new(a: ScalarVector, b: ScalarVector) -> Option<Self> {
50    if a.0.is_empty() || (a.len() != b.len()) {
51      None?;
52    }
53
54    let mut power_of_2 = 1;
55    while power_of_2 < a.len() {
56      power_of_2 <<= 1;
57    }
58    if power_of_2 != a.len() {
59      None?;
60    }
61
62    Some(Self { a, b })
63  }
64}
65
66/// A proof for the Bulletproofs Inner-Product statement.
67#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
68pub(crate) struct IpProof {
69  pub(crate) L: Vec<CompressedPoint>,
70  pub(crate) R: Vec<CompressedPoint>,
71  pub(crate) a: Scalar,
72  pub(crate) b: Scalar,
73}
74
75impl IpStatement {
76  /// Create a new Inner-Product statement which won't transcript P.
77  ///
78  /// This MUST only be called when P is deterministic to already transcripted elements.
79  pub(crate) fn new_without_P_transcript(h_bold_weights: ScalarVector, u: Scalar) -> Self {
80    Self { h_bold_weights, u }
81  }
82
83  // Transcript a round of the protocol
84  fn transcript_L_R(transcript: Scalar, L: CompressedPoint, R: CompressedPoint) -> Scalar {
85    let mut transcript = transcript.to_bytes().to_vec();
86    transcript.extend_from_slice(L.as_bytes());
87    transcript.extend_from_slice(R.as_bytes());
88    keccak256_to_scalar(transcript)
89  }
90
91  /// Prove for this Inner-Product statement.
92  ///
93  /// Returns an error if this statement couldn't be proven for (such as if the witness isn't
94  /// consistent).
95  pub(crate) fn prove(
96    self,
97    mut transcript: Scalar,
98    witness: IpWitness,
99  ) -> Result<IpProof, IpError> {
100    let generators = &crate::original::GENERATORS;
101    let g_bold_slice = &generators.G[.. witness.a.len()];
102    let h_bold_slice = &generators.H[.. witness.a.len()];
103
104    let (mut g_bold, mut h_bold, u, mut a, mut b) = {
105      let IpStatement { h_bold_weights, u } = self;
106      let u = *H * u;
107
108      // Ensure we have the exact amount of weights
109      if h_bold_weights.len() != g_bold_slice.len() {
110        Err(IpError::IncorrectAmountOfGenerators)?;
111      }
112      // Acquire a local copy of the generators
113      let g_bold = PointVector(g_bold_slice.to_vec());
114      let h_bold = PointVector(h_bold_slice.to_vec()).mul_vec(&h_bold_weights);
115
116      let IpWitness { a, b } = witness;
117
118      (g_bold, h_bold, u, a, b)
119    };
120
121    let mut L_vec = vec![];
122    let mut R_vec = vec![];
123
124    // `else: (n > 1)` case, lines 18-35 of the Bulletproofs paper
125    // This interprets `g_bold.len()` as `n`
126    while g_bold.len() > 1 {
127      // Split a, b, g_bold, h_bold as needed for lines 20-24
128      let (a1, a2) = a.clone().split();
129      let (b1, b2) = b.clone().split();
130
131      let (g_bold1, g_bold2) = g_bold.split();
132      let (h_bold1, h_bold2) = h_bold.split();
133
134      let n_hat = g_bold1.len();
135
136      // Sanity
137      debug_assert_eq!(a1.len(), n_hat);
138      debug_assert_eq!(a2.len(), n_hat);
139      debug_assert_eq!(b1.len(), n_hat);
140      debug_assert_eq!(b2.len(), n_hat);
141      debug_assert_eq!(g_bold1.len(), n_hat);
142      debug_assert_eq!(g_bold2.len(), n_hat);
143      debug_assert_eq!(h_bold1.len(), n_hat);
144      debug_assert_eq!(h_bold2.len(), n_hat);
145
146      // cl, cr, lines 21-22
147      let cl = a1.clone().inner_product(&b2);
148      let cr = a2.clone().inner_product(&b1);
149
150      let L = {
151        let mut L_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
152        for (a, g) in a1.0.iter().zip(g_bold2.0.iter()) {
153          L_terms.push((*a, *g));
154        }
155        for (b, h) in b2.0.iter().zip(h_bold1.0.iter()) {
156          L_terms.push((*b, *h));
157        }
158        L_terms.push((cl, u));
159        // Uses vartime since this isn't a ZK proof
160        multiexp_vartime(&L_terms)
161      };
162      L_vec.push(CompressedPoint::from((L * INV_EIGHT()).compress()));
163
164      let R = {
165        let mut R_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
166        for (a, g) in a2.0.iter().zip(g_bold1.0.iter()) {
167          R_terms.push((*a, *g));
168        }
169        for (b, h) in b1.0.iter().zip(h_bold2.0.iter()) {
170          R_terms.push((*b, *h));
171        }
172        R_terms.push((cr, u));
173        multiexp_vartime(&R_terms)
174      };
175      R_vec.push(CompressedPoint::from((R * INV_EIGHT()).compress()));
176
177      // Now that we've calculate L, R, transcript them to receive x (26-27)
178      transcript = Self::transcript_L_R(
179        transcript,
180        *L_vec.last().expect("couldn't get last L_vec despite always being non-empty"),
181        *R_vec.last().expect("couldn't get last R_vec despite always being non-empty"),
182      );
183      let x = transcript;
184      let x_inv = x.invert();
185
186      // The prover and verifier now calculate the following (28-31)
187      g_bold = PointVector(Vec::with_capacity(g_bold1.len()));
188      for (a, b) in g_bold1.0.into_iter().zip(g_bold2.0.into_iter()) {
189        g_bold.0.push(multiexp_vartime(&[(x_inv, a), (x, b)]));
190      }
191      h_bold = PointVector(Vec::with_capacity(h_bold1.len()));
192      for (a, b) in h_bold1.0.into_iter().zip(h_bold2.0.into_iter()) {
193        h_bold.0.push(multiexp_vartime(&[(x, a), (x_inv, b)]));
194      }
195
196      // 32-34
197      a = (a1 * x) + &(a2 * x_inv);
198      b = (b1 * x_inv) + &(b2 * x);
199    }
200
201    // `if n = 1` case from line 14-17
202
203    // Sanity
204    debug_assert_eq!(g_bold.len(), 1);
205    debug_assert_eq!(h_bold.len(), 1);
206    debug_assert_eq!(a.len(), 1);
207    debug_assert_eq!(b.len(), 1);
208
209    // We simply send a/b
210    Ok(IpProof { L: L_vec, R: R_vec, a: a[0], b: b[0] })
211  }
212
213  /// Queue an Inner-Product proof for batch verification.
214  ///
215  /// This will return Err if there is an error. This will return Ok if the proof was successfully
216  /// queued for batch verification. The caller is required to verify the batch in order to ensure
217  /// the proof is actually correct.
218  pub(crate) fn verify(
219    self,
220    verifier: &mut BulletproofsBatchVerifier,
221    ip_rows: usize,
222    mut transcript: Scalar,
223    verifier_weight: Scalar,
224    proof: IpProof,
225  ) -> Result<(), IpError> {
226    let generators = &crate::original::GENERATORS;
227    let g_bold_slice = &generators.G[.. ip_rows];
228    let h_bold_slice = &generators.H[.. ip_rows];
229
230    let IpStatement { h_bold_weights, u } = self;
231
232    // Verify the L/R lengths
233    {
234      // Calculate the discrete log w.r.t. 2 for the amount of generators present
235      let mut lr_len = 0;
236      while (1 << lr_len) < g_bold_slice.len() {
237        lr_len += 1;
238      }
239
240      // This proof has less/more terms than the passed in generators are for
241      if proof.L.len() != lr_len {
242        Err(IpError::IncorrectAmountOfGenerators)?;
243      }
244      if proof.L.len() != proof.R.len() {
245        Err(IpError::DifferingLrLengths)?;
246      }
247    }
248
249    // Again, we start with the `else: (n > 1)` case
250
251    // We need x, x_inv per lines 25-27 for lines 28-31
252    let mut xs = Vec::with_capacity(proof.L.len());
253    for (L, R) in proof.L.iter().zip(proof.R.iter()) {
254      transcript = Self::transcript_L_R(transcript, *L, *R);
255      xs.push(transcript);
256    }
257
258    // We calculate their inverse in batch
259    let mut x_invs = xs.clone();
260    Scalar::batch_invert(&mut x_invs);
261
262    // Now, with x and x_inv, we need to calculate g_bold', h_bold', P'
263    //
264    // For the sake of performance, we solely want to calculate all of these in terms of scalings
265    // for g_bold, h_bold, P, and don't want to actually perform intermediary scalings of the
266    // points
267    //
268    // L and R are easy, as it's simply x**2, x**-2
269    //
270    // For the series of g_bold, h_bold, we use the `challenge_products` function
271    // For how that works, please see its own documentation
272    let product_cache = {
273      let mut challenges = Vec::with_capacity(proof.L.len());
274
275      let x_iter = xs.into_iter().zip(x_invs);
276      let lr_iter = proof.L.into_iter().zip(proof.R);
277      for ((x, x_inv), (L, R)) in x_iter.zip(lr_iter) {
278        challenges.push((x, x_inv));
279
280        let L =
281          L.decompress().map(|p| EdwardsPoint::mul_by_cofactor(&p)).ok_or(IpError::InvalidPoint)?;
282        let R =
283          R.decompress().map(|p| EdwardsPoint::mul_by_cofactor(&p)).ok_or(IpError::InvalidPoint)?;
284
285        verifier.0.other.push((verifier_weight * (x * x), L));
286        verifier.0.other.push((verifier_weight * (x_inv * x_inv), R));
287      }
288
289      challenge_products(&challenges)
290    };
291
292    // And now for the `if n = 1` case
293    let c = proof.a * proof.b;
294
295    // The multiexp of these terms equate to the final permutation of P
296    // We now add terms for a * g_bold' + b * h_bold' b + c * u, with the scalars negative such
297    // that the terms sum to 0 for an honest prover
298
299    // The g_bold * a term case from line 16
300    #[allow(clippy::needless_range_loop)]
301    for i in 0 .. g_bold_slice.len() {
302      verifier.0.g_bold[i] -= verifier_weight * product_cache[i] * proof.a;
303    }
304    // The h_bold * b term case from line 16
305    for i in 0 .. h_bold_slice.len() {
306      verifier.0.h_bold[i] -=
307        verifier_weight * product_cache[product_cache.len() - 1 - i] * proof.b * h_bold_weights[i];
308    }
309    // The c * u term case from line 16
310    verifier.0.h -= verifier_weight * c * u;
311
312    Ok(())
313  }
314}