summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarcel Müller <neikos@neikos.email>2023-01-23 08:57:09 +0100
committerMarcel Müller <neikos@neikos.email>2023-01-23 08:57:09 +0100
commitaacd7050fa1bd2620f8dc69cb2b189f55a206286 (patch)
treec206fbf0e2c99e65d70c93e0e84fa66b0e0c1145
parent78e21316274c25423bb501efe43931417633fca4 (diff)
Remove outer Arc from MqttServer
Signed-off-by: Marcel Müller <neikos@neikos.email>
-rw-r--r--src/bin/cloudmqtt-server.rs4
-rw-r--r--src/server/mod.rs153
2 files changed, 92 insertions, 65 deletions
diff --git a/src/bin/cloudmqtt-server.rs b/src/bin/cloudmqtt-server.rs
index cfa1acc..e307d6a 100644
--- a/src/bin/cloudmqtt-server.rs
+++ b/src/bin/cloudmqtt-server.rs
@@ -72,9 +72,9 @@ async fn main() {
.with_login_handler(SimpleLoginHandler)
.with_subscription_handler(SimpleSubscriptionHandler);
- let server = Arc::new(server);
+ let server = server;
- tokio::spawn(server.clone().subscribe_to_message(
+ tokio::spawn(server.subscribe_to_message(
vec![String::from("foo/bar"), String::from("bar/#")],
|msg| async move {
info!("Got message: {msg:?}");
diff --git a/src/server/mod.rs b/src/server/mod.rs
index a252ca4..a1517ea 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -147,6 +147,18 @@ impl ClientSource {
/// Check out the server example for a working version.
///
pub struct MqttServer<LoginH, SubH> {
+ inner: Arc<InnerServer<LoginH, SubH>>,
+}
+
+impl<LoginH, SubH> Clone for MqttServer<LoginH, SubH> {
+ fn clone(&self) -> Self {
+ Self {
+ inner: self.inner.clone(),
+ }
+ }
+}
+
+struct InnerServer<LoginH, SubH> {
clients: Arc<DashMap<ClientId, ClientState>>,
client_source: Mutex<ClientSource>,
auth_handler: LoginH,
@@ -164,11 +176,13 @@ impl MqttServer<AllowAllLogins, AllowAllSubscriptions> {
let (extra_listeners, _) = tokio::sync::broadcast::channel(50);
Ok(MqttServer {
- clients: Arc::new(DashMap::new()),
- client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)),
- auth_handler: AllowAllLogins,
- extra_listeners,
- subscription_manager: Arc::new(SubscriptionManager::new()),
+ inner: Arc::new(InnerServer {
+ clients: Arc::new(DashMap::new()),
+ client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)),
+ auth_handler: AllowAllLogins,
+ extra_listeners,
+ subscription_manager: Arc::new(SubscriptionManager::new()),
+ }),
})
}
}
@@ -183,12 +197,16 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
self,
new_login_handler: NLH,
) -> MqttServer<NLH, SH> {
+ let inner = Arc::try_unwrap(self.inner)
+ .unwrap_or_else(|_| panic!("Called after started listening"));
MqttServer {
- clients: self.clients,
- client_source: self.client_source,
- auth_handler: new_login_handler,
- extra_listeners: self.extra_listeners,
- subscription_manager: self.subscription_manager,
+ inner: Arc::new(InnerServer {
+ clients: inner.clients,
+ client_source: inner.client_source,
+ auth_handler: new_login_handler,
+ extra_listeners: inner.extra_listeners,
+ subscription_manager: inner.subscription_manager,
+ }),
}
}
@@ -201,31 +219,36 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
self,
new_subscription_handler: NSH,
) -> MqttServer<LH, NSH> {
+ let inner = Arc::try_unwrap(self.inner)
+ .unwrap_or_else(|_| panic!("Called after started listening"));
MqttServer {
- clients: self.clients,
- client_source: self.client_source,
- auth_handler: self.auth_handler,
- extra_listeners: self.extra_listeners,
- subscription_manager: Arc::new({
- let manager = Arc::try_unwrap(self.subscription_manager);
-
- manager
- .unwrap_or_else(|_| panic!("Called after started listening"))
- .with_subscription_handler(new_subscription_handler)
+ inner: Arc::new(InnerServer {
+ clients: inner.clients,
+ client_source: inner.client_source,
+ auth_handler: inner.auth_handler,
+ extra_listeners: inner.extra_listeners,
+ subscription_manager: Arc::new({
+ let manager = Arc::try_unwrap(inner.subscription_manager);
+
+ manager
+ .unwrap_or_else(|_| panic!("Called after started listening"))
+ .with_subscription_handler(new_subscription_handler)
+ }),
}),
}
}
/// Start accepting new clients connecting to the server
- pub async fn accept_new_clients(self: Arc<Self>) -> Result<(), MqttError> {
+ pub async fn accept_new_clients(&self) -> Result<(), MqttError> {
let mut client_source = self
+ .inner
.client_source
.try_lock()
.map_err(|_| MqttError::AlreadyListening)?;
loop {
+ let server: MqttServer<LH, SH> = self.clone();
let client = client_source.accept().await?;
- let server = self.clone();
tokio::spawn(async move {
if let Err(client_error) = server.accept_client(client).await {
tracing::error!("Client error: {}", client_error)
@@ -235,58 +258,60 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
}
/// Listen to messages sent to the given topic_paths
- pub async fn subscribe_to_message<
+ pub fn subscribe_to_message<
Fut: std::future::Future<Output = ()>,
- CB: FnMut(MqttMessage) -> Fut,
+ CB: FnMut(MqttMessage) -> Fut + 'static,
>(
- self: Arc<Self>,
+ &self,
topic_paths: Vec<String>,
mut callback: CB,
- ) -> Result<(), MqttError> {
- let mut listener = self.extra_listeners.subscribe();
+ ) -> impl std::future::Future<Output = Result<(), MqttError>> + 'static {
+ let mut listener = self.inner.extra_listeners.subscribe();
let topics = topic_paths
.into_iter()
.map(TopicFilter::parse_from)
.collect::<Vec<_>>();
- loop {
- let message = listener.recv().await;
-
- match message {
- Ok(message) => {
- if topics.iter().any(|topic| {
- let msg_topic = TopicFilter::parse_from(message.topic().to_string());
-
- let mut i = 0;
- loop {
- match (topic.get(i), msg_topic.get(i)) {
- (None, None) => break true,
- (None, Some(_)) => break false,
- (Some(_), None) => break false,
- (Some(TopicFilter::MultiWildcard), Some(_)) => break true,
- (Some(TopicFilter::SingleWildcard), Some(_)) => (),
- (Some(left), Some(right)) => {
- if left != right {
- break false;
+ async move {
+ loop {
+ let message = listener.recv().await;
+
+ match message {
+ Ok(message) => {
+ if topics.iter().any(|topic| {
+ let msg_topic = TopicFilter::parse_from(message.topic().to_string());
+
+ let mut i = 0;
+ loop {
+ match (topic.get(i), msg_topic.get(i)) {
+ (None, None) => break true,
+ (None, Some(_)) => break false,
+ (Some(_), None) => break false,
+ (Some(TopicFilter::MultiWildcard), Some(_)) => break true,
+ (Some(TopicFilter::SingleWildcard), Some(_)) => (),
+ (Some(left), Some(right)) => {
+ if left != right {
+ break false;
+ }
}
}
- }
- i += 1;
+ i += 1;
+ }
+ }) {
+ callback(message).await;
}
- }) {
- callback(message).await;
}
- }
- Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
- Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
- warn!("Subscriber lagged by {count} values")
+ Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
+ Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
+ warn!("Subscriber lagged by {count} values")
+ }
}
}
- }
- Ok(())
+ Ok(())
+ }
}
/// Accept a new client connected through the `client` stream
@@ -343,15 +368,16 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
}
let session_present = if clean_session {
- let _ = server.clients.remove(&client_id);
+ let _ = server.inner.clients.remove(&client_id);
false
} else {
- server.clients.contains_key(&client_id)
+ server.inner.clients.contains_key(&client_id)
};
let client_id = Arc::new(client_id);
if let Err(err) = server
+ .inner
.auth_handler
.allow_login(client_id.clone(), username.as_deref(), password)
.await
@@ -373,6 +399,7 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
{
let client_state = server
+ .inner
.clients
.entry((*client_id).clone())
.or_insert_with(ClientState::default);
@@ -385,13 +412,13 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
.as_ref()
.map(|will| MqttMessage::from_last_will(will, client_id.clone()));
- let published_packets = server.subscription_manager.clone();
+ let published_packets = server.inner.subscription_manager.clone();
let (published_packets_send, mut published_packets_rec) =
tokio::sync::mpsc::unbounded_channel::<MqttMessage>();
let send_loop = {
let publisher_client_id = client_id.clone();
- let clients = server.clients.clone();
+ let clients = server.inner.clients.clone();
tokio::spawn(async move {
loop {
match published_packets_rec.recv().await {
@@ -422,10 +449,10 @@ impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
let read_loop = {
let keep_alive = keep_alive;
- let subscription_manager = server.subscription_manager.clone();
+ let subscription_manager = server.inner.subscription_manager.clone();
let client_id = client_id.clone();
- let clients = server.clients.clone();
- let extra_listener = server.extra_listeners.clone();
+ let clients = server.inner.clients.clone();
+ let extra_listener = server.inner.extra_listeners.clone();
tokio::spawn(async move {
let client_id = client_id;