From 358e4f9f8029b6b289f8ef5a54bd7c6eae5bf969 Mon Sep 17 00:00:00 2001 From: Evan Cameron Date: Thu, 22 Oct 2020 12:58:00 -0400 Subject: tokio: add back poll_* for udp (#2981) --- tokio/src/net/udp/socket.rs | 259 +++++++++++++++++++++++++++++++++++++++++++- tokio/tests/udp.rs | 131 +++++++++++++++++++++- 2 files changed, 388 insertions(+), 2 deletions(-) diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs index d13e92bc..beb2252c 100644 --- a/tokio/src/net/udp/socket.rs +++ b/tokio/src/net/udp/socket.rs @@ -1,10 +1,11 @@ -use crate::io::PollEvented; +use crate::io::{PollEvented, ReadBuf}; use crate::net::{to_socket_addrs, ToSocketAddrs}; use std::convert::TryFrom; use std::fmt; use std::io; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::task::{Context, Poll}; cfg_net! { /// A UDP socket @@ -272,6 +273,42 @@ impl UdpSocket { .await } + /// Attempts to send data on the socket to the remote address to which it was previously + /// `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. The future + /// will resolve to an error if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the send direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not available to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), + } + } + } + /// Try to send data on the socket to the remote address to which it is /// connected. /// @@ -304,6 +341,55 @@ impl UdpSocket { .await } + /// Attempts to receive a single datagram message on the socket from the remote + /// address to which it is `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. This method + /// resolves to an error if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(()))` reads data `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + match self.io.get_ref().recv(b) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + Ok(n) => { + // Safety: We trust `recv` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok(())); + } + } + } + } + /// Returns a future that sends data on the socket to the given address. /// On success, the future will resolve to the number of bytes written. /// @@ -337,6 +423,41 @@ impl UdpSocket { } } + /// Attempts to send data on the socket to a given address. + /// + /// Note that on multiple calls to a `poll_*` method in the send direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: &SocketAddr, + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send_to(buf, *target) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), + } + } + } + /// Try to send data on the socket to the given address, but if the send is blocked /// this will return right away. /// @@ -403,6 +524,142 @@ impl UdpSocket { .await } + /// Attempts to receive a single datagram on the socket. + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + match self.io.get_ref().recv_from(b) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + Ok((n, addr)) => { + // Safety: We trust `recv` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok(addr)); + } + } + } + } + + /// Receives data from the socket, without removing it from the input queue. + /// On success, returns the number of bytes read and the address from whence + /// the data came. + /// + /// # Notes + /// + /// On Windows, if the data is larger than the buffer specified, the buffer + /// is filled with the first part of the data, and peek_from returns the error + /// WSAEMSGSIZE(10040). The excess data is lost. + /// Make sure to always use a sufficiently large buffer to hold the + /// maximum UDP packet size, which can be up to 65536 bytes in size. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::().unwrap()).await?; + /// let mut buf = [0u8; 32]; + /// let (len, addr) = sock.peek_from(&mut buf).await?; + /// println!("peeked {:?} bytes from {:?}", len, addr); + /// # Ok(()) + /// # } + /// ``` + pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.io + .async_io(mio::Interest::READABLE, |sock| sock.peek_from(buf)) + .await + } + + /// Receives data from the socket, without removing it from the input queue. + /// On success, returns the number of bytes read. + /// + /// # Notes + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup + /// + /// On Windows, if the data is larger than the buffer specified, the buffer + /// is filled with the first part of the data, and peek returns the error + /// WSAEMSGSIZE(10040). The excess data is lost. + /// Make sure to always use a sufficiently large buffer to hold the + /// maximum UDP packet size, which can be up to 65536 bytes in size. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_peek_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + match self.io.get_ref().peek_from(b) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + Ok((n, addr)) => { + // Safety: We trust `recv` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok(addr)); + } + } + } + } + /// Gets the value of the `SO_BROADCAST` option for this socket. /// /// For more information about this option, see [`set_broadcast`]. diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index 0bea83aa..8b79cb85 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use futures::future::poll_fn; use std::sync::Arc; -use tokio::net::UdpSocket; +use tokio::{io::ReadBuf, net::UdpSocket}; const MSG: &[u8] = b"hello"; const MSG_LEN: usize = MSG.len(); @@ -24,6 +25,24 @@ async fn send_recv() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn send_recv_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + sender.connect(receiver.local_addr()?).await?; + receiver.connect(sender.local_addr()?).await?; + + poll_fn(|cx| sender.poll_send(cx, MSG)).await?; + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let _len = poll_fn(|cx| receiver.poll_recv(cx, &mut read)).await?; + + assert_eq!(read.filled(), MSG); + Ok(()) +} + #[tokio::test] async fn send_to_recv_from() -> std::io::Result<()> { let sender = UdpSocket::bind("127.0.0.1:0").await?; @@ -40,6 +59,79 @@ async fn send_to_recv_from() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn send_to_recv_from_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let receiver_addr = receiver.local_addr()?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let addr = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), MSG); + assert_eq!(addr, sender.local_addr()?); + Ok(()) +} + +#[tokio::test] +async fn send_to_peek_from() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let receiver_addr = receiver.local_addr()?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + + // peek + let mut recv_buf = [0u8; 32]; + let (n, addr) = receiver.peek_from(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], MSG); + assert_eq!(addr, sender.local_addr()?); + + // peek + let mut recv_buf = [0u8; 32]; + let (n, addr) = receiver.peek_from(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], MSG); + assert_eq!(addr, sender.local_addr()?); + + let mut recv_buf = [0u8; 32]; + let (n, addr) = receiver.recv_from(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], MSG); + assert_eq!(addr, sender.local_addr()?); + + Ok(()) +} + +#[tokio::test] +async fn send_to_peek_from_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let receiver_addr = receiver.local_addr()?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let addr = poll_fn(|cx| receiver.poll_peek_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), MSG); + assert_eq!(addr, sender.local_addr()?); + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + poll_fn(|cx| receiver.poll_peek_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), MSG); + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + + poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; + assert_eq!(read.filled(), MSG); + Ok(()) +} + #[tokio::test] async fn split() -> std::io::Result<()> { let socket = UdpSocket::bind("127.0.0.1:0").await?; @@ -88,6 +180,43 @@ async fn split_chan() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn split_chan_poll() -> std::io::Result<()> { + // setup UdpSocket that will echo all sent items + let socket = UdpSocket::bind("127.0.0.1:0").await?; + let addr = socket.local_addr().unwrap(); + let s = Arc::new(socket); + let r = s.clone(); + + let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec, std::net::SocketAddr)>(1_000); + tokio::spawn(async move { + while let Some((bytes, addr)) = rx.recv().await { + poll_fn(|cx| s.poll_send_to(cx, &bytes, &addr)) + .await + .unwrap(); + } + }); + + tokio::spawn(async move { + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + loop { + let addr = poll_fn(|cx| r.poll_recv_from(cx, &mut read)).await.unwrap(); + tx.send((read.filled().to_vec(), addr)).await.unwrap(); + } + }); + + // test that we can send a value and get back some response + let sender = UdpSocket::bind("127.0.0.1:0").await?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &addr)).await?; + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let _ = poll_fn(|cx| sender.poll_recv_from(cx, &mut read)).await?; + assert_eq!(read.filled(), MSG); + Ok(()) +} + // # Note // // This test is purposely written such that each time `sender` sends data on -- cgit v1.2.3