summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarcel Müller <neikos@neikos.email>2023-01-04 10:41:40 +0100
committerMarcel Müller <neikos@neikos.email>2023-01-04 10:58:49 +0100
commitf5b2e8f9b609ba30c552ec6e5d1d4849a678e632 (patch)
treea5688ce5095c10fdd7bb959fa15105168808237c
parentd95ee4a11a04443e215218c8c019b261c0767534 (diff)
Add login handling to server
Signed-off-by: Marcel Müller <neikos@neikos.email>
-rw-r--r--src/server/login.rs18
-rw-r--r--src/server/mod.rs55
2 files changed, 59 insertions, 14 deletions
diff --git a/src/server/login.rs b/src/server/login.rs
index 123f4aa..c964f39 100644
--- a/src/server/login.rs
+++ b/src/server/login.rs
@@ -8,6 +8,7 @@ use std::sync::Arc;
use crate::server::ClientId;
/// Errors that can occur during login
+#[derive(Debug, thiserror::Error)]
pub enum LoginError {}
/// Objects that can handle authentication implement this trait
@@ -15,8 +16,21 @@ pub enum LoginError {}
pub trait LoginHandler {
/// Check whether to allow this client to log in
async fn allow_login(
+ &self,
client_id: Arc<ClientId>,
- username: &str,
- password: &str,
+ username: Option<&str>,
+ password: Option<&[u8]>,
) -> Result<(), LoginError>;
}
+
+#[async_trait::async_trait]
+impl LoginHandler for () {
+ async fn allow_login(
+ &self,
+ _client_id: Arc<ClientId>,
+ _username: Option<&str>,
+ _password: Option<&[u8]>,
+ ) -> Result<(), LoginError> {
+ Ok(())
+ }
+}
diff --git a/src/server/mod.rs b/src/server/mod.rs
index 6884f24..581157a 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -29,11 +29,11 @@
//! [MQTT Spec]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
#![deny(missing_docs)]
+/// Authentication related functionality
+pub mod login;
mod message;
mod state;
mod subscriptions;
-/// Authentication related functionality
-pub mod login;
use std::{sync::Arc, time::Duration};
@@ -58,7 +58,11 @@ use tracing::{debug, error, info, trace};
use crate::{error::MqttError, mqtt_stream::MqttStream, PacketIOError};
use subscriptions::{ClientInformation, SubscriptionManager};
-use self::{message::MqttMessage, state::ClientState};
+use self::{
+ login::{LoginError, LoginHandler},
+ message::MqttMessage,
+ state::ClientState,
+};
/// The unique id (per server) of a connecting client
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -85,6 +89,9 @@ pub enum ClientError {
/// An error occurred during the sending/receiving of a packet
#[error("An error occured during the handling of a packet")]
Packet(#[from] PacketIOError),
+ /// An authentication was rejected
+ #[error("An authentication was rejected")]
+ Authentication(#[from] LoginError),
}
#[derive(Debug)]
@@ -129,13 +136,14 @@ impl ClientSource {
///
/// Check out the server example for a working version.
///
-pub struct MqttServer {
+pub struct MqttServer<LH = ()> {
clients: Arc<DashMap<ClientId, ClientState>>,
client_source: Mutex<ClientSource>,
+ auth_handler: LH,
subscription_manager: SubscriptionManager,
}
-impl MqttServer {
+impl MqttServer<()> {
/// Create a new MQTT server listening on the given `SocketAddr`
pub async fn serve_v3_unsecured_tcp<Addr: ToSocketAddrs>(
addr: Addr,
@@ -145,9 +153,22 @@ impl MqttServer {
Ok(MqttServer {
clients: Arc::new(DashMap::new()),
client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)),
+ auth_handler: (),
subscription_manager: SubscriptionManager::new(),
})
}
+}
+
+impl<LH: Send + Sync + LoginHandler + 'static> MqttServer<LH> {
+ /// Switch the login handler with a new one
+ pub fn with_login_handler<NLH: LoginHandler>(self, new_login_handler: NLH) -> MqttServer<NLH> {
+ MqttServer {
+ clients: self.clients,
+ client_source: self.client_source,
+ auth_handler: new_login_handler,
+ subscription_manager: self.subscription_manager,
+ }
+ }
/// Start accepting new clients connecting to the server
pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> {
@@ -191,15 +212,15 @@ impl MqttServer {
}
#[allow(clippy::too_many_arguments)]
- async fn connect_client<'message>(
- server: &MqttServer,
+ async fn connect_client<'message, LH: LoginHandler>(
+ server: &MqttServer<LH>,
mut client: MqttStream,
_protocol_name: MString<'message>,
_protocol_level: u8,
clean_session: bool,
will: Option<MLastWill<'message>>,
- _username: Option<MString<'message>>,
- _password: Option<&'message [u8]>,
+ username: Option<MString<'message>>,
+ password: Option<&'message [u8]>,
keep_alive: u16,
client_id: MString<'message>,
) -> Result<(), ClientError> {
@@ -227,6 +248,18 @@ impl MqttServer {
server.clients.contains_key(&client_id)
};
+ let client_id = Arc::new(client_id);
+
+ if let Err(err) = server
+ .auth_handler
+ .allow_login(client_id.clone(), username.as_deref(), password)
+ .await
+ {
+ send_connack(session_present, MConnectReturnCode::Accepted, &mut client).await?;
+
+ return Err(ClientError::Authentication(err));
+ }
+
send_connack(session_present, MConnectReturnCode::Accepted, &mut client).await?;
debug!(?client_id, "Accepted new connection");
@@ -240,15 +273,13 @@ impl MqttServer {
{
let client_state = server
.clients
- .entry(client_id.clone())
+ .entry((*client_id).clone())
.or_insert_with(ClientState::default);
client_state
.set_new_connection(client_connection.clone())
.await;
}
- let client_id = Arc::new(client_id);
-
let mut last_will: Option<MqttMessage> = will
.as_ref()
.map(|will| MqttMessage::from_last_will(will, client_id.clone()));