diff options
author | Marcel Müller <neikos@neikos.email> | 2023-01-04 10:41:40 +0100 |
---|---|---|
committer | Marcel Müller <neikos@neikos.email> | 2023-01-04 10:58:49 +0100 |
commit | f5b2e8f9b609ba30c552ec6e5d1d4849a678e632 (patch) | |
tree | a5688ce5095c10fdd7bb959fa15105168808237c | |
parent | d95ee4a11a04443e215218c8c019b261c0767534 (diff) |
Add login handling to server
Signed-off-by: Marcel Müller <neikos@neikos.email>
-rw-r--r-- | src/server/login.rs | 18 | ||||
-rw-r--r-- | src/server/mod.rs | 55 |
2 files changed, 59 insertions, 14 deletions
diff --git a/src/server/login.rs b/src/server/login.rs index 123f4aa..c964f39 100644 --- a/src/server/login.rs +++ b/src/server/login.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use crate::server::ClientId; /// Errors that can occur during login +#[derive(Debug, thiserror::Error)] pub enum LoginError {} /// Objects that can handle authentication implement this trait @@ -15,8 +16,21 @@ pub enum LoginError {} pub trait LoginHandler { /// Check whether to allow this client to log in async fn allow_login( + &self, client_id: Arc<ClientId>, - username: &str, - password: &str, + username: Option<&str>, + password: Option<&[u8]>, ) -> Result<(), LoginError>; } + +#[async_trait::async_trait] +impl LoginHandler for () { + async fn allow_login( + &self, + _client_id: Arc<ClientId>, + _username: Option<&str>, + _password: Option<&[u8]>, + ) -> Result<(), LoginError> { + Ok(()) + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 6884f24..581157a 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -29,11 +29,11 @@ //! [MQTT Spec]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html #![deny(missing_docs)] +/// Authentication related functionality +pub mod login; mod message; mod state; mod subscriptions; -/// Authentication related functionality -pub mod login; use std::{sync::Arc, time::Duration}; @@ -58,7 +58,11 @@ use tracing::{debug, error, info, trace}; use crate::{error::MqttError, mqtt_stream::MqttStream, PacketIOError}; use subscriptions::{ClientInformation, SubscriptionManager}; -use self::{message::MqttMessage, state::ClientState}; +use self::{ + login::{LoginError, LoginHandler}, + message::MqttMessage, + state::ClientState, +}; /// The unique id (per server) of a connecting client #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -85,6 +89,9 @@ pub enum ClientError { /// An error occurred during the sending/receiving of a packet #[error("An error occured during the handling of a packet")] Packet(#[from] PacketIOError), + /// An authentication was rejected + #[error("An authentication was rejected")] + Authentication(#[from] LoginError), } #[derive(Debug)] @@ -129,13 +136,14 @@ impl ClientSource { /// /// Check out the server example for a working version. /// -pub struct MqttServer { +pub struct MqttServer<LH = ()> { clients: Arc<DashMap<ClientId, ClientState>>, client_source: Mutex<ClientSource>, + auth_handler: LH, subscription_manager: SubscriptionManager, } -impl MqttServer { +impl MqttServer<()> { /// Create a new MQTT server listening on the given `SocketAddr` pub async fn serve_v3_unsecured_tcp<Addr: ToSocketAddrs>( addr: Addr, @@ -145,9 +153,22 @@ impl MqttServer { Ok(MqttServer { clients: Arc::new(DashMap::new()), client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)), + auth_handler: (), subscription_manager: SubscriptionManager::new(), }) } +} + +impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> { + /// Switch the login handler with a new one + pub fn with_login_handler<NLH: LoginHandler>(self, new_login_handler: NLH) -> MqttServer<NLH> { + MqttServer { + clients: self.clients, + client_source: self.client_source, + auth_handler: new_login_handler, + subscription_manager: self.subscription_manager, + } + } /// Start accepting new clients connecting to the server pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> { @@ -191,15 +212,15 @@ impl MqttServer { } #[allow(clippy::too_many_arguments)] - async fn connect_client<'message>( - server: &MqttServer, + async fn connect_client<'message, LH: LoginHandler>( + server: &MqttServer<LH>, mut client: MqttStream, _protocol_name: MString<'message>, _protocol_level: u8, clean_session: bool, will: Option<MLastWill<'message>>, - _username: Option<MString<'message>>, - _password: Option<&'message [u8]>, + username: Option<MString<'message>>, + password: Option<&'message [u8]>, keep_alive: u16, client_id: MString<'message>, ) -> Result<(), ClientError> { @@ -227,6 +248,18 @@ impl MqttServer { server.clients.contains_key(&client_id) }; + let client_id = Arc::new(client_id); + + if let Err(err) = server + .auth_handler + .allow_login(client_id.clone(), username.as_deref(), password) + .await + { + send_connack(session_present, MConnectReturnCode::Accepted, &mut client).await?; + + return Err(ClientError::Authentication(err)); + } + send_connack(session_present, MConnectReturnCode::Accepted, &mut client).await?; debug!(?client_id, "Accepted new connection"); @@ -240,15 +273,13 @@ impl MqttServer { { let client_state = server .clients - .entry(client_id.clone()) + .entry((*client_id).clone()) .or_insert_with(ClientState::default); client_state .set_new_connection(client_connection.clone()) .await; } - let client_id = Arc::new(client_id); - let mut last_will: Option<MqttMessage> = will .as_ref() .map(|will| MqttMessage::from_last_will(will, client_id.clone())); |