summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2023-01-05 10:03:17 +0000
committerGitHub <noreply@github.com>2023-01-05 10:03:17 +0000
commita1b312fb0e2c0e7a8c8596674cbf70a5cf8a6e8d (patch)
tree945d4cc900638fc9c67357b5b6905ea776a72c27
parentdd49fdcc40e1d1b99ea69b2cc3866b357b0ece54 (diff)
parentf4d9d50e1e9cd22056af27f5108e3b6cebda8578 (diff)
Merge #119
119: Feature/add manual subscriptions r=TheNeikos a=TheNeikos Co-authored-by: Marcel Müller <neikos@neikos.email>
-rw-r--r--src/bin/cloudmqtt-server.rs11
-rw-r--r--src/server/mod.rs67
-rw-r--r--src/server/subscriptions.rs4
3 files changed, 78 insertions, 4 deletions
diff --git a/src/bin/cloudmqtt-server.rs b/src/bin/cloudmqtt-server.rs
index 37ad878..cfa1acc 100644
--- a/src/bin/cloudmqtt-server.rs
+++ b/src/bin/cloudmqtt-server.rs
@@ -72,5 +72,14 @@ async fn main() {
.with_login_handler(SimpleLoginHandler)
.with_subscription_handler(SimpleSubscriptionHandler);
- Arc::new(server).accept_new_clients().await.unwrap();
+ let server = Arc::new(server);
+
+ tokio::spawn(server.clone().subscribe_to_message(
+ vec![String::from("foo/bar"), String::from("bar/#")],
+ |msg| async move {
+ info!("Got message: {msg:?}");
+ },
+ ));
+
+ server.accept_new_clients().await.unwrap();
}
diff --git a/src/server/mod.rs b/src/server/mod.rs
index 6b482a9..17f1f97 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -52,9 +52,10 @@ use mqtt_format::v3::{
use tokio::{
io::{AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
net::{TcpListener, ToSocketAddrs},
+ sync::broadcast::Sender as BroadcastSender,
sync::Mutex,
};
-use tracing::{debug, error, info, trace};
+use tracing::{debug, error, info, trace, warn};
use crate::{error::MqttError, mqtt_stream::MqttStream, PacketIOError};
use subscriptions::{ClientInformation, SubscriptionManager};
@@ -65,6 +66,7 @@ use self::{
},
message::MqttMessage,
state::ClientState,
+ subscriptions::TopicFilter,
};
/// The unique id (per server) of a connecting client
@@ -143,6 +145,7 @@ pub struct MqttServer<LoginH, SubH> {
clients: Arc<DashMap<ClientId, ClientState>>,
client_source: Mutex<ClientSource>,
auth_handler: LoginH,
+ extra_listeners: BroadcastSender<MqttMessage>,
subscription_manager: Arc<SubscriptionManager<SubH>>,
}
@@ -153,10 +156,13 @@ impl MqttServer<AllowAllLogins, AllowAllSubscriptions> {
) -> Result<Self, MqttError> {
let bind = TcpListener::bind(addr).await?;
+ let (extra_listeners, _) = tokio::sync::broadcast::channel(50);
+
Ok(MqttServer {
clients: Arc::new(DashMap::new()),
client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)),
auth_handler: AllowAllLogins,
+ extra_listeners,
subscription_manager: Arc::new(SubscriptionManager::new()),
})
}
@@ -176,6 +182,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
clients: self.clients,
client_source: self.client_source,
auth_handler: new_login_handler,
+ extra_listeners: self.extra_listeners,
subscription_manager: self.subscription_manager,
}
}
@@ -193,6 +200,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
clients: self.clients,
client_source: self.client_source,
auth_handler: self.auth_handler,
+ extra_listeners: self.extra_listeners,
subscription_manager: Arc::new({
let manager = Arc::try_unwrap(self.subscription_manager);
@@ -221,6 +229,61 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
}
}
+ /// Listen to messages sent to the given topic_paths
+ pub async fn subscribe_to_message<
+ Fut: std::future::Future<Output = ()>,
+ CB: FnMut(MqttMessage) -> Fut,
+ >(
+ self: Arc<Self>,
+ topic_paths: Vec<String>,
+ mut callback: CB,
+ ) -> Result<(), MqttError> {
+ let mut listener = self.extra_listeners.subscribe();
+
+ let topics = topic_paths
+ .into_iter()
+ .map(TopicFilter::parse_from)
+ .collect::<Vec<_>>();
+
+ loop {
+ let message = listener.recv().await;
+
+ match message {
+ Ok(message) => {
+ if topics.iter().any(|topic| {
+ let msg_topic = TopicFilter::parse_from(message.topic().to_string());
+
+ let mut i = 0;
+ loop {
+ match (topic.get(i), msg_topic.get(i)) {
+ (None, None) => break true,
+ (None, Some(_)) => break false,
+ (Some(_), None) => break false,
+ (Some(TopicFilter::MultiWildcard), Some(_)) => break true,
+ (Some(TopicFilter::SingleWildcard), Some(_)) => (),
+ (Some(left), Some(right)) => {
+ if left != right {
+ break false;
+ }
+ }
+ }
+
+ i += 1;
+ }
+ }) {
+ callback(message).await;
+ }
+ }
+ Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
+ Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
+ warn!("Subscriber lagged by {count} values")
+ }
+ }
+ }
+
+ Ok(())
+ }
+
/// Accept a new client connected through the `client` stream
///
/// This does multiple things:
@@ -357,6 +420,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
let subscription_manager = server.subscription_manager.clone();
let client_id = client_id.clone();
let clients = server.clients.clone();
+ let extra_listener = server.extra_listeners.clone();
tokio::spawn(async move {
let client_id = client_id;
@@ -399,6 +463,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
*qos,
);
+ let _ = extra_listener.send(message.clone());
subscription_manager.route_message(message).await;
// Handle QoS 1/AtLeastOnce response
diff --git a/src/server/subscriptions.rs b/src/server/subscriptions.rs
index ebc3c4c..a45be93 100644
--- a/src/server/subscriptions.rs
+++ b/src/server/subscriptions.rs
@@ -78,14 +78,14 @@ impl TopicName {
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
-enum TopicFilter {
+pub enum TopicFilter {
MultiWildcard,
SingleWildcard,
Named(String),
}
impl TopicFilter {
- fn parse_from(topic: String) -> VecDeque<TopicFilter> {
+ pub fn parse_from(topic: String) -> VecDeque<TopicFilter> {
topic
.split('/')
.map(|piece| match piece {