From b01b2dacf2e4136c0237977dac27a3688467d2ea Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 11 Dec 2020 23:40:24 -0500 Subject: net: update `TcpStream::poll_peek` to use `ReadBuf` (#3259) Closes #2987 --- tokio/src/net/tcp/split.rs | 12 +++++++++--- tokio/src/net/tcp/split_owned.rs | 12 +++++++++--- tokio/src/net/tcp/stream.rs | 21 +++++++++++++++++---- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index 28c94eb4..296b469d 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -60,7 +60,7 @@ impl ReadHalf<'_> { /// # Examples /// /// ```no_run - /// use tokio::io; + /// use tokio::io::{self, ReadBuf}; /// use tokio::net::TcpStream; /// /// use futures::future::poll_fn; @@ -70,6 +70,7 @@ impl ReadHalf<'_> { /// let mut stream = TcpStream::connect("127.0.0.1:8000").await?; /// let (mut read_half, _) = stream.split(); /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); /// /// poll_fn(|cx| { /// read_half.poll_peek(cx, &mut buf) @@ -80,7 +81,11 @@ impl ReadHalf<'_> { /// ``` /// /// [`TcpStream::poll_peek`]: TcpStream::poll_peek - pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + pub fn poll_peek( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.0.poll_peek(cx, buf) } @@ -124,7 +129,8 @@ impl ReadHalf<'_> { /// [`read`]: fn@crate::io::AsyncReadExt::read /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { - poll_fn(|cx| self.poll_peek(cx, buf)).await + let mut buf = ReadBuf::new(buf); + poll_fn(|cx| self.poll_peek(cx, &mut buf)).await } } diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs index 8d77c8ca..725d7411 100644 --- a/tokio/src/net/tcp/split_owned.rs +++ b/tokio/src/net/tcp/split_owned.rs @@ -115,7 +115,7 @@ impl OwnedReadHalf { /// # Examples /// /// ```no_run - /// use tokio::io; + /// use tokio::io::{self, ReadBuf}; /// use tokio::net::TcpStream; /// /// use futures::future::poll_fn; @@ -125,6 +125,7 @@ impl OwnedReadHalf { /// let stream = TcpStream::connect("127.0.0.1:8000").await?; /// let (mut read_half, _) = stream.into_split(); /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); /// /// poll_fn(|cx| { /// read_half.poll_peek(cx, &mut buf) @@ -135,7 +136,11 @@ impl OwnedReadHalf { /// ``` /// /// [`TcpStream::poll_peek`]: TcpStream::poll_peek - pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + pub fn poll_peek( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.inner.poll_peek(cx, buf) } @@ -179,7 +184,8 @@ impl OwnedReadHalf { /// [`read`]: fn@crate::io::AsyncReadExt::read /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { - poll_fn(|cx| self.poll_peek(cx, buf)).await + let mut buf = ReadBuf::new(buf); + poll_fn(|cx| self.poll_peek(cx, &mut buf)).await } } diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 83e9f2a7..c4a0d12c 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -291,7 +291,7 @@ impl TcpStream { /// # Examples /// /// ```no_run - /// use tokio::io; + /// use tokio::io::{self, ReadBuf}; /// use tokio::net::TcpStream; /// /// use futures::future::poll_fn; @@ -300,6 +300,7 @@ impl TcpStream { /// async fn main() -> io::Result<()> { /// let stream = TcpStream::connect("127.0.0.1:8000").await?; /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); /// /// poll_fn(|cx| { /// stream.poll_peek(cx, &mut buf) @@ -308,12 +309,24 @@ impl TcpStream { /// Ok(()) /// } /// ``` - pub fn poll_peek(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + pub fn poll_peek( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { loop { let ev = ready!(self.io.registration().poll_read_ready(cx))?; - match self.io.peek(buf) { - Ok(ret) => return Poll::Ready(Ok(ret)), + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + + match self.io.peek(b) { + Ok(ret) => { + unsafe { buf.assume_init(ret) }; + buf.advance(ret); + return Poll::Ready(Ok(ret)); + } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.registration().clear_readiness(ev); } -- cgit v1.2.3