diff options
-rw-r--r-- | src/bin/cloudmqtt-server.rs | 4 | ||||
-rw-r--r-- | src/server/mod.rs | 153 |
2 files changed, 92 insertions, 65 deletions
diff --git a/src/bin/cloudmqtt-server.rs b/src/bin/cloudmqtt-server.rs index cfa1acc..e307d6a 100644 --- a/src/bin/cloudmqtt-server.rs +++ b/src/bin/cloudmqtt-server.rs @@ -72,9 +72,9 @@ async fn main() { .with_login_handler(SimpleLoginHandler) .with_subscription_handler(SimpleSubscriptionHandler); - let server = Arc::new(server); + let server = server; - tokio::spawn(server.clone().subscribe_to_message( + tokio::spawn(server.subscribe_to_message( vec![String::from("foo/bar"), String::from("bar/#")], |msg| async move { info!("Got message: {msg:?}"); diff --git a/src/server/mod.rs b/src/server/mod.rs index a252ca4..a1517ea 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -147,6 +147,18 @@ impl ClientSource { /// Check out the server example for a working version. /// pub struct MqttServer<LoginH, SubH> { + inner: Arc<InnerServer<LoginH, SubH>>, +} + +impl<LoginH, SubH> Clone for MqttServer<LoginH, SubH> { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +struct InnerServer<LoginH, SubH> { clients: Arc<DashMap<ClientId, ClientState>>, client_source: Mutex<ClientSource>, auth_handler: LoginH, @@ -164,11 +176,13 @@ impl MqttServer<AllowAllLogins, AllowAllSubscriptions> { let (extra_listeners, _) = tokio::sync::broadcast::channel(50); Ok(MqttServer { - clients: Arc::new(DashMap::new()), - client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)), - auth_handler: AllowAllLogins, - extra_listeners, - subscription_manager: Arc::new(SubscriptionManager::new()), + inner: Arc::new(InnerServer { + clients: Arc::new(DashMap::new()), + client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)), + auth_handler: AllowAllLogins, + extra_listeners, + subscription_manager: Arc::new(SubscriptionManager::new()), + }), }) } } @@ -183,12 +197,16 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { self, new_login_handler: NLH, ) -> MqttServer<NLH, SH> { + let inner = Arc::try_unwrap(self.inner) + .unwrap_or_else(|_| panic!("Called after started listening")); MqttServer { - clients: self.clients, - client_source: self.client_source, - auth_handler: new_login_handler, - extra_listeners: self.extra_listeners, - subscription_manager: self.subscription_manager, + inner: Arc::new(InnerServer { + clients: inner.clients, + client_source: inner.client_source, + auth_handler: new_login_handler, + extra_listeners: inner.extra_listeners, + subscription_manager: inner.subscription_manager, + }), } } @@ -201,31 +219,36 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { self, new_subscription_handler: NSH, ) -> MqttServer<LH, NSH> { + let inner = Arc::try_unwrap(self.inner) + .unwrap_or_else(|_| panic!("Called after started listening")); MqttServer { - clients: self.clients, - client_source: self.client_source, - auth_handler: self.auth_handler, - extra_listeners: self.extra_listeners, - subscription_manager: Arc::new({ - let manager = Arc::try_unwrap(self.subscription_manager); - - manager - .unwrap_or_else(|_| panic!("Called after started listening")) - .with_subscription_handler(new_subscription_handler) + inner: Arc::new(InnerServer { + clients: inner.clients, + client_source: inner.client_source, + auth_handler: inner.auth_handler, + extra_listeners: inner.extra_listeners, + subscription_manager: Arc::new({ + let manager = Arc::try_unwrap(inner.subscription_manager); + + manager + .unwrap_or_else(|_| panic!("Called after started listening")) + .with_subscription_handler(new_subscription_handler) + }), }), } } /// Start accepting new clients connecting to the server - pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> { + pub async fn accept_new_clients(&self) -> Result<(), MqttError> { let mut client_source = self + .inner .client_source .try_lock() .map_err(|_| MqttError::AlreadyListening)?; loop { + let server: MqttServer<LH, SH> = self.clone(); let client = client_source.accept().await?; - let server = self.clone(); tokio::spawn(async move { if let Err(client_error) = server.accept_client(client).await { tracing::error!("Client error: {}", client_error) @@ -235,58 +258,60 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { } /// Listen to messages sent to the given topic_paths - pub async fn subscribe_to_message< + pub fn subscribe_to_message< Fut: std::future::Future<Output = ()>, - CB: FnMut(MqttMessage) -> Fut, + CB: FnMut(MqttMessage) -> Fut + 'static, >( - self: Arc<Self>, + &self, topic_paths: Vec<String>, mut callback: CB, - ) -> Result<(), MqttError> { - let mut listener = self.extra_listeners.subscribe(); + ) -> impl std::future::Future<Output = Result<(), MqttError>> + 'static { + let mut listener = self.inner.extra_listeners.subscribe(); let topics = topic_paths .into_iter() .map(TopicFilter::parse_from) .collect::<Vec<_>>(); - loop { - let message = listener.recv().await; - - match message { - Ok(message) => { - if topics.iter().any(|topic| { - let msg_topic = TopicFilter::parse_from(message.topic().to_string()); - - let mut i = 0; - loop { - match (topic.get(i), msg_topic.get(i)) { - (None, None) => break true, - (None, Some(_)) => break false, - (Some(_), None) => break false, - (Some(TopicFilter::MultiWildcard), Some(_)) => break true, - (Some(TopicFilter::SingleWildcard), Some(_)) => (), - (Some(left), Some(right)) => { - if left != right { - break false; + async move { + loop { + let message = listener.recv().await; + + match message { + Ok(message) => { + if topics.iter().any(|topic| { + let msg_topic = TopicFilter::parse_from(message.topic().to_string()); + + let mut i = 0; + loop { + match (topic.get(i), msg_topic.get(i)) { + (None, None) => break true, + (None, Some(_)) => break false, + (Some(_), None) => break false, + (Some(TopicFilter::MultiWildcard), Some(_)) => break true, + (Some(TopicFilter::SingleWildcard), Some(_)) => (), + (Some(left), Some(right)) => { + if left != right { + break false; + } } } - } - i += 1; + i += 1; + } + }) { + callback(message).await; } - }) { - callback(message).await; } - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, - Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => { - warn!("Subscriber lagged by {count} values") + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => { + warn!("Subscriber lagged by {count} values") + } } } - } - Ok(()) + Ok(()) + } } /// Accept a new client connected through the `client` stream @@ -343,15 +368,16 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { } let session_present = if clean_session { - let _ = server.clients.remove(&client_id); + let _ = server.inner.clients.remove(&client_id); false } else { - server.clients.contains_key(&client_id) + server.inner.clients.contains_key(&client_id) }; let client_id = Arc::new(client_id); if let Err(err) = server + .inner .auth_handler .allow_login(client_id.clone(), username.as_deref(), password) .await @@ -373,6 +399,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { { let client_state = server + .inner .clients .entry((*client_id).clone()) .or_insert_with(ClientState::default); @@ -385,13 +412,13 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { .as_ref() .map(|will| MqttMessage::from_last_will(will, client_id.clone())); - let published_packets = server.subscription_manager.clone(); + let published_packets = server.inner.subscription_manager.clone(); let (published_packets_send, mut published_packets_rec) = tokio::sync::mpsc::unbounded_channel::<MqttMessage>(); let send_loop = { let publisher_client_id = client_id.clone(); - let clients = server.clients.clone(); + let clients = server.inner.clients.clone(); tokio::spawn(async move { loop { match published_packets_rec.recv().await { @@ -422,10 +449,10 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { let read_loop = { let keep_alive = keep_alive; - let subscription_manager = server.subscription_manager.clone(); + let subscription_manager = server.inner.subscription_manager.clone(); let client_id = client_id.clone(); - let clients = server.clients.clone(); - let extra_listener = server.extra_listeners.clone(); + let clients = server.inner.clients.clone(); + let extra_listener = server.inner.extra_listeners.clone(); tokio::spawn(async move { let client_id = client_id; |