summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorCarl Lerche <me@carllerche.com>2020-10-08 12:12:56 -0700
committerGitHub <noreply@github.com>2020-10-08 12:12:56 -0700
commit066965cd59d01fd9d999152e32169a24dfe434fa (patch)
treeeef03ca071b8d9f285954a1f98fd85e3e188c98b
parent6259893094ebcdfecb107fcf3682eaad1bd1903b (diff)
net: use &self with TcpListener::accept (#2919)
Uses the infrastructure added by #2828 to enable switching `TcpListener::accept` to use `&self`. This also switches `poll_accept` to use `&self`. While doing introduces a hazard, `poll_*` style functions are considered low-level. Most users will use the `async fn` variants which are more misuse-resistant. TcpListener::incoming() is temporarily removed as it has the same problem as `TcpSocket::by_ref()` and will be implemented later.
-rw-r--r--examples/chat.rs2
-rw-r--r--examples/echo.rs2
-rw-r--r--examples/print_each_packet.rs2
-rw-r--r--examples/proxy.rs2
-rw-r--r--examples/tinydb.rs2
-rw-r--r--examples/tinyhttp.rs8
-rw-r--r--tokio/src/io/driver/scheduled_io.rs67
-rw-r--r--tokio/src/io/registration.rs15
-rw-r--r--tokio/src/lib.rs2
-rw-r--r--tokio/src/macros/cfg.rs2
-rw-r--r--tokio/src/net/tcp/incoming.rs42
-rw-r--r--tokio/src/net/tcp/listener.rs75
-rw-r--r--tokio/src/net/tcp/mod.rs4
-rw-r--r--tokio/src/net/tcp/stream.rs4
-rw-r--r--tokio/src/runtime/mod.rs4
-rw-r--r--tokio/src/task/spawn.rs2
-rw-r--r--tokio/tests/buffered.rs2
-rw-r--r--tokio/tests/io_driver.rs2
-rw-r--r--tokio/tests/io_driver_drop.rs4
-rw-r--r--tokio/tests/rt_common.rs12
-rw-r--r--tokio/tests/rt_threaded.rs2
-rw-r--r--tokio/tests/tcp_accept.rs90
-rw-r--r--tokio/tests/tcp_connect.rs16
-rw-r--r--tokio/tests/tcp_echo.rs2
-rw-r--r--tokio/tests/tcp_into_split.rs2
-rw-r--r--tokio/tests/tcp_shutdown.rs2
-rw-r--r--tokio/tests/tcp_socket.rs6
27 files changed, 201 insertions, 174 deletions
diff --git a/examples/chat.rs b/examples/chat.rs
index c4b8c6a2..3f945039 100644
--- a/examples/chat.rs
+++ b/examples/chat.rs
@@ -77,7 +77,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Bind a TCP listener to the socket address.
//
// Note that this is the Tokio TcpListener, which is fully async.
- let mut listener = TcpListener::bind(&addr).await?;
+ let listener = TcpListener::bind(&addr).await?;
tracing::info!("server running on {}", addr);
diff --git a/examples/echo.rs b/examples/echo.rs
index f3068074..d492e07e 100644
--- a/examples/echo.rs
+++ b/examples/echo.rs
@@ -39,7 +39,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Next up we create a TCP listener which will listen for incoming
// connections. This TCP listener is bound to the address we determined
// above and must be associated with an event loop.
- let mut listener = TcpListener::bind(&addr).await?;
+ let listener = TcpListener::bind(&addr).await?;
println!("Listening on: {}", addr);
loop {
diff --git a/examples/print_each_packet.rs b/examples/print_each_packet.rs
index d650b5bd..b3e1b17e 100644
--- a/examples/print_each_packet.rs
+++ b/examples/print_each_packet.rs
@@ -74,7 +74,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// above and must be associated with an event loop, so we pass in a handle
// to our event loop. After the socket's created we inform that we're ready
// to go and start accepting connections.
- let mut listener = TcpListener::bind(&addr).await?;
+ let listener = TcpListener::bind(&addr).await?;
println!("Listening on: {}", addr);
loop {
diff --git a/examples/proxy.rs b/examples/proxy.rs
index 144f0179..2d9b7ce3 100644
--- a/examples/proxy.rs
+++ b/examples/proxy.rs
@@ -43,7 +43,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
println!("Listening on: {}", listen_addr);
println!("Proxying to: {}", server_addr);
- let mut listener = TcpListener::bind(listen_addr).await?;
+ let listener = TcpListener::bind(listen_addr).await?;
while let Ok((inbound, _)) = listener.accept().await {
let transfer = transfer(inbound, server_addr.clone()).map(|r| {
diff --git a/examples/tinydb.rs b/examples/tinydb.rs
index c1af2541..f0db7fa8 100644
--- a/examples/tinydb.rs
+++ b/examples/tinydb.rs
@@ -89,7 +89,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
.nth(1)
.unwrap_or_else(|| "127.0.0.1:8080".to_string());
- let mut listener = TcpListener::bind(&addr).await?;
+ let listener = TcpListener::bind(&addr).await?;
println!("Listening on: {}", addr);
// Create the shared state of this server that will be shared amongst all
diff --git a/examples/tinyhttp.rs b/examples/tinyhttp.rs
index 4870aea2..c561bbd3 100644
--- a/examples/tinyhttp.rs
+++ b/examples/tinyhttp.rs
@@ -30,19 +30,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
let addr = env::args()
.nth(1)
.unwrap_or_else(|| "127.0.0.1:8080".to_string());
- let mut server = TcpListener::bind(&addr).await?;
- let mut incoming = server.incoming();
+ let server = TcpListener::bind(&addr).await?;
println!("Listening on: {}", addr);
- while let Some(Ok(stream)) = incoming.next().await {
+ loop {
+ let (stream, _) = server.accept().await?;
tokio::spawn(async move {
if let Err(e) = process(stream).await {
println!("failed to process connection; error = {}", e);
}
});
}
-
- Ok(())
}
async fn process(stream: TcpStream) -> Result<(), Box<dyn Error>> {
diff --git a/tokio/src/io/driver/scheduled_io.rs b/tokio/src/io/driver/scheduled_io.rs
index bdf21798..0c0448c3 100644
--- a/tokio/src/io/driver/scheduled_io.rs
+++ b/tokio/src/io/driver/scheduled_io.rs
@@ -32,7 +32,7 @@ cfg_io_readiness! {
#[derive(Debug, Default)]
struct Waiters {
- #[cfg(any(feature = "udp", feature = "uds"))]
+ #[cfg(any(feature = "tcp", feature = "udp", feature = "uds"))]
/// List of all current waiters
list: WaitList,
@@ -186,33 +186,78 @@ impl ScheduledIo {
}
}
+ /// Notifies all pending waiters that have registered interest in `ready`.
+ ///
+ /// There may be many waiters to notify. Waking the pending task **must** be
+ /// done from outside of the lock otherwise there is a potential for a
+ /// deadlock.
+ ///
+ /// A stack array of wakers is created and filled with wakers to notify, the
+ /// lock is released, and the wakers are notified. Because there may be more
+ /// than 32 wakers to notify, if the stack array fills up, the lock is
+ /// released, the array is cleared, and the iteration continues.
pub(super) fn wake(&self, ready: Ready) {
+ const NUM_WAKERS: usize = 32;
+
+ let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default();
+ let mut curr = 0;
+
let mut waiters = self.waiters.lock();
// check for AsyncRead slot
if ready.is_readable() {
if let Some(waker) = waiters.reader.take() {
- waker.wake();
+ wakers[curr] = Some(waker);
+ curr += 1;
}
}
// check for AsyncWrite slot
if ready.is_writable() {
if let Some(waker) = waiters.writer.take() {
- waker.wake();
+ wakers[curr] = Some(waker);
+ curr += 1;
}
}
- #[cfg(any(feature = "udp", feature = "uds"))]
- {
- // check list of waiters
- for waiter in waiters.list.drain_filter(|w| ready.satisfies(w.interest)) {
- let waiter = unsafe { &mut *waiter.as_ptr() };
- if let Some(waker) = waiter.waker.take() {
- waiter.is_ready = true;
- waker.wake();
+ #[cfg(any(feature = "tcp", feature = "udp", feature = "uds"))]
+ 'outer: loop {
+ let mut iter = waiters.list.drain_filter(|w| ready.satisfies(w.interest));
+
+ while curr < NUM_WAKERS {
+ match iter.next() {
+ Some(waiter) => {
+ let waiter = unsafe { &mut *waiter.as_ptr() };
+
+ if let Some(waker) = waiter.waker.take() {
+ waiter.is_ready = true;
+ wakers[curr] = Some(waker);
+ curr += 1;
+ }
+ }
+ None => {
+ break 'outer;
+ }
}
}
+
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr) {
+ waker.take().unwrap().wake();
+ }
+
+ curr = 0;
+
+ // Acquire the lock again.
+ waiters = self.waiters.lock();
+ }
+
+ // Release the lock before notifying
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr) {
+ waker.take().unwrap().wake();
}
}
diff --git a/tokio/src/io/registration.rs b/tokio/src/io/registration.rs
index 03221b60..ce6cffda 100644
--- a/tokio/src/io/registration.rs
+++ b/tokio/src/io/registration.rs
@@ -132,8 +132,19 @@ impl Registration {
cfg_io_readiness! {
impl Registration {
pub(super) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> {
- // TODO: does this need to return a `Result`?
- Ok(self.shared.readiness(interest).await)
+ use std::future::Future;
+ use std::pin::Pin;
+
+ let fut = self.shared.readiness(interest);
+ pin!(fut);
+
+ crate::future::poll_fn(|cx| {
+ if self.handle.inner().is_none() {
+ return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone")));
+ }
+
+ Pin::new(&mut fut).poll(cx).map(Ok)
+ }).await
}
}
}
diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs
index 1b0dad5d..948ac888 100644
--- a/tokio/src/lib.rs
+++ b/tokio/src/lib.rs
@@ -306,7 +306,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
-//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs
index 328f3230..8f1536f8 100644
--- a/tokio/src/macros/cfg.rs
+++ b/tokio/src/macros/cfg.rs
@@ -176,7 +176,7 @@ macro_rules! cfg_not_io_driver {
macro_rules! cfg_io_readiness {
($($item:item)*) => {
$(
- #[cfg(any(feature = "udp", feature = "uds"))]
+ #[cfg(any(feature = "udp", feature = "uds", feature = "tcp"))]
$item
)*
}
diff --git a/tokio/src/net/tcp/incoming.rs b/tokio/src/net/tcp/incoming.rs
deleted file mode 100644
index 062be1e9..00000000
--- a/tokio/src/net/tcp/incoming.rs
+++ /dev/null
@@ -1,42 +0,0 @@
-use crate::net::tcp::{TcpListener, TcpStream};
-
-use std::io;
-use std::pin::Pin;
-use std::task::{Context, Poll};
-
-/// Stream returned by the `TcpListener::incoming` function representing the
-/// stream of sockets received from a listener.
-#[must_use = "streams do nothing unless polled"]
-#[derive(Debug)]
-pub struct Incoming<'a> {
- inner: &'a mut TcpListener,
-}
-
-impl Incoming<'_> {
- pub(crate) fn new(listener: &mut TcpListener) -> Incoming<'_> {
- Incoming { inner: listener }
- }
-
- /// Attempts to poll `TcpStream` by polling inner `TcpListener` to accept
- /// connection.
- ///
- /// If `TcpListener` isn't ready yet, `Poll::Pending` is returned and
- /// current task will be notified by a waker.
- pub fn poll_accept(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<TcpStream>> {
- let (socket, _) = ready!(self.inner.poll_accept(cx))?;
- Poll::Ready(Ok(socket))
- }
-}
-
-#[cfg(feature = "stream")]
-impl crate::stream::Stream for Incoming<'_> {
- type Item = io::Result<TcpStream>;
-
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- let (socket, _) = ready!(self.inner.poll_accept(cx))?;
- Poll::Ready(Some(Ok(socket)))
- }
-}
diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs
index 133852d2..98c8961e 100644
--- a/tokio/src/net/tcp/listener.rs
+++ b/tokio/src/net/tcp/listener.rs
@@ -1,6 +1,5 @@
-use crate::future::poll_fn;
use crate::io::PollEvented;
-use crate::net::tcp::{Incoming, TcpStream};
+use crate::net::tcp::TcpStream;
use crate::net::{to_socket_addrs, ToSocketAddrs};
use std::convert::TryFrom;
@@ -40,7 +39,7 @@ cfg_tcp! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+ /// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
@@ -171,7 +170,7 @@ impl TcpListener {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+ /// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// match listener.accept().await {
/// Ok((_socket, addr)) => println!("new client: {:?}", addr),
@@ -181,18 +180,25 @@ impl TcpListener {
/// Ok(())
/// }
/// ```
- pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
- poll_fn(|cx| self.poll_accept(cx)).await
+ pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
+ let (mio, addr) = self
+ .io
+ .async_io(mio::Interest::READABLE, |sock| sock.accept())
+ .await?;
+
+ let stream = TcpStream::new(mio)?;
+ Ok((stream, addr))
}
/// Polls to accept a new incoming connection to this listener.
///
- /// If there is no connection to accept, `Poll::Pending` is returned and
- /// the current task will be notified by a waker.
- pub fn poll_accept(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
+ /// If there is no connection to accept, `Poll::Pending` is returned and the
+ /// current task will be notified by a waker.
+ ///
+ /// When ready, the most recent task that called `poll_accept` is notified.
+ /// The caller is responsble to ensure that `poll_accept` is called from a
+ /// single task. Failing to do this could result in tasks hanging.
+ pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;
@@ -293,46 +299,6 @@ impl TcpListener {
self.io.get_ref().local_addr()
}
- /// Returns a stream over the connections being received on this listener.
- ///
- /// Note that `TcpListener` also directly implements `Stream`.
- ///
- /// The returned stream will never return `None` and will also not yield the
- /// peer's `SocketAddr` structure. Iterating over it is equivalent to
- /// calling accept in a loop.
- ///
- /// # Errors
- ///
- /// Note that accepting a connection can lead to various errors and not all
- /// of them are necessarily fatal ‒ for example having too many open file
- /// descriptors or the other side closing the connection while it waits in
- /// an accept queue. These would terminate the stream if not handled in any
- /// way.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::{net::TcpListener, stream::StreamExt};
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
- /// let mut incoming = listener.incoming();
- ///
- /// while let Some(stream) = incoming.next().await {
- /// match stream {
- /// Ok(stream) => {
- /// println!("new client!");
- /// }
- /// Err(e) => { /* connection failed */ }
- /// }
- /// }
- /// }
- /// ```
- pub fn incoming(&mut self) -> Incoming<'_> {
- Incoming::new(self)
- }
-
/// Gets the value of the `IP_TTL` option for this socket.
///
/// For more information about this option, see [`set_ttl`].
@@ -390,10 +356,7 @@ impl TcpListener {
impl crate::stream::Stream for TcpListener {
type Item = io::Result<TcpStream>;
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Self::Item>> {
+ fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (socket, _) = ready!(self.poll_accept(cx))?;
Poll::Ready(Some(Ok(socket)))
}
diff --git a/tokio/src/net/tcp/mod.rs b/tokio/src/net/tcp/mod.rs
index c27038f9..7f0f6d91 100644
--- a/tokio/src/net/tcp/mod.rs
+++ b/tokio/src/net/tcp/mod.rs
@@ -1,10 +1,6 @@
//! TCP utility types
pub(crate) mod listener;
-pub(crate) use listener::TcpListener;
-
-mod incoming;
-pub use incoming::Incoming;
pub(crate) mod socket;
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index 4349ea80..3f9d6670 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -22,8 +22,8 @@ cfg_tcp! {
/// traits. Examples import these traits through [the prelude].
///
/// [`connect`]: method@TcpStream::connect
- /// [accepting]: method@super::TcpListener::accept
- /// [listener]: struct@super::TcpListener
+ /// [accepting]: method@crate::net::TcpListener::accept
+ /// [listener]: struct@crate::net::TcpListener
/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
/// [the prelude]: crate::prelude
diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs
index a6a739be..22109f7d 100644
--- a/tokio/src/runtime/mod.rs
+++ b/tokio/src/runtime/mod.rs
@@ -25,7 +25,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
-//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
@@ -73,7 +73,7 @@
//!
//! // Spawn the root task
//! rt.block_on(async {
-//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs
index 280e90ea..d7aca572 100644
--- a/tokio/src/task/spawn.rs
+++ b/tokio/src/task/spawn.rs
@@ -37,7 +37,7 @@ doc_rt_core! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+ /// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
diff --git a/tokio/tests/buffered.rs b/tokio/tests/buffered.rs
index 595f855a..97ba00cd 100644
--- a/tokio/tests/buffered.rs
+++ b/tokio/tests/buffered.rs
@@ -13,7 +13,7 @@ use std::thread;
async fn echo_server() {
const N: usize = 1024;
- let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(srv.local_addr());
let msg = "foo bar baz";
diff --git a/tokio/tests/io_driver.rs b/tokio/tests/io_driver.rs
index d4f4f8d4..01be3659 100644
--- a/tokio/tests/io_driver.rs
+++ b/tokio/tests/io_driver.rs
@@ -56,7 +56,7 @@ fn test_drop_on_notify() {
// Define a task that just drains the listener
let task = Arc::new(Task::new(async move {
// Create a listener
- let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
// Send the address
let addr = listener.local_addr().unwrap();
diff --git a/tokio/tests/io_driver_drop.rs b/tokio/tests/io_driver_drop.rs
index 0a5ce625..2ee02a42 100644
--- a/tokio/tests/io_driver_drop.rs
+++ b/tokio/tests/io_driver_drop.rs
@@ -9,7 +9,7 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task};
fn tcp_doesnt_block() {
let rt = rt();
- let mut listener = rt.enter(|| {
+ let listener = rt.enter(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
@@ -27,7 +27,7 @@ fn tcp_doesnt_block() {
fn drop_wakes() {
let rt = rt();
- let mut listener = rt.enter(|| {
+ let listener = rt.enter(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs
index 3e95c2aa..93d6a44e 100644
--- a/tokio/tests/rt_common.rs
+++ b/tokio/tests/rt_common.rs
@@ -471,7 +471,7 @@ rt_test! {
rt.block_on(async move {
let (tx, rx) = oneshot::channel();
- let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
@@ -539,7 +539,7 @@ rt_test! {
let rt = rt();
rt.block_on(async move {
- let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());
let peer = tokio::task::spawn_blocking(move || {
@@ -634,7 +634,7 @@ rt_test! {
// Do some I/O work
rt.block_on(async {
- let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());
let srv = tokio::spawn(async move {
@@ -912,7 +912,7 @@ rt_test! {
}
async fn client_server(tx: mpsc::Sender<()>) {
- let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
// Get the assigned address
let addr = assert_ok!(server.local_addr());
@@ -943,7 +943,7 @@ rt_test! {
local.block_on(&rt, async move {
let (tx, rx) = oneshot::channel();
- let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
task::spawn_local(async move {
@@ -970,7 +970,7 @@ rt_test! {
}
async fn client_server_local(tx: mpsc::Sender<()>) {
- let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
// Get the assigned address
let addr = assert_ok!(server.local_addr());
diff --git a/tokio/tests/rt_threaded.rs b/tokio/tests/rt_threaded.rs
index 2c7cfb80..1ac6ed32 100644
--- a/tokio/tests/rt_threaded.rs
+++ b/tokio/tests/rt_threaded.rs
@@ -139,7 +139,7 @@ fn spawn_shutdown() {
}
async fn client_server(tx: mpsc::Sender<()>) {
- let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
// Get the assigned address
let addr = assert_ok!(server.local_addr());
diff --git a/tokio/tests/tcp_accept.rs b/tokio/tests/tcp_accept.rs
index 9f5b4414..4c0d6822 100644
--- a/tokio/tests/tcp_accept.rs
+++ b/tokio/tests/tcp_accept.rs
@@ -5,6 +5,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tokio_test::assert_ok;
+use std::io;
use std::net::{IpAddr, SocketAddr};
macro_rules! test_accept {
@@ -12,7 +13,7 @@ macro_rules! test_accept {
$(
#[tokio::test]
async fn $ident() {
- let mut listener = assert_ok!(TcpListener::bind($target).await);
+ let listener = assert_ok!(TcpListener::bind($target).await);
let addr = listener.local_addr().unwrap();
let (tx, rx) = oneshot::channel();
@@ -39,7 +40,6 @@ test_accept! {
(ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
}
-use pin_project_lite::pin_project;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
@@ -48,23 +48,17 @@ use std::sync::{
use std::task::{Context, Poll};
use tokio::stream::{Stream, StreamExt};
-pin_project! {
- struct TrackPolls<S> {
- npolls: Arc<AtomicUsize>,
- #[pin]
- s: S,
- }
+struct TrackPolls<'a> {
+ npolls: Arc<AtomicUsize>,
+ listener: &'a mut TcpListener,
}
-impl<S> Stream for TrackPolls<S>
-where
- S: Stream,
-{
- type Item = S::Item;
+impl<'a> Stream for TrackPolls<'a> {
+ type Item = io::Result<(TcpStream, SocketAddr)>;
+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- let this = self.project();
- this.npolls.fetch_add(1, SeqCst);
- this.s.poll_next(cx)
+ self.npolls.fetch_add(1, SeqCst);
+ self.listener.poll_accept(cx).map(Some)
}
}
@@ -79,7 +73,7 @@ async fn no_extra_poll() {
tokio::spawn(async move {
let mut incoming = TrackPolls {
npolls: Arc::new(AtomicUsize::new(0)),
- s: listener.incoming(),
+ listener: &mut listener,
};
assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
while incoming.next().await.is_some() {
@@ -99,3 +93,65 @@ async fn no_extra_poll() {
// should have been polled twice more: once to yield Some(), then once to yield Pending
assert_eq!(npolls.load(SeqCst), 1 + 2);
}
+
+#[tokio::test]
+async fn accept_many() {
+ use futures::future::poll_fn;
+ use std::future::Future;
+ use std::sync::atomic::AtomicBool;
+
+ const N: usize = 50;
+
+ let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let listener = Arc::new(listener);
+ let addr = listener.local_addr().unwrap();
+ let connected = Arc::new(AtomicBool::new(false));
+
+ let (pending_tx, mut pending_rx) = mpsc::unbounded_channel();
+ let (notified_tx, mut notified_rx) = mpsc::unbounded_channel();
+
+ for _ in 0..N {
+ let listener = listener.clone();
+ let connected = connected.clone();
+ let pending_tx = pending_tx.clone();
+ let notified_tx = notified_tx.clone();
+
+ tokio::spawn(async move {
+ let accept = listener.accept();
+ tokio::pin!(accept);
+
+ let mut polled = false;
+
+ poll_fn(|cx| {
+ if !polled {
+ polled = true;
+ assert!(Pin::new(&mut accept).poll(cx).is_pending());
+ pending_tx.send(()).unwrap();
+ Poll::Pending
+ } else if connected.load(SeqCst) {
+ notified_tx.send(()).unwrap();
+ Poll::Ready(())
+ } else {
+ Poll::Pending
+ }
+ })
+ .await;
+
+ pending_tx.send(()).unwrap();
+ });
+ }
+
+ // Wait for all tasks to have polled at least once
+ for _ in 0..N {
+ pending_rx.recv().await.unwrap();
+ }
+
+ // Establish a TCP connection
+ connected.store(true, SeqCst);
+ let _sock = TcpStream::connect(addr).await.unwrap();
+
+ // Wait for all notifications
+ for _ in 0..N {
+ notified_rx.recv().await.unwrap();
+ }
+}
diff --git a/tokio/tests/tcp_connect.rs b/tokio/tests/tcp_connect.rs
index de1cead8..44942c4e 100644
--- a/tokio/tests/tcp_connect.rs
+++ b/tokio/tests/tcp_connect.rs
@@ -9,7 +9,7 @@ use futures::join;