summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorCarl Lerche <me@carllerche.com>2020-10-07 13:02:29 -0700
committerGitHub <noreply@github.com>2020-10-07 13:02:29 -0700
commita9a59ea90eb0fc242bef3bed12986425b66206ee (patch)
treecdab62ca76028fe5652c5d67fc89feacb4ab1e92
parentc248167173a6e7ecfa8f596a82dca37041aa5132 (diff)
net: add `TcpSocket` for configuring a socket (#2920)
This enables the caller to configure the socket and to explicitly bind the socket before converting it to a `TcpStream` or `TcpListener`. Closes: #2902
-rw-r--r--tokio/Cargo.toml2
-rw-r--r--tokio/src/net/mod.rs1
-rw-r--r--tokio/src/net/tcp/listener.rs2
-rw-r--r--tokio/src/net/tcp/mod.rs2
-rw-r--r--tokio/src/net/tcp/socket.rs349
-rw-r--r--tokio/src/net/tcp/stream.rs4
-rw-r--r--tokio/tests/tcp_socket.rs60
7 files changed, 418 insertions, 2 deletions
diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml
index 6d8377c2..4d5f833c 100644
--- a/tokio/Cargo.toml
+++ b/tokio/Cargo.toml
@@ -98,7 +98,7 @@ fnv = { version = "1.0.6", optional = true }
futures-core = { version = "0.3.0", optional = true }
lazy_static = { version = "1.0.2", optional = true }
memchr = { version = "2.2", optional = true }
-mio = { version = "0.7.2", optional = true }
+mio = { version = "0.7.3", optional = true }
num_cpus = { version = "1.8.0", optional = true }
parking_lot = { version = "0.11.0", optional = true } # Not in full
slab = { version = "0.4.1", optional = true }
diff --git a/tokio/src/net/mod.rs b/tokio/src/net/mod.rs
index e3fb2c73..a91085f8 100644
--- a/tokio/src/net/mod.rs
+++ b/tokio/src/net/mod.rs
@@ -35,6 +35,7 @@ cfg_dns! {
cfg_tcp! {
pub mod tcp;
pub use tcp::listener::TcpListener;
+ pub use tcp::socket::TcpSocket;
pub use tcp::stream::TcpStream;
}
diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs
index 0ac03632..133852d2 100644
--- a/tokio/src/net/tcp/listener.rs
+++ b/tokio/src/net/tcp/listener.rs
@@ -261,7 +261,7 @@ impl TcpListener {
Ok(TcpListener { io })
}
- fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> {
+ pub(crate) fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> {
let io = PollEvented::new(listener)?;
Ok(TcpListener { io })
}
diff --git a/tokio/src/net/tcp/mod.rs b/tokio/src/net/tcp/mod.rs
index 7ad36eb0..c27038f9 100644
--- a/tokio/src/net/tcp/mod.rs
+++ b/tokio/src/net/tcp/mod.rs
@@ -6,6 +6,8 @@ pub(crate) use listener::TcpListener;
mod incoming;
pub use incoming::Incoming;
+pub(crate) mod socket;
+
mod split;
pub use split::{ReadHalf, WriteHalf};
diff --git a/tokio/src/net/tcp/socket.rs b/tokio/src/net/tcp/socket.rs
new file mode 100644
index 00000000..5b0f802a
--- /dev/null
+++ b/tokio/src/net/tcp/socket.rs
@@ -0,0 +1,349 @@
+use crate::net::{TcpListener, TcpStream};
+
+use std::fmt;
+use std::io;
+use std::net::SocketAddr;
+
+#[cfg(unix)]
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+#[cfg(windows)]
+use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
+
+/// A TCP socket that has not yet been converted to a `TcpStream` or
+/// `TcpListener`.
+///
+/// `TcpSocket` wraps an operating system socket and enables the caller to
+/// configure the socket before establishing a TCP connection or accepting
+/// inbound connections. The caller is able to set socket option and explicitly
+/// bind the socket with a socket address.
+///
+/// The underlying socket is closed when the `TcpSocket` value is dropped.
+///
+/// `TcpSocket` should only be used directly if the default configuration used
+/// by `TcpStream::connect` and `TcpListener::bind` does not meet the required
+/// use case.
+///
+/// Calling `TcpStream::connect("127.0.0.1:8080")` is equivalent to:
+///
+/// ```no_run
+/// use tokio::net::TcpSocket;
+///
+/// use std::io;
+///
+/// #[tokio::main]
+/// async fn main() -> io::Result<()> {
+/// let addr = "127.0.0.1:8080".parse().unwrap();
+///
+/// let socket = TcpSocket::new_v4()?;
+/// let stream = socket.connect(addr).await?;
+/// # drop(stream);
+///
+/// Ok(())
+/// }
+/// ```
+///
+/// Calling `TcpListener::bind("127.0.0.1:8080")` is equivalent to:
+///
+/// ```no_run
+/// use tokio::net::TcpSocket;
+///
+/// use std::io;
+///
+/// #[tokio::main]
+/// async fn main() -> io::Result<()> {
+/// let addr = "127.0.0.1:8080".parse().unwrap();
+///
+/// let socket = TcpSocket::new_v4()?;
+/// // On platforms with Berkeley-derived sockets, this allows to quickly
+/// // rebind a socket, without needing to wait for the OS to clean up the
+/// // previous one.
+/// //
+/// // On Windows, this allows rebinding sockets which are actively in use,
+/// // which allows “socket hijacking”, so we explicitly don't set it here.
+/// // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
+/// socket.set_reuseaddr(true)?;
+/// socket.bind(addr)?;
+///
+/// let listener = socket.listen(1024)?;
+/// # drop(listener);
+///
+/// Ok(())
+/// }
+/// ```
+///
+/// Setting socket options not explicitly provided by `TcpSocket` may be done by
+/// accessing the `RawFd`/`RawSocket` using [`AsRawFd`]/[`AsRawSocket`] and
+/// setting the option with a crate like [`socket2`].
+///
+/// [`RawFd`]: https://doc.rust-lang.org/std/os/unix/io/type.RawFd.html
+/// [`RawSocket`]: https://doc.rust-lang.org/std/os/windows/io/type.RawSocket.html
+/// [`AsRawFd`]: https://doc.rust-lang.org/std/os/unix/io/trait.AsRawFd.html
+/// [`AsRawSocket`]: https://doc.rust-lang.org/std/os/windows/io/trait.AsRawSocket.html
+/// [`socket2`]: https://docs.rs/socket2/
+pub struct TcpSocket {
+ inner: mio::net::TcpSocket,
+}
+
+impl TcpSocket {
+ /// Create a new socket configured for IPv4.
+ ///
+ /// Calls `socket(2)` with `AF_INET` and `SOCK_STREAM`.
+ ///
+ /// # Returns
+ ///
+ /// On success, the newly created `TcpSocket` is returned. If an error is
+ /// encountered, it is returned instead.
+ ///
+ /// # Examples
+ ///
+ /// Create a new IPv4 socket and start listening.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(128)?;
+ /// # drop(listener);
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn new_v4() -> io::Result<TcpSocket> {
+ let inner = mio::net::TcpSocket::new_v4()?;
+ Ok(TcpSocket { inner })
+ }
+
+ /// Create a new socket configured for IPv6.
+ ///
+ /// Calls `socket(2)` with `AF_INET6` and `SOCK_STREAM`.
+ ///
+ /// # Returns
+ ///
+ /// On success, the newly created `TcpSocket` is returned. If an error is
+ /// encountered, it is returned instead.
+ ///
+ /// # Examples
+ ///
+ /// Create a new IPv6 socket and start listening.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "[::1]:8080".parse().unwrap();
+ /// let socket = TcpSocket::new_v6()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(128)?;
+ /// # drop(listener);
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn new_v6() -> io::Result<TcpSocket> {
+ let inner = mio::net::TcpSocket::new_v6()?;
+ Ok(TcpSocket { inner })
+ }
+
+ /// Allow the socket to bind to an in-use address.
+ ///
+ /// Behavior is platform specific. Refer to the target platform's
+ /// documentation for more details.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.set_reuseaddr(true)?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(1024)?;
+ /// # drop(listener);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
+ self.inner.set_reuseaddr(reuseaddr)
+ }
+
+ /// Bind the socket to the given address.
+ ///
+ /// This calls the `bind(2)` operating-system function. Behavior is
+ /// platform specific. Refer to the target platform's documentation for more
+ /// details.
+ ///
+ /// # Examples
+ ///
+ /// Bind a socket before listening.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(1024)?;
+ /// # drop(listener);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
+ self.inner.bind(addr)
+ }
+
+ /// Establish a TCP connection with a peer at the specified socket address.
+ ///
+ /// The `TcpSocket` is consumed. Once the connection is established, a
+ /// connected [`TcpStream`] is returned. If the connection fails, the
+ /// encountered error is returned.
+ ///
+ /// [`TcpStream`]: TcpStream
+ ///
+ /// This calls the `connect(2)` operating-system function. Behavior is
+ /// platform specific. Refer to the target platform's documentation for more
+ /// details.
+ ///
+ /// # Examples
+ ///
+ /// Connecting to a peer.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// let stream = socket.connect(addr).await?;
+ /// # drop(stream);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
+ let mio = self.inner.connect(addr)?;
+ TcpStream::connect_mio(mio).await
+ }
+
+ /// Convert the socket into a `TcpListener`.
+ ///
+ /// `backlog` defines the maximum number of pending connections are queued
+ /// by the operating system at any given time. Connection are removed from
+ /// the queue with [`TcpListener::accept`]. When the queue is full, the
+ /// operationg-system will start rejecting connections.
+ ///
+ /// [`TcpListener::accept`]: TcpListener::accept
+ ///
+ /// This calls the `listen(2)` operating-system function, marking the socket
+ /// as a passive socket. Behavior is platform specific. Refer to the target
+ /// platform's documentation for more details.
+ ///
+ /// # Examples
+ ///
+ /// Create a `TcpListener`.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(1024)?;
+ /// # drop(listener);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn listen(self, backlog: u32) -> io::Result<TcpListener> {
+ let mio = self.inner.listen(backlog)?;
+ TcpListener::new(mio)
+ }
+}
+
+impl fmt::Debug for TcpSocket {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.inner.fmt(fmt)
+ }
+}
+
+#[cfg(unix)]
+impl AsRawFd for TcpSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.inner.as_raw_fd()
+ }
+}
+
+#[cfg(unix)]
+impl FromRawFd for TcpSocket {
+ /// Converts a `RawFd` to a `TcpSocket`.
+ ///
+ /// # Notes
+ ///
+ /// The caller is responsible for ensuring that the socket is in
+ /// non-blocking mode.
+ unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket {
+ let inner = mio::net::TcpSocket::from_raw_fd(fd);
+ TcpSocket { inner }
+ }
+}
+
+#[cfg(windows)]
+impl IntoRawSocket for TcpSocket {
+ fn into_raw_socket(self) -> RawSocket {
+ self.inner.into_raw_socket()
+ }
+}
+
+#[cfg(windows)]
+impl AsRawSocket for TcpSocket {
+ fn as_raw_socket(&self) -> RawSocket {
+ self.inner.as_raw_socket()
+ }
+}
+
+#[cfg(windows)]
+impl FromRawSocket for TcpSocket {
+ /// Converts a `RawSocket` to a `TcpStream`.
+ ///
+ /// # Notes
+ ///
+ /// The caller is responsible for ensuring that the socket is in
+ /// non-blocking mode.
+ unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket {
+ let inner = mio::net::TcpSocket::from_raw_socket(socket);
+ TcpSocket { inner }
+ }
+}
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index 9141d981..4349ea80 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -137,6 +137,10 @@ impl TcpStream {
/// Establishes a connection to the specified `addr`.
async fn connect_addr(addr: SocketAddr) -> io::Result<TcpStream> {
let sys = mio::net::TcpStream::connect(addr)?;
+ TcpStream::connect_mio(sys).await
+ }
+
+ pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result<TcpStream> {
let stream = TcpStream::new(sys)?;
// Once we've connected, wait for the stream to be writable as
diff --git a/tokio/tests/tcp_socket.rs b/tokio/tests/tcp_socket.rs
new file mode 100644
index 00000000..993a1e0c
--- /dev/null
+++ b/tokio/tests/tcp_socket.rs
@@ -0,0 +1,60 @@
+#![warn(rust_2018_idioms)]
+#![cfg(feature = "full")]
+
+use tokio::net::TcpSocket;
+use tokio_test::assert_ok;
+
+#[tokio::test]
+async fn basic_usage_v4() {
+ // Create server
+ let addr = assert_ok!("127.0.0.1:0".parse());
+ let srv = assert_ok!(TcpSocket::new_v4());
+ assert_ok!(srv.bind(addr));
+
+ let mut srv = assert_ok!(srv.listen(128));
+
+ // Create client & connect
+ let addr = srv.local_addr().unwrap();
+ let cli = assert_ok!(TcpSocket::new_v4());
+ let _cli = assert_ok!(cli.connect(addr).await);
+
+ // Accept
+ let _ = assert_ok!(srv.accept().await);
+}
+
+#[tokio::test]
+async fn basic_usage_v6() {
+ // Create server
+ let addr = assert_ok!("[::1]:0".parse());
+ let srv = assert_ok!(TcpSocket::new_v6());
+ assert_ok!(srv.bind(addr));
+
+ let mut srv = assert_ok!(srv.listen(128));
+
+ // Create client & connect
+ let addr = srv.local_addr().unwrap();
+ let cli = assert_ok!(TcpSocket::new_v6());
+ let _cli = assert_ok!(cli.connect(addr).await);
+
+ // Accept
+ let _ = assert_ok!(srv.accept().await);
+}
+
+#[tokio::test]
+async fn bind_before_connect() {
+ // Create server
+ let any_addr = assert_ok!("127.0.0.1:0".parse());
+ let srv = assert_ok!(TcpSocket::new_v4());
+ assert_ok!(srv.bind(any_addr));
+
+ let mut srv = assert_ok!(srv.listen(128));
+
+ // Create client & connect
+ let addr = srv.local_addr().unwrap();
+ let cli = assert_ok!(TcpSocket::new_v4());
+ assert_ok!(cli.bind(any_addr));
+ let _cli = assert_ok!(cli.connect(addr).await);
+
+ // Accept
+ let _ = assert_ok!(srv.accept().await);
+}