diff options
Diffstat (limited to 'ipc/src/lib.rs')
-rw-r--r-- | ipc/src/lib.rs | 120 |
1 files changed, 101 insertions, 19 deletions
diff --git a/ipc/src/lib.rs b/ipc/src/lib.rs index 623aa66d..7e6a62ab 100644 --- a/ipc/src/lib.rs +++ b/ipc/src/lib.rs @@ -41,7 +41,7 @@ use std::io::{self, Read, Write}; use std::net::{Ipv4Addr, SocketAddr, TcpStream, TcpListener}; use std::path::PathBuf; -use anyhow::Result; +use anyhow::{anyhow, Result}; use fs2::FileExt; use futures::{Future, Stream}; @@ -52,11 +52,12 @@ use tokio_io::AsyncRead; use capnp_rpc::{RpcSystem, twoparty}; use capnp_rpc::rpc_twoparty_capnp::Side; -/* Unix-specific options. */ -use std::os::unix::io::{IntoRawFd, FromRawFd}; -use std::os::unix::fs::OpenOptionsExt; - -/* XXX: Implement Windows support. */ +#[cfg(unix)] +use std::os::unix::{io::{IntoRawFd, FromRawFd}, fs::OpenOptionsExt}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, IntoRawSocket, FromRawSocket}; +#[cfg(windows)] +use winapi::um::winsock2; use std::process::{Command, Stdio}; use std::thread; @@ -68,6 +69,21 @@ use sequoia_core as core; pub mod assuan; pub mod gnupg; +macro_rules! platform { + { unix => { $($unix:tt)* }, windows => { $($windows:tt)* } } => { + if cfg!(unix) { + #[cfg(unix)] { $($unix)* } + #[cfg(not(unix))] { unreachable!() } + } else if cfg!(windows) { + #[cfg(windows)] { $($windows)* } + #[cfg(not(windows))] { unreachable!() } + } else { + #[cfg(not(any(unix, windows)))] compile_error!("Unsupported platform"); + unreachable!() + } + } +} + /// Servers need to implement this trait. pub trait Handler { /// Called on every connection. @@ -139,12 +155,14 @@ impl Descriptor { }; fs::create_dir_all(self.ctx.home())?; - let mut file = fs::OpenOptions::new() + let mut file = fs::OpenOptions::new(); + file .read(true) .write(true) - .create(true) - .mode(0o600) - .open(&self.rendezvous)?; + .create(true); + #[cfg(unix)] + file.mode(0o600); + let mut file = file.open(&self.rendezvous)?; file.lock_exclusive()?; let mut c = vec![]; @@ -205,17 +223,44 @@ impl Descriptor { } fn fork(&self, listener: TcpListener) -> Result<()> { - Command::new(&self.executable) + let mut cmd = Command::new(&self.executable); + cmd .arg("--home") .arg(self.ctx.home()) .arg("--lib") .arg(self.ctx.lib()) .arg("--ephemeral") .arg(self.ctx.ephemeral().to_string()) - .stdin(unsafe { Stdio::from_raw_fd(listener.into_raw_fd()) }) .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn()?; + .stderr(Stdio::null()); + + platform! { + unix => { + // Pass the listening TCP socket as child stdin. + cmd.stdin(unsafe { Stdio::from_raw_fd(listener.into_raw_fd()) }); + }, + windows => { + // Sockets for `TcpListener` are not inheritable by default, so + // let's make them so, since we'll pass them to a child process. + unsafe { + match winapi::um::handleapi::SetHandleInformation( + listener.as_raw_socket() as _, + winapi::um::winbase::HANDLE_FLAG_INHERIT, + winapi::um::winbase::HANDLE_FLAG_INHERIT, + ) { + 0 => Err(std::io::Error::last_os_error()), + _ => Ok(()) + }? + }; + // We can't pass the socket to stdin directly on Windows, since + // non-overlapped (blocking) I/O handles can be redirected there. + // We use Tokio (async I/O), so we just pass it via env var rather + // than establishing a separate channel to pass the socket through. + cmd.env("SOCKET", format!("{}", listener.into_raw_socket())); + } + } + + cmd.spawn()?; Ok(()) } @@ -253,7 +298,7 @@ impl Server { if args.len() != 7 || args[1] != "--home" || args[3] != "--lib" || args[5] != "--ephemeral" { - return Err(anyhow::anyhow!( + return Err(anyhow!( "Usage: {} --home <HOMEDIR> --lib <LIBDIR> \ --ephemeral true|false", args[0])); } @@ -266,7 +311,7 @@ impl Server { cfg.set_ephemeral(); } } else { - return Err(anyhow::anyhow!( + return Err(anyhow!( "Expected 'true' or 'false' for --ephemeral, got: {}", args[6])); } @@ -276,8 +321,11 @@ impl Server { /// Turns this process into a server. /// - /// External servers must call this early on. Expects 'stdin' to - /// be a listening TCP socket. + /// External servers must call this early on. + /// + /// On Linux expects 'stdin' to be a listening TCP socket. + /// On Windows this expects `SOCKET` env var to be set to a listening socket + /// of the Windows Sockets API `SOCKET` value. /// /// # Example /// @@ -299,7 +347,14 @@ impl Server { /// } /// ``` pub fn serve(&mut self) -> Result<()> { - self.serve_listener(unsafe { TcpListener::from_raw_fd(0) }) + let listener = platform! { + unix => { unsafe { TcpListener::from_raw_fd(0) } }, + windows => { + let socket = std::env::var("SOCKET")?.parse()?; + unsafe { TcpListener::from_raw_socket(socket) } + } + }; + self.serve_listener(listener) } fn serve_listener(&mut self, l: TcpListener) -> Result<()> { @@ -427,3 +482,30 @@ pub enum Error { #[error("Connection closed unexpectedly.")] ConnectionClosed(Vec<u8>), } + +// Global initialization and cleanup of the Windows Sockets API (WSA) module. +// NOTE: This has to be top-level in order for `ctor::{ctor, dtor}` to work. +#[cfg(windows)] +use std::sync::atomic::{AtomicBool, Ordering}; +#[cfg(windows)] +static WSA_INITED: AtomicBool = AtomicBool::new(false); + +#[cfg(windows)] +#[ctor::ctor] +fn wsa_startup() { + unsafe { + let ret = winsock2::WSAStartup( + 0x202, // version 2.2 + &mut std::mem::zeroed(), + ); + WSA_INITED.store(ret != 0, Ordering::SeqCst); + } +} + +#[cfg(windows)] +#[ctor::dtor] +fn wsa_cleanup() { + if WSA_INITED.load(Ordering::SeqCst) { + let _ = unsafe { winsock2::WSACleanup() }; + } +} |