diff options
author | Marcel Müller <neikos@neikos.email> | 2023-01-04 14:28:49 +0100 |
---|---|---|
committer | Marcel Müller <neikos@neikos.email> | 2023-01-04 14:28:49 +0100 |
commit | 80574cc1d6c7d4b922b570e451c58c5c1cf5f118 (patch) | |
tree | c0a0952268c819a7fe76860752ebd4bca8e53b94 | |
parent | 03ad09f9435309e8cfbc48f3132a315222b17101 (diff) |
Simplify handling subscription requests
Signed-off-by: Marcel Müller <neikos@neikos.email>
-rw-r--r-- | src/server/handler.rs | 12 | ||||
-rw-r--r-- | src/server/mod.rs | 27 | ||||
-rw-r--r-- | src/server/subscriptions.rs | 34 |
3 files changed, 59 insertions, 14 deletions
diff --git a/src/server/handler.rs b/src/server/handler.rs index 843ce14..b95fae4 100644 --- a/src/server/handler.rs +++ b/src/server/handler.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use mqtt_format::v3::{ connect_return::MConnectReturnCode, qos::MQualityOfService, - subscription_acks::MSubscriptionAck, subscription_request::MSubscriptionRequest, + subscription_request::MSubscriptionRequest, }; use crate::server::ClientId; @@ -64,7 +64,7 @@ pub trait SubscriptionHandler: Send + Sync + 'static { &self, client_id: Arc<ClientId>, subscription: MSubscriptionRequest<'_>, - ) -> MSubscriptionAck; + ) -> Option<MQualityOfService>; } /// A [`SubscriptionHandler`] that simply allows all subscription requests @@ -77,11 +77,7 @@ impl SubscriptionHandler for AllowAllSubscriptions { &self, _client_id: Arc<ClientId>, subscription: MSubscriptionRequest<'_>, - ) -> MSubscriptionAck { - match subscription.qos { - MQualityOfService::AtMostOnce => MSubscriptionAck::MaximumQualityAtMostOnce, - MQualityOfService::AtLeastOnce => MSubscriptionAck::MaximumQualityAtLeastOnce, - MQualityOfService::ExactlyOnce => MSubscriptionAck::MaximumQualityExactlyOnce, - } + ) -> Option<MQualityOfService> { + Some(subscription.qos) } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 33c11c5..6b482a9 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -164,6 +164,10 @@ impl MqttServer<AllowAllLogins, AllowAllSubscriptions> { impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { /// Switch the login handler with a new one + /// + /// ## Note + /// + /// You should only call this after instantiating the server, and before listening pub fn with_login_handler<NLH: LoginHandler>( self, new_login_handler: NLH, @@ -176,6 +180,29 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { } } + /// Resets the subscription handler to a new one + /// + /// ## Note + /// + /// You should only call this after instantiating the server, and before listening + pub fn with_subscription_handler<NSH: SubscriptionHandler>( + self, + new_subscription_handler: NSH, + ) -> MqttServer<LH, NSH> { + MqttServer { + clients: self.clients, + client_source: self.client_source, + auth_handler: self.auth_handler, + subscription_manager: Arc::new({ + let manager = Arc::try_unwrap(self.subscription_manager); + + manager + .unwrap_or_else(|_| panic!("Called after started listening")) + .with_subscription_handler(new_subscription_handler) + }), + } + } + /// Start accepting new clients connecting to the server pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> { let mut client_source = self diff --git a/src/server/subscriptions.rs b/src/server/subscriptions.rs index 16e96ef..ebc3c4c 100644 --- a/src/server/subscriptions.rs +++ b/src/server/subscriptions.rs @@ -112,6 +112,15 @@ impl SubscriptionManager<AllowAllSubscriptions> { } impl<SH: SubscriptionHandler> SubscriptionManager<SH> { + pub(crate) fn with_subscription_handler<NSH: SubscriptionHandler>( + self, + new_subscription_handler: NSH, + ) -> SubscriptionManager<NSH> { + SubscriptionManager { + subscriptions: self.subscriptions, + subscription_handler: new_subscription_handler, + } + } pub(crate) async fn subscribe( &self, client: Arc<ClientInformation>, @@ -125,16 +134,27 @@ impl<SH: SubscriptionHandler> SubscriptionManager<SH> { async move { let topic_levels: VecDeque<TopicFilter> = TopicFilter::parse_from(sub.topic.to_string()); - let client_sub = ClientSubscription { - qos: sub.qos, - client: client.clone(), - }; - let ack = self + let sub_resp = self .subscription_handler .allow_subscription(client.client_id.clone(), sub) .await; + let ack = match sub_resp { + None => MSubscriptionAck::Failure, + Some(MQualityOfService::AtMostOnce) => { + MSubscriptionAck::MaximumQualityAtMostOnce + } + Some(MQualityOfService::AtLeastOnce) => { + MSubscriptionAck::MaximumQualityAtLeastOnce + } + Some(MQualityOfService::ExactlyOnce) => { + MSubscriptionAck::MaximumQualityExactlyOnce + } + }; + + let client_sub = sub_resp.map(|qos| ClientSubscription { qos, client }); + (topic_levels, client_sub, ack) } }) @@ -146,7 +166,9 @@ impl<SH: SubscriptionHandler> SubscriptionManager<SH> { let mut subs = SubscriptionTopic::clone(old_table); for (topic, client, _) in sub_changes.clone() { - subs.add_subscription(topic, client); + if let Some(client) = client { + subs.add_subscription(topic, client); + } } subs |