summaryrefslogtreecommitdiffstats
path: root/tokio
diff options
context:
space:
mode:
authorMikail Bagishov <bagishov.mikail@yandex.ru>2020-10-05 17:07:46 +0300
committerGitHub <noreply@github.com>2020-10-05 16:07:46 +0200
commit1684e1c80921f13600ee6c4576662b7b587443c6 (patch)
treeaed5ae60b9e3d2398b7807cef099d8e03fb59ae9 /tokio
parent0ed4127d5cab264c84e0669f4ac168eb754f8d23 (diff)
io: optimize writing large buffers to windows stdio (#2888)
Diffstat (limited to 'tokio')
-rw-r--r--tokio/src/io/stdio_common.rs170
1 files changed, 130 insertions, 40 deletions
diff --git a/tokio/src/io/stdio_common.rs b/tokio/src/io/stdio_common.rs
index a86ca375..03800fcb 100644
--- a/tokio/src/io/stdio_common.rs
+++ b/tokio/src/io/stdio_common.rs
@@ -3,7 +3,8 @@ use crate::io::AsyncWrite;
use std::pin::Pin;
use std::task::{Context, Poll};
/// # Windows
-/// AsyncWrite adapter that finds last char boundary in given buffer and does not write the rest.
+/// AsyncWrite adapter that finds last char boundary in given buffer and does not write the rest,
+/// if buffer contents seems to be utf8. Otherwise it only trims buffer down to MAX_BUF.
/// That's why, wrapped writer will always receive well-formed utf-8 bytes.
/// # Other platforms
/// passes data to `inner` as is
@@ -18,6 +19,12 @@ impl<W> SplitByUtf8BoundaryIfWindows<W> {
}
}
+// this constant is defined by Unicode standard.
+const MAX_BYTES_PER_CHAR: usize = 4;
+
+// Subject for tweaking here
+const MAGIC_CONST: usize = 8;
+
impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
where
W: AsyncWrite + Unpin,
@@ -25,46 +32,62 @@ where
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &[u8],
+ mut buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
- // following two ifs are enabled only on windows targets, because
- // on other targets we do not have problems with incomplete utf8 chars
-
- // ensure buffer is not longer than MAX_BUF
- #[cfg(any(target_os = "windows", test))]
- let buf = if buf.len() > crate::io::blocking::MAX_BUF {
- &buf[..crate::io::blocking::MAX_BUF]
- } else {
- buf
- };
- // now remove possible trailing incomplete character
- #[cfg(any(target_os = "windows", test))]
- let buf = match std::str::from_utf8(buf) {
- // `buf` is already utf-8, no need to trim it futher
- Ok(_) => buf,
+ // just a closure to avoid repetitive code
+ let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
+
+ // 1. Only windows stdio can suffer from non-utf8.
+ // We also check for `test` so that we can write some tests
+ // for further code. Since `AsyncWrite` can always shrink
+ // buffer at its discretion, excessive (i.e. in tests) shrinking
+ // does not break correctness.
+ // 2. If buffer is small, it will not be shrinked.
+ // That's why, it's "textness" will not change, so we don't have
+ // to fixup it.
+ if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF
+ {
+ return call_inner(buf);
+ }
+
+ buf = &buf[..crate::io::blocking::MAX_BUF];
+
+ // Now there are two possibilites.
+ // If caller gave is binary buffer, we **should not** shrink it
+ // anymore, because excessive shrinking hits performance.
+ // If caller gave as binary buffer, we **must** additionaly
+ // shrink it to strip incomplete char at the end of buffer.
+ // that's why check we will perform now is allowed to have
+ // false-positive.
+
+ // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes.
+ // if they are (possibly incomplete) utf8, then we can be quite sure
+ // that input buffer was utf8.
+
+ let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
+ Ok(_) => true,
Err(err) => {
- let bad_bytes = buf.len() - err.valid_up_to();
- // TODO: this is too conservative
- const MAX_BYTES_PER_CHAR: usize = 8;
-
- if bad_bytes <= MAX_BYTES_PER_CHAR && err.valid_up_to() > 0 {
- // Input data is probably UTF-8, but last char was split
- // after trimming.
- // let's exclude this character from the buf
- &buf[..err.valid_up_to()]
- } else {
- // UTF-8 violation could not be caused by trimming.
- // Let's pass buffer to underlying writer as is.
- // Why do not we return error here? It is possible
- // that stdout is not console. Such streams allow
- // non-utf8 data. That's why, let's defer to underlying
- // writer and let it return error if needed
- buf
- }
+ let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
+ incomplete_bytes < MAX_BYTES_PER_CHAR
}
};
- // now pass trimmed input buffer to inner writer
- Pin::new(&mut self.inner).poll_write(cx, buf)
+
+ if have_to_fix_up {
+ // We must pop several bytes at the end which form incomplete
+ // character. To achieve it, we exploit UTF8 encoding:
+ // for any code point, all bytes except first start with 0b10 prefix.
+ // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details
+ let trailing_incomplete_char_size = buf
+ .iter()
+ .rev()
+ .take(MAX_BYTES_PER_CHAR)
+ .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
+ .unwrap_or(0)
+ + 1;
+ buf = &buf[..buf.len() - trailing_incomplete_char_size];
+ }
+
+ call_inner(buf)
}
fn poll_flush(
@@ -92,8 +115,10 @@ mod tests {
use std::task::Poll;
const MAX_BUF: usize = 16 * 1024;
- struct MockWriter;
- impl crate::io::AsyncWrite for MockWriter {
+
+ struct TextMockWriter;
+
+ impl crate::io::AsyncWrite for TextMockWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
@@ -116,10 +141,45 @@ mod tests {
}
}
+ struct LoggingMockWriter {
+ write_history: Vec<usize>,
+ }
+
+ impl LoggingMockWriter {
+ fn new() -> Self {
+ LoggingMockWriter {
+ write_history: Vec::new(),
+ }
+ }
+ }
+
+ impl crate::io::AsyncWrite for LoggingMockWriter {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ assert!(buf.len() <= MAX_BUF);
+ self.write_history.push(buf.len());
+ Poll::Ready(Ok(buf.len()))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+ }
+
#[test]
fn test_splitter() {
let data = str::repeat("█", MAX_BUF);
- let mut wr = super::SplitByUtf8BoundaryIfWindows::new(MockWriter);
+ let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
let fut = async move {
wr.write_all(data.as_bytes()).await.unwrap();
};
@@ -129,4 +189,34 @@ mod tests {
.unwrap()
.block_on(fut);
}
+
+ #[test]
+ fn test_pseudo_text() {
+ // In this test we write a piece of binary data, whose beginning is
+ // text though. We then validate that even in this corner case buffer
+ // was not shrinked too much.
+ let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
+ let mut data: Vec<u8> = str::repeat("a", checked_count).into();
+ data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1));
+ let mut writer = LoggingMockWriter::new();
+ let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
+ crate::runtime::Builder::new()
+ .basic_scheduler()
+ .build()
+ .unwrap()
+ .block_on(async {
+ splitter.write_all(&data).await.unwrap();
+ });
+ // Check that at most two writes were performed
+ assert!(writer.write_history.len() <= 2);
+ // Check that all has been written
+ assert_eq!(
+ writer.write_history.iter().copied().sum::<usize>(),
+ data.len()
+ );
+ // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrinked
+ // from the buffer: one because it was outside of MAX_BUF boundary, and
+ // up to one "utf8 code point".
+ assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
+ }
}