diff options
author | Carl Lerche <me@carllerche.com> | 2019-07-09 12:37:14 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-07-09 12:37:14 -0700 |
commit | 64343f1b786b386bb4fbd6169fc4850cce9e245a (patch) | |
tree | cc541e2ad896afdd44cd7c3d43000deab4b41009 | |
parent | 82795184c1a8ce136c4f0fee2a6c9127ff495565 (diff) |
tokio: add AsyncWriteExt::write_all (#1277)
-rw-r--r-- | tokio/src/io/async_write_ext.rs | 20 | ||||
-rw-r--r-- | tokio/src/io/mod.rs | 1 | ||||
-rw-r--r-- | tokio/src/io/write_all.rs | 46 | ||||
-rw-r--r-- | tokio/tests/io_write.rs | 18 | ||||
-rw-r--r-- | tokio/tests/io_write_all.rs | 51 |
5 files changed, 127 insertions, 9 deletions
diff --git a/tokio/src/io/async_write_ext.rs b/tokio/src/io/async_write_ext.rs index ffb5794f..1972238f 100644 --- a/tokio/src/io/async_write_ext.rs +++ b/tokio/src/io/async_write_ext.rs @@ -1,13 +1,11 @@ use crate::io::write::{write, Write}; +use crate::io::write_all::{write_all, WriteAll}; use tokio_io::AsyncWrite; /// An extension trait which adds utility methods to `AsyncWrite` types. pub trait AsyncWriteExt: AsyncWrite { - /// Write the provided data into `self`. - /// - /// The returned future will resolve to the number of bytes written once the - /// write operation is completed. + /// Write a buffer into this writter, returning how many bytes were written. /// /// # Examples /// @@ -20,6 +18,20 @@ pub trait AsyncWriteExt: AsyncWrite { { write(self, src) } + + /// Attempt to write an entire buffer into this writter. + /// + /// # Examples + /// + /// ``` + /// unimplemented!(); + /// ``` + fn write_all<'a>(&'a mut self, src: &'a [u8]) -> WriteAll<'a, Self> + where + Self: Unpin, + { + write_all(self, src) + } } impl<W: AsyncWrite + ?Sized> AsyncWriteExt for W {} diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 99e4954f..48dbdf16 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -42,6 +42,7 @@ mod copy; mod read; mod read_exact; mod write; +mod write_all; pub use self::async_read_ext::AsyncReadExt; pub use self::async_write_ext::AsyncWriteExt; diff --git a/tokio/src/io/write_all.rs b/tokio/src/io/write_all.rs new file mode 100644 index 00000000..c337ee78 --- /dev/null +++ b/tokio/src/io/write_all.rs @@ -0,0 +1,46 @@ +use tokio_io::AsyncWrite; + +use std::future::Future; +use std::io; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct WriteAll<'a, W: ?Sized> { + writer: &'a mut W, + buf: &'a [u8], +} + +pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + WriteAll { writer, buf } +} + +impl<W: ?Sized + Unpin> Unpin for WriteAll<'_, W> {} + +impl<W> Future for WriteAll<'_, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let me = &mut *self; + while !me.buf.is_empty() { + let n = ready!(Pin::new(&mut me.writer).poll_write(cx, me.buf))?; + { + let (_, rest) = mem::replace(&mut me.buf, &[]).split_at(n); + me.buf = rest; + } + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + + Poll::Ready(Ok(())) + } +} diff --git a/tokio/tests/io_write.rs b/tokio/tests/io_write.rs index 990abdf2..482f21ba 100644 --- a/tokio/tests/io_write.rs +++ b/tokio/tests/io_write.rs @@ -11,7 +11,10 @@ use std::task::{Context, Poll}; #[tokio::test] async fn write() { - struct Wr(BytesMut); + struct Wr { + buf: BytesMut, + cnt: usize, + } impl AsyncWrite for Wr { fn poll_write( @@ -19,8 +22,9 @@ async fn write() { _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { - self.0.extend(buf); - Ok(buf.len()).into() + assert_eq!(self.cnt, 0); + self.buf.extend(&buf[0..4]); + Ok(4).into() } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { @@ -32,8 +36,12 @@ async fn write() { } } - let mut wr = Wr(BytesMut::with_capacity(64)); + let mut wr = Wr { + buf: BytesMut::with_capacity(64), + cnt: 0, + }; let n = assert_ok!(wr.write(b"hello world").await); - assert_eq!(n, 11); + assert_eq!(n, 4); + assert_eq!(wr.buf, b"hell"[..]); } diff --git a/tokio/tests/io_write_all.rs b/tokio/tests/io_write_all.rs new file mode 100644 index 00000000..b12e0f81 --- /dev/null +++ b/tokio/tests/io_write_all.rs @@ -0,0 +1,51 @@ +#![deny(warnings, rust_2018_idioms)] +#![feature(async_await)] + +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio_test::assert_ok; + +use bytes::BytesMut; +use std::cmp; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[tokio::test] +async fn write_all() { + struct Wr { + buf: BytesMut, + cnt: usize, + } + + impl AsyncWrite for Wr { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let n = cmp::min(4, buf.len()); + let buf = &buf[0..n]; + + self.cnt += 1; + self.buf.extend(buf); + Ok(buf.len()).into() + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Ok(()).into() + } + } + + let mut wr = Wr { + buf: BytesMut::with_capacity(64), + cnt: 0, + }; + + assert_ok!(wr.write_all(b"hello world").await); + assert_eq!(wr.buf, b"hello world"[..]); + assert_eq!(wr.cnt, 3); +} |