1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![cfg_attr(not(feature = "std"), no_std)]
5
6use core::fmt::Debug;
7#[allow(unused_imports)]
8use std_shims::prelude::*;
9use std_shims::{
10 vec,
11 vec::Vec,
12 io::{self, Read, Write},
13};
14
15use curve25519_dalek::{scalar::Scalar, edwards::EdwardsPoint};
16
17mod compressed_point;
18pub use compressed_point::CompressedPoint;
19
20const VARINT_CONTINUATION_MASK: u8 = 0b1000_0000;
21
22mod sealed {
23 pub trait VarInt: TryFrom<u64> + Copy {
27 const BITS: usize;
28 fn into_u64(self) -> u64;
29 }
30
31 impl VarInt for u8 {
32 const BITS: usize = 8;
33 fn into_u64(self) -> u64 {
34 self.into()
35 }
36 }
37 impl VarInt for u32 {
38 const BITS: usize = 32;
39 fn into_u64(self) -> u64 {
40 self.into()
41 }
42 }
43 impl VarInt for u64 {
44 const BITS: usize = 64;
45 fn into_u64(self) -> u64 {
46 self
47 }
48 }
49 const _NO_128_BIT_PLATFORMS: [(); (u64::BITS - usize::BITS) as usize] =
52 [(); (u64::BITS - usize::BITS) as usize];
53 impl VarInt for usize {
54 const BITS: usize = core::mem::size_of::<usize>() * 8;
55 fn into_u64(self) -> u64 {
56 self.try_into().expect("compiling on platform with <64-bit usize yet value didn't fit in u64")
57 }
58 }
59}
60
61pub fn varint_len<V: sealed::VarInt>(varint: V) -> usize {
65 let varint_u64 = varint.into_u64();
66 ((usize::try_from(u64::BITS - varint_u64.leading_zeros())
67 .expect("64 > usize::MAX")
68 .saturating_sub(1)) /
69 7) +
70 1
71}
72
73pub fn write_byte<W: Write>(byte: &u8, w: &mut W) -> io::Result<()> {
77 w.write_all(&[*byte])
78}
79
80pub fn write_varint<W: Write, U: sealed::VarInt>(varint: &U, w: &mut W) -> io::Result<()> {
84 let mut varint: u64 = varint.into_u64();
85 while {
86 let mut b = u8::try_from(varint & u64::from(!VARINT_CONTINUATION_MASK))
87 .expect("& eight_bit_mask left more than 8 bits set");
88 varint >>= 7;
89 if varint != 0 {
90 b |= VARINT_CONTINUATION_MASK;
91 }
92 write_byte(&b, w)?;
93 varint != 0
94 } {}
95 Ok(())
96}
97
98pub fn write_scalar<W: Write>(scalar: &Scalar, w: &mut W) -> io::Result<()> {
100 w.write_all(&scalar.to_bytes())
101}
102
103pub fn write_point<W: Write>(point: &EdwardsPoint, w: &mut W) -> io::Result<()> {
105 CompressedPoint(point.compress().to_bytes()).write(w)
106}
107
108pub fn write_raw_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
110 f: F,
111 values: &[T],
112 w: &mut W,
113) -> io::Result<()> {
114 for value in values {
115 f(value, w)?;
116 }
117 Ok(())
118}
119
120pub fn write_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
122 f: F,
123 values: &[T],
124 w: &mut W,
125) -> io::Result<()> {
126 write_varint(&values.len(), w)?;
127 write_raw_vec(f, values, w)
128}
129
130pub fn read_bytes<R: Read, const N: usize>(r: &mut R) -> io::Result<[u8; N]> {
132 let mut res = [0; N];
133 r.read_exact(&mut res)?;
134 Ok(res)
135}
136
137pub fn read_byte<R: Read>(r: &mut R) -> io::Result<u8> {
139 Ok(read_bytes::<_, 1>(r)?[0])
140}
141
142pub fn read_u16<R: Read>(r: &mut R) -> io::Result<u16> {
144 read_bytes(r).map(u16::from_le_bytes)
145}
146
147pub fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
149 read_bytes(r).map(u32::from_le_bytes)
150}
151
152pub fn read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
154 read_bytes(r).map(u64::from_le_bytes)
155}
156
157pub fn read_varint<R: Read, U: sealed::VarInt>(r: &mut R) -> io::Result<U> {
159 let mut bits = 0;
160 let mut res = 0;
161 while {
162 let b = read_byte(r)?;
163 if (bits != 0) && (b == 0) {
164 Err(io::Error::other("non-canonical varint"))?;
165 }
166 if ((bits + 7) >= U::BITS) && (b >= (1 << (U::BITS - bits))) {
167 Err(io::Error::other("varint overflow"))?;
168 }
169
170 res += u64::from(b & (!VARINT_CONTINUATION_MASK)) << bits;
171 bits += 7;
172 b & VARINT_CONTINUATION_MASK == VARINT_CONTINUATION_MASK
173 } {}
174 res.try_into().map_err(|_| io::Error::other("VarInt does not fit into integer type"))
175}
176
177pub fn read_scalar<R: Read>(r: &mut R) -> io::Result<Scalar> {
182 Option::from(Scalar::from_canonical_bytes(read_bytes(r)?))
183 .ok_or_else(|| io::Error::other("unreduced scalar"))
184}
185
186pub fn read_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
191 CompressedPoint::read(r)?.decompress().ok_or_else(|| io::Error::other("invalid point"))
192}
193
194pub fn read_raw_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
196 f: F,
197 len: usize,
198 r: &mut R,
199) -> io::Result<Vec<T>> {
200 let mut res = vec![];
201 for _ in 0 .. len {
202 res.push(f(r)?);
203 }
204 Ok(res)
205}
206
207pub fn read_array<R: Read, T: Debug, F: Fn(&mut R) -> io::Result<T>, const N: usize>(
209 f: F,
210 r: &mut R,
211) -> io::Result<[T; N]> {
212 read_raw_vec(f, N, r).map(|vec| {
213 vec.try_into().expect(
214 "read vector of specific length yet couldn't transform to an array of the same length",
215 )
216 })
217}
218
219pub fn read_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
225 f: F,
226 length_bound: Option<usize>,
227 r: &mut R,
228) -> io::Result<Vec<T>> {
229 let declared_length: usize = read_varint(r)?;
230 if let Some(length_bound) = length_bound {
231 if declared_length > length_bound {
232 Err(io::Error::other("vector exceeds bound on length"))?;
233 }
234 }
235 read_raw_vec(f, declared_length, r)
236}