summaryrefslogtreecommitdiffstats
path: root/tokio/src/runtime/shell.rs
blob: 486d4fa5bbe0be6ead132c7821f55099eabf4a90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#![allow(clippy::redundant_clone)]

use crate::future::poll_fn;
use crate::park::{Park, Unpark};
use crate::runtime::driver::Driver;
use crate::sync::Notify;
use crate::util::{waker_ref, Wake};

use std::sync::{Arc, Mutex};
use std::task::Context;
use std::task::Poll::{Pending, Ready};
use std::{future::Future, sync::PoisonError};

#[derive(Debug)]
pub(super) struct Shell {
    driver: Mutex<Option<Driver>>,

    notify: Notify,

    /// TODO: don't store this
    unpark: Arc<Handle>,
}

#[derive(Debug)]
struct Handle(<Driver as Park>::Unpark);

impl Shell {
    pub(super) fn new(driver: Driver) -> Shell {
        let unpark = Arc::new(Handle(driver.unpark()));

        Shell {
            driver: Mutex::new(Some(driver)),
            notify: Notify::new(),
            unpark,
        }
    }

    pub(super) fn block_on<F>(&self, f: F) -> F::Output
    where
        F: Future,
    {
        let mut enter = crate::runtime::enter(true);

        pin!(f);

        loop {
            if let Some(driver) = &mut self.take_driver() {
                return driver.block_on(f);
            } else {
                let notified = self.notify.notified();
                pin!(notified);

                if let Some(out) = enter
                    .block_on(poll_fn(|cx| {
                        if notified.as_mut().poll(cx).is_ready() {
                            return Ready(None);
                        }

                        if let Ready(out) = f.as_mut().poll(cx) {
                            return Ready(Some(out));
                        }

                        Pending
                    }))
                    .expect("Failed to `Enter::block_on`")
                {
                    return out;
                }
            }
        }
    }

    fn take_driver(&self) -> Option<DriverGuard<'_>> {
        let mut lock = self.driver.lock().unwrap();
        let driver = lock.take()?;

        Some(DriverGuard {
            inner: Some(driver),
            shell: &self,
        })
    }
}

impl Wake for Handle {
    /// Wake by value
    fn wake(self: Arc<Self>) {
        Wake::wake_by_ref(&self);
    }

    /// Wake by reference
    fn wake_by_ref(arc_self: &Arc<Self>) {
        arc_self.0.unpark();
    }
}

struct DriverGuard<'a> {
    inner: Option<Driver>,
    shell: &'a Shell,
}

impl DriverGuard<'_> {
    fn block_on<F: Future>(&mut self, f: F) -> F::Output {
        let driver = self.inner.as_mut().unwrap();

        pin!(f);

        let waker = waker_ref(&self.shell.unpark);
        let mut cx = Context::from_waker(&waker);

        loop {
            if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
                return v;
            }

            driver.park().unwrap();
        }
    }
}

impl Drop for DriverGuard<'_> {
    fn drop(&mut self) {
        if let Some(inner) = self.inner.take() {
            self.shell
                .driver
                .lock()
                .unwrap_or_else(PoisonError::into_inner)
                .replace(inner);

            self.shell.notify.notify_one();
        }
    }
}