1use std_shims::{vec, vec::Vec};
2
3use rand_core::{RngCore, CryptoRng};
4use zeroize::{Zeroize, ZeroizeOnDrop};
5
6use curve25519_dalek::{scalar::Scalar, edwards::EdwardsPoint};
7use curve25519_dalek::edwards::CompressedEdwardsY;
8use monero_io::decompress_point;
9use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
10use crate::{
11 core::{multiexp, multiexp_vartime, challenge_products},
12 batch_verifier::BulletproofsPlusBatchVerifier,
13 plus::{ScalarVector, PointVector, GeneratorsList, BpPlusGenerators, padded_pow_of_2},
14};
15
16#[derive(Clone, Debug)]
18pub(crate) struct WipStatement {
19 generators: BpPlusGenerators,
20 P: EdwardsPoint,
21 y: ScalarVector,
22}
23
24impl Zeroize for WipStatement {
25 fn zeroize(&mut self) {
26 self.P.zeroize();
27 self.y.zeroize();
28 }
29}
30
31#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
32pub(crate) struct WipWitness {
33 a: ScalarVector,
34 b: ScalarVector,
35 alpha: Scalar,
36}
37
38impl WipWitness {
39 pub(crate) fn new(mut a: ScalarVector, mut b: ScalarVector, alpha: Scalar) -> Option<Self> {
40 if a.0.is_empty() || (a.len() != b.len()) {
41 return None;
42 }
43
44 let missing = padded_pow_of_2(a.len()) - a.len();
46 a.0.reserve(missing);
47 b.0.reserve(missing);
48 for _ in 0 .. missing {
49 a.0.push(Scalar::ZERO);
50 b.0.push(Scalar::ZERO);
51 }
52
53 Some(Self { a, b, alpha })
54 }
55}
56
57#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
58pub(crate) struct WipProof {
59 pub(crate) L: Vec<CompressedEdwardsY>,
60 pub(crate) R: Vec<CompressedEdwardsY>,
61 pub(crate) A: CompressedEdwardsY,
62 pub(crate) B: CompressedEdwardsY,
63 pub(crate) r_answer: Scalar,
64 pub(crate) s_answer: Scalar,
65 pub(crate) delta_answer: Scalar,
66}
67
68impl WipStatement {
69 pub(crate) fn new(generators: BpPlusGenerators, P: EdwardsPoint, y: Scalar) -> Self {
70 debug_assert_eq!(generators.len(), padded_pow_of_2(generators.len()));
71
72 let mut y_vec = ScalarVector::new(generators.len());
74 y_vec[0] = y;
75 for i in 1 .. y_vec.len() {
76 y_vec[i] = y_vec[i - 1] * y;
77 }
78
79 Self { generators, P, y: y_vec }
80 }
81
82 fn transcript_L_R(
83 transcript: &mut Scalar,
84 L: CompressedEdwardsY,
85 R: CompressedEdwardsY,
86 ) -> Scalar {
87 let e = keccak256_to_scalar(
88 [transcript.as_bytes().as_ref(), L.as_bytes().as_ref(), R.as_bytes().as_ref()].concat(),
89 );
90 *transcript = e;
91 e
92 }
93
94 fn transcript_A_B(
95 transcript: &mut Scalar,
96 A: CompressedEdwardsY,
97 B: CompressedEdwardsY,
98 ) -> Scalar {
99 let e = keccak256_to_scalar(
100 [transcript.as_bytes().as_ref(), A.as_bytes().as_ref(), B.as_bytes().as_ref()].concat(),
101 );
102 *transcript = e;
103 e
104 }
105
106 #[allow(clippy::too_many_arguments)]
110 fn next_G_H(
111 transcript: &mut Scalar,
112 mut g_bold1: PointVector,
113 mut g_bold2: PointVector,
114 mut h_bold1: PointVector,
115 mut h_bold2: PointVector,
116 L: CompressedEdwardsY,
117 R: CompressedEdwardsY,
118 y_inv_n_hat: Scalar,
119 ) -> (Scalar, Scalar, Scalar, Scalar, PointVector, PointVector) {
120 debug_assert_eq!(g_bold1.len(), g_bold2.len());
121 debug_assert_eq!(h_bold1.len(), h_bold2.len());
122 debug_assert_eq!(g_bold1.len(), h_bold1.len());
123
124 let e = Self::transcript_L_R(transcript, L, R);
125 let inv_e = e.invert();
126
127 let mut new_g_bold = Vec::with_capacity(g_bold1.len());
129 let e_y_inv = e * y_inv_n_hat;
130 for g_bold in g_bold1.0.drain(..).zip(g_bold2.0.drain(..)) {
131 new_g_bold.push(multiexp_vartime(&[(inv_e, g_bold.0), (e_y_inv, g_bold.1)]));
132 }
133
134 let mut new_h_bold = Vec::with_capacity(h_bold1.len());
135 for h_bold in h_bold1.0.drain(..).zip(h_bold2.0.drain(..)) {
136 new_h_bold.push(multiexp_vartime(&[(e, h_bold.0), (inv_e, h_bold.1)]));
137 }
138
139 let e_square = e * e;
140 let inv_e_square = inv_e * inv_e;
141
142 (e, inv_e, e_square, inv_e_square, PointVector(new_g_bold), PointVector(new_h_bold))
143 }
144
145 pub(crate) fn prove<R: RngCore + CryptoRng>(
146 self,
147 rng: &mut R,
148 mut transcript: Scalar,
149 witness: &WipWitness,
150 ) -> Option<WipProof> {
151 let WipStatement { generators, P, mut y } = self;
152 #[cfg(not(debug_assertions))]
153 let _ = P;
154
155 if generators.len() != witness.a.len() {
156 return None;
157 }
158 let (g, h) = (BpPlusGenerators::g(), BpPlusGenerators::h());
159 let mut g_bold = vec![];
160 let mut h_bold = vec![];
161 for i in 0 .. generators.len() {
162 g_bold.push(generators.generator(GeneratorsList::GBold, i));
163 h_bold.push(generators.generator(GeneratorsList::HBold, i));
164 }
165 let mut g_bold = PointVector(g_bold);
166 let mut h_bold = PointVector(h_bold);
167
168 let mut y_inv = {
169 let mut i = 1;
170 let mut to_invert = vec![];
171 while i < g_bold.len() {
172 to_invert.push(y[i - 1]);
173 i *= 2;
174 }
175 Scalar::batch_invert(&mut to_invert);
176 to_invert
177 };
178
179 #[cfg(debug_assertions)]
181 {
182 let mut P_terms = witness
183 .a
184 .0
185 .iter()
186 .copied()
187 .zip(g_bold.0.iter().copied())
188 .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied()))
189 .collect::<Vec<_>>();
190 P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g));
191 P_terms.push((witness.alpha, h));
192 debug_assert_eq!(multiexp(&P_terms), P);
193 P_terms.zeroize();
194 }
195
196 let mut a = witness.a.clone();
197 let mut b = witness.b.clone();
198 let mut alpha = witness.alpha;
199
200 debug_assert_eq!(g_bold.len(), a.len());
202
203 let mut L_vec = vec![];
204 let mut R_vec = vec![];
205
206 while g_bold.len() > 1 {
208 let (a1, a2) = a.clone().split();
209 let (b1, b2) = b.clone().split();
210 let (g_bold1, g_bold2) = g_bold.split();
211 let (h_bold1, h_bold2) = h_bold.split();
212
213 let n_hat = g_bold1.len();
214 debug_assert_eq!(a1.len(), n_hat);
215 debug_assert_eq!(a2.len(), n_hat);
216 debug_assert_eq!(b1.len(), n_hat);
217 debug_assert_eq!(b2.len(), n_hat);
218 debug_assert_eq!(g_bold1.len(), n_hat);
219 debug_assert_eq!(g_bold2.len(), n_hat);
220 debug_assert_eq!(h_bold1.len(), n_hat);
221 debug_assert_eq!(h_bold2.len(), n_hat);
222
223 let y_n_hat = y[n_hat - 1];
224 y.0.truncate(n_hat);
225
226 let d_l = Scalar::random(&mut *rng);
227 let d_r = Scalar::random(&mut *rng);
228
229 let c_l = a1.clone().weighted_inner_product(&b2, &y);
230 let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);
231
232 let y_inv_n_hat = y_inv.pop().unwrap();
233
234 let mut L_terms = (a1.clone() * y_inv_n_hat)
235 .0
236 .drain(..)
237 .zip(g_bold2.0.iter().copied())
238 .chain(b2.0.iter().copied().zip(h_bold1.0.iter().copied()))
239 .collect::<Vec<_>>();
240 L_terms.push((c_l, g));
241 L_terms.push((d_l, h));
242 let L = (multiexp(&L_terms) * INV_EIGHT()).compress();
243 L_vec.push(L);
244 L_terms.zeroize();
245
246 let mut R_terms = (a2.clone() * y_n_hat)
247 .0
248 .drain(..)
249 .zip(g_bold1.0.iter().copied())
250 .chain(b1.0.iter().copied().zip(h_bold2.0.iter().copied()))
251 .collect::<Vec<_>>();
252 R_terms.push((c_r, g));
253 R_terms.push((d_r, h));
254 let R = (multiexp(&R_terms) * INV_EIGHT()).compress();
255 R_vec.push(R);
256 R_terms.zeroize();
257
258 let (e, inv_e, e_square, inv_e_square);
259 (e, inv_e, e_square, inv_e_square, g_bold, h_bold) =
260 Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat);
261
262 a = (a1 * e) + &(a2 * (y_n_hat * inv_e));
263 b = (b1 * inv_e) + &(b2 * e);
264 alpha += (d_l * e_square) + (d_r * inv_e_square);
265
266 debug_assert_eq!(g_bold.len(), a.len());
267 debug_assert_eq!(g_bold.len(), h_bold.len());
268 debug_assert_eq!(g_bold.len(), b.len());
269 }
270
271 debug_assert_eq!(g_bold.len(), 1);
273 debug_assert_eq!(h_bold.len(), 1);
274
275 debug_assert_eq!(a.len(), 1);
276 debug_assert_eq!(b.len(), 1);
277
278 let r = Scalar::random(&mut *rng);
279 let s = Scalar::random(&mut *rng);
280 let delta = Scalar::random(&mut *rng);
281 let eta = Scalar::random(&mut *rng);
282
283 let ry = r * y[0];
284
285 let mut A_terms =
286 vec![(r, g_bold[0]), (s, h_bold[0]), ((ry * b[0]) + (s * y[0] * a[0]), g), (delta, h)];
287 let A = (multiexp(&A_terms) * INV_EIGHT()).compress();
288 A_terms.zeroize();
289
290 let mut B_terms = vec![(ry * s, g), (eta, h)];
291 let B = (multiexp(&B_terms) * INV_EIGHT()).compress();
292 B_terms.zeroize();
293
294 let e = Self::transcript_A_B(&mut transcript, A, B);
295
296 let r_answer = r + (a[0] * e);
297 let s_answer = s + (b[0] * e);
298 let delta_answer = eta + (delta * e) + (alpha * (e * e));
299
300 Some(WipProof { L: L_vec, R: R_vec, A, B, r_answer, s_answer, delta_answer })
301 }
302
303 pub(crate) fn verify<R: RngCore + CryptoRng>(
304 self,
305 rng: &mut R,
306 verifier: &mut BulletproofsPlusBatchVerifier,
307 mut transcript: Scalar,
308 proof: WipProof,
309 ) -> bool {
310 let verifier_weight = Scalar::random(rng);
311
312 let WipStatement { generators, P, y } = self;
313
314 {
316 let mut lr_len = 0;
317 while (1 << lr_len) < generators.len() {
318 lr_len += 1;
319 }
320 if (proof.L.len() != lr_len) ||
321 (proof.R.len() != lr_len) ||
322 (generators.len() != (1 << lr_len))
323 {
324 return false;
325 }
326 }
327
328 let inv_y = {
329 let inv_y = y[0].invert();
330 let mut res = Vec::with_capacity(y.len());
331 res.push(inv_y);
332 while res.len() < y.len() {
333 res.push(inv_y * res.last().unwrap());
334 }
335 res
336 };
337
338 let mut e_is = Vec::with_capacity(proof.L.len());
339 let mut L = Vec::with_capacity(proof.L.len());
340 let mut R = Vec::with_capacity(proof.R.len());
341
342 for (L_i, R_i) in proof.L.into_iter().zip(proof.R.into_iter()) {
343 e_is.push(Self::transcript_L_R(&mut transcript, L_i, R_i));
344
345 let Some(L_i) = decompress_point(L_i).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
346 return false;
347 };
348 let Some(R_i) = decompress_point(R_i).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
349 return false;
350 };
351
352 L.push(L_i);
353 R.push(R_i);
354 }
355
356 let e = Self::transcript_A_B(&mut transcript, proof.A, proof.B);
357 let Some(A) = decompress_point(proof.A).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
358 return false;
359 };
360 let Some(B) = decompress_point(proof.B).map(|p| EdwardsPoint::mul_by_cofactor(&p)) else {
361 return false;
362 };
363 let neg_e_square = verifier_weight * -(e * e);
364
365 verifier.0.other.push((neg_e_square, P));
366
367 let mut challenges = Vec::with_capacity(L.len());
368 let product_cache = {
369 let mut inv_e_is = e_is.clone();
370 Scalar::batch_invert(&mut inv_e_is);
371
372 debug_assert_eq!(e_is.len(), inv_e_is.len());
373 debug_assert_eq!(e_is.len(), L.len());
374 debug_assert_eq!(e_is.len(), R.len());
375 for ((e_i, inv_e_i), (L, R)) in
376 e_is.drain(..).zip(inv_e_is.drain(..)).zip(L.iter().zip(R.iter()))
377 {
378 debug_assert_eq!(e_i.invert(), inv_e_i);
379
380 challenges.push((e_i, inv_e_i));
381
382 let e_i_square = e_i * e_i;
383 let inv_e_i_square = inv_e_i * inv_e_i;
384 verifier.0.other.push((neg_e_square * e_i_square, *L));
385 verifier.0.other.push((neg_e_square * inv_e_i_square, *R));
386 }
387
388 challenge_products(&challenges)
389 };
390
391 while verifier.0.g_bold.len() < generators.len() {
392 verifier.0.g_bold.push(Scalar::ZERO);
393 }
394 while verifier.0.h_bold.len() < generators.len() {
395 verifier.0.h_bold.push(Scalar::ZERO);
396 }
397
398 let re = proof.r_answer * e;
399 for i in 0 .. generators.len() {
400 let mut scalar = product_cache[i] * re;
401 if i > 0 {
402 scalar *= inv_y[i - 1];
403 }
404 verifier.0.g_bold[i] += verifier_weight * scalar;
405 }
406
407 let se = proof.s_answer * e;
408 for i in 0 .. generators.len() {
409 verifier.0.h_bold[i] += verifier_weight * (se * product_cache[product_cache.len() - 1 - i]);
410 }
411
412 verifier.0.other.push((verifier_weight * -e, A));
413 verifier.0.g += verifier_weight * (proof.r_answer * y[0] * proof.s_answer);
414 verifier.0.h += verifier_weight * proof.delta_answer;
415 verifier.0.other.push((-verifier_weight, B));
416
417 true
418 }
419}