1#![cfg_attr(not(feature = "full"), allow(dead_code))]
2
3use crate::runtime::context;
33
34#[derive(Debug, Copy, Clone)]
37pub(crate) struct Budget(Option<u8>);
38
39pub(crate) struct BudgetDecrement {
40 success: bool,
41 hit_zero: bool,
42}
43
44impl Budget {
45 const fn initial() -> Budget {
56 Budget(Some(128))
57 }
58
59 pub(super) const fn unconstrained() -> Budget {
61 Budget(None)
62 }
63
64 fn has_remaining(self) -> bool {
65 self.0.map_or(true, |budget| budget > 0)
66 }
67}
68
69#[inline(always)]
72pub(crate) fn budget<R>(f: impl FnOnce() -> R) -> R {
73 with_budget(Budget::initial(), f)
74}
75
76#[inline(always)]
79pub(crate) fn with_unconstrained<R>(f: impl FnOnce() -> R) -> R {
80 with_budget(Budget::unconstrained(), f)
81}
82
83#[inline(always)]
84fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R {
85 struct ResetGuard {
86 prev: Budget,
87 }
88
89 impl Drop for ResetGuard {
90 fn drop(&mut self) {
91 let _ = context::budget(|cell| {
92 cell.set(self.prev);
93 });
94 }
95 }
96
97 #[allow(unused_variables)]
98 let maybe_guard = context::budget(|cell| {
99 let prev = cell.get();
100 cell.set(budget);
101
102 ResetGuard { prev }
103 });
104
105 f()
108}
109
110#[inline(always)]
111pub(crate) fn has_budget_remaining() -> bool {
112 context::budget(|cell| cell.get().has_remaining()).unwrap_or(true)
115}
116
117cfg_rt_multi_thread! {
118 pub(crate) fn set(budget: Budget) {
120 let _ = context::budget(|cell| cell.set(budget));
121 }
122}
123
124cfg_rt! {
125 pub(crate) fn stop() -> Budget {
129 context::budget(|cell| {
130 let prev = cell.get();
131 cell.set(Budget::unconstrained());
132 prev
133 }).unwrap_or(Budget::unconstrained())
134 }
135}
136
137cfg_coop! {
138 use pin_project_lite::pin_project;
139 use std::cell::Cell;
140 use std::future::Future;
141 use std::pin::Pin;
142 use std::task::{ready, Context, Poll};
143
144 #[must_use]
145 pub(crate) struct RestoreOnPending(Cell<Budget>);
146
147 impl RestoreOnPending {
148 pub(crate) fn made_progress(&self) {
149 self.0.set(Budget::unconstrained());
150 }
151 }
152
153 impl Drop for RestoreOnPending {
154 fn drop(&mut self) {
155 let budget = self.0.get();
158 if !budget.is_unconstrained() {
159 let _ = context::budget(|cell| {
160 cell.set(budget);
161 });
162 }
163 }
164 }
165
166 #[inline]
179 pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll<RestoreOnPending> {
180 context::budget(|cell| {
181 let mut budget = cell.get();
182
183 let decrement = budget.decrement();
184
185 if decrement.success {
186 let restore = RestoreOnPending(Cell::new(cell.get()));
187 cell.set(budget);
188
189 if decrement.hit_zero {
191 inc_budget_forced_yield_count();
192 }
193
194 Poll::Ready(restore)
195 } else {
196 cx.waker().wake_by_ref();
197 Poll::Pending
198 }
199 }).unwrap_or(Poll::Ready(RestoreOnPending(Cell::new(Budget::unconstrained()))))
200 }
201
202 cfg_rt! {
203 cfg_unstable_metrics! {
204 #[inline(always)]
205 fn inc_budget_forced_yield_count() {
206 let _ = context::with_current(|handle| {
207 handle.scheduler_metrics().inc_budget_forced_yield_count();
208 });
209 }
210 }
211
212 cfg_not_unstable_metrics! {
213 #[inline(always)]
214 fn inc_budget_forced_yield_count() {}
215 }
216 }
217
218 cfg_not_rt! {
219 #[inline(always)]
220 fn inc_budget_forced_yield_count() {}
221 }
222
223 impl Budget {
224 fn decrement(&mut self) -> BudgetDecrement {
227 if let Some(num) = &mut self.0 {
228 if *num > 0 {
229 *num -= 1;
230
231 let hit_zero = *num == 0;
232
233 BudgetDecrement { success: true, hit_zero }
234 } else {
235 BudgetDecrement { success: false, hit_zero: false }
236 }
237 } else {
238 BudgetDecrement { success: true, hit_zero: false }
239 }
240 }
241
242 fn is_unconstrained(self) -> bool {
243 self.0.is_none()
244 }
245 }
246
247 pin_project! {
248 #[must_use = "futures do nothing unless polled"]
256 pub(crate) struct Coop<F: Future> {
257 #[pin]
258 pub(crate) fut: F,
259 }
260 }
261
262 impl<F: Future> Future for Coop<F> {
263 type Output = F::Output;
264
265 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266 let coop = ready!(poll_proceed(cx));
267 let me = self.project();
268 if let Poll::Ready(ret) = me.fut.poll(cx) {
269 coop.made_progress();
270 Poll::Ready(ret)
271 } else {
272 Poll::Pending
273 }
274 }
275 }
276
277 #[inline]
281 pub(crate) fn cooperative<F: Future>(fut: F) -> Coop<F> {
282 Coop { fut }
283 }
284}
285
286#[cfg(all(test, not(loom)))]
287mod test {
288 use super::*;
289
290 #[cfg(all(target_family = "wasm", not(target_os = "wasi")))]
291 use wasm_bindgen_test::wasm_bindgen_test as test;
292
293 fn get() -> Budget {
294 context::budget(|cell| cell.get()).unwrap_or(Budget::unconstrained())
295 }
296
297 #[test]
298 fn budgeting() {
299 use std::future::poll_fn;
300 use tokio_test::*;
301
302 assert!(get().0.is_none());
303
304 let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
305
306 assert!(get().0.is_none());
307 drop(coop);
308 assert!(get().0.is_none());
309
310 budget(|| {
311 assert_eq!(get().0, Budget::initial().0);
312
313 let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
314 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
315 drop(coop);
316 assert_eq!(get().0, Budget::initial().0);
318
319 let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
320 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
321 coop.made_progress();
322 drop(coop);
323 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
325
326 let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
327 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);
328 coop.made_progress();
329 drop(coop);
330 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);
331
332 budget(|| {
333 assert_eq!(get().0, Budget::initial().0);
334
335 let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
336 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
337 coop.made_progress();
338 drop(coop);
339 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
340 });
341
342 assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);
343 });
344
345 assert!(get().0.is_none());
346
347 budget(|| {
348 let n = get().0.unwrap();
349
350 for _ in 0..n {
351 let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
352 coop.made_progress();
353 }
354
355 let mut task = task::spawn(poll_fn(|cx| {
356 let coop = std::task::ready!(poll_proceed(cx));
357 coop.made_progress();
358 Poll::Ready(())
359 }));
360
361 assert_pending!(task.poll());
362 });
363 }
364}