monero_bulletproofs/original/
inner_product.rs1use std_shims::{vec, vec::Vec};
2
3use zeroize::Zeroize;
4
5use curve25519_dalek::{Scalar, EdwardsPoint};
6use monero_generators::H;
7use monero_io::{CompressedPoint};
8use monero_primitives::{INV_EIGHT, keccak256_to_scalar};
9use crate::{
10 core::{multiexp_vartime, challenge_products},
11 scalar_vector::ScalarVector,
12 point_vector::PointVector,
13 BulletproofsBatchVerifier,
14};
15
16#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18pub(crate) enum IpError {
19 IncorrectAmountOfGenerators,
20 DifferingLrLengths,
21 InvalidPoint,
22}
23
24#[derive(Clone, Debug)]
28pub(crate) struct IpStatement {
29 h_bold_weights: ScalarVector,
31 u: Scalar,
33}
34
35#[derive(Clone, Debug)]
37pub(crate) struct IpWitness {
38 a: ScalarVector,
40 b: ScalarVector,
42}
43
44impl IpWitness {
45 pub(crate) fn new(a: ScalarVector, b: ScalarVector) -> Option<Self> {
50 if a.0.is_empty() || (a.len() != b.len()) {
51 None?;
52 }
53
54 let mut power_of_2 = 1;
55 while power_of_2 < a.len() {
56 power_of_2 <<= 1;
57 }
58 if power_of_2 != a.len() {
59 None?;
60 }
61
62 Some(Self { a, b })
63 }
64}
65
66#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
68pub(crate) struct IpProof {
69 pub(crate) L: Vec<CompressedPoint>,
70 pub(crate) R: Vec<CompressedPoint>,
71 pub(crate) a: Scalar,
72 pub(crate) b: Scalar,
73}
74
75impl IpStatement {
76 pub(crate) fn new_without_P_transcript(h_bold_weights: ScalarVector, u: Scalar) -> Self {
80 Self { h_bold_weights, u }
81 }
82
83 fn transcript_L_R(transcript: Scalar, L: CompressedPoint, R: CompressedPoint) -> Scalar {
85 let mut transcript = transcript.to_bytes().to_vec();
86 transcript.extend_from_slice(L.as_bytes());
87 transcript.extend_from_slice(R.as_bytes());
88 keccak256_to_scalar(transcript)
89 }
90
91 pub(crate) fn prove(
96 self,
97 mut transcript: Scalar,
98 witness: IpWitness,
99 ) -> Result<IpProof, IpError> {
100 let generators = &crate::original::GENERATORS;
101 let g_bold_slice = &generators.G[.. witness.a.len()];
102 let h_bold_slice = &generators.H[.. witness.a.len()];
103
104 let (mut g_bold, mut h_bold, u, mut a, mut b) = {
105 let IpStatement { h_bold_weights, u } = self;
106 let u = *H * u;
107
108 if h_bold_weights.len() != g_bold_slice.len() {
110 Err(IpError::IncorrectAmountOfGenerators)?;
111 }
112 let g_bold = PointVector(g_bold_slice.to_vec());
114 let h_bold = PointVector(h_bold_slice.to_vec()).mul_vec(&h_bold_weights);
115
116 let IpWitness { a, b } = witness;
117
118 (g_bold, h_bold, u, a, b)
119 };
120
121 let mut L_vec = vec![];
122 let mut R_vec = vec![];
123
124 while g_bold.len() > 1 {
127 let (a1, a2) = a.clone().split();
129 let (b1, b2) = b.clone().split();
130
131 let (g_bold1, g_bold2) = g_bold.split();
132 let (h_bold1, h_bold2) = h_bold.split();
133
134 let n_hat = g_bold1.len();
135
136 debug_assert_eq!(a1.len(), n_hat);
138 debug_assert_eq!(a2.len(), n_hat);
139 debug_assert_eq!(b1.len(), n_hat);
140 debug_assert_eq!(b2.len(), n_hat);
141 debug_assert_eq!(g_bold1.len(), n_hat);
142 debug_assert_eq!(g_bold2.len(), n_hat);
143 debug_assert_eq!(h_bold1.len(), n_hat);
144 debug_assert_eq!(h_bold2.len(), n_hat);
145
146 let cl = a1.clone().inner_product(&b2);
148 let cr = a2.clone().inner_product(&b1);
149
150 let L = {
151 let mut L_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
152 for (a, g) in a1.0.iter().zip(g_bold2.0.iter()) {
153 L_terms.push((*a, *g));
154 }
155 for (b, h) in b2.0.iter().zip(h_bold1.0.iter()) {
156 L_terms.push((*b, *h));
157 }
158 L_terms.push((cl, u));
159 multiexp_vartime(&L_terms)
161 };
162 L_vec.push(CompressedPoint::from((L * INV_EIGHT()).compress()));
163
164 let R = {
165 let mut R_terms = Vec::with_capacity(1 + (2 * g_bold1.len()));
166 for (a, g) in a2.0.iter().zip(g_bold1.0.iter()) {
167 R_terms.push((*a, *g));
168 }
169 for (b, h) in b1.0.iter().zip(h_bold2.0.iter()) {
170 R_terms.push((*b, *h));
171 }
172 R_terms.push((cr, u));
173 multiexp_vartime(&R_terms)
174 };
175 R_vec.push(CompressedPoint::from((R * INV_EIGHT()).compress()));
176
177 transcript = Self::transcript_L_R(
179 transcript,
180 *L_vec.last().expect("couldn't get last L_vec despite always being non-empty"),
181 *R_vec.last().expect("couldn't get last R_vec despite always being non-empty"),
182 );
183 let x = transcript;
184 let x_inv = x.invert();
185
186 g_bold = PointVector(Vec::with_capacity(g_bold1.len()));
188 for (a, b) in g_bold1.0.into_iter().zip(g_bold2.0.into_iter()) {
189 g_bold.0.push(multiexp_vartime(&[(x_inv, a), (x, b)]));
190 }
191 h_bold = PointVector(Vec::with_capacity(h_bold1.len()));
192 for (a, b) in h_bold1.0.into_iter().zip(h_bold2.0.into_iter()) {
193 h_bold.0.push(multiexp_vartime(&[(x, a), (x_inv, b)]));
194 }
195
196 a = (a1 * x) + &(a2 * x_inv);
198 b = (b1 * x_inv) + &(b2 * x);
199 }
200
201 debug_assert_eq!(g_bold.len(), 1);
205 debug_assert_eq!(h_bold.len(), 1);
206 debug_assert_eq!(a.len(), 1);
207 debug_assert_eq!(b.len(), 1);
208
209 Ok(IpProof { L: L_vec, R: R_vec, a: a[0], b: b[0] })
211 }
212
213 pub(crate) fn verify(
219 self,
220 verifier: &mut BulletproofsBatchVerifier,
221 ip_rows: usize,
222 mut transcript: Scalar,
223 verifier_weight: Scalar,
224 proof: IpProof,
225 ) -> Result<(), IpError> {
226 let generators = &crate::original::GENERATORS;
227 let g_bold_slice = &generators.G[.. ip_rows];
228 let h_bold_slice = &generators.H[.. ip_rows];
229
230 let IpStatement { h_bold_weights, u } = self;
231
232 {
234 let mut lr_len = 0;
236 while (1 << lr_len) < g_bold_slice.len() {
237 lr_len += 1;
238 }
239
240 if proof.L.len() != lr_len {
242 Err(IpError::IncorrectAmountOfGenerators)?;
243 }
244 if proof.L.len() != proof.R.len() {
245 Err(IpError::DifferingLrLengths)?;
246 }
247 }
248
249 let mut xs = Vec::with_capacity(proof.L.len());
253 for (L, R) in proof.L.iter().zip(proof.R.iter()) {
254 transcript = Self::transcript_L_R(transcript, *L, *R);
255 xs.push(transcript);
256 }
257
258 let mut x_invs = xs.clone();
260 Scalar::batch_invert(&mut x_invs);
261
262 let product_cache = {
273 let mut challenges = Vec::with_capacity(proof.L.len());
274
275 let x_iter = xs.into_iter().zip(x_invs);
276 let lr_iter = proof.L.into_iter().zip(proof.R);
277 for ((x, x_inv), (L, R)) in x_iter.zip(lr_iter) {
278 challenges.push((x, x_inv));
279
280 let L =
281 L.decompress().map(|p| EdwardsPoint::mul_by_cofactor(&p)).ok_or(IpError::InvalidPoint)?;
282 let R =
283 R.decompress().map(|p| EdwardsPoint::mul_by_cofactor(&p)).ok_or(IpError::InvalidPoint)?;
284
285 verifier.0.other.push((verifier_weight * (x * x), L));
286 verifier.0.other.push((verifier_weight * (x_inv * x_inv), R));
287 }
288
289 challenge_products(&challenges)
290 };
291
292 let c = proof.a * proof.b;
294
295 #[allow(clippy::needless_range_loop)]
301 for i in 0 .. g_bold_slice.len() {
302 verifier.0.g_bold[i] -= verifier_weight * product_cache[i] * proof.a;
303 }
304 for i in 0 .. h_bold_slice.len() {
306 verifier.0.h_bold[i] -=
307 verifier_weight * product_cache[product_cache.len() - 1 - i] * proof.b * h_bold_weights[i];
308 }
309 verifier.0.h -= verifier_weight * c * u;
311
312 Ok(())
313 }
314}