summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlice Ryhl <alice@ryhl.io>2020-04-19 19:00:44 +0200
committerGitHub <noreply@github.com>2020-04-19 19:00:44 +0200
commit8f3a26597270871a2adbf4f8da80a82961e9e296 (patch)
tree6f541e33084ec33425beb326fca4c87ce5089623
parent800574b4e0c7d276866fc8c7b0efb2c8474e0315 (diff)
net: introduce owned split on TcpStream (#2270)
-rw-r--r--tokio/src/net/tcp/mod.rs3
-rw-r--r--tokio/src/net/tcp/split.rs4
-rw-r--r--tokio/src/net/tcp/split_owned.rs251
-rw-r--r--tokio/src/net/tcp/stream.rs17
-rw-r--r--tokio/tests/tcp_into_split.rs139
5 files changed, 412 insertions, 2 deletions
diff --git a/tokio/src/net/tcp/mod.rs b/tokio/src/net/tcp/mod.rs
index d5354b38..7ad36eb0 100644
--- a/tokio/src/net/tcp/mod.rs
+++ b/tokio/src/net/tcp/mod.rs
@@ -9,5 +9,8 @@ pub use incoming::Incoming;
mod split;
pub use split::{ReadHalf, WriteHalf};
+mod split_owned;
+pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError};
+
pub(crate) mod stream;
pub(crate) use stream::TcpStream;
diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs
index cce50f6a..39dca996 100644
--- a/tokio/src/net/tcp/split.rs
+++ b/tokio/src/net/tcp/split.rs
@@ -25,8 +25,8 @@ pub struct ReadHalf<'a>(&'a TcpStream);
/// Write half of a `TcpStream`.
///
-/// Note that in the `AsyncWrite` implemenation of `TcpStreamWriteHalf`,
-/// `poll_shutdown` actually shuts down the TCP stream in the write direction.
+/// Note that in the `AsyncWrite` implemenation of this type, `poll_shutdown` will
+/// shut down the TCP stream in the write direction.
#[derive(Debug)]
pub struct WriteHalf<'a>(&'a TcpStream);
diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs
new file mode 100644
index 00000000..908a39e2
--- /dev/null
+++ b/tokio/src/net/tcp/split_owned.rs
@@ -0,0 +1,251 @@
+//! `TcpStream` owned split support.
+//!
+//! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
+//! with the `TcpStream::into_split` method. `OwnedReadHalf` implements
+//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
+//!
+//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
+//! split has no associated overhead and enforces all invariants at the type
+//! level.
+
+use crate::future::poll_fn;
+use crate::io::{AsyncRead, AsyncWrite};
+use crate::net::TcpStream;
+
+use bytes::Buf;
+use std::error::Error;
+use std::mem::MaybeUninit;
+use std::net::Shutdown;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use std::{fmt, io};
+
+/// Owned read half of a [`TcpStream`], created by [`into_split`].
+///
+/// [`TcpStream`]: TcpStream
+/// [`into_split`]: TcpStream::into_split()
+#[derive(Debug)]
+pub struct OwnedReadHalf {
+ inner: Arc<TcpStream>,
+}
+
+/// Owned write half of a [`TcpStream`], created by [`into_split`].
+///
+/// Note that in the `AsyncWrite` implemenation of this type, `poll_shutdown` will
+/// shut down the TCP stream in the write direction.
+///
+/// Dropping the write half will close the TCP stream in both directions.
+///
+/// [`TcpStream`]: TcpStream
+/// [`into_split`]: TcpStream::into_split()
+#[derive(Debug)]
+pub struct OwnedWriteHalf {
+ inner: Arc<TcpStream>,
+ shutdown_on_drop: bool,
+}
+
+pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
+ let arc = Arc::new(stream);
+ let read = OwnedReadHalf {
+ inner: Arc::clone(&arc),
+ };
+ let write = OwnedWriteHalf {
+ inner: arc,
+ shutdown_on_drop: true,
+ };
+ (read, write)
+}
+
+pub(crate) fn reunite(
+ read: OwnedReadHalf,
+ write: OwnedWriteHalf,
+) -> Result<TcpStream, ReuniteError> {
+ if Arc::ptr_eq(&read.inner, &write.inner) {
+ write.forget();
+ // This unwrap cannot fail as the api does not allow creating more than two Arcs,
+ // and we just dropped the other half.
+ Ok(Arc::try_unwrap(read.inner).expect("Too many handles to Arc"))
+ } else {
+ Err(ReuniteError(read, write))
+ }
+}
+
+/// Error indicating two halves were not from the same socket, and thus could
+/// not be reunited.
+#[derive(Debug)]
+pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
+
+impl fmt::Display for ReuniteError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "tried to reunite halves that are not from the same socket"
+ )
+ }
+}
+
+impl Error for ReuniteError {}
+
+impl OwnedReadHalf {
+ /// Attempts to put the two halves of a `TcpStream` back together and
+ /// recover the original socket. Succeeds only if the two halves
+ /// originated from the same call to [`into_split`].
+ ///
+ /// [`into_split`]: TcpStream::into_split()
+ pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
+ reunite(self, other)
+ }
+
+ /// Attempt to receive data on the socket, without removing that data from
+ /// the queue, registering the current task for wakeup if data is not yet
+ /// available.
+ ///
+ /// See the [`TcpStream::poll_peek`] level documenation for more details.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::io;
+ /// use tokio::net::TcpStream;
+ ///
+ /// use futures::future::poll_fn;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let stream = TcpStream::connect("127.0.0.1:8000").await?;
+ /// let (mut read_half, _) = stream.into_split();
+ /// let mut buf = [0; 10];
+ ///
+ /// poll_fn(|cx| {
+ /// read_half.poll_peek(cx, &mut buf)
+ /// }).await?;
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ ///
+ /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
+ pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
+ self.inner.poll_peek2(cx, buf)
+ }
+
+ /// Receives data on the socket from the remote address to which it is
+ /// connected, without removing that data from the queue. On success,
+ /// returns the number of bytes peeked.
+ ///
+ /// See the [`TcpStream::peek`] level documenation for more details.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpStream;
+ /// use tokio::prelude::*;
+ /// use std::error::Error;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> Result<(), Box<dyn Error>> {
+ /// // Connect to a peer
+ /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
+ /// let (mut read_half, _) = stream.into_split();
+ ///
+ /// let mut b1 = [0; 10];
+ /// let mut b2 = [0; 10];
+ ///
+ /// // Peek at the data
+ /// let n = read_half.peek(&mut b1).await?;
+ ///
+ /// // Read the data
+ /// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
+ /// assert_eq!(&b1[..n], &b2[..n]);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ ///
+ /// [`TcpStream::peek`]: TcpStream::peek
+ pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ poll_fn(|cx| self.poll_peek(cx, buf)).await
+ }
+}
+
+impl AsyncRead for OwnedReadHalf {
+ unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
+ false
+ }
+
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut [u8],
+ ) -> Poll<io::Result<usize>> {
+ self.inner.poll_read_priv(cx, buf)
+ }
+}
+
+impl OwnedWriteHalf {
+ /// Attempts to put the two halves of a `TcpStream` back together and
+ /// recover the original socket. Succeeds only if the two halves
+ /// originated from the same call to [`into_split`].
+ ///
+ /// [`into_split`]: TcpStream::into_split()
+ pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
+ reunite(other, self)
+ }
+ /// Destroy the write half, but don't close the stream until the read half
+ /// is dropped. If the read half has already been dropped, this closes the
+ /// stream.
+ pub fn forget(mut self) {
+ self.shutdown_on_drop = false;
+ drop(self);
+ }
+}
+
+impl Drop for OwnedWriteHalf {
+ fn drop(&mut self) {
+ if self.shutdown_on_drop {
+ let _ = self.inner.shutdown(Shutdown::Both);
+ }
+ }
+}
+
+impl AsyncWrite for OwnedWriteHalf {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ self.inner.poll_write_priv(cx, buf)
+ }
+
+ fn poll_write_buf<B: Buf>(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ self.inner.poll_write_buf_priv(cx, buf)
+ }
+
+ #[inline]
+ fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
+ // tcp flush is a no-op
+ Poll::Ready(Ok(()))
+ }
+
+ // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
+ fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
+ self.inner.shutdown(Shutdown::Write).into()
+ }
+}
+
+impl AsRef<TcpStream> for OwnedReadHalf {
+ fn as_ref(&self) -> &TcpStream {
+ &*self.inner
+ }
+}
+
+impl AsRef<TcpStream> for OwnedWriteHalf {
+ fn as_ref(&self) -> &TcpStream {
+ &*self.inner
+ }
+}
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index 732c0ca3..03489152 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -1,6 +1,7 @@
use crate::future::poll_fn;
use crate::io::{AsyncRead, AsyncWrite, PollEvented};
use crate::net::tcp::split::{split, ReadHalf, WriteHalf};
+use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::net::ToSocketAddrs;
use bytes::Buf;
@@ -614,10 +615,26 @@ impl TcpStream {
/// Splits a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
+ ///
+ /// This method is more efficient than [`into_split`], but the halves cannot be
+ /// moved into independently spawned tasks.
+ ///
+ /// [`into_split`]: TcpStream::into_split()
pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
split(self)
}
+ /// Splits a `TcpStream` into a read half and a write half, which can be used
+ /// to read and write the stream concurrently.
+ ///
+ /// Unlike [`split`], the owned halves can be moved to separate tasks, however
+ /// this comes at the cost of a heap allocation.
+ ///
+ /// [`split`]: TcpStream::split()
+ pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
+ split_owned(self)
+ }
+
// == Poll IO functions that takes `&self` ==
//
// They are not public because (taken from the doc of `PollEvented`):
diff --git a/tokio/tests/tcp_into_split.rs b/tokio/tests/tcp_into_split.rs
new file mode 100644
index 00000000..6561fa30
--- /dev/null
+++ b/tokio/tests/tcp_into_split.rs
@@ -0,0 +1,139 @@
+#![warn(rust_2018_idioms)]
+#![cfg(feature = "full")]
+
+use std::io::{Error, ErrorKind, Result};
+use std::io::{Read, Write};
+use std::sync::{Arc, Barrier};
+use std::{net, thread};
+
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::net::{TcpListener, TcpStream};
+use tokio::try_join;
+
+#[tokio::test]
+async fn split() -> Result<()> {
+ const MSG: &[u8] = b"split";
+
+ let mut listener = TcpListener::bind("127.0.0.1:0").await?;
+ let addr = listener.local_addr()?;
+
+ let (stream1, (mut stream2, _)) = try_join! {
+ TcpStream::connect(&addr),
+ listener.accept(),
+ }?;
+ let (mut read_half, mut write_half) = stream1.into_split();
+
+ let ((), (), ()) = try_join! {
+ async {
+ let len = stream2.write(MSG).await?;
+ assert_eq!(len, MSG.len());
+
+ let mut read_buf = vec![0u8; 32];
+ let read_len = stream2.read(&mut read_buf).await?;
+ assert_eq!(&read_buf[..read_len], MSG);
+ Result::Ok(())
+ },
+ async {
+ let len = write_half.write(MSG).await?;
+ assert_eq!(len, MSG.len());
+ Ok(())
+ },
+ async {
+ let mut read_buf = vec![0u8; 32];
+ let peek_len1 = read_half.peek(&mut read_buf[..]).await?;
+ let peek_len2 = read_half.peek(&mut read_buf[..]).await?;
+ assert_eq!(peek_len1, peek_len2);
+
+ let read_len = read_half.read(&mut read_buf[..]).await?;
+ assert_eq!(peek_len1, read_len);
+ assert_eq!(&read_buf[..read_len], MSG);
+ Ok(())
+ },
+ }?;
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn reunite() -> Result<()> {
+ let listener = net::TcpListener::bind("127.0.0.1:0")?;
+ let addr = listener.local_addr()?;
+
+ let handle = thread::spawn(move || {
+ drop(listener.accept().unwrap());
+ drop(listener.accept().unwrap());
+ });
+
+ let stream1 = TcpStream::connect(&addr).await?;
+ let (read1, write1) = stream1.into_split();
+
+ let stream2 = TcpStream::connect(&addr).await?;
+ let (_, write2) = stream2.into_split();
+
+ let read1 = match read1.reunite(write2) {
+ Ok(_) => panic!("Reunite should not succeed"),
+ Err(err) => err.0,
+ };
+
+ read1.reunite(write1).expect("Reunite should succeed");
+
+ handle.join().unwrap();
+ Ok(())
+}
+
+/// Test that dropping the write half actually closes the stream.
+#[tokio::test]
+async fn drop_write() -> Result<()> {
+ const MSG: &[u8] = b"split";
+
+ let listener = net::TcpListener::bind("127.0.0.1:0")?;
+ let addr = listener.local_addr()?;
+
+ let barrier = Arc::new(Barrier::new(2));
+ let barrier2 = barrier.clone();
+
+ let handle = thread::spawn(move || {
+ let (mut stream, _) = listener.accept().unwrap();
+ stream.write(MSG).unwrap();
+
+ let mut read_buf = [0u8; 32];
+ let res = match stream.read(&mut read_buf) {
+ Ok(0) => Ok(()),
+ Ok(len) => Err(Error::new(
+ ErrorKind::Other,
+ format!("Unexpected read: {} bytes.", len),
+ )),
+ Err(err) => Err(err),
+ };
+
+ barrier2.wait();
+
+ drop(stream);
+
+ res
+ });
+
+ let stream = TcpStream::connect(&addr).await?;
+ let (mut read_half, write_half) = stream.into_split();
+
+ let mut read_buf = [0u8; 32];
+ let read_len = read_half.read(&mut read_buf[..]).await?;
+ assert_eq!(&read_buf[..read_len], MSG);
+
+ // drop it while the read is in progress
+ std::thread::spawn(move || {
+ thread::sleep(std::time::Duration::from_millis(50));
+ drop(write_half);
+ });
+
+ match read_half.read(&mut read_buf[..]).await {
+ Ok(0) => {}
+ Ok(len) => panic!("Unexpected read: {} bytes.", len),
+ Err(err) => panic!("Unexpected error: {}.", err),
+ }
+
+ barrier.wait();
+
+ handle.join().unwrap().unwrap();
+ Ok(())
+}