diff options
Diffstat (limited to 'tokio/src/net/udp/socket.rs')
-rw-r--r-- | tokio/src/net/udp/socket.rs | 419 |
1 files changed, 419 insertions, 0 deletions
diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs new file mode 100644 index 00000000..9cd861c8 --- /dev/null +++ b/tokio/src/net/udp/socket.rs @@ -0,0 +1,419 @@ +use crate::net::udp::split::{split, UdpSocketRecvHalf, UdpSocketSendHalf}; +use crate::net::util::PollEvented; +use crate::net::ToSocketAddrs; + +use futures_core::ready; +use futures_util::future::poll_fn; +use std::convert::TryFrom; +use std::fmt; +use std::io; +use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::task::{Context, Poll}; + +/// An I/O object representing a UDP socket. +pub struct UdpSocket { + io: PollEvented<mio::net::UdpSocket>, +} + +impl UdpSocket { + /// This function will create a new UDP socket and attempt to bind it to + /// the `addr` provided. + pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> { + let addrs = addr.to_socket_addrs().await?; + let mut last_err = None; + + for addr in addrs { + match UdpSocket::bind_addr(addr) { + Ok(socket) => return Ok(socket), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + ) + })) + } + + fn bind_addr(addr: SocketAddr) -> io::Result<UdpSocket> { + let sys = mio::net::UdpSocket::bind(&addr)?; + UdpSocket::new(sys) + } + + fn new(socket: mio::net::UdpSocket) -> io::Result<UdpSocket> { + let io = PollEvented::new(socket)?; + Ok(UdpSocket { io }) + } + + /// Creates a new `UdpSocket` from the previously bound socket provided. + /// + /// The socket given will be registered with the event loop that `handle` + /// is associated with. This function requires that `socket` has previously + /// been bound to an address to work correctly. + /// + /// This can be used in conjunction with net2's `UdpBuilder` interface to + /// configure a socket before it's handed off, such as setting options like + /// `reuse_address` or binding to multiple addresses. + pub fn from_std(socket: net::UdpSocket) -> io::Result<UdpSocket> { + let io = mio::net::UdpSocket::from_socket(socket)?; + let io = PollEvented::new(io)?; + Ok(UdpSocket { io }) + } + + /// Split the `UdpSocket` into a receive half and a send half. The two parts + /// can be used to receive and send datagrams concurrently, even from two + /// different tasks. + /// + /// See the module level documenation of [`split`](super::split) for more + /// details. + pub fn split(self) -> (UdpSocketRecvHalf, UdpSocketSendHalf) { + split(self) + } + + /// Returns the local address that this socket is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.get_ref().local_addr() + } + + /// Connects the UDP socket setting the default destination for send() and + /// limiting packets that are read via recv from the address specified in + /// `addr`. + pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> { + let addrs = addr.to_socket_addrs().await?; + let mut last_err = None; + + for addr in addrs { + match self.io.get_ref().connect(addr) { + Ok(_) => return Ok(()), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + ) + })) + } + + /// Returns a future that sends data on the socket to the remote address to which it is connected. + /// On success, the future will resolve to the number of bytes written. + /// + /// The [`connect`] method will connect this socket to a remote address. The future + /// will resolve to an error if the socket is not connected. + /// + /// [`connect`]: #method.connect + pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> { + poll_fn(|cx| self.poll_send(cx, buf)).await + } + + // Poll IO functions that takes `&self` are provided for the split API. + // + // They are not public because (taken from the doc of `PollEvented`): + // + // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the + // caller must ensure that there are at most two tasks that use a + // `PollEvented` instance concurrently. One for reading and one for writing. + // While violating this requirement is "safe" from a Rust memory model point + // of view, it will result in unexpected behavior in the form of lost + // notifications and tasks hanging. + #[doc(hidden)] + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { + ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_write_ready(cx)?; + Poll::Pending + } + x => Poll::Ready(x), + } + } + + /// Returns a future that receives a single datagram message on the socket from + /// the remote address to which it is connected. On success, the future will resolve + /// to the number of bytes read. + /// + /// The function must be called with valid byte array `buf` of sufficient size to + /// hold the message bytes. If a message is too long to fit in the supplied buffer, + /// excess bytes may be discarded. + /// + /// The [`connect`] method will connect this socket to a remote address. The future + /// will fail if the socket is not connected. + /// + /// [`connect`]: #method.connect + pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { + poll_fn(|cx| self.poll_recv(cx, buf)).await + } + + #[doc(hidden)] + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { + ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; + + match self.io.get_ref().recv(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_read_ready(cx, mio::Ready::readable())?; + Poll::Pending + } + x => Poll::Ready(x), + } + } + + /// Returns a future that sends data on the socket to the given address. + /// On success, the future will resolve to the number of bytes written. + /// + /// The future will resolve to an error if the IP version of the socket does + /// not match that of `target`. + pub async fn send_to<A: ToSocketAddrs>(&mut self, buf: &[u8], target: A) -> io::Result<usize> { + let mut addrs = target.to_socket_addrs().await?; + + match addrs.next() { + Some(target) => poll_fn(|cx| self.poll_send_to(cx, buf, &target)).await, + None => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "no addresses to send data to", + )), + } + } + + // TODO: Public or not? + #[doc(hidden)] + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: &SocketAddr, + ) -> Poll<io::Result<usize>> { + ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send_to(buf, target) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_write_ready(cx)?; + Poll::Pending + } + x => Poll::Ready(x), + } + } + + /// Returns a future that receives a single datagram on the socket. On success, + /// the future resolves to the number of bytes read and the origin. + /// + /// The function must be called with valid byte array `buf` of sufficient size + /// to hold the message bytes. If a message is too long to fit in the supplied + /// buffer, excess bytes may be discarded. + pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + poll_fn(|cx| self.poll_recv_from(cx, buf)).await + } + + #[doc(hidden)] + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<Result<(usize, SocketAddr), io::Error>> { + ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; + + match self.io.get_ref().recv_from(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_read_ready(cx, mio::Ready::readable())?; + Poll::Pending + } + x => Poll::Ready(x), + } + } + + /// Gets the value of the `SO_BROADCAST` option for this socket. + /// + /// For more information about this option, see [`set_broadcast`]. + /// + /// [`set_broadcast`]: #method.set_broadcast + pub fn broadcast(&self) -> io::Result<bool> { + self.io.get_ref().broadcast() + } + + /// Sets the value of the `SO_BROADCAST` option for this socket. + /// + /// When enabled, this socket is allowed to send packets to a broadcast + /// address. + pub fn set_broadcast(&self, on: bool) -> io::Result<()> { + self.io.get_ref().set_broadcast(on) + } + + /// Gets the value of the `IP_MULTICAST_LOOP` option for this socket. + /// + /// For more information about this option, see [`set_multicast_loop_v4`]. + /// + /// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4 + pub fn multicast_loop_v4(&self) -> io::Result<bool> { + self.io.get_ref().multicast_loop_v4() + } + + /// Sets the value of the `IP_MULTICAST_LOOP` option for this socket. + /// + /// If enabled, multicast packets will be looped back to the local socket. + /// + /// # Note + /// + /// This may not have any affect on IPv6 sockets. + pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> { + self.io.get_ref().set_multicast_loop_v4(on) + } + + /// Gets the value of the `IP_MULTICAST_TTL` option for this socket. + /// + /// For more information about this option, see [`set_multicast_ttl_v4`]. + /// + /// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4 + pub fn multicast_ttl_v4(&self) -> io::Result<u32> { + self.io.get_ref().multicast_ttl_v4() + } + + /// Sets the value of the `IP_MULTICAST_TTL` option for this socket. + /// + /// Indicates the time-to-live value of outgoing multicast packets for + /// this socket. The default value is 1 which means that multicast packets + /// don't leave the local network unless explicitly requested. + /// + /// # Note + /// + /// This may not have any affect on IPv6 sockets. + pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> { + self.io.get_ref().set_multicast_ttl_v4(ttl) + } + + /// Gets the value of the `IPV6_MULTICAST_LOOP` option for this socket. + /// + /// For more information about this option, see [`set_multicast_loop_v6`]. + /// + /// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6 + pub fn multicast_loop_v6(&self) -> io::Result<bool> { + self.io.get_ref().multicast_loop_v6() + } + + /// Sets the value of the `IPV6_MULTICAST_LOOP` option for this socket. + /// + /// Controls whether this socket sees the multicast packets it sends itself. + /// + /// # Note + /// + /// This may not have any affect on IPv4 sockets. + pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> { + self.io.get_ref().set_multicast_loop_v6(on) + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`]. + /// + /// [`set_ttl`]: #method.set_ttl + pub fn ttl(&self) -> io::Result<u32> { + self.io.get_ref().ttl() + } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.io.get_ref().set_ttl(ttl) + } + + /// Executes an operation of the `IP_ADD_MEMBERSHIP` type. + /// + /// This function specifies a new multicast group for this socket to join. + /// The address must be a valid multicast address, and `interface` is the + /// address of the local interface with which the system should join the + /// multicast group. If it's equal to `INADDR_ANY` then an appropriate + /// interface is chosen by the system. + pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> { + self.io.get_ref().join_multicast_v4(&multiaddr, &interface) + } + + /// Executes an operation of the `IPV6_ADD_MEMBERSHIP` type. + /// + /// This function specifies a new multicast group for this socket to join. + /// The address must be a valid multicast address, and `interface` is the + /// index of the interface to join/leave (or 0 to indicate any interface). + pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + self.io.get_ref().join_multicast_v6(multiaddr, interface) + } + + /// Executes an operation of the `IP_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v4`]. + /// + /// [`join_multicast_v4`]: #method.join_multicast_v4 + pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> { + self.io.get_ref().leave_multicast_v4(&multiaddr, &interface) + } + + /// Executes an operation of the `IPV6_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v6`]. + /// + /// [`join_multicast_v6`]: #method.join_multicast_v6 + pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + self.io.get_ref().leave_multicast_v6(multiaddr, interface) + } +} + +impl TryFrom<UdpSocket> for mio::net::UdpSocket { + type Error = io::Error; + + /// Consumes value, returning the mio I/O object. + /// + /// See [`PollEvented::into_inner`] for more details about + /// resource deregistration that happens during the call. + /// + /// [`PollEvented::into_inner`]: crate::util::PollEvented::into_inner + fn try_from(value: UdpSocket) -> Result<Self, Self::Error> { + value.io.into_inner() + } +} + +impl TryFrom<net::UdpSocket> for UdpSocket { + type Error = io::Error; + + /// Consumes stream, returning the tokio I/O object. + /// + /// This is equivalent to + /// [`UdpSocket::from_std(stream)`](UdpSocket::from_std). + fn try_from(stream: net::UdpSocket) -> Result<Self, Self::Error> { + Self::from_std(stream) + } +} + +impl fmt::Debug for UdpSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.get_ref().fmt(f) + } +} + +#[cfg(all(unix))] +mod sys { + use super::UdpSocket; + use std::os::unix::prelude::*; + + impl AsRawFd for UdpSocket { + fn as_raw_fd(&self) -> RawFd { + self.io.get_ref().as_raw_fd() + } + } +} + +#[cfg(windows)] +mod sys { + // TODO: let's land these upstream with mio and then we can add them here. + // + // use std::os::windows::prelude::*; + // use super::UdpSocket; + // + // impl AsRawHandle for UdpSocket { + // fn as_raw_handle(&self) -> RawHandle { + // self.io.get_ref().as_raw_handle() + // } + // } +} |