diff options
author | cssivision <cssivision@gmail.com> | 2020-07-24 13:03:47 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-23 22:03:47 -0700 |
commit | ff7125ec7b4fe009b80cc033edbf338883eb0435 (patch) | |
tree | ead45fdae5a766df7283b7ddf0edd56ba357c193 | |
parent | 7a60a0b362337c9986a8d8bf59f458691685bf3e (diff) |
net: introduce split on UnixDatagram (#2557)
-rw-r--r-- | tokio/src/net/unix/datagram.rs | 138 | ||||
-rw-r--r-- | tokio/src/net/unix/mod.rs | 1 | ||||
-rw-r--r-- | tokio/tests/uds_datagram.rs | 50 |
3 files changed, 189 insertions, 0 deletions
diff --git a/tokio/src/net/unix/datagram.rs b/tokio/src/net/unix/datagram.rs index ff0f4241..de450e24 100644 --- a/tokio/src/net/unix/datagram.rs +++ b/tokio/src/net/unix/datagram.rs @@ -2,12 +2,14 @@ use crate::future::poll_fn; use crate::io::PollEvented; use std::convert::TryFrom; +use std::error::Error; use std::fmt; use std::io; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{self, SocketAddr}; use std::path::Path; +use std::sync::Arc; use std::task::{Context, Poll}; cfg_uds! { @@ -201,6 +203,12 @@ impl UnixDatagram { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.io.get_ref().shutdown(how) } + + /// Split a `UnixDatagram` into a receive half and a send half, which can be used + /// to receice and send the datagram concurrently. + pub fn into_split(self) -> (OwnedRecvHalf, OwnedSendHalf) { + split_owned(self) + } } impl TryFrom<UnixDatagram> for mio_uds::UnixDatagram { @@ -240,3 +248,133 @@ impl AsRawFd for UnixDatagram { self.io.get_ref().as_raw_fd() } } + +fn split_owned(socket: UnixDatagram) -> (OwnedRecvHalf, OwnedSendHalf) { + let shared = Arc::new(socket); + let send = shared.clone(); + let recv = shared; + ( + OwnedRecvHalf { inner: recv }, + OwnedSendHalf { + inner: send, + shutdown_on_drop: true, + }, + ) +} + +/// The send half after [`split`](UnixDatagram::into_split). +/// +/// Use [`send_to`](#method.send_to) or [`send`](#method.send) to send +/// datagrams. +#[derive(Debug)] +pub struct OwnedSendHalf { + inner: Arc<UnixDatagram>, + shutdown_on_drop: bool, +} + +/// The recv half after [`split`](UnixDatagram::into_split). +/// +/// Use [`recv_from`](#method.recv_from) or [`recv`](#method.recv) to receive +/// datagrams. +#[derive(Debug)] +pub struct OwnedRecvHalf { + inner: Arc<UnixDatagram>, +} + +/// Error indicating two halves were not from the same socket, and thus could +/// not be `reunite`d. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedSendHalf, pub OwnedRecvHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +fn reunite(s: OwnedSendHalf, r: OwnedRecvHalf) -> Result<UnixDatagram, ReuniteError> { + if Arc::ptr_eq(&s.inner, &r.inner) { + s.forget(); + // Only two instances of the `Arc` are ever created, one for the + // receiver and one for the sender, and those `Arc`s are never exposed + // externally. And so when we drop one here, the other one must be the + // only remaining one. + Ok(Arc::try_unwrap(r.inner).expect("unixdatagram: try_unwrap failed in reunite")) + } else { + Err(ReuniteError(s, r)) + } +} + +impl OwnedRecvHalf { + /// Attempts to put the two "halves" of a `UnixDatagram` back together and + /// recover the original socket. Succeeds only if the two "halves" + /// originated from the same call to `UnixDatagram::split`. + pub fn reunite(self, other: OwnedSendHalf) -> Result<UnixDatagram, ReuniteError> { + reunite(other, self) + } + + /// Receives data from the socket. + pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + poll_fn(|cx| self.inner.poll_recv_from_priv(cx, buf)).await + } + + /// Receives data from the socket. + pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { + poll_fn(|cx| self.inner.poll_recv_priv(cx, buf)).await + } +} + +impl OwnedSendHalf { + /// Attempts to put the two "halves" of a `UnixDatagram` back together and + /// recover the original socket. Succeeds only if the two "halves" + /// originated from the same call to `UnixDatagram::split`. + pub fn reunite(self, other: OwnedRecvHalf) -> Result<UnixDatagram, ReuniteError> { + reunite(self, other) + } + + /// Sends data on the socket to the specified address. + pub async fn send_to<P>(&mut self, buf: &[u8], target: P) -> io::Result<usize> + where + P: AsRef<Path> + Unpin, + { + poll_fn(|cx| self.inner.poll_send_to_priv(cx, buf, target.as_ref())).await + } + + /// Sends data on the socket to the socket's peer. + pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> { + poll_fn(|cx| self.inner.poll_send_priv(cx, buf)).await + } + + /// Destroy the send half, but don't close the stream until the recvice half + /// is dropped. If the read half has already been dropped, this closes the + /// stream. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } +} + +impl Drop for OwnedSendHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown(Shutdown::Both); + } + } +} + +impl AsRef<UnixDatagram> for OwnedSendHalf { + fn as_ref(&self) -> &UnixDatagram { + &self.inner + } +} + +impl AsRef<UnixDatagram> for OwnedRecvHalf { + fn as_ref(&self) -> &UnixDatagram { + &self.inner + } +} diff --git a/tokio/src/net/unix/mod.rs b/tokio/src/net/unix/mod.rs index ddba60d1..f063b74b 100644 --- a/tokio/src/net/unix/mod.rs +++ b/tokio/src/net/unix/mod.rs @@ -1,6 +1,7 @@ //! Unix domain socket utility types pub(crate) mod datagram; +pub use datagram::{OwnedRecvHalf, OwnedSendHalf, ReuniteError}; mod incoming; pub use incoming::Incoming; diff --git a/tokio/tests/uds_datagram.rs b/tokio/tests/uds_datagram.rs index dd995237..cfb1c649 100644 --- a/tokio/tests/uds_datagram.rs +++ b/tokio/tests/uds_datagram.rs @@ -3,6 +3,7 @@ #![cfg(unix)] use tokio::net::UnixDatagram; +use tokio::try_join; use std::io; @@ -41,3 +42,52 @@ async fn echo() -> io::Result<()> { Ok(()) } + +#[tokio::test] +async fn split() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("split.sock"); + let socket = UnixDatagram::bind(path.clone())?; + let (mut r, mut s) = socket.into_split(); + + let msg = b"hello"; + let ((), ()) = try_join! { + async { + s.send_to(msg, path).await?; + io::Result::Ok(()) + }, + async { + let mut recv_buf = [0u8; 32]; + let (len, _) = r.recv_from(&mut recv_buf[..]).await?; + assert_eq!(&recv_buf[..len], msg); + Ok(()) + }, + }?; + + Ok(()) +} + +#[tokio::test] +async fn reunite() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("reunite.sock"); + let socket = UnixDatagram::bind(path)?; + let (s, r) = socket.into_split(); + assert!(s.reunite(r).is_ok()); + Ok(()) +} + +#[tokio::test] +async fn reunite_error() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("reunit.sock"); + let dir = tempfile::tempdir().unwrap(); + let path1 = dir.path().join("reunit.sock"); + let socket = UnixDatagram::bind(path)?; + let socket1 = UnixDatagram::bind(path1)?; + + let (s, _) = socket.into_split(); + let (_, r1) = socket1.into_split(); + assert!(s.reunite(r1).is_err()); + Ok(()) +} |