1use std_shims::{vec, vec::Vec};
23use zeroize::Zeroize;
45use curve25519_dalek::{Scalar, EdwardsPoint};
6use curve25519_dalek::edwards::CompressedEdwardsY;
7use monero_generators::H;
8use monero_io::decompress_point;
9use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
10use crate::{
11 core::{multiexp_vartime, challenge_products},
12 scalar_vector::ScalarVector,
13 point_vector::PointVector,
14 BulletproofsBatchVerifier,
15};
1617/// An error from proving/verifying Inner-Product statements.
18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
19pub(crate) enum IpError {
20 IncorrectAmountOfGenerators,
21 DifferingLrLengths,
22}
2324/// The Bulletproofs Inner-Product statement.
25///
26/// This is for usage with Protocol 2 from the Bulletproofs paper.
27#[derive(Clone, Debug)]
28pub(crate) struct IpStatement {
29// Weights for h_bold
30h_bold_weights: ScalarVector,
31// u as the discrete logarithm of G
32u: Scalar,
33}
3435/// The witness for the Bulletproofs Inner-Product statement.
36#[derive(Clone, Debug)]
37pub(crate) struct IpWitness {
38// a
39a: ScalarVector,
40// b
41b: ScalarVector,
42}
4344impl IpWitness {
45/// Construct a new witness for an Inner-Product statement.
46 ///
47 /// This functions return None if the lengths of a, b are mismatched, not a power of two, or are
48 /// empty.
49pub(crate) fn new(a: ScalarVector, b: ScalarVector) -> Option<Self> {
50if a.0.is_empty() || (a.len() != b.len()) {
51None?;
52 }
5354let mut power_of_2 = 1;
55while power_of_2 < a.len() {
56 power_of_2 <<= 1;
57 }
58if power_of_2 != a.len() {
59None?;
60 }
6162Some(Self { a, b })
63 }
64}
6566/// A proof for the Bulletproofs Inner-Product statement.
67#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
68pub(crate) struct IpProof {
69pub(crate) L: Vec<CompressedEdwardsY>,
70pub(crate) R: Vec<CompressedEdwardsY>,
71pub(crate) a: Scalar,
72pub(crate) b: Scalar,
73}
7475impl IpStatement {
76/// Create a new Inner-Product statement which won't transcript P.
77 ///
78 /// This MUST only be called when P is deterministic to already transcripted elements.
79pub(crate) fn new_without_P_transcript(h_bold_weights: ScalarVector, u: Scalar) -> Self {
80Self { h_bold_weights, u }
81 }
8283// Transcript a round of the protocol
84fn transcript_L_R(transcript: Scalar, L: CompressedEdwardsY, R: CompressedEdwardsY) -> Scalar {
85let mut transcript = transcript.to_bytes().to_vec();
86 transcript.extend_from_slice(L.as_bytes());
87 transcript.extend_from_slice(R.as_bytes());
88 keccak256_to_scalar(transcript)
89 }
9091/// Prove for this Inner-Product statement.
92 ///
93 /// Returns an error if this statement couldn't be proven for (such as if the witness isn't
94 /// consistent).
95pub(crate) fn prove(
96self,
97mut transcript: Scalar,
98 witness: IpWitness,
99 ) -> Result<IpProof, IpError> {
100let generators = &crate::original::GENERATORS;
101let g_bold_slice = &generators.G[.. witness.a.len()];
102let h_bold_slice = &generators.H[.. witness.a.len()];
103104let (mut g_bold, mut h_bold, u, mut a, mut b) = {
105let IpStatement { h_bold_weights, u } = self;
106let u = *H * u;
107108// Ensure we have the exact amount of weights
109if h_bold_weights.len() != g_bold_slice.len() {
110Err(IpError::IncorrectAmountOfGenerators)?;
111 }
112// Acquire a local copy of the generators
113let g_bold = PointVector(g_bold_slice.to_vec());
114let h_bold = PointVector(h_bold_slice.to_vec()).mul_vec(&h_bold_weights);
115116let IpWitness { a, b } = witness;
117118 (g_bold, h_bold, u, a, b)
119 };
120121let mut L_vec = vec![];
122let mut R_vec = vec![];
123124// `else: (n > 1)` case, lines 18-35 of the Bulletproofs paper
125 // This interprets `g_bold.len()` as `n`
126while g_bold.len() > 1 {
127// Split a, b, g_bold, h_bold as needed for lines 20-24
128let (a1, a2) = a.clone().split();
129let (b1, b2) = b.clone().split();
130131let (g_bold1, g_bold2) = g_bold.split();
132let (h_bold1, h_bold2) = h_bold.split();
133134let n_hat = g_bold1.len();
135136// Sanity
137debug_assert_eq!(a1.len(), n_hat);
138debug_assert_eq!(a2.len(), n_hat);
139debug_assert_eq!(b1.len(), n_hat);
140debug_assert_eq!(b2.len(), n_hat);
141debug_assert_eq!(g_bold1.len(), n_hat);
142debug_assert_eq!(g_bold2.len(), n_hat);
143debug_assert_eq!(h_bold1.len(), n_hat);
144debug_assert_eq!(h_bold2.len(), n_hat);
145146// cl, cr, lines 21-22
147let cl = a1.clone().inner_product(&b2);
148let cr = a2.clone().inner_product(&b1);
149150let L = {
151let mut L_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
152for (a, g) in a1.0.iter().zip(g_bold2.0.iter()) {
153 L_terms.push((*a, *g));
154 }
155for (b, h) in b2.0.iter().zip(h_bold1.0.iter()) {
156 L_terms.push((*b, *h));
157 }
158 L_terms.push((cl, u));
159// Uses vartime since this isn't a ZK proof
160multiexp_vartime(&L_terms)
161 };
162 L_vec.push((L * INV_EIGHT()).compress());
163164let R = {
165let mut R_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
166for (a, g) in a2.0.iter().zip(g_bold1.0.iter()) {
167 R_terms.push((*a, *g));
168 }
169for (b, h) in b1.0.iter().zip(h_bold2.0.iter()) {
170 R_terms.push((*b, *h));
171 }
172 R_terms.push((cr, u));
173 multiexp_vartime(&R_terms)
174 };
175 R_vec.push((R * INV_EIGHT()).compress());
176177// Now that we've calculate L, R, transcript them to receive x (26-27)
178transcript = Self::transcript_L_R(transcript, *L_vec.last().unwrap(), *R_vec.last().unwrap());
179let x = transcript;
180let x_inv = x.invert();
181182// The prover and verifier now calculate the following (28-31)
183g_bold = PointVector(Vec::with_capacity(g_bold1.len()));
184for (a, b) in g_bold1.0.into_iter().zip(g_bold2.0.into_iter()) {
185 g_bold.0.push(multiexp_vartime(&[(x_inv, a), (x, b)]));
186 }
187 h_bold = PointVector(Vec::with_capacity(h_bold1.len()));
188for (a, b) in h_bold1.0.into_iter().zip(h_bold2.0.into_iter()) {
189 h_bold.0.push(multiexp_vartime(&[(x, a), (x_inv, b)]));
190 }
191192// 32-34
193a = (a1 * x) + &(a2 * x_inv);
194 b = (b1 * x_inv) + &(b2 * x);
195 }
196197// `if n = 1` case from line 14-17
198199 // Sanity
200debug_assert_eq!(g_bold.len(), 1);
201debug_assert_eq!(h_bold.len(), 1);
202debug_assert_eq!(a.len(), 1);
203debug_assert_eq!(b.len(), 1);
204205// We simply send a/b
206Ok(IpProof { L: L_vec, R: R_vec, a: a[0], b: b[0] })
207 }
208209/// Queue an Inner-Product proof for batch verification.
210 ///
211 /// This will return Err if there is an error. This will return Ok if the proof was successfully
212 /// queued for batch verification. The caller is required to verify the batch in order to ensure
213 /// the proof is actually correct.
214pub(crate) fn verify(
215self,
216 verifier: &mut BulletproofsBatchVerifier,
217 ip_rows: usize,
218mut transcript: Scalar,
219 verifier_weight: Scalar,
220 proof: IpProof,
221 ) -> Result<(), IpError> {
222let generators = &crate::original::GENERATORS;
223let g_bold_slice = &generators.G[.. ip_rows];
224let h_bold_slice = &generators.H[.. ip_rows];
225226let IpStatement { h_bold_weights, u } = self;
227228// Verify the L/R lengths
229{
230// Calculate the discrete log w.r.t. 2 for the amount of generators present
231let mut lr_len = 0;
232while (1 << lr_len) < g_bold_slice.len() {
233 lr_len += 1;
234 }
235236// This proof has less/more terms than the passed in generators are for
237if proof.L.len() != lr_len {
238Err(IpError::IncorrectAmountOfGenerators)?;
239 }
240if proof.L.len() != proof.R.len() {
241Err(IpError::DifferingLrLengths)?;
242 }
243 }
244245// Again, we start with the `else: (n > 1)` case
246247 // We need x, x_inv per lines 25-27 for lines 28-31
248let mut xs = Vec::with_capacity(proof.L.len());
249for (L, R) in proof.L.iter().zip(proof.R.iter()) {
250 transcript = Self::transcript_L_R(transcript, *L, *R);
251 xs.push(transcript);
252 }
253254// We calculate their inverse in batch
255let mut x_invs = xs.clone();
256 Scalar::batch_invert(&mut x_invs);
257258// Now, with x and x_inv, we need to calculate g_bold', h_bold', P'
259 //
260 // For the sake of performance, we solely want to calculate all of these in terms of scalings
261 // for g_bold, h_bold, P, and don't want to actually perform intermediary scalings of the
262 // points
263 //
264 // L and R are easy, as it's simply x**2, x**-2
265 //
266 // For the series of g_bold, h_bold, we use the `challenge_products` function
267 // For how that works, please see its own documentation
268let product_cache = {
269let mut challenges = Vec::with_capacity(proof.L.len());
270271let x_iter = xs.into_iter().zip(x_invs);
272let lr_iter = proof.L.into_iter().zip(proof.R);
273for ((x, x_inv), (L, R)) in x_iter.zip(lr_iter) {
274 challenges.push((x, x_inv));
275276// TODO: create proper error
277let L = decompress_point(L)
278 .map(|p| EdwardsPoint::mul_by_cofactor(&p))
279 .ok_or(IpError::DifferingLrLengths)?;
280let R = decompress_point(R)
281 .map(|p| EdwardsPoint::mul_by_cofactor(&p))
282 .ok_or(IpError::DifferingLrLengths)?;
283284 verifier.0.other.push((verifier_weight * (x * x), L));
285 verifier.0.other.push((verifier_weight * (x_inv * x_inv), R));
286 }
287288 challenge_products(&challenges)
289 };
290291// And now for the `if n = 1` case
292let c = proof.a * proof.b;
293294// The multiexp of these terms equate to the final permutation of P
295 // We now add terms for a * g_bold' + b * h_bold' b + c * u, with the scalars negative such
296 // that the terms sum to 0 for an honest prover
297298 // The g_bold * a term case from line 16
299#[allow(clippy::needless_range_loop)]
300for i in 0 .. g_bold_slice.len() {
301 verifier.0.g_bold[i] -= verifier_weight * product_cache[i] * proof.a;
302 }
303// The h_bold * b term case from line 16
304for i in 0 .. h_bold_slice.len() {
305 verifier.0.h_bold[i] -=
306 verifier_weight * product_cache[product_cache.len() - 1 - i] * proof.b * h_bold_weights[i];
307 }
308// The c * u term case from line 16
309verifier.0.h -= verifier_weight * c * u;
310311Ok(())
312 }
313}