tokio/signal/
registry.rs

1use crate::signal::os::{OsExtraData, OsStorage};
2use crate::sync::watch;
3use crate::util::once_cell::OnceCell;
4
5use std::ops;
6use std::sync::atomic::{AtomicBool, Ordering};
7
8pub(crate) type EventId = usize;
9
10/// State for a specific event, whether a notification is pending delivery,
11/// and what listeners are registered.
12#[derive(Debug)]
13pub(crate) struct EventInfo {
14    pending: AtomicBool,
15    tx: watch::Sender<()>,
16}
17
18impl Default for EventInfo {
19    fn default() -> Self {
20        let (tx, _rx) = watch::channel(());
21
22        Self {
23            pending: AtomicBool::new(false),
24            tx,
25        }
26    }
27}
28
29/// An interface for retrieving the `EventInfo` for a particular `eventId`.
30pub(crate) trait Storage {
31    /// Gets the `EventInfo` for `id` if it exists.
32    fn event_info(&self, id: EventId) -> Option<&EventInfo>;
33
34    /// Invokes `f` once for each defined `EventInfo` in this storage.
35    fn for_each<'a, F>(&'a self, f: F)
36    where
37        F: FnMut(&'a EventInfo);
38}
39
40impl Storage for Vec<EventInfo> {
41    fn event_info(&self, id: EventId) -> Option<&EventInfo> {
42        self.get(id)
43    }
44
45    fn for_each<'a, F>(&'a self, f: F)
46    where
47        F: FnMut(&'a EventInfo),
48    {
49        self.iter().for_each(f);
50    }
51}
52
53/// An interface for initializing a type. Useful for situations where we cannot
54/// inject a configured instance in the constructor of another type.
55pub(crate) trait Init {
56    fn init() -> Self;
57}
58
59/// Manages and distributes event notifications to any registered listeners.
60///
61/// Generic over the underlying storage to allow for domain specific
62/// optimizations (e.g. `eventIds` may or may not be contiguous).
63#[derive(Debug)]
64pub(crate) struct Registry<S> {
65    storage: S,
66}
67
68impl<S> Registry<S> {
69    fn new(storage: S) -> Self {
70        Self { storage }
71    }
72}
73
74impl<S: Storage> Registry<S> {
75    /// Registers a new listener for `event_id`.
76    fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> {
77        self.storage
78            .event_info(event_id)
79            .unwrap_or_else(|| panic!("invalid event_id: {event_id}"))
80            .tx
81            .subscribe()
82    }
83
84    /// Marks `event_id` as having been delivered, without broadcasting it to
85    /// any listeners.
86    fn record_event(&self, event_id: EventId) {
87        if let Some(event_info) = self.storage.event_info(event_id) {
88            event_info.pending.store(true, Ordering::SeqCst);
89        }
90    }
91
92    /// Broadcasts all previously recorded events to their respective listeners.
93    ///
94    /// Returns `true` if an event was delivered to at least one listener.
95    fn broadcast(&self) -> bool {
96        let mut did_notify = false;
97        self.storage.for_each(|event_info| {
98            // Any signal of this kind arrived since we checked last?
99            if !event_info.pending.swap(false, Ordering::SeqCst) {
100                return;
101            }
102
103            // Ignore errors if there are no listeners
104            if event_info.tx.send(()).is_ok() {
105                did_notify = true;
106            }
107        });
108
109        did_notify
110    }
111}
112
113pub(crate) struct Globals {
114    extra: OsExtraData,
115    registry: Registry<OsStorage>,
116}
117
118impl ops::Deref for Globals {
119    type Target = OsExtraData;
120
121    fn deref(&self) -> &Self::Target {
122        &self.extra
123    }
124}
125
126impl Globals {
127    /// Registers a new listener for `event_id`.
128    pub(crate) fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> {
129        self.registry.register_listener(event_id)
130    }
131
132    /// Marks `event_id` as having been delivered, without broadcasting it to
133    /// any listeners.
134    pub(crate) fn record_event(&self, event_id: EventId) {
135        self.registry.record_event(event_id);
136    }
137
138    /// Broadcasts all previously recorded events to their respective listeners.
139    ///
140    /// Returns `true` if an event was delivered to at least one listener.
141    pub(crate) fn broadcast(&self) -> bool {
142        self.registry.broadcast()
143    }
144
145    #[cfg(unix)]
146    pub(crate) fn storage(&self) -> &OsStorage {
147        &self.registry.storage
148    }
149}
150
151fn globals_init() -> Globals
152where
153    OsExtraData: 'static + Send + Sync + Init,
154    OsStorage: 'static + Send + Sync + Init,
155{
156    Globals {
157        extra: OsExtraData::init(),
158        registry: Registry::new(OsStorage::init()),
159    }
160}
161
162pub(crate) fn globals() -> &'static Globals
163where
164    OsExtraData: 'static + Send + Sync + Init,
165    OsStorage: 'static + Send + Sync + Init,
166{
167    static GLOBALS: OnceCell<Globals> = OnceCell::new();
168
169    GLOBALS.get(globals_init)
170}
171
172#[cfg(all(test, not(loom)))]
173mod tests {
174    use super::*;
175    use crate::runtime::{self, Runtime};
176    use crate::sync::{oneshot, watch};
177
178    use futures::future;
179
180    #[test]
181    fn smoke() {
182        let rt = rt();
183        rt.block_on(async move {
184            let registry = Registry::new(vec![
185                EventInfo::default(),
186                EventInfo::default(),
187                EventInfo::default(),
188            ]);
189
190            let first = registry.register_listener(0);
191            let second = registry.register_listener(1);
192            let third = registry.register_listener(2);
193
194            let (fire, wait) = oneshot::channel();
195
196            crate::spawn(async {
197                wait.await.expect("wait failed");
198
199                // Record some events which should get coalesced
200                registry.record_event(0);
201                registry.record_event(0);
202                registry.record_event(1);
203                registry.record_event(1);
204                registry.broadcast();
205
206                // Yield so the previous broadcast can get received
207                //
208                // This yields many times since the block_on task is only polled every 61
209                // ticks.
210                for _ in 0..100 {
211                    crate::task::yield_now().await;
212                }
213
214                // Send subsequent signal
215                registry.record_event(0);
216                registry.broadcast();
217
218                drop(registry);
219            });
220
221            let _ = fire.send(());
222            let all = future::join3(collect(first), collect(second), collect(third));
223
224            let (first_results, second_results, third_results) = all.await;
225            assert_eq!(2, first_results.len());
226            assert_eq!(1, second_results.len());
227            assert_eq!(0, third_results.len());
228        });
229    }
230
231    #[test]
232    #[should_panic = "invalid event_id: 1"]
233    fn register_panics_on_invalid_input() {
234        let registry = Registry::new(vec![EventInfo::default()]);
235
236        registry.register_listener(1);
237    }
238
239    #[test]
240    fn record_invalid_event_does_nothing() {
241        let registry = Registry::new(vec![EventInfo::default()]);
242        registry.record_event(1302);
243    }
244
245    #[test]
246    fn broadcast_returns_if_at_least_one_event_fired() {
247        let registry = Registry::new(vec![EventInfo::default(), EventInfo::default()]);
248
249        registry.record_event(0);
250        assert!(!registry.broadcast());
251
252        let first = registry.register_listener(0);
253        let second = registry.register_listener(1);
254
255        registry.record_event(0);
256        assert!(registry.broadcast());
257
258        drop(first);
259        registry.record_event(0);
260        assert!(!registry.broadcast());
261
262        drop(second);
263    }
264
265    fn rt() -> Runtime {
266        runtime::Builder::new_current_thread()
267            .enable_time()
268            .build()
269            .unwrap()
270    }
271
272    async fn collect(mut rx: watch::Receiver<()>) -> Vec<()> {
273        let mut ret = vec![];
274
275        while let Ok(v) = rx.changed().await {
276            ret.push(v);
277        }
278
279        ret
280    }
281}