summaryrefslogtreecommitdiffstats
path: root/tokio-util
diff options
context:
space:
mode:
authorJohn Doneth <doneth7@gmail.com>2020-07-23 11:27:43 -0400
committerGitHub <noreply@github.com>2020-07-23 11:27:43 -0400
commit94b64cd70d936cfc96fbb2a3289c3f02cd163be6 (patch)
tree62dbab67d943af5b349d1657d17518088fe29b33 /tokio-util
parentb5d2b0d05b3fde22fbcbe19bfeca44ee1b846d87 (diff)
udp: Fix `UdpFramed` with regards to `Decode` (#1445)
Diffstat (limited to 'tokio-util')
-rw-r--r--tokio-util/src/udp/frame.rs53
-rw-r--r--tokio-util/tests/udp.rs25
2 files changed, 58 insertions, 20 deletions
diff --git a/tokio-util/src/udp/frame.rs b/tokio-util/src/udp/frame.rs
index 5b098bd4..560f35c9 100644
--- a/tokio-util/src/udp/frame.rs
+++ b/tokio-util/src/udp/frame.rs
@@ -6,6 +6,7 @@ use bytes::{BufMut, BytesMut};
use futures_core::ready;
use futures_sink::Sink;
use std::io;
+use std::mem::MaybeUninit;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -36,6 +37,8 @@ pub struct UdpFramed<C> {
wr: BytesMut,
out_addr: SocketAddr,
flushed: bool,
+ is_readable: bool,
+ current_addr: Option<SocketAddr>,
}
impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
@@ -46,27 +49,39 @@ impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
pin.rd.reserve(INITIAL_RD_CAPACITY);
- let (_n, addr) = unsafe {
- // Read into the buffer without having to initialize the memory.
- //
- // safety: we know tokio::net::UdpSocket never reads from the memory
- // during a recv
- let res = {
- let bytes = &mut *(pin.rd.bytes_mut() as *mut _ as *mut [u8]);
- ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, bytes))
- };
+ loop {
+ // Are there are still bytes left in the read buffer to decode?
+ if pin.is_readable {
+ if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
+ let current_addr = pin
+ .current_addr
+ .expect("will always be set before this line is called");
- let (n, addr) = res?;
- pin.rd.advance_mut(n);
- (n, addr)
- };
+ return Poll::Ready(Some(Ok((frame, current_addr))));
+ }
+
+ // if this line has been reached then decode has returned `None`.
+ pin.is_readable = false;
+ pin.rd.clear();
+ }
- let frame_res = pin.codec.decode(&mut pin.rd);
- pin.rd.clear();
- let frame = frame_res?;
- let result = frame.map(|frame| Ok((frame, addr))); // frame -> (frame, addr)
+ // We're out of data. Try and fetch more data to decode
+ let addr = unsafe {
+ // Convert `&mut [MaybeUnit<u8>]` to `&mut [u8]` because we will be
+ // writing to it via `poll_recv_from` and therefore initializing the memory.
+ let buf: &mut [u8] =
+ &mut *(pin.rd.bytes_mut() as *mut [MaybeUninit<u8>] as *mut [u8]);
- Poll::Ready(result)
+ let res = ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, buf));
+
+ let (n, addr) = res?;
+ pin.rd.advance_mut(n);
+ addr
+ };
+
+ pin.current_addr = Some(addr);
+ pin.is_readable = true;
+ }
}
}
@@ -148,6 +163,8 @@ impl<C> UdpFramed<C> {
rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
flushed: true,
+ is_readable: false,
+ current_addr: None,
}
}
diff --git a/tokio-util/tests/udp.rs b/tokio-util/tests/udp.rs
index 0ba05742..d0320beb 100644
--- a/tokio-util/tests/udp.rs
+++ b/tokio-util/tests/udp.rs
@@ -1,5 +1,5 @@
use tokio::{net::UdpSocket, stream::StreamExt};
-use tokio_util::codec::{Decoder, Encoder};
+use tokio_util::codec::{Decoder, Encoder, LinesCodec};
use tokio_util::udp::UdpFramed;
use bytes::{BufMut, BytesMut};
@@ -10,7 +10,7 @@ use std::io;
#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))]
#[tokio::test]
-async fn send_framed() -> std::io::Result<()> {
+async fn send_framed_byte_codec() -> std::io::Result<()> {
let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?;
let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?;
@@ -77,3 +77,24 @@ impl Encoder<&[u8]> for ByteCodec {
Ok(())
}
}
+
+#[tokio::test]
+async fn send_framed_lines_codec() -> std::io::Result<()> {
+ let a_soc = UdpSocket::bind("127.0.0.1:0").await?;
+ let b_soc = UdpSocket::bind("127.0.0.1:0").await?;
+
+ let a_addr = a_soc.local_addr()?;
+ let b_addr = b_soc.local_addr()?;
+
+ let mut a = UdpFramed::new(a_soc, ByteCodec);
+ let mut b = UdpFramed::new(b_soc, LinesCodec::new());
+
+ let msg = b"1\r\n2\r\n3\r\n".to_vec();
+ a.send((&msg, b_addr)).await?;
+
+ assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr));
+
+ Ok(())
+}