summaryrefslogtreecommitdiffstats
path: root/tokio
diff options
context:
space:
mode:
authorSean McArthur <sean@seanmonstar.com>2020-07-22 15:07:39 -0700
committerGitHub <noreply@github.com>2020-07-22 15:07:39 -0700
commit0e090b7ae2c79c35389adab5effaedf825590d87 (patch)
treec29a516a5c42e4450311defea55016a2219ad925 /tokio
parent21f726041cf9a4ca408d97394af220caf90312ed (diff)
io: add `io::duplex()` as bidirectional reader/writer (#2661)
`duplex` returns a pair of connected `DuplexStream`s. `DuplexStream` is a bidirectional type that can be used to simulate IO, but over an in-process piece of memory.
Diffstat (limited to 'tokio')
-rw-r--r--tokio/src/io/mod.rs4
-rw-r--r--tokio/src/io/util/mem.rs222
-rw-r--r--tokio/src/io/util/mod.rs3
-rw-r--r--tokio/tests/io_mem_stream.rs83
4 files changed, 310 insertions, 2 deletions
diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs
index 7b005560..7e91defd 100644
--- a/tokio/src/io/mod.rs
+++ b/tokio/src/io/mod.rs
@@ -226,8 +226,8 @@ cfg_io_util! {
pub(crate) mod util;
pub use util::{
- copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
- BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take,
+ copy, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
+ BufReader, BufStream, BufWriter, DuplexStream, Copy, Empty, Lines, Repeat, Sink, Split, Take,
};
cfg_stream! {
diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs
new file mode 100644
index 00000000..02ba6aa7
--- /dev/null
+++ b/tokio/src/io/util/mem.rs
@@ -0,0 +1,222 @@
+//! In-process memory IO types.
+
+use crate::io::{AsyncRead, AsyncWrite};
+use crate::loom::sync::Mutex;
+
+use bytes::{Buf, BytesMut};
+use std::{
+ pin::Pin,
+ sync::Arc,
+ task::{self, Poll, Waker},
+};
+
+/// A bidirectional pipe to read and write bytes in memory.
+///
+/// A pair of `DuplexStream`s are created together, and they act as a "channel"
+/// that can be used as in-memory IO types. Writing to one of the pairs will
+/// allow that data to be read from the other, and vice versa.
+///
+/// # Example
+///
+/// ```
+/// # async fn ex() -> std::io::Result<()> {
+/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
+/// let (mut client, mut server) = tokio::io::duplex(64);
+///
+/// client.write_all(b"ping").await?;
+///
+/// let mut buf = [0u8; 4];
+/// server.read_exact(&mut buf).await?;
+/// assert_eq!(&buf, b"ping");
+///
+/// server.write_all(b"pong").await?;
+///
+/// client.read_exact(&mut buf).await?;
+/// assert_eq!(&buf, b"pong");
+/// # Ok(())
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct DuplexStream {
+ read: Arc<Mutex<Pipe>>,
+ write: Arc<Mutex<Pipe>>,
+}
+
+/// A unidirectional IO over a piece of memory.
+///
+/// Data can be written to the pipe, and reading will return that data.
+#[derive(Debug)]
+struct Pipe {
+ /// The buffer storing the bytes written, also read from.
+ ///
+ /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
+ /// functionality already. Additionally, it can try to copy data in the
+ /// same buffer if there read index has advanced far enough.
+ buffer: BytesMut,
+ /// Determines if the write side has been closed.
+ is_closed: bool,
+ /// The maximum amount of bytes that can be written before returning
+ /// `Poll::Pending`.
+ max_buf_size: usize,
+ /// If the `read` side has been polled and is pending, this is the waker
+ /// for that parked task.
+ read_waker: Option<Waker>,
+ /// If the `write` side has filled the `max_buf_size` and returned
+ /// `Poll::Pending`, this is the waker for that parked task.
+ write_waker: Option<Waker>,
+}
+
+// ===== impl DuplexStream =====
+
+/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
+///
+/// The `max_buf_size` argument is the maximum amount of bytes that can be
+/// written to a side before the write returns `Poll::Pending`.
+pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
+ let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
+ let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
+
+ (
+ DuplexStream {
+ read: one.clone(),
+ write: two.clone(),
+ },
+ DuplexStream {
+ read: two,
+ write: one,
+ },
+ )
+}
+
+impl AsyncRead for DuplexStream {
+ // Previous rustc required this `self` to be `mut`, even though newer
+ // versions recognize it isn't needed to call `lock()`. So for
+ // compatibility, we include the `mut` and `allow` the lint.
+ //
+ // See https://github.com/rust-lang/rust/issues/73592
+ #[allow(unused_mut)]
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut [u8],
+ ) -> Poll<std::io::Result<usize>> {
+ Pin::new(&mut *self.read.lock().unwrap()).poll_read(cx, buf)
+ }
+}
+
+impl AsyncWrite for DuplexStream {
+ #[allow(unused_mut)]
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<std::io::Result<usize>> {
+ Pin::new(&mut *self.write.lock().unwrap()).poll_write(cx, buf)
+ }
+
+ #[allow(unused_mut)]
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.write.lock().unwrap()).poll_flush(cx)
+ }
+
+ #[allow(unused_mut)]
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.write.lock().unwrap()).poll_shutdown(cx)
+ }
+}
+
+impl Drop for DuplexStream {
+ fn drop(&mut self) {
+ // notify the other side of the closure
+ self.write.lock().unwrap().close();
+ }
+}
+
+// ===== impl Pipe =====
+
+impl Pipe {
+ fn new(max_buf_size: usize) -> Self {
+ Pipe {
+ buffer: BytesMut::new(),
+ is_closed: false,
+ max_buf_size,
+ read_waker: None,
+ write_waker: None,
+ }
+ }
+
+ fn close(&mut self) {
+ self.is_closed = true;
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ }
+}
+
+impl AsyncRead for Pipe {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut [u8],
+ ) -> Poll<std::io::Result<usize>> {
+ if self.buffer.has_remaining() {
+ let max = self.buffer.remaining().min(buf.len());
+ self.buffer.copy_to_slice(&mut buf[..max]);
+ if max > 0 {
+ // The passed `buf` might have been empty, don't wake up if
+ // no bytes have been moved.
+ if let Some(waker) = self.write_waker.take() {
+ waker.wake();
+ }
+ }
+ Poll::Ready(Ok(max))
+ } else if self.is_closed {
+ Poll::Ready(Ok(0))
+ } else {
+ self.read_waker = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+}
+
+impl AsyncWrite for Pipe {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<std::io::Result<usize>> {
+ if self.is_closed {
+ return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
+ }
+ let avail = self.max_buf_size - self.buffer.len();
+ if avail == 0 {
+ self.write_waker = Some(cx.waker().clone());
+ return Poll::Pending;
+ }
+
+ let len = buf.len().min(avail);
+ self.buffer.extend_from_slice(&buf[..len]);
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ Poll::Ready(Ok(len))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ _: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ self.close();
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs
index c4754abf..609ff238 100644
--- a/tokio/src/io/util/mod.rs
+++ b/tokio/src/io/util/mod.rs
@@ -35,6 +35,9 @@ cfg_io_util! {
mod lines;
pub use lines::Lines;
+ mod mem;
+ pub use mem::{duplex, DuplexStream};
+
mod read;
mod read_buf;
mod read_exact;
diff --git a/tokio/tests/io_mem_stream.rs b/tokio/tests/io_mem_stream.rs
new file mode 100644
index 00000000..3335214c
--- /dev/null
+++ b/tokio/tests/io_mem_stream.rs
@@ -0,0 +1,83 @@
+#![warn(rust_2018_idioms)]
+#![cfg(feature = "full")]
+
+use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
+
+#[tokio::test]
+async fn ping_pong() {
+ let (mut a, mut b) = duplex(32);
+
+ let mut buf = [0u8; 4];
+
+ a.write_all(b"ping").await.unwrap();
+ b.read_exact(&mut buf).await.unwrap();
+ assert_eq!(&buf, b"ping");
+
+ b.write_all(b"pong").await.unwrap();
+ a.read_exact(&mut buf).await.unwrap();
+ assert_eq!(&buf, b"pong");
+}
+
+#[tokio::test]
+async fn across_tasks() {
+ let (mut a, mut b) = duplex(32);
+
+ let t1 = tokio::spawn(async move {
+ a.write_all(b"ping").await.unwrap();
+ let mut buf = [0u8; 4];
+ a.read_exact(&mut buf).await.unwrap();
+ assert_eq!(&buf, b"pong");
+ });
+
+ let t2 = tokio::spawn(async move {
+ let mut buf = [0u8; 4];
+ b.read_exact(&mut buf).await.unwrap();
+ assert_eq!(&buf, b"ping");
+ b.write_all(b"pong").await.unwrap();
+ });
+
+ t1.await.unwrap();
+ t2.await.unwrap();
+}
+
+#[tokio::test]
+async fn disconnect() {
+ let (mut a, mut b) = duplex(32);
+
+ let t1 = tokio::spawn(async move {
+ a.write_all(b"ping").await.unwrap();
+ // and dropped
+ });
+
+ let t2 = tokio::spawn(async move {
+ let mut buf = [0u8; 32];
+ let n = b.read(&mut buf).await.unwrap();
+ assert_eq!(&buf[..n], b"ping");
+
+ let n = b.read(&mut buf).await.unwrap();
+ assert_eq!(n, 0);
+ });
+
+ t1.await.unwrap();
+ t2.await.unwrap();
+}
+
+#[tokio::test]
+async fn max_write_size() {
+ let (mut a, mut b) = duplex(32);
+
+ let t1 = tokio::spawn(async move {
+ let n = a.write(&[0u8; 64]).await.unwrap();
+ assert_eq!(n, 32);
+ let n = a.write(&[0u8; 64]).await.unwrap();
+ assert_eq!(n, 4);
+ });
+
+ let t2 = tokio::spawn(async move {
+ let mut buf = [0u8; 4];
+ b.read_exact(&mut buf).await.unwrap();
+ });
+
+ t1.await.unwrap();
+ t2.await.unwrap();
+}