summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKevin Leimkuhler <kevin@kleimkuhler.com>2020-01-22 13:22:10 -0800
committerCarl Lerche <me@carllerche.com>2020-01-22 13:22:10 -0800
commit7f580071f3e5d475db200d2101ff35be0b4f6efe (patch)
tree34da374a2d15e13f7e7e8b309f8d6d14cfcd128a
parent5fe2df0fbac79d96913d65612d88bb41c640a1ca (diff)
net: add `ReadHalf::{poll,poll_peak}` (#2151)
The `&mut self` requirements for `TcpStream` methods ensure that there are at most two tasks using the stream--one for reading and one for writing. `TcpStream::split` allows two separate tasks to hold a reference to a single `TcpStream`. `TcpStream::{peek,poll_peek}` only poll for read readiness, and therefore are safe to use with a `ReadHalf`. Instead of duplicating `TcpStream::poll_peek`, a private method is now used by both `poll_peek` methods that uses the fact that only a `&TcpStream` is required. Closes #2136
-rw-r--r--tokio/src/net/tcp/split.rs74
-rw-r--r--tokio/src/net/tcp/stream.rs8
-rw-r--r--tokio/tests/tcp_split.rs43
3 files changed, 124 insertions, 1 deletions
diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs
index 6034d4ef..cce50f6a 100644
--- a/tokio/src/net/tcp/split.rs
+++ b/tokio/src/net/tcp/split.rs
@@ -8,6 +8,7 @@
//! split has no associated overhead and enforces all invariants at the type
//! level.
+use crate::future::poll_fn;
use crate::io::{AsyncRead, AsyncWrite};
use crate::net::TcpStream;
@@ -33,6 +34,79 @@ pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
(ReadHalf(&*stream), WriteHalf(&*stream))
}
+impl ReadHalf<'_> {
+ /// Attempt to receive data on the socket, without removing that data from
+ /// the queue, registering the current task for wakeup if data is not yet
+ /// available.
+ ///
+ /// See the [`TcpStream::poll_peek`] level documenation for more details.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::io;
+ /// use tokio::net::TcpStream;
+ ///
+ /// use futures::future::poll_fn;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let mut stream = TcpStream::connect("127.0.0.1:8000").await?;
+ /// let (mut read_half, _) = stream.split();
+ /// let mut buf = [0; 10];
+ ///
+ /// poll_fn(|cx| {
+ /// read_half.poll_peek(cx, &mut buf)
+ /// }).await?;
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ ///
+ /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
+ pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
+ self.0.poll_peek2(cx, buf)
+ }
+
+ /// Receives data on the socket from the remote address to which it is
+ /// connected, without removing that data from the queue. On success,
+ /// returns the number of bytes peeked.
+ ///
+ /// See the [`TcpStream::peek`] level documenation for more details.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpStream;
+ /// use tokio::prelude::*;
+ /// use std::error::Error;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> Result<(), Box<dyn Error>> {
+ /// // Connect to a peer
+ /// let mut stream = TcpStream::connect("127.0.0.1:8080").await?;
+ /// let (mut read_half, _) = stream.split();
+ ///
+ /// let mut b1 = [0; 10];
+ /// let mut b2 = [0; 10];
+ ///
+ /// // Peek at the data
+ /// let n = read_half.peek(&mut b1).await?;
+ ///
+ /// // Read the data
+ /// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
+ /// assert_eq!(&b1[..n], &b2[..n]);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ ///
+ /// [`TcpStream::peek`]: TcpStream::peek
+ pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ poll_fn(|cx| self.poll_peek(cx, buf)).await
+ }
+}
+
impl AsyncRead for ReadHalf<'_> {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
false
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index f32a6c4a..f3fb880b 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -258,6 +258,14 @@ impl TcpStream {
/// }
/// ```
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
+ self.poll_peek2(cx, buf)
+ }
+
+ pub(super) fn poll_peek2(
+ &self,
+ cx: &mut Context<'_>,
+ buf: &mut [u8],
+ ) -> Poll<io::Result<usize>> {
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
match self.io.get_ref().peek(buf) {
diff --git a/tokio/tests/tcp_split.rs b/tokio/tests/tcp_split.rs
index ae5f249c..42f79770 100644
--- a/tokio/tests/tcp_split.rs
+++ b/tokio/tests/tcp_split.rs
@@ -1 +1,42 @@
-// TODO: write tests using TcpStream::split()
+#![warn(rust_2018_idioms)]
+#![cfg(feature = "full")]
+
+use std::io::Result;
+use std::io::{Read, Write};
+use std::{net, thread};
+
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::net::TcpStream;
+
+#[tokio::test]
+async fn split() -> Result<()> {
+ const MSG: &[u8] = b"split";
+
+ let listener = net::TcpListener::bind("127.0.0.1:0")?;
+ let addr = listener.local_addr()?;
+
+ let handle = thread::spawn(move || {
+ let (mut stream, _) = listener.accept().unwrap();
+ stream.write(MSG).unwrap();
+
+ let mut read_buf = [0u8; 32];
+ let read_len = stream.read(&mut read_buf).unwrap();
+ assert_eq!(&read_buf[..read_len], MSG);
+ });
+
+ let mut stream = TcpStream::connect(&addr).await?;
+ let (mut read_half, mut write_half) = stream.split();
+
+ let mut read_buf = [0u8; 32];
+ let peek_len1 = read_half.peek(&mut read_buf[..]).await?;
+ let peek_len2 = read_half.peek(&mut read_buf[..]).await?;
+ assert_eq!(peek_len1, peek_len2);
+
+ let read_len = read_half.read(&mut read_buf[..]).await?;
+ assert_eq!(peek_len1, read_len);
+ assert_eq!(&read_buf[..read_len], MSG);
+
+ write_half.write(MSG).await?;
+ handle.join().unwrap();
+ Ok(())
+}