summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2023-01-05 08:24:41 +0000
committerGitHub <noreply@github.com>2023-01-05 08:24:41 +0000
commitdd49fdcc40e1d1b99ea69b2cc3866b357b0ece54 (patch)
tree30b3b86270b135dd2088138e8488b55edf7bd983
parent807aee4e811b98819a9b1c8b626ad4d802a9489f (diff)
parentb1a2c2d9a1723de358b34786f8194f66d17f42c2 (diff)
Merge #118
118: Feature/add subscription handler r=TheNeikos a=TheNeikos Closes #104 Co-authored-by: Marcel Müller <neikos@neikos.email>
-rw-r--r--src/bin/cloudmqtt-server.rs24
-rw-r--r--src/server/handler.rs (renamed from src/server/login.rs)38
-rw-r--r--src/server/mod.rs56
-rw-r--r--src/server/subscriptions.rs76
4 files changed, 156 insertions, 38 deletions
diff --git a/src/bin/cloudmqtt-server.rs b/src/bin/cloudmqtt-server.rs
index 6d17f94..37ad878 100644
--- a/src/bin/cloudmqtt-server.rs
+++ b/src/bin/cloudmqtt-server.rs
@@ -6,8 +6,10 @@
use std::sync::Arc;
-use cloudmqtt::server::login::{LoginError, LoginHandler};
+use cloudmqtt::server::handler::{LoginError, LoginHandler, SubscriptionHandler};
use cloudmqtt::server::{ClientId, MqttServer};
+use mqtt_format::v3::qos::MQualityOfService;
+use mqtt_format::v3::subscription_request::MSubscriptionRequest;
use tracing::info;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
@@ -32,6 +34,23 @@ impl LoginHandler for SimpleLoginHandler {
}
}
+struct SimpleSubscriptionHandler;
+
+#[async_trait::async_trait]
+impl SubscriptionHandler for SimpleSubscriptionHandler {
+ async fn allow_subscription(
+ &self,
+ _client_id: Arc<ClientId>,
+ subscription: MSubscriptionRequest<'_>,
+ ) -> Option<MQualityOfService> {
+ if &*subscription.topic == "forbidden" {
+ return None;
+ }
+
+ Some(subscription.qos)
+ }
+}
+
#[tokio::main]
async fn main() {
let fmt_layer = tracing_subscriber::fmt::layer()
@@ -50,7 +69,8 @@ async fn main() {
let server = MqttServer::serve_v3_unsecured_tcp("0.0.0.0:1883")
.await
.unwrap()
- .with_login_handler(SimpleLoginHandler);
+ .with_login_handler(SimpleLoginHandler)
+ .with_subscription_handler(SimpleSubscriptionHandler);
Arc::new(server).accept_new_clients().await.unwrap();
}
diff --git a/src/server/login.rs b/src/server/handler.rs
index b1e2c94..b95fae4 100644
--- a/src/server/login.rs
+++ b/src/server/handler.rs
@@ -5,7 +5,10 @@
//
use std::sync::Arc;
-use mqtt_format::v3::connect_return::MConnectReturnCode;
+use mqtt_format::v3::{
+ connect_return::MConnectReturnCode, qos::MQualityOfService,
+ subscription_request::MSubscriptionRequest,
+};
use crate::server::ClientId;
@@ -28,7 +31,7 @@ impl LoginError {
/// Objects that can handle authentication implement this trait
#[async_trait::async_trait]
-pub trait LoginHandler {
+pub trait LoginHandler: Send + Sync + 'static {
/// Check whether to allow this client to log in
async fn allow_login(
&self,
@@ -38,8 +41,11 @@ pub trait LoginHandler {
) -> Result<(), LoginError>;
}
+/// A [`LoginHandler`] that simply allows all login attempts
+pub struct AllowAllLogins;
+
#[async_trait::async_trait]
-impl LoginHandler for () {
+impl LoginHandler for AllowAllLogins {
async fn allow_login(
&self,
_client_id: Arc<ClientId>,
@@ -49,3 +55,29 @@ impl LoginHandler for () {
Ok(())
}
}
+
+/// Objects that can handle authentication implement this trait
+#[async_trait::async_trait]
+pub trait SubscriptionHandler: Send + Sync + 'static {
+ /// Check whether to allow this client to log in
+ async fn allow_subscription(
+ &self,
+ client_id: Arc<ClientId>,
+ subscription: MSubscriptionRequest<'_>,
+ ) -> Option<MQualityOfService>;
+}
+
+/// A [`SubscriptionHandler`] that simply allows all subscription requests
+pub struct AllowAllSubscriptions;
+
+#[async_trait::async_trait]
+impl SubscriptionHandler for AllowAllSubscriptions {
+ /// Check whether to allow this client to log in
+ async fn allow_subscription(
+ &self,
+ _client_id: Arc<ClientId>,
+ subscription: MSubscriptionRequest<'_>,
+ ) -> Option<MQualityOfService> {
+ Some(subscription.qos)
+ }
+}
diff --git a/src/server/mod.rs b/src/server/mod.rs
index dbf93c7..6b482a9 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -30,7 +30,7 @@
#![deny(missing_docs)]
/// Authentication related functionality
-pub mod login;
+pub mod handler;
mod message;
mod state;
mod subscriptions;
@@ -60,7 +60,9 @@ use crate::{error::MqttError, mqtt_stream::MqttStream, PacketIOError};
use subscriptions::{ClientInformation, SubscriptionManager};
use self::{
- login::{LoginError, LoginHandler},
+ handler::{
+ AllowAllLogins, AllowAllSubscriptions, LoginError, LoginHandler, SubscriptionHandler,
+ },
message::MqttMessage,
state::ClientState,
};
@@ -137,14 +139,14 @@ impl ClientSource {
///
/// Check out the server example for a working version.
///
-pub struct MqttServer<LH = ()> {
+pub struct MqttServer<LoginH, SubH> {
clients: Arc<DashMap<ClientId, ClientState>>,
client_source: Mutex<ClientSource>,
- auth_handler: LH,
- subscription_manager: SubscriptionManager,
+ auth_handler: LoginH,
+ subscription_manager: Arc<SubscriptionManager<SubH>>,
}
-impl MqttServer<()> {
+impl MqttServer<AllowAllLogins, AllowAllSubscriptions> {
/// Create a new MQTT server listening on the given `SocketAddr`
pub async fn serve_v3_unsecured_tcp<Addr: ToSocketAddrs>(
addr: Addr,
@@ -154,15 +156,22 @@ impl MqttServer<()> {
Ok(MqttServer {
clients: Arc::new(DashMap::new()),
client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)),
- auth_handler: (),
- subscription_manager: SubscriptionManager::new(),
+ auth_handler: AllowAllLogins,
+ subscription_manager: Arc::new(SubscriptionManager::new()),
})
}
}
-impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> {
+impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
/// Switch the login handler with a new one
- pub fn with_login_handler<NLH: LoginHandler>(self, new_login_handler: NLH) -> MqttServer<NLH> {
+ ///
+ /// ## Note
+ ///
+ /// You should only call this after instantiating the server, and before listening
+ pub fn with_login_handler<NLH: LoginHandler>(
+ self,
+ new_login_handler: NLH,
+ ) -> MqttServer<NLH, SH> {
MqttServer {
clients: self.clients,
client_source: self.client_source,
@@ -171,6 +180,29 @@ impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> {
}
}
+ /// Resets the subscription handler to a new one
+ ///
+ /// ## Note
+ ///
+ /// You should only call this after instantiating the server, and before listening
+ pub fn with_subscription_handler<NSH: SubscriptionHandler>(
+ self,
+ new_subscription_handler: NSH,
+ ) -> MqttServer<LH, NSH> {
+ MqttServer {
+ clients: self.clients,
+ client_source: self.client_source,
+ auth_handler: self.auth_handler,
+ subscription_manager: Arc::new({
+ let manager = Arc::try_unwrap(self.subscription_manager);
+
+ manager
+ .unwrap_or_else(|_| panic!("Called after started listening"))
+ .with_subscription_handler(new_subscription_handler)
+ }),
+ }
+ }
+
/// Start accepting new clients connecting to the server
pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> {
let mut client_source = self
@@ -213,8 +245,8 @@ impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> {
}
#[allow(clippy::too_many_arguments)]
- async fn connect_client<'message, LH: LoginHandler>(
- server: &MqttServer<LH>,
+ async fn connect_client<'message, LH: LoginHandler, SubH: SubscriptionHandler>(
+ server: &MqttServer<LH, SubH>,
mut client: MqttStream,
_protocol_name: MString<'message>,
_protocol_level: u8,
diff --git a/src/server/subscriptions.rs b/src/server/subscriptions.rs
index 35946e1..ebc3c4c 100644
--- a/src/server/subscriptions.rs
+++ b/src/server/subscriptions.rs
@@ -10,6 +10,7 @@ use std::{
};
use arc_swap::ArcSwap;
+use futures::{stream::FuturesUnordered, StreamExt};
use mqtt_format::v3::{
qos::MQualityOfService, subscription_acks::MSubscriptionAck,
subscription_request::MSubscriptionRequests,
@@ -18,6 +19,8 @@ use tracing::{debug, trace};
use crate::server::{ClientId, MqttMessage};
+use super::handler::{AllowAllSubscriptions, SubscriptionHandler};
+
// foo/barr/# => vec![Named, Named, MultiWildcard]
// /foo/barr/# => vec![Empty, ... ]
// /devices/+/temperature
@@ -94,16 +97,30 @@ impl TopicFilter {
}
}
-#[derive(Debug, Clone, Default)]
-pub(crate) struct SubscriptionManager {
+pub(crate) struct SubscriptionManager<SubH> {
+ subscription_handler: SubH,
subscriptions: Arc<ArcSwap<SubscriptionTopic>>,
}
-impl SubscriptionManager {
- pub(crate) fn new() -> SubscriptionManager {
- Default::default()
+impl SubscriptionManager<AllowAllSubscriptions> {
+ pub(crate) fn new() -> Self {
+ SubscriptionManager {
+ subscription_handler: AllowAllSubscriptions,
+ subscriptions: Arc::default(),
+ }
}
+}
+impl<SH: SubscriptionHandler> SubscriptionManager<SH> {
+ pub(crate) fn with_subscription_handler<NSH: SubscriptionHandler>(
+ self,
+ new_subscription_handler: NSH,
+ ) -> SubscriptionManager<NSH> {
+ SubscriptionManager {
+ subscriptions: self.subscriptions,
+ subscription_handler: new_subscription_handler,
+ }
+ }
pub(crate) async fn subscribe(
&self,
client: Arc<ClientInformation>,
@@ -113,28 +130,45 @@ impl SubscriptionManager {
let sub_changes: Vec<_> = subscriptions
.into_iter()
.map(|sub| {
- let topic_levels: VecDeque<TopicFilter> =
- TopicFilter::parse_from(sub.topic.to_string());
- let client_sub = ClientSubscription {
- qos: sub.qos,
- client: client.clone(),
- };
-
- let ack = match sub.qos {
- MQualityOfService::AtMostOnce => MSubscriptionAck::MaximumQualityAtMostOnce,
- MQualityOfService::AtLeastOnce => MSubscriptionAck::MaximumQualityAtLeastOnce,
- MQualityOfService::ExactlyOnce => MSubscriptionAck::MaximumQualityExactlyOnce,
- };
-
- (topic_levels, client_sub, ack)
+ let client = client.clone();
+ async move {
+ let topic_levels: VecDeque<TopicFilter> =
+ TopicFilter::parse_from(sub.topic.to_string());
+
+ let sub_resp = self
+ .subscription_handler
+ .allow_subscription(client.client_id.clone(), sub)
+ .await;
+
+ let ack = match sub_resp {
+ None => MSubscriptionAck::Failure,
+ Some(MQualityOfService::AtMostOnce) => {
+ MSubscriptionAck::MaximumQualityAtMostOnce
+ }
+ Some(MQualityOfService::AtLeastOnce) => {
+ MSubscriptionAck::MaximumQualityAtLeastOnce
+ }
+ Some(MQualityOfService::ExactlyOnce) => {
+ MSubscriptionAck::MaximumQualityExactlyOnce
+ }
+ };
+
+ let client_sub = sub_resp.map(|qos| ClientSubscription { qos, client });
+
+ (topic_levels, client_sub, ack)
+ }
})
- .collect();
+ .collect::<FuturesUnordered<_>>()
+ .collect()
+ .await;
self.subscriptions.rcu(|old_table| {
let mut subs = SubscriptionTopic::clone(old_table);
for (topic, client, _) in sub_changes.clone() {
- subs.add_subscription(topic, client);
+ if let Some(client) = client {
+ subs.add_subscription(topic, client);
+ }
}
subs