summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarcel Müller <neikos@neikos.email>2024-04-04 12:28:14 +0200
committerGitHub <noreply@github.com>2024-04-04 12:28:14 +0200
commitde260daf73add9c28ceedaef1345a250c241b1de (patch)
tree385c71654cf90501c1ef59cccee443dcb7dded87
parentae07b4c04a27be8e21e0da36a45e80ed6b9ce498 (diff)
parent8dccfc038dce577f6e7ff4e17fb86debfd97b3cd (diff)
Merge pull request #273 from TheNeikos/feature/expand_client
Refactor callbacks to remove hashmap and enum matching
-rw-r--r--src/client/mod.rs7
-rw-r--r--src/client/receive.rs75
-rw-r--r--src/client/send.rs104
3 files changed, 89 insertions, 97 deletions
diff --git a/src/client/mod.rs b/src/client/mod.rs
index 12175bb..72780b3 100644
--- a/src/client/mod.rs
+++ b/src/client/mod.rs
@@ -14,9 +14,8 @@ use std::sync::Arc;
use futures::lock::Mutex;
use self::send::Acknowledge;
-use self::send::CallbackState;
+use self::send::Callbacks;
use self::send::ClientHandlers;
-use self::send::Id;
use self::state::ConnectState;
use self::state::SessionState;
@@ -24,7 +23,7 @@ struct InnerClient {
connection_state: Option<ConnectState>,
session_state: Option<SessionState>,
default_handlers: ClientHandlers,
- outstanding_completions: std::collections::HashMap<Id, CallbackState>,
+ outstanding_callbacks: Callbacks,
}
pub struct MqttClient {
@@ -41,7 +40,7 @@ impl MqttClient {
on_packet_recv: Box::new(|_| ()),
handle_acknowledge: Box::new(|_| Acknowledge::Yes),
},
- outstanding_completions: std::collections::HashMap::new(),
+ outstanding_callbacks: Callbacks::new(),
})),
}
}
diff --git a/src/client/receive.rs b/src/client/receive.rs
index e7d55c2..0231fd7 100644
--- a/src/client/receive.rs
+++ b/src/client/receive.rs
@@ -15,8 +15,6 @@ use tracing::Instrument;
use yoke::Yoke;
use super::InnerClient;
-use crate::client::CallbackState;
-use crate::client::Id;
use crate::codecs::MqttPacketCodec;
use crate::packets::MqttPacket;
use crate::packets::MqttWriter;
@@ -107,21 +105,12 @@ async fn handle_pingresp(
let mut inner = inner.lock().await;
let inner = &mut *inner;
- if let Some(callback) = inner.outstanding_completions.get_mut(&Id::PingReq) {
- match callback {
- CallbackState::PingReq { on_pingresp } => {
- if let Some(cb) = on_pingresp.pop_front() {
- if cb.send(()).is_err() {
- tracing::debug!(
- "PingReq completion handler was dropped before receiving response"
- )
- }
- } else {
- tracing::warn!("Received an unwarranted PingResp from the server, continuing")
- }
- }
- _ => todo!("Had non pingreq in pingreq callback state"),
+ if let Some(cb) = inner.outstanding_callbacks.take_ping_req() {
+ if cb.send(()).is_err() {
+ tracing::debug!("PingReq completion handler was dropped before receiving response")
}
+ } else {
+ tracing::warn!("Received an unwarranted PingResp from the server, continuing")
}
Ok(())
@@ -170,22 +159,12 @@ async fn handle_pubcomp(
session_state.outstanding_packets.remove_by_id(pident);
tracing::trace!("Removed packet id from outstanding packets");
- if let Some(callback) = inner
- .outstanding_completions
- .get_mut(&Id::PacketIdentifier(pident))
- {
- match callback {
- CallbackState::Qos2 { on_complete, .. } => {
- if let Some(on_complete) = on_complete.take() {
- if let Err(_) = on_complete.send(packet.clone()) {
- tracing::trace!("Could not send ack, receiver was dropped.")
- }
- } else {
- todo!("Invariant broken: Double on_complete for a single pid: {pident}")
- }
- }
- _ => todo!(),
+ if let Some(callback) = inner.outstanding_callbacks.take_qos2_complete(pident) {
+ if let Err(_) = callback.on_complete.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
}
+ } else {
+ todo!("Invariant broken: Received on_complete for unknown packet")
}
}
}
@@ -221,17 +200,9 @@ async fn handle_puback(
session_state.outstanding_packets.remove_by_id(pident);
tracing::trace!("Removed packet id from outstanding packets");
- if let Some(callback) = inner
- .outstanding_completions
- .remove(&Id::PacketIdentifier(pident))
- {
- match callback {
- CallbackState::Qos1 { on_acknowledge } => {
- if let Err(_) = on_acknowledge.send(packet.clone()) {
- tracing::trace!("Could not send ack, receiver was dropped.")
- }
- }
- _ => todo!(),
+ if let Some(callback) = inner.outstanding_callbacks.take_qos1(pident) {
+ if let Err(_) = callback.on_acknowledge.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
}
}
} else {
@@ -296,22 +267,12 @@ async fn handle_pubrec(
tracing::trace!("Update packet from outstanding packets");
conn_state.conn_write.send(pubrel).await.map_err(drop)?;
- if let Some(callback) = inner
- .outstanding_completions
- .get_mut(&Id::PacketIdentifier(pident))
- {
- match callback {
- CallbackState::Qos2 { on_receive, .. } => {
- if let Some(on_receive) = on_receive.take() {
- if let Err(_) = on_receive.send(packet.clone()) {
- tracing::trace!("Could not send ack, receiver was dropped.")
- }
- } else {
- todo!("Invariant broken: Double on_receive for a single pid: {pident}")
- }
- }
- _ => todo!(),
+ if let Some(callback) = inner.outstanding_callbacks.take_qos2_receive(pident) {
+ if let Err(_) = callback.on_receive.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
}
+ } else {
+ todo!("Invariant broken: Receive PubRec for unawaited {pident}")
}
}
}
diff --git a/src/client/send.rs b/src/client/send.rs
index 0223551..052d0a3 100644
--- a/src/client/send.rs
+++ b/src/client/send.rs
@@ -4,6 +4,7 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
+use std::collections::HashMap;
use std::collections::VecDeque;
use std::num::NonZeroU16;
@@ -104,21 +105,18 @@ impl MqttClient {
QualityOfService::AtMostOnce => unreachable!(),
QualityOfService::AtLeastOnce => {
let (on_acknowledge, recv) = futures::channel::oneshot::channel();
- inner.outstanding_completions.insert(
- Id::PacketIdentifier(pi),
- CallbackState::Qos1 { on_acknowledge },
- );
+ inner
+ .outstanding_callbacks
+ .add_qos1(pi, Qos1Callbacks { on_acknowledge });
published_recv = PublishedReceiver::Once(PublishedQos1 { recv });
}
QualityOfService::ExactlyOnce => {
let (on_receive, recv) = futures::channel::oneshot::channel();
let (on_complete, comp_recv) = futures::channel::oneshot::channel();
- inner.outstanding_completions.insert(
- Id::PacketIdentifier(pi),
- CallbackState::Qos2 {
- on_receive: Some(on_receive),
- on_complete: Some(on_complete),
- },
+ inner.outstanding_callbacks.add_qos2(
+ pi,
+ Qos2ReceiveCallback { on_receive },
+ Qos2CompleteCallback { on_complete },
);
published_recv =
PublishedReceiver::Twice(PublishedQos2Received { recv, comp_recv });
@@ -229,23 +227,67 @@ pub(crate) enum Acknowledge {
YesWithProps {},
}
-#[derive(Debug, Hash, PartialEq, Eq)]
-pub(crate) enum Id {
- PingReq,
- PacketIdentifier(NonZeroU16),
+pub(crate) struct Callbacks {
+ ping_req: VecDeque<futures::channel::oneshot::Sender<()>>,
+ qos1: HashMap<NonZeroU16, Qos1Callbacks>,
+ qos2_receive: HashMap<NonZeroU16, Qos2ReceiveCallback>,
+ qos2_complete: HashMap<NonZeroU16, Qos2CompleteCallback>,
+}
+
+impl Callbacks {
+ pub(crate) fn new() -> Callbacks {
+ Callbacks {
+ ping_req: Default::default(),
+ qos1: HashMap::default(),
+ qos2_receive: HashMap::default(),
+ qos2_complete: HashMap::default(),
+ }
+ }
+
+ pub(crate) fn add_ping_req(&mut self, cb: futures::channel::oneshot::Sender<()>) {
+ self.ping_req.push_back(cb);
+ }
+
+ pub(crate) fn add_qos1(&mut self, id: NonZeroU16, cb: Qos1Callbacks) {
+ self.qos1.insert(id, cb);
+ }
+
+ pub(crate) fn add_qos2(
+ &mut self,
+ id: NonZeroU16,
+ rec: Qos2ReceiveCallback,
+ comp: Qos2CompleteCallback,
+ ) {
+ self.qos2_receive.insert(id, rec);
+ self.qos2_complete.insert(id, comp);
+ }
+
+ pub(crate) fn take_ping_req(&mut self) -> Option<futures::channel::oneshot::Sender<()>> {
+ self.ping_req.pop_front()
+ }
+
+ pub(crate) fn take_qos1(&mut self, id: NonZeroU16) -> Option<Qos1Callbacks> {
+ self.qos1.remove(&id)
+ }
+
+ pub(crate) fn take_qos2_receive(&mut self, id: NonZeroU16) -> Option<Qos2ReceiveCallback> {
+ self.qos2_receive.remove(&id)
+ }
+
+ pub(crate) fn take_qos2_complete(&mut self, id: NonZeroU16) -> Option<Qos2CompleteCallback> {
+ self.qos2_complete.remove(&id)
+ }
+}
+
+pub(crate) struct Qos1Callbacks {
+ pub(crate) on_acknowledge: futures::channel::oneshot::Sender<crate::packets::MqttPacket>,
}
-pub(crate) enum CallbackState {
- PingReq {
- on_pingresp: VecDeque<futures::channel::oneshot::Sender<()>>,
- },
- Qos1 {
- on_acknowledge: futures::channel::oneshot::Sender<crate::packets::MqttPacket>,
- },
- Qos2 {
- on_receive: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>,
- on_complete: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>,
- },
+pub(crate) struct Qos2ReceiveCallback {
+ pub(crate) on_receive: futures::channel::oneshot::Sender<crate::packets::MqttPacket>,
+}
+pub(crate) struct Qos2CompleteCallback {
+ pub(crate) on_complete: futures::channel::oneshot::Sender<crate::packets::MqttPacket>,
}
pub struct Publish {
@@ -365,17 +407,7 @@ impl MqttClient {
let (sender, recv) = futures::channel::oneshot::channel();
- let cbs = inner
- .outstanding_completions
- .entry(Id::PingReq)
- .or_insert_with(|| CallbackState::PingReq {
- on_pingresp: Default::default(),
- });
-
- match cbs {
- CallbackState::PingReq { on_pingresp } => on_pingresp.push_back(sender),
- _ => unreachable!("Had a non-pingreq in a pingreq response"),
- }
+ inner.outstanding_callbacks.add_ping_req(sender);
conn_state.conn_write.send(packet).await.map_err(drop)?;