summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorkxt <ktamas@fastmail.fm>2021-05-27 16:28:28 +0200
committerGitHub <noreply@github.com>2021-05-27 16:28:28 +0200
commit0c0355dbc6e0159a72b0f55c7aabb83d76c2312a (patch)
treeeb09715c31950f918efe54bd189af65b5b2fdab0 /src
parent9bdb40b4c644c6a3a061dd0cc4683fc92d504201 (diff)
refactors for #525 (#534)
* refactor(fakes): clean up add_terminal_input * refactor(fakes): append whole buf to output_buffer in FakeStdoutWriter::write * refactor(fakes): append whole buf to output_buffer in FakeInputOutput::write_to_tty_stdin * fix(fakes): allow partial reads in read_from_tty_stdout This patch fixes two bugs in read_from_tty_stdout: * if there was a partial read (ie. `bytes.read_position` is not 0 but less than `bytes.content.len()`), subsequent calls to would fill `buf` starting at index `bytes.read_position` instead of 0, leaving range 0..`bytes.read_position` untouched. * if `buf` was smaller than `bytes.content.len()`, a panic would occur. * refactor(channels): use crossbeam instead of mpsc This patch replaces mpsc with crossbeam channels because crossbeam supports selecting on multiple channels which will be necessary in a subsequent patch. * refactor(threadbus): allow multiple receivers in Bus This patch changes Bus to use multiple receivers. Method `recv` returns data from all of them. This will be used in a subsequent patch for receiving from bounded and unbounded queues at the same time. * refactor(channels): remove SenderType enum This enum has only one variant, so the entire enum can be replaced with the innards of said variant. * refactor(channels): remove Send+Sync trait implementations The implementation of these traits is not necessary, as SenderWithContext is automatically Send and Sync for every T and ErrorContext that's Send and Sync.
Diffstat (limited to 'src')
-rw-r--r--src/tests/fakes.rs53
1 files changed, 20 insertions, 33 deletions
diff --git a/src/tests/fakes.rs b/src/tests/fakes.rs
index 23f9f1824..8d11458b9 100644
--- a/src/tests/fakes.rs
+++ b/src/tests/fakes.rs
@@ -2,7 +2,7 @@ use std::collections::{HashMap, VecDeque};
use std::io::Write;
use std::os::unix::io::RawFd;
use std::path::PathBuf;
-use std::sync::{mpsc, Arc, Condvar, Mutex};
+use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
use zellij_utils::{nix, zellij_tile};
@@ -14,7 +14,7 @@ use zellij_server::os_input_output::{async_trait, AsyncReader, Pid, ServerOsApi}
use zellij_tile::data::Palette;
use zellij_utils::{
async_std,
- channels::{ChannelWithContext, SenderType, SenderWithContext},
+ channels::{self, ChannelWithContext, SenderWithContext},
errors::ErrorContext,
interprocess::local_socket::LocalSocketStream,
ipc::{ClientToServerMsg, ServerToClientMsg},
@@ -52,13 +52,9 @@ impl FakeStdoutWriter {
impl Write for FakeStdoutWriter {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
- let mut bytes_written = 0;
let mut output_buffer = self.output_buffer.lock().unwrap();
- for byte in buf {
- bytes_written += 1;
- output_buffer.push(*byte);
- }
- Ok(bytes_written)
+ output_buffer.extend_from_slice(buf);
+ Ok(buf.len())
}
fn flush(&mut self) -> Result<(), std::io::Error> {
let mut output_buffer = self.output_buffer.lock().unwrap();
@@ -83,9 +79,11 @@ pub struct FakeInputOutput {
possible_tty_inputs: HashMap<u16, Bytes>,
last_snapshot_time: Arc<Mutex<Instant>>,
send_instructions_to_client: SenderWithContext<ServerToClientMsg>,
- receive_instructions_from_server: Arc<Mutex<mpsc::Receiver<(ServerToClientMsg, ErrorContext)>>>,
+ receive_instructions_from_server:
+ Arc<Mutex<channels::Receiver<(ServerToClientMsg, ErrorContext)>>>,
send_instructions_to_server: SenderWithContext<ClientToServerMsg>,
- receive_instructions_from_client: Arc<Mutex<mpsc::Receiver<(ClientToServerMsg, ErrorContext)>>>,
+ receive_instructions_from_client:
+ Arc<Mutex<channels::Receiver<(ClientToServerMsg, ErrorContext)>>>,
should_trigger_sigwinch: Arc<(Mutex<bool>, Condvar)>,
sigwinch_event: Option<PositionAndSize>,
}
@@ -96,11 +94,11 @@ impl FakeInputOutput {
let last_snapshot_time = Arc::new(Mutex::new(Instant::now()));
let stdout_writer = FakeStdoutWriter::new(last_snapshot_time.clone());
let (client_sender, client_receiver): ChannelWithContext<ServerToClientMsg> =
- mpsc::channel();
- let send_instructions_to_client = SenderWithContext::new(SenderType::Sender(client_sender));
+ channels::unbounded();
+ let send_instructions_to_client = SenderWithContext::new(client_sender);
let (server_sender, server_receiver): ChannelWithContext<ClientToServerMsg> =
- mpsc::channel();
- let send_instructions_to_server = SenderWithContext::new(SenderType::Sender(server_sender));
+ channels::unbounded();
+ let send_instructions_to_server = SenderWithContext::new(server_sender);
win_sizes.insert(0, winsize); // 0 is the current terminal
FakeInputOutput {
read_buffers: Arc::new(Mutex::new(HashMap::new())),
@@ -125,10 +123,7 @@ impl FakeInputOutput {
self
}
pub fn add_terminal_input(&mut self, input: &[&[u8]]) {
- let mut stdin_commands: VecDeque<Vec<u8>> = VecDeque::new();
- for command in input.iter() {
- stdin_commands.push_back(command.iter().copied().collect())
- }
+ let stdin_commands = input.iter().map(|i| i.to_vec()).collect();
self.stdin_commands = Arc::new(Mutex::new(stdin_commands));
}
pub fn add_terminal(&self, fd: RawFd) {
@@ -281,26 +276,18 @@ impl ServerOsApi for FakeInputOutput {
fn write_to_tty_stdin(&self, pid: RawFd, buf: &[u8]) -> Result<usize, nix::Error> {
let mut stdin_writes = self.stdin_writes.lock().unwrap();
let write_buffer = stdin_writes.get_mut(&pid).unwrap();
- let mut bytes_written = 0;
- for byte in buf {
- bytes_written += 1;
- write_buffer.push(*byte);
- }
- Ok(bytes_written)
+ Ok(write_buffer.write(buf).unwrap())
}
- fn read_from_tty_stdout(&self, pid: RawFd, buf: &mut [u8]) -> Result<usize, nix::Error> {
+ fn read_from_tty_stdout(&self, pid: RawFd, mut buf: &mut [u8]) -> Result<usize, nix::Error> {
let mut read_buffers = self.read_buffers.lock().unwrap();
- let mut bytes_read = 0;
match read_buffers.get_mut(&pid) {
Some(bytes) => {
- for i in bytes.read_position..bytes.content.len() {
- bytes_read += 1;
- buf[i] = bytes.content[i];
- }
- if bytes_read > bytes.read_position {
- bytes.set_read_position(bytes_read);
+ let available_range = bytes.read_position..bytes.content.len();
+ let len = buf.write(&bytes.content[available_range]).unwrap();
+ if len > bytes.read_position {
+ bytes.set_read_position(len);
}
- return Ok(bytes_read);
+ return Ok(len);
}
None => Err(nix::Error::Sys(nix::errno::Errno::EAGAIN)),
}