summaryrefslogtreecommitdiffstats
path: root/mqtt-tester/src/command.rs
blob: 7e122611779e563bc7adc2b059984a7f15510095 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
//
//   This Source Code Form is subject to the terms of the Mozilla Public
//   License, v. 2.0. If a copy of the MPL was not distributed with this
//   file, You can obtain one at http://mozilla.org/MPL/2.0/.
//

use bytes::{BufMut, BytesMut};
use miette::IntoDiagnostic;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

pub struct Command {
    inner: tokio::process::Command,
}

pub type CheckBytesFn = Box<dyn FnOnce(&[u8]) -> bool>;

pub enum ClientCommand {
    Send(Vec<u8>),
    #[allow(unused)]
    WaitFor(Vec<u8>),
    #[allow(unused)]
    WaitAndCheck(CheckBytesFn),
}

impl Command {
    pub fn new(inner: tokio::process::Command) -> Self {
        Self { inner }
    }

    pub async fn wait_for_write<C>(
        mut self,
        commands: C,
    ) -> Result<std::process::Output, miette::Error>
    where
        C: IntoIterator<Item = ClientCommand>,
    {
        let mut client = self.inner.spawn().into_diagnostic()?;

        let mut to_client = client.stdin.take().unwrap();
        let mut from_client = client.stdout.take().unwrap();

        for command in commands {
            match command {
                ClientCommand::Send(bytes) => {
                    to_client.write_all(&bytes).await.into_diagnostic()?
                }
                ClientCommand::WaitFor(expected_bytes) => {
                    let mut buf = vec![0; expected_bytes.len()];
                    match tokio::time::timeout(
                        std::time::Duration::from_millis(100),
                        from_client.read_exact(&mut buf),
                    )
                    .await
                    {
                        Ok(Ok(_)) => {
                            if buf != expected_bytes {
                                return Err(miette::miette!(
                                    "Received Bytes did not match expected bytes: {:?} != {:?}",
                                    buf,
                                    expected_bytes
                                ));
                            }
                        }
                        Ok(Err(e)) => return Err(e).into_diagnostic(),
                        Err(_elapsed) => {
                            return Err(miette::miette!("Did not hear from server until timeout"))
                        }
                    }
                }
                ClientCommand::WaitAndCheck(check) => {
                    match tokio::time::timeout(std::time::Duration::from_millis(100), async {
                        let mut buffer = BytesMut::new();
                        buffer.put_u16(from_client.read_u16().await.into_diagnostic()?);
                        buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);

                        if buffer[1] & 0b1000_0000 != 0 {
                            buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);
                            if buffer[2] & 0b1000_0000 != 0 {
                                buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);
                                if buffer[3] & 0b1000_0000 != 0 {
                                    buffer.put_u8(from_client.read_u8().await.into_diagnostic()?);
                                }
                            }
                        }

                        let rest_len = buffer[1..].iter().enumerate().fold(0, |val, (exp, len)| {
                            val + (*len as u32 & 0b0111_1111) * 128u32.pow(exp as u32)
                        });

                        let mut rest_buf = buffer.limit(rest_len as usize);
                        from_client
                            .read_buf(&mut rest_buf)
                            .await
                            .into_diagnostic()?;
                        Ok::<_, miette::Error>(rest_buf.into_inner())
                    })
                    .await
                    {
                        Ok(Ok(buffer)) => {
                            if !check(&buffer) {
                                return Err(miette::miette!("Check failed for Bytes {:?}", buffer));
                            }
                        }
                        Ok(Err(e)) => return Err(e),
                        Err(_elapsed) => {
                            return Err(miette::miette!("Did not hear from server until timeout"))
                        }
                    }
                }
            }
        }

        client.wait_with_output().await.into_diagnostic()
    }
}