1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![cfg_attr(not(feature = "std"), no_std)]
5
6use core::fmt::Debug;
7use std_shims::{
8 vec,
9 vec::Vec,
10 io::{self, Read, Write},
11};
12
13use curve25519_dalek::{
14 scalar::Scalar,
15 edwards::{EdwardsPoint, CompressedEdwardsY},
16};
17
18const VARINT_CONTINUATION_MASK: u8 = 0b1000_0000;
19
20mod sealed {
21 pub trait VarInt: TryInto<u64> + TryFrom<u64> + Copy {
25 const BITS: usize;
26 }
27
28 impl VarInt for u8 {
29 const BITS: usize = 8;
30 }
31 impl VarInt for u32 {
32 const BITS: usize = 32;
33 }
34 impl VarInt for u64 {
35 const BITS: usize = 64;
36 }
37 impl VarInt for usize {
38 const BITS: usize = core::mem::size_of::<usize>() * 8;
39 }
40}
41
42pub fn varint_len<V: sealed::VarInt>(varint: V) -> usize {
46 let varint_u64: u64 = varint.try_into().map_err(|_| "varint exceeded u64").unwrap();
47 ((usize::try_from(u64::BITS - varint_u64.leading_zeros()).unwrap().saturating_sub(1)) / 7) + 1
48}
49
50pub fn write_byte<W: Write>(byte: &u8, w: &mut W) -> io::Result<()> {
54 w.write_all(&[*byte])
55}
56
57pub fn write_varint<W: Write, U: sealed::VarInt>(varint: &U, w: &mut W) -> io::Result<()> {
61 let mut varint: u64 = (*varint).try_into().map_err(|_| "varint exceeded u64").unwrap();
62 while {
63 let mut b = u8::try_from(varint & u64::from(!VARINT_CONTINUATION_MASK)).unwrap();
64 varint >>= 7;
65 if varint != 0 {
66 b |= VARINT_CONTINUATION_MASK;
67 }
68 write_byte(&b, w)?;
69 varint != 0
70 } {}
71 Ok(())
72}
73
74pub fn write_scalar<W: Write>(scalar: &Scalar, w: &mut W) -> io::Result<()> {
76 w.write_all(&scalar.to_bytes())
77}
78
79pub fn write_point<W: Write>(point: &EdwardsPoint, w: &mut W) -> io::Result<()> {
81 w.write_all(&point.compress().to_bytes())
82}
83
84pub fn write_compressed_point<W: Write>(point: &CompressedEdwardsY, w: &mut W) -> io::Result<()> {
86 w.write_all(&point.0)
87}
88
89pub fn write_raw_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
91 f: F,
92 values: &[T],
93 w: &mut W,
94) -> io::Result<()> {
95 for value in values {
96 f(value, w)?;
97 }
98 Ok(())
99}
100
101pub fn write_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
103 f: F,
104 values: &[T],
105 w: &mut W,
106) -> io::Result<()> {
107 write_varint(&values.len(), w)?;
108 write_raw_vec(f, values, w)
109}
110
111pub fn read_bytes<R: Read, const N: usize>(r: &mut R) -> io::Result<[u8; N]> {
113 let mut res = [0; N];
114 r.read_exact(&mut res)?;
115 Ok(res)
116}
117
118pub fn read_byte<R: Read>(r: &mut R) -> io::Result<u8> {
120 Ok(read_bytes::<_, 1>(r)?[0])
121}
122
123pub fn read_u16<R: Read>(r: &mut R) -> io::Result<u16> {
125 read_bytes(r).map(u16::from_le_bytes)
126}
127
128pub fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
130 read_bytes(r).map(u32::from_le_bytes)
131}
132
133pub fn read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
135 read_bytes(r).map(u64::from_le_bytes)
136}
137
138pub fn read_varint<R: Read, U: sealed::VarInt>(r: &mut R) -> io::Result<U> {
140 let mut bits = 0;
141 let mut res = 0;
142 while {
143 let b = read_byte(r)?;
144 if (bits != 0) && (b == 0) {
145 Err(io::Error::other("non-canonical varint"))?;
146 }
147 if ((bits + 7) >= U::BITS) && (b >= (1 << (U::BITS - bits))) {
148 Err(io::Error::other("varint overflow"))?;
149 }
150
151 res += u64::from(b & (!VARINT_CONTINUATION_MASK)) << bits;
152 bits += 7;
153 b & VARINT_CONTINUATION_MASK == VARINT_CONTINUATION_MASK
154 } {}
155 res.try_into().map_err(|_| io::Error::other("VarInt does not fit into integer type"))
156}
157
158pub fn read_scalar<R: Read>(r: &mut R) -> io::Result<Scalar> {
163 Option::from(Scalar::from_canonical_bytes(read_bytes(r)?))
164 .ok_or_else(|| io::Error::other("unreduced scalar"))
165}
166
167pub fn decompress_point(compressed: CompressedEdwardsY) -> Option<EdwardsPoint> {
177 compressed
178 .decompress()
179 .filter(|point| point.compress().to_bytes() == compressed.to_bytes())
181}
182
183pub fn read_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
188 let bytes = read_bytes(r)?;
189 decompress_point(CompressedEdwardsY(bytes)).ok_or_else(|| io::Error::other("invalid point"))
190}
191
192pub fn read_torsion_free_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
194 read_point(r)
195 .ok()
196 .filter(EdwardsPoint::is_torsion_free)
197 .ok_or_else(|| io::Error::other("invalid point"))
198}
199
200pub fn read_compressed_point<R: Read>(r: &mut R) -> io::Result<CompressedEdwardsY> {
202 Ok(CompressedEdwardsY(read_bytes(r)?))
203}
204
205pub fn read_raw_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
207 f: F,
208 len: usize,
209 r: &mut R,
210) -> io::Result<Vec<T>> {
211 let mut res = vec![];
212 for _ in 0 .. len {
213 res.push(f(r)?);
214 }
215 Ok(res)
216}
217
218pub fn read_array<R: Read, T: Debug, F: Fn(&mut R) -> io::Result<T>, const N: usize>(
220 f: F,
221 r: &mut R,
222) -> io::Result<[T; N]> {
223 read_raw_vec(f, N, r).map(|vec| vec.try_into().unwrap())
224}
225
226pub fn read_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(f: F, r: &mut R) -> io::Result<Vec<T>> {
228 read_raw_vec(f, read_varint(r)?, r)
229}