summaryrefslogtreecommitdiffstats
path: root/mqtt-tester/src/behaviour/wait_for_connect.rs
blob: 8563f22eca59d68011cfbd5512f478492054a824 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
//
//   This Source Code Form is subject to the terms of the Mozilla Public
//   License, v. 2.0. If a copy of the MPL was not distributed with this
//   file, You can obtain one at http://mozilla.org/MPL/2.0/.
//

use miette::Context;

use crate::{
    behaviour_test::BehaviourTest,
    command::{Input, Output},
    executable::ClientExecutableCommand,
    report::ReportResult,
};

pub struct WaitForConnect;

#[async_trait::async_trait]
impl BehaviourTest for WaitForConnect {
    fn commands(&self) -> Vec<Box<dyn ClientExecutableCommand>> {
        vec![]
    }

    #[tracing::instrument(skip_all)]
    async fn execute(
        &self,
        _input: Input,
        mut output: Output,
    ) -> Result<ReportResult, miette::Error> {
        let check_result = output
            .wait_and_check(
                &(|bytes: &[u8]| -> bool {
                    let connect_flags = if let Some(flags) = find_connect_flags(bytes) {
                        flags
                    } else {
                        return false;
                    };
                    tracing::trace!(?connect_flags, "Connect flags");

                    let username_flag_set = 0 != (connect_flags & 0b1000_0000); // Username flag
                    let password_flag_set = 0 != (connect_flags & 0b0100_0000); // Username flag
                    tracing::trace!(?username_flag_set, "username flag");
                    tracing::trace!(?password_flag_set, "password flag");

                    if username_flag_set {
                        !password_flag_set
                    } else {
                        true
                    }
                }),
            )
            .await
            .context("Waiting for bytes to check")?;

        tracing::trace!(?check_result, "result of check");
        Ok(check_result)
    }

    fn report_name(&self) -> &str {
        "Wait for client to connect"
    }

    fn report_desc(&self) -> &str {
        "A client should send a CONNECT packet to connect to the server"
    }

    fn report_normative(&self) -> &str {
        "none"
    }

    fn translate_client_exit_code(&self, success: bool) -> ReportResult {
        if success {
            ReportResult::Success
        } else {
            ReportResult::Failure
        }
    }
}

fn find_connect_flags(bytes: &[u8]) -> Option<u8> {
    macro_rules! getbyte {
        ($n:tt) => {
            if let Some(b) = bytes.get($n) {
                *b
            } else {
                return None;
            }
        };
    }

    if getbyte!(0) != 0b0001_0000 {
        tracing::trace!("Not a CONNECT packet");
        return None;
    }

    let str_len = getbyte!(3);
    tracing::trace!(?str_len, "Length of protocol name");

    let connect_flag_position = 4usize + (str_len as usize) + 2;
    tracing::trace!(?connect_flag_position, "Position of CONNECT flags");

    let flags = getbyte!(connect_flag_position);
    tracing::trace!(?flags, "CONNECT flags");

    Some(flags)
}