summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarcel Müller <neikos@neikos.email>2023-01-04 14:28:49 +0100
committerMarcel Müller <neikos@neikos.email>2023-01-04 14:28:49 +0100
commit80574cc1d6c7d4b922b570e451c58c5c1cf5f118 (patch)
treec0a0952268c819a7fe76860752ebd4bca8e53b94
parent03ad09f9435309e8cfbc48f3132a315222b17101 (diff)
Simplify handling subscription requests
Signed-off-by: Marcel Müller <neikos@neikos.email>
-rw-r--r--src/server/handler.rs12
-rw-r--r--src/server/mod.rs27
-rw-r--r--src/server/subscriptions.rs34
3 files changed, 59 insertions, 14 deletions
diff --git a/src/server/handler.rs b/src/server/handler.rs
index 843ce14..b95fae4 100644
--- a/src/server/handler.rs
+++ b/src/server/handler.rs
@@ -7,7 +7,7 @@ use std::sync::Arc;
use mqtt_format::v3::{
connect_return::MConnectReturnCode, qos::MQualityOfService,
- subscription_acks::MSubscriptionAck, subscription_request::MSubscriptionRequest,
+ subscription_request::MSubscriptionRequest,
};
use crate::server::ClientId;
@@ -64,7 +64,7 @@ pub trait SubscriptionHandler: Send + Sync + 'static {
&self,
client_id: Arc<ClientId>,
subscription: MSubscriptionRequest<'_>,
- ) -> MSubscriptionAck;
+ ) -> Option<MQualityOfService>;
}
/// A [`SubscriptionHandler`] that simply allows all subscription requests
@@ -77,11 +77,7 @@ impl SubscriptionHandler for AllowAllSubscriptions {
&self,
_client_id: Arc<ClientId>,
subscription: MSubscriptionRequest<'_>,
- ) -> MSubscriptionAck {
- match subscription.qos {
- MQualityOfService::AtMostOnce => MSubscriptionAck::MaximumQualityAtMostOnce,
- MQualityOfService::AtLeastOnce => MSubscriptionAck::MaximumQualityAtLeastOnce,
- MQualityOfService::ExactlyOnce => MSubscriptionAck::MaximumQualityExactlyOnce,
- }
+ ) -> Option<MQualityOfService> {
+ Some(subscription.qos)
}
}
diff --git a/src/server/mod.rs b/src/server/mod.rs
index 33c11c5..6b482a9 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -164,6 +164,10 @@ impl MqttServer<AllowAllLogins, AllowAllSubscriptions> {
impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
/// Switch the login handler with a new one
+ ///
+ /// ## Note
+ ///
+ /// You should only call this after instantiating the server, and before listening
pub fn with_login_handler<NLH: LoginHandler>(
self,
new_login_handler: NLH,
@@ -176,6 +180,29 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
}
}
+ /// Resets the subscription handler to a new one
+ ///
+ /// ## Note
+ ///
+ /// You should only call this after instantiating the server, and before listening
+ pub fn with_subscription_handler<NSH: SubscriptionHandler>(
+ self,
+ new_subscription_handler: NSH,
+ ) -> MqttServer<LH, NSH> {
+ MqttServer {
+ clients: self.clients,
+ client_source: self.client_source,
+ auth_handler: self.auth_handler,
+ subscription_manager: Arc::new({
+ let manager = Arc::try_unwrap(self.subscription_manager);
+
+ manager
+ .unwrap_or_else(|_| panic!("Called after started listening"))
+ .with_subscription_handler(new_subscription_handler)
+ }),
+ }
+ }
+
/// Start accepting new clients connecting to the server
pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> {
let mut client_source = self
diff --git a/src/server/subscriptions.rs b/src/server/subscriptions.rs
index 16e96ef..ebc3c4c 100644
--- a/src/server/subscriptions.rs
+++ b/src/server/subscriptions.rs
@@ -112,6 +112,15 @@ impl SubscriptionManager<AllowAllSubscriptions> {
}
impl<SH: SubscriptionHandler> SubscriptionManager<SH> {
+ pub(crate) fn with_subscription_handler<NSH: SubscriptionHandler>(
+ self,
+ new_subscription_handler: NSH,
+ ) -> SubscriptionManager<NSH> {
+ SubscriptionManager {
+ subscriptions: self.subscriptions,
+ subscription_handler: new_subscription_handler,
+ }
+ }
pub(crate) async fn subscribe(
&self,
client: Arc<ClientInformation>,
@@ -125,16 +134,27 @@ impl<SH: SubscriptionHandler> SubscriptionManager<SH> {
async move {
let topic_levels: VecDeque<TopicFilter> =
TopicFilter::parse_from(sub.topic.to_string());
- let client_sub = ClientSubscription {
- qos: sub.qos,
- client: client.clone(),
- };
- let ack = self
+ let sub_resp = self
.subscription_handler
.allow_subscription(client.client_id.clone(), sub)
.await;
+ let ack = match sub_resp {
+ None => MSubscriptionAck::Failure,
+ Some(MQualityOfService::AtMostOnce) => {
+ MSubscriptionAck::MaximumQualityAtMostOnce
+ }
+ Some(MQualityOfService::AtLeastOnce) => {
+ MSubscriptionAck::MaximumQualityAtLeastOnce
+ }
+ Some(MQualityOfService::ExactlyOnce) => {
+ MSubscriptionAck::MaximumQualityExactlyOnce
+ }
+ };
+
+ let client_sub = sub_resp.map(|qos| ClientSubscription { qos, client });
+
(topic_levels, client_sub, ack)
}
})
@@ -146,7 +166,9 @@ impl<SH: SubscriptionHandler> SubscriptionManager<SH> {
let mut subs = SubscriptionTopic::clone(old_table);
for (topic, client, _) in sub_changes.clone() {
- subs.add_subscription(topic, client);
+ if let Some(client) = client {
+ subs.add_subscription(topic, client);
+ }
}
subs