summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSean McArthur <sean@seanmonstar.com>2019-12-13 10:25:27 -0800
committerGitHub <noreply@github.com>2019-12-13 10:25:27 -0800
commit8abaf89e5f3d3371f600777751741a3bb0a1047a (patch)
treec7e0f18ddfa8bcc0a69042d2bf37cecaa06f5aac
parentb560df9e669f73d4663915d9e6dd17af9ab227b9 (diff)
Re-enable writev support in TcpStreams (#1956)
-rw-r--r--tokio/Cargo.toml3
-rw-r--r--tokio/src/net/tcp/split.rs9
-rw-r--r--tokio/src/net/tcp/stream.rs67
3 files changed, 77 insertions, 2 deletions
diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml
index 586a756d..d4008966 100644
--- a/tokio/Cargo.toml
+++ b/tokio/Cargo.toml
@@ -85,7 +85,7 @@ signal = [
stream = ["futures-core"]
sync = ["fnv"]
test-util = []
-tcp = ["io-driver"]
+tcp = ["io-driver", "iovec"]
time = ["slab"]
udp = ["io-driver"]
uds = ["io-driver", "mio-uds", "libc"]
@@ -103,6 +103,7 @@ futures-core = { version = "0.3.0", optional = true }
lazy_static = { version = "1.0.2", optional = true }
memchr = { version = "2.2", optional = true }
mio = { version = "0.6.20", optional = true }
+iovec = { version = "0.1.4", optional = true }
num_cpus = { version = "1.8.0", optional = true }
# Backs `DelayQueue`
slab = { version = "0.4.1", optional = true }
diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs
index 2b337c08..6034d4ef 100644
--- a/tokio/src/net/tcp/split.rs
+++ b/tokio/src/net/tcp/split.rs
@@ -11,6 +11,7 @@
use crate::io::{AsyncRead, AsyncWrite};
use crate::net::TcpStream;
+use bytes::Buf;
use std::io;
use std::mem::MaybeUninit;
use std::net::Shutdown;
@@ -55,6 +56,14 @@ impl AsyncWrite for WriteHalf<'_> {
self.0.poll_write_priv(cx, buf)
}
+ fn poll_write_buf<B: Buf>(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ self.0.poll_write_buf_priv(cx, buf)
+ }
+
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index 35add8ea..343c6c5d 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -3,6 +3,8 @@ use crate::io::{AsyncRead, AsyncWrite, PollEvented};
use crate::net::tcp::split::{split, ReadHalf, WriteHalf};
use crate::net::ToSocketAddrs;
+use bytes::Buf;
+use iovec::IoVec;
use std::convert::TryFrom;
use std::fmt;
use std::io::{self, Read, Write};
@@ -639,7 +641,7 @@ impl TcpStream {
}
}
- pub(crate) fn poll_write_priv(
+ pub(super) fn poll_write_priv(
&self,
cx: &mut Context<'_>,
buf: &[u8],
@@ -654,6 +656,61 @@ impl TcpStream {
x => Poll::Ready(x),
}
}
+
+ pub(super) fn poll_write_buf_priv<B: Buf>(
+ &self,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ use std::io::IoSlice;
+
+ ready!(self.io.poll_write_ready(cx))?;
+
+ // The `IoVec` (v0.1.x) type can't have a zero-length size, so create
+ // a dummy version from a 1-length slice which we'll overwrite with
+ // the `bytes_vectored` method.
+ static S: &[u8] = &[0];
+ const MAX_BUFS: usize = 64;
+
+ // IoSlice isn't Copy, so we must expand this manually ;_;
+ let mut slices: [IoSlice<'_>; MAX_BUFS] = [
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ IoSlice::new(S), IoSlice::new(S), IoSlice::new(S), IoSlice::new(S),
+ ];
+ let cnt = buf.bytes_vectored(&mut slices);
+
+ let iovec = <&IoVec>::from(S);
+ let mut vecs = [iovec; MAX_BUFS];
+ for i in 0..cnt {
+ vecs[i] = (*slices[i]).into();
+ }
+
+ match self.io.get_ref().write_bufs(&vecs[..cnt]) {
+ Ok(n) => {
+ buf.advance(n);
+ Poll::Ready(Ok(n))
+ },
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_write_ready(cx)?;
+ Poll::Pending
+ },
+ Err(e) => Poll::Ready(Err(e)),
+ }
+ }
}
impl TryFrom<TcpStream> for mio::net::TcpStream {
@@ -707,6 +764,14 @@ impl AsyncWrite for TcpStream {
self.poll_write_priv(cx, buf)
}
+ fn poll_write_buf<B: Buf>(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ self.poll_write_buf_priv(cx, buf)
+ }
+
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op