tor_proto/stream/
incoming.rs1use bitvec::prelude::*;
4
5use super::{AnyCmdChecker, DataStream, StreamStatus};
6use crate::circuit::{ClientCircSyncView, StreamComponents};
7use crate::tunnel::reactor::CloseStreamBehavior;
8use crate::{Error, Result};
9use derive_deftly::Deftly;
10use oneshot_fused_workaround as oneshot;
11use tor_cell::relaycell::{msg, RelayCmd, UnparsedRelayMsg};
12use tor_cell::restricted_msg;
13use tor_error::internal;
14use tor_memquota::derive_deftly_template_HasMemoryCost;
15use tor_rtcompat::DynTimeProvider;
16
17#[derive(Debug)]
27pub struct IncomingStream {
28 time_provider: DynTimeProvider,
30 request: IncomingStreamRequest,
32 components: StreamComponents,
34}
35
36impl IncomingStream {
37 pub(crate) fn new(
39 time_provider: DynTimeProvider,
40 request: IncomingStreamRequest,
41 components: StreamComponents,
42 ) -> Self {
43 Self {
44 time_provider,
45 request,
46 components,
47 }
48 }
49
50 pub fn request(&self) -> &IncomingStreamRequest {
52 &self.request
53 }
54
55 pub async fn accept_data(self, message: msg::Connected) -> Result<DataStream> {
58 let Self {
59 time_provider,
60 request,
61 components:
62 StreamComponents {
63 mut target,
64 stream_receiver,
65 xon_xoff_reader_ctrl,
66 memquota,
67 },
68 } = self;
69
70 match request {
71 IncomingStreamRequest::Begin(_) | IncomingStreamRequest::BeginDir(_) => {
72 target.send(message.into()).await?;
73 Ok(DataStream::new_connected(
74 time_provider,
75 stream_receiver,
76 xon_xoff_reader_ctrl,
77 target,
78 memquota,
79 ))
80 }
81 IncomingStreamRequest::Resolve(_) => {
82 Err(internal!("Cannot accept data on a RESOLVE stream").into())
83 }
84 }
85 }
86
87 pub async fn reject(mut self, message: msg::End) -> Result<()> {
89 let rx = self.reject_inner(CloseStreamBehavior::SendEnd(message))?;
90
91 rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
92 }
93
94 fn reject_inner(
98 &mut self,
99 message: CloseStreamBehavior,
100 ) -> Result<oneshot::Receiver<Result<()>>> {
101 self.components.target.close_pending(message)
102 }
103
104 pub async fn discard(mut self) -> Result<()> {
110 let rx = self.reject_inner(CloseStreamBehavior::SendNothing)?;
111
112 rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
113 }
114}
115
116restricted_msg! {
122 #[derive(Clone, Debug, Deftly)]
124 #[derive_deftly(HasMemoryCost)]
125 #[non_exhaustive]
126 pub enum IncomingStreamRequest: RelayMsg {
127 Begin,
129 BeginDir,
131 Resolve,
133 }
134}
135
136type RelayCmdSet = bitvec::BitArr!(for 256);
141
142#[derive(Debug)]
145pub(crate) struct IncomingCmdChecker {
146 allow_commands: RelayCmdSet,
154}
155
156impl IncomingCmdChecker {
157 pub(crate) fn new_any(allow_commands: &[RelayCmd]) -> AnyCmdChecker {
159 let mut array = BitArray::ZERO;
160 for c in allow_commands {
161 array.set(u8::from(*c) as usize, true);
162 }
163 Box::new(Self {
164 allow_commands: array,
165 })
166 }
167}
168
169impl super::CmdChecker for IncomingCmdChecker {
170 fn check_msg(&mut self, msg: &UnparsedRelayMsg) -> Result<StreamStatus> {
171 if self.allow_commands[u8::from(msg.cmd()) as usize] {
172 Ok(StreamStatus::Open)
173 } else {
174 Err(Error::StreamProto(format!(
175 "Unexpected {} on incoming stream",
176 msg.cmd()
177 )))
178 }
179 }
180
181 fn consume_checked_msg(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
182 let _ = msg
183 .decode::<IncomingStreamRequest>()
184 .map_err(|err| Error::from_bytes_err(err, "invalid message on incoming stream"))?;
185
186 Ok(())
187 }
188}
189
190pub trait IncomingStreamRequestFilter: Send + 'static {
197 fn disposition(
201 &mut self,
202 ctx: &IncomingStreamRequestContext<'_>,
203 circ: &ClientCircSyncView<'_>,
204 ) -> Result<IncomingStreamRequestDisposition>;
205}
206
207#[derive(Clone, Debug)]
209#[non_exhaustive]
210pub enum IncomingStreamRequestDisposition {
211 Accept,
214 CloseCircuit,
216 RejectRequest(msg::End),
218}
219
220pub struct IncomingStreamRequestContext<'a> {
222 pub(crate) request: &'a IncomingStreamRequest,
224}
225
226impl<'a> IncomingStreamRequestContext<'a> {
227 pub fn request(&self) -> &'a IncomingStreamRequest {
229 self.request
230 }
231}
232
233#[cfg(test)]
234mod test {
235 #![allow(clippy::bool_assert_comparison)]
237 #![allow(clippy::clone_on_copy)]
238 #![allow(clippy::dbg_macro)]
239 #![allow(clippy::mixed_attributes_style)]
240 #![allow(clippy::print_stderr)]
241 #![allow(clippy::print_stdout)]
242 #![allow(clippy::single_char_pattern)]
243 #![allow(clippy::unwrap_used)]
244 #![allow(clippy::unchecked_duration_subtraction)]
245 #![allow(clippy::useless_vec)]
246 #![allow(clippy::needless_pass_by_value)]
247 use tor_cell::relaycell::{
250 msg::{Begin, BeginDir, Data, Resolve},
251 AnyRelayMsgOuter, RelayCellFormat,
252 };
253
254 use super::*;
255
256 #[test]
257 fn incoming_cmd_checker() {
258 let u = |msg| {
260 let body = AnyRelayMsgOuter::new(None, msg)
261 .encode(RelayCellFormat::V0, &mut rand::rng())
262 .unwrap();
263 UnparsedRelayMsg::from_singleton_body(RelayCellFormat::V0, body).unwrap()
264 };
265 let begin = u(Begin::new("allium.example.com", 443, 0).unwrap().into());
266 let begin_dir = u(BeginDir::default().into());
267 let resolve = u(Resolve::new("allium.example.com").into());
268 let data = u(Data::new(&[1, 2, 3]).unwrap().into());
269
270 {
271 let mut cc_none = IncomingCmdChecker::new_any(&[]);
272 for m in [&begin, &begin_dir, &resolve, &data] {
273 assert!(cc_none.check_msg(m).is_err());
274 }
275 }
276
277 {
278 let mut cc_begin = IncomingCmdChecker::new_any(&[RelayCmd::BEGIN]);
279 assert_eq!(cc_begin.check_msg(&begin).unwrap(), StreamStatus::Open);
280 for m in [&begin_dir, &resolve, &data] {
281 assert!(cc_begin.check_msg(m).is_err());
282 }
283 }
284
285 {
286 let mut cc_any = IncomingCmdChecker::new_any(&[
287 RelayCmd::BEGIN,
288 RelayCmd::BEGIN_DIR,
289 RelayCmd::RESOLVE,
290 ]);
291 for m in [&begin, &begin_dir, &resolve] {
292 assert_eq!(cc_any.check_msg(m).unwrap(), StreamStatus::Open);
293 }
294 assert!(cc_any.check_msg(&data).is_err());
295 }
296 }
297}