diff options
author | Marcel Müller <neikos@neikos.email> | 2024-04-04 12:22:15 +0200 |
---|---|---|
committer | Marcel Müller <neikos@neikos.email> | 2024-04-04 12:22:15 +0200 |
commit | 8dccfc038dce577f6e7ff4e17fb86debfd97b3cd (patch) | |
tree | 385c71654cf90501c1ef59cccee443dcb7dded87 | |
parent | 9a034a6b5d0a84d8c3e665ef4f089e4eada4502a (diff) |
Refactor Callbacks to use methods instead of direct
Signed-off-by: Marcel Müller <neikos@neikos.email>
-rw-r--r-- | src/client/receive.rs | 28 | ||||
-rw-r--r-- | src/client/send.rs | 65 |
2 files changed, 62 insertions, 31 deletions
diff --git a/src/client/receive.rs b/src/client/receive.rs index 0773636..0231fd7 100644 --- a/src/client/receive.rs +++ b/src/client/receive.rs @@ -105,7 +105,7 @@ async fn handle_pingresp( let mut inner = inner.lock().await; let inner = &mut *inner; - if let Some(cb) = inner.outstanding_callbacks.ping_req.pop_front() { + if let Some(cb) = inner.outstanding_callbacks.take_ping_req() { if cb.send(()).is_err() { tracing::debug!("PingReq completion handler was dropped before receiving response") } @@ -159,14 +159,12 @@ async fn handle_pubcomp( session_state.outstanding_packets.remove_by_id(pident); tracing::trace!("Removed packet id from outstanding packets"); - if let Some(callback) = inner.outstanding_callbacks.qos2.get_mut(&pident) { - if let Some(on_complete) = callback.on_complete.take() { - if let Err(_) = on_complete.send(packet.clone()) { - tracing::trace!("Could not send ack, receiver was dropped.") - } - } else { - todo!("Invariant broken: Double on_complete for a single pid: {pident}") + if let Some(callback) = inner.outstanding_callbacks.take_qos2_complete(pident) { + if let Err(_) = callback.on_complete.send(packet.clone()) { + tracing::trace!("Could not send ack, receiver was dropped.") } + } else { + todo!("Invariant broken: Received on_complete for unknown packet") } } } @@ -202,7 +200,7 @@ async fn handle_puback( session_state.outstanding_packets.remove_by_id(pident); tracing::trace!("Removed packet id from outstanding packets"); - if let Some(callback) = inner.outstanding_callbacks.qos1.remove(&pident) { + if let Some(callback) = inner.outstanding_callbacks.take_qos1(pident) { if let Err(_) = callback.on_acknowledge.send(packet.clone()) { tracing::trace!("Could not send ack, receiver was dropped.") } @@ -269,14 +267,12 @@ async fn handle_pubrec( tracing::trace!("Update packet from outstanding packets"); conn_state.conn_write.send(pubrel).await.map_err(drop)?; - if let Some(callback) = inner.outstanding_callbacks.qos2.get_mut(&pident) { - if let Some(on_receive) = callback.on_receive.take() { - if let Err(_) = on_receive.send(packet.clone()) { - tracing::trace!("Could not send ack, receiver was dropped.") - } - } else { - todo!("Invariant broken: Double on_receive for a single pid: {pident}") + if let Some(callback) = inner.outstanding_callbacks.take_qos2_receive(pident) { + if let Err(_) = callback.on_receive.send(packet.clone()) { + tracing::trace!("Could not send ack, receiver was dropped.") } + } else { + todo!("Invariant broken: Receive PubRec for unawaited {pident}") } } } diff --git a/src/client/send.rs b/src/client/send.rs index 4e8e23f..052d0a3 100644 --- a/src/client/send.rs +++ b/src/client/send.rs @@ -107,19 +107,16 @@ impl MqttClient { let (on_acknowledge, recv) = futures::channel::oneshot::channel(); inner .outstanding_callbacks - .qos1 - .insert(pi, Qos1Callbacks { on_acknowledge }); + .add_qos1(pi, Qos1Callbacks { on_acknowledge }); published_recv = PublishedReceiver::Once(PublishedQos1 { recv }); } QualityOfService::ExactlyOnce => { let (on_receive, recv) = futures::channel::oneshot::channel(); let (on_complete, comp_recv) = futures::channel::oneshot::channel(); - inner.outstanding_callbacks.qos2.insert( + inner.outstanding_callbacks.add_qos2( pi, - Qos2Callbacks { - on_receive: Some(on_receive), - on_complete: Some(on_complete), - }, + Qos2ReceiveCallback { on_receive }, + Qos2CompleteCallback { on_complete }, ); published_recv = PublishedReceiver::Twice(PublishedQos2Received { recv, comp_recv }); @@ -231,9 +228,10 @@ pub(crate) enum Acknowledge { } pub(crate) struct Callbacks { - pub(crate) ping_req: VecDeque<futures::channel::oneshot::Sender<()>>, - pub(crate) qos1: HashMap<NonZeroU16, Qos1Callbacks>, - pub(crate) qos2: HashMap<NonZeroU16, Qos2Callbacks>, + ping_req: VecDeque<futures::channel::oneshot::Sender<()>>, + qos1: HashMap<NonZeroU16, Qos1Callbacks>, + qos2_receive: HashMap<NonZeroU16, Qos2ReceiveCallback>, + qos2_complete: HashMap<NonZeroU16, Qos2CompleteCallback>, } impl Callbacks { @@ -241,18 +239,55 @@ impl Callbacks { Callbacks { ping_req: Default::default(), qos1: HashMap::default(), - qos2: HashMap::default(), + qos2_receive: HashMap::default(), + qos2_complete: HashMap::default(), } } + + pub(crate) fn add_ping_req(&mut self, cb: futures::channel::oneshot::Sender<()>) { + self.ping_req.push_back(cb); + } + + pub(crate) fn add_qos1(&mut self, id: NonZeroU16, cb: Qos1Callbacks) { + self.qos1.insert(id, cb); + } + + pub(crate) fn add_qos2( + &mut self, + id: NonZeroU16, + rec: Qos2ReceiveCallback, + comp: Qos2CompleteCallback, + ) { + self.qos2_receive.insert(id, rec); + self.qos2_complete.insert(id, comp); + } + + pub(crate) fn take_ping_req(&mut self) -> Option<futures::channel::oneshot::Sender<()>> { + self.ping_req.pop_front() + } + + pub(crate) fn take_qos1(&mut self, id: NonZeroU16) -> Option<Qos1Callbacks> { + self.qos1.remove(&id) + } + + pub(crate) fn take_qos2_receive(&mut self, id: NonZeroU16) -> Option<Qos2ReceiveCallback> { + self.qos2_receive.remove(&id) + } + + pub(crate) fn take_qos2_complete(&mut self, id: NonZeroU16) -> Option<Qos2CompleteCallback> { + self.qos2_complete.remove(&id) + } } pub(crate) struct Qos1Callbacks { pub(crate) on_acknowledge: futures::channel::oneshot::Sender<crate::packets::MqttPacket>, } -pub(crate) struct Qos2Callbacks { - pub(crate) on_receive: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>, - pub(crate) on_complete: Option<futures::channel::oneshot::Sender<crate::packets::MqttPacket>>, +pub(crate) struct Qos2ReceiveCallback { + pub(crate) on_receive: futures::channel::oneshot::Sender<crate::packets::MqttPacket>, +} +pub(crate) struct Qos2CompleteCallback { + pub(crate) on_complete: futures::channel::oneshot::Sender<crate::packets::MqttPacket>, } pub struct Publish { @@ -372,7 +407,7 @@ impl MqttClient { let (sender, recv) = futures::channel::oneshot::channel(); - inner.outstanding_callbacks.ping_req.push_back(sender); + inner.outstanding_callbacks.add_ping_req(sender); conn_state.conn_write.send(packet).await.map_err(drop)?; |