diff options
author | Kevin Leimkuhler <kevin@kleimkuhler.com> | 2020-01-22 13:22:10 -0800 |
---|---|---|
committer | Carl Lerche <me@carllerche.com> | 2020-01-22 13:22:10 -0800 |
commit | 7f580071f3e5d475db200d2101ff35be0b4f6efe (patch) | |
tree | 34da374a2d15e13f7e7e8b309f8d6d14cfcd128a | |
parent | 5fe2df0fbac79d96913d65612d88bb41c640a1ca (diff) |
net: add `ReadHalf::{poll,poll_peak}` (#2151)
The `&mut self` requirements for `TcpStream` methods ensure that there are at
most two tasks using the stream--one for reading and one for writing.
`TcpStream::split` allows two separate tasks to hold a reference to a single
`TcpStream`. `TcpStream::{peek,poll_peek}` only poll for read readiness, and
therefore are safe to use with a `ReadHalf`.
Instead of duplicating `TcpStream::poll_peek`, a private method is now used by
both `poll_peek` methods that uses the fact that only a `&TcpStream` is
required.
Closes #2136
-rw-r--r-- | tokio/src/net/tcp/split.rs | 74 | ||||
-rw-r--r-- | tokio/src/net/tcp/stream.rs | 8 | ||||
-rw-r--r-- | tokio/tests/tcp_split.rs | 43 |
3 files changed, 124 insertions, 1 deletions
diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index 6034d4ef..cce50f6a 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -8,6 +8,7 @@ //! split has no associated overhead and enforces all invariants at the type //! level. +use crate::future::poll_fn; use crate::io::{AsyncRead, AsyncWrite}; use crate::net::TcpStream; @@ -33,6 +34,79 @@ pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) { (ReadHalf(&*stream), WriteHalf(&*stream)) } +impl ReadHalf<'_> { + /// Attempt to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. + /// + /// See the [`TcpStream::poll_peek`] level documenation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io; + /// use tokio::net::TcpStream; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut stream = TcpStream::connect("127.0.0.1:8000").await?; + /// let (mut read_half, _) = stream.split(); + /// let mut buf = [0; 10]; + /// + /// poll_fn(|cx| { + /// read_half.poll_peek(cx, &mut buf) + /// }).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`TcpStream::poll_peek`]: TcpStream::poll_peek + pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { + self.0.poll_peek2(cx, buf) + } + + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// See the [`TcpStream::peek`] level documenation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::prelude::*; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect("127.0.0.1:8080").await?; + /// let (mut read_half, _) = stream.split(); + /// + /// let mut b1 = [0; 10]; + /// let mut b2 = [0; 10]; + /// + /// // Peek at the data + /// let n = read_half.peek(&mut b1).await?; + /// + /// // Read the data + /// assert_eq!(n, read_half.read(&mut b2[..n]).await?); + /// assert_eq!(&b1[..n], &b2[..n]); + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`TcpStream::peek`]: TcpStream::peek + pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { + poll_fn(|cx| self.poll_peek(cx, buf)).await + } +} + impl AsyncRead for ReadHalf<'_> { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { false diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index f32a6c4a..f3fb880b 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -258,6 +258,14 @@ impl TcpStream { /// } /// ``` pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { + self.poll_peek2(cx, buf) + } + + pub(super) fn poll_peek2( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; match self.io.get_ref().peek(buf) { diff --git a/tokio/tests/tcp_split.rs b/tokio/tests/tcp_split.rs index ae5f249c..42f79770 100644 --- a/tokio/tests/tcp_split.rs +++ b/tokio/tests/tcp_split.rs @@ -1 +1,42 @@ -// TODO: write tests using TcpStream::split() +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::io::Result; +use std::io::{Read, Write}; +use std::{net, thread}; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +#[tokio::test] +async fn split() -> Result<()> { + const MSG: &[u8] = b"split"; + + let listener = net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + + let handle = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + stream.write(MSG).unwrap(); + + let mut read_buf = [0u8; 32]; + let read_len = stream.read(&mut read_buf).unwrap(); + assert_eq!(&read_buf[..read_len], MSG); + }); + + let mut stream = TcpStream::connect(&addr).await?; + let (mut read_half, mut write_half) = stream.split(); + + let mut read_buf = [0u8; 32]; + let peek_len1 = read_half.peek(&mut read_buf[..]).await?; + let peek_len2 = read_half.peek(&mut read_buf[..]).await?; + assert_eq!(peek_len1, peek_len2); + + let read_len = read_half.read(&mut read_buf[..]).await?; + assert_eq!(peek_len1, read_len); + assert_eq!(&read_buf[..read_len], MSG); + + write_half.write(MSG).await?; + handle.join().unwrap(); + Ok(()) +} |