summaryrefslogtreecommitdiffstats
path: root/tokio/src/runtime/basic_scheduler.rs
blob: f809db419ae2fb9429e3b8ac6b7aa5567619220f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
use crate::park::{Park, Unpark};
use crate::task::{self, queue::MpscQueues, JoinHandle, Schedule, ScheduleSendOnly, Task};

use std::cell::Cell;
use std::fmt;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::ptr;
use std::sync::Arc;
use std::task::{RawWaker, RawWakerVTable, Waker};
use std::time::Duration;

/// Executes tasks on the current thread
#[derive(Debug)]
pub(crate) struct BasicScheduler<P>
where
    P: Park,
{
    /// Scheduler component
    scheduler: Arc<SchedulerPriv>,

    /// Local state
    local: LocalState<P>,
}

#[derive(Debug, Clone)]
pub(crate) struct Spawner {
    scheduler: Arc<SchedulerPriv>,
}

/// The scheduler component.
pub(super) struct SchedulerPriv {
    queues: MpscQueues<Self>,
    /// Unpark the blocked thread
    unpark: Box<dyn Unpark>,
}

unsafe impl Send for SchedulerPriv {}
unsafe impl Sync for SchedulerPriv {}

/// Local state
#[derive(Debug)]
struct LocalState<P> {
    /// Current tick
    tick: u8,

    /// Thread park handle
    park: P,
}

/// Max number of tasks to poll per tick.
const MAX_TASKS_PER_TICK: usize = 61;

thread_local! {
    static ACTIVE: Cell<*const SchedulerPriv> = Cell::new(ptr::null())
}

impl<P> BasicScheduler<P>
where
    P: Park,
{
    pub(crate) fn new(park: P) -> BasicScheduler<P> {
        let unpark = park.unpark();

        BasicScheduler {
            scheduler: Arc::new(SchedulerPriv {
                queues: MpscQueues::new(),
                unpark: Box::new(unpark),
            }),
            local: LocalState { tick: 0, park },
        }
    }

    pub(crate) fn spawner(&self) -> Spawner {
        Spawner {
            scheduler: self.scheduler.clone(),
        }
    }

    /// Spawn a future onto the thread pool
    pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
    where
        F: Future + Send + 'static,
        F::Output: Send + 'static,
    {
        let (task, handle) = task::joinable(future);
        self.scheduler.schedule(task, true);
        handle
    }

    pub(crate) fn block_on<F>(&mut self, mut future: F) -> F::Output
    where
        F: Future,
    {
        use crate::runtime;
        use std::pin::Pin;
        use std::task::Context;
        use std::task::Poll::Ready;

        let local = &mut self.local;
        let scheduler = &*self.scheduler;

        struct Guard {
            old: *const SchedulerPriv,
        }

        impl Drop for Guard {
            fn drop(&mut self) {
                ACTIVE.with(|cell| cell.set(self.old));
            }
        }

        // Track the current scheduler
        let _guard = ACTIVE.with(|cell| {
            let guard = Guard { old: cell.get() };

            cell.set(scheduler as *const SchedulerPriv);

            guard
        });

        let mut _enter = runtime::enter();

        let raw_waker = RawWaker::new(
            scheduler as *const SchedulerPriv as *const (),
            &RawWakerVTable::new(sched_clone_waker, sched_noop, sched_wake_by_ref, sched_noop),
        );

        let waker = ManuallyDrop::new(unsafe { Waker::from_raw(raw_waker) });
        let mut cx = Context::from_waker(&waker);

        // `block_on` takes ownership of `f`. Once it is pinned here, the
        // original `f` binding can no longer be accessed, making the
        // pinning safe.
        let mut future = unsafe { Pin::new_unchecked(&mut future) };

        loop {
            if let Ready(v) = future.as_mut().poll(&mut cx) {
                return v;
            }

            scheduler.tick(local);

            // Maintenance work
            unsafe {
                // safety: this function is safe to call only from the
                // thread the basic scheduler is running on (which we are).
                scheduler.queues.drain_pending_drop();
            }
        }
    }
}

