diff options
author | Carl Lerche <me@carllerche.com> | 2020-10-07 13:02:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-07 13:02:29 -0700 |
commit | a9a59ea90eb0fc242bef3bed12986425b66206ee (patch) | |
tree | cdab62ca76028fe5652c5d67fc89feacb4ab1e92 | |
parent | c248167173a6e7ecfa8f596a82dca37041aa5132 (diff) |
net: add `TcpSocket` for configuring a socket (#2920)
This enables the caller to configure the socket and to explicitly bind
the socket before converting it to a `TcpStream` or `TcpListener`.
Closes: #2902
-rw-r--r-- | tokio/Cargo.toml | 2 | ||||
-rw-r--r-- | tokio/src/net/mod.rs | 1 | ||||
-rw-r--r-- | tokio/src/net/tcp/listener.rs | 2 | ||||
-rw-r--r-- | tokio/src/net/tcp/mod.rs | 2 | ||||
-rw-r--r-- | tokio/src/net/tcp/socket.rs | 349 | ||||
-rw-r--r-- | tokio/src/net/tcp/stream.rs | 4 | ||||
-rw-r--r-- | tokio/tests/tcp_socket.rs | 60 |
7 files changed, 418 insertions, 2 deletions
diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 6d8377c2..4d5f833c 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -98,7 +98,7 @@ fnv = { version = "1.0.6", optional = true } futures-core = { version = "0.3.0", optional = true } lazy_static = { version = "1.0.2", optional = true } memchr = { version = "2.2", optional = true } -mio = { version = "0.7.2", optional = true } +mio = { version = "0.7.3", optional = true } num_cpus = { version = "1.8.0", optional = true } parking_lot = { version = "0.11.0", optional = true } # Not in full slab = { version = "0.4.1", optional = true } diff --git a/tokio/src/net/mod.rs b/tokio/src/net/mod.rs index e3fb2c73..a91085f8 100644 --- a/tokio/src/net/mod.rs +++ b/tokio/src/net/mod.rs @@ -35,6 +35,7 @@ cfg_dns! { cfg_tcp! { pub mod tcp; pub use tcp::listener::TcpListener; + pub use tcp::socket::TcpSocket; pub use tcp::stream::TcpStream; } diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs index 0ac03632..133852d2 100644 --- a/tokio/src/net/tcp/listener.rs +++ b/tokio/src/net/tcp/listener.rs @@ -261,7 +261,7 @@ impl TcpListener { Ok(TcpListener { io }) } - fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> { + pub(crate) fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> { let io = PollEvented::new(listener)?; Ok(TcpListener { io }) } diff --git a/tokio/src/net/tcp/mod.rs b/tokio/src/net/tcp/mod.rs index 7ad36eb0..c27038f9 100644 --- a/tokio/src/net/tcp/mod.rs +++ b/tokio/src/net/tcp/mod.rs @@ -6,6 +6,8 @@ pub(crate) use listener::TcpListener; mod incoming; pub use incoming::Incoming; +pub(crate) mod socket; + mod split; pub use split::{ReadHalf, WriteHalf}; diff --git a/tokio/src/net/tcp/socket.rs b/tokio/src/net/tcp/socket.rs new file mode 100644 index 00000000..5b0f802a --- /dev/null +++ b/tokio/src/net/tcp/socket.rs @@ -0,0 +1,349 @@ +use crate::net::{TcpListener, TcpStream}; + +use std::fmt; +use std::io; +use std::net::SocketAddr; + +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; + +/// A TCP socket that has not yet been converted to a `TcpStream` or +/// `TcpListener`. +/// +/// `TcpSocket` wraps an operating system socket and enables the caller to +/// configure the socket before establishing a TCP connection or accepting +/// inbound connections. The caller is able to set socket option and explicitly +/// bind the socket with a socket address. +/// +/// The underlying socket is closed when the `TcpSocket` value is dropped. +/// +/// `TcpSocket` should only be used directly if the default configuration used +/// by `TcpStream::connect` and `TcpListener::bind` does not meet the required +/// use case. +/// +/// Calling `TcpStream::connect("127.0.0.1:8080")` is equivalent to: +/// +/// ```no_run +/// use tokio::net::TcpSocket; +/// +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let addr = "127.0.0.1:8080".parse().unwrap(); +/// +/// let socket = TcpSocket::new_v4()?; +/// let stream = socket.connect(addr).await?; +/// # drop(stream); +/// +/// Ok(()) +/// } +/// ``` +/// +/// Calling `TcpListener::bind("127.0.0.1:8080")` is equivalent to: +/// +/// ```no_run +/// use tokio::net::TcpSocket; +/// +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let addr = "127.0.0.1:8080".parse().unwrap(); +/// +/// let socket = TcpSocket::new_v4()?; +/// // On platforms with Berkeley-derived sockets, this allows to quickly +/// // rebind a socket, without needing to wait for the OS to clean up the +/// // previous one. +/// // +/// // On Windows, this allows rebinding sockets which are actively in use, +/// // which allows “socket hijacking”, so we explicitly don't set it here. +/// // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse +/// socket.set_reuseaddr(true)?; +/// socket.bind(addr)?; +/// +/// let listener = socket.listen(1024)?; +/// # drop(listener); +/// +/// Ok(()) +/// } +/// ``` +/// +/// Setting socket options not explicitly provided by `TcpSocket` may be done by +/// accessing the `RawFd`/`RawSocket` using [`AsRawFd`]/[`AsRawSocket`] and +/// setting the option with a crate like [`socket2`]. +/// +/// [`RawFd`]: https://doc.rust-lang.org/std/os/unix/io/type.RawFd.html +/// [`RawSocket`]: https://doc.rust-lang.org/std/os/windows/io/type.RawSocket.html +/// [`AsRawFd`]: https://doc.rust-lang.org/std/os/unix/io/trait.AsRawFd.html +/// [`AsRawSocket`]: https://doc.rust-lang.org/std/os/windows/io/trait.AsRawSocket.html +/// [`socket2`]: https://docs.rs/socket2/ +pub struct TcpSocket { + inner: mio::net::TcpSocket, +} + +impl TcpSocket { + /// Create a new socket configured for IPv4. + /// + /// Calls `socket(2)` with `AF_INET` and `SOCK_STREAM`. + /// + /// # Returns + /// + /// On success, the newly created `TcpSocket` is returned. If an error is + /// encountered, it is returned instead. + /// + /// # Examples + /// + /// Create a new IPv4 socket and start listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(128)?; + /// # drop(listener); + /// Ok(()) + /// } + /// ``` + pub fn new_v4() -> io::Result<TcpSocket> { + let inner = mio::net::TcpSocket::new_v4()?; + Ok(TcpSocket { inner }) + } + + /// Create a new socket configured for IPv6. + /// + /// Calls `socket(2)` with `AF_INET6` and `SOCK_STREAM`. + /// + /// # Returns + /// + /// On success, the newly created `TcpSocket` is returned. If an error is + /// encountered, it is returned instead. + /// + /// # Examples + /// + /// Create a new IPv6 socket and start listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "[::1]:8080".parse().unwrap(); + /// let socket = TcpSocket::new_v6()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(128)?; + /// # drop(listener); + /// Ok(()) + /// } + /// ``` + pub fn new_v6() -> io::Result<TcpSocket> { + let inner = mio::net::TcpSocket::new_v6()?; + Ok(TcpSocket { inner }) + } + + /// Allow the socket to bind to an in-use address. + /// + /// Behavior is platform specific. Refer to the target platform's + /// documentation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.set_reuseaddr(true)?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> { + self.inner.set_reuseaddr(reuseaddr) + } + + /// Bind the socket to the given address. + /// + /// This calls the `bind(2)` operating-system function. Behavior is + /// platform specific. Refer to the target platform's documentation for more + /// details. + /// + /// # Examples + /// + /// Bind a socket before listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { + self.inner.bind(addr) + } + + /// Establish a TCP connection with a peer at the specified socket address. + /// + /// The `TcpSocket` is consumed. Once the connection is established, a + /// connected [`TcpStream`] is returned. If the connection fails, the + /// encountered error is returned. + /// + /// [`TcpStream`]: TcpStream + /// + /// This calls the `connect(2)` operating-system function. Behavior is + /// platform specific. Refer to the target platform's documentation for more + /// details. + /// + /// # Examples + /// + /// Connecting to a peer. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// let stream = socket.connect(addr).await?; + /// # drop(stream); + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> { + let mio = self.inner.connect(addr)?; + TcpStream::connect_mio(mio).await + } + + /// Convert the socket into a `TcpListener`. + /// + /// `backlog` defines the maximum number of pending connections are queued + /// by the operating system at any given time. Connection are removed from + /// the queue with [`TcpListener::accept`]. When the queue is full, the + /// operationg-system will start rejecting connections. + /// + /// [`TcpListener::accept`]: TcpListener::accept + /// + /// This calls the `listen(2)` operating-system function, marking the socket + /// as a passive socket. Behavior is platform specific. Refer to the target + /// platform's documentation for more details. + /// + /// # Examples + /// + /// Create a `TcpListener`. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn listen(self, backlog: u32) -> io::Result<TcpListener> { + let mio = self.inner.listen(backlog)?; + TcpListener::new(mio) + } +} + +impl fmt::Debug for TcpSocket { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(fmt) + } +} + +#[cfg(unix)] +impl AsRawFd for TcpSocket { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(unix)] +impl FromRawFd for TcpSocket { + /// Converts a `RawFd` to a `TcpSocket`. + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the socket is in + /// non-blocking mode. + unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket { + let inner = mio::net::TcpSocket::from_raw_fd(fd); + TcpSocket { inner } + } +} + +#[cfg(windows)] +impl IntoRawSocket for TcpSocket { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_raw_socket() + } +} + +#[cfg(windows)] +impl AsRawSocket for TcpSocket { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +impl FromRawSocket for TcpSocket { + /// Converts a `RawSocket` to a `TcpStream`. + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the socket is in + /// non-blocking mode. + unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket { + let inner = mio::net::TcpSocket::from_raw_socket(socket); + TcpSocket { inner } + } +} diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 9141d981..4349ea80 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -137,6 +137,10 @@ impl TcpStream { /// Establishes a connection to the specified `addr`. async fn connect_addr(addr: SocketAddr) -> io::Result<TcpStream> { let sys = mio::net::TcpStream::connect(addr)?; + TcpStream::connect_mio(sys).await + } + + pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result<TcpStream> { let stream = TcpStream::new(sys)?; // Once we've connected, wait for the stream to be writable as diff --git a/tokio/tests/tcp_socket.rs b/tokio/tests/tcp_socket.rs new file mode 100644 index 00000000..993a1e0c --- /dev/null +++ b/tokio/tests/tcp_socket.rs @@ -0,0 +1,60 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::net::TcpSocket; +use tokio_test::assert_ok; + +#[tokio::test] +async fn basic_usage_v4() { + // Create server + let addr = assert_ok!("127.0.0.1:0".parse()); + let srv = assert_ok!(TcpSocket::new_v4()); + assert_ok!(srv.bind(addr)); + + let mut srv = assert_ok!(srv.listen(128)); + + // Create client & connect + let addr = srv.local_addr().unwrap(); + let cli = assert_ok!(TcpSocket::new_v4()); + let _cli = assert_ok!(cli.connect(addr).await); + + // Accept + let _ = assert_ok!(srv.accept().await); +} + +#[tokio::test] +async fn basic_usage_v6() { + // Create server + let addr = assert_ok!("[::1]:0".parse()); + let srv = assert_ok!(TcpSocket::new_v6()); + assert_ok!(srv.bind(addr)); + + let mut srv = assert_ok!(srv.listen(128)); + + // Create client & connect + let addr = srv.local_addr().unwrap(); + let cli = assert_ok!(TcpSocket::new_v6()); + let _cli = assert_ok!(cli.connect(addr).await); + + // Accept + let _ = assert_ok!(srv.accept().await); +} + +#[tokio::test] +async fn bind_before_connect() { + // Create server + let any_addr = assert_ok!("127.0.0.1:0".parse()); + let srv = assert_ok!(TcpSocket::new_v4()); + assert_ok!(srv.bind(any_addr)); + + let mut srv = assert_ok!(srv.listen(128)); + + // Create client & connect + let addr = srv.local_addr().unwrap(); + let cli = assert_ok!(TcpSocket::new_v4()); + assert_ok!(cli.bind(any_addr)); + let _cli = assert_ok!(cli.connect(addr).await); + + // Accept + let _ = assert_ok!(srv.accept().await); +} |