summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormasnagam <masnagam@gmail.com>2020-11-17 02:51:06 +0900
committerGitHub <noreply@github.com>2020-11-16 09:51:06 -0800
commit4e39c9b818eb8af064bb9f45f47e3cfc6593de95 (patch)
tree2222dd2f8638fb64f228badef84814d2f4079a82
parent97c2c4203cd7c42960cac895987c43a17dff052e (diff)
net: restore TcpStream::{poll_read_ready, poll_write_ready} (#2743)
-rw-r--r--tokio/src/net/tcp/stream.rs22
-rw-r--r--tokio/tests/tcp_stream.rs112
2 files changed, 132 insertions, 2 deletions
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index 2ac37a2b..8a157e1c 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -356,6 +356,17 @@ impl TcpStream {
Ok(())
}
+ /// Polls for read readiness.
+ ///
+ /// This function is intended for cases where creating and pinning a future
+ /// via [`readable`] is not feasible. Where possible, using [`readable`] is
+ /// preferred, as this supports polling from multiple tasks at once.
+ ///
+ /// [`readable`]: method@Self::readable
+ pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ self.io.registration().poll_read_ready(cx).map_ok(|_| ())
+ }
+
/// Try to read data from the stream into the provided buffer, returning how
/// many bytes were read.
///
@@ -467,6 +478,17 @@ impl TcpStream {
Ok(())
}
+ /// Polls for write readiness.
+ ///
+ /// This function is intended for cases where creating and pinning a future
+ /// via [`writable`] is not feasible. Where possible, using [`writable`] is
+ /// preferred, as this supports polling from multiple tasks at once.
+ ///
+ /// [`writable`]: method@Self::writable
+ pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ self.io.registration().poll_write_ready(cx).map_ok(|_| ())
+ }
+
/// Try to write a buffer to the stream, returning how many bytes were
/// written.
///
diff --git a/tokio/tests/tcp_stream.rs b/tokio/tests/tcp_stream.rs
index 784ade8a..84d58dc5 100644
--- a/tokio/tests/tcp_stream.rs
+++ b/tokio/tests/tcp_stream.rs
@@ -1,12 +1,16 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
-use tokio::io::Interest;
+use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
use tokio::net::{TcpListener, TcpStream};
+use tokio::try_join;
use tokio_test::task;
-use tokio_test::{assert_pending, assert_ready_ok};
+use tokio_test::{assert_ok, assert_pending, assert_ready_ok};
use std::io;
+use std::task::Poll;
+
+use futures::future::poll_fn;
#[tokio::test]
async fn try_read_write() {
@@ -110,3 +114,107 @@ fn buffer_not_included_in_future() {
let n = mem::size_of_val(&fut);
assert!(n < 1000);
}
+
+macro_rules! assert_readable_by_polling {
+ ($stream:expr) => {
+ assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
+ };
+}
+
+macro_rules! assert_not_readable_by_polling {
+ ($stream:expr) => {
+ poll_fn(|cx| {
+ assert_pending!($stream.poll_read_ready(cx));
+ Poll::Ready(())
+ })
+ .await;
+ };
+}
+
+macro_rules! assert_writable_by_polling {
+ ($stream:expr) => {
+ assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
+ };
+}
+
+macro_rules! assert_not_writable_by_polling {
+ ($stream:expr) => {
+ poll_fn(|cx| {
+ assert_pending!($stream.poll_write_ready(cx));
+ Poll::Ready(())
+ })
+ .await;
+ };
+}
+
+#[tokio::test]
+async fn poll_read_ready() {
+ let (mut client, mut server) = create_pair().await;
+
+ // Initial state - not readable.
+ assert_not_readable_by_polling!(server);
+
+ // There is data in the buffer - readable.
+ assert_ok!(client.write_all(b"ping").await);
+ assert_readable_by_polling!(server);
+
+ // Readable until calls to `poll_read` return `Poll::Pending`.
+ let mut buf = [0u8; 4];
+ assert_ok!(server.read_exact(&mut buf).await);
+ assert_readable_by_polling!(server);
+ read_until_pending(&mut server);
+ assert_not_readable_by_polling!(server);
+
+ // Detect the client disconnect.
+ drop(client);
+ assert_readable_by_polling!(server);
+}
+
+#[tokio::test]
+async fn poll_write_ready() {
+ let (mut client, server) = create_pair().await;
+
+ // Initial state - writable.
+ assert_writable_by_polling!(client);
+
+ // No space to write - not writable.
+ write_until_pending(&mut client);
+ assert_not_writable_by_polling!(client);
+
+ // Detect the server disconnect.
+ drop(server);
+ assert_writable_by_polling!(client);
+}
+
+async fn create_pair() -> (TcpStream, TcpStream) {
+ let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+ let addr = assert_ok!(listener.local_addr());
+ let (client, (server, _)) = assert_ok!(try_join!(TcpStream::connect(&addr), listener.accept()));
+ (client, server)
+}
+
+fn read_until_pending(stream: &mut TcpStream) {
+ let mut buf = vec![0u8; 1024 * 1024];
+ loop {
+ match stream.try_read(&mut buf) {
+ Ok(_) => (),
+ Err(err) => {
+ assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
+ break;
+ }
+ }
+ }
+}
+
+fn write_until_pending(stream: &mut TcpStream) {
+ let buf = vec![0u8; 1024 * 1024];
+ loop {
+ match stream.try_write(&buf) {
+ Ok(_) => (),
+ Err(err) => {
+ assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
+ break;
+ }
+ }
+ }
+}