From 8f3a26597270871a2adbf4f8da80a82961e9e296 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Sun, 19 Apr 2020 19:00:44 +0200 Subject: net: introduce owned split on TcpStream (#2270) --- tokio/src/net/tcp/mod.rs | 3 + tokio/src/net/tcp/split.rs | 4 +- tokio/src/net/tcp/split_owned.rs | 251 +++++++++++++++++++++++++++++++++++++++ tokio/src/net/tcp/stream.rs | 17 +++ tokio/tests/tcp_into_split.rs | 139 ++++++++++++++++++++++ 5 files changed, 412 insertions(+), 2 deletions(-) create mode 100644 tokio/src/net/tcp/split_owned.rs create mode 100644 tokio/tests/tcp_into_split.rs diff --git a/tokio/src/net/tcp/mod.rs b/tokio/src/net/tcp/mod.rs index d5354b38..7ad36eb0 100644 --- a/tokio/src/net/tcp/mod.rs +++ b/tokio/src/net/tcp/mod.rs @@ -9,5 +9,8 @@ pub use incoming::Incoming; mod split; pub use split::{ReadHalf, WriteHalf}; +mod split_owned; +pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; + pub(crate) mod stream; pub(crate) use stream::TcpStream; diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index cce50f6a..39dca996 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -25,8 +25,8 @@ pub struct ReadHalf<'a>(&'a TcpStream); /// Write half of a `TcpStream`. /// -/// Note that in the `AsyncWrite` implemenation of `TcpStreamWriteHalf`, -/// `poll_shutdown` actually shuts down the TCP stream in the write direction. +/// Note that in the `AsyncWrite` implemenation of this type, `poll_shutdown` will +/// shut down the TCP stream in the write direction. #[derive(Debug)] pub struct WriteHalf<'a>(&'a TcpStream); diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs new file mode 100644 index 00000000..908a39e2 --- /dev/null +++ b/tokio/src/net/tcp/split_owned.rs @@ -0,0 +1,251 @@ +//! `TcpStream` owned split support. +//! +//! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf` +//! with the `TcpStream::into_split` method. `OwnedReadHalf` implements +//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::future::poll_fn; +use crate::io::{AsyncRead, AsyncWrite}; +use crate::net::TcpStream; + +use bytes::Buf; +use std::error::Error; +use std::mem::MaybeUninit; +use std::net::Shutdown; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, io}; + +/// Owned read half of a [`TcpStream`], created by [`into_split`]. +/// +/// [`TcpStream`]: TcpStream +/// [`into_split`]: TcpStream::into_split() +#[derive(Debug)] +pub struct OwnedReadHalf { + inner: Arc, +} + +/// Owned write half of a [`TcpStream`], created by [`into_split`]. +/// +/// Note that in the `AsyncWrite` implemenation of this type, `poll_shutdown` will +/// shut down the TCP stream in the write direction. +/// +/// Dropping the write half will close the TCP stream in both directions. +/// +/// [`TcpStream`]: TcpStream +/// [`into_split`]: TcpStream::into_split() +#[derive(Debug)] +pub struct OwnedWriteHalf { + inner: Arc, + shutdown_on_drop: bool, +} + +pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) { + let arc = Arc::new(stream); + let read = OwnedReadHalf { + inner: Arc::clone(&arc), + }; + let write = OwnedWriteHalf { + inner: arc, + shutdown_on_drop: true, + }; + (read, write) +} + +pub(crate) fn reunite( + read: OwnedReadHalf, + write: OwnedWriteHalf, +) -> Result { + if Arc::ptr_eq(&read.inner, &write.inner) { + write.forget(); + // This unwrap cannot fail as the api does not allow creating more than two Arcs, + // and we just dropped the other half. + Ok(Arc::try_unwrap(read.inner).expect("Too many handles to Arc")) + } else { + Err(ReuniteError(read, write)) + } +} + +/// Error indicating two halves were not from the same socket, and thus could +/// not be reunited. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +impl OwnedReadHalf { + /// Attempts to put the two halves of a `TcpStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: TcpStream::into_split() + pub fn reunite(self, other: OwnedWriteHalf) -> Result { + reunite(self, other) + } + + /// Attempt to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. + /// + /// See the [`TcpStream::poll_peek`] level documenation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io; + /// use tokio::net::TcpStream; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let stream = TcpStream::connect("127.0.0.1:8000").await?; + /// let (mut read_half, _) = stream.into_split(); + /// let mut buf = [0; 10]; + /// + /// poll_fn(|cx| { + /// read_half.poll_peek(cx, &mut buf) + /// }).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`TcpStream::poll_peek`]: TcpStream::poll_peek + pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + self.inner.poll_peek2(cx, buf) + } + + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// See the [`TcpStream::peek`] level documenation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::prelude::*; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// let (mut read_half, _) = stream.into_split(); + /// + /// let mut b1 = [0; 10]; + /// let mut b2 = [0; 10]; + /// + /// // Peek at the data + /// let n = read_half.peek(&mut b1).await?; + /// + /// // Read the data + /// assert_eq!(n, read_half.read(&mut b2[..n]).await?); + /// assert_eq!(&b1[..n], &b2[..n]); + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`TcpStream::peek`]: TcpStream::peek + pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { + poll_fn(|cx| self.poll_peek(cx, buf)).await + } +} + +impl AsyncRead for OwnedReadHalf { + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { + false + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.inner.poll_read_priv(cx, buf) + } +} + +impl OwnedWriteHalf { + /// Attempts to put the two halves of a `TcpStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: TcpStream::into_split() + pub fn reunite(self, other: OwnedReadHalf) -> Result { + reunite(other, self) + } + /// Destroy the write half, but don't close the stream until the read half + /// is dropped. If the read half has already been dropped, this closes the + /// stream. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } +} + +impl Drop for OwnedWriteHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown(Shutdown::Both); + } + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.inner.poll_write_priv(cx, buf) + } + + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.inner.poll_write_buf_priv(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.inner.shutdown(Shutdown::Write).into() + } +} + +impl AsRef for OwnedReadHalf { + fn as_ref(&self) -> &TcpStream { + &*self.inner + } +} + +impl AsRef for OwnedWriteHalf { + fn as_ref(&self) -> &TcpStream { + &*self.inner + } +} diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 732c0ca3..03489152 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -1,6 +1,7 @@ use crate::future::poll_fn; use crate::io::{AsyncRead, AsyncWrite, PollEvented}; use crate::net::tcp::split::{split, ReadHalf, WriteHalf}; +use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::ToSocketAddrs; use bytes::Buf; @@ -614,10 +615,26 @@ impl TcpStream { /// Splits a `TcpStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. + /// + /// This method is more efficient than [`into_split`], but the halves cannot be + /// moved into independently spawned tasks. + /// + /// [`into_split`]: TcpStream::into_split() pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { split(self) } + /// Splits a `TcpStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// Unlike [`split`], the owned halves can be moved to separate tasks, however + /// this comes at the cost of a heap allocation. + /// + /// [`split`]: TcpStream::split() + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + split_owned(self) + } + // == Poll IO functions that takes `&self` == // // They are not public because (taken from the doc of `PollEvented`): diff --git a/tokio/tests/tcp_into_split.rs b/tokio/tests/tcp_into_split.rs new file mode 100644 index 00000000..6561fa30 --- /dev/null +++ b/tokio/tests/tcp_into_split.rs @@ -0,0 +1,139 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::io::{Error, ErrorKind, Result}; +use std::io::{Read, Write}; +use std::sync::{Arc, Barrier}; +use std::{net, thread}; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::try_join; + +#[tokio::test] +async fn split() -> Result<()> { + const MSG: &[u8] = b"split"; + + let mut listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let (stream1, (mut stream2, _)) = try_join! { + TcpStream::connect(&addr), + listener.accept(), + }?; + let (mut read_half, mut write_half) = stream1.into_split(); + + let ((), (), ()) = try_join! { + async { + let len = stream2.write(MSG).await?; + assert_eq!(len, MSG.len()); + + let mut read_buf = vec![0u8; 32]; + let read_len = stream2.read(&mut read_buf).await?; + assert_eq!(&read_buf[..read_len], MSG); + Result::Ok(()) + }, + async { + let len = write_half.write(MSG).await?; + assert_eq!(len, MSG.len()); + Ok(()) + }, + async { + let mut read_buf = vec![0u8; 32]; + let peek_len1 = read_half.peek(&mut read_buf[..]).await?; + let peek_len2 = read_half.peek(&mut read_buf[..]).await?; + assert_eq!(peek_len1, peek_len2); + + let read_len = read_half.read(&mut read_buf[..]).await?; + assert_eq!(peek_len1, read_len); + assert_eq!(&read_buf[..read_len], MSG); + Ok(()) + }, + }?; + + Ok(()) +} + +#[tokio::test] +async fn reunite() -> Result<()> { + let listener = net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + + let handle = thread::spawn(move || { + drop(listener.accept().unwrap()); + drop(listener.accept().unwrap()); + }); + + let stream1 = TcpStream::connect(&addr).await?; + let (read1, write1) = stream1.into_split(); + + let stream2 = TcpStream::connect(&addr).await?; + let (_, write2) = stream2.into_split(); + + let read1 = match read1.reunite(write2) { + Ok(_) => panic!("Reunite should not succeed"), + Err(err) => err.0, + }; + + read1.reunite(write1).expect("Reunite should succeed"); + + handle.join().unwrap(); + Ok(()) +} + +/// Test that dropping the write half actually closes the stream. +#[tokio::test] +async fn drop_write() -> Result<()> { + const MSG: &[u8] = b"split"; + + let listener = net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + + let barrier = Arc::new(Barrier::new(2)); + let barrier2 = barrier.clone(); + + let handle = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + stream.write(MSG).unwrap(); + + let mut read_buf = [0u8; 32]; + let res = match stream.read(&mut read_buf) { + Ok(0) => Ok(()), + Ok(len) => Err(Error::new( + ErrorKind::Other, + format!("Unexpected read: {} bytes.", len), + )), + Err(err) => Err(err), + }; + + barrier2.wait(); + + drop(stream); + + res + }); + + let stream = TcpStream::connect(&addr).await?; + let (mut read_half, write_half) = stream.into_split(); + + let mut read_buf = [0u8; 32]; + let read_len = read_half.read(&mut read_buf[..]).await?; + assert_eq!(&read_buf[..read_len], MSG); + + // drop it while the read is in progress + std::thread::spawn(move || { + thread::sleep(std::time::Duration::from_millis(50)); + drop(write_half); + }); + + match read_half.read(&mut read_buf[..]).await { + Ok(0) => {} + Ok(len) => panic!("Unexpected read: {} bytes.", len), + Err(err) => panic!("Unexpected error: {}.", err), + } + + barrier.wait(); + + handle.join().unwrap().unwrap(); + Ok(()) +} -- cgit v1.2.3