summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarcel Müller <neikos@neikos.email>2024-04-04 12:22:15 +0200
committerMarcel Müller <neikos@neikos.email>2024-04-04 12:22:15 +0200
commit8dccfc038dce577f6e7ff4e17fb86debfd97b3cd (patch)
tree385c71654cf90501c1ef59cccee443dcb7dded87
parent9a034a6b5d0a84d8c3e665ef4f089e4eada4502a (diff)
Refactor Callbacks to use methods instead of direct
Signed-off-by: Marcel Müller <neikos@neikos.email>
-rw-r--r--src/client/receive.rs28
-rw-r--r--src/client/send.rs65
2 files changed, 62 insertions, 31 deletions
diff --git a/src/client/receive.rs b/src/client/receive.rs
index 0773636..0231fd7 100644
--- a/src/client/receive.rs
+++ b/src/client/receive.rs
@@ -105,7 +105,7 @@ async fn handle_pingresp(
let mut inner = inner.lock().await;
let inner = &mut *inner;
- if let Some(cb) = inner.outstanding_callbacks.ping_req.pop_front() {
+ if let Some(cb) = inner.outstanding_callbacks.take_ping_req() {
if cb.send(()).is_err() {
tracing::debug!("PingReq completion handler was dropped before receiving response")
}
@@ -159,14 +159,12 @@ async fn handle_pubcomp(
session_state.outstanding_packets.remove_by_id(pident);
tracing::trace!("Removed packet id from outstanding packets");
- if let Some(callback) = inner.outstanding_callbacks.qos2.get_mut(&pident) {
- if let Some(on_complete) = callback.on_complete.take() {
- if let Err(_) = on_complete.send(packet.clone()) {
- tracing::trace!("Could not send ack, receiver was dropped.")
- }
- } else {
- todo!("Invariant broken: Double on_complete for a single pid: {pident}")
+ if let Some(callback) = inner.outstanding_callbacks.take_qos2_complete(pident) {
+ if let Err(_) = callback.on_complete.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
}
+ } else {
+ todo!("Invariant broken: Received on_complete for unknown packet")
}
}
}
@@ -202,7 +200,7 @@ async fn handle_puback(
session_state.outstanding_packets.remove_by_id(pident);
tracing::trace!("Removed packet id from outstanding packets");
- if let Some(callback) = inner.outstanding_callbacks.qos1.remove(&pident) {
+ if let Some(callback) = inner.outstanding_callbacks.take_qos1(pident) {
if let Err(_) = callback.on_acknowledge.send(packet.clone()) {
tracing::trace!("Could not send ack, receiver was dropped.")
}
@@ -269,14 +267,12 @@ async fn handle_pubrec(
tracing::trace!("Update packet from outstanding packets");
conn_state.conn_write.send(pubrel).await.map_err(drop)?;
- if let Some(callback) = inner.outstanding_callbacks.qos2.get_mut(&pident) {
- if let Some(on_receive) = callback.on_receive.take() {
- if let Err(_) = on_receive.send(packet.clone()) {
- tracing::trace!("Could not send ack, receiver was dropped.")
- }
- } else {
- todo!("Invariant broken: Double on_receive for a single pid: {pident}")
+ if let Some(callback) = inner.outstanding_callbacks.take_qos2_receive(pident) {
+ if let Err(_) = callback.on_receive.send(packet.clone()) {
+ tracing::trace!("Could not send ack, receiver was dropped.")
}
+ } else {
+ todo!("Invariant broken: Receive PubRec for unawaited {pident}")
}
}
}
diff --git a/src/client/send.rs b/src/client/send.rs
index 4e8e23f..052d0a3 100644
--- a/src/client/send.rs
+++ b/src/client/send.rs
@@ -107,19 +107,16 @@ impl MqttClient {
let (on_acknowledge, recv) = futures::channel::oneshot::channel();
inner
.outstanding_callbacks
- .qos1
- .insert(pi, Qos1Callbacks { on_acknowledge });
+ .add_qos1(pi, Qos1Callbacks { on_acknowledge });
published_recv = PublishedReceiver::Once(PublishedQos1 { recv });
}
QualityOfService::ExactlyOnce => {
let (on_receive, recv) = futures::channel::oneshot::channel();
let (on_complete, comp_recv) = futures::channel::oneshot::channel();
- inner.outstanding_callbacks.qos2.insert(
+ inner.outstanding_callbacks.add_qos2(
pi,
- Qos2Callbacks {
- on_receive: Some(on_receive),
- on_complete: Some(on_complete),
- },
+ Qos2ReceiveCallback { on_receive },
+ Qos2CompleteCallback { on_complete },
);
published_recv =
PublishedReceiver::Twice(PublishedQos2Received { recv, comp_recv });
@@ -231,9 +228,10 @@ pub(crate) enum Acknowledge {
}
pub(crate) struct Callbacks {
- pub(crate) ping_req: VecDeque<futures::channel::oneshot::Sender<()>>,
- pub(crate) qos1: HashMap<NonZeroU16, Qos1Callbacks>,
- pub(crate) qos2: HashMap<NonZeroU16, Qos2Callbacks>,
+ ping_req: VecDeque<futures::channel::oneshot::Sender<()>>,
+ qos1: HashMap<NonZeroU16, Qos1Callbacks>,
+ qos2_receive: HashMap<NonZeroU16, Qos2ReceiveCallback>,
+ qos2_complete: HashMap<NonZeroU16, Qos2CompleteCallback>,
}
impl Callbacks {
@@ -241,18 +239,55 @@ impl Callbacks {
Callbacks {
ping_req: Default::default(),
qos1: HashMap::default(),
- qos2: HashMap::default(),
+ qos2_receive: HashMap::default(),
+ qos2_complete: HashMap::default(),
}
}
+
+ pub(crate) fn add_ping_req(&mut self, cb: futures::channel::oneshot::Sender<()>) {
+ self.ping_req.push_back(cb);
+ }
+
+ pub(crate) fn add_qos1(&mut self, id: NonZeroU16, cb: Qos1Callbacks) {
+ self.qos1.insert(id, cb);
+ }
+
+ pub(crate) fn add_qos2(
+ &mut self,
+ id: NonZeroU16,
+ rec: Qos2ReceiveCallback,
+ comp: Qos2CompleteCallback,
+ ) {
+ self.qos2_receive.insert(id, rec);
+ self.qos2_complete.insert(id, comp);
+ }
+
+ pub(crate) fn take_ping_req(&mut self) -> Option<futures::channel::oneshot::Sender<()>> {
+ self.ping_req.pop_front()
+ }
+
+ pub(crate) fn take_qos1(&mut self, id: NonZeroU16) -> Option<Qos1Callbacks> {
+ self.qos1.remove(&id)
+ }
+
+ pub(crate) fn take_qos2_receive(&mut self, id: NonZeroU16) -> Option<Qos2ReceiveCallback> {
+ self.qos2_receive.remove(&id)
+ }
+
+ pub(crate) fn take_qos2_complete(&mut self, id: NonZeroU16) -> Option<Qos2CompleteCallback> {
+ self.qos2_complete.remove(&id)
+ }
}
pub(crate) struct Qos1Callbacks {
pub(crate) on_acknowledge: futures::channel::oneshot::Sender<crate::packets::MqttPacket>,
}
-pub(crate) struct Qos2Callbacks {
- pub(crate) on_receive: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>,
- pub(crate) on_complete: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>,
+pub(crate) struct Qos2ReceiveCallback {
+ pub(crate) on_receive: futures::channel::oneshot::Sender<crate::packets::MqttPacket>,
+}
+pub(crate) struct Qos2CompleteCallback {
+ pub(crate) on_complete: futures::channel::oneshot::Sender<crate::packets::MqttPacket>,
}
pub struct Publish {
@@ -372,7 +407,7 @@ impl MqttClient {
let (sender, recv) = futures::channel::oneshot::channel();
- inner.outstanding_callbacks.ping_req.push_back(sender);
+ inner.outstanding_callbacks.add_ping_req(sender);
conn_state.conn_write.send(packet).await.map_err(drop)?;