diff options
Diffstat (limited to 'tokio-util/src/codec/framed_write.rs')
-rw-r--r-- | tokio-util/src/codec/framed_write.rs | 236 |
1 files changed, 21 insertions, 215 deletions
diff --git a/tokio-util/src/codec/framed_write.rs b/tokio-util/src/codec/framed_write.rs index c0049b2d..834eb6ed 100644 --- a/tokio-util/src/codec/framed_write.rs +++ b/tokio-util/src/codec/framed_write.rs @@ -1,20 +1,12 @@ -use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; -use crate::codec::framed::{Fuse, ProjectFuse}; +use crate::codec::framed_impl::{FramedImpl, WriteFrame}; -use tokio::{ - io::{AsyncBufRead, AsyncRead, AsyncWrite}, - stream::Stream, -}; +use tokio::{io::AsyncWrite, stream::Stream}; -use bytes::{Buf, BytesMut}; -use futures_core::ready; use futures_sink::Sink; -use log::trace; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, BufRead, Read}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -24,21 +16,10 @@ pin_project! { /// [`Sink`]: futures_sink::Sink pub struct FramedWrite<T, E> { #[pin] - inner: FramedWrite2<Fuse<T, E>>, + inner: FramedImpl<T, E, WriteFrame>, } } -pin_project! { - pub(crate) struct FramedWrite2<T> { - #[pin] - inner: T, - buffer: BytesMut, - } -} - -const INITIAL_CAPACITY: usize = 8 * 1024; -const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; - impl<T, E> FramedWrite<T, E> where T: AsyncWrite, @@ -46,10 +27,11 @@ where /// Creates a new `FramedWrite` with the given `encoder`. pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> { FramedWrite { - inner: framed_write2(Fuse { - io: inner, + inner: FramedImpl { + inner, codec: encoder, - }), + state: WriteFrame::default(), + }, } } } @@ -62,7 +44,7 @@ impl<T, E> FramedWrite<T, E> { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { - &self.inner.inner.io + &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by @@ -72,7 +54,7 @@ impl<T, E> FramedWrite<T, E> { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.inner.io + &mut self.inner.inner } /// Consumes the `FramedWrite`, returning its underlying I/O stream. @@ -81,21 +63,21 @@ impl<T, E> FramedWrite<T, E> { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { - self.inner.inner.io + self.inner.inner } - /// Returns a reference to the underlying decoder. + /// Returns a reference to the underlying encoder. pub fn encoder(&self) -> &E { - &self.inner.inner.codec + &self.inner.codec } - /// Returns a mutable reference to the underlying decoder. + /// Returns a mutable reference to the underlying encoder. pub fn encoder_mut(&mut self) -> &mut E { - &mut self.inner.inner.codec + &mut self.inner.codec } } -// This impl just defers to the underlying FramedWrite2 +// This impl just defers to the underlying FramedImpl impl<T, I, E> Sink<I> for FramedWrite<T, E> where T: AsyncWrite, @@ -121,6 +103,7 @@ where } } +// This impl just defers to the underlying T: Stream impl<T, D> Stream for FramedWrite<T, D> where T: Stream, @@ -128,13 +111,7 @@ where type Item = T::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_next(cx) + self.project().inner.project().inner.poll_next(cx) } } @@ -145,180 +122,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedWrite") - .field("inner", &self.inner.get_ref().io) - .field("encoder", &self.inner.get_ref().codec) - .field("buffer", &self.inner.buffer) + .field("inner", &self.get_ref()) + .field("encoder", &self.encoder()) + .field("buffer", &self.inner.state.buffer) .finish() } } - -// ===== impl FramedWrite2 ===== - -pub(crate) fn framed_write2<T>(inner: T) -> FramedWrite2<T> { - FramedWrite2 { - inner, - buffer: BytesMut::with_capacity(INITIAL_CAPACITY), - } -} - -pub(crate) fn framed_write2_with_buffer<T>(inner: T, mut buf: BytesMut) -> FramedWrite2<T> { - if buf.capacity() < INITIAL_CAPACITY { - let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); - buf.reserve(bytes_to_reserve); - } - FramedWrite2 { inner, buffer: buf } -} - -impl<T> FramedWrite2<T> { - pub(crate) fn get_ref(&self) -> &T { - &self.inner - } - - pub(crate) fn into_inner(self) -> T { - self.inner - } - - pub(crate) fn into_parts(self) -> (T, BytesMut) { - (self.inner, self.buffer) - } - - pub(crate) fn get_mut(&mut self) -> &mut T { - &mut self.inner - } -} - -impl<I, T> Sink<I> for FramedWrite2<T> -where - T: ProjectFuse + AsyncWrite, - T::Codec: Encoder<I>, -{ - type Error = <T::Codec as Encoder<I>>::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - // If the buffer is already over 8KiB, then attempt to flush it. If after flushing it's - // *still* over 8KiB, then apply backpressure (reject the send). - if self.buffer.len() >= BACKPRESSURE_BOUNDARY { - match self.as_mut().poll_flush(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(())) => (), - }; - - if self.buffer.len() >= BACKPRESSURE_BOUNDARY { - return Poll::Pending; - } - } - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - let mut pinned = self.project(); - pinned - .inner - .project() - .codec - .encode(item, &mut pinned.buffer)?; - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - trace!("flushing framed transport"); - let mut pinned = self.project(); - - while !pinned.buffer.is_empty() { - trace!("writing; remaining={}", pinned.buffer.len()); - - let buf = &pinned.buffer; - let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; - - if n == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to \ - write frame to transport", - ) - .into())); - } - - pinned.buffer.advance(n); - } - - // Try flushing the underlying IO - ready!(pinned.inner.poll_flush(cx))?; - - trace!("framed transport flushed"); - Poll::Ready(Ok(())) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - ready!(self.as_mut().poll_flush(cx))?; - ready!(self.project().inner.poll_shutdown(cx))?; - - Poll::Ready(Ok(())) - } -} - -impl<T: Decoder> Decoder for FramedWrite2<T> { - type Item = T::Item; - type Error = T::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result<Option<T::Item>, T::Error> { - self.inner.decode(src) - } - - fn decode_eof(&mut self, src: &mut BytesMut) -> Result<Option<T::Item>, T::Error> { - self.inner.decode_eof(src) - } -} - -impl<T: Read> Read for FramedWrite2<T> { - fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { - self.inner.read(dst) - } -} - -impl<T: BufRead> BufRead for FramedWrite2<T> { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.inner.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.inner.consume(amt) - } -} - -impl<T: AsyncRead> AsyncRead for FramedWrite2<T> { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<Result<usize, io::Error>> { - self.project().inner.poll_read(cx, buf) - } -} - -impl<T: AsyncBufRead> AsyncBufRead for FramedWrite2<T> { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { - self.project().inner.poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().inner.consume(amt) - } -} - -impl<T> ProjectFuse for FramedWrite2<T> -where - T: ProjectFuse, -{ - type Io = T::Io; - type Codec = T::Codec; - - fn project(self: Pin<&mut Self>) -> Fuse<Pin<&mut Self::Io>, &mut Self::Codec> { - self.project().inner.project() - } -} |