diff options
-rw-r--r-- | common/mqtt_client/src/lib.rs | 31 | ||||
-rw-r--r-- | common/mqtt_client/tests/mqtt_pub_sub_test.rs | 85 |
2 files changed, 104 insertions, 12 deletions
diff --git a/common/mqtt_client/src/lib.rs b/common/mqtt_client/src/lib.rs index cd95c086..92fc66af 100644 --- a/common/mqtt_client/src/lib.rs +++ b/common/mqtt_client/src/lib.rs @@ -362,10 +362,10 @@ impl MqttClient for Client { &self, filter: TopicFilter, ) -> Result<Box<dyn MqttMessageStream>, MqttClientError> { - let () = self - .mqtt_client - .subscribe(&filter.pattern, filter.qos) - .await?; + let qos = filter.qos; + for pattern in filter.patterns.iter() { + let () = self.mqtt_client.subscribe(pattern, qos).await?; + } Ok(Box::new(MessageStream::new( filter, @@ -507,7 +507,7 @@ impl Topic { /// Build a topic filter filtering only that topic pub fn filter(&self) -> TopicFilter { TopicFilter { - pattern: self.name.clone(), + patterns: vec![self.name.clone()], qos: QoS::AtLeastOnce, } } @@ -516,7 +516,7 @@ impl Topic { /// An MQTT topic filter #[derive(Debug, Clone, Eq, PartialEq)] pub struct TopicFilter { - pub pattern: String, + pub patterns: Vec<String>, pub qos: QoS, } @@ -526,7 +526,20 @@ impl TopicFilter { let pattern = String::from(pattern); let qos = QoS::AtLeastOnce; if rumqttc::valid_filter(&pattern) { - Ok(TopicFilter { pattern, qos }) + Ok(TopicFilter { + patterns: vec![pattern], + qos, + }) + } else { + Err(MqttClientError::InvalidFilter { pattern }) + } + } + + /// Check if the pattern is valid and at it to this topic filter. + pub fn add(&mut self, pattern: &str) -> Result<(), MqttClientError> { + let pattern = String::from(pattern); + if rumqttc::valid_filter(&pattern) { + Ok(self.patterns.push(pattern)) } else { Err(MqttClientError::InvalidFilter { pattern }) } @@ -534,7 +547,9 @@ impl TopicFilter { /// Check if the given topic matches this filter pattern. fn accept(&self, topic: &Topic) -> bool { - rumqttc::matches(&topic.name, &self.pattern) + self.patterns + .iter() + .any(|pattern| rumqttc::matches(&topic.name, &pattern)) } pub fn qos(self, qos: QoS) -> Self { diff --git a/common/mqtt_client/tests/mqtt_pub_sub_test.rs b/common/mqtt_client/tests/mqtt_pub_sub_test.rs index 29d43e25..71beb205 100644 --- a/common/mqtt_client/tests/mqtt_pub_sub_test.rs +++ b/common/mqtt_client/tests/mqtt_pub_sub_test.rs @@ -1,13 +1,12 @@ mod rumqttd_broker; +use mqtt_client::{Client, Message, MqttClient, Topic, TopicFilter}; +use std::time::Duration; +use tokio::time::sleep; const MQTTTESTPORT: u16 = 58586; #[test] fn sending_and_receiving_a_message() { - use mqtt_client::{Client, Message, MqttClient, Topic}; - use std::time::Duration; - use tokio::time::sleep; - async fn scenario(payload: String) -> Result<Option<Message>, mqtt_client::MqttClientError> { let _mqtt_server_handle = tokio::spawn(async { rumqttd_broker::start_broker_local(MQTTTESTPORT).await }); @@ -40,3 +39,81 @@ fn sending_and_receiving_a_message() { Err(e) => panic!("Got an error: {}", e), } } + +#[tokio::test] +async fn subscribing_to_many_topics() -> Result<(), anyhow::Error> { + // Given an MQTT broker + let mqtt_port: u16 = 55555; + let _mqtt_server_handle = + tokio::spawn(async move { rumqttd_broker::start_broker_local(mqtt_port).await }); + + // And an MQTT client connected to that server + let subscriber = Client::connect( + "client_subscribing_to_many_topics", + &mqtt_client::Config::default().with_port(mqtt_port), + ) + .await?; + + // The client can subscribe to many topics + let mut topic_filter = TopicFilter::new("/a/first/topic")?; + topic_filter.add("/a/second/topic")?; + topic_filter.add("/a/+/pattern")?; // one can use + pattern + topic_filter.add("/any/#")?; // one can use # pattern + + // The messages for these topics will all be received on the same message stream + let mut messages = subscriber.subscribe(topic_filter).await?; + + // So let us create another MQTT client publishing messages. + let publisher = Client::connect( + "client_publishing_to_many_topics", + &mqtt_client::Config::default().with_port(mqtt_port), + ) + .await?; + + // A message published on any of the subscribed topics must be received + for (topic_name, payload) in vec![ + ("/a/first/topic", "a first message"), + ("/a/second/topic", "a second message"), + ("/a/plus/pattern", "a third message"), + ("/any/sub/topic", "a fourth message"), + ] + .into_iter() + { + let topic = Topic::new(topic_name)?; + let message = Message::new(&topic, payload); + let () = publisher.publish(message).await?; + + tokio::select! { + maybe_msg = messages.next() => { + let msg = maybe_msg.expect("Unexpected end of stream"); + assert_eq!(msg.topic, topic); + assert_eq!(msg.payload_str()?, payload); + } + _ = sleep(Duration::from_millis(1000)) => { + assert!(false, "No message received after a second"); + } + } + } + + // No message should be received from un-subscribed topics + for (topic_name, payload) in vec![ + ("/a/third/topic", "unrelated message"), + ("/unrelated/topic", "unrelated message"), + ] + .into_iter() + { + let topic = Topic::new(topic_name)?; + let message = Message::new(&topic, payload); + let () = publisher.publish(message).await?; + + tokio::select! { + _ = messages.next() => { + assert!(false, "Unrelated message received"); + } + _ = sleep(Duration::from_millis(1000)) => { + } + } + } + + Ok(()) +} |