#![warn(rust_2018_idioms)] #![cfg(feature = "full")] use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, oneshot}; use tokio_test::assert_ok; use std::net::{IpAddr, SocketAddr}; macro_rules! test_accept { ($(($ident:ident, $target:expr),)*) => { $( #[tokio::test] async fn $ident() { let mut listener = assert_ok!(TcpListener::bind($target).await); let addr = listener.local_addr().unwrap(); let (tx, rx) = oneshot::channel(); tokio::spawn(async move { let (socket, _) = assert_ok!(listener.accept().await); assert_ok!(tx.send(socket)); }); let cli = assert_ok!(TcpStream::connect(&addr).await); let srv = assert_ok!(rx.await); assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap()); } )* } } test_accept! { (ip_str, "127.0.0.1:0"), (host_str, "localhost:0"), (socket_addr, "127.0.0.1:0".parse::().unwrap()), (str_port_tuple, ("127.0.0.1", 0)), (ip_port_tuple, ("127.0.0.1".parse::().unwrap(), 0)), } use pin_project_lite::pin_project; use std::pin::Pin; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, Arc, }; use std::task::{Context, Poll}; use tokio::stream::{Stream, StreamExt}; pin_project! { struct TrackPolls { npolls: Arc, #[pin] s: S, } } impl Stream for TrackPolls where S: Stream, { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); this.npolls.fetch_add(1, SeqCst); this.s.poll_next(cx) } } #[tokio::test] async fn no_extra_poll() { let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = listener.local_addr().unwrap(); let (tx, rx) = oneshot::channel(); let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel(); tokio::spawn(async move { let mut incoming = TrackPolls { npolls: Arc::new(AtomicUsize::new(0)), s: listener.incoming(), }; assert_ok!(tx.send(Arc::clone(&incoming.npolls))); while incoming.next().await.is_some() { accepted_tx.send(()).unwrap(); } }); let npolls = assert_ok!(rx.await); tokio::task::yield_now().await; // should have been polled exactly once: the initial poll assert_eq!(npolls.load(SeqCst), 1); let _ = assert_ok!(TcpStream::connect(&addr).await); accepted_rx.next().await.unwrap(); // should have been polled twice more: once to yield Some(), then once to yield Pending assert_eq!(npolls.load(SeqCst), 1 + 2); }