diff options
-rw-r--r-- | tokio/src/io/mod.rs | 4 | ||||
-rw-r--r-- | tokio/src/io/util/mem.rs | 222 | ||||
-rw-r--r-- | tokio/src/io/util/mod.rs | 3 | ||||
-rw-r--r-- | tokio/tests/io_mem_stream.rs | 83 |
4 files changed, 310 insertions, 2 deletions
diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 7b005560..7e91defd 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -226,8 +226,8 @@ cfg_io_util! { pub(crate) mod util; pub use util::{ - copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, - BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take, + copy, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, + BufReader, BufStream, BufWriter, DuplexStream, Copy, Empty, Lines, Repeat, Sink, Split, Take, }; cfg_stream! { diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs new file mode 100644 index 00000000..02ba6aa7 --- /dev/null +++ b/tokio/src/io/util/mem.rs @@ -0,0 +1,222 @@ +//! In-process memory IO types. + +use crate::io::{AsyncRead, AsyncWrite}; +use crate::loom::sync::Mutex; + +use bytes::{Buf, BytesMut}; +use std::{ + pin::Pin, + sync::Arc, + task::{self, Poll, Waker}, +}; + +/// A bidirectional pipe to read and write bytes in memory. +/// +/// A pair of `DuplexStream`s are created together, and they act as a "channel" +/// that can be used as in-memory IO types. Writing to one of the pairs will +/// allow that data to be read from the other, and vice versa. +/// +/// # Example +/// +/// ``` +/// # async fn ex() -> std::io::Result<()> { +/// # use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// let (mut client, mut server) = tokio::io::duplex(64); +/// +/// client.write_all(b"ping").await?; +/// +/// let mut buf = [0u8; 4]; +/// server.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"ping"); +/// +/// server.write_all(b"pong").await?; +/// +/// client.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"pong"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct DuplexStream { + read: Arc<Mutex<Pipe>>, + write: Arc<Mutex<Pipe>>, +} + +/// A unidirectional IO over a piece of memory. +/// +/// Data can be written to the pipe, and reading will return that data. +#[derive(Debug)] +struct Pipe { + /// The buffer storing the bytes written, also read from. + /// + /// Using a `BytesMut` because it has efficient `Buf` and `BufMut` + /// functionality already. Additionally, it can try to copy data in the + /// same buffer if there read index has advanced far enough. + buffer: BytesMut, + /// Determines if the write side has been closed. + is_closed: bool, + /// The maximum amount of bytes that can be written before returning + /// `Poll::Pending`. + max_buf_size: usize, + /// If the `read` side has been polled and is pending, this is the waker + /// for that parked task. + read_waker: Option<Waker>, + /// If the `write` side has filled the `max_buf_size` and returned + /// `Poll::Pending`, this is the waker for that parked task. + write_waker: Option<Waker>, +} + +// ===== impl DuplexStream ===== + +/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets. +/// +/// The `max_buf_size` argument is the maximum amount of bytes that can be +/// written to a side before the write returns `Poll::Pending`. +pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) { + let one = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + let two = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + + ( + DuplexStream { + read: one.clone(), + write: two.clone(), + }, + DuplexStream { + read: two, + write: one, + }, + ) +} + +impl AsyncRead for DuplexStream { + // Previous rustc required this `self` to be `mut`, even though newer + // versions recognize it isn't needed to call `lock()`. So for + // compatibility, we include the `mut` and `allow` the lint. + // + // See https://github.com/rust-lang/rust/issues/73592 + #[allow(unused_mut)] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<std::io::Result<usize>> { + Pin::new(&mut *self.read.lock().unwrap()).poll_read(cx, buf) + } +} + +impl AsyncWrite for DuplexStream { + #[allow(unused_mut)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + Pin::new(&mut *self.write.lock().unwrap()).poll_write(cx, buf) + } + + #[allow(unused_mut)] + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock().unwrap()).poll_flush(cx) + } + + #[allow(unused_mut)] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock().unwrap()).poll_shutdown(cx) + } +} + +impl Drop for DuplexStream { + fn drop(&mut self) { + // notify the other side of the closure + self.write.lock().unwrap().close(); + } +} + +// ===== impl Pipe ===== + +impl Pipe { + fn new(max_buf_size: usize) -> Self { + Pipe { + buffer: BytesMut::new(), + is_closed: false, + max_buf_size, + read_waker: None, + write_waker: None, + } + } + + fn close(&mut self) { + self.is_closed = true; + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } +} + +impl AsyncRead for Pipe { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<std::io::Result<usize>> { + if self.buffer.has_remaining() { + let max = self.buffer.remaining().min(buf.len()); + self.buffer.copy_to_slice(&mut buf[..max]); + if max > 0 { + // The passed `buf` might have been empty, don't wake up if + // no bytes have been moved. + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + Poll::Ready(Ok(max)) + } else if self.is_closed { + Poll::Ready(Ok(0)) + } else { + self.read_waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} + +impl AsyncWrite for Pipe { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + if self.is_closed { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + let avail = self.max_buf_size - self.buffer.len(); + if avail == 0 { + self.write_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + let len = buf.len().min(avail); + self.buffer.extend_from_slice(&buf[..len]); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + self.close(); + Poll::Ready(Ok(())) + } +} diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index c4754abf..609ff238 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -35,6 +35,9 @@ cfg_io_util! { mod lines; pub use lines::Lines; + mod mem; + pub use mem::{duplex, DuplexStream}; + mod read; mod read_buf; mod read_exact; diff --git a/tokio/tests/io_mem_stream.rs b/tokio/tests/io_mem_stream.rs new file mode 100644 index 00000000..3335214c --- /dev/null +++ b/tokio/tests/io_mem_stream.rs @@ -0,0 +1,83 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; + +#[tokio::test] +async fn ping_pong() { + let (mut a, mut b) = duplex(32); + + let mut buf = [0u8; 4]; + + a.write_all(b"ping").await.unwrap(); + b.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ping"); + + b.write_all(b"pong").await.unwrap(); + a.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"pong"); +} + +#[tokio::test] +async fn across_tasks() { + let (mut a, mut b) = duplex(32); + + let t1 = tokio::spawn(async move { + a.write_all(b"ping").await.unwrap(); + let mut buf = [0u8; 4]; + a.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"pong"); + }); + + let t2 = tokio::spawn(async move { + let mut buf = [0u8; 4]; + b.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ping"); + b.write_all(b"pong").await.unwrap(); + }); + + t1.await.unwrap(); + t2.await.unwrap(); +} + +#[tokio::test] +async fn disconnect() { + let (mut a, mut b) = duplex(32); + + let t1 = tokio::spawn(async move { + a.write_all(b"ping").await.unwrap(); + // and dropped + }); + + let t2 = tokio::spawn(async move { + let mut buf = [0u8; 32]; + let n = b.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], b"ping"); + + let n = b.read(&mut buf).await.unwrap(); + assert_eq!(n, 0); + }); + + t1.await.unwrap(); + t2.await.unwrap(); +} + +#[tokio::test] +async fn max_write_size() { + let (mut a, mut b) = duplex(32); + + let t1 = tokio::spawn(async move { + let n = a.write(&[0u8; 64]).await.unwrap(); + assert_eq!(n, 32); + let n = a.write(&[0u8; 64]).await.unwrap(); + assert_eq!(n, 4); + }); + + let t2 = tokio::spawn(async move { + let mut buf = [0u8; 4]; + b.read_exact(&mut buf).await.unwrap(); + }); + + t1.await.unwrap(); + t2.await.unwrap(); +} |