From 066965cd59d01fd9d999152e32169a24dfe434fa Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Thu, 8 Oct 2020 12:12:56 -0700 Subject: net: use &self with TcpListener::accept (#2919) Uses the infrastructure added by #2828 to enable switching `TcpListener::accept` to use `&self`. This also switches `poll_accept` to use `&self`. While doing introduces a hazard, `poll_*` style functions are considered low-level. Most users will use the `async fn` variants which are more misuse-resistant. TcpListener::incoming() is temporarily removed as it has the same problem as `TcpSocket::by_ref()` and will be implemented later. --- tokio/tests/buffered.rs | 2 +- tokio/tests/io_driver.rs | 2 +- tokio/tests/io_driver_drop.rs | 4 +- tokio/tests/rt_common.rs | 12 +++--- tokio/tests/rt_threaded.rs | 2 +- tokio/tests/tcp_accept.rs | 90 +++++++++++++++++++++++++++++++++++-------- tokio/tests/tcp_connect.rs | 16 ++++---- tokio/tests/tcp_echo.rs | 2 +- tokio/tests/tcp_into_split.rs | 2 +- tokio/tests/tcp_shutdown.rs | 2 +- tokio/tests/tcp_socket.rs | 6 +-- 11 files changed, 98 insertions(+), 42 deletions(-) (limited to 'tokio/tests') diff --git a/tokio/tests/buffered.rs b/tokio/tests/buffered.rs index 595f855a..97ba00cd 100644 --- a/tokio/tests/buffered.rs +++ b/tokio/tests/buffered.rs @@ -13,7 +13,7 @@ use std::thread; async fn echo_server() { const N: usize = 1024; - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let msg = "foo bar baz"; diff --git a/tokio/tests/io_driver.rs b/tokio/tests/io_driver.rs index d4f4f8d4..01be3659 100644 --- a/tokio/tests/io_driver.rs +++ b/tokio/tests/io_driver.rs @@ -56,7 +56,7 @@ fn test_drop_on_notify() { // Define a task that just drains the listener let task = Arc::new(Task::new(async move { // Create a listener - let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Send the address let addr = listener.local_addr().unwrap(); diff --git a/tokio/tests/io_driver_drop.rs b/tokio/tests/io_driver_drop.rs index 0a5ce625..2ee02a42 100644 --- a/tokio/tests/io_driver_drop.rs +++ b/tokio/tests/io_driver_drop.rs @@ -9,7 +9,7 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task}; fn tcp_doesnt_block() { let rt = rt(); - let mut listener = rt.enter(|| { + let listener = rt.enter(|| { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); TcpListener::from_std(listener).unwrap() }); @@ -27,7 +27,7 @@ fn tcp_doesnt_block() { fn drop_wakes() { let rt = rt(); - let mut listener = rt.enter(|| { + let listener = rt.enter(|| { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); TcpListener::from_std(listener).unwrap() }); diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 3e95c2aa..93d6a44e 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -471,7 +471,7 @@ rt_test! { rt.block_on(async move { let (tx, rx) = oneshot::channel(); - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { @@ -539,7 +539,7 @@ rt_test! { let rt = rt(); rt.block_on(async move { - let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(listener.local_addr()); let peer = tokio::task::spawn_blocking(move || { @@ -634,7 +634,7 @@ rt_test! { // Do some I/O work rt.block_on(async { - let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(listener.local_addr()); let srv = tokio::spawn(async move { @@ -912,7 +912,7 @@ rt_test! { } async fn client_server(tx: mpsc::Sender<()>) { - let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Get the assigned address let addr = assert_ok!(server.local_addr()); @@ -943,7 +943,7 @@ rt_test! { local.block_on(&rt, async move { let (tx, rx) = oneshot::channel(); - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); task::spawn_local(async move { @@ -970,7 +970,7 @@ rt_test! { } async fn client_server_local(tx: mpsc::Sender<()>) { - let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Get the assigned address let addr = assert_ok!(server.local_addr()); diff --git a/tokio/tests/rt_threaded.rs b/tokio/tests/rt_threaded.rs index 2c7cfb80..1ac6ed32 100644 --- a/tokio/tests/rt_threaded.rs +++ b/tokio/tests/rt_threaded.rs @@ -139,7 +139,7 @@ fn spawn_shutdown() { } async fn client_server(tx: mpsc::Sender<()>) { - let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Get the assigned address let addr = assert_ok!(server.local_addr()); diff --git a/tokio/tests/tcp_accept.rs b/tokio/tests/tcp_accept.rs index 9f5b4414..4c0d6822 100644 --- a/tokio/tests/tcp_accept.rs +++ b/tokio/tests/tcp_accept.rs @@ -5,6 +5,7 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, oneshot}; use tokio_test::assert_ok; +use std::io; use std::net::{IpAddr, SocketAddr}; macro_rules! test_accept { @@ -12,7 +13,7 @@ macro_rules! test_accept { $( #[tokio::test] async fn $ident() { - let mut listener = assert_ok!(TcpListener::bind($target).await); + let listener = assert_ok!(TcpListener::bind($target).await); let addr = listener.local_addr().unwrap(); let (tx, rx) = oneshot::channel(); @@ -39,7 +40,6 @@ test_accept! { (ip_port_tuple, ("127.0.0.1".parse::().unwrap(), 0)), } -use pin_project_lite::pin_project; use std::pin::Pin; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, @@ -48,23 +48,17 @@ use std::sync::{ use std::task::{Context, Poll}; use tokio::stream::{Stream, StreamExt}; -pin_project! { - struct TrackPolls { - npolls: Arc, - #[pin] - s: S, - } +struct TrackPolls<'a> { + npolls: Arc, + listener: &'a mut TcpListener, } -impl Stream for TrackPolls -where - S: Stream, -{ - type Item = S::Item; +impl<'a> Stream for TrackPolls<'a> { + type Item = io::Result<(TcpStream, SocketAddr)>; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - this.npolls.fetch_add(1, SeqCst); - this.s.poll_next(cx) + self.npolls.fetch_add(1, SeqCst); + self.listener.poll_accept(cx).map(Some) } } @@ -79,7 +73,7 @@ async fn no_extra_poll() { tokio::spawn(async move { let mut incoming = TrackPolls { npolls: Arc::new(AtomicUsize::new(0)), - s: listener.incoming(), + listener: &mut listener, }; assert_ok!(tx.send(Arc::clone(&incoming.npolls))); while incoming.next().await.is_some() { @@ -99,3 +93,65 @@ async fn no_extra_poll() { // should have been polled twice more: once to yield Some(), then once to yield Pending assert_eq!(npolls.load(SeqCst), 1 + 2); } + +#[tokio::test] +async fn accept_many() { + use futures::future::poll_fn; + use std::future::Future; + use std::sync::atomic::AtomicBool; + + const N: usize = 50; + + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let listener = Arc::new(listener); + let addr = listener.local_addr().unwrap(); + let connected = Arc::new(AtomicBool::new(false)); + + let (pending_tx, mut pending_rx) = mpsc::unbounded_channel(); + let (notified_tx, mut notified_rx) = mpsc::unbounded_channel(); + + for _ in 0..N { + let listener = listener.clone(); + let connected = connected.clone(); + let pending_tx = pending_tx.clone(); + let notified_tx = notified_tx.clone(); + + tokio::spawn(async move { + let accept = listener.accept(); + tokio::pin!(accept); + + let mut polled = false; + + poll_fn(|cx| { + if !polled { + polled = true; + assert!(Pin::new(&mut accept).poll(cx).is_pending()); + pending_tx.send(()).unwrap(); + Poll::Pending + } else if connected.load(SeqCst) { + notified_tx.send(()).unwrap(); + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + + pending_tx.send(()).unwrap(); + }); + } + + // Wait for all tasks to have polled at least once + for _ in 0..N { + pending_rx.recv().await.unwrap(); + } + + // Establish a TCP connection + connected.store(true, SeqCst); + let _sock = TcpStream::connect(addr).await.unwrap(); + + // Wait for all notifications + for _ in 0..N { + notified_rx.recv().await.unwrap(); + } +} diff --git a/tokio/tests/tcp_connect.rs b/tokio/tests/tcp_connect.rs index de1cead8..44942c4e 100644 --- a/tokio/tests/tcp_connect.rs +++ b/tokio/tests/tcp_connect.rs @@ -9,7 +9,7 @@ use futures::join; #[tokio::test] async fn connect_v4() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); assert!(addr.is_ipv4()); @@ -36,7 +36,7 @@ async fn connect_v4() { #[tokio::test] async fn connect_v6() { - let mut srv = assert_ok!(TcpListener::bind("[::1]:0").await); + let srv = assert_ok!(TcpListener::bind("[::1]:0").await); let addr = assert_ok!(srv.local_addr()); assert!(addr.is_ipv6()); @@ -63,7 +63,7 @@ async fn connect_v6() { #[tokio::test] async fn connect_addr_ip_string() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = format!("127.0.0.1:{}", addr.port()); @@ -80,7 +80,7 @@ async fn connect_addr_ip_string() { #[tokio::test] async fn connect_addr_ip_str_slice() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = format!("127.0.0.1:{}", addr.port()); @@ -97,7 +97,7 @@ async fn connect_addr_ip_str_slice() { #[tokio::test] async fn connect_addr_host_string() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = format!("localhost:{}", addr.port()); @@ -114,7 +114,7 @@ async fn connect_addr_host_string() { #[tokio::test] async fn connect_addr_ip_port_tuple() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = (addr.ip(), addr.port()); @@ -131,7 +131,7 @@ async fn connect_addr_ip_port_tuple() { #[tokio::test] async fn connect_addr_ip_str_port_tuple() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = ("127.0.0.1", addr.port()); @@ -148,7 +148,7 @@ async fn connect_addr_ip_str_port_tuple() { #[tokio::test] async fn connect_addr_host_str_port_tuple() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = ("localhost", addr.port()); diff --git a/tokio/tests/tcp_echo.rs b/tokio/tests/tcp_echo.rs index 1feba63e..d9cb456f 100644 --- a/tokio/tests/tcp_echo.rs +++ b/tokio/tests/tcp_echo.rs @@ -12,7 +12,7 @@ async fn echo_server() { let (tx, rx) = oneshot::channel(); - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let msg = "foo bar baz"; diff --git a/tokio/tests/tcp_into_split.rs b/tokio/tests/tcp_into_split.rs index 86ed4619..b4bb2eeb 100644 --- a/tokio/tests/tcp_into_split.rs +++ b/tokio/tests/tcp_into_split.rs @@ -13,7 +13,7 @@ use tokio::try_join; async fn split() -> Result<()> { const MSG: &[u8] = b"split"; - let mut listener = TcpListener::bind("127.0.0.1:0").await?; + let listener = TcpListener::bind("127.0.0.1:0").await?; let addr = listener.local_addr()?; let (stream1, (mut stream2, _)) = try_join! { diff --git a/tokio/tests/tcp_shutdown.rs b/tokio/tests/tcp_shutdown.rs index bd43e143..615855f1 100644 --- a/tokio/tests/tcp_shutdown.rs +++ b/tokio/tests/tcp_shutdown.rs @@ -8,7 +8,7 @@ use tokio_test::assert_ok; #[tokio::test] async fn shutdown() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); tokio::spawn(async move { diff --git a/tokio/tests/tcp_socket.rs b/tokio/tests/tcp_socket.rs index 993a1e0c..9258864d 100644 --- a/tokio/tests/tcp_socket.rs +++ b/tokio/tests/tcp_socket.rs @@ -11,7 +11,7 @@ async fn basic_usage_v4() { let srv = assert_ok!(TcpSocket::new_v4()); assert_ok!(srv.bind(addr)); - let mut srv = assert_ok!(srv.listen(128)); + let srv = assert_ok!(srv.listen(128)); // Create client & connect let addr = srv.local_addr().unwrap(); @@ -29,7 +29,7 @@ async fn basic_usage_v6() { let srv = assert_ok!(TcpSocket::new_v6()); assert_ok!(srv.bind(addr)); - let mut srv = assert_ok!(srv.listen(128)); + let srv = assert_ok!(srv.listen(128)); // Create client & connect let addr = srv.local_addr().unwrap(); @@ -47,7 +47,7 @@ async fn bind_before_connect() { let srv = assert_ok!(TcpSocket::new_v4()); assert_ok!(srv.bind(any_addr)); - let mut srv = assert_ok!(srv.listen(128)); + let srv = assert_ok!(srv.listen(128)); // Create client & connect let addr = srv.local_addr().unwrap(); -- cgit v1.2.3