impl Spawner {
    /// Spawn a future onto the thread pool
    pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
    where
        F: Future + Send + 'static,
        F::Output: Send + 'static,
    {
        let (task, handle) = task::joinable(future);
        self.scheduler.schedule(task, true);
        handle
    }
}

// === impl SchedulerPriv ===

impl SchedulerPriv {
    fn tick(&self, local: &mut LocalState<impl Park>) {
        for _ in 0..MAX_TASKS_PER_TICK {
            // Get the current tick
            let tick = local.tick;

            // Increment the tick
            local.tick = tick.wrapping_add(1);
            let next = unsafe {
                // safety: this function is safe to call only from the
                // thread the basic scheduler is running on. The `LocalState`
                // parameter to this method implies that we are on that thread.
                self.queues.next_task(tick)
            };

            let task = match next {
                Some(task) => task,
                None => {
                    local.park.park().ok().expect("failed to park");
                    return;
                }
            };

            if let Some(task) = task.run(&mut || Some(self.into())) {
                unsafe {
                    // safety: this function is safe to call only from the
                    // thread the basic scheduler is running on. The `LocalState`
                    // parameter to this method implies that we are on that thread.
                    self.queues.push_local(task);
                }
            }
        }

        local
            .park
            .park_timeout(Duration::from_millis(0))
            .ok()
            .expect("failed to park");
    }

    /// Schedule the provided task on the scheduler.
    ///
    /// If this scheduler is the `ACTIVE` scheduler, enqueue this task on the local queue, otherwise
    /// the task is enqueued on the remote queue.
    fn schedule(&self, task: Task<Self>, spawn: bool) {
        let is_current = ACTIVE.with(|cell| cell.get() == self as *const SchedulerPriv);

        if is_current {
            unsafe {
                // safety: this function is safe to call only from the
                // thread the basic scheduler is running on. If `is_current` is
                // then we are on that thread.
                self.queues.push_local(task)
            };
        } else {
            let mut lock = self.queues.remote();
            lock.schedule(task, spawn);

            // while locked, call unpark
            self.unpark.unpark();

            drop(lock);
        }
    }
}

impl Schedule for SchedulerPriv {
    fn bind(&self, task: &Task<Self>) {
        unsafe {
            // safety: `Queues::add_task` is only safe to call from the thread
            // that owns the queues (the thread the scheduler is running on).
            // `Scheduler::bind` is called when polling a task that
            // doesn't have a scheduler set. We will only poll new tasks from
            // the thread that the scheduler is running on. Therefore, this is
            // safe to call.
            self.queues.add_task(task);
        }
    }

    fn release(&self, task: Task<Self>) {
        self.queues.release_remote(task);
    }

    fn release_local(&self, task: &Task<Self>) {
        unsafe {
            // safety: `Scheduler::release_local` is only called from the
            // thread that the scheduler is running on. The `Schedule` trait's
            // contract is that releasing a task from another thread should call
            // `release` rather than `release_local`.
            self.queues.release_local(task);
        }
    }

    fn schedule(&self, task: Task<Self>) {
        SchedulerPriv::schedule(self, task, false);
    }
}

impl ScheduleSendOnly for SchedulerPriv {}

impl<P> Drop for BasicScheduler<P>
where
    P: Park,
{
    fn drop(&mut self) {
        unsafe {
            // safety: the `Drop` impl owns the scheduler's queues. these fields
            // will only be accessed when running the scheduler, and it can no
            // longer be run, since we are in the process of dropping it.

            // Shut down the task queues.
            self.scheduler.queues.shutdown();
        }

        // Wait until all tasks have been released.
        loop {
            unsafe {
                self.scheduler.queues.drain_pending_drop();
                self.scheduler.queues.drain_queues();

                if !self.scheduler.queues.has_tasks_remaining() {
                    break;
                }

                self.local.park.park().ok().expect("park failed");
            }
        }
    }
}

impl fmt::Debug for