diff options
author | bors[bot] <26634292+bors[bot]@users.noreply.github.com> | 2023-01-05 10:03:17 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-05 10:03:17 +0000 |
commit | a1b312fb0e2c0e7a8c8596674cbf70a5cf8a6e8d (patch) | |
tree | 945d4cc900638fc9c67357b5b6905ea776a72c27 | |
parent | dd49fdcc40e1d1b99ea69b2cc3866b357b0ece54 (diff) | |
parent | f4d9d50e1e9cd22056af27f5108e3b6cebda8578 (diff) |
Merge #119
119: Feature/add manual subscriptions r=TheNeikos a=TheNeikos
Co-authored-by: Marcel Müller <neikos@neikos.email>
-rw-r--r-- | src/bin/cloudmqtt-server.rs | 11 | ||||
-rw-r--r-- | src/server/mod.rs | 67 | ||||
-rw-r--r-- | src/server/subscriptions.rs | 4 |
3 files changed, 78 insertions, 4 deletions
diff --git a/src/bin/cloudmqtt-server.rs b/src/bin/cloudmqtt-server.rs index 37ad878..cfa1acc 100644 --- a/src/bin/cloudmqtt-server.rs +++ b/src/bin/cloudmqtt-server.rs @@ -72,5 +72,14 @@ async fn main() { .with_login_handler(SimpleLoginHandler) .with_subscription_handler(SimpleSubscriptionHandler); - Arc::new(server).accept_new_clients().await.unwrap(); + let server = Arc::new(server); + + tokio::spawn(server.clone().subscribe_to_message( + vec![String::from("foo/bar"), String::from("bar/#")], + |msg| async move { + info!("Got message: {msg:?}"); + }, + )); + + server.accept_new_clients().await.unwrap(); } diff --git a/src/server/mod.rs b/src/server/mod.rs index 6b482a9..17f1f97 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -52,9 +52,10 @@ use mqtt_format::v3::{ use tokio::{ io::{AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf}, net::{TcpListener, ToSocketAddrs}, + sync::broadcast::Sender as BroadcastSender, sync::Mutex, }; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, trace, warn}; use crate::{error::MqttError, mqtt_stream::MqttStream, PacketIOError}; use subscriptions::{ClientInformation, SubscriptionManager}; @@ -65,6 +66,7 @@ use self::{ }, message::MqttMessage, state::ClientState, + subscriptions::TopicFilter, }; /// The unique id (per server) of a connecting client @@ -143,6 +145,7 @@ pub struct MqttServer<LoginH, SubH> { clients: Arc<DashMap<ClientId, ClientState>>, client_source: Mutex<ClientSource>, auth_handler: LoginH, + extra_listeners: BroadcastSender<MqttMessage>, subscription_manager: Arc<SubscriptionManager<SubH>>, } @@ -153,10 +156,13 @@ impl MqttServer<AllowAllLogins, AllowAllSubscriptions> { ) -> Result<Self, MqttError> { let bind = TcpListener::bind(addr).await?; + let (extra_listeners, _) = tokio::sync::broadcast::channel(50); + Ok(MqttServer { clients: Arc::new(DashMap::new()), client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)), auth_handler: AllowAllLogins, + extra_listeners, subscription_manager: Arc::new(SubscriptionManager::new()), }) } @@ -176,6 +182,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { clients: self.clients, client_source: self.client_source, auth_handler: new_login_handler, + extra_listeners: self.extra_listeners, subscription_manager: self.subscription_manager, } } @@ -193,6 +200,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { clients: self.clients, client_source: self.client_source, auth_handler: self.auth_handler, + extra_listeners: self.extra_listeners, subscription_manager: Arc::new({ let manager = Arc::try_unwrap(self.subscription_manager); @@ -221,6 +229,61 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { } } + /// Listen to messages sent to the given topic_paths + pub async fn subscribe_to_message< + Fut: std::future::Future<Output = ()>, + CB: FnMut(MqttMessage) -> Fut, + >( + self: Arc<Self>, + topic_paths: Vec<String>, + mut callback: CB, + ) -> Result<(), MqttError> { + let mut listener = self.extra_listeners.subscribe(); + + let topics = topic_paths + .into_iter() + .map(TopicFilter::parse_from) + .collect::<Vec<_>>(); + + loop { + let message = listener.recv().await; + + match message { + Ok(message) => { + if topics.iter().any(|topic| { + let msg_topic = TopicFilter::parse_from(message.topic().to_string()); + + let mut i = 0; + loop { + match (topic.get(i), msg_topic.get(i)) { + (None, None) => break true, + (None, Some(_)) => break false, + (Some(_), None) => break false, + (Some(TopicFilter::MultiWildcard), Some(_)) => break true, + (Some(TopicFilter::SingleWildcard), Some(_)) => (), + (Some(left), Some(right)) => { + if left != right { + break false; + } + } + } + + i += 1; + } + }) { + callback(message).await; + } + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => { + warn!("Subscriber lagged by {count} values") + } + } + } + + Ok(()) + } + /// Accept a new client connected through the `client` stream /// /// This does multiple things: @@ -357,6 +420,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { let subscription_manager = server.subscription_manager.clone(); let client_id = client_id.clone(); let clients = server.clients.clone(); + let extra_listener = server.extra_listeners.clone(); tokio::spawn(async move { let client_id = client_id; @@ -399,6 +463,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { *qos, ); + let _ = extra_listener.send(message.clone()); subscription_manager.route_message(message).await; // Handle QoS 1/AtLeastOnce response diff --git a/src/server/subscriptions.rs b/src/server/subscriptions.rs index ebc3c4c..a45be93 100644 --- a/src/server/subscriptions.rs +++ b/src/server/subscriptions.rs @@ -78,14 +78,14 @@ impl TopicName { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum TopicFilter { +pub enum TopicFilter { MultiWildcard, SingleWildcard, Named(String), } impl TopicFilter { - fn parse_from(topic: String) -> VecDeque<TopicFilter> { + pub fn parse_from(topic: String) -> VecDeque<TopicFilter> { topic .split('/') .map(|piece| match piece { |