summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorFuyang Liu <liufuyang@users.noreply.github.com>2020-12-06 14:33:04 +0100
committerGitHub <noreply@github.com>2020-12-06 14:33:04 +0100
commit0707f4c19210d6dac620c663e94d34834714a7c9 (patch)
treea3aff2f279b1e560602b4752435e092b4a22424e
parent0dbba139848de6a8ee88350cc7fc48d0b05016c5 (diff)
net: add TcpStream::into_std (#3189)
-rw-r--r--tokio/src/io/poll_evented.rs8
-rw-r--r--tokio/src/net/tcp/stream.rs56
-rw-r--r--tokio/tests/tcp_into_std.rs44
3 files changed, 106 insertions, 2 deletions
diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs
index 3a659610..0ecdb180 100644
--- a/tokio/src/io/poll_evented.rs
+++ b/tokio/src/io/poll_evented.rs
@@ -124,6 +124,14 @@ impl<E: Source> PollEvented<E> {
pub(crate) fn registration(&self) -> &Registration {
&self.registration
}
+
+ /// Deregister the inner io from the registration and returns a Result containing the inner io
+ #[cfg(feature = "net")]
+ pub(crate) fn into_inner(mut self) -> io::Result<E> {
+ let mut inner = self.io.take().unwrap(); // As io shouldn't ever be None, just unwrap here.
+ self.registration.deregister(&mut inner)?;
+ Ok(inner)
+ }
}
feature! {
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index 28118f73..83e9f2a7 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -9,10 +9,10 @@ use std::fmt;
use std::io;
use std::net::{Shutdown, SocketAddr};
#[cfg(windows)]
-use std::os::windows::io::{AsRawSocket, FromRawSocket};
+use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
#[cfg(unix)]
-use std::os::unix::io::{AsRawFd, FromRawFd};
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
@@ -184,6 +184,58 @@ impl TcpStream {
Ok(TcpStream { io })
}
+ /// Turn a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`].
+ ///
+ /// The returned [`std::net::TcpStream`] will have `nonblocking mode` set as `true`.
+ /// Use [`set_nonblocking`] to change the blocking mode if needed.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::error::Error;
+ /// use std::io::Read;
+ /// use tokio::net::TcpListener;
+ /// # use tokio::net::TcpStream;
+ /// # use tokio::io::AsyncWriteExt;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> Result<(), Box<dyn Error>> {
+ /// let mut data = [0u8; 12];
+ /// let listener = TcpListener::bind("127.0.0.1:34254").await?;
+ /// # let handle = tokio::spawn(async {
+ /// # let mut stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap();
+ /// # stream.write(b"Hello world!").await.unwrap();
+ /// # });
+ /// let (tokio_tcp_stream, _) = listener.accept().await?;
+ /// let mut std_tcp_stream = tokio_tcp_stream.into_std()?;
+ /// # handle.await.expect("The task being joined has panicked");
+ /// std_tcp_stream.set_nonblocking(false)?;
+ /// std_tcp_stream.read_exact(&mut data)?;
+ /// # assert_eq!(b"Hello world!", &data);
+ /// Ok(())
+ /// }
+ /// ```
+ /// [`tokio::net::TcpStream`]: TcpStream
+ /// [`std::net::TcpStream`]: std::net::TcpStream
+ /// [`set_nonblocking`]: fn@std::net::TcpStream::set_nonblocking
+ pub fn into_std(self) -> io::Result<std::net::TcpStream> {
+ #[cfg(unix)]
+ {
+ self.io
+ .into_inner()
+ .map(|io| io.into_raw_fd())
+ .map(|raw_fd| unsafe { std::net::TcpStream::from_raw_fd(raw_fd) })
+ }
+
+ #[cfg(windows)]
+ {
+ self.io
+ .into_inner()
+ .map(|io| io.into_raw_socket())
+ .map(|raw_socket| unsafe { std::net::TcpStream::from_raw_socket(raw_socket) })
+ }
+ }
+
/// Returns the local address that this stream is bound to.
///
/// # Examples
diff --git a/tokio/tests/tcp_into_std.rs b/tokio/tests/tcp_into_std.rs
new file mode 100644
index 00000000..a46aace7
--- /dev/null
+++ b/tokio/tests/tcp_into_std.rs
@@ -0,0 +1,44 @@
+#![warn(rust_2018_idioms)]
+#![cfg(feature = "full")]
+
+use std::io::Read;
+use std::io::Result;
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::net::TcpListener;
+use tokio::net::TcpStream;
+
+#[tokio::test]
+async fn tcp_into_std() -> Result<()> {
+ let mut data = [0u8; 12];
+ let listener = TcpListener::bind("127.0.0.1:34254").await?;
+
+ let handle = tokio::spawn(async {
+ let stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap();
+ stream
+ });
+
+ let (tokio_tcp_stream, _) = listener.accept().await?;
+ let mut std_tcp_stream = tokio_tcp_stream.into_std()?;
+ std_tcp_stream
+ .set_nonblocking(false)
+ .expect("set_nonblocking call failed");
+
+ let mut client = handle.await.expect("The task being joined has panicked");
+ client.write_all(b"Hello world!").await?;
+
+ std_tcp_stream
+ .read_exact(&mut data)
+ .expect("std TcpStream read failed!");
+ assert_eq!(b"Hello world!", &data);
+
+ // test back to tokio stream
+ std_tcp_stream
+ .set_nonblocking(true)
+ .expect("set_nonblocking call failed");
+ let mut tokio_tcp_stream = TcpStream::from_std(std_tcp_stream)?;
+ client.write_all(b"Hello tokio!").await?;
+ let _size = tokio_tcp_stream.read_exact(&mut data).await?;
+ assert_eq!(b"Hello tokio!", &data);
+
+ Ok(())
+}