summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--common/mqtt_client/src/lib.rs31
-rw-r--r--common/mqtt_client/tests/mqtt_pub_sub_test.rs85
2 files changed, 104 insertions, 12 deletions
diff --git a/common/mqtt_client/src/lib.rs b/common/mqtt_client/src/lib.rs
index cd95c086..92fc66af 100644
--- a/common/mqtt_client/src/lib.rs
+++ b/common/mqtt_client/src/lib.rs
@@ -362,10 +362,10 @@ impl MqttClient for Client {
&self,
filter: TopicFilter,
) -> Result<Box<dyn MqttMessageStream>, MqttClientError> {
- let () = self
- .mqtt_client
- .subscribe(&filter.pattern, filter.qos)
- .await?;
+ let qos = filter.qos;
+ for pattern in filter.patterns.iter() {
+ let () = self.mqtt_client.subscribe(pattern, qos).await?;
+ }
Ok(Box::new(MessageStream::new(
filter,
@@ -507,7 +507,7 @@ impl Topic {
/// Build a topic filter filtering only that topic
pub fn filter(&self) -> TopicFilter {
TopicFilter {
- pattern: self.name.clone(),
+ patterns: vec![self.name.clone()],
qos: QoS::AtLeastOnce,
}
}
@@ -516,7 +516,7 @@ impl Topic {
/// An MQTT topic filter
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct TopicFilter {
- pub pattern: String,
+ pub patterns: Vec<String>,
pub qos: QoS,
}
@@ -526,7 +526,20 @@ impl TopicFilter {
let pattern = String::from(pattern);
let qos = QoS::AtLeastOnce;
if rumqttc::valid_filter(&pattern) {
- Ok(TopicFilter { pattern, qos })
+ Ok(TopicFilter {
+ patterns: vec![pattern],
+ qos,
+ })
+ } else {
+ Err(MqttClientError::InvalidFilter { pattern })
+ }
+ }
+
+ /// Check if the pattern is valid and at it to this topic filter.
+ pub fn add(&mut self, pattern: &str) -> Result<(), MqttClientError> {
+ let pattern = String::from(pattern);
+ if rumqttc::valid_filter(&pattern) {
+ Ok(self.patterns.push(pattern))
} else {
Err(MqttClientError::InvalidFilter { pattern })
}
@@ -534,7 +547,9 @@ impl TopicFilter {
/// Check if the given topic matches this filter pattern.
fn accept(&self, topic: &Topic) -> bool {
- rumqttc::matches(&topic.name, &self.pattern)
+ self.patterns
+ .iter()
+ .any(|pattern| rumqttc::matches(&topic.name, &pattern))
}
pub fn qos(self, qos: QoS) -> Self {
diff --git a/common/mqtt_client/tests/mqtt_pub_sub_test.rs b/common/mqtt_client/tests/mqtt_pub_sub_test.rs
index 29d43e25..71beb205 100644
--- a/common/mqtt_client/tests/mqtt_pub_sub_test.rs
+++ b/common/mqtt_client/tests/mqtt_pub_sub_test.rs
@@ -1,13 +1,12 @@
mod rumqttd_broker;
+use mqtt_client::{Client, Message, MqttClient, Topic, TopicFilter};
+use std::time::Duration;
+use tokio::time::sleep;
const MQTTTESTPORT: u16 = 58586;
#[test]
fn sending_and_receiving_a_message() {
- use mqtt_client::{Client, Message, MqttClient, Topic};
- use std::time::Duration;
- use tokio::time::sleep;
-
async fn scenario(payload: String) -> Result<Option<Message>, mqtt_client::MqttClientError> {
let _mqtt_server_handle =
tokio::spawn(async { rumqttd_broker::start_broker_local(MQTTTESTPORT).await });
@@ -40,3 +39,81 @@ fn sending_and_receiving_a_message() {
Err(e) => panic!("Got an error: {}", e),
}
}
+
+#[tokio::test]
+async fn subscribing_to_many_topics() -> Result<(), anyhow::Error> {
+ // Given an MQTT broker
+ let mqtt_port: u16 = 55555;
+ let _mqtt_server_handle =
+ tokio::spawn(async move { rumqttd_broker::start_broker_local(mqtt_port).await });
+
+ // And an MQTT client connected to that server
+ let subscriber = Client::connect(
+ "client_subscribing_to_many_topics",
+ &mqtt_client::Config::default().with_port(mqtt_port),
+ )
+ .await?;
+
+ // The client can subscribe to many topics
+ let mut topic_filter = TopicFilter::new("/a/first/topic")?;
+ topic_filter.add("/a/second/topic")?;
+ topic_filter.add("/a/+/pattern")?; // one can use + pattern
+ topic_filter.add("/any/#")?; // one can use # pattern
+
+ // The messages for these topics will all be received on the same message stream
+ let mut messages = subscriber.subscribe(topic_filter).await?;
+
+ // So let us create another MQTT client publishing messages.
+ let publisher = Client::connect(
+ "client_publishing_to_many_topics",
+ &mqtt_client::Config::default().with_port(mqtt_port),
+ )
+ .await?;
+
+ // A message published on any of the subscribed topics must be received
+ for (topic_name, payload) in vec![
+ ("/a/first/topic", "a first message"),
+ ("/a/second/topic", "a second message"),
+ ("/a/plus/pattern", "a third message"),
+ ("/any/sub/topic", "a fourth message"),
+ ]
+ .into_iter()
+ {
+ let topic = Topic::new(topic_name)?;
+ let message = Message::new(&topic, payload);
+ let () = publisher.publish(message).await?;
+
+ tokio::select! {
+ maybe_msg = messages.next() => {
+ let msg = maybe_msg.expect("Unexpected end of stream");
+ assert_eq!(msg.topic, topic);
+ assert_eq!(msg.payload_str()?, payload);
+ }
+ _ = sleep(Duration::from_millis(1000)) => {
+ assert!(false, "No message received after a second");
+ }
+ }
+ }
+
+ // No message should be received from un-subscribed topics
+ for (topic_name, payload) in vec![
+ ("/a/third/topic", "unrelated message"),
+ ("/unrelated/topic", "unrelated message"),
+ ]
+ .into_iter()
+ {
+ let topic = Topic::new(topic_name)?;
+ let message = Message::new(&topic, payload);
+ let () = publisher.publish(message).await?;
+
+ tokio::select! {
+ _ = messages.next() => {
+ assert!(false, "Unrelated message received");
+ }
+ _ = sleep(Duration::from_millis(1000)) => {
+ }
+ }
+ }
+
+ Ok(())
+}