1use std_shims::{vec, vec::Vec};
2
3use rand_core::{RngCore, CryptoRng};
4use zeroize::{Zeroize, ZeroizeOnDrop};
5
6use curve25519_dalek::{scalar::Scalar, edwards::EdwardsPoint};
7
8use monero_io::CompressedPoint;
9use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
10use crate::{
11 core::{multiexp, multiexp_vartime, challenge_products},
12 batch_verifier::BulletproofsPlusBatchVerifier,
13 plus::{ScalarVector, PointVector, GeneratorsList, BpPlusGenerators, padded_pow_of_2},
14};
15
16#[derive(Clone, Debug)]
18pub(crate) struct WipStatement {
19 generators: BpPlusGenerators,
20 P: EdwardsPoint,
21 y: ScalarVector,
22}
23
24impl Zeroize for WipStatement {
25 fn zeroize(&mut self) {
26 self.P.zeroize();
27 self.y.zeroize();
28 }
29}
30
31#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
32pub(crate) struct WipWitness {
33 a: ScalarVector,
34 b: ScalarVector,
35 alpha: Scalar,
36}
37
38impl WipWitness {
39 pub(crate) fn new(mut a: ScalarVector, mut b: ScalarVector, alpha: Scalar) -> Option<Self> {
40 if a.0.is_empty() || (a.len() != b.len()) {
41 return None;
42 }
43
44 let missing = padded_pow_of_2(a.len()) - a.len();
46 a.0.reserve(missing);
47 b.0.reserve(missing);
48 for _ in 0 .. missing {
49 a.0.push(Scalar::ZERO);
50 b.0.push(Scalar::ZERO);
51 }
52
53 Some(Self { a, b, alpha })
54 }
55}
56
57#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
58pub(crate) struct WipProof {
59 pub(crate) L: Vec<CompressedPoint>,
60 pub(crate) R: Vec<CompressedPoint>,
61 pub(crate) A: CompressedPoint,
62 pub(crate) B: CompressedPoint,
63 pub(crate) r_answer: Scalar,
64 pub(crate) s_answer: Scalar,
65 pub(crate) delta_answer: Scalar,
66}
67
68impl WipStatement {
69 pub(crate) fn new(generators: BpPlusGenerators, P: EdwardsPoint, y: Scalar) -> Self {
70 debug_assert_eq!(generators.len(), padded_pow_of_2(generators.len()));
71
72 let mut y_vec = ScalarVector::new(generators.len());
74 y_vec[0] = y;
75 for i in 1 .. y_vec.len() {
76 y_vec[i] = y_vec[i - 1] * y;
77 }
78
79 Self { generators, P, y: y_vec }
80 }
81
82 fn transcript_L_R(transcript: &mut Scalar, L: CompressedPoint, R: CompressedPoint) -> Scalar {
83 let e = keccak256_to_scalar([transcript.to_bytes(), L.to_bytes(), R.to_bytes()].concat());
84 *transcript = e;
85 e
86 }
87
88 fn transcript_A_B(transcript: &mut Scalar, A: CompressedPoint, B: CompressedPoint) -> Scalar {
89 let e = keccak256_to_scalar([transcript.to_bytes(), A.to_bytes(), B.to_bytes()].concat());
90 *transcript = e;
91 e
92 }
93
94 #[allow(clippy::too_many_arguments)]
98 fn next_G_H(
99 transcript: &mut Scalar,
100 mut g_bold1: PointVector,
101 mut g_bold2: PointVector,
102 mut h_bold1: PointVector,
103 mut h_bold2: PointVector,
104 L: CompressedPoint,
105 R: CompressedPoint,
106 y_inv_n_hat: Scalar,
107 ) -> (Scalar, Scalar, Scalar, Scalar, PointVector, PointVector) {
108 debug_assert_eq!(g_bold1.len(), g_bold2.len());
109 debug_assert_eq!(h_bold1.len(), h_bold2.len());
110 debug_assert_eq!(g_bold1.len(), h_bold1.len());
111
112 let e = Self::transcript_L_R(transcript, L, R);
113 let inv_e = e.invert();
114
115 let mut new_g_bold = Vec::with_capacity(g_bold1.len());
117 let e_y_inv = e * y_inv_n_hat;
118 for g_bold in g_bold1.0.drain(..).zip(g_bold2.0.drain(..)) {
119 new_g_bold.push(multiexp_vartime(&[(inv_e, g_bold.0), (e_y_inv, g_bold.1)]));
120 }
121
122 let mut new_h_bold = Vec::with_capacity(h_bold1.len());
123 for h_bold in h_bold1.0.drain(..).zip(h_bold2.0.drain(..)) {
124 new_h_bold.push(multiexp_vartime(&[(e, h_bold.0), (inv_e, h_bold.1)]));
125 }
126
127 let e_square = e * e;
128 let inv_e_square = inv_e * inv_e;
129
130 (e, inv_e, e_square, inv_e_square, PointVector(new_g_bold), PointVector(new_h_bold))
131 }
132
133 pub(crate) fn prove<R: RngCore + CryptoRng>(
134 self,
135 rng: &mut R,
136 mut transcript: Scalar,
137 witness: &WipWitness,
138 ) -> Option<WipProof> {
139 let WipStatement { generators, P, mut y } = self;
140 #[cfg(not(debug_assertions))]
141 let _ = P;
142
143 if generators.len() != witness.a.len() {
144 return None;
145 }
146 let (g, h) = (BpPlusGenerators::g(), BpPlusGenerators::h());
147 let mut g_bold = vec![];
148 let mut h_bold = vec![];
149 for i in 0 .. generators.len() {
150 g_bold.push(generators.generator(GeneratorsList::GBold, i));
151 h_bold.push(generators.generator(GeneratorsList::HBold, i));
152 }
153 let mut g_bold = PointVector(g_bold);
154 let mut h_bold = PointVector(h_bold);
155
156 let mut y_inv = {
157 let mut i = 1;
158 let mut to_invert = vec![];
159 while i < g_bold.len() {
160 to_invert.push(y[i - 1]);
161 i *= 2;
162 }
163 Scalar::batch_invert(&mut to_invert);
164 to_invert
165 };
166
167 #[cfg(debug_assertions)]
169 {
170 let mut P_terms = witness
171 .a
172 .0
173 .iter()
174 .copied()
175 .zip(g_bold.0.iter().copied())
176 .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied()))
177 .collect::<Vec<_>>();
178 P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g));
179 P_terms.push((witness.alpha, h));
180 debug_assert_eq!(multiexp(&P_terms), P);
181 P_terms.zeroize();
182 }
183
184 let mut a = witness.a.clone();
185 let mut b = witness.b.clone();
186 let mut alpha = witness.alpha;
187
188 debug_assert_eq!(g_bold.len(), a.len());
190
191 let mut L_vec = vec![];
192 let mut R_vec = vec![];
193
194 while g_bold.len() > 1 {
196 let (a1, a2) = a.clone().split();
197 let (b1, b2) = b.clone().split();
198 let (g_bold1, g_bold2) = g_bold.split();
199 let (h_bold1, h_bold2) = h_bold.split();
200
201 let n_hat = g_bold1.len();
202 debug_assert_eq!(a1.len(), n_hat);
203 debug_assert_eq!(a2.len(), n_hat);
204 debug_assert_eq!(b1.len(), n_hat);
205 debug_assert_eq!(b2.len(), n_hat);
206 debug_assert_eq!(g_bold1.len(), n_hat);
207 debug_assert_eq!(g_bold2.len(), n_hat);
208 debug_assert_eq!(h_bold1.len(), n_hat);
209 debug_assert_eq!(h_bold2.len(), n_hat);
210
211 let y_n_hat = y[n_hat - 1];
212 y.0.truncate(n_hat);
213
214 let d_l = Scalar::random(&mut *rng);
215 let d_r = Scalar::random(&mut *rng);
216
217 let c_l = a1.clone().weighted_inner_product(&b2, &y);
218 let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);
219
220 let y_inv_n_hat = y_inv
221 .pop()
222 .expect("couldn't pop y_inv despite y_inv being of same length as times iterated");
223
224 let mut L_terms = (a1.clone() * y_inv_n_hat)
225 .0
226 .drain(..)
227 .zip(g_bold2.0.iter().copied())
228 .chain(b2.0.iter().copied().zip(h_bold1.0.iter().copied()))
229 .collect::<Vec<_>>();
230 L_terms.push((c_l, g));
231 L_terms.push((d_l, h));
232 let L = CompressedPoint::from((multiexp(&L_terms) * INV_EIGHT()).compress());
233 L_vec.push(L);
234 L_terms.zeroize();
235
236 let mut R_terms = (a2.clone() * y_n_hat)
237 .0
238 .drain(..)
239 .zip(g_bold1.0.iter().copied())
240 .chain(b1.0.iter().copied().zip(h_bold2.0.iter().copied()))
241 .collect::<Vec<_>>();
242 R_terms.push((c_r, g));
243 R_terms.push((d_r, h));
244 let R = CompressedPoint::from((multiexp(&R_terms) * INV_EIGHT()).compress());
245 R_vec.push(R);
246 R_terms.zeroize();
247
248 let (e, inv_e, e_square, inv_e_square);
249 (e, inv_e, e_square, inv_e_square, g_bold, h_bold) =
250 Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat);
251
252 a = (a1 * e) + &(a2 * (y_n_hat * inv_e));
253 b = (b1 * inv_e) + &(b2 * e);
254 alpha += (d_l * e_square) + (d_r * inv_e_square);
255
256 debug_assert_eq!(g_bold.len(), a.len());
257 debug_assert_eq!(g_bold.len(), h_bold.len());
258 debug_assert_eq!(g_bold.len(), b.len());
259 }
260
261 debug_assert_eq!(g_bold.len(), 1);
263 debug_assert_eq!(h_bold.len(), 1);
264
265 debug_assert_eq!(a.len(), 1);
266 debug_assert_eq!(b.len(), 1);
267
268 let r = Scalar::random(&mut *rng);
269 let s = Scalar::random(&mut *rng);
270 let delta = Scalar::random(&mut *rng);
271 let eta = Scalar::random(&mut *rng);
272
273 let ry = r * y[0];
274
275 let mut A_terms =
276 vec![(r, g_bold[0]), (s, h_bold[0]), ((ry * b[0]) + (s * y[0] * a[0]), g), (delta, h)];
277 let A = CompressedPoint::from((multiexp(&A_terms) * INV_EIGHT()).compress());
278 A_terms.zeroize();
279
280 let mut B_terms = vec![(ry * s, g), (eta, h)];
281 let B = CompressedPoint::from((multiexp(&B_terms) * INV_EIGHT()).compress());
282 B_terms.zeroize();
283
284 let e = Self::transcript_A_B(&mut transcript, A, B);
285
286 let r_answer = r + (a[0] * e);
287 let s_answer = s + (b[0] * e);
288 let delta_answer = eta + (delta * e) + (alpha * (e * e));
289
290 Some(WipProof { L: L_vec, R: R_vec, A, B, r_answer, s_answer, delta_answer })
291 }
292
293 pub(crate) fn verify<R: RngCore + CryptoRng>(
294 self,
295 rng: &mut R,
296 verifier: &mut BulletproofsPlusBatchVerifier,
297 mut transcript: Scalar,
298 WipProof { L, R, A, B, r_answer, s_answer, delta_answer }: WipProof,
299 ) -> bool {
300 let verifier_weight = Scalar::random(rng);
301
302 let WipStatement { generators, P, y } = self;
303
304 {
306 let mut lr_len = 0;
307 while (1 << lr_len) < generators.len() {
308 lr_len += 1;
309 }
310 if (L.len() != lr_len) || (R.len() != lr_len) || (generators.len() != (1 << lr_len)) {
311 return false;
312 }
313 }
314
315 let inv_y = {
316 let inv_y = y[0].invert();
317 let mut res = Vec::with_capacity(y.len());
318 res.push(inv_y);
319 while res.len() < y.len() {
320 res.push(
321 inv_y * res.last().expect("couldn't get last inv_y despite inv_y always being non-empty"),
322 );
323 }
324 res
325 };
326
327 let mut e_is = Vec::with_capacity(L.len());
328 let mut L_decomp = Vec::with_capacity(L.len());
329 let mut R_decomp = Vec::with_capacity(R.len());
330
331 let decomp_mul_cofactor =
332 |p| CompressedPoint::decompress(&p).map(|p| EdwardsPoint::mul_by_cofactor(&p));
333
334 for (L_i, R_i) in L.into_iter().zip(R.into_iter()) {
335 e_is.push(Self::transcript_L_R(&mut transcript, L_i, R_i));
336
337 let (Some(L_i), Some(R_i)) = (decomp_mul_cofactor(L_i), decomp_mul_cofactor(R_i)) else {
338 return false;
339 };
340
341 L_decomp.push(L_i);
342 R_decomp.push(R_i);
343 }
344
345 let L = L_decomp;
346 let R = R_decomp;
347
348 let e = Self::transcript_A_B(&mut transcript, A, B);
349
350 let (Some(A), Some(B)) = (decomp_mul_cofactor(A), decomp_mul_cofactor(B)) else {
351 return false;
352 };
353
354 let neg_e_square = verifier_weight * -(e * e);
355
356 verifier.0.other.push((neg_e_square, P));
357
358 let mut challenges = Vec::with_capacity(L.len());
359 let product_cache = {
360 let mut inv_e_is = e_is.clone();
361 Scalar::batch_invert(&mut inv_e_is);
362
363 debug_assert_eq!(e_is.len(), inv_e_is.len());
364 debug_assert_eq!(e_is.len(), L.len());
365 debug_assert_eq!(e_is.len(), R.len());
366 for ((e_i, inv_e_i), (L, R)) in
367 e_is.drain(..).zip(inv_e_is.drain(..)).zip(L.iter().zip(R.iter()))
368 {
369 debug_assert_eq!(e_i.invert(), inv_e_i);
370
371 challenges.push((e_i, inv_e_i));
372
373 let e_i_square = e_i * e_i;
374 let inv_e_i_square = inv_e_i * inv_e_i;
375 verifier.0.other.push((neg_e_square * e_i_square, *L));
376 verifier.0.other.push((neg_e_square * inv_e_i_square, *R));
377 }
378
379 challenge_products(&challenges)
380 };
381
382 while verifier.0.g_bold.len() < generators.len() {
383 verifier.0.g_bold.push(Scalar::ZERO);
384 }
385 while verifier.0.h_bold.len() < generators.len() {
386 verifier.0.h_bold.push(Scalar::ZERO);
387 }
388
389 let re = r_answer * e;
390 for i in 0 .. generators.len() {
391 let mut scalar = product_cache[i] * re;
392 if i > 0 {
393 scalar *= inv_y[i - 1];
394 }
395 verifier.0.g_bold[i] += verifier_weight * scalar;
396 }
397
398 let se = s_answer * e;
399 for i in 0 .. generators.len() {
400 verifier.0.h_bold[i] += verifier_weight * (se * product_cache[product_cache.len() - 1 - i]);
401 }
402
403 verifier.0.other.push((verifier_weight * -e, A));
404 verifier.0.g += verifier_weight * (r_answer * y[0] * s_answer);
405 verifier.0.h += verifier_weight * delta_answer;
406 verifier.0.other.push((-verifier_weight, B));
407
408 true
409 }
410}