safelog/
flags.rs

1//! Code for turning safelogging on and off.
2//!
3//! By default, safelogging is on.  There are two ways to turn it off: Globally
4//! (with [`disable_safe_logging`]) and locally (with
5//! [`with_safe_logging_suppressed`]).
6
7use crate::{Error, Result};
8use fluid_let::fluid_let;
9use std::sync::atomic::{AtomicIsize, Ordering};
10
11/// A global atomic used to track locking guards for enabling and disabling
12/// safe-logging.
13///
14/// The value of this atomic is less than 0 if we have enabled unsafe logging.
15/// greater than 0 if we have enabled safe logging, and 0 if nobody cares.
16static LOGGING_STATE: AtomicIsize = AtomicIsize::new(0);
17
18fluid_let!(
19    /// A dynamic variable used to temporarily disable safe-logging.
20    static SAFE_LOGGING_SUPPRESSED_IN_THREAD: bool
21);
22
23/// Returns true if we are displaying sensitive values, false otherwise.
24#[doc(hidden)]
25pub fn unsafe_logging_enabled() -> bool {
26    LOGGING_STATE.load(Ordering::Relaxed) < 0
27        || SAFE_LOGGING_SUPPRESSED_IN_THREAD.get(|v| v == Some(&true))
28}
29
30/// Run a given function with the regular `safelog` functionality suppressed.
31///
32/// The provided function, and everything it calls, will display
33/// [`Sensitive`](crate::Sensitive) values as if they were not sensitive.
34///
35/// # Examples
36///
37/// ```
38/// use safelog::{Sensitive, with_safe_logging_suppressed};
39///
40/// let string = Sensitive::new("swordfish");
41///
42/// // Ordinarily, the string isn't displayed as normal
43/// assert_eq!(format!("The value is {}", string),
44///            "The value is [scrubbed]");
45///
46/// // But you can override that:
47/// assert_eq!(
48///     with_safe_logging_suppressed(|| format!("The value is {}", string)),
49///     "The value is swordfish"
50/// );
51/// ```
52pub fn with_safe_logging_suppressed<F, V>(func: F) -> V
53where
54    F: FnOnce() -> V,
55{
56    // This sets the value of the variable to Some(true) temporarily, for as
57    // long as `func` is being called.  It uses thread-local variables
58    // internally.
59    SAFE_LOGGING_SUPPRESSED_IN_THREAD.set(true, func)
60}
61
62/// Enum to describe what kind of a [`Guard`] we've created.
63#[derive(Debug, Copy, Clone)]
64enum GuardKind {
65    /// We are forcing safe-logging to be enabled, so that nobody
66    /// can turn it off with `disable_safe_logging`
67    Safe,
68    /// We have are turning safe-logging off with `disable_safe_logging`.
69    Unsafe,
70}
71
72/// A guard object used to enforce safe logging, or turn it off.
73///
74/// For as long as this object exists, the chosen behavior will be enforced.
75//
76// TODO: Should there be different types for "keep safe logging on" and "turn
77// safe logging off"?  Having the same type makes it easier to write code that
78// does stuff like this:
79//
80//     let g = if cfg.safe {
81//         enforce_safe_logging()
82//     } else {
83//         disable_safe_logging()
84//     };
85#[derive(Debug)]
86#[must_use = "If you drop the guard immediately, it won't do anything."]
87pub struct Guard {
88    /// What kind of guard is this?
89    kind: GuardKind,
90}
91
92impl GuardKind {
93    /// Return an error if `val` (as a value of `LOGGING_STATE`) indicates that
94    /// intended kind of guard cannot be created.
95    fn check(&self, val: isize) -> Result<()> {
96        match self {
97            GuardKind::Safe => {
98                if val < 0 {
99                    return Err(Error::AlreadyUnsafe);
100                }
101            }
102            GuardKind::Unsafe => {
103                if val > 0 {
104                    return Err(Error::AlreadySafe);
105                }
106            }
107        }
108        Ok(())
109    }
110    /// Return the value by which `LOGGING_STATE` should change while a guard of
111    /// this type exists.
112    fn increment(&self) -> isize {
113        match self {
114            GuardKind::Safe => 1,
115            GuardKind::Unsafe => -1,
116        }
117    }
118}
119
120impl Guard {
121    /// Helper: Create a guard of a given kind.
122    fn new(kind: GuardKind) -> Result<Self> {
123        let inc = kind.increment();
124        loop {
125            // Find the current value of LOGGING_STATE and see if this guard can
126            // be created.
127            let old_val = LOGGING_STATE.load(Ordering::SeqCst);
128            // Exit if this guard can't be created.
129            kind.check(old_val)?;
130            // Otherwise, try changing LOGGING_STATE to the new value that it
131            // _should_ have when this guard exists.
132            let new_val = match old_val.checked_add(inc) {
133                Some(v) => v,
134                None => return Err(Error::Overflow),
135            };
136            if let Ok(v) =
137                LOGGING_STATE.compare_exchange(old_val, new_val, Ordering::SeqCst, Ordering::SeqCst)
138            {
139                // Great, we set the value to what it should be; we're done.
140                debug_assert_eq!(v, old_val);
141                return Ok(Self { kind });
142            }
143            // Otherwise, somebody else altered this value concurrently: try
144            // again.
145        }
146    }
147}
148
149impl Drop for Guard {
150    fn drop(&mut self) {
151        let inc = self.kind.increment();
152        LOGGING_STATE.fetch_sub(inc, Ordering::SeqCst);
153    }
154}
155
156/// Create a new [`Guard`] to prevent anyone else from disabling safe logging.
157///
158/// Until the resulting `Guard` is dropped, any attempts to call
159/// `disable_safe_logging` will give an error.  This guard does _not_ affect
160/// calls to [`with_safe_logging_suppressed`].
161///
162/// This call will return an error if safe logging is _already_ disabled.
163///
164/// Note that this function is called "enforce", not "enable", since safe
165/// logging is enabled by default.  Its purpose is to make sure that nothing
166/// _else_ has called disable_safe_logging().
167pub fn enforce_safe_logging() -> Result<Guard> {
168    Guard::new(GuardKind::Safe)
169}
170
171/// Create a new [`Guard`] to disable safe logging.
172///
173/// Until the resulting `Guard` is dropped, all [`Sensitive`](crate::Sensitive)
174/// values will be displayed as if they were not sensitive.
175///
176/// This call will return an error if safe logging has been enforced with
177/// [`enforce_safe_logging`].
178pub fn disable_safe_logging() -> Result<Guard> {
179    Guard::new(GuardKind::Unsafe)
180}
181
182#[cfg(test)]
183mod test {
184    // @@ begin test lint list maintained by maint/add_warning @@
185    #![allow(clippy::bool_assert_comparison)]
186    #![allow(clippy::clone_on_copy)]
187    #![allow(clippy::dbg_macro)]
188    #![allow(clippy::mixed_attributes_style)]
189    #![allow(clippy::print_stderr)]
190    #![allow(clippy::print_stdout)]
191    #![allow(clippy::single_char_pattern)]
192    #![allow(clippy::unwrap_used)]
193    #![allow(clippy::unchecked_duration_subtraction)]
194    #![allow(clippy::useless_vec)]
195    #![allow(clippy::needless_pass_by_value)]
196    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
197    use super::*;
198    // We use "serial_test" to make sure that our tests here run one at a time,
199    // since they modify global state.
200    use serial_test::serial;
201
202    #[test]
203    #[serial]
204    fn guards() {
205        // Try operations with logging guards turned on and off, in a single
206        // thread.
207        assert!(!unsafe_logging_enabled());
208        let g1 = enforce_safe_logging().unwrap();
209        let g2 = enforce_safe_logging().unwrap();
210
211        assert!(!unsafe_logging_enabled());
212
213        let e = disable_safe_logging();
214        assert!(matches!(e, Err(Error::AlreadySafe)));
215        assert!(!unsafe_logging_enabled());
216
217        drop(g1);
218        drop(g2);
219        let _g3 = disable_safe_logging().unwrap();
220        assert!(unsafe_logging_enabled());
221        let e = enforce_safe_logging();
222        assert!(matches!(e, Err(Error::AlreadyUnsafe)));
223        assert!(unsafe_logging_enabled());
224        let _g4 = disable_safe_logging().unwrap();
225
226        assert!(unsafe_logging_enabled());
227    }
228
229    #[test]
230    #[serial]
231    fn suppress() {
232        // Try out `with_safe_logging_suppressed` and make sure it does what we want
233        // regardless of the initial state of logging.
234        {
235            let _g = enforce_safe_logging().unwrap();
236            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
237            assert!(!unsafe_logging_enabled());
238        }
239
240        {
241            assert!(!unsafe_logging_enabled());
242            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
243            assert!(!unsafe_logging_enabled());
244        }
245
246        {
247            let _g = disable_safe_logging().unwrap();
248            assert!(unsafe_logging_enabled());
249            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
250        }
251    }
252
253    #[test]
254    #[serial]
255    fn interfere_1() {
256        // Make sure that two threads trying to enforce and disable safe logging
257        // can interfere with each other, but will never enter an incorrect
258        // state.
259        use std::thread::{spawn, yield_now};
260
261        let thread1 = spawn(|| {
262            for _ in 0..10_000 {
263                if let Ok(_g) = enforce_safe_logging() {
264                    assert!(!unsafe_logging_enabled());
265                    yield_now();
266                    assert!(disable_safe_logging().is_err());
267                }
268                yield_now();
269            }
270        });
271
272        let thread2 = spawn(|| {
273            for _ in 0..10_000 {
274                if let Ok(_g) = disable_safe_logging() {
275                    assert!(unsafe_logging_enabled());
276                    yield_now();
277                    assert!(enforce_safe_logging().is_err());
278                }
279                yield_now();
280            }
281        });
282
283        thread1.join().unwrap();
284        thread2.join().unwrap();
285    }
286
287    #[test]
288    #[serial]
289    fn interfere_2() {
290        // Make sure that two threads trying to disable safe logging don't
291        // interfere.
292        use std::thread::{spawn, yield_now};
293
294        let thread1 = spawn(|| {
295            for _ in 0..10_000 {
296                let g = disable_safe_logging().unwrap();
297                assert!(unsafe_logging_enabled());
298                yield_now();
299                drop(g);
300                yield_now();
301            }
302        });
303
304        let thread2 = spawn(|| {
305            for _ in 0..10_000 {
306                let g = disable_safe_logging().unwrap();
307                assert!(unsafe_logging_enabled());
308                yield_now();
309                drop(g);
310                yield_now();
311            }
312        });
313
314        thread1.join().unwrap();
315        thread2.join().unwrap();
316    }
317
318    #[test]
319    #[serial]
320    fn interfere_3() {
321        // Make sure that `with_safe_logging_suppressed` only applies to the
322        // current thread.
323        use std::thread::{spawn, yield_now};
324
325        let thread1 = spawn(|| {
326            for _ in 0..10_000 {
327                assert!(!unsafe_logging_enabled());
328                yield_now();
329            }
330        });
331
332        let thread2 = spawn(|| {
333            for _ in 0..10_000 {
334                assert!(!unsafe_logging_enabled());
335                with_safe_logging_suppressed(|| {
336                    assert!(unsafe_logging_enabled());
337                    yield_now();
338                });
339            }
340        });
341
342        thread1.join().unwrap();
343        thread2.join().unwrap();
344    }
345}