1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
use bytes::{BufMut, BytesMut};
use miette::IntoDiagnostic;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub struct Command {
inner: tokio::process::Command,
}
pub type CheckBytesFn = Box<dyn FnOnce(&[u8]) -> bool>;
pub enum ClientCommand {
Send(Vec<u8>),
#[allow(unused)]
WaitFor(Vec<u8>),
#[allow(unused)]
WaitAndCheck(CheckBytesFn),
}
impl Command {
pub fn new(inner: tokio::process::Command) -> Self {
Self { inner }
}
pub async fn wait_for_write<C>(
mut self,
commands: C,
) -> Result<std::process::Output, miette::Error>
where
C: IntoIterator<Item = ClientCommand>,
{
let mut client = self.inner.spawn().into_diagnostic()?;
let mut to_client = client.stdin.take().unwrap();
let mut from_client = client.stdout.take().unwrap();
for command in commands {
match command {
ClientCommand::Send(bytes) => {
to_client.write_all(&bytes).await.into_diagnostic()?
}
ClientCommand::WaitFor(expected_bytes) => {
let mut buf = vec![0; expected_bytes.len()];
match tokio::time::timeout(
std::time::Duration::from_millis(100),
from_client.read_exact(&mut buf),
)
.await
{
Ok(Ok(_)) => {
if buf != expected_bytes {
return Err(miette::miette!(
"Received Bytes did not match expected bytes: {:?} != {:?}",
buf,
expected_bytes
));
}
}
Ok(Err(e)) => return Err(e).into_diagnostic(),
Err(_elapsed) => {
return Err(miette::miette!("Did not hear from server until timeout"))
}
}
}
ClientCommand::WaitAndCheck(check) => {
match tokio::time::timeout(std::time::Duration::from_millis(100), async {
let mut buffer = BytesMut::new();
buffer.put_u16(from_client.read_u16().await.into_diagnostic()?);
buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);
if buffer[1] & 0b1000_0000 != 0 {
buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);
if buffer[2] & 0b1000_0000 != 0 {
buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);
if buffer[3] & 0b1000_0000 != 0 {
buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);
}
}
}
let rest_len = buffer[1..].iter().enumerate().fold(0, |val, (exp, len)| {
val + (*len as u32 & 0b0111_1111) * 128u32.pow(exp as u32)
});
let mut rest_buf = buffer.limit(rest_len as usize);
from_client
.read_buf(&mut rest_buf)
.await
.into_diagnostic()?;
Ok::<_, miette::Error>(rest_buf.into_inner())
})
.await
{
Ok(Ok(buffer)) => {
if !check(&buffer) {
return Err(miette::miette!("Check failed for Bytes {:?}", buffer));
}
}
Ok(Err(e)) => return Err(e),
Err(_elapsed) => {
return Err(miette::miette!("Did not hear from server until timeout"))
}
}
}
}
}
client.wait_with_output().await.into_diagnostic()
}
}
|