diff options
author | bors[bot] <26634292+bors[bot]@users.noreply.github.com> | 2023-01-05 08:24:41 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-05 08:24:41 +0000 |
commit | dd49fdcc40e1d1b99ea69b2cc3866b357b0ece54 (patch) | |
tree | 30b3b86270b135dd2088138e8488b55edf7bd983 | |
parent | 807aee4e811b98819a9b1c8b626ad4d802a9489f (diff) | |
parent | b1a2c2d9a1723de358b34786f8194f66d17f42c2 (diff) |
Merge #118
118: Feature/add subscription handler r=TheNeikos a=TheNeikos
Closes #104
Co-authored-by: Marcel Müller <neikos@neikos.email>
-rw-r--r-- | src/bin/cloudmqtt-server.rs | 24 | ||||
-rw-r--r-- | src/server/handler.rs (renamed from src/server/login.rs) | 38 | ||||
-rw-r--r-- | src/server/mod.rs | 56 | ||||
-rw-r--r-- | src/server/subscriptions.rs | 76 |
4 files changed, 156 insertions, 38 deletions
diff --git a/src/bin/cloudmqtt-server.rs b/src/bin/cloudmqtt-server.rs index 6d17f94..37ad878 100644 --- a/src/bin/cloudmqtt-server.rs +++ b/src/bin/cloudmqtt-server.rs @@ -6,8 +6,10 @@ use std::sync::Arc; -use cloudmqtt::server::login::{LoginError, LoginHandler}; +use cloudmqtt::server::handler::{LoginError, LoginHandler, SubscriptionHandler}; use cloudmqtt::server::{ClientId, MqttServer}; +use mqtt_format::v3::qos::MQualityOfService; +use mqtt_format::v3::subscription_request::MSubscriptionRequest; use tracing::info; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -32,6 +34,23 @@ impl LoginHandler for SimpleLoginHandler { } } +struct SimpleSubscriptionHandler; + +#[async_trait::async_trait] +impl SubscriptionHandler for SimpleSubscriptionHandler { + async fn allow_subscription( + &self, + _client_id: Arc<ClientId>, + subscription: MSubscriptionRequest<'_>, + ) -> Option<MQualityOfService> { + if &*subscription.topic == "forbidden" { + return None; + } + + Some(subscription.qos) + } +} + #[tokio::main] async fn main() { let fmt_layer = tracing_subscriber::fmt::layer() @@ -50,7 +69,8 @@ async fn main() { let server = MqttServer::serve_v3_unsecured_tcp("0.0.0.0:1883") .await .unwrap() - .with_login_handler(SimpleLoginHandler); + .with_login_handler(SimpleLoginHandler) + .with_subscription_handler(SimpleSubscriptionHandler); Arc::new(server).accept_new_clients().await.unwrap(); } diff --git a/src/server/login.rs b/src/server/handler.rs index b1e2c94..b95fae4 100644 --- a/src/server/login.rs +++ b/src/server/handler.rs @@ -5,7 +5,10 @@ // use std::sync::Arc; -use mqtt_format::v3::connect_return::MConnectReturnCode; +use mqtt_format::v3::{ + connect_return::MConnectReturnCode, qos::MQualityOfService, + subscription_request::MSubscriptionRequest, +}; use crate::server::ClientId; @@ -28,7 +31,7 @@ impl LoginError { /// Objects that can handle authentication implement this trait #[async_trait::async_trait] -pub trait LoginHandler { +pub trait LoginHandler: Send + Sync + 'static { /// Check whether to allow this client to log in async fn allow_login( &self, @@ -38,8 +41,11 @@ pub trait LoginHandler { ) -> Result<(), LoginError>; } +/// A [`LoginHandler`] that simply allows all login attempts +pub struct AllowAllLogins; + #[async_trait::async_trait] -impl LoginHandler for () { +impl LoginHandler for AllowAllLogins { async fn allow_login( &self, _client_id: Arc<ClientId>, @@ -49,3 +55,29 @@ impl LoginHandler for () { Ok(()) } } + +/// Objects that can handle authentication implement this trait +#[async_trait::async_trait] +pub trait SubscriptionHandler: Send + Sync + 'static { + /// Check whether to allow this client to log in + async fn allow_subscription( + &self, + client_id: Arc<ClientId>, + subscription: MSubscriptionRequest<'_>, + ) -> Option<MQualityOfService>; +} + +/// A [`SubscriptionHandler`] that simply allows all subscription requests +pub struct AllowAllSubscriptions; + +#[async_trait::async_trait] +impl SubscriptionHandler for AllowAllSubscriptions { + /// Check whether to allow this client to log in + async fn allow_subscription( + &self, + _client_id: Arc<ClientId>, + subscription: MSubscriptionRequest<'_>, + ) -> Option<MQualityOfService> { + Some(subscription.qos) + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index dbf93c7..6b482a9 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -30,7 +30,7 @@ #![deny(missing_docs)] /// Authentication related functionality -pub mod login; +pub mod handler; mod message; mod state; mod subscriptions; @@ -60,7 +60,9 @@ use crate::{error::MqttError, mqtt_stream::MqttStream, PacketIOError}; use subscriptions::{ClientInformation, SubscriptionManager}; use self::{ - login::{LoginError, LoginHandler}, + handler::{ + AllowAllLogins, AllowAllSubscriptions, LoginError, LoginHandler, SubscriptionHandler, + }, message::MqttMessage, state::ClientState, }; @@ -137,14 +139,14 @@ impl ClientSource { /// /// Check out the server example for a working version. /// -pub struct MqttServer<LH = ()> { +pub struct MqttServer<LoginH, SubH> { clients: Arc<DashMap<ClientId, ClientState>>, client_source: Mutex<ClientSource>, - auth_handler: LH, - subscription_manager: SubscriptionManager, + auth_handler: LoginH, + subscription_manager: Arc<SubscriptionManager<SubH>>, } -impl MqttServer<()> { +impl MqttServer<AllowAllLogins, AllowAllSubscriptions> { /// Create a new MQTT server listening on the given `SocketAddr` pub async fn serve_v3_unsecured_tcp<Addr: ToSocketAddrs>( addr: Addr, @@ -154,15 +156,22 @@ impl MqttServer<()> { Ok(MqttServer { clients: Arc::new(DashMap::new()), client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)), - auth_handler: (), - subscription_manager: SubscriptionManager::new(), + auth_handler: AllowAllLogins, + subscription_manager: Arc::new(SubscriptionManager::new()), }) } } -impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> { +impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> { /// Switch the login handler with a new one - pub fn with_login_handler<NLH: LoginHandler>(self, new_login_handler: NLH) -> MqttServer<NLH> { + /// + /// ## Note + /// + /// You should only call this after instantiating the server, and before listening + pub fn with_login_handler<NLH: LoginHandler>( + self, + new_login_handler: NLH, + ) -> MqttServer<NLH, SH> { MqttServer { clients: self.clients, client_source: self.client_source, @@ -171,6 +180,29 @@ impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> { } } + /// Resets the subscription handler to a new one + /// + /// ## Note + /// + /// You should only call this after instantiating the server, and before listening + pub fn with_subscription_handler<NSH: SubscriptionHandler>( + self, + new_subscription_handler: NSH, + ) -> MqttServer<LH, NSH> { + MqttServer { + clients: self.clients, + client_source: self.client_source, + auth_handler: self.auth_handler, + subscription_manager: Arc::new({ + let manager = Arc::try_unwrap(self.subscription_manager); + + manager + .unwrap_or_else(|_| panic!("Called after started listening")) + .with_subscription_handler(new_subscription_handler) + }), + } + } + /// Start accepting new clients connecting to the server pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> { let mut client_source = self @@ -213,8 +245,8 @@ impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> { } #[allow(clippy::too_many_arguments)] - async fn connect_client<'message, LH: LoginHandler>( - server: &MqttServer<LH>, + async fn connect_client<'message, LH: LoginHandler, SubH: SubscriptionHandler>( + server: &MqttServer<LH, SubH>, mut client: MqttStream, _protocol_name: MString<'message>, _protocol_level: u8, diff --git a/src/server/subscriptions.rs b/src/server/subscriptions.rs index 35946e1..ebc3c4c 100644 --- a/src/server/subscriptions.rs +++ b/src/server/subscriptions.rs @@ -10,6 +10,7 @@ use std::{ }; use arc_swap::ArcSwap; +use futures::{stream::FuturesUnordered, StreamExt}; use mqtt_format::v3::{ qos::MQualityOfService, subscription_acks::MSubscriptionAck, subscription_request::MSubscriptionRequests, @@ -18,6 +19,8 @@ use tracing::{debug, trace}; use crate::server::{ClientId, MqttMessage}; +use super::handler::{AllowAllSubscriptions, SubscriptionHandler}; + // foo/barr/# => vec![Named, Named, MultiWildcard] // /foo/barr/# => vec![Empty, ... ] // /devices/+/temperature @@ -94,16 +97,30 @@ impl TopicFilter { } } -#[derive(Debug, Clone, Default)] -pub(crate) struct SubscriptionManager { +pub(crate) struct SubscriptionManager<SubH> { + subscription_handler: SubH, subscriptions: Arc<ArcSwap<SubscriptionTopic>>, } -impl SubscriptionManager { - pub(crate) fn new() -> SubscriptionManager { - Default::default() +impl SubscriptionManager<AllowAllSubscriptions> { + pub(crate) fn new() -> Self { + SubscriptionManager { + subscription_handler: AllowAllSubscriptions, + subscriptions: Arc::default(), + } } +} +impl<SH: SubscriptionHandler> SubscriptionManager<SH> { + pub(crate) fn with_subscription_handler<NSH: SubscriptionHandler>( + self, + new_subscription_handler: NSH, + ) -> SubscriptionManager<NSH> { + SubscriptionManager { + subscriptions: self.subscriptions, + subscription_handler: new_subscription_handler, + } + } pub(crate) async fn subscribe( &self, client: Arc<ClientInformation>, @@ -113,28 +130,45 @@ impl SubscriptionManager { let sub_changes: Vec<_> = subscriptions .into_iter() .map(|sub| { - let topic_levels: VecDeque<TopicFilter> = - TopicFilter::parse_from(sub.topic.to_string()); - let client_sub = ClientSubscription { - qos: sub.qos, - client: client.clone(), - }; - - let ack = match sub.qos { - MQualityOfService::AtMostOnce => MSubscriptionAck::MaximumQualityAtMostOnce, - MQualityOfService::AtLeastOnce => MSubscriptionAck::MaximumQualityAtLeastOnce, - MQualityOfService::ExactlyOnce => MSubscriptionAck::MaximumQualityExactlyOnce, - }; - - (topic_levels, client_sub, ack) + let client = client.clone(); + async move { + let topic_levels: VecDeque<TopicFilter> = + TopicFilter::parse_from(sub.topic.to_string()); + + let sub_resp = self + .subscription_handler + .allow_subscription(client.client_id.clone(), sub) + .await; + + let ack = match sub_resp { + None => MSubscriptionAck::Failure, + Some(MQualityOfService::AtMostOnce) => { + MSubscriptionAck::MaximumQualityAtMostOnce + } + Some(MQualityOfService::AtLeastOnce) => { + MSubscriptionAck::MaximumQualityAtLeastOnce + } + Some(MQualityOfService::ExactlyOnce) => { + MSubscriptionAck::MaximumQualityExactlyOnce + } + }; + + let client_sub = sub_resp.map(|qos| ClientSubscription { qos, client }); + + (topic_levels, client_sub, ack) + } }) - .collect(); + .collect::<FuturesUnordered<_>>() + .collect() + .await; self.subscriptions.rcu(|old_table| { let mut subs = SubscriptionTopic::clone(old_table); for (topic, client, _) in sub_changes.clone() { - subs.add_subscription(topic, client); + if let Some(client) = client { + subs.add_subscription(topic, client); + } } subs |