diff options
author | kxt <ktamas@fastmail.fm> | 2021-05-27 16:28:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-27 16:28:28 +0200 |
commit | 0c0355dbc6e0159a72b0f55c7aabb83d76c2312a (patch) | |
tree | eb09715c31950f918efe54bd189af65b5b2fdab0 | |
parent | 9bdb40b4c644c6a3a061dd0cc4683fc92d504201 (diff) |
refactors for #525 (#534)
* refactor(fakes): clean up add_terminal_input
* refactor(fakes): append whole buf to output_buffer in FakeStdoutWriter::write
* refactor(fakes): append whole buf to output_buffer in FakeInputOutput::write_to_tty_stdin
* fix(fakes): allow partial reads in read_from_tty_stdout
This patch fixes two bugs in read_from_tty_stdout:
* if there was a partial read (ie. `bytes.read_position` is not 0 but
less than `bytes.content.len()`), subsequent calls to would fill `buf`
starting at index `bytes.read_position` instead of 0, leaving range
0..`bytes.read_position` untouched.
* if `buf` was smaller than `bytes.content.len()`, a panic would occur.
* refactor(channels): use crossbeam instead of mpsc
This patch replaces mpsc with crossbeam channels because crossbeam
supports selecting on multiple channels which will be necessary in a
subsequent patch.
* refactor(threadbus): allow multiple receivers in Bus
This patch changes Bus to use multiple receivers. Method `recv` returns
data from all of them. This will be used in a subsequent patch for
receiving from bounded and unbounded queues at the same time.
* refactor(channels): remove SenderType enum
This enum has only one variant, so the entire enum can be replaced with
the innards of said variant.
* refactor(channels): remove Send+Sync trait implementations
The implementation of these traits is not necessary, as
SenderWithContext is automatically Send and Sync for every T and
ErrorContext that's Send and Sync.
-rw-r--r-- | Cargo.lock | 25 | ||||
-rw-r--r-- | src/tests/fakes.rs | 53 | ||||
-rw-r--r-- | zellij-client/src/lib.rs | 10 | ||||
-rw-r--r-- | zellij-server/src/lib.rs | 29 | ||||
-rw-r--r-- | zellij-server/src/thread_bus.rs | 27 | ||||
-rw-r--r-- | zellij-utils/Cargo.toml | 1 | ||||
-rw-r--r-- | zellij-utils/src/channels.rs | 37 |
7 files changed, 87 insertions, 95 deletions
diff --git a/Cargo.lock b/Cargo.lock index 5283550e6..dc862d534 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -418,6 +418,20 @@ dependencies = [ ] [[package]] +name = "crossbeam" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd01a6eb3daaafa260f6fc94c3a6c36390abc2080e38e3e34ced87393fb77d80" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] name = "crossbeam-channel" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -452,6 +466,16 @@ dependencies = [ ] [[package]] +name = "crossbeam-queue" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f6cb3c7f5b8e51bc3ebb73a2327ad4abdbd119dc13223f14f961d2f38486756" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", +] + +[[package]] name = "crossbeam-utils" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2355,6 +2379,7 @@ dependencies = [ "backtrace", "bincode", "colors-transform", + "crossbeam", "directories-next", "interprocess", "lazy_static", diff --git a/src/tests/fakes.rs b/src/tests/fakes.rs index 23f9f1824..8d11458b9 100644 --- a/src/tests/fakes.rs +++ b/src/tests/fakes.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, VecDeque}; use std::io::Write; use std::os::unix::io::RawFd; use std::path::PathBuf; -use std::sync::{mpsc, Arc, Condvar, Mutex}; +use std::sync::{Arc, Condvar, Mutex}; use std::time::{Duration, Instant}; use zellij_utils::{nix, zellij_tile}; @@ -14,7 +14,7 @@ use zellij_server::os_input_output::{async_trait, AsyncReader, Pid, ServerOsApi} use zellij_tile::data::Palette; use zellij_utils::{ async_std, - channels::{ChannelWithContext, SenderType, SenderWithContext}, + channels::{self, ChannelWithContext, SenderWithContext}, errors::ErrorContext, interprocess::local_socket::LocalSocketStream, ipc::{ClientToServerMsg, ServerToClientMsg}, @@ -52,13 +52,9 @@ impl FakeStdoutWriter { impl Write for FakeStdoutWriter { fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> { - let mut bytes_written = 0; let mut output_buffer = self.output_buffer.lock().unwrap(); - for byte in buf { - bytes_written += 1; - output_buffer.push(*byte); - } - Ok(bytes_written) + output_buffer.extend_from_slice(buf); + Ok(buf.len()) } fn flush(&mut self) -> Result<(), std::io::Error> { let mut output_buffer = self.output_buffer.lock().unwrap(); @@ -83,9 +79,11 @@ pub struct FakeInputOutput { possible_tty_inputs: HashMap<u16, Bytes>, last_snapshot_time: Arc<Mutex<Instant>>, send_instructions_to_client: SenderWithContext<ServerToClientMsg>, - receive_instructions_from_server: Arc<Mutex<mpsc::Receiver<(ServerToClientMsg, ErrorContext)>>>, + receive_instructions_from_server: + Arc<Mutex<channels::Receiver<(ServerToClientMsg, ErrorContext)>>>, send_instructions_to_server: SenderWithContext<ClientToServerMsg>, - receive_instructions_from_client: Arc<Mutex<mpsc::Receiver<(ClientToServerMsg, ErrorContext)>>>, + receive_instructions_from_client: + Arc<Mutex<channels::Receiver<(ClientToServerMsg, ErrorContext)>>>, should_trigger_sigwinch: Arc<(Mutex<bool>, Condvar)>, sigwinch_event: Option<PositionAndSize>, } @@ -96,11 +94,11 @@ impl FakeInputOutput { let last_snapshot_time = Arc::new(Mutex::new(Instant::now())); let stdout_writer = FakeStdoutWriter::new(last_snapshot_time.clone()); let (client_sender, client_receiver): ChannelWithContext<ServerToClientMsg> = - mpsc::channel(); - let send_instructions_to_client = SenderWithContext::new(SenderType::Sender(client_sender)); + channels::unbounded(); + let send_instructions_to_client = SenderWithContext::new(client_sender); let (server_sender, server_receiver): ChannelWithContext<ClientToServerMsg> = - mpsc::channel(); - let send_instructions_to_server = SenderWithContext::new(SenderType::Sender(server_sender)); + channels::unbounded(); + let send_instructions_to_server = SenderWithContext::new(server_sender); win_sizes.insert(0, winsize); // 0 is the current terminal FakeInputOutput { read_buffers: Arc::new(Mutex::new(HashMap::new())), @@ -125,10 +123,7 @@ impl FakeInputOutput { self } pub fn add_terminal_input(&mut self, input: &[&[u8]]) { - let mut stdin_commands: VecDeque<Vec<u8>> = VecDeque::new(); - for command in input.iter() { - stdin_commands.push_back(command.iter().copied().collect()) - } + let stdin_commands = input.iter().map(|i| i.to_vec()).collect(); self.stdin_commands = Arc::new(Mutex::new(stdin_commands)); } pub fn add_terminal(&self, fd: RawFd) { @@ -281,26 +276,18 @@ impl ServerOsApi for FakeInputOutput { fn write_to_tty_stdin(&self, pid: RawFd, buf: &[u8]) -> Result<usize, nix::Error> { let mut stdin_writes = self.stdin_writes.lock().unwrap(); let write_buffer = stdin_writes.get_mut(&pid).unwrap(); - let mut bytes_written = 0; - for byte in buf { - bytes_written += 1; - write_buffer.push(*byte); - } - Ok(bytes_written) + Ok(write_buffer.write(buf).unwrap()) } - fn read_from_tty_stdout(&self, pid: RawFd, buf: &mut [u8]) -> Result<usize, nix::Error> { + fn read_from_tty_stdout(&self, pid: RawFd, mut buf: &mut [u8]) -> Result<usize, nix::Error> { let mut read_buffers = self.read_buffers.lock().unwrap(); - let mut bytes_read = 0; match read_buffers.get_mut(&pid) { Some(bytes) => { - for i in bytes.read_position..bytes.content.len() { - bytes_read += 1; - buf[i] = bytes.content[i]; - } - if bytes_read > bytes.read_position { - bytes.set_read_position(bytes_read); + let available_range = bytes.read_position..bytes.content.len(); + let len = buf.write(&bytes.content[available_range]).unwrap(); + if len > bytes.read_position { + bytes.set_read_position(len); } - return Ok(bytes_read); + return Ok(len); } None => Err(nix::Error::Sys(nix::errno::Errno::EAGAIN)), } diff --git a/zellij-client/src/lib.rs b/zellij-client/src/lib.rs index f3b0f11d3..3817e7908 100644 --- a/zellij-client/src/lib.rs +++ b/zellij-client/src/lib.rs @@ -7,7 +7,6 @@ use std::env::current_exe; use std::io::{self, Write}; use std::path::Path; use std::process::Command; -use std::sync::mpsc; use std::thread; use crate::{ @@ -16,7 +15,7 @@ use crate::{ }; use zellij_utils::cli::CliArgs; use zellij_utils::{ - channels::{SenderType, SenderWithContext, SyncChannelWithContext}, + channels::{self, ChannelWithContext, SenderWithContext}, consts::{SESSION_NAME, ZELLIJ_IPC_PIPE}, errors::{ClientContext, ContextType, ErrorInstruction}, input::{actions::Action, config::Config, options::Options}, @@ -149,11 +148,10 @@ pub fn start_client( .write(bracketed_paste.as_bytes()) .unwrap(); - let (send_client_instructions, receive_client_instructions): SyncChannelWithContext< + let (send_client_instructions, receive_client_instructions): ChannelWithContext< ClientInstruction, - > = mpsc::sync_channel(50); - let send_client_instructions = - SenderWithContext::new(SenderType::SyncSender(send_client_instructions)); + > = channels::bounded(50); + let send_client_instructions = SenderWithContext::new(send_client_instructions); #[cfg(not(any(feature = "test", test)))] std::panic::set_hook({ diff --git a/zellij-server/src/lib.rs b/zellij-server/src/lib.rs index 348734957..6d33eddf1 100644 --- a/zellij-server/src/lib.rs +++ b/zellij-server/src/lib.rs @@ -11,9 +11,9 @@ mod wasm_vm; use zellij_utils::zellij_tile; +use std::path::PathBuf; use std::sync::{Arc, Mutex, RwLock}; use std::thread; -use std::{path::PathBuf, sync::mpsc}; use wasmer::Store; use zellij_tile::data::{Event, InputMode, PluginCapabilities}; @@ -27,7 +27,8 @@ use crate::{ }; use route::route_thread_main; use zellij_utils::{ - channels::{ChannelWithContext, SenderType, SenderWithContext, SyncChannelWithContext}, + channels, + channels::{ChannelWithContext, SenderWithContext}, cli::CliArgs, errors::{ContextType, ErrorInstruction, ServerContext}, input::{get_mode_info, options::Options}, @@ -117,9 +118,8 @@ pub fn start_server(os_input: Box<dyn ServerOsApi>, socket_path: PathBuf) { std::env::set_var(&"ZELLIJ", "0"); - let (to_server, server_receiver): SyncChannelWithContext<ServerInstruction> = - mpsc::sync_channel(50); - let to_server = SenderWithContext::new(SenderType::SyncSender(to_server)); + let (to_server, server_receiver): ChannelWithContext<ServerInstruction> = channels::bounded(50); + let to_server = SenderWithContext::new(to_server); let session_data: Arc<RwLock<Option<SessionMetaData>>> = Arc::new(RwLock::new(None)); let session_state = Arc::new(RwLock::new(SessionState::Uninitialized)); @@ -301,13 +301,12 @@ fn init_session( client_attributes: ClientAttributes, session_state: Arc<RwLock<SessionState>>, ) -> SessionMetaData { - let (to_screen, screen_receiver): ChannelWithContext<ScreenInstruction> = mpsc::channel(); - let to_screen = SenderWithContext::new(SenderType::Sender(to_screen)); - - let (to_plugin, plugin_receiver): ChannelWithContext<PluginInstruction> = mpsc::channel(); - let to_plugin = SenderWithContext::new(SenderType::Sender(to_plugin)); - let (to_pty, pty_receiver): ChannelWithContext<PtyInstruction> = mpsc::channel(); - let to_pty = SenderWithContext::new(SenderType::Sender(to_pty)); + let (to_screen, screen_receiver): ChannelWithContext<ScreenInstruction> = channels::unbounded(); + let to_screen = SenderWithContext::new(to_screen); + let (to_plugin, plugin_receiver): ChannelWithContext<PluginInstruction> = channels::unbounded(); + let to_plugin = SenderWithContext::new(to_plugin); + let (to_pty, pty_receiver): ChannelWithContext<PtyInstruction> = channels::unbounded(); + let to_pty = SenderWithContext::new(to_pty); // Determine and initialize the data directory let data_dir = opts.data_dir.unwrap_or_else(get_default_data_dir); @@ -334,7 +333,7 @@ fn init_session( .spawn({ let pty = Pty::new( Bus::new( - pty_receiver, + vec![pty_receiver], Some(&to_screen), None, Some(&to_plugin), @@ -352,7 +351,7 @@ fn init_session( .name("screen".to_string()) .spawn({ let screen_bus = Bus::new( - screen_receiver, + vec![screen_receiver], None, Some(&to_pty), Some(&to_plugin), @@ -377,7 +376,7 @@ fn init_session( .name("wasm".to_string()) .spawn({ let plugin_bus = Bus::new( - plugin_receiver, + vec![plugin_receiver], Some(&to_screen), Some(&to_pty), None, diff --git a/zellij-server/src/thread_bus.rs b/zellij-server/src/thread_bus.rs index afc1d875b..f7c1c1ebc 100644 --- a/zellij-server/src/thread_bus.rs +++ b/zellij-server/src/thread_bus.rs @@ -4,8 +4,7 @@ use crate::{ os_input_output::ServerOsApi, pty::PtyInstruction, screen::ScreenInstruction, wasm_vm::PluginInstruction, ServerInstruction, }; -use std::sync::mpsc; -use zellij_utils::{channels::SenderWithContext, errors::ErrorContext}; +use zellij_utils::{channels, channels::SenderWithContext, errors::ErrorContext}; /// A container for senders to the different threads in zellij on the server side #[derive(Clone)] @@ -20,42 +19,42 @@ impl ThreadSenders { pub fn send_to_screen( &self, instruction: ScreenInstruction, - ) -> Result<(), mpsc::SendError<(ScreenInstruction, ErrorContext)>> { + ) -> Result<(), channels::SendError<(ScreenInstruction, ErrorContext)>> { self.to_screen.as_ref().unwrap().send(instruction) } pub fn send_to_pty( &self, instruction: PtyInstruction, - ) -> Result<(), mpsc::SendError<(PtyInstruction, ErrorContext)>> { + ) -> Result<(), channels::SendError<(PtyInstruction, ErrorContext)>> { self.to_pty.as_ref().unwrap().send(instruction) } pub fn send_to_plugin( &self, instruction: PluginInstruction, - ) -> Result<(), mpsc::SendError<(PluginInstruction, ErrorContext)>> { + ) -> Result<(), channels::SendError<(PluginInstruction, ErrorContext)>> { self.to_plugin.as_ref().unwrap().send(instruction) } pub fn send_to_server( &self, instruction: ServerInstruction, - ) -> Result<(), mpsc::SendError<(ServerInstruction, ErrorContext)>> { + ) -> Result<(), channels::SendError<(ServerInstruction, ErrorContext)>> { self.to_server.as_ref().unwrap().send(instruction) } } /// A container for a receiver, OS input and the senders to a given thread pub(crate) struct Bus<T> { - pub receiver: mpsc::Receiver<(T, ErrorContext)>, + receivers: Vec<channels::Receiver<(T, ErrorContext)>>, pub senders: ThreadSenders, pub os_input: Option<Box<dyn ServerOsApi>>, } impl<T> Bus<T> { pub fn new( - receiver: mpsc::Receiver<(T, ErrorContext)>, + receivers: Vec<channels::Receiver<(T, ErrorContext)>>, to_screen: Option<&SenderWithContext<ScreenInstruction>>, to_pty: Option<&SenderWithContext<PtyInstruction>>, to_plugin: Option<&SenderWithContext<PluginInstruction>>, @@ -63,7 +62,7 @@ impl<T> Bus<T> { os_input: Option<Box<dyn ServerOsApi>>, ) -> Self { Bus { - receiver, + receivers, senders: ThreadSenders { to_screen: to_screen.cloned(), to_pty: to_pty.cloned(), @@ -74,7 +73,13 @@ impl<T> Bus<T> { } } - pub fn recv(&self) -> Result<(T, ErrorContext), mpsc::RecvError> { - self.receiver.recv() + pub fn recv(&self) -> Result<(T, ErrorContext), channels::RecvError> { + let mut selector = channels::Select::new(); + self.receivers.iter().for_each(|r| { + selector.recv(r); + }); + let oper = selector.select(); + let idx = oper.index(); + oper.recv(&self.receivers[idx]) } } diff --git a/zellij-utils/Cargo.toml b/zellij-utils/Cargo.toml index 61205d9d7..e93f99188 100644 --- a/zellij-utils/Cargo.toml +++ b/zellij-utils/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT" backtrace = "0.3.55" bincode = "1.3.1" colors-transform = "0.2.5" +crossbeam = "0.8.0" directories-next = "2.0" interprocess = "1.1.1" lazy_static = "1.4.0" diff --git a/zellij-utils/src/channels.rs b/zellij-utils/src/channels.rs index 8a97fa7d8..26a271671 100644 --- a/zellij-utils/src/channels.rs +++ b/zellij-utils/src/channels.rs @@ -2,56 +2,33 @@ use async_std::task_local; use std::cell::RefCell; -use std::sync::mpsc; use crate::errors::{get_current_ctx, ErrorContext}; +pub use crossbeam::channel::{bounded, unbounded, Receiver, RecvError, Select, SendError, Sender}; /// An [MPSC](mpsc) asynchronous channel with added error context. -pub type ChannelWithContext<T> = ( - mpsc::Sender<(T, ErrorContext)>, - mpsc::Receiver<(T, ErrorContext)>, -); -/// An [MPSC](mpsc) synchronous channel with added error context. -pub type SyncChannelWithContext<T> = ( - mpsc::SyncSender<(T, ErrorContext)>, - mpsc::Receiver<(T, ErrorContext)>, -); - -/// Wrappers around the two standard [MPSC](mpsc) sender types, [`mpsc::Sender`] and [`mpsc::SyncSender`], with an additional [`ErrorContext`]. -#[derive(Clone)] -pub enum SenderType<T: Clone> { - /// A wrapper around an [`mpsc::Sender`], adding an [`ErrorContext`]. - Sender(mpsc::Sender<(T, ErrorContext)>), - /// A wrapper around an [`mpsc::SyncSender`], adding an [`ErrorContext`]. - SyncSender(mpsc::SyncSender<(T, ErrorContext)>), -} +pub type ChannelWithContext<T> = (Sender<(T, ErrorContext)>, Receiver<(T, ErrorContext)>); /// Sends messages on an [MPSC](std::sync::mpsc) channel, along with an [`ErrorContext`], /// synchronously or asynchronously depending on the underlying [`SenderType`]. #[derive(Clone)] -pub struct SenderWithContext<T: Clone> { - sender: SenderType<T>, +pub struct SenderWithContext<T> { + sender: Sender<(T, ErrorContext)>, } impl<T: Clone> SenderWithContext<T> { - pub fn new(sender: SenderType<T>) -> Self { + pub fn new(sender: Sender<(T, ErrorContext)>) -> Self { Self { sender } } /// Sends an event, along with the current [`ErrorContext`], on this /// [`SenderWithContext`]'s channel. - pub fn send(&self, event: T) -> Result<(), mpsc::SendError<(T, ErrorContext)>> { + pub fn send(&self, event: T) -> Result<(), SendError<(T, ErrorContext)>> { let err_ctx = get_current_ctx(); - match self.sender { - SenderType::Sender(ref s) => s.send((event, err_ctx)), - SenderType::SyncSender(ref s) => s.send((event, err_ctx)), - } + self.sender.send((event, err_ctx)) } } -unsafe impl<T: Clone> Send for SenderWithContext<T> {} -unsafe impl<T: Clone> Sync for SenderWithContext<T> {} - thread_local!( /// A key to some thread local storage (TLS) that holds a representation of the thread's call /// stack in the form of an [`ErrorContext`]. |