From 0707f4c19210d6dac620c663e94d34834714a7c9 Mon Sep 17 00:00:00 2001 From: Fuyang Liu Date: Sun, 6 Dec 2020 14:33:04 +0100 Subject: net: add TcpStream::into_std (#3189) --- tokio/src/io/poll_evented.rs | 8 +++++++ tokio/src/net/tcp/stream.rs | 56 ++++++++++++++++++++++++++++++++++++++++++-- tokio/tests/tcp_into_std.rs | 44 ++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 tokio/tests/tcp_into_std.rs diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index 3a659610..0ecdb180 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -124,6 +124,14 @@ impl PollEvented { pub(crate) fn registration(&self) -> &Registration { &self.registration } + + /// Deregister the inner io from the registration and returns a Result containing the inner io + #[cfg(feature = "net")] + pub(crate) fn into_inner(mut self) -> io::Result { + let mut inner = self.io.take().unwrap(); // As io shouldn't ever be None, just unwrap here. + self.registration.deregister(&mut inner)?; + Ok(inner) + } } feature! { diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 28118f73..83e9f2a7 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -9,10 +9,10 @@ use std::fmt; use std::io; use std::net::{Shutdown, SocketAddr}; #[cfg(windows)] -use std::os::windows::io::{AsRawSocket, FromRawSocket}; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket}; #[cfg(unix)] -use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; @@ -184,6 +184,58 @@ impl TcpStream { Ok(TcpStream { io }) } + /// Turn a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`]. + /// + /// The returned [`std::net::TcpStream`] will have `nonblocking mode` set as `true`. + /// Use [`set_nonblocking`] to change the blocking mode if needed. + /// + /// # Examples + /// + /// ``` + /// use std::error::Error; + /// use std::io::Read; + /// use tokio::net::TcpListener; + /// # use tokio::net::TcpStream; + /// # use tokio::io::AsyncWriteExt; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let mut data = [0u8; 12]; + /// let listener = TcpListener::bind("127.0.0.1:34254").await?; + /// # let handle = tokio::spawn(async { + /// # let mut stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap(); + /// # stream.write(b"Hello world!").await.unwrap(); + /// # }); + /// let (tokio_tcp_stream, _) = listener.accept().await?; + /// let mut std_tcp_stream = tokio_tcp_stream.into_std()?; + /// # handle.await.expect("The task being joined has panicked"); + /// std_tcp_stream.set_nonblocking(false)?; + /// std_tcp_stream.read_exact(&mut data)?; + /// # assert_eq!(b"Hello world!", &data); + /// Ok(()) + /// } + /// ``` + /// [`tokio::net::TcpStream`]: TcpStream + /// [`std::net::TcpStream`]: std::net::TcpStream + /// [`set_nonblocking`]: fn@std::net::TcpStream::set_nonblocking + pub fn into_std(self) -> io::Result { + #[cfg(unix)] + { + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { std::net::TcpStream::from_raw_fd(raw_fd) }) + } + + #[cfg(windows)] + { + self.io + .into_inner() + .map(|io| io.into_raw_socket()) + .map(|raw_socket| unsafe { std::net::TcpStream::from_raw_socket(raw_socket) }) + } + } + /// Returns the local address that this stream is bound to. /// /// # Examples diff --git a/tokio/tests/tcp_into_std.rs b/tokio/tests/tcp_into_std.rs new file mode 100644 index 00000000..a46aace7 --- /dev/null +++ b/tokio/tests/tcp_into_std.rs @@ -0,0 +1,44 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::io::Read; +use std::io::Result; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; + +#[tokio::test] +async fn tcp_into_std() -> Result<()> { + let mut data = [0u8; 12]; + let listener = TcpListener::bind("127.0.0.1:34254").await?; + + let handle = tokio::spawn(async { + let stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap(); + stream + }); + + let (tokio_tcp_stream, _) = listener.accept().await?; + let mut std_tcp_stream = tokio_tcp_stream.into_std()?; + std_tcp_stream + .set_nonblocking(false) + .expect("set_nonblocking call failed"); + + let mut client = handle.await.expect("The task being joined has panicked"); + client.write_all(b"Hello world!").await?; + + std_tcp_stream + .read_exact(&mut data) + .expect("std TcpStream read failed!"); + assert_eq!(b"Hello world!", &data); + + // test back to tokio stream + std_tcp_stream + .set_nonblocking(true) + .expect("set_nonblocking call failed"); + let mut tokio_tcp_stream = TcpStream::from_std(std_tcp_stream)?; + client.write_all(b"Hello tokio!").await?; + let _size = tokio_tcp_stream.read_exact(&mut data).await?; + assert_eq!(b"Hello tokio!", &data); + + Ok(()) +} -- cgit v1.2.